Efficient generic Python memoize

旧城冷巷雨未停 提交于 2020-01-03 20:45:33

问题


I have a generic Python memoizer:

cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, str(args))
        result = cache.get(key, None)
        if result is None:
            result = f(*args)
            cache[key] = result
        return result

    return decorated

It works, but I'm not happy with it, because sometimes it's not efficient. Recently, I used it with a function that takes lists as arguments, and apparently making keys with whole lists slowed everything down. What is the best way to do that? (i.e., to efficiently compute keys, whatever the args, and however long or complex they are)

I guess the question is really about how you would efficiently produce keys from the args and the function for a generic memoizer - I have observed in one program that poor keys (too expensive to produce) had a significant impact on the runtime. My prog was taking 45s with 'str(args)', but I could reduce that to 3s with handcrafted keys. Unfortunately, the handcrafted keys are specific to this prog, but I want a fast memoizer where I won't have to roll out specific, handcrafted keys for the cache each time.


回答1:


First, if you're pretty sure that O(N) hashing is reasonable and necessary here, and you just want to speed things up with a faster algorithm than hash(str(x)), try this:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        result ^= hash(element)
    return result

Of course this won't work for possibly-deep sequences, but there's an obvious way around that:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        try:
            result ^= hash(element)
        except TypeError:
            result ^= hash_seq(element)
    return result

I don't think sure this is a good-enough hash algorithm, because it will return the same value for different permutations of the same list. But I am pretty sure that no good-enough hash algorithm will be much faster. At least if it's written in C or Cython, which you'll probably ultimately want to do if this is the direction you're going.

Also, it's worth noting that this will be correct in many cases where str (or marshal) will not—for example, if your list may have some mutable element whose repr involves its id rather than its value. However, it's still not correct in all cases. In particular, it assumes that "iterates the same elements" means "equal" for any iterable type, which obviously isn't guaranteed to be true. False negatives aren't a huge deal, but false positives are (e.g., two dicts with the same keys but different values may spuriously compare equal and share a memo).

Also, it uses no extra space, instead of O(N) with a rather large multiplier.

At any rate, it's worth trying this first, and only then deciding whether it's worth analyzing for good-enough-ness and tweaking for micro-optimizations.

Here's a trivial Cython version of the shallow implementation:

def test_cy_xor(iterable):
    cdef int result = hash(type(iterable))
    cdef int h
    for element in iterable:
        h = hash(element)
        result ^= h
    return result

From a quick test, the pure Python implementation is pretty slow (as you'd expect, with all that Python looping, compared to the C looping in str and marshal), but the Cython version wins easily:

    test_str(    3):  0.015475
test_marshal(    3):  0.008852
    test_xor(    3):  0.016770
 test_cy_xor(    3):  0.004613
    test_str(10000):  8.633486
test_marshal(10000):  2.735319
    test_xor(10000): 24.895457
 test_cy_xor(10000):  0.716340

Just iterating the sequence in Cython and doing nothing (which is effectively just N calls to PyIter_Next and some refcounting, so you're not going to do much better in native C) is 70% of the same time as test_cy_xor. You can presumably make it faster by requiring an actual sequence instead of an iterable, and even more so by requiring a list, although either way it might require writing explicit C rather than Cython to get the benefits.

Anyway, how do we fix the ordering problem? The obvious Python solution is to hash (i, element) instead of element, but all that tuple manipulation slows down the Cython version up to 12x. The standard solution is to multiply by some number between each xor. But while you're at it, it's worth trying to get the values to spread out nicely for short sequences, small int elements, and other very common edge cases. Picking the right numbers is tricky, so… I just borrowed everything from tuple. Here's the complete test.

_hashtest.pyx:

cdef _test_xor(seq):
    cdef long result = 0x345678
    cdef long mult = 1000003
    cdef long h
    cdef long l = 0
    try:
        l = len(seq)
    except TypeError:
        # NOTE: This probably means very short non-len-able sequences
        # will not be spread as well as they should, but I'm not
        # sure what else to do.
        l = 100
    for element in seq:
        try:
            h = hash(element)
        except TypeError:
            h = _test_xor(element)
        result ^= h
        result *= mult
        mult += 82520 + l + l
    result += 97531
    return result

def test_xor(seq):
    return _test_xor(seq) ^ hash(type(seq))

hashtest.py:

import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest

def test_str(seq):
    return hash(str(seq))

def test_marshal(seq):
    return hash(marshal.dumps(seq))

def test_cy_xor(seq):
    return _hashtest.test_xor(seq)

# This one is so slow that I don't bother to test it...
def test_xor(seq):
    result = hash(type(seq))
    for i, element in enumerate(seq):
        try:
            result ^= hash((i, element))
        except TypeError:
            result ^= hash(i, hash_seq(element))
    return result

smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]

def run():
    for seq in smalltest, bigtest:
        for f in test_str, test_marshal, test_cy_xor:
            print('%16s(%5d): %9f' % (f.func_name, len(seq),
                                      timeit.timeit(lambda: f(seq), number=10000)))

if __name__ == '__main__':
    run()

Output:

    test_str(    3):  0.014489
test_marshal(    3):  0.008746
 test_cy_xor(    3):  0.004686
    test_str(10000):  8.563252
test_marshal(10000):  2.744564
 test_cy_xor(10000):  0.904398

Here are some potential ways to make this faster:

  • If you have lots of deep sequences, instead of using try around hash, call PyObject_Hash and check for -1.
  • If you know you have a sequence (or, even better, specifically a list), instead of just an iterable, PySequence_ITEM (or PyList_GET_ITEM) is probably going to be faster than the PyIter_Next implicitly used above.

In either case, once you start calling C API calls, it's usually easier to drop Cython and just write the function in C. (You can still use Cython to write a trivial wrapper around that C function, instead of manually coding up the extension module.) And at that point, just borrow the tuplehash code directly instead of reimplementing the same algorithm.

If you're looking for a way to avoid the O(N) in the first place, that's just not possible. If you look at how tuple.__hash__, frozenset.__hash__, and ImmutableSet.__hash__ work (the last one is pure Python and very readable, by the way), they all take O(N). However, they also all cache the hash values. So, if you're frequently hashing the same tuple (rather than non-identical-but-equal ones), it approaches constant time. (It's O(N/M), where M is the number of times you call with each tuple.)

If you can assume that your list objects never mutate between calls, you can obviously do the same thing, e.g., with a dict mapping id to hash as an external cache. But in general, that obviously isn't a reasonable assumption. (If your list objects never mutate, it would be easier to just switch to tuple objects and not bother with all this complexity.)

But you can wrap up your list objects in a subclass that adds a cached hash value member (or slot), and invalidates the cache whenever it gets a mutating call (append, __setitem__, __delitem__, etc.). Then your hash_seq can check for that.

The end result is the same correctness and performance as with tuples: amortized O(N/M), except that for tuple M is the number of times you call with each identical tuple, while for list it's the number of times you call with each identical list without mutating in between.




回答2:


You could try a couple of things:

Using marshal.dumps instead of str might be slightly faster (at least on my machine):

>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219

Also, if your functions are expensive to compute, and could possibly return None themselves, then your memoizing function will be re-computing them each time (I'm possibly reaching here, but without knowing more I can only guess). Incorporating these 2 things gives:

import marshal
cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, marshal.dumps(args))
        if key in cache:
            return cache[key]

        cache[key] = f(*args)
        return cache[key]

    return decorated


来源:https://stackoverflow.com/questions/14074249/efficient-generic-python-memoize

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!