Indexing a multi-dimensional tensor with a tensor in PyTorch

梦想与她 提交于 2019-12-10 03:38:02

问题


I have the following code:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

I have a multi-dimensional index b and want to use it to select a single cell in a. If b wasn't a tensor, I could do:

a[1,1,1,1]

Which returns the correct cell, but:

a[b]

Doesn't work, because it just selects a[1] four times.

How can I do this? Thanks


回答1:


A more elegant (and simpler) solution might be to simply cast b as a tuple:

a[tuple(b)]
Out[10]: tensor(5.)

I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.




回答2:


You can split b into 4 using chunk, and then use the chunked b to index the specific element you want:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

What's nice about it is that it can be easily generalized to any dimension of a, you just need to make number of chucks equal the dimension of a.



来源:https://stackoverflow.com/questions/52092230/indexing-a-multi-dimensional-tensor-with-a-tensor-in-pytorch

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!