文章目录

  • 一、算法流程
  • 二、推导
  • 三、python代码实现

一、算法流程

首先有2点需要注意:

  1. k-meansk-nn(k邻近法) 是不一样的,不要混淆。
  2. k-means本身理解起来不难,讲这个是为了引出后面的EM算法,两者有共通之处。

算法流程如下所示,这里以 k = 3 k=3 k=3, 数据维数 D = 2 D=2 D=2, 数据样本个数 N = 500 N=500 N=500为例:

  1. 给定K个聚类中心,用 μ \mu μ 来表示, 并且适当进行初始化

  2. 现在对于给定的 μ = ( μ 1 , μ 2 , μ 3 ) \mu = (\mu_1,\mu_2,\mu_3) μ=(μ1,μ2,μ3), 找出500个样本数据中距离 μ k \mu_k μk最近的一批数据标注为类别 k k k

  3. 对于属于k类的所有数据进行求平均,将这个平均值作为新的 μ k \mu_k μk使用,同时也得到新的 μ = ( μ 1 , μ 2 , μ 3 ) \mu = (\mu_1, \mu_2, \mu_3) μ=(μ1,μ2,μ3)

  4. 对比更新前后的 μ \mu μ值,如果其差量收束的话停止更新,否则重复第2步

二、推导

  • x x x: D D D维的数据
  • X = { x 1 , x 2 , . . , x N } X = \{x_1, x_2, .., x_N\} X={ x1,x2,..,xN}: N N N个数据样本
  • K K K: 聚类数,已知
  • μ k ( k = 1 , 2 , . . , K ) \mu_k(k=1,2,..,K) μk(k=1,2,..,K): D D D 维的聚类中心(centriod)
  • r n k r_{nk} rnk: 第 n n n个样本属于 k k k类的话该值为1,否则为0

损失函数定义如下:
J = ∑ n = 1 N ∑ k = 1 K r n k ∣ ∣ x n − μ k ∣ ∣ 2 J = \sum_{n=1}^{N}\sum_{k=1}^{K}r_{nk}||x_n-\mu_k||^2 J=n=1Nk=1Krnkxnμk2

该损失函数的最优化流程如下:

  1. 固定 μ k \mu_k μk关于 r n k r_{nk} rnk求偏微分,将其最小化
    d n k = ∣ ∣ x n − μ k ∣ ∣ 2 d_{nk}=||x_n-\mu_k||^2 dnk=xnμk2
    J = ∑ n = 1 N ( r n 1 d n 1 + r n 2 d n 2 + . . . + r n k d n k ) J = \sum_{n=1}^{N}(r_{n1}d_{n1}+r_{n2}d_{n2}+…+r_{nk}d_{nk}) J=n=1N(rn1dn1+rn2dn2+...+rnkdnk)
    为了实现 J J J的最小化,等同于将 ( r n 1 d n 1 + r n 2 d n 2 + . . . + r n k d n k ) (r_{n1}d_{n1}+r_{n2}d_{n2}+…+r_{nk}d_{nk}) (rn1dn1+rn2dn2+...+rnkdnk)最小化,因为 r n k r_{nk} rnk取值为 { 0 , 1 } {\{0,1\}} { 0,1},所以该最小值等同于最小的 d n k d_{nk} dnk,因此除了最小的 d n k d_{nk} dnk r n k r_{nk} rnk为1,其他 r n k r_{nk} rnk均为0
    r n k = { 1    ( k = a r g m i n ∣ ∣ x n − μ k ∣ ∣ 2 ) 0    ( o t h e r w i s e ) r_{nk} = \left\{\begin{array}{l}1\space\space(k=argmin||x_n-\mu_k||^2)\\0\space\space(otherwise)\end{array}\right. rnk={ 1  (k=argminxnμk2)0  (otherwise)

  2. 固定 r n k r_{nk} rnk关于 μ k \mu_k μk求变微分,将其最小化
    ∂ J ∂ μ k = ∂ ∂ μ k ∑ n = 1 N ∑ k = 1 K r n k ∣ ∣ x n − μ k ∣ ∣ 2                                      = ∂ ∂ μ k ∑ n = 1 N ∑ k = 1 K r n k ( − 2 x n T μ k + μ k T μ k )                                                  = ∑ n = 1 N ∑ k = 1 K r n k ( − 2 ∂ ∂ μ k x n T μ k + ∂ ∂ μ k μ k T μ k )      = − 2 ∑ n = 1 N r n k ( x n − μ k ) = 0 \frac{\partial J}{\partial \mu_k}=\frac{\partial}{\partial \mu_k}\sum_{n=1}^{N}\sum_{k=1}^{K}r_{nk}||x_n-\mu_k||^2 \\ \;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\ =\frac{\partial}{\partial \mu_k}\sum_{n=1}^{N}\sum_{k=1}^{K}r_{nk}(-2x_n^T\mu_k+\mu_k^T\mu_k) \\ \;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;= \sum_{n=1}^{N}\sum_{k=1}^{K}r_{nk}(-2\frac{\partial}{\partial \mu_k}x_n^T\mu_k+\frac{\partial}{\partial \mu_k}\mu_k^T\mu_k) \\ \;\;=-2\sum_{n=1}^{N}r_{nk}(x_n-\mu_k)=0 μkJ=μkn=1Nk=1Krnkxnμk2 =μkn=1Nk=1Krnk(2xnTμk+μkTμk)=n=1Nk=1Krnk(2μkxnTμk+μkμkTμk)=2n=1Nrnk(xnμk)=0
    最后再展开一下
    ∑ n = 1 N r n k x n = ∑ n = 1 N r n k μ k \sum_{n=1}^Nr_{nk}x_n = \sum_{n=1}^Nr_{nk}\mu_k n=1Nrnkxn=n=1Nrnkμk
    μ k ∗ = ∑ n = 1 N r n k x n ∑ n = 1 N r n k \mu_k^*=\frac{\sum_{n=1}^{N}r_{nk}x_n}{\sum_{n=1}^{N}r_{nk}} μk=n=1Nrnkn=1Nrnkxn

最后这个最优的 μ k ∗ \mu_k^* μk也很好理解,就是 k k k类所有数据样本的平均值。

三、python代码实现

具体的代码实现可以参考我的github,这里只贴出一部分核心代码

# k-means算法实现
from collections import Counter
iterations = 100
for iter in range(iterations):
  r = np.zeros(N)
  for i in range(N):
    r[i] = np.argmin([np.linalg.norm(data[i] - mu[k]) for k in range(K)])
  
  if iter % 10 == 0:
    print(mu)
    plt.figure()
    for i in range(N):
      plt.scatter(data[i, 0], data[i, 1], s = 30, c = color_dict[r[i]], alpha = 0.5, marker = "+")
    
    for i in range(K):
      ax = plt.axes()
      ax.arrow(mu[i, 0], mu[i, 1], mu_true[i, 0] - mu[i, 0], mu_true[i, 1] - mu[i, 1], lw = 0.8, head_width = 0.02)
      plt.scatter(mu[i,0], mu[i,1], c = color_dict[i], marker = "o", edgecolors = "k", linewidths=1)
      plt.scatter(mu_true[i, 0], mu_true[i, 1], c = color_dict[i], marker = "o", edgecolors = "k", linewidths=2)
  
    plt.title("iter:{}".format(iter))
    plt.show()
  
  mu_prev = mu.copy()
  # cnt = dict(Counter(r))
  # N_k = [cnt[k] for k in range(K)]
  mu = np.array([np.sum(data[r == k], axis = 0)/len(data[r == k]) for k in range(K)])
  diff = mu - mu_prev

本文参考了这篇日文博客

本文地址:https://blog.csdn.net/dhj_tsukuba/article/details/110392431