kmeans算法原理及代码实现
完整的实验代码在我的github上👉QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐
kmeans算法原理
在上一篇文章中,我们介绍了Mean Shift聚类算法的原理和代码实现。不同于Mean Shift聚类的基于密度的方法,k均值聚类是一种基于距离的聚类算法。它将数据集划分为k个簇,每个簇包含最接近它们的数据点。与Mean Shift聚类不同,k均值聚类需要事先指定簇的数量k。
该算法的实现过程包括以下几个步骤:
- 随机选择k个点作为初始中心点。
- 对于每个数据点,计算它与k个中心点的距离,并将其归为距离最近的中心点所在的簇。
- 对于每个簇,重新计算它们的中心点。
- 重复步骤2和3,直到簇不再发生变化或达到预定的迭代次数。
公式
kmeans算法的代价函数为:
$J(c, \mu) = \sum_{i=1}{m}|x - \mu_{c{(i)}}|$
其中,$c{(i)}$表示第$i$个样本所属的簇,$\mu_{c{(i)}}$表示第$i$个样本所属簇的中心点。
代码实现
代码实现步骤为:
在数据集中随机选择k个点作为初始中心点。
ci = np.random.choice(len(dataset),k,replace=False)
centers = data[ci]
对于每个数据点,计算它与k个中心点的距离,并将其归为距离最近的中心点所在的簇。
distances = np.hstack([np.sum((data-center)**2,axis=1)[:,np.newaxis] for center in centers])
mink = np.argmin(distances,axis=1)
对于每个簇,重新计算它们的中心点。
newcenters = np.array([np.mean(data[mink==i],axis=0) for i in range(k)])
重复步骤2和3,直到簇不再发生变化或达到预定的迭代次数。
delta = np.sum(np.abs(newcenters - centers))
if delta < 1e-5 or it >10000:
return centers,mink,it
centers = newcenters
it += 1
具体实现代码如下:
import numpy as np
def kmeans(data,k):
ci = np.random.choice(len(dataset),k,replace=False)
centers = data[ci]
it = 0
while True:
#计算所有点到聚类中心的距离
distances = np.hstack([np.sum((data-center)**2,axis=1)[:,np.newaxis] for center in centers])
mink = np.argmin(distances,axis=1)
#确定下一轮聚类中心
newcenters = np.array([np.mean(data[mink==i],axis=0) for i in range(k)])
#判断是否需要下次迭代
delta = np.sum(np.abs(newcenters - centers))
if delta < 1e-5 or it >10000:
return centers,mink,it
#下次迭代
centers = newcenters
it += 1
if __name__ == "__main__":
#读取数据
data = []
with open("data.txt") as f:
for line in f:
x,y = line.strip().split()
data.append((float(x),float(y)))
data = np.array(dataset)
#kmeans聚类
centers,index,it = kmeans(data,2)
print(centers)
print(index)
print(it)
总结
本文介绍了kmeans聚类算法的原理和代码实现。kmeans是一种基于距离的聚类算法,它将数据集划分为k个簇,每个簇包含最接近它们的数据点。与Mean Shift聚类不同,k均值聚类需要事先指定簇的数量k。在实现过程中,首先随机选择k个点作为初始中心点,然后计算每个数据点与k个中心点的距离,并将其归为距离最近的中心点所在的簇。之后重新计算每个簇的中心点,并重复上述步骤,直到簇不再发生变化或达到预定的迭代次数。本文还给出了算法的代价函数和具体代码实现。Kmeans算法是一种简单易用的聚类算法,在实际应用中具有广泛的应用前景。
完整的实验代码在我的github上👉QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐