multiply numpy ndarray with 1d array along a given axis

前端 未结 5 1917
死守一世寂寞
死守一世寂寞 2021-01-04 06:40

It seems I am getting lost in something potentially silly. I have an n-dimensional numpy array, and I want to multiply it with a vector (1d array) along some dimension (whi

5条回答
  •  一个人的身影
    2021-01-04 07:10

    Utilizing casting and views, instead of actually copying data N times into a new array with appropiate shape (as existing answers do) is way more memory efficient. Here is such a method (based on @ShuxuanXU's code):

    def mult_along_axis(A, B, axis):
    
        # ensure we're working with Numpy arrays
        A = np.array(A)
        B = np.array(B)
    
        # shape check
        if axis >= A.ndim:
            raise AxisError(axis, A.ndim)
        if A.shape[axis] != B.size:
            raise ValueError(
                "Length of 'A' along the given axis must be the same as B.size"
                )
    
        # np.broadcast_to puts the new axis as the last axis, so 
        # we swap the given axis with the last one, to determine the
        # corresponding array shape. np.swapaxes only returns a view
        # of the supplied array, so no data is copied unneccessarily.
        shape = np.swapaxes(A, A.ndim-1, axis).shape
    
        # Broadcast to an array with the shape as above. Again, 
        # no data is copied, we only get a new look at the existing data.
        B_brc = np.broadcast_to(B, shape)
    
        # Swap back the axes. As before, this only changes our "point of view".
        B_brc = np.swapaxes(B_brc, A.ndim-1, axis)
    
        return A * B_brc
    

提交回复
热议问题