Suppose I have a torch.Tensor t of shape (10, 20, 20, 10). I want to index along the first and last dimensions only.
torch.Tensor t
(10, 20, 20, 10)
In this case, since I k