Log-computations in Python

人盡茶涼 提交于 2019-12-01 01:19:37

The sum

is the expectation of f with respect to a binomial distribution with n = 5000 and p = 0.5.

You can compute this with scipy.stats.binom.expect:

import scipy.stats as stats

def f(i):
    return i
n, p = 5000, 0.5
print(stats.binom.expect(f, (n, p), lb=0, ub=n))
# 2499.99999997

Also note that as n goes to infinity, with p fixed, the binomial distribution approaches the normal distribution with mean np and variance np*(1-p). Therefore, for large n you can instead compute:

import math
print(stats.norm.expect(f, loc=n*p, scale=math.sqrt((n*p*(1-p))), lb=0, ub=n))
# 2500.0

By default, comb gives you a float64, which overflows and gives you inf.

But if you pass exact=True, it gives you a Python variable-sized int instead, which can't overflow (unless you get so ridiculously huge you run out of memory).

And, while you can't use np.log2 on an int, you can use Python's math.log2.

So:

math.log2(scipy.misc.comb(5000, 2000, exact=True))

As an alternative, you relative that n choose k is defined as n!k / k!, right? You can reduce that to ∏(i=1...k)((n+1-i)/i), which is simple to compute.

Or, if you want to avoid overflow, you can do it by alternating * (n-i) and / (k-i).

Which, of course, you can also reduce to adding and subtracting logs. I think looping in Python and computing 4000 logarithms is going to be slower than looping in C and computing 4000 multiplications, but we can always vectorize it, and then, it might be faster. Let's write it and test:

In [1327]: n, k = 5000, 2000
In [1328]: %timeit math.log2(scipy.misc.comb(5000, 2000, exact=True))
100 loops, best of 3: 1.6 ms per loop
In [1329]: %timeit np.log2(np.arange(n-k+1, n+1)).sum() - np.log2(np.arange(1, k+1)).sum()
10000 loops, best of 3: 91.1 µs per loop

Of course if you're more concerned with memory instead of time… well, this obviously makes it worse. We've got 2000 8-byte floats instead of one 608-byte integer at a time. And if you go up to 100000, 20000, you get 20000 8-byte floats instead of one 9K integer. And at 1000000, 200000, it's 200000 8-byte floats vs. one 720K integer.

I'm not sure why either way is a problem for you. Especially given that you're using a listcomp instead of a genexpr, and therefore creating an unnecessary list of 5000, 100000, or 1000000 Python floats—24MB is not a problem, but 720K is? But if it is, we can obviously just do the same thing iteratively, at the cost of some speed:

r = sum(math.log2(n-i) - math.log2(k-i) for i in range(n-k))

This isn't too much slower than the scipy solution, and it never uses more than a small constant number of bytes (a handful of Python floats). (Unless you're on Python 2, in which case… just use xrange instead of range and it's back to constant.)


As a side note, why are you using a list comprehension instead of an NumPy array with vectorized operations (for speed, and also a bit of compactness) or a generator expression instead of a list comprehension (for no memory usage at all, at no cost to speed)?

EDIT: @unutbu has answered the real question, but I'll leave this here in case log2comb(n, k) is useful to anyone.


comb(n, k) is n! / ((n-k)! k!), and n! can be computed using the Gamma function gamma(n+1). Scipy provides the function scipy.special.gamma. Scipy also provides gammaln, which is the log (natural log, that is) of the Gamma function.

So log(comb(n, k)) can be computed as gammaln(n+1) - gammaln(n-k+1) - gammaln(k+1)

For example, log(comb(100, 8)) (after executing from scipy.special import gammaln):

In [26]: log(comb(100, 8))
Out[26]: 25.949484949043022

In [27]: gammaln(101) - gammaln(93) - gammaln(9)
Out[27]: 25.949484949042962

and log(comb(5000, 2000)):

In [28]: log(comb(5000, 2000))  # Overflow!
Out[28]: inf

In [29]: gammaln(5001) - gammaln(3001) - gammaln(2001)
Out[29]: 3360.5943053174142

(Of course, to get the base-2 logarithm, just divide by log(2).)

For convenience, you can define:

from math import log
from scipy.special import gammaln

def log2comb(n, k):
    return (gammaln(n+1) - gammaln(n-k+1) - gammaln(k+1)) / log(2) 
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!