In python list, we can use list.index(somevalue). How can pytorch do this? For example:
list.index(somevalue)
a=[1,2,3] print(a.index(2))
Can be done by converting to numpy as follows
import torch x = torch.range(1,4) print(x) ===> tensor([ 1., 2., 3., 4.]) nx = x.numpy() np.where(nx == 3)[0][0] ===> 2