torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

沿dim指定的轴聚集值。

对于三维张量,输出由以下公式指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

如果input是大小为(x0, x1…, xi−1, xi, xi+1, …, xn−1) 的n维张量并且dim = i,那么index必须是大小为(x0, x1…, xi−1, y, xi+1, …, xn−1) 的n维张量,并且 y >= 1,outindex具有相同的大小。

Parameters

  • input (Tensor) – 输入张量
  • dim (int) – 要索引的轴
  • index (LongTensor) – 要收集的元素的索引
  • sparse_grad (bool,optional) – 如果为True,梯度w.r.t。input将是一个稀疏张量。
  • out (Tensor, optional) – 目标张量

Example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(dim=1, index=torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

dim=1 时,就是按列进行索引,dim=0 时,就是按行进行索引。
然后按照index去交换元素的位置。