张量的拼接
torch.cat(tensors, dim=0, out=None)
功能:将张量按维度dim进行拼接
·tensors:张量序列
·dim:要拼接的维度
import torch
t = torch.ones((2,3))
t_0 = torch.cat([t,t], dim=0)
t_1 = torch.cat([t,t], dim=1)
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_0,t_0.shape,t_1,t_1.shape))
torch.stack(tensors, dim=0, out=None
功能:在新创建的维度dim上进行拼接
·tensors:张量序列
·dim:要拼接的维度
t = torch.ones((2,3))
t_stack1 = torch.stack([t,t], dim=0)
t_stack2 = torch.stack([t,t], dim=2)
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_stack1,t_stack1.shape,t_stack2,t_stack2.shape))
张量的切分
torch.chunk(input, chunks, dim=0)
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
·input:要切分的张量
·chunks:要切分的份数
·dim:要切分的维度
t = torch.ones((2,5))
list_of_tensors = torch.chunk(t, dim=1, chunks=2)
for idx, mat in enumerate(list_of_tensors):
print('第{}个张量:{}, 维度为{}'.format(idx+1,mat,mat.shape))
torch.split(tensor, split_size_or_sections, dim=0)
功能:将张量安慰度dim进行切分
返回值:张量列表
·tensor:要切分的张量
·split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
·dim:要切分的维度
t = torch.ones((2,5))
list_of_tensors = torch.split(t, [2,1,2], dim=1)
for idx, mat in enumerate(list_of_tensors):
print('第{}个张量:{}, 维度为{}'.format(idx+1,mat,mat.shape))
张量索引
torch.index_select(input,dim=0,index=None)
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
·input:要索引的张量
·dim:要索引的维度
·index:要索引数据的序号
t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2], dtype=torch.long) #float
t_select = torch.index_select(t, dim=0, index=idx)
print('{}\n{}'.format(t, t_select))
torch.masked_select(input, mask, out=None)
功能:按mask中的True进行索引
返回值:一维张量
·input:要索引的张量
·mask:与input同形状的布尔类型张量
t = torch.randint(0,9,size=(3,3))
#返回大小为t的矩阵,其中大于等于5的元素为True,小于5的为False
mask = t.ge(5)
t_select = torch.masked_select(t, mask)
print('t:\n{}\nmask:\n{}\nt_select:\n{}'.format(t,mask,t_select))
张量变换
torch.reshape(input, shape)
功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
·input:要变换的张量
·shape:新张量的形状
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1,2,2))
print('t:\n{}\nt_reshape:\n{}'.format(t, t_reshape))
print('t内存地址{}'.format(id(t.data)))
print('t_reshape内存地址{}'.format(id(t_reshape.data)))
torch.transpose(input, dim0, dim1)
功能:交换张量的两个维度
·input:要变换的张量
·dim0:要变换的维度
·dim1:要变换的维度
t = torch.rand((2,3,4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)
print('t shape:{} t_transpose shape:{}'.format(t.shape, t_transpose.shape))
torch.t(input)
功能:2维张量转置,对矩阵而言,等价于torch.transpsoe(input,0,1)
torch.squeeze(input, dim=None, out=None)
功能:压缩长度为1的维度(轴)
·dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;
t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)#第二个维度是2故无法压缩掉
torch.usqueeze(input, dim, out=None)
功能:依据dim扩展维度
·dim:扩展的维度
t = torch.rand((1,2,3))
t_sq = torch.unsqueeze(t,dim=3)
print(t.shape)
print(t_sq.shape)