pytorch中tensor的基本维度变换

旧巷老猫 提交于 2020-02-08 04:44:58

直接从代码中学习tensor的一些维度变换操作:

import torch

torch.manual_seed(2020)

x = torch.rand(1, 2, 3)
print(x)
# tensor([[[0.4869, 0.1052, 0.5883],
#          [0.1161, 0.4949, 0.2824]]])

print(x.view(-1, 3).size())      # torch.Size([2, 3])

print('\ntranspose:')
print(torch.transpose(x, 0, 1))
print(x.transpose(0, 1).size())  # torch.Size([2, 1, 3])
print(x.transpose(1, 2).size())  # torch.Size([1, 3, 2])
# transpose要指明待交换的维度

print('\ntorch.cat:')
y = torch.rand(1, 1, 3)
print(torch.cat((x, y), dim=1).size())  # torch.Size([1, 3, 3])
# dim指定待拼接的维度;待拼接的两个向量除了待拼接的维度,其余维度必须相等或为空

print('\ntorch.chunk:')
x_chunks = torch.chunk(x, chunks=2, dim=1)  # x_chunks是一个tuple
print(x_chunks)
# (tensor([[[0.4869, 0.1052, 0.5883]]]), tensor([[[0.1161, 0.4949, 0.2824]]]))
print(x_chunks[0].size(), x_chunks[1].size())
# torch.Size([1, 1, 3]) torch.Size([1, 1, 3])
print(torch.chunk(x, 2, 2))  # 不能整除时,最后一个chunk较小
# (tensor([[[0.4869, 0.1052], [0.1161, 0.4949]]]), 
#  tensor([[[0.5883], [0.2824]]]))
print(torch.chunk(x, 4, 2))  # chunks大于tensor在维度dim上的值时,每个chunk均为1
# (tensor([[[0.4869], [0.1161]]]), 
#  tensor([[[0.1052], [0.4949]]]), 
#  tensor([[[0.5883], [0.2824]]]))
# torch.chunk将tensor在dim维度上划分为chunks块;

print('\ntorch.split:')
z = torch.rand(4, 6, 8)
z_split = torch.split(z, split_size_or_sections=2, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 2, 8]), torch.Size([4, 2, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=4, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 4, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=[3, 3], dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 3, 8]), torch.Size([4, 3, 8])]
# torch.split也是将tensor在指定的维度上分成若干块,不同于torch.chunk的是:
# torch.chunk指定分成几个chunk;torch.split指定每个chunk的大小
# torch.chunk和torch.split可以看作是torch.cat的反面

print('\ntorch.stack:')
a = torch.rand(2, 3, 4)
b = torch.rand(2, 3, 4)
print(torch.stack((a, b), dim=0).size())  # torch.Size([2, 2, 3, 4])
c = torch.rand(2, 3, 4)
print(torch.stack((a, b, c), dim=0).size())  # torch.Size([3, 2, 3, 4])
# torch.stack与torch.cat的区别:前者在新的维度上拼接;后者在已有的维度上拼接

print('\ntorch.squeeze:')
d = torch.rand(1, 2, 3, 1)
print(torch.squeeze(d).size())  # torch.Size([2, 3])
print(torch.squeeze(d, dim=3).size())    # torch.Size([1, 2, 3])
# torch.squeeze去掉大小为1的维度;dim默认为None,去掉所有大小为1的维度;
# 指定dim时,只去掉指定的大小为1的维度;若指定的维度大小不为1,则不起作用
print(torch.unsqueeze(d, dim=0).size())  # torch.Size([1, 1, 2, 3, 1])
print(torch.unsqueeze(d, dim=-1).size()) # torch.Size([1, 2, 3, 1, 1])
# torch.unsqueeze在指定位置增加一个大小为1的维度
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!