Indexing a multi-dimensional tensor with a tensor in PyTorch

旧时模样 提交于 2019-12-05 02:57:15

A more elegant (and simpler) solution might be to simply cast b as a tuple:

a[tuple(b)]
Out[10]: tensor(5.)

I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.

You can split b into 4 using chunk, and then use the chunked b to index the specific element you want:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

What's nice about it is that it can be easily generalized to any dimension of a, you just need to make number of chucks equal the dimension of a.

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!