聚类
聚类主要内容是将样本进行归类,同种类别的样本放到一起,所有样本最终会形成K个簇,它属于无监督学习。
核心思想
根据给定的K值和K个初始质心将样本中每个点都分到距离最近的类簇中,当所有点分配完后根据每个类簇的所有点重新计算质心,一般是通过平均值计算,然后再将每个点分到距离最近的新类簇中,不断循环此操作,直到质心不再变化或达到一定的迭代次数。数学上可以证明k-means是收敛的。
伪代码
随机选择k个质心,即为簇数while(true){ 计算每个点到最近距离的质心,归为该类。 重新计算每个类的质心。 if(质心与上一次质心一样or达到最大迭代次数) break;}复制代码
缺点
- 需要事先确定类簇的数量。
- 质心的选取会影响最终的聚类结果。
代码实现
from numpy import *import matplotlib.pyplot as pltfrom sklearn.cluster import KMeansdef kmeans(dataSet, k): sampleNum, col = dataSet.shape cluster = mat(zeros((sampleNum, 2))) centroids = zeros((k, col)) ##choose centroids for i in range(k): index = int(random.uniform(0, sampleNum)) centroids[i, :] = dataSet[index, :] clusterChanged = True while clusterChanged: clusterChanged = False for i in range(sampleNum): minDist = sqrt(sum(power(centroids[0, :] - dataSet[i, :], 2))) minIndex = 0 for j in range(1,k): distance = sqrt(sum(power(centroids[j, :] - dataSet[i, :], 2))) if distance < minDist: minDist = distance minIndex = j if cluster[i, 0] != minIndex: clusterChanged = True cluster[i, :] = minIndex, minDist**2 for j in range(k): pointsInCluster = dataSet[nonzero(cluster[:, 0].A == j)[0]] centroids[j, :] = mean(pointsInCluster, axis = 0) return centroids, clusterdataSet = [[1,1],[3,1],[1,4],[2,5],[11,12],[14,11],[13,12],[11,16],[17,12],[28,10],[26,15],[27,13],[28,11],[29,15]]dataSet = mat(dataSet)k = 3centroids, cluster = kmeans(dataSet, k)sampleNum, col = dataSet.shapemark = ['or', 'ob', 'og']for i in range(sampleNum): markIndex = int(cluster[i, 0]) plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])mark = ['+r', '+b', '+g']for i in range(k): plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12)plt.show()复制代码
结果:
直接用机器学习库更加方便
from numpy import *import matplotlib.pyplot as pltfrom sklearn.cluster import KMeansdataSet = [[1,1],[3,1],[1,4],[2,5],[11,12],[14,11],[13,12],[11,16],[17,12],[28,10],[26,15],[27,13],[28,11],[29,15]]dataSet=mat(dataSet)k = 3markers = ['^', 'o', 'x']cls =KMeans(k).fit(dataSet)for i in range(k): members=cls.labels_==i plt.scatter(dataSet[members,0],dataSet[members,1],marker=markers[i])plt.show()复制代码
========广告时间========
鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以到 进行预定。感谢各位朋友。
=========================
欢迎关注: