In python list, we can use list.index(somevalue). How can pytorch do this?
For example:
a=[1,2,3]
print(a.index(2))
for finding index of an element in 1d tensor/array Example
mat=torch.tensor([1,8,5,3])
to find index of 5
five=5
numb_of_col=4
for o in range(numb_of_col):
if mat[o]==five:
print(torch.tensor([o]))
To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)
Example
mat=torch.tensor([[1,2],[4,3])
#to find index of 2
five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
if mat[o] == five:
print(torch.tensor([o]))