本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。
一、张量拼接与切分
1.1 torch.cat
功能:将张量按维度dim 进行拼接
- tensors : 张量序列
- dim: 要拼接的维度
(2,3) -> (2,6)
这里的dim维度与axis相同,0代表列,1代表行。
1.2 torch.stack
功能:在新创建的维度 dim 上进行拼接(会拓宽原有的张量维度)
- tensors:张量序列
- dim:要拼接的维度
可见,它在新的维度上进行了拼接。
参数[t, t, t]的意思就是在第n个维度上拼接成这个样子。
1.3 torch.chunk
功能:将张量按维度 dim 进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量。
- input : 要切分的张量
- chunks 要切分的份数
- dim 要切分的维度
code
可知,切分是7/3向上取整,每份是3,最后剩下的维度直接输出即可。
1.4 torch.split
torch.split(Tensor, split_size_or_sections, dim)
功能:将张量按维度 dim 进行切分
返回值:张量列表
- tensor : 要切分的张量
- split_size_or_sections 为 int 时,表示 每一份的长度;为 list 时,按 list 元素切分
- dim 要切分的维度
code:
是按照指定长度list进行切分的。注意list中长度总和必须为原张量在改维度的大小,不然会报错。
二、张量索引
2.1 torch.index_select
torch.index_select(input, dim, index, out=None)
功能:在维度dim 上,按 index 索引数据
返回值:依index 索引数据拼接的张量
- input : 要索引的张量
- dim 要索引的维度
- index 要索引数据的序号
code:
可见idx是一个存储序号的张量,而torch.index_select通过该张量索引原tensor并且拼接返回。
2.2 torch.masked_select
功能:按mask 中的 True 进行索引
返回值:一维张量(无法确定true的个数,因此也就无法显示原来的形状,因此这里返回一维张量)
- input : 要索引的张量
- mask 与 input 同形状的布尔类型张量
通过掩码来索引。