文章目录

  • Pytorch基本数据类型及操作(二)
  • 1. 索引选取
  • 2. 切片选取
  • 3. 步长选取
  • 4. 用...选取
  • 5. 使用mask来索引
  • 6. 使用 take 打成一维


Pytorch基本数据类型及操作(二)

1. 索引选取

a = torch.rand(4, 3, 28, 28)     # 定义a是一个 4张28*28的RGB图 的张量

# 单个选取
print(a.shape)          # torch.Size([4, 3, 28, 28])
print(a[0].shape)       # torch.Size([3, 28, 28])
print(a[0, 0].shape)    # torch.Size([28, 28])
print(a[0, 0, 0].shape) # torch.Size([28])
print(a[0, 0, 2, 4].shape)  # torch.Size([])
  • a[0]:理解为取第 0 张图片,这张图片有 3 个通道,每个通道都是 28 * 28 的
  • a[0, 0]:理解为取第 0 张图片的第 0 个通道,这个通道是 28 * 28 的
  • a[0, 0, 0]:理解为取第 0 张图片的第 0 个通道的第 0 行像素点,这一行一共有 28 个像素点
  • a[0, 0, 2, 4]:理解为取第 0 张图片的第 0 个通道的第 2 行第 4 列的像素点,它是一个点
  • a[:2]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片有 3 个通道,每个通道都是 28 * 28 的

2. 切片选取

a = torch.rand(4, 3, 28, 28)     # 定义a是一个 4张28*28的RGB图 的张量

# 连续选取
print(a[:2].shape)              # torch.Size([2, 3, 28, 28])
print(a[:2, :1, :, :].shape)    # torch.Size([2, 1, 28, 28])
print(a[:2, 1:, :, :].shape)    # torch.Size([2, 2, 28, 28])
print(a[:2, -1:, :, :].shape)   # torch.Size([2, 1, 28, 28])
  • a[:2, :1]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片取第 0 个通道,每个通道都是 28 * 28 的
  • a[:2, 1:, :, :]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片取第 1 个和第 2 个通道,每个通道都是 28 * 28 的
  • a[:2, -1:, :, :]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片取第 2 个通道,每个通道都是 28 * 28 的

3. 步长选取

a = torch.rand(4, 3, 28, 28)     # 定义a是一个 4张28*28的RGB图 的张量

# 间隔选取
print(a[:, :, 0:28:2, 0:28:2].shape)    # torch.Size([4, 3, 14, 14])
print(a[:, :, ::2, ::2].shape)          # torch.Size([4, 3, 14, 14])
  • a[:, :, 0:28:2, 0:28:2]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片有 3 个通道,对行和列像素点以步长为2从0至28取行和列
  • a[:, :, ::2, ::2]:理解为取第 1 张和第 2 张和第 3 张图片,每张图片有 3 个通道,对行和列像素点以步长为2从0至28取行和列
def index_select(dim, index)
  • dim:表示要操作的维度
  • index:表示该维度下的哪些值,这里的index必须是一个tensor,不能直接是一个list
a = torch.rand(4, 3, 28, 28)     # 定义a是一个 4张28*28的RGB图 的张量

print(a.index_select(0, torch.tensor([0, 2])).shape)    # torch.Size([2, 3, 28, 28])
print(a.index_select(1, torch.tensor([1, 2])).shape)    # torch.Size([4, 2, 28, 28])
print(a.index_select(2, torch.arange(28)).shape)        # torch.Size([4, 3, 28, 28])
print(a.index_select(2, torch.arange(8)).shape)         # torch.Size([4, 3, 8, 28])
  • a.index_select(0, torch.tensor([0, 2])):表示要操作第 0 维度,即图片张数维度,取第 0 张和第 2 张图片
  • a.index_select(1, torch.tensor([1, 2])):表示要操作第 1 维度,即图片通道维度,取第 1 个和第 2 个通道
  • a.index_select(2, torch.arange(28)):表示要操作第 2 个维度,即图片高度像素点维度,取前 28 个行
  • a.index_select(2, torch.arange(8)):表示要操作第 2 个维度,即图片高度像素点维度,取前 8 个行

4. 用…选取

  • ...可以取代:,增加方便性
print(a[...].shape) 		# torch.Size([4, 3, 28, 28])
print(a[0, ...].shape) 		# torch.Size([3, 28, 28])
print(a[:, 1, ...].shape) 	# torch.Size([4, 28, 28])
print(a[..., :2].shape) 	# torch.Size([4, 3, 28, 2])
  • a[…]:表示选取所有维度(即 4 个维度)上的所有值,相当于a[:, :, :, :]
  • a[0, …]:表示选取后3个维度的所有值
  • a[:, 1, …]:表示选取第 1 和维度图片张数中所有值,选取第二个维度RGB三个通道中第1个通道,选取三四维度所有值
  • a[…, :2]:表示选取前三个维度所有值,第四个维度宽度像素点维度中取前2个值

5. 使用mask来索引

x = torch.randn(3, 4)
print(x)
'''
tensor([[-0.4146, -0.1112, -0.6213, -0.3464],
        [-1.0482,  0.2925,  1.0796,  0.1143],
        [-0.7203,  0.5699,  1.3800, -0.3570]])
'''

mask = x.ge(0.5)	
print(mask)
'''
tensor([[False, False, False, False],
        [False, False,  True, False],
        [False,  True,  True, False]])
'''

c = torch.masked_select(x, mask)
print(c)         # tensor([1.0796, 0.5699, 1.3800])
print(c.shape)   # torch.Size([3])
print(c.dim())   # 1
  • mask = x.ge(0.5):表示将 x 中所有大于等于 0.5 的设置为 True ,反之设置为 False ,生成由 True 和 False 组成的Tensor
  • masked_select 函数是从 mask 中取出来所有值为 True 的值,形成一个 1 维的Tensor

6. 使用 take 打成一维

def take(input, index) -> Tensor
  • input:是一个Tensor ,将input变成一维的Tensor
  • index:是一个Tensor,表示从打平后的一维input中取出哪些下标的值
  • 返回值:由index为下标的值组成的一维Tensor
# src是一个二维的 2 * 3 的Tensor
src = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

dst = torch.take(src, torch.tensor([0, 2, 5]))  
print(dst)          # tensor([1, 3, 6])
print(dst.shape)    # torch.Size([3])
print(dst.dim())    # 1
  • torch.take(src, torch.tensor([0, 2, 5])) :表示将 src 打平成一维后,取出下标为 0、2、5的值组织成一个一维Tensor