How two rows can be swapped in a torch tensor?

南笙酒味 提交于 2021-02-16 20:22:07

问题


var = [[0, 1, -4, 8],
       [2, -3, 2, 1],
       [5, -8, 7, 1]]

var = torch.Tensor(var)

Here, var is a 3 x 4 (2d) tensor. How the first and second row can be swapped to get the following 2d tensor?

2, -3, 2, 1 
0, 1, -4, 8
5, -8, 7, 1

回答1:


The other answer does not work, as some dimensions get overwritten before they are copied:

>>> var = [[0, 1, -4, 8],
       [2, -3, 2, 1],
       [5, -8, 7, 1]]
>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> x[index] = x
>>> x
tensor([[ 0,  1, -4,  8],
        [ 0,  1, -4,  8],
        [ 5, -8,  7,  1]])

For me, it suffices to create a new tensor (with separate underlying storage) to hold the result:

>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> y = torch.zeros_like(x)
>>> y[index] = x

Alternatively, you can use (index_copy_)[https://pytorch.org/docs/stable/tensors.html#torch.Tensor.index_copy_] (following the explanation in discuss.pytorch.org), although I don't see an advantage for either way at the moment.




回答2:


Generate the permutation index you desire:

index = torch.LongTensor([1,0,2])

Apply the permutation:

var[index] = var



回答3:


As other answers suggested that your permutation index should be a tensor itself, but it is not necessary. You can swap 1st and 2nd row like this:

>>> var
tensor([[ 0,  1, -4,  8],
        [ 2, -3,  2,  1],
        [ 5, -8,  7,  1]])

>>> var[[0, 1]] = var[[1, 0]]

>>> var
tensor([[ 2, -3,  2,  1],
        [ 0,  1, -4,  8],
        [ 5, -8,  7,  1]])

var can be a NumPy array or PyTorch tensor.



来源:https://stackoverflow.com/questions/44935176/how-two-rows-can-be-swapped-in-a-torch-tensor

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!