发现当我使用DataLoader加载数据的时候使用Module进行前向传播是可以的,但是如果仅仅是对一个img(三维)进行前项传播是不可以的。

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 3, 2, 2], 
but got 3-dimensional input of size [3, 32, 32] instead


发现Dataloader有一个批处理,使得其一个tensor里面包含多个图片,tensor是四维的。

1.增加维度

a = torch.randn(2, 28, 28)

import torch
a = torch.randn(3, 32, 32)
print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(-2).shape)
print(a.unsqueeze(-3).shape)
print(a.unsqueeze(-4).shape)
print(a.unsqueeze(4).shape)


结果:

Pytorch tensor维度变化_数据

2. 删除维度

维度删除的功能并不能做到删除任意维度的数据,只能删除那些size为1的维度

import torch

a = torch.Tensor(1, 4, 1, 9)
print(a.shape)
print(a.squeeze().shape)
print(a.squeeze(0).shape))# 0号维度是1,因此能删除
print(a.squeeze(1).shape)# 1号维度是4,因此不能删除
print(a.squeeze(2).shape)
print(a.squeeze(3).shape)# 3号维度是9,因此不能删除


显示结果:

Pytorch tensor维度变化_批处理_02