b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
结果
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
有一点需要注意是,dim=1是横向看,dim=也是横向看
dim=0时,是纵向看,index是指横向纵向看的索引
来源:CSDN
作者:weixin_36411839
链接:https://blog.csdn.net/weixin_36411839/article/details/103609811