Is it possible to numpy.vectorize an instance method?

前端 未结 5 518
我在风中等你
我在风中等你 2021-01-17 08:51

I\'ve found that the numpy.vectorize allows one to convert \'ordinary\' functions which expect a single number as input to a function which can also convert a list of inputs

5条回答
  •  猫巷女王i
    2021-01-17 09:32

    Here's a generic decorator that works with instance methods as well as functions (refer to Numpy's documentation for otypes and signature):

    from functools import wraps
    
    import numpy as np
    
    def vectorize(otypes=None, signature=None):
        """Numpy vectorization wrapper that works with instance methods."""
        def decorator(fn):
            vectorized = np.vectorize(fn, otypes=otypes, signature=signature)
            @wraps(fn)
            def wrapper(*args):
                return vectorized(*args)
            return wrapper
        return decorator
    

    You may use it to vectorize your method as follows:

    class Dummy(object):
        def __init__(self, val=1):
            self.val = val
    
        @vectorize(signature="(),()->()")
        def f(self, x):
            if x == 0:
                return self.val
            else:
                return 2
    
    
    def test_3():
        assert list(Dummy().f([0, 1, 2])) == [1, 2, 2]
    

    The key is to make use of the signature kwarg. Parenthesized values to the left of -> specify input parameters and values to the right specify output values. () represents a scalar (0-dimensional vector); (n) represents a 1-dimensional vector; (m,n) represents a 2-dimensional vector; (m,n,p) represents a 3-dimensional vector; etc. Here, signature="(),()->()" specifies to Numpy that the first parameter (self) is a scalar, the second (x) is also a scalar, and the method returns a scalar (either self.val or 2, depending on x).

    $ pytest /tmp/instance_vectorize.py
    ======================= test session starts ========================
    platform linux -- Python 3.6.5, pytest-3.5.1, py-1.5.3, pluggy-0.6.0
    rootdir: /tmp, inifile:
    collected 1 item
    
    ../../tmp/instance_vectorize.py .                                                                                                                                                     [100%]
    
    ==================== 1 passed in 0.08 seconds ======================
    

提交回复
热议问题