Understanding Gaussian Mixture Models

后端 未结 1 1288
南方客
南方客 2020-12-13 02:48

I am trying to understand the results from the scikit-learn gaussian mixture model implementation. Take a look at the following example:

#!/opt/local/bin/pyt         


        
相关标签:
1条回答
  • 2020-12-13 03:27

    Just in case anyone in the future is wondering about the same thing: One has to normalise the individual components, not the sum:

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.mixture import GaussianMixture
    
    # Define simple gaussian
    def gauss_function(x, amp, x0, sigma):
        return amp * np.exp(-(x - x0) ** 2. / (2. * sigma ** 2.))
    
    # Generate sample from three gaussian distributions
    samples = np.random.normal(-0.5, 0.2, 2000)
    samples = np.append(samples, np.random.normal(-0.1, 0.07, 5000))
    samples = np.append(samples, np.random.normal(0.2, 0.13, 10000))
    
    # Fit GMM
    gmm = GaussianMixture(n_components=3, covariance_type="full", tol=0.001)
    gmm = gmm.fit(X=np.expand_dims(samples, 1))
    
    # Evaluate GMM
    gmm_x = np.linspace(-2, 1.5, 5000)
    gmm_y = np.exp(gmm.score_samples(gmm_x.reshape(-1, 1)))
    
    # Construct function manually as sum of gaussians
    gmm_y_sum = np.full_like(gmm_x, fill_value=0, dtype=np.float32)
    for m, c, w in zip(gmm.means_.ravel(), gmm.covariances_.ravel(), gmm.weights_.ravel()):
        gauss = gauss_function(x=gmm_x, amp=1, x0=m, sigma=np.sqrt(c))
        gmm_y_sum += gauss / np.trapz(gauss, gmm_x) * w
    
    # Make regular histogram
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[8, 5])
    ax.hist(samples, bins=50, normed=True, alpha=0.5, color="#0070FF")
    ax.plot(gmm_x, gmm_y, color="crimson", lw=4, label="GMM")
    ax.plot(gmm_x, gmm_y_sum, color="black", lw=4, label="Gauss_sum", linestyle="dashed")
    
    # Annotate diagram
    ax.set_ylabel("Probability density")
    ax.set_xlabel("Arbitrary units")
    
    # Make legend
    plt.legend()
    
    plt.show()
    

    0 讨论(0)
提交回复
热议问题