张量的合并操作
张量的合并操作类似与列表的追加元素,可以拼接、也可以堆叠。
1.拼接方法:.cat()
PyTorch中,可以使用.cat()方法实现张量的拼接,不改变张量形状,并且返回结果是原张量的视图。
(1).cat()方法的使用,第一个参数和第二个参数:为目标张量,第三个参数:0表示行数增加,1表示列数增加,注意观察张量形状
a = torch.zeros(2, 3) #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.]])
b = torch.ones(2, 3) #创建2行3列元素全部为1(浮点型)的二维张量
#结果为:tensor([[1., 1., 1.],
[1., 1., 1.]])
c = torch.zeros(3, 3) #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
torch.cat([a, b]) # 按照行进行拼接,dim默认取值为0,行数增加,观察仔细
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]])
torch.cat([a, b], 1) # 按照列进行拼接,列数增加,观察仔细,矩阵的嵌套
#结果为:tensor([[0., 0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1., 1.]])
torch.cat([a, c], 1) # 形状不匹配时将报错
#报错
RuntimeError Traceback (most recent call last)
<ipython-input-153-8bdd1a857266> in <module>
----> 1 torch.cat([a, c], 1) # 形状不匹配时将报错
RuntimeError: Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 0 (The offending index is 1)
注意理解:拼接的本质是实现元素的堆积,也就是构成a、b两个二维张量的各一维张量的堆积,最终还是构成二维向量。
2.堆叠方法:.stack()
和拼接不同,堆叠不是将元素拆分重装,而是简单的将各参与堆叠的对象分装到一个更高维度的张量里。
(1).stack()方法的使用,第一个参数和第二个参数:为目标张量,第三个参数:0表示行数增加,1表示列数增加,注意观察张量形状
a = torch.zeros(2, 3) #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.]])
b = torch.ones(2, 3) #创建2行3列元素全部为1(浮点型)的二维张量
#结果为:tensor([[1., 1., 1.],
[1., 1., 1.]])
c = torch.zeros(3, 3) #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
torch.stack([a, b]) # 堆叠之后,生成一个三维张量
#结果为:tensor([[[0., 0., 0.],
[0., 0., 0.]],
[[1., 1., 1.],
[1., 1., 1.]]])
torch.stack([a, b]).shape #查看堆叠后的形状
#结果为:torch.Size([2, 2, 3])
表示:其是3维张量,由2个二维张量组成,每个二维张量有2个一维张量组成,每个一维张量有3个元素
torch.cat([a, b])
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]])
注意对比二者区别,拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量中,而堆叠则是直接将两个二维张量封装到一个三维张量中,因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。
a = torch.zeros(2, 3) #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.]])
c = torch.zeros(3, 3) #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
torch.cat([a, c]) # 横向拼接时,对行数没有一致性要求
#结果为:tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
torch.stack([a, c]) # 维度不匹配时,堆叠也会报错
#报错:
RuntimeError Traceback (most recent call last)
<ipython-input-167-0311d15e051e> in <module>
----> 1 torch.stack([a, c]) # 维度不匹配时也会报错
RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3, 3] at entry 1
以上不是本人的浅显见解,还请他人多多指导,更正错误。