1 访问 Tensor 中某个元素
和Python中列表的下标一样,Tensor 的索引是基于零(Zero-based)开始计数的。比如我们想要获取向量/数组 a = [0, 1, 2, 3, 4] 中的第 2 个元素,我们可以使用 a[1]。
也可以这样理解:
- 一个标量是一个 0 维 Tensor;
- 一个向量是一个 1 维 Tensor;
- 一个矩阵是一个 2 维 Tensor;
- 一个 n 维数组是一个 n 维 Tensor。
而在访问 n 维 Tensor (假定为 T)的特定某个元素时,可以使用如下语法:
或
即我们要提供 共 n 个索引值,每次索引降低一个维度,最终得到 0 维数字(即标量)。
2 创建2维Tensor
创建一个需要用于后边进行索引查询的Tensor。一个三行两列的随机数矩阵torch.randn(3,2)
。
3 索引使用
3.1 获取第一行所有数据
因为第1行的索引为0,所以可以用如下方式获取第一行所有元素:
t[0]
仅设置行索引,可以获取指定行所有数据;
t[0,]
逗号后为列索引,这里列索引为空,即为所有列;
t[0,:]
逗号后为列索引,列元素索引这里使用冒号,表示的取值范围为全部,即获取所有列的数据。
3.2 获取最后一列的所有数据
可以通过以下方式获取二维Tensor中最后一列的所有数据:
t[:,-1]
逗号前为行索引,这里取的是全部行,列索引使用-1,表示取最后一列。
因为在python语法中,第一个索引值不允许为空,以下语句会报错:
t[,-1]
3.3 获取大于1的张量
这是根据筛选条件获取数据的方式。
Tensor 的变量名为 t,要获取 t 中所有大于1的元素数据,可以设置如下条件:mask = t > 1;
然后将该条件放入 t 中进行筛选。
3.4 获取筛选条件中值为true的索引
在上一步中创建的筛选条件为 mask = t > 1;
作用其实是将 t 中所有元素和 1 进行比较,如果大于1,值为 true,否则值为 false,并将这些 true 和 false 组合成为一个新的 Tensor 。可以查看该mask中的具体内容:
可以通过如下方法,获取到值为 true 的行、列索引:
torch.nonzero(mask)
运行后得到0行1列。
3.5 获取指定索引对应的元素
3.5.1 先看结果
创建要获取元素的索引:
index=torch.tensor([[0,2]])
调用方法 torch.gather(t,0,index) 对Tensor为 t 的张量根据索引采集其中的数据。
3.5.2 分析
torch.gather(t,0,index) 方法中的三个参数分别表示:
t —— 表示需要被采集数据的张量;
0 —— dim值,0表示按列采集,1表示按行采集;
index —— 表示要采集元素在张量中的索引。
在样例中,索引中的数据是使用一位下标表示,但是我们的Tensor是一个二维数组,我们都知道,如果要获取元素,必须明确元素的行下标及列下标才行。
这里使用dim,其实是为我们自动填充下标提供帮助。
如下代码中:
index=torch.tensor([[0,2]])
使用 dim=0 按列获取元素,
torch.gather(t,0,index)
那么列的下标程序会为我们自动生成:
即 [0,0] 和 [2,1]
使用 dim=1 按行获取元素,
torch.gather(t,1,index)
那么行的下标程序会为我们自动补充:
即 [0,0] 和 [0,2],程序执行后会报错,因为Tensor没有索引为2的列。
执行程序进行验证,提示下标越界
3.5.3 复杂验证
新建一个3行2列矩阵,索引也是一个3行2列的矩阵,然后使用上边的规律,分析下获得的元素会有哪些。
- 按列取
列下标根据数组中元素所在的列,自动补充。
- 按行取
行下标根据数组中元素所在的行,自动补充。
3.6 数据散布
使用方法 scatter(dim, index, src),可以将src中的数据,根据index中的索引及dim的方向,散布到 张量 中。
tensor.scatter(dim, index, src)
dim —— 与获取指定索引数据相同,0 按列,1 按行;
index —— 指定索引,执行后,被散布的张量中该索引位置元素将被替换;
src —— 要散布的张量。
3.6.1 创建一个全0的二维Tensor
使用 zeros(3,2) 创建一个全为0的二维Tensor。
3.6.2 按列散布
调用方法 x.scatter(0, index, t)
- 根据按列散布的方式生成索引序列;
- 按照从左到右,从上到下的顺序,分别从src的张量和索引序列中,依次取出元素数据、索引数据;
- 将元素数据,按照索引位置,放入到目标张量中。
3.6.3 按行散布
调用方法 x.scatter(1, index, t)
- 根据按行散布的方式生成索引序列;
- 按照从左到右,从上到下的顺序,分别从src的张量和索引序列中,依次取出元素数据、索引数据;
- 将元素数据,按照索引位置,放入到目标张量中。
3.6.4 修改还是创建
调用方法 scatter(dim, index, src) 对目标张量进行散布操作后,会生成一个新的张量,目标张量的数据并不会被修改;
如果想要对目标张量进行修改,可以调用方法 scatter_(dim, index, src),在进行散布操作后,目标张量的数据会发生修改。