1 访问 Tensor 中某个元素

和Python中列表的下标一样,Tensor 的索引是基于零(Zero-based)开始计数的。比如我们想要获取向量/数组 a = [0, 1, 2, 3, 4] 中的第 2 个元素,我们可以使用 a[1]。

也可以这样理解:

  • 一个标量是一个 0 维 Tensor;
  • 一个向量是一个 1 维 Tensor;
  • 一个矩阵是一个 2 维 Tensor;
  • 一个 n 维数组是一个 n 维 Tensor。

而在访问 n 维 Tensor (假定为 T)的特定某个元素时,可以使用如下语法:

PyTorch项目实战04——Tensor的索引_自动生成 或 PyTorch项目实战04——Tensor的索引_数据_02

即我们要提供 PyTorch项目实战04——Tensor的索引_数据_03 共 n 个索引值,每次索引降低一个维度,最终得到 0 维数字(即标量)。

2 创建2维Tensor

创建一个需要用于后边进行索引查询的Tensor。一个三行两列的随机数矩阵torch.randn(3,2)

PyTorch项目实战04——Tensor的索引_数组_04

3 索引使用

3.1 获取第一行所有数据

因为第1行的索引为0,所以可以用如下方式获取第一行所有元素:

t[0] 仅设置行索引,可以获取指定行所有数据;

t[0,] 逗号后为列索引,这里列索引为空,即为所有列;

t[0,:] 逗号后为列索引,列元素索引这里使用冒号,表示的取值范围为全部,即获取所有列的数据。

PyTorch项目实战04——Tensor的索引_数据_05

3.2 获取最后一列的所有数据

可以通过以下方式获取二维Tensor中最后一列的所有数据:

t[:,-1] 逗号前为行索引,这里取的是全部行,列索引使用-1,表示取最后一列。

因为在python语法中,第一个索引值不允许为空,以下语句会报错:

t[,-1]

PyTorch项目实战04——Tensor的索引_自动生成_06

3.3 获取大于1的张量

这是根据筛选条件获取数据的方式。

Tensor 的变量名为 t,要获取 t 中所有大于1的元素数据,可以设置如下条件:mask = t > 1;

然后将该条件放入 t 中进行筛选。

PyTorch项目实战04——Tensor的索引_自动生成_07

3.4 获取筛选条件中值为true的索引

在上一步中创建的筛选条件为 mask = t > 1;

作用其实是将 t 中所有元素和 1 进行比较,如果大于1,值为 true,否则值为 false,并将这些 true 和 false 组合成为一个新的 Tensor 。可以查看该mask中的具体内容:

PyTorch项目实战04——Tensor的索引_数组_08

可以通过如下方法,获取到值为 true 的行、列索引:

torch.nonzero(mask)

运行后得到0行1列。

PyTorch项目实战04——Tensor的索引_数组_09

3.5 获取指定索引对应的元素

3.5.1 先看结果

创建要获取元素的索引:

index=torch.tensor([[0,2]])

调用方法 torch.gather(t,0,index) 对Tensor为 t 的张量根据索引采集其中的数据。

PyTorch项目实战04——Tensor的索引_自动生成_10

3.5.2 分析

torch.gather(t,0,index) 方法中的三个参数分别表示:

t —— 表示需要被采集数据的张量;

0 —— dim值,0表示按列采集,1表示按行采集;

index —— 表示要采集元素在张量中的索引。

在样例中,索引中的数据是使用一位下标表示,但是我们的Tensor是一个二维数组,我们都知道,如果要获取元素,必须明确元素的行下标及列下标才行。

这里使用dim,其实是为我们自动填充下标提供帮助。

如下代码中:

index=torch.tensor([[0,2]])

使用 dim=0 按列获取元素,

torch.gather(t,0,index)

那么列的下标程序会为我们自动生成:

即 [0,0] 和 [2,1]

PyTorch项目实战04——Tensor的索引_数据_11

使用 dim=1 按行获取元素,

torch.gather(t,1,index)

那么行的下标程序会为我们自动补充:

即 [0,0] 和 [0,2],程序执行后会报错,因为Tensor没有索引为2的列。

PyTorch项目实战04——Tensor的索引_数据_12


执行程序进行验证,提示下标越界

PyTorch项目实战04——Tensor的索引_自动生成_13

3.5.3 复杂验证

新建一个3行2列矩阵,索引也是一个3行2列的矩阵,然后使用上边的规律,分析下获得的元素会有哪些。

PyTorch项目实战04——Tensor的索引_数据_14

  • 按列取

列下标根据数组中元素所在的列,自动补充。

PyTorch项目实战04——Tensor的索引_自动生成_15

  • 按行取

行下标根据数组中元素所在的行,自动补充。

PyTorch项目实战04——Tensor的索引_自动生成_16


3.6 数据散布

使用方法 scatter(dim, index, src),可以将src中的数据,根据index中的索引及dim的方向,散布到 张量 中。

tensor.scatter(dim, index, src)

dim —— 与获取指定索引数据相同,0 按列,1 按行;

index —— 指定索引,执行后,被散布的张量中该索引位置元素将被替换

src —— 要散布的张量。

3.6.1 创建一个全0的二维Tensor

使用 zeros(3,2) 创建一个全为0的二维Tensor。

PyTorch项目实战04——Tensor的索引_数据_17

3.6.2 按列散布

调用方法 x.scatter(0, index, t)

  • 根据按列散布的方式生成索引序列;
  • 按照从左到右,从上到下的顺序,分别从src的张量和索引序列中,依次取出元素数据、索引数据;
  • 将元素数据,按照索引位置,放入到目标张量中。

PyTorch项目实战04——Tensor的索引_数据_18

3.6.3 按行散布

调用方法 x.scatter(1, index, t)

  • 根据按行散布的方式生成索引序列;
  • 按照从左到右,从上到下的顺序,分别从src的张量和索引序列中,依次取出元素数据、索引数据;
  • 将元素数据,按照索引位置,放入到目标张量中。

PyTorch项目实战04——Tensor的索引_数组_19

3.6.4 修改还是创建

调用方法 scatter(dim, index, src) 对目标张量进行散布操作后,会生成一个新的张量,目标张量的数据并不会被修改;

PyTorch项目实战04——Tensor的索引_数据_20

如果想要对目标张量进行修改,可以调用方法 scatter_(dim, index, src),在进行散布操作后,目标张量的数据会发生修改。

PyTorch项目实战04——Tensor的索引_数组_21