该函数的作用为:收集指定索引位置的值。

先将函数原型写出:

torch.gather(input, dim, index, out=None) → Tensor

参数:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标
  • out (Tensor, optional) – 目标张量

首先来用自己的语言解释该函数的各个参数。

第一个参数就是自己要收集值的Tensor,不用多解释。

第二个参数就是指你要收集值的轴(也可以理解为行或列),如果是0,则按照横轴收集。如果是1,则按照纵轴收集。

第三个参数就是对应于你要搜集的Tensor的下标。

第四个参数一般缺省,这里不做细致的讨论。

下面直接上例子:

Torch.gather_pytorch

 首先,我们先创建一个2维的Tensor。

然后我们就采用torch.gather()函数来取这个Tensor里面的值。

Torch.gather_二维_02

 这里我们的dim取0,就是按照行进行取值,后面我们跟上一个Tensor,注意这里的Tensor一定要和我们前面的test  Tensor的维数相同,要不然会报错。【你当然也可以理解,从一个二维的Tensor取值,你当然也是进入一个二维的Tensor索引】

这里就是从上往下看,第一个取的索引值为0的值就是1,第二个取索引值为0的值就是2.【也就是第一列取了第一个,第二列取了第一个】

我们改成其他的值也可以很明显的看出来。

Torch.gather_python_03

 然后我们来看一下取得索引多一点:

Torch.gather_python_04

 这样行的就理解的差不多了,然后我们把那个dim改成1进行观察。

Torch.gather_取值_05

 这里我们进行解释一下,首先,我们还是那个tensor。

这里我们将dim改成1,当Tensor[[0,1]]时,第一个取的还是我们test Tensor中第一个值1,第二个就是索引值为1的2.

当tensor[[0,0],[0,0]]也是同理。

但是这里我们发现了不同,我们写入3个值时,却无法正常的取值。而上面的当dim=0时却可以正常的取值。那是因为这里我们是按照列进行取值的,我们这里的列数只有两个,所以当我们写第三个值的时候,就是取第三行的值,因为我们的test内的Tensor只有两行,就出现了错误。而与按行取的不同,每一次取的都是按照行索引值进行取值,所以,无论我们写多少个也不会报错。

Torch.gather_二维_06