Pytorch基础(二)Tensor的索引和切片
Tensor的index和select
- **Dim 0 first:**多维张量的索引默认为第一个维度索引
a = torch.Tensor(4, 3, 28, 28)
print(a[0].shape) # torch.Size([3,28,28])
print(a[0,0].shape) # troch.Size([28,28])
- 选择前N个或后N个
-
:
代表全部 -
n:
代表从第n个到最后(包括第n个) -
:n
代表从第一个到第n个(不包括第n个) -
n:m
代表从第n个到第m个(包括第n个,不包括第m个) -
n:m:x
代表从第n个到第m个,每隔x个取一个(包括第n个,不包括第m个) - 通用形式为:
start:end:step
(不包括end)
a = torch.Tensor(4, 3, 28, 28)
print(a[:2].shape) # torch.Size([2,3,28,28]) 这里:2代表从0到2(不包括2)
print(a[2:].shape) # torch.Size([2,3,28,28]) 这里2:代表从2到最后(包括2)
print(a[-2:].shape) # torch.Size([2,3,28,28]) 这里-2:代表从倒数第二个到最后(包括倒数第二个)
print(a[:].shape) # torch.Size([4,3,28,28]) 这里:代表这维度的所有元素
print(a[:,:,0:14]) # torch.Size([4,3,14,28]) 这里0:14代表从0到14(不包括14)
print(a[:,:,0:28:2])# torch.Size([4,3,14,28]) 这里0:28:2代表从0到28(不包括28),每两个取一次
- 选择特定的维度
.index_select(dim, index)
a = torch.Tensor(4, 3, 28, 28)
print(a.index_select(0,torch.tensor([2, 3])).shape) # 沿着第0个维度进行切片,取第2和第3个tensor。
print(a.index_select(2,torch.arange(14)).shape) # 沿着第2个维度切片,取前14个tensor
注意,index这个参数必须是
torch.tensor
不能使用python中的list
- 用省略号
...
代表任意维度
这里的
...
代表维度需要根据具体情况进行推测所以这里的
...
必须是可以推测出的维度,比如最左/右或中间的维度
a = torch.Tensor(4, 3, 28, 28)
a[...].shape # torch.Size([4,3,28,28]) 这里代表所有维度
a[:,1,...].shape # torch.Size([4,28,28]) 这里最右边的所有维度
- 通过mask(掩码)来进行筛选
torch.masked_select()
注意,使用
torch.masked_select()
会将数据的维度打平,返回的tensor维度为1,长度不定
x = torch.randn(3,4)
print(x)
# out:
# tensor([[ 0.6797, -0.1078, 0.7623, 0.2214],
# [-1.2354, 0.6120, 2.3871, -1.1993],
# [-0.2460, -1.2034, 0.7166, 0.2186]])
mask = x.ge(0.5)
print(mask)
# out:
# tensor([[ True, False, True, False],
# [False, True, True, False],
# [False, False, True, False]])
a = torch.masked_select(x,mask)
print(a)
# tensor([0.6797, 0.7623, 0.6120, 2.3871, 0.7166])
print(a.shape)
# torch.Size([5])
- 通过将tensor的维度打平来进行select
torch.take()
a = torch.tensor([[4,3,5]
[6,7,8]])
torch.take(a,torch.tensor([0,2])) # tensor([4,5,8])
Tensor的维度变换
常用的API:
- 形状变换:View/reshape
在pytorch0.3之后,view和reshape这两个函数功能完全相同,但要确保前后的numel一致
适合全连接层将维度打平时使用
但在维度改变时,会丢失原先tensor各个维度的意义。例如[b,c,h,w]打平后会破坏各个维度的顺序,还原后顺序会被改变
- 挤压和扩充:Squeeze/unsqueeze
squeeze只能将shape为1的维度压缩
unsqueeze不会改变具体的数据,只是将tensor的维度进行扩充,被扩充的维度的shape还是1
- 转置:Transpose/t/permute
transpose只能将其中某两个维度进行交换
要多个维度交换可以使用多次transpose,也可以直接使用一次permute
不论是transpose还是permute,维度交换后内存地址还是不连续的,一般可以后面添加
.contiguous()
使内存地址连续,这样再进行其他操作减小了报错的可能性
- 维度扩展:Expand/repeat