PyTorch中的gather函数
在PyTorch中,gather
函数是一个非常有用的函数,用于从一个张量(Tensor)中按照指定的索引提取元素。gather
函数的功能类似于数组的索引操作,可以在一个张量中根据指定的索引位置获取对应的元素或子集。本文将介绍gather
函数的使用方法,以及它在实际深度学习任务中的应用。
1. gather函数的基本用法
gather
函数的用法如下:
torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
其中,参数的含义如下:
input
:输入张量,即需要从中提取元素的张量。dim
:指定提取元素的维度,即在哪个维度上进行索引操作。index
:索引张量,即指定提取元素的索引位置。out
:输出张量,存放提取的元素。sparse_grad
:是否开启稀疏梯度计算。
举个例子,假设我们有一个形状为(4, 3)
的输入张量input
,和一个形状为(2, 2)
的索引张量index
,我们可以使用gather
函数在dim=0
的维度上提取对应索引位置的元素。具体代码如下:
import torch
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
index = torch.tensor([[0, 1], [2, 3]])
output = torch.gather(input, 0, index)
print(output)
运行上述代码,将得到如下输出:
tensor([[ 1, 5],
[ 7, 12]])
可以看到,输出的张量output
的形状为(2, 2)
,其中每个元素都是根据索引张量index
从输入张量input
中提取出来的。
我们也可以在其他维度上进行索引操作,比如在dim=1
的维度上提取元素。具体代码如下:
output = torch.gather(input, 1, index)
print(output)
运行上述代码,将得到如下输出:
tensor([[ 1, 2],
[ 6, 7],
[ 9, 10],
[11, 12]])
可以看到,输出的张量output
的形状为(4, 2)
,其中每个元素都是根据索引张量index
从输入张量input
中提取出来的。
2. gather函数的应用场景
gather
函数在实际的深度学习任务中有很多应用场景,下面我们将介绍两个常见的应用场景。
2.1. 按索引取出指定元素
首先,gather
函数可以用于按索引取出指定位置的元素。在目标检测任务中,我们通常需要从一个包含所有检测框坐标的张量中,提取出指定索引的检测框坐标。具体代码如下:
import torch
# 输入张量,形状为(100, 4)
boxes = torch.randn(100, 4)
# 索引张量,形状为(10,)
indices = torch.tensor([5, 13, 27, 45, 62, 71, 80, 92, 98, 99])
# 提取指定索引位置的检测框坐标
selected_boxes = torch.gather(boxes, 0, indices.unsqueeze(1).expand(-1, 4))
上述代码中,我们首先定义了一个形状为(100, 4)
的输入张量boxes
,表示了100个检测框的坐标。然后,我们定义了一个形状为(10,)
的索引张量indices
,其中包含了我们要提取的检测框的索引。
最后,我们使用gather
函数将索引张量indices
应用于输入张量boxes
,并通过`unsqueeze(1).expand