PyTorch中的gather函数

在PyTorch中,gather函数是一个非常有用的函数,用于从一个张量(Tensor)中按照指定的索引提取元素。gather函数的功能类似于数组的索引操作,可以在一个张量中根据指定的索引位置获取对应的元素或子集。本文将介绍gather函数的使用方法,以及它在实际深度学习任务中的应用。

1. gather函数的基本用法

gather函数的用法如下:

torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor

其中,参数的含义如下:

  • input:输入张量,即需要从中提取元素的张量。
  • dim:指定提取元素的维度,即在哪个维度上进行索引操作。
  • index:索引张量,即指定提取元素的索引位置。
  • out:输出张量,存放提取的元素。
  • sparse_grad:是否开启稀疏梯度计算。

举个例子,假设我们有一个形状为(4, 3)的输入张量input,和一个形状为(2, 2)的索引张量index,我们可以使用gather函数在dim=0的维度上提取对应索引位置的元素。具体代码如下:

import torch

input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
index = torch.tensor([[0, 1], [2, 3]])

output = torch.gather(input, 0, index)
print(output)

运行上述代码,将得到如下输出:

tensor([[ 1,  5],
        [ 7, 12]])

可以看到,输出的张量output的形状为(2, 2),其中每个元素都是根据索引张量index从输入张量input中提取出来的。

我们也可以在其他维度上进行索引操作,比如在dim=1的维度上提取元素。具体代码如下:

output = torch.gather(input, 1, index)
print(output)

运行上述代码,将得到如下输出:

tensor([[ 1,  2],
        [ 6,  7],
        [ 9, 10],
        [11, 12]])

可以看到,输出的张量output的形状为(4, 2),其中每个元素都是根据索引张量index从输入张量input中提取出来的。

2. gather函数的应用场景

gather函数在实际的深度学习任务中有很多应用场景,下面我们将介绍两个常见的应用场景。

2.1. 按索引取出指定元素

首先,gather函数可以用于按索引取出指定位置的元素。在目标检测任务中,我们通常需要从一个包含所有检测框坐标的张量中,提取出指定索引的检测框坐标。具体代码如下:

import torch

# 输入张量,形状为(100, 4)
boxes = torch.randn(100, 4)

# 索引张量,形状为(10,)
indices = torch.tensor([5, 13, 27, 45, 62, 71, 80, 92, 98, 99])

# 提取指定索引位置的检测框坐标
selected_boxes = torch.gather(boxes, 0, indices.unsqueeze(1).expand(-1, 4))

上述代码中,我们首先定义了一个形状为(100, 4)的输入张量boxes,表示了100个检测框的坐标。然后,我们定义了一个形状为(10,)的索引张量indices,其中包含了我们要提取的检测框的索引。

最后,我们使用gather函数将索引张量indices应用于输入张量boxes,并通过`unsqueeze(1).expand