PyTorch框架学习三——张量操作
- 一、拼接
- 1.torch.cat()
- 2.torch.stack()
- 二、切分
- 1.torch.chunk()
- 2.torch.split()
- 三、索引
- 1.torch.index_select()
- 2.torch.masked_select()
- 四、变换
- 1.torch.reshape()
- 2.torch.transpace()
- 3.torch.t()
- 4.torch.squeeze()
- 5.torch.unsqueeze()
一、拼接
1.torch.cat()
功能:将tensor按照维度dim进行拼接,除了需要拼接的维度外,其余维度尺寸得是相同的。
torch.cat(tensors, dim=0, out=None)
看一下所有的参数:
- tensors:需要被拼接的张量序列。
- dim:(int,可选)被拼接的维度,默认为0。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
2.torch.stack()
功能:在新创建的维度dim上进行拼接,所有的张量必须是相同的维度。
torch.stack(tensors, dim=0, out=None)
注意:stack()会创建一个新的维度。
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
原来t的维度是(2, 3),本来是没有第三维的,但是stack()会构建新的dim=2,就是先构建第三维dim=2,然后在该维度上进行拼接。
二、切分
1.torch.chunk()
功能:将tensor按维度dim进行平均切分。如果不能整除,最后一份tensor在该维度上的长度小于其他tensor。
torch.chunk(input, chunks, dim=0)
- input:要切分的张量。
- chunks:要切分的份数。
- dim:要切分的维度,默认为0。
a = torch.ones((2, 7)) # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
2.torch.split()
功能:将tensor按dim进行切分。
torch.split(tensor, split_size_or_sections, dim=0)
- tensor:要切分的张量。
- split_size_or_sections:(int或list(int))为int时,表示每一份的长度,如果不能整除,最后一份的长度要小于其他的张量,为list时,按list元素来切分。
- dim:同上。
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
三、索引
1.torch.index_select()
功能:在dim上,按照index索引数据,返回一个依据index索引数据拼接的张量。
torch.index_select(input, dim, index, out=None)
- input:要索引的张量。
- dim:被索引的维度。
- index:一维张量,包括了要索引的数据序号。(long,不能是float)
- out:输出张量(可选)。
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
2.torch.masked_select()
功能:按照mask中的True进行索引,返回一个一维张量。
torch.masked_select(input, mask, out=None)
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
四、变换
1.torch.reshape()
功能:变换张量的形状。
torch.reshape(input, shape)
- input:输入张量。
- shape:新张量的形状。当某个维度为-1时,表示该维度不用关心,可以从别的维度计算得到。
>>> a = torch.arange(4.)
>>> torch.reshape(a, (2, 2))
tensor([[ 0., 1.],
[ 2., 3.]])
>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
tensor([ 0, 1, 2, 3])
2.torch.transpace()
功能:交换tensor的两个维度。
torch.transpose(input, dim0, dim1)
- input:输入张量。
- dim0和dim1:要交换的两个维度。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
3.torch.t()
功能:2维tensor转置,对矩阵而言。等价于torch.transpose(input, 0, 1)。
torch.t(input)
>>> x = torch.randn(())
>>> x
tensor(0.1995)
>>> torch.t(x)
tensor(0.1995)
>>> x = torch.randn(3)
>>> x
tensor([ 2.4320, -0.4608, 0.7702])
>>> torch.t(x)
tensor([ 2.4320, -0.4608, 0.7702])
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.4875, 0.9158, -0.5872],
[ 0.3938, -0.6929, 0.6932]])
>>> torch.t(x)
tensor([[ 0.4875, 0.3938],
[ 0.9158, -0.6929],
[-0.5872, 0.6932]])
注意:只对矩阵会转置,对标量和向量都不会。
4.torch.squeeze()
功能:压缩长度为1的维度(轴)。
torch.squeeze(input, dim=None, out=None)
- input:输入张量。
- dim:(可选)若为None,移除所有长度为1的轴,若指定轴,当且仅当该轴长度为1时移除。
- out:输出张量。
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
5.torch.unsqueeze()
功能:返回一个新的张量,对输入的指定位置插入维度 1。
torch.unsqueeze(input, dim)
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])