Plotting a mixture distribution in sympy.stats

时光怂恿深爱的人放手 提交于 2019-12-10 23:54:39

问题


( gist of this Q here )

I'd like create a mixture of two Gamma distributions and plot the result, evaluated over a given range.

It would appear that sympy.stats is capable of this because it is able to compute the expectation of the mixture and sample from it. I'm quite new to sympy, so not sure if there is a preferred way for evaluating and plotting in this situation than the one I've been using.

%matplotlib inline
from matplotlib import pyplot as plt
from sympy.stats import Gamma, E, density
import numpy as np

G1 = Gamma("G1", 5, 2.5)
G2 = Gamma("G2", 4, 1.5)
f1 = 0.7; f2 = 1-f1
G3 = f1*G1 + f2*G2

Expectation gives me single sensible number for all 3

In [19]: E(G1)
Out[19]: 12.5000000000000

In [20]: E(G2)
Out[20]: 6.00000000000000

In [21]: E(G3)
Out[21]: 10.5500000000000

...but plotting fails on the mixture

u = np.linspace(0, 50)
D1 = density(G1); D2 = density(G2); D3 = density(G3)
v1 = [D1.args[1].subs(D1.args[0][0], i).evalf() for i in u]
v2 = [D2.args[1].subs(D2.args[0][0], i).evalf() for i in u]
v3 = [D3.args[1].subs(D3.args[0][0], i).evalf() for i in u]

plt.plot(u, v1)
plt.plot(u, v2)
plt.plot(u, v3)  # this one fails with error 'can't convert expression to float'

The problem would appear to be that the mixture terms still contain free symbols

In [44]:  v1[0].free_symbols
Out[44]:  set()

In [45]:  v3[0].free_symbols
Out[45]:  {x}

...as I said, sympy.stats appears to be dealing with this ok somehow in computing the expectation, I assume. So I think I need to apply that machinery here in evaluating and plotting the mixture distribution (?)


回答1:


It looks like this was fixed. I can reproduce your error in SymPy 0.7.3 but it works just fine in 0.7.4.1, the latest version.

First off, you don't need the fanagling with the .args. The expressions returned by density are callable. Just call D1(i).evalf() to get the numerical value of D1 at i, like

D1 = density(G1); D2 = density(G2); D3 = density(G3)
v1 = [D1(i).evalf() for i in u]
v2 = [D2(i).evalf() for i in u]
v3 = [D3(i).evalf() for i in u]

I've uploaded a working version to http://nbviewer.ipython.org/gist/asmeurer/8486176.



来源:https://stackoverflow.com/questions/21187257/plotting-a-mixture-distribution-in-sympy-stats

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!