1. 手撕 k-means 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import matplotlib.pyplot as plt

def kmeans(data, K, max_iterations=100):
# 随机初始化K个簇中心点
centers = data[np.random.choice(data.shape[0], K, replace=False)]

# 迭代更新中心点
for _ in range(max_iterations):
# 计算每个样本到各个中心点的距离
distances = np.linalg.norm(data[:, None] - centers, axis=2)
# 等价于distances = np.linalg.norm(X[:, np.newaxis, :] - centers, axis=2)

# 分配每个数据点到最近的簇
labels = np.argmin(distances, axis=1)

# 更新中心点为每个簇的平均值
new_centers = np.array([data[labels == k].mean(axis=0) for k in range(K)])

if np.all(centers == new_centers):
break

centers = new_centers

return labels, centers

# 示例输入
data = np.random.rand(100, 2) # 100个样本,每个样本有两个特征
K = 3 # 聚类数为3
labels, centers = kmeans(data, K)

# 可视化结果
plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis')
plt.scatter(centers[:, 0], centers[:, 1], c='red', marker='x')
plt.show()