Huang Chao's Blog

  • 首页

  • 标签

  • 分类

  • 归档

手写Kmeans

发表于 2019-09-22 分类于 Python 阅读次数:
本文字数: 2.3k 阅读时长 ≈ 2 分钟

kmeans.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
手写kmeans
"""
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
import typing


class KMeans:
def __init__(self, k: int):
self.k = k
self._centers = None

def fit(self, nda: np.ndarray, n_iters=10, callback: typing.Callable = None):
n_features = nda.shape[1]
centers = self.random_centers(self.k, n_features)

for i in range(n_iters):
labels = self.assign(centers, nda)
centers = self.update(nda, labels, self.k)
if callback:
callback(nda, labels, i)

def predict(self, nda: np.ndarray):
return self.assign(self.centers, nda)

@property
def centers(self):
if not self._centers:
raise AttributeError("Call 'fit' before reference to centers.")
return self._centers

@staticmethod
def random_centers(k, n_features):
return np.random.random((k, n_features))

@staticmethod
def assign(centers, nda):
n = nda.shape[0]
labels = np.empty(n)
for i, arr in enumerate(nda):
labels[i] = KMeans.nearest_center(centers, arr)
return labels

@staticmethod
def update(nda, labels, k):
centers = np.empty(k)
for i in range(k):
center = KMeans.cal_center(nda, labels, i)
centers[i] = center
return centers

@staticmethod
def distance(arr1, arr2):
return np.sum((arr1 - arr2) ** 2)

@staticmethod
def cal_center(nda, labels, i):
return np.mean(nda[labels == i])

@staticmethod
def nearest_center(centers, nda):
j = -1
min_dis = np.PINF
for i, center in enumerate(centers):
dis = KMeans.distance(center, nda)
if dis < min_dis:
min_dis = dis
j = i
return j


def my_plot(nda, labels, i):
if i % 10 == 0:
plt.scatter(nda[:, 0], nda[:, 1], c=labels)
plt.title("i = %s" % i)
plt.savefig("%s.png" % i)


def main():
X, y = make_blobs(n_samples=1000, n_features=2, centers=[[-1, -1], [0, 0], [1, 1], [2, 2]],
cluster_std=[0.4, 0.2, 0.2, 0.2],
random_state=9)

kmeans = KMeans(4)
kmeans.fit(X, n_iters=40, callback=my_plot)


if __name__ == '__main__':
main()

运行结果:




坚持原创技术分享,您的支持将鼓励我继续创作!
Huang Chao 微信支付

微信支付

Huang Chao 支付宝

支付宝

  • 本文作者: Huang Chao
  • 本文链接: https://huangchaosp.github.io/2019/09/22/手写Kmeans/
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
# python # 机器学习 # kmeans # 练习
软件设计原则小记
Huang Chao

Huang Chao

7 日志
3 分类
11 标签
E-Mail
Creative Commons
© 2019 Huang Chao | 22k | 20 分钟
由 Hexo 强力驱动 v3.9.0
|
主题 – NexT.Gemini v7.3.0
|
0%