How can implement EM-GMM in python?

后端 未结 2 1314
面向向阳花
面向向阳花 2020-12-06 23:06

I have implemented EM algorithm for GMM using this post GMMs and Maximum Likelihood Optimization Using NumPy unsuccessfully as follows:

import numpy as np

de         


        
2条回答
  •  眼角桃花
    2020-12-06 23:09

    As I mentioned in the comment, the critical point that I see is the means initialization. Following the default implementation of sklearn Gaussian Mixture, instead of random initialization, I switched to KMeans.

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.style.use('seaborn')
    
    eps=1e-8 
    
    def PDF(data, means, variances):
        return 1/(np.sqrt(2 * np.pi * variances) + eps) * np.exp(-1/2 * (np.square(data - means) / (variances + eps)))
    
    def EM_GMM(data, k=3, iterations=100, init_strategy='kmeans'):
        weights = np.ones((k, 1)) / k # shape=(k, 1)
        
        if init_strategy=='kmeans':
            from sklearn.cluster import KMeans
            
            km = KMeans(k).fit(data[:, None])
            means = km.cluster_centers_ # shape=(k, 1)
            
        else: # init_strategy=='random'
            means = np.random.choice(data, k)[:, np.newaxis] # shape=(k, 1)
        
        variances = np.random.random_sample(size=k)[:, np.newaxis] # shape=(k, 1)
    
        data = np.repeat(data[np.newaxis, :], k, 0) # shape=(k, n)
    
        for step in range(iterations):
            # Expectation step
            likelihood = PDF(data, means, np.sqrt(variances)) # shape=(k, n)
    
            # Maximization step
            b = likelihood * weights # shape=(k, n)
            b /= np.sum(b, axis=1)[:, np.newaxis] + eps
    
            # updage means, variances, and weights
            means = np.sum(b * data, axis=1)[:, np.newaxis] / (np.sum(b, axis=1)[:, np.newaxis] + eps)
            variances = np.sum(b * np.square(data - means), axis=1)[:, np.newaxis] / (np.sum(b, axis=1)[:, np.newaxis] + eps)
            weights = np.mean(b, axis=1)[:, np.newaxis]
            
        return means, variances
    

    This seems to yield the desired output much more consistently:

    s = np.array([25.31      , 24.31      , 24.12      , 43.46      , 41.48666667,
                  41.48666667, 37.54      , 41.175     , 44.81      , 44.44571429,
                  44.44571429, 44.44571429, 44.44571429, 44.44571429, 44.44571429,
                  44.44571429, 44.44571429, 44.44571429, 44.44571429, 44.44571429,
                  44.44571429, 44.44571429, 39.71      , 26.69      , 34.15      ,
                  24.94      , 24.75      , 24.56      , 24.38      , 35.25      ,
                  44.62      , 44.94      , 44.815     , 44.69      , 42.31      ,
                  40.81      , 44.38      , 44.56      , 44.44      , 44.25      ,
                  43.66666667, 43.66666667, 43.66666667, 43.66666667, 43.66666667,
                  40.75      , 32.31      , 36.08      , 30.135     , 24.19      ])
    k=3
    n_iter=100
    
    means, variances = EM_GMM(s, k, n_iter)
    print(means,variances)
    [[44.42596231]
     [24.509301  ]
     [35.4137508 ]] 
    [[0.07568723]
     [0.10583743]
     [0.52125856]]
    
    # Plotting the results
    colors = ['green', 'red', 'blue', 'yellow']
    bins = np.linspace(np.min(s)-2, np.max(s)+2, 100)
    
    plt.figure(figsize=(10,7))
    plt.xlabel('$x$')
    plt.ylabel('pdf')
    
    sns.scatterplot(s, [0.05] * len(s), color='navy', s=40, marker=2, label='Series data')
    
    for i, (m, v) in enumerate(zip(means, variances)):
        sns.lineplot(bins, PDF(bins, m, v), color=colors[i], label=f'Cluster {i+1}')
    
    plt.legend()
    plt.plot()
    

    Finally we can see that the purely random initialization generates different results; let's see the resulting means:

    for _ in range(5):
        print(EM_GMM(s, k, n_iter, init_strategy='random')[0], '\n')
    
    [[44.42596231]
     [44.42596231]
     [44.42596231]]
    
    [[44.42596231]
     [24.509301  ]
     [30.1349997 ]]
    
    [[44.42596231]
     [35.4137508 ]
     [44.42596231]]
    
    [[44.42596231]
     [30.1349997 ]
     [44.42596231]]
    
    [[44.42596231]
     [44.42596231]
     [44.42596231]]
    

    One can see how different these results are, in some cases the resulting means is constant, meaning that inizalization chose 3 similar values and didn't change much while iterating. Adding some print statements inside the EM_GMM will clarify that.

提交回复
热议问题