torch.squeeze(input, dim=None, *, out=None) → Tensor
squeeze
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
#torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
#torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
#torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
y.size()
#torch.Size([2, 2, 1, 2])
unsqueeze
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 0)
# tensor([[ 1, 2, 3, 4]])
torch.unsqueeze(x, 1)
'''
tensor([[ 1],
[ 2],
[ 3],
[ 4]])
'''