np.add.at indexing with array

后端 未结 2 1862
孤街浪徒
孤街浪徒 2020-12-30 06:12

I\'m working on cs231n and I\'m having a difficult time understanding how this indexing works. Given that

x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [         


        
2条回答
  •  一个人的身影
    2020-12-30 06:22

    In [226]: x = [[0,4,1], [3,2,4]]
         ...: dW = np.zeros((5,6),int)
    
    In [227]: np.add.at(dW,x,1)
    In [228]: dW
    Out[228]: 
    array([[0, 0, 0, 1, 0, 0],
           [0, 0, 0, 0, 1, 0],
           [0, 0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0, 0],
           [0, 0, 1, 0, 0, 0]])
    

    With this x there aren't any duplicate entries, so add.at is the same as using += indexing. Equivalently we can read the changed values with:

    In [229]: dW[x[0], x[1]]
    Out[229]: array([1, 1, 1])
    

    The indices work the same either way, including broadcasting:

    In [234]: dW[...]=0
    In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
    In [236]: dW
    Out[236]: 
    array([[0, 0, 0, 0, 0, 0],
           [0, 0, 1, 0, 2, 0],
           [0, 0, 1, 0, 2, 0],
           [0, 0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0, 0]])
    

    possible values

    The values have to be broadcastable, with respect to the indexes:

    In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
    ...
    In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
    ...
    ValueError: array is not broadcastable to correct shape
    In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])
    
    In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])
    
    In [118]: dW
    Out[118]: 
    array([[ 0,  0,  0,  0,  0,  0],
           [ 0,  0,  3,  0,  9,  0],
           [ 0,  0,  4,  0, 11,  0],
           [ 0,  0,  0,  0,  0,  0],
           [ 0,  0,  0,  0,  0,  0]])
    

    In this case the indices define a (2,3) shape, so (2,3),(3,), (2,1), and scalar values work. (6,) does not.

    In this case, add.at is mapping a (2,3) array onto a (2,2) subarray of dW.

提交回复
热议问题