ensor的索引、切片和拼接

一、相关命令

命令1:拼接-torch.cat()

  • 格式: torch.cat(tensors, dim=0, out=None) → Tensor
  • 解释:在指定维度上拼接两个tensor
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)  # 在dim=0上拼接,也就是行方向
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1) # 在dim=1上拼接
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

命令2:拼接-torch.stack()

  • 格式:torch.stack(tensors, dim=0, out=None) → Tensor
  • 解释:沿着一个新的维度对张量进行拼接。序列中的tensors必须具有相同的size。
  • **直白的说:**它可以将二维tensor变三维tensor,三维变4维
  • 与torch.cat()的区别。stack()属于扩张再拼接的函数。通常用于NLP和CV领域

pytorch tensor 交集 pytorch tensor拼接_人工智能


注意: 下图及实例来自torch.stack()的官方解释,详解以及例子

pytorch tensor 交集 pytorch tensor拼接_官网_02

# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])
print(torch.stack((T1,T2),dim=0))
print(torch.stack((T1,T2),dim=0).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=1))
print(torch.stack((T1,T2),dim=1).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=2))
print(torch.stack((T1,T2),dim=2).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=3).shape)

输出:

tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
         
torch.Size([2, 3, 3])
==================================
tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]])
         
torch.Size([3, 2, 3])
==================================
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])

torch.Size([3, 3, 2])
==================================
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

命令3:分割-torch.chunk()

  • 格式: torch.chunk(input, chunks, dim=0) → List of Tensors
  • chunks只能是整数
  • 解释:可以看做是torch.cat()的反向操作,分割tensor

注意:

如果沿着指定轴不能整数切分,那么最后一个块将会是最小的

>>> c
tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
        [0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
        [0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
        [0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
>>> cc = torch.chunk(c,2,1) # 很明显size_c = (4,5)不能争分
>>> cc  # 因此cc的两个块分别为(4,3)和(4,2)
(tensor([[0.9387, 0.5666, 0.3289],
        [0.0968, 0.0961, 0.6976],
        [0.4676, 0.7772, 0.2398],
        [0.0588, 0.7729, 0.2259]], dtype=torch.float64), tensor([[0.7775, 0.5938],
        [0.9121, 0.0796],
        [0.5254, 0.9906],
        [0.6438, 0.8299]], dtype=torch.float64))

命令4:分割-torch.split()

  • 格式: torch.split(tensor, split_size_or_sections, dim=0)
  • split_size_or_sections 可以是整数,也可以是一个列表。如果是int,则先尽可能整除,然后最后一个最小;如果是列表,则将分割为len(list)个块,并按照list的元素进行分配
  • 解释:可以看做是torch.cat()的反向操作,分割tensor
>>> c
tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
        [0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
        [0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
        [0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
# 是一个列表
>>>cc = torch.split(c,[2,2,1],1) 
>>> cc
(tensor([[0.9387, 0.5666],
        [0.0968, 0.0961],
        [0.4676, 0.7772],
        [0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
        [0.6976, 0.9121],
        [0.2398, 0.5254],
        [0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
        [0.0796],
        [0.9906],
        [0.8299]], dtype=torch.float64))
# 是一个整数
>>> cc=torch.split(c,2,1)
>>> cc
(tensor([[0.9387, 0.5666],
        [0.0968, 0.0961],
        [0.4676, 0.7772],
        [0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
        [0.6976, 0.9121],
        [0.2398, 0.5254],
        [0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
        [0.0796],
        [0.9906],
        [0.8299]], dtype=torch.float64))

命令5:采集指定维度数据-torch.gather()

  • 格式:torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
  • 解释:从原tensor中按照指定轴dim和索引index获取数据
  • index是一个tensor。输出的维度与index的维度一定时相同的

pytorch tensor 交集 pytorch tensor拼接_pytorch tensor 交集_03


实例:摘自知乎:图解PyTorch中的torch.gather函数

先定义一个原始tensor:

>>> a = torch.arange(3, 12).view(3, 3)
>>> a
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

指定index

index = torch.tensor([[2,1,0]])

在dim=0方向采集

>>> b = torch.gather(a,0,index)
>>> b
tensor([[9, 7, 5]])

pytorch tensor 交集 pytorch tensor拼接_python_04

命令6:维度压缩-torch.squeeze()

  • 格式:torch.squeeze(input, dim=None, out=None) → Tensor
  • 解释:大小为1的维度都被删除

官网实例
如果 input 的size=(A×1×B×C×1×D),则返回的tensor size= (A×B×C×D)

如果指定dim上的维度不为1,则返回不变

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

命令7:维度恢复或添加-torch.unsqueeze()

  • 格式:torch.unsqueeze(input, dim, out=None) → Tensor
  • 解释:在指定轴添加1。

pytorch tensor 交集 pytorch tensor拼接_深度学习_05

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

命令8、tensor扁平化torch.flatten()

  • torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

举例

>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]]])
>>>t.size # (2,2,2)
>>> torch.flatten(t) # 默认从dim=0开始
tensor([1, 2, 3, 4, 5, 6, 7, 8])

>>> torch.flatten(t,start_dim = 1) # 
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])