How does NumPy Sum (with axis) work?

前端 未结 2 1592
一个人的身影
一个人的身影 2020-12-07 09:44

I\'ve taken it upon myself to learn how NumPy works for my own curiosity.

It seems that the simplest function is the hardest to translate to code (I un

2条回答
  •  心在旅途
    2020-12-07 10:14

    Setup

    consider the numpy array a

    a = np.arange(30).reshape(2, 3, 5)
    print(a)
    
    [[[ 0  1  2  3  4]
      [ 5  6  7  8  9]
      [10 11 12 13 14]]
    
     [[15 16 17 18 19]
      [20 21 22 23 24]
      [25 26 27 28 29]]]
    

    Where are the dimensions?

    The dimensions and positions are highlighted by the following

                p  p  p  p  p
                o  o  o  o  o
                s  s  s  s  s
    
         dim 2  0  1  2  3  4
    
                |  |  |  |  |
      dim 0     ↓  ↓  ↓  ↓  ↓
      ----> [[[ 0  1  2  3  4]   <---- dim 1, pos 0
      pos 0   [ 5  6  7  8  9]   <---- dim 1, pos 1
              [10 11 12 13 14]]  <---- dim 1, pos 2
      dim 0
      ---->  [[15 16 17 18 19]   <---- dim 1, pos 0
      pos 1   [20 21 22 23 24]   <---- dim 1, pos 1
              [25 26 27 28 29]]] <---- dim 1, pos 2
                ↑  ↑  ↑  ↑  ↑
                |  |  |  |  |
    
         dim 2  p  p  p  p  p
                o  o  o  o  o
                s  s  s  s  s
    
                0  1  2  3  4
    

    Dimension examples:

    This becomes more clear with a few examples

    a[0, :, :] # dim 0, pos 0
    
    [[ 0  1  2  3  4]
     [ 5  6  7  8  9]
     [10 11 12 13 14]]
    

    a[:, 1, :] # dim 1, pos 1
    
    [[ 5  6  7  8  9]
     [20 21 22 23 24]]
    

    a[:, :, 3] # dim 2, pos 3
    
    [[ 3  8 13]
     [18 23 28]]
    

    sum

    explanation of sum and axis
    a.sum(0) is the sum of all slices along dim 0

    a.sum(0)
    
    [[15 17 19 21 23]
     [25 27 29 31 33]
     [35 37 39 41 43]]
    

    same as

    a[0, :, :] + \
    a[1, :, :]
    
    [[15 17 19 21 23]
     [25 27 29 31 33]
     [35 37 39 41 43]]
    

    a.sum(1) is the sum of all slices along dim 1

    a.sum(1)
    
    [[15 18 21 24 27]
     [60 63 66 69 72]]
    

    same as

    a[:, 0, :] + \
    a[:, 1, :] + \
    a[:, 2, :]
    
    [[15 18 21 24 27]
     [60 63 66 69 72]]
    

    a.sum(2) is the sum of all slices along dim 2

    a.sum(2)
    
    [[ 10  35  60]
     [ 85 110 135]]
    

    same as

    a[:, :, 0] + \
    a[:, :, 1] + \
    a[:, :, 2] + \
    a[:, :, 3] + \
    a[:, :, 4]
    
    [[ 10  35  60]
     [ 85 110 135]]
    

    default axis is -1
    this means all axes. or sum all numbers.

    a.sum()
    
    435
    

提交回复
热议问题