import matplotlib.pyplot as plt
import numpy as np
num_mixtures = 8
radius = 2.0
std = 0.02
thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
mix_coeffs=tuple([1 / num_mixtures] * num_mixtures)
mean=tuple(zip(xs, ys))
cov=tuple([(std, std)] * num_mixtures)
ax = None
epoch = 0
fig = None
def gmm_sample(num_samples, mix_coeffs, mean, cov):
z = np.random.multinomial(num_samples, mix_coeffs)
samples = np.zeros(shape=[num_samples, len(mean[0])])
i_start = 0
for i in range(len(mix_coeffs)):
i_end = i_start + z[i]
samples[i_start:i_end, :] = np.random.multivariate_normal(
mean=np.array(mean)[i, :],
cov=np.diag(np.array(cov)[i, :]),
size=z[i])
i_start = i_end
return samples
def disp_scatter(x, fig=None, ax=None):
if ax is None:
fig, ax = plt.subplots()
ax.scatter(x[:, 0], x[:, 1], s=10, marker='+', color='r', alpha=0.8, label='real data')
ax.legend()
return fig, ax
num_samples=1000
x = gmm_sample(num_samples, mix_coeffs, mean, cov)
fig, ax = disp_scatter(x, fig=None, ax=None)
fig.tight_layout()
fig.savefig("output\{}.png".format(epoch))
num_mixtures = 8

num_mixtures = 1

来源:https://www.cnblogs.com/gaona666/p/12446784.html