1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import numpy as np import matplotlib.pyplot as plt
def kmeans(data, K, max_iterations=100): centers = data[np.random.choice(data.shape[0], K, replace=False)]
for _ in range(max_iterations): distances = np.linalg.norm(data[:, None] - centers, axis=2)
labels = np.argmin(distances, axis=1)
new_centers = np.array([data[labels == k].mean(axis=0) for k in range(K)])
if np.all(centers == new_centers): break
centers = new_centers
return labels, centers
data = np.random.rand(100, 2) K = 3 labels, centers = kmeans(data, K)
plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis') plt.scatter(centers[:, 0], centers[:, 1], c='red', marker='x') plt.show()
|