Beijing Institute of Technology | Ming-Jian Li
The following Python code is a companion code for the course on Artificial Intelligence and Simulation Science. It functions to cluster the data in dataset cluster.csv (click to download) using the k-means clustering method.
x1import pandas as pd
2import matplotlib.pyplot as plt
3import numpy as np
4
5data=pd.read_csv(r'cluster.csv',sep=',')
6
7x_value=data['radius'].values
8y_value=data['weight'].values
9
10M = len(data)
11print(M)
12
13plt.scatter(x_value,y_value,c='red')
14plt.show()
15
16# randomly generate 2 centroids
17def createpoint2():
18 x=5*np.random.random(2)
19 y=5*np.random.random(2)
20 return x,y
21
22point2=createpoint2()
23p0=(point2[0][0],point2[1][0])
24p1=(point2[0][1],point2[1][1])
25
26def distance(p1,p2):
27 return np.sqrt((p1[0]-p2[0])**2+(p1[1]-p2[1])**2)
28
29# plot centroids
30plt.scatter(x_value,y_value,c='red')
31plt.scatter(point2[0],point2[1],c='black',marker='x')
32plt.show()
33
34# calculate distance
35for i in range(M):
36 point=(x_value[i],y_value[i])
37 if distance(p0,point)<distance(p1,point):
38 plt.scatter(point[0],point[1],c='blue',marker='*')
39 else:
40 plt.scatter(point[0],point[1],c='red')
41plt.scatter(point2[0],point2[1],c='black',marker='x')
42plt.show()
43
44iter = 5
45
46for i in range(iter):
47 # re-calculate centroids
48 new0_list=[]
49 new1_list=[]
50 for i in range(M):
51 point=(x_value[i],y_value[i])
52 if distance(p0,point)<distance(p1,point):
53 plt.scatter(point[0],point[1],c='blue',marker='*')
54 new0_list.append(point)
55 else:
56 plt.scatter(point[0],point[1],c='red')
57 new1_list.append(point)
58 # get new centroids
59 dt_temp0=pd.DataFrame(new0_list,columns=['x','y'])
60 p0=(np.average(dt_temp0['x']),np.average(dt_temp0['y']))
61 dt_temp1=pd.DataFrame(new1_list,columns=['x','y'])
62 p1=(np.average(dt_temp1['x']),np.average(dt_temp1['y']))
63 plt.scatter(p0[0],p0[1],c='black',marker='x')
64 plt.scatter(p1[0],p1[1],c='black',marker='x')
65 plt.show()
66 # update distance
67 for i in range(M):
68 point=(x_value[i],y_value[i])
69 if distance(p0,point)<distance(p1,point):
70 plt.scatter(point[0],point[1],c='blue',marker='*')
71 else:
72 plt.scatter(point[0],point[1],c='red')
73 plt.scatter(p0[0],p0[1],c='black',marker='x')
74 plt.scatter(p1[0],p1[1],c='black',marker='x')
75 plt.show()
The result is as follows.