Array and __rmul__ operator in Python Numpy

前端 未结 3 914
太阳男子
太阳男子 2021-01-05 01:10

In a project, I created a class, and I needed an operation between this new class and a real matrix, so I overloaded the __rmul__ function like this

<         


        
3条回答
  •  [愿得一人]
    2021-01-05 01:26

    The behaviour is expected.

    First of all you have to understand how an operation like x*y is actually executed. The python interpreter will first try to compute x.__mul__(y). If this call returns NotImplemented it will then try to compute y.__rmul__(x). Except when y is a proper subclass of the type of x, in this case the interpreter will first consider y.__rmul__(x) and then x.__mul__(y).

    Now what happens is that numpy treats arguments differently depending on whether or not he thinks the argument are scalar or arrays.

    When dealing with arrays * does element-by-element multiplication, while scalar multiplication multiplies all the entry of an array by the given scalar.

    In your case foo() is considered as a scalar by numpy, and thus numpy multiplies all elements of the array by foo. Moreover, since numpy doesn't know about the type foo it returns an array with dtype=object, so the object returned is:

    array([[0, 0],
           [0, 0],
           [0, 0]], dtype=object)
    

    Note: numpy's array does not return NotImplemented when you try to compute the product, so the interpreter calls numpy's array __mul__ method, which performs scalar multiplication as we said. At this point numpy will try to multiply each entry of the array by your "scalar" foo(), and here's is where your __rmul__ method gets called, because the numbers in the array return NotImplemented when their __mul__ is called with a foo argument.

    Obviously if you change the order of the arguments to the initial multiplication your __mul__ method gets called immediately and you don't have any trouble.

    So, to answer your question, one way to handle this is to have foo inherit from ndarray, so that the second rule applies:

    class foo(np.ndarray):
        def __new__(cls):
           # you must implement __new__
        # code as before
    

    Warning however that subclassing ndarray isn't straightforward. Moreover you might have other side effects, since now your class is an ndarray.

提交回复
热议问题