问题
I have a generic Python memoizer:
cache = {}
def memoize(f):
"""Memoize any function."""
def decorated(*args):
key = (f, str(args))
result = cache.get(key, None)
if result is None:
result = f(*args)
cache[key] = result
return result
return decorated
It works, but I'm not happy with it, because sometimes it's not efficient. Recently, I used it with a function that takes lists as arguments, and apparently making keys with whole lists slowed everything down. What is the best way to do that? (i.e., to efficiently compute keys, whatever the args, and however long or complex they are)
I guess the question is really about how you would efficiently produce keys from the args and the function for a generic memoizer - I have observed in one program that poor keys (too expensive to produce) had a significant impact on the runtime. My prog was taking 45s with 'str(args)', but I could reduce that to 3s with handcrafted keys. Unfortunately, the handcrafted keys are specific to this prog, but I want a fast memoizer where I won't have to roll out specific, handcrafted keys for the cache each time.
回答1:
First, if you're pretty sure that O(N)
hashing is reasonable and necessary here, and you just want to speed things up with a faster algorithm than hash(str(x))
, try this:
def hash_seq(iterable):
result = hash(type(iterable))
for element in iterable:
result ^= hash(element)
return result
Of course this won't work for possibly-deep sequences, but there's an obvious way around that:
def hash_seq(iterable):
result = hash(type(iterable))
for element in iterable:
try:
result ^= hash(element)
except TypeError:
result ^= hash_seq(element)
return result
I don't think sure this is a good-enough hash algorithm, because it will return the same value for different permutations of the same list. But I am pretty sure that no good-enough hash algorithm will be much faster. At least if it's written in C or Cython, which you'll probably ultimately want to do if this is the direction you're going.
Also, it's worth noting that this will be correct in many cases where str
(or marshal
) will not—for example, if your list
may have some mutable element whose repr
involves its id
rather than its value. However, it's still not correct in all cases. In particular, it assumes that "iterates the same elements" means "equal" for any iterable type, which obviously isn't guaranteed to be true. False negatives aren't a huge deal, but false positives are (e.g., two dict
s with the same keys but different values may spuriously compare equal and share a memo).
Also, it uses no extra space, instead of O(N) with a rather large multiplier.
At any rate, it's worth trying this first, and only then deciding whether it's worth analyzing for good-enough-ness and tweaking for micro-optimizations.
Here's a trivial Cython version of the shallow implementation:
def test_cy_xor(iterable):
cdef int result = hash(type(iterable))
cdef int h
for element in iterable:
h = hash(element)
result ^= h
return result
From a quick test, the pure Python implementation is pretty slow (as you'd expect, with all that Python looping, compared to the C looping in str
and marshal
), but the Cython version wins easily:
test_str( 3): 0.015475
test_marshal( 3): 0.008852
test_xor( 3): 0.016770
test_cy_xor( 3): 0.004613
test_str(10000): 8.633486
test_marshal(10000): 2.735319
test_xor(10000): 24.895457
test_cy_xor(10000): 0.716340
Just iterating the sequence in Cython and doing nothing (which is effectively just N calls to PyIter_Next
and some refcounting, so you're not going to do much better in native C) is 70% of the same time as test_cy_xor
. You can presumably make it faster by requiring an actual sequence instead of an iterable, and even more so by requiring a list
, although either way it might require writing explicit C rather than Cython to get the benefits.
Anyway, how do we fix the ordering problem? The obvious Python solution is to hash (i, element)
instead of element
, but all that tuple manipulation slows down the Cython version up to 12x. The standard solution is to multiply by some number between each xor. But while you're at it, it's worth trying to get the values to spread out nicely for short sequences, small int
elements, and other very common edge cases. Picking the right numbers is tricky, so… I just borrowed everything from tuple. Here's the complete test.
_hashtest.pyx:
cdef _test_xor(seq):
cdef long result = 0x345678
cdef long mult = 1000003
cdef long h
cdef long l = 0
try:
l = len(seq)
except TypeError:
# NOTE: This probably means very short non-len-able sequences
# will not be spread as well as they should, but I'm not
# sure what else to do.
l = 100
for element in seq:
try:
h = hash(element)
except TypeError:
h = _test_xor(element)
result ^= h
result *= mult
mult += 82520 + l + l
result += 97531
return result
def test_xor(seq):
return _test_xor(seq) ^ hash(type(seq))
hashtest.py:
import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest
def test_str(seq):
return hash(str(seq))
def test_marshal(seq):
return hash(marshal.dumps(seq))
def test_cy_xor(seq):
return _hashtest.test_xor(seq)
# This one is so slow that I don't bother to test it...
def test_xor(seq):
result = hash(type(seq))
for i, element in enumerate(seq):
try:
result ^= hash((i, element))
except TypeError:
result ^= hash(i, hash_seq(element))
return result
smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]
def run():
for seq in smalltest, bigtest:
for f in test_str, test_marshal, test_cy_xor:
print('%16s(%5d): %9f' % (f.func_name, len(seq),
timeit.timeit(lambda: f(seq), number=10000)))
if __name__ == '__main__':
run()
Output:
test_str( 3): 0.014489
test_marshal( 3): 0.008746
test_cy_xor( 3): 0.004686
test_str(10000): 8.563252
test_marshal(10000): 2.744564
test_cy_xor(10000): 0.904398
Here are some potential ways to make this faster:
- If you have lots of deep sequences, instead of using
try
aroundhash
, callPyObject_Hash
and check for -1. - If you know you have a sequence (or, even better, specifically a
list
), instead of just an iterable,PySequence_ITEM
(orPyList_GET_ITEM
) is probably going to be faster than thePyIter_Next
implicitly used above.
In either case, once you start calling C API calls, it's usually easier to drop Cython and just write the function in C. (You can still use Cython to write a trivial wrapper around that C function, instead of manually coding up the extension module.) And at that point, just borrow the tuplehash
code directly instead of reimplementing the same algorithm.
If you're looking for a way to avoid the O(N)
in the first place, that's just not possible. If you look at how tuple.__hash__, frozenset.__hash__, and ImmutableSet.__hash__ work (the last one is pure Python and very readable, by the way), they all take O(N)
. However, they also all cache the hash values. So, if you're frequently hashing the same tuple
(rather than non-identical-but-equal ones), it approaches constant time. (It's O(N/M)
, where M
is the number of times you call with each tuple
.)
If you can assume that your list
objects never mutate between calls, you can obviously do the same thing, e.g., with a dict
mapping id
to hash
as an external cache. But in general, that obviously isn't a reasonable assumption. (If your list
objects never mutate, it would be easier to just switch to tuple
objects and not bother with all this complexity.)
But you can wrap up your list
objects in a subclass that adds a cached hash value member (or slot), and invalidates the cache whenever it gets a mutating call (append
, __setitem__
, __delitem__
, etc.). Then your hash_seq
can check for that.
The end result is the same correctness and performance as with tuple
s: amortized O(N/M)
, except that for tuple
M
is the number of times you call with each identical tuple
, while for list
it's the number of times you call with each identical list
without mutating in between.
回答2:
You could try a couple of things:
Using marshal.dumps instead of str might be slightly faster (at least on my machine):
>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219
Also, if your functions are expensive to compute, and could possibly return None themselves, then your memoizing function will be re-computing them each time (I'm possibly reaching here, but without knowing more I can only guess). Incorporating these 2 things gives:
import marshal
cache = {}
def memoize(f):
"""Memoize any function."""
def decorated(*args):
key = (f, marshal.dumps(args))
if key in cache:
return cache[key]
cache[key] = f(*args)
return cache[key]
return decorated
来源:https://stackoverflow.com/questions/14074249/efficient-generic-python-memoize