最近加入了一个deeplearning的学习小组开始学习pytorch,初始对这个向量切片函数index_select()感到有些疑惑,经过自己一番实验之后,应该算是懂了吧,和大家一起分享一下实验结果。

index_select有两种用法,一种是将某一个张量(tensor)作为变量传入torch.index_select()函数,还有一个是tensor的内置方法index_select。用法分别是

a=tensor.tensor([1,2])
# first use
c = torch.index_select(a, 1, torch.tensor([0]))
# second use
d = a.index_select(1 , torch.tensor([0]))

这两种用法的区别和python中的sort以及sorted函数有些类似。第一种用法 c会成为一个新的张量且不会和a共用内存,而第二种用法d会和a共用内存 

然后,如何使用这个切片函数呢?首先,我们看一下官方文档的介绍。

torch.index_select(input, dim, index, out=None) → Tensor

沿着指定维度对输入进行切片,取index中指定的相应项(index为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。

注意: 返回的张量不与原始张量共享内存空间。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 包含索引下标的一维张量
  • out (Tensor, optional) – 目标张量

其实只要搞清楚这里的 dim和index参数的含义就好。通常我们能够写出来的,直观的数组就是二维数组,可以把它想象成一个表格,行是第零维(注意索引是从0开始的),列是第一维,因此 dim是确定维度,而index是选择该维度上的那些切片。可能还是比较晦涩,下面我们举一些例子

c=torch.tensor([[1,2,3],[4,5,6]])
print(c)
d=torch.index_select(c,0,torch.tensor([1]))
print(d)

这里我们选择了第零维,也就是行,然后index参数是1,代表获取第一行(注意从0开始算起的),因此我们得到的结果是这样的

tensor([[1, 2, 3],
        [4, 5, 6]])
tensor([[4, 5, 6]])

二维张量比较容易理解,接下来我们理解三维张量,大家可以在脑子里将三维数组想象成类似这样的

python np张量局部平均 pytorch张量切片_数据挖掘

 

这里我构造了一个3*3*3的张量,然后根据第0维进行切片

a = [[[1,2,3], [4,5,6], [7,8,9]], [[11,22,33], [44,55,66], [77,88,99]], [[111,222,333], [444,555,666], [777,888,999]]]
a = torch.tensor(a)
print(a)
b = a.index_select(0, torch.tensor([0]))

这里我们切出来了第零维里的第零个数组,由于这是3*3*3的数组,因此第零维就是第一个3,然后index是[0]的话代表是第一个三中的第零个张量,其他两个3可以不管,照抄。因此结果是这样的。

tensor([[[  1,   2,   3],
         [  4,   5,   6],
         [  7,   8,   9]],

        [[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[111, 222, 333],
         [444, 555, 666],
         [777, 888, 999]]])

tensor([[[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]])

同理,我们试一下根据第一维切片。

b = a.index_select(1, torch.tensor([0,2]))
print(b)

这里的第一维代表3*3*3中的第二个3,选择其中的第零和第二组数据,其他维度照抄,因此得到的结果是这样的

tensor([[[  1,   2,   3],
         [  7,   8,   9]],

        [[ 11,  22,  33],
         [ 77,  88,  99]],

        [[111, 222, 333],
         [777, 888, 999]]])