Numpy是python中的一个很重要的科学计算库,而在使用numpy时,经常需要axis来指定运算的轴,在计算时会沿着指定轴进行运算。比如:np.max(), np.min(), np.mean(), np.sum()等等。

一维矩阵

一维矩阵,只有一个维度,所以只能指定axis=0或者不指定,这比较好理解。如下面的代码所示:

import numpy as np
a = np.arange(6)
print(a)
print("max: ", np.max(a, axis=0))  # 指定axis=0,沿着第0个维度计算

可以得到,沿着第0个维度的计算结果:

[0 1 2 3 4 5]
max:  5

二维矩阵

对于如下的 4x5 矩阵,第0维(轴)的方向,和第1维(轴)的方向,已经标记出来。如下图所示:

python如何找到矩阵某一元素的索引位置 python查看矩阵维数_机器学习


当我们用代码来对这个矩阵进行运算:

import numpy as np

a = np.arange(20).reshape(4, 5)
print("a矩阵为:\n", a)
print("a矩阵的维度:", a.shape)
print("-"*20)
a_0 = np.sum(a, axis=0)  # 沿着第0维运算
a_1 = np.sum(a, axis=1)  # 沿着第1维运算
print("axis = 0, sum: {}, a_0的维度为:{}".format(a_0, a_0.shape))
print("axis=1, sum: {}, a_1的维度为:{}".format(a_1, a_1.shape))

可以得到如下的运行结果:

a矩阵为:
 [[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]]
a矩阵的维度: (4, 5)
--------------------
axis = 0, sum: [30 34 38 42 46], a_0的维度为:(5,)
axis=1, sum: [10 35 60 85], a_1的维度为:(4,)

计算结果,可以这样来理解:

  1. 当指定 axis=0 时,运算是沿着第 0 个维度进行的,跨越不同的行,运算之后,第 0 个维度被“压扁”了,被消除了,只剩下第 1 个维度。
  2. 当指定 axis=1 时,运算是沿着第 1 个维度进行的,跨越不同的列,运算之后,第 1 个维度被“压扁”了,被消除了,只剩下第 0 个维度。

多维矩阵

多维矩阵的运算,我们这里使用三维矩阵为例,先来看一段代码:

import numpy as np

a = np.arange(24).reshape(2, 3, 4)
print("a矩阵为:\n", a)
print("a矩阵的维度:", a.shape)
print("="*30)
a_0 = np.sum(a, axis=0)
a_1 = np.sum(a, axis=1)
a_2 = np.sum(a, axis=2)

print("axis=0:")
print("sum: \n{}".format(a_0))
print("a_0的维度为:{}".format(a_0.shape))
print("-"*20)
print("axis=1:")
print("sum: \n{}".format(a_1))
print("a_1的维度为:{}".format(a_1.shape))
print("-"*20)
print("axis=2:")
print("sum: \n{}".format(a_2))
print("a_2的维度为:{}".format(a_2.shape))

上面这段代码运行的结果如下:

a矩阵为:
 [[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
a矩阵的维度: (2, 3, 4)
==============================
axis=0:
sum: 
[[12 14 16 18]
 [20 22 24 26]
 [28 30 32 34]]
a_0的维度为:(3, 4)
--------------------
axis=1:
sum: 
[[12 15 18 21]
 [48 51 54 57]]
a_1的维度为:(2, 4)
--------------------
axis=2:
sum: 
[[ 6 22 38]
 [54 70 86]]
a_2的维度为:(2, 3)

矩阵计算的维度变化,我们依然可以理解为:当沿着第 i 个维度计算时,这个维度会被“压扁”,最后消除,只留下其他的维度。这里a矩阵的维度为(2, 3, 4),当指定 axis=0 时,计算之后,第 0 个维度就被消除,只剩下维度为(3, 4)的二维矩阵。
但是,沿着第 i 个维度运算,这个不太好理解。什么叫沿着第 i 个维度?
这里参考了一下这个博客。二维矩阵比较好理解,可以直观地画出来,到了三维或者更高的维度,就不太好直观理解了。矩阵中每个数字都有一个索引下标,axis=i 就是沿着第 i 个下标变化的方向运算
上面的三维矩阵,当指定 axis=0 时,就是沿着第 0 个下标变化的方向运算。因为第 0 维被消除了,所以得到的矩阵维度为(3, 4)。

沿着第 0 维下标变化情况:
(0, 0, 0) -> (1, 0, 0) || (0, 0, 1) -> (1, 0, 1) || (0, 0, 2) -> (1, 0, 2) || (0, 0, 3) -> (1, 0, 3)
(0, 1, 0) -> (1, 1, 0) || (0, 1, 1) -> (1, 1, 1) || (0, 1, 2) -> (1, 1, 2) || (0, 1, 3) -> (1, 1, 3)
(0, 2, 0) -> (1, 2, 0) || (0, 2, 1) -> (1, 2, 1) || (0, 2, 2) -> (1, 2, 2) || (0, 2, 3) -> (1, 2, 3)
上面的三维矩阵,使用np.sum(a, axis=0)运算时,沿着第 0 维变化的方向相加,运算过程如下:
0+12, 1+13, 2+14, 3+15
4+16, 5+17, 6+18, 7+19
8+20, 9+21, 10+22, 11+23
最终得到运算结果:
[[12 14 16 18]
 [20 22 24 26]
 [28 30 32 34]]

当指定 axis=1 时,就是沿着第 1 个下标变化的方向运算。因为第 1 维被消除了,所以得到的矩阵维度为(2, 4)。

沿着第 1 维下标变化情况:
(0,0,0)->(0,1,0)->(0,2,0) || (0,0,1)->(0,1,1)->(0,2,1) || (0,0,2)->(0,1,2)->(0,2,2) || (0,0,3)->(0,1,3)->(0,2,3)
(1,0,0)->(1,1,0)->(1,2,0) || (1,0,1)->(1,1,1)->(1,2,1) || (1,0,2)->(1,1,2)->(1,2,2) || (1,0,3)->(1,1,3)->(1,2,3)
上面的三维矩阵,使用np.sum(a, axis=1)运算时,沿着第 1 维变化的方向相加,运算过程如下:
0+4+8, 1+5+9, 2+6+10, 3+7+11
12+16+20, 13+17+21, 14+18+22, 15+19+23
最终得到运算结果:
[[12 15 18 21]
 [48 51 54 57]]

其他维度的矩阵运算,采用相同的方法来理解即可。