pytorch中的gather函数
b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) index_2 = torch.LongTensor([[0,1,1],[0,0,0]]) print torch.gather(b, dim=1, index=index_1) print torch.gather(b, dim=0, index=index_2) 结果 1 2 3 4 5 6 [torch.FloatTensor of size 2x3] 1 2 6 4 [torch.FloatTensor of size 2x2] 1 5 6 1 2 3 [torch.FloatTensor of size 2x3] 有一点需要注意是,dim=1是横向看,dim=也是横向看 dim=0时,是纵向看,index是指横向纵向看的索引 来源: CSDN 作者: weixin_36411839 链接: https://blog.csdn.net/weixin_36411839/article/details/103609811