PyTorch get indices of value in two-dimensional tensor

可紊 提交于 2021-02-10 05:29:05

问题


Given the following tensor (or any random tensor with two dimension), I want to get the index of '101':

tens = tensor([[  101,   146,  1176, 21806,  1116,  1105, 18621,   119,   102,     0,
             0,     0,     0],
        [  101,  1192,  1132,  1136,  1184,   146,  1354,  1128,  1127,   117,
          1463,   119,   102],
        [  101,  6816,  1905,  1132, 14918,   119,   102,     0,     0,     0,
             0,     0,     0]])

From the related answers I know that I can do something like this:

idxs = torch.tensor([(i == 101).nonzero() for i in tens])

But this seems messy and potentially quite slow. Is there a better way to do this that is fast and more torch-y?

Related questions discussing only one-dimensional tensor:

  • How Pytorch Tensor get the index of specific value
  • How Pytorch Tensor get the index of elements?

回答1:


How about (tens == 101).nonzero()[:, 1]

In [20]: from torch import tensor                                                                       

In [21]: tens = torch.tensor([[  101,   146,  1176, 21806,  1116,  1105, 18621,   119,   102,     0, 
    ...:              0,     0,     0], 
    ...:         [  101,  1192,  1132,  1136,  1184,   146,  1354,  1128,  1127,   117, 
    ...:           1463,   119,   102], 
    ...:         [  101,  6816,  1905,  1132, 14918,   119,   102,     0,     0,     0, 
    ...:              0,     0,     0]])                                                                

In [22]: (tens == 101).nonzero()[:, 1]                                                                  
Out[22]: tensor([0, 0, 0])


来源:https://stackoverflow.com/questions/59908433/pytorch-get-indices-of-value-in-two-dimensional-tensor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!