最近加入了一个deeplearning的学习小组开始学习pytorch,初始对这个向量切片函数index_select()感到有些疑惑,经过自己一番实验之后,应该算是懂了吧,和大家一起分享一下实验结果。
index_select有两种用法,一种是将某一个张量(tensor)作为变量传入torch.index_select()函数,还有一个是tensor的内置方法index_select。用法分别是
a=tensor.tensor([1,2])
# first use
c = torch.index_select(a, 1, torch.tensor([0]))
# second use
d = a.index_select(1 , torch.tensor([0]))
这两种用法的区别和python中的sort以及sorted函数有些类似。第一种用法 c会成为一个新的张量且不会和a共用内存,而第二种用法d会和a共用内存
然后,如何使用这个切片函数呢?首先,我们看一下官方文档的介绍。
torch.index_select(input, dim, index, out=None) → Tensor
沿着指定维度对输入进行切片,取index
中指定的相应项(index
为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
注意: 返回的张量不与原始张量共享内存空间。
参数:
- input (Tensor) – 输入张量
- dim (int) – 索引的轴
- index (LongTensor) – 包含索引下标的一维张量
- out (Tensor, optional) – 目标张量
其实只要搞清楚这里的 dim和index参数的含义就好。通常我们能够写出来的,直观的数组就是二维数组,可以把它想象成一个表格,行是第零维(注意索引是从0开始的),列是第一维,因此 dim是确定维度,而index是选择该维度上的那些切片。可能还是比较晦涩,下面我们举一些例子
c=torch.tensor([[1,2,3],[4,5,6]])
print(c)
d=torch.index_select(c,0,torch.tensor([1]))
print(d)
这里我们选择了第零维,也就是行,然后index参数是1,代表获取第一行(注意从0开始算起的),因此我们得到的结果是这样的
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[4, 5, 6]])
二维张量比较容易理解,接下来我们理解三维张量,大家可以在脑子里将三维数组想象成类似这样的
这里我构造了一个3*3*3的张量,然后根据第0维进行切片
a = [[[1,2,3], [4,5,6], [7,8,9]], [[11,22,33], [44,55,66], [77,88,99]], [[111,222,333], [444,555,666], [777,888,999]]]
a = torch.tensor(a)
print(a)
b = a.index_select(0, torch.tensor([0]))
这里我们切出来了第零维里的第零个数组,由于这是3*3*3的数组,因此第零维就是第一个3,然后index是[0]的话代表是第一个三中的第零个张量,其他两个3可以不管,照抄。因此结果是这样的。
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99]],
[[111, 222, 333],
[444, 555, 666],
[777, 888, 999]]])
tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
同理,我们试一下根据第一维切片。
b = a.index_select(1, torch.tensor([0,2]))
print(b)
这里的第一维代表3*3*3中的第二个3,选择其中的第零和第二组数据,其他维度照抄,因此得到的结果是这样的
tensor([[[ 1, 2, 3],
[ 7, 8, 9]],
[[ 11, 22, 33],
[ 77, 88, 99]],
[[111, 222, 333],
[777, 888, 999]]])