torch.squeeze(input, dim=None, *, out=None) → Tensor

squeeze

x = torch.zeros(2, 1, 2, 1, 2)
x.size()
#torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
#torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
#torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
y.size()
#torch.Size([2, 2, 1, 2])

unsqueeze

x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 0)
# tensor([[ 1,  2,  3,  4]])
torch.unsqueeze(x, 1)
'''
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])
'''