What does the gather function do in pytorch in layman terms?

前端 未结 3 1520
情深已故
情深已故 2020-12-07 10:59

I have been through the official doc and this but it is hard to understand what is going on.

I am trying to understand a DQN source code and it uses the gather funct

3条回答
  •  忘掉有多难
    2020-12-07 11:38

    torch.gather creates a new tensor from the input tensor by taking the values from each row along the input dimension dim. The values in torch.LongTensor, passed as index, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly:

    (Note: In the illustration, indexing starts from 1 and not 0).

    In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result, it takes row value from the index for the src that is 1. At (1,1) in source value is 1 so, outputs 1 at (1,1) in result. Similarly for (2,2) the row value from the index for src is 3. At (3,2) the value in src is 8 and hence outputs 8 and so on.

    Similarly for second example, indexing is along columns, and hence at (2,2) position of the result, the column value from the index for src is 3, so at (2,3) from src ,6 is taken and outputs to result at (2,2)

提交回复
热议问题