当我们拿到一个numpy数组后,有时候我们并不是对整个数组的元素感兴趣,可能只想针对数组的某一个元素或者某一部分元素进行某些操作,而选中数组中的某个数据子集正是切片和索引的意义所在。
        切片的意义在于取得数组中某一“矩形”子集的数据,比如“第一行的第2到第5个元素”,“第3列的全部元素”等,提取出原数组中处于某一个“矩形”范围内的数据;而索引的意义在于提取满足一定条件的元素的子集,这个条件可能是某个具体的位置(基础的位置索引),也可能是对应位置的bool值(bool索引),也可能是打乱原有数组相对位置的位置索引(神奇索引)。在平时使用的时候往往是两者相辅相成,并没有进行非常明显的区分,因此本文中也不对两者进行进一步地区分。

1. 基础的索引与切片

numpy中的最一般的索引与切片与python中对于列表的索引和切片基本是一样的。

1.1 一维数组的索引与切片

arr1 = np.arange(12)
print(arr1)
print(arr1[1])
print(arr1[-1])
print(arr1[2: 6])
print(arr1[-5: -2])
print(arr1[:4])
print(arr1[6:])
print(arr1[6: -2])
[ 0 1 2 3 4 5 6 7 8 9 10 11]

 1

 11

 [2 3 4 5]

 [7 8 9]

 [0 1 2 3]

 [ 6 7 8 9 10 11]

 [6 7 8 9]

这里需要注意的有以下几点:

  • 数组的下标是从0开始的(这应该是常识了)
  • 可以使用负索引,-1表示最后一个元素
  • 切片[2: 6][-5: -2]都是左闭右开区间,即包括左侧边界元素不包括右侧边界元素
  • 切片中冒号(:)的左侧没有值,表明是从第一个元素开始,如果右侧没有值,则表明取到最后一个元素(包括最后一个元素)
  • 正索引可以和负索引共同使用,但是需要两者组成的实际区间有意义

另外,numpy数组的切片总是原数组的一个视图,在视图上的修改会反映到原数组上

arr1[2: 5] = 24
print(arr1)
slice1 = arr1[6: 8]
slice1[0] = 12
print(slice1)
print(arr1)
[ 0 1 24 24 24 5 6 7 8 9 10 11]

 [12 7]

 [ 0 1 24 24 24 5 12 7 8 9 10 11]

这里我发现了一个小问题:

arr1 = np.arange(12)
print(arr1)
slice2 = arr1[2: 5]
print(slice2)
slice2 = 12
print(slice2)
print(arr1)
[ 0 1 2 3 4 5 6 7 8 9 10 11]

 [2 3 4]

 12

 [ 0 1 2 3 4 5 6 7 8 9 10 11]

可以看到,此时对切片的修改并没有反映到原数组,个人感觉可能的原因是,这个切片本身有三个元素,而对切片本身进行少于三个元素的赋值只能改变切片,数组本身不知道该怎么去改变,所以就没有发生变化。(可能会想到数组的广播,由切片赋值的变量本身也是一个数组,由于广播机制,所以slice2本身发生了变化,而这里的广播可能并不能"传递"吧,个人理解,欢迎纠错)

如果需要得到的不是数组的一个视图,而是一个复制的单独的数组,可是使用copy()方法。

arr1 = np.arange(12)
print(arr1)
slice2 = arr1[2: 5].copy()
print(slice2)
slice2[0] = 12
print(slice2)
print(arr1)
[ 0 1 2 3 4 5 6 7 8 9 10 11]

 [2 3 4]

 [12 3 4]

 [ 0 1 2 3 4 5 6 7 8 9 10 11]

可以看出在使用copy()方法后,对于切片赋值的变量的修改并不会影响原数组。

1.2 二(高)维数组的索引与切片

在说明了一维数组的基础索引与切片的相关概念与操作后,对于二维或者高维数组而言,只不过是索引和切片的在二维或高维数组上的推广而已,本质上没有什么差别,但是在实际操作过程中还有些问题需要注意。

arr2 = np.array([[1, 2, 3], [4, 5, 6]])
print(arr2)
print(arr2.shape)     # 数组的形状
print(arr2[0])        # 数组的第一行
print(arr2[-1])       # 数组的最后一行
print(arr2[1, 2])     # 索引为(1, 2)的元素
print(arr2[1][2])     # 索引为(1, 2)的元素
print(arr2[1, :])     # 数组的第二行
print(arr2[:, 2])     # 数组的第三列
print(arr2[1, :2])    # 数组第二行中前两个元素
print(arr2[:-1, 2])   # 数组第三列中,除了最后一个元素
[[1 2 3]
 [4 5 6]]

 (2, 3) 

 [1 2 3]

 [4 5 6]

 6

 6

 [4 5 6]

 [3 6]

 [4 5]

 [3]

这里需要注意的几点:

  • 二维数组使用1个下标进行索引得到的是一个一维数组,用2个下标得到的是具体的某个元素。(n维数组使用1个下标得到的是n-1维数组,使用n个下标得到的是具体的元素,使用k个下标得到的是n-k维的数组)
  • [1, 2][1][2]numpy中的意义相同
  • 各个维度的切片需要注意值的越界问题,这一点在二维和更高维数组中尤其重要

二维或高维数组的切片视图问题与一位数组一样,在此不再赘述。

2. 布尔索引

当我们需要得到数组中满足一定条件的元素的集合的时候,比如“数组中大于3的所有元素”等,往往首先得到的是一个布尔数组,将这个布尔数组放置于索引的位置,就得到满足条件的元素集合,这就是布尔索引。

records = np.array(["a", "b", "c", "b", "c", "d", "a"])
data = np.random.randn(7, 4)
print(records)
print(data)
[‘a’ ‘b’ ‘c’ ‘b’ ‘c’ ‘d’ ‘a’]

 [[ 0.46956577 -0.5920976 -1.04230422 0.2345082 ]
 [ 1.115396 -0.76998952 -0.34218191 0.48640238]
 [-0.50296235 0.57976796 -0.18288928 0.38151529]
 [ 0.34595951 -0.89662156 -0.34394609 -0.77178702]
 [ 1.15546732 -1.13124287 1.00495337 -1.3925357 ]
 [-1.87583627 -0.25535481 -0.06034188 -0.64896447]
 [ 1.69924803 0.88408588 0.60332694 -0.76869852]]

这里给出了两个数组,records数组表示data数组中每一行数据所属的类别,如果这时候我们需要仅仅得到类别为"b"的数据,就需要使用到布尔索引:

records == "b"
array([False, True, False, True, False, False, False])

records == "b"得到一个布尔数组,布尔数值中的元素表示原数组中对应位置元素是否是"b"True表示是,False表示不是。需要注意的是该布尔数组的长度与records数组长度一致。

data[records == "b"]
array([[ 1.115396 , -0.76998952, -0.34218191, 0.48640238],
 [ 0.34595951, -0.89662156, -0.34394609, -0.77178702]])

可以看到,结果证实类别"b"所对应的那两行。这里需要特别注意的是布尔数组对应的是原数组的行而不是列,布尔数组的长度必须与原数组的行数相等,但是如果布尔数组的长度与原数组的行数不相等时也不会报错,但是结果就不是想要的结果。

除了简单的布尔索引之外,我们还可以进行复杂的布尔索引:

data[records == "b", 1:]  # 得到除第一列之外的元素
array([[-0.76998952, -0.34218191, 0.48640238],
 [-0.89662156, -0.34394609, -0.77178702]])
data[~(records == "b")]   # 得到不是"b"类别的数据
array([[ 0.46956577, -0.5920976 , -1.04230422, 0.2345082 ],
 [-0.50296235, 0.57976796, -0.18288928, 0.38151529],
 [ 1.15546732, -1.13124287, 1.00495337, -1.3925357 ],
 [-1.87583627, -0.25535481, -0.06034188, -0.64896447],
 [ 1.69924803, 0.88408588, 0.60332694, -0.76869852]])
data[(records == "a") | (records == "b")]      # 得到"a"或"b"类对应的数据
array([[ 0.46956577, -0.5920976 , -1.04230422, 0.2345082 ],
 [ 1.115396 , -0.76998952, -0.34218191, 0.48640238],
 [ 0.34595951, -0.89662156, -0.34394609, -0.77178702],
 [ 1.69924803, 0.88408588, 0.60332694, -0.76869852]])

这里需要注意的是:

  • 在numpy中使用~ | & 表示逻辑上的"非""或""与",而不是python中使用的"not""or""and"
  • 布尔索引生成的数组都是原数组的拷贝,不是视图

还有一个我个人经常遇到的操作,就是讲数组中大于某一数值的元素全部替换为某一数值:

data[data>0] = 10
data
array([[10. , -0.5920976 , -1.04230422, 10. ],
 [10. , -0.76998952, -0.34218191, 10. ],
 [-0.50296235, 10. , -0.18288928, 10. ],
 [10. , -0.89662156, -0.34394609, -0.77178702],
 [10. , -1.13124287, 10. , -1.3925357 ],
 [-1.87583627, -0.25535481, -0.06034188, -0.64896447],
 [10. , 10. , 10. , -0.76869852]])

3. 神奇索引

看到很多讲解numpy的书籍中都会提到numpy的神奇索引,乍一看好像很神奇的样子,其实看完之后也觉得很一般,可能神奇之处在于神奇索引得到的结果与原数组中同一元素的位置不同吧。

arr = np.random.randn(6, 4)
print(arr)
print(arr[[3, 5, 2]])     # 得到行索引为3,5,2的集合
print(arr[[-2, -1, -4]])  # 得到行索引为-2,-1,-4的集合
print(arr[[3, 4, 2], [2, 3, 1]])
print(arr[[3, 4, 2]][:, [2, 3, 1]])
[[ 0.26117421 -0.32260913 0.36688278 0.46538055]
 [ 1.10905033 0.57096941 0.04658424 -0.11849178]
 [ 1.73825248 1.08860368 0.21497244 0.18495274]
 [ 1.97775351 -1.08430494 0.6280707 1.00538162]
 [-0.2373218 0.4812642 0.25606079 2.60419572]
 [ 2.03303289 1.07603077 -0.37909895 -1.95541141]]

 [[ 1.97775351 -1.08430494 0.6280707 1.00538162]
 [ 2.03303289 1.07603077 -0.37909895 -1.95541141]
 [ 1.73825248 1.08860368 0.21497244 0.18495274]]

 [[-0.2373218 0.4812642 0.25606079 2.60419572]
 [ 2.03303289 1.07603077 -0.37909895 -1.95541141]
 [ 1.73825248 1.08860368 0.21497244 0.18495274]]

 [0.6280707 2.60419572 1.08860368] 

 [[ 0.6280707 1.00538162 -1.08430494]
 [ 0.25606079 2.60419572 0.4812642 ]
 [ 0.21497244 0.18495274 1.08860368]]

这里需要注意的是:
-神奇索引返回的是原数组的拷贝,并不是视图
-后两个结果中的arr[[3, 4, 2], [2, 3, 1]]表示得到索引为(3, 2)(4,3)(2,1)的三个元素
arr[[3, 4, 2]][:, [2, 3, 1]]表示arr[[3, 4, 2]]结果的列索引为231的三列,是两次运用了神奇索引

另外numpy中还有与神奇索引等价的方法take()put()

print(arr.take([3, 5, 2], axis=0))
print(arr.take([-2, -1, -4], axis=0))
[[ 1.97775351 -1.08430494 0.6280707 1.00538162]
 [ 2.03303289 1.07603077 -0.37909895 -1.95541141]
 [ 1.73825248 1.08860368 0.21497244 0.18495274]]

 [[-0.2373218 0.4812642 0.25606079 2.60419572]
 [ 2.03303289 1.07603077 -0.37909895 -1.95541141]
 [ 1.73825248 1.08860368 0.21497244 0.18495274]]

这里需要说明的是:
-take()方法中需要指定参数axis=0可以与神奇索引的结果一致,当指定axis为其他值的时候,可以在其他轴上进行与神奇索引等价的操作,如果不指定axis的值,那么就视为将数组进行扁平操作为一维数组,在进行在一维数组上的神奇索引操作。take()方法可以视为神奇索引的强化版本
-put()方法用来修改数组中相应位置的元素的值,没有axis参数,统一为数组扁平为一维数组后的位置

arr.put([3, 5, 2], [100, 200, 300])
print(arr)
[[ 2.61174210e-01 -3.22609125e-01 3.00000000e+02 1.00000000e+02]
 [ 1.10905033e+00 2.00000000e+02 4.65842427e-02 -1.18491782e-01]
 [ 1.73825248e+00 1.08860368e+00 2.14972440e-01 1.84952740e-01]
 [ 1.97775351e+00 -1.08430494e+00 6.28070700e-01 1.00538162e+00]
 [-2.37321803e-01 4.81264197e-01 2.56060789e-01 2.60419572e+00]
 [ 2.03303289e+00 1.07603077e+00 -3.79098951e-01 -1.95541141e+00]]