问题
I have a list of indices:
indx = torch.LongTensor([
    [ 0,  2,  0],
    [ 0,  2,  4],
    [ 0,  4,  0],
    [ 0, 10, 14],
    [ 1,  4,  0],
    [ 1,  8,  2],
    [ 1, 12,  0]
])
And I have a tensor of 2x2 blocks:
blocks = torch.FloatTensor([
    [[1.5818, 2.3108],
     [2.6742, 3.0024]],
    [[2.0472, 1.6651],
     [3.2807, 2.7413]],
    [[1.5587, 2.1905],
     [1.9231, 3.5083]],
    [[1.6007, 2.1426],
     [2.4802, 3.0610]],
    [[1.9087, 2.1021],
     [2.7781, 3.2282]],
    [[1.5127, 2.6322],
     [2.4233, 3.6836]],
    [[1.9645, 2.3831],
     [2.8675, 3.3770]]
])
What I want to do is to add each block at an index position to another tensor (i.e. so that it starts at that index). Let's assume that I want to add it to the following tensor:
a = torch.ones([2,18,18])
Is there any efficient way to do so? So far I came up only with:
i = 0
for b, x, y in indx:
   a[b, x:x+2, y:y+2] += blocks[i]
   i += 1
It is quite inefficient, I also tried to use index_add, but it did not work properly.
回答1:
You are looking to index on three different dimensions at the same time. I had a look around in the documentation, torch.index_add will only receive a vector as index. My hopes were on torch.scatter but it doesn't to fit well to this problem. As it turns out you can achieve this pretty easily with a little work, the most difficult parts are the setup and teardown. Please hang on tight.
I'll use a simplified example here, but the same can be applied with larger tensors.
>>> indx 
tensor([[ 0,  2,  0],
        [ 0,  2,  4],
        [ 0,  4,  0]]))
>>> blocks
tensor([[[1.5818, 2.3108],
         [2.6742, 3.0024]],
        [[2.0472, 1.6651],
         [3.2807, 2.7413]],
        [[1.5587, 2.1905],
         [1.9231, 3.5083]]])
>>> a
tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]])
The main issue here is that you are looking index with slicing. That not possible in a vectorize form. To counter that though you can convert your a tensor into 2x2 chunks. This will be particulary handy since we will be able to access sub-tensors such as a[0, 2:4, 4:6] with just a[0, 1, 2]. Since the 2:4 slice on dim=1 will be grouped together on index=1 while the 4:6 slice on dim=0 will be grouped on index=2.
First we will convert a to tensor made up of 2x2 chunks. Then we will update with blocks. Finally, we will stitch back the resulting tensor into the original shape.
1. Converting a to a 2x2-chunks tensor
You can use a combination of torch.chunk and torch.cat (not torch.dog) twice: on dim=1 and dim=2. The shape of a is (1, h, w) so we're looking for a result of shape (1, h//2, w//2, 2, 2).
To do so we will unsqueeze two axes on a:
>>> a_ = a[:, None, :, None, :]
>>> a_.shape
torch.Size([1, 1, 6, 1, 6])
Then make 3 chunks on dim=2, then concatenate on dim=1:
>>> a_row_chunks = torch.cat(torch.chunk(a_, 3, dim=2), dim=1)
>>> a_row_chunks.shape
torch.Size([1, 3, 2, 1, 6])
And make 3 chunks on dim=4, then concatenate on dim=3:
>>> a_col_chunks  = torch.cat(torch.chunk(a_row_chunks, 3, dim=4), dim=3)
>>> a_col_chunks.shape
torch.Size([1, 3, 2, 3, 2])
Finally reshape all.
>>> a_chunks = a_col_chunks.reshape(1, 3, 3, 2, 2)
Create a new index with adjusted values for our new tensor with. Essentially we divide all values by 2 except for the first column which is the index of dim=0 in a which was unchanged. There's some fiddling around with the types (in short: it has to be a float in order to divide by 2 but needs to be cast back to a long in order for the indexing to work):
>>> indx_ = indx.clone().float()
>>> indx_[:, 1:] /= 2
>>> indx_ = indx_.long()
tensor([[0, 1, 0],
        [0, 1, 2],
        [0, 2, 0]])
2. Updating with blocks
We will simply index and accumulate with:
>>> a_chunks[indx_[:, 0], indx_[:, 1], indx_[:, 2]] += blocks
3. Putting it back together
I thought that was it, but actually converting a_chunk back to a 6x6 tensor is way trickier than it seems. Apparently torch.cat can only receive a tuple. I won't go into to much detail: tuple() will only consider the first axis, as a workaround you can use torch.permute to switch the axes. This combined with two torch.cat will do:
>>> a_row_cat = torch.cat(tuple(a_chunks.permute(1, 0, 2, 3, 4)), dim=2)
>>> a_row_cat.shape
torch.Size([1, 3, 6, 2])
>>> A = torch.cat(tuple(a_row_cat.permute(1, 0, 2, 3)), dim=2)
>>> A.shape
torch.Size([1, 6, 6])
>>> A
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.5818, 2.3108, 0.0000, 0.0000, 2.0472, 1.6651],
         [2.6742, 3.0024, 0.0000, 0.0000, 3.2807, 2.7413],
         [1.5587, 2.1905, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9231, 3.5083, 0.0000, 0.0000, 0.0000, 0.0000]]])
Et voilà.
If you didn't quite get how the chunks worked. Run this:
for x in range(0, 6, 2):
    for y in range(0, 6, 2):
        a *= 0
        a[:, x:x+2, y:y+2] = 1
        print(a)
And see for yourself: each 2x2 block of 1s corresponds to a chunk in a_chunks.
So you can do the same with:
for x in range(3):
    for y in range(3):
        a_chunks *= 0
        a_chunks[:, x, y] = 1
        print(a_chunks)
来源:https://stackoverflow.com/questions/65571114/add-blocks-of-values-to-a-tensor-at-specific-locations-in-pytorch