1. 小试牛刀
前面已经说过,广播和矢量化是 NumPy 的精髓所在。所谓广播,就是将对数组的操作映射到每个数组元素上;矢量化可以理解为代码中没有显式的循环、索引等。如果用循环结构遍历 NumPy 数组,显然不符合 NumPy 的思想。可以说,使用 Numpy 的最高境界就是避免使用循环。如果代码中存在遍历 NumPy 数组的结构,就不是好的代码,就一定有优化空间。
下面,我们先用三个小问题来演示如何避免遍历 NumPy 数组。乍一看这三个问题,似乎不用循环就无法实现,但真正理解广播和矢量化的特性之后,我们就可以轻松解决。这三个问题,分别代表了 NumPy 数组最经典的三种应用模式。如果你已经读过《每天15分钟,5天学会NumPy》的前 4 篇,建议先尝试着不用循环的方法解决,然后再继续阅读后面的答案。
- ds 是随机数组,令小于0.5的元素为0,其余为1;
- ds 是随机整型数组,实现索引序号为偶数的元素自乘,其他元素不变;
- 酒精温度计是用液面高度映射温度的。一维数组 ca 表示液面高度和温度的对照关系:索引序号表示液面高度,对应的元素值表示温度。ds 是排成10行10列的100支酒精温度计的液面高度数据组成的整型组,值域范围是[0,100)mm。请将数组a的液面高度数据转换成温度数据。ca 可以用下面的算式模拟,ds 可以随机生成模拟数据。这个问题是我在处理卫星数据时遇到的:对近5亿个数据点做定标,使用循环结构来处理的话,耗时40分钟,利用 NumPy 数组的特性,只需要15秒钟。
ca = np.log(np.arange(100)+1)*3+15
问题1的参考答案
>>> ds = np.random.random(10)>>> dsarray([0.53173073, 0.20320829, 0.29968805, 0.25571728, 0.99315841, 0.99734874, 0.87743861, 0.69224332, 0.69648324, 0.67377661])>>> ds[ds<0.5] = 0>>> ds[ds>=0.5] = 1>>> dsarray([1., 0., 0., 0., 1., 1., 1., 1., 1., 1.])>>> ds = np.random.random(10) # 尝试另外一种方法>>> np.where(ds<0.5, 0, 1) # 使用where更简洁array([0, 1, 0, 1, 1, 0, 0, 1, 0, 1])
问题2的参考答案
>>> ds = np.random.randint(1, 10, 8)>>> dsarray([4, 9, 1, 4, 4, 7, 5, 5])>>> ds[::2] *= ds[::2]>>> dsarray([16, 9, 1, 4, 16, 7, 25, 5])
问题3的参考答案
>>> ca = np.log(np.arange(100)+1)*3+15>>> ds= np.random.randint(0, 100, (10, 10), dtype=np.uint8)>>> ca[ds]array([[27.1291538 , 24.40648265, 23.98719682, 28.03141627, 25.75055682, 23.1241506 , 27.61407786, 26.73606902, 28.36304189, 27.38140316], [28.72413294, 17.07944154, 28.1460799 , 26.61360303, 25.75055682, 28.56536573, 25.75055682, 25.83275374, 27.42940418, 25.91275848], [28.03141627, 22.45471995, 27.99220002, 26.7954769 , 27.23261233, 25.83275374, 27.95246434, 27.38140316, 28.36304189, 28.62988435], [27.91219528, 26.3525689 , 27.70231951, 28.81551056, 27.52316181, 28.03141627, 24.27312736, 26.91087574, 25.66604418, 24.53416149], [27.65852312, 27.65852312, 28.03141627, 25.66604418, 25.83275374, 21.23832463, 27.74548573, 26.1407162 , 28.75490244, 27.56896423], [23.83331694, 27.95246434, 28.32795377, 22.19368582, 25.83275374, 19.15888308, 17.07944154, 27.95246434, 28.75490244, 27.18132903], [23.98719682, 27.74548573, 24.77428961, 28.2924504 , 28.32795377, 27.1291538 , 22.45471995, 21.90775528, 28.66163067, 28.1460799 ], [19.15888308, 23.83331694, 28.03141627, 27.74548573, 28.49942901, 28.56536573, 25.91275848, 28.07012648, 21.90775528, 25.57908157], [27.99220002, 22.19368582, 28.59779848, 28.22015774, 24.65662747, 28.56536573, 28.49942901, 27.33262159, 24.77428961, 24.77428961], [27.61407786, 28.81551056, 22.69484807, 28.10834356, 22.69484807, 27.87137832, 17.07944154, 24.99661353, 28.32795377, 27.65852312]])那么,是否存在不用循环就解决不了的问题呢?这样的问题肯定存在,但我们可以自定义广播函数,将这个函数广播到数组的每一个元素上,以此来回避循环结构。只是,自定义的广播函数,执行效率不一定比循坏结构高。
2.
有一次高数考试,题出得特别难,全班都没有考好,及格的同学没几人。正当大家惶恐不安的时候,数学老师说,别紧张,每个人的成绩都做一次开方乘10,这样估计大家都及格了。可不是嘛,原先36分的同学,开方乘10,正好60,考了64分的同学,瞬间变成了80。于是乎,皆大欢喜,这位老师成功收获了每一位同学的拥戴。
严格说,当年我的数学老师的算法,不是开方乘10,而是:
- 得分除以100(归一化)
- 开方(提升)或乘方(降低)
- 乘以100(值域还原)
这个思路用在
下图左侧是一张老照片,暗区什么也看不出来。右侧是经过伽马校正之后的

from PIL import Imageimport numpy as npdef gamma_adjust(im, gamma=1.0): """伽马矫正""" return (np.power(im.astype(np.float32)/255, 1/gamma)*255).astype(np.uint8)im = Image.open(r'D:\CSDN\DSC_8510.jpg') # 打开照片im = np.array(im) # 将PIL对象转成NumPy数组im = gamma_adjust(im, gamma=2) # 伽马矫正,得到新的NumPy数组im = Image.fromarray(im, mode='RGB') # 将NumPy数组转成PIL对象im.save(r'D:\CSDN\DSC_8510_gamma15.jpg') # 将PIL对象保存成jpg图片这张测试用的照片大约有1200万像素,伽马矫正几乎是瞬间完成的。用循环方式遍历每一个像素,在我的电脑上耗时超过了60秒。有兴趣的同学可以自己找一张照片测试一下。
除了伽马矫正,几乎所有的
3. 代码加速
我曾经以“统计小于1亿的素数个数”为题,写过一篇讨论如何用 NumPy 数组代替 Python 列表实现加速运算的文章。下面的代码,使用 Python 列表,采用两层嵌套循环,找出100万以内的质数,大约需要3秒钟。查找范围再增加一个量级的话,耗时会远远超过30秒(我最终没有等到程序结束,就终止了运行)。
# -*- coding: utf-8 -*-import sys, timefrom math import sqrtdef find_prime(upper): """找出小于upper的所有质数""" prime_list = list() # 存放找到的质数 for i in range(2, upper): # 从2开始,逐一甄别是否是质数 is_prime = True # 假设当前数值i是质数 for p in prime_list: # 遍历当前已经找到的质数列表 if i%p == 0: is_prime = False break elif p > sqrt(i): break if is_prime: prime_list.append(i) return prime_listupper = 1000000t0 = time.time()prime_list = find_prime(upper)t1 = time.time()print('查找%d以内的质数耗时%0.3f秒,共找到%d个质数'%(upper, t1-t0, len(prime_list)))将 Python 列表替换为 NumPy 数组,可以不用循环的方式,从数组 nums 中将素数 p 的倍数全部设置为0:
nums[p::p] = 0下面的代码使用了 NumPy 数组,和使用 Python 列表的代码相比,少了内层的循环,找出1千万以内的质数,大约需要12秒钟,速度至少提高了一个量级。
import sys, timeimport numpy as npdef find_prime(upper): """找出小于upper的所有质数""" prime_list = list() # 空数组,用于存放找到的质素 mid = int(np.sqrt(upper)) # 判断100是否是质数,只需要分别用2,3...等质素去除100,看能否被整除,最多做到100的平方福根就够了 nums = np.arange(upper) # 生成0到上限的数组 nums[1] = 0 # 数组第1和元素置0,从2开始,都是非0的 while True: # 循环 primes = nums[nums>0] # 找出所有非0的元素 if primes.any(): # 如果能找到 p = primes[0] # 则第一个元素为质数 prime_list.append(p) # 保存第一个元素到返回的数组 nums[p::p] = 0 # 这个质数的所有倍数,都置为0(表示非质素) if p > mid: # 如果找到的质数大于上限的平方根 break # 则退出循环 else: break # 全部0,也退出循环 prime_list.extend(nums[nums>0].tolist()) # nums里面剩余的非0元素,都是质数,合并到返回的数组中 return prime_list # 返回结果upper = 10000000t0 = time.time()prime_list = find_prime(upper)t1 = time.time()print('查找%d以内的质数耗时%0.3f秒,共找到%d个质数'%(upper, t1-t0, len(prime_list)))
4. 旋转矩阵
在数学上,矩阵是一个复杂概念,但在 NumPy 中,矩阵对象 matrix 继承自数组对象 ndarray,这意味着,矩阵本质上是一个数组,拥有数组的所有属性和方法。同时,矩阵又有一些不同于数组的特性和方法。
- 首先,矩阵总是二维的,不能像数组一样可以幻化成任意维度,即使展平或者切片,返回也是二维的;
- 其次,矩阵和数组混合运算时,运算结果总是返回矩阵;
- 第三,矩阵的乘法不同于数组乘法。
在讲广播和矢量化的时候,我们已经知道,两个 NumPy 数组相乘,就是对应元素相乘,条件是两个数组的shape相同。事实上,即使两个数组的shape不同,只要满足特定条件,也能做乘法。请看演示:
>>> a = np.random.randint(0,10,(2,3))>>> aarray([[7, 2, 3], [1, 2, 7]])>>> b = np.random.randint(0,10,3)>>> barray([9, 6, 2])>>> a*barray([[63, 12, 6], [ 9, 12, 14]])>>> b*aarray([[63, 12, 6], [ 9, 12, 14]])实际上,数组还有另一种乘法,使用dot()函数的乘法
>>> np.dot(a,b)array([81, 35])我习惯把数组对应元素相乘,叫做星乘,把dot()函数实现的乘法,叫做点乘。NumPy 的点乘,就是矩阵乘法。这里需要说明一下,本课程中的星乘、点乘,是我个人的习惯用法,和纯数学里的点乘、叉乘并没有对应关系。对于数组而言,星乘和点乘,是完全不同的两种乘法,对于矩阵来说,不管是星乘还是点乘,结果都是一样的,都是点乘。矩阵没有对应元素相乘这个概念。那么,点乘,或者说矩阵乘法,究竟是怎么运算的呢?

不是所有的矩阵都能相乘。我们来看,矩阵A乘以矩阵B,二者可以相乘的条件是:A的列数必须等于B的行数。比如,a是4行2列,b是2行3列,axb,4223,没问题,但是反过来,bxa,2342,就无法运算了。可见,矩阵乘法,不满足交换律。再来看看乘法规则。概括说,就是A的各行逐一去乘B的各列。比如,A的第1行和b的第2列,元素个数一定相等,对应元素相乘后求和,作为结果矩阵第1行第2列的值。再比如,a的第3行和b的第3列,对应元素相乘后求和,作为结果矩阵第3行第3列的值。以此类推,我们就得到了矩阵A乘以矩阵B的结果矩阵。
那么,这个眼花缭乱的矩阵乘法,有什么实用价值吗?答案是:有,不但有,而且有非常大的使用价值。对于程序员来说,矩阵乘法最常见的应用是

下面,我们应用这个推导结果,定义一个函数,返回平面上的点围绕原点旋转给定角度后的坐标:
>>> def rotate(p,d): a = np.radians(d) m = np.array([[np.cos(a), np.sin(a)],[-np.sin(a), np.cos(a)]]) return np.dot(np.array(p), m)>>> rotate((5.7,2.8), 35) # 旋转35°array([3.06315263, 5.56301141])>>> rotate((5.7,2.8), 90) # 旋转90°array([-2.8, 5.7])>>> rotate((5.7,2.8), 180) # 旋转180°array([-5.7, -2.8])>>> rotate((5.7,2.8), 360) # 旋转360°array([5.7, 2.8])关于矩阵及其乘法,我们就讨论这么多。难度应该不大,但对于有数学恐惧症的程序员来说,会感到紧张。没关系,只要记住旋转矩阵的使用方法,即使不懂数学,也照样可以成为优秀的 Python程序员。
5. 求解线性方程组
求解线性方程组应该是小学高年级或初中一年级的数学问题,但要是用代码求解的话,其实并不容易,不相信的话,你可以去尝试一下。NumPy 的线性代数模块 linalg 提供了高效的解决方案。这个模块通常用来解决逆矩阵、特征值、线性方程组以及行列式等问题。我们来演示一下求解如图所示的方程组。

>>> a = np.array([[1,-2,1],[0,2,-8],[-4,5,9]])>>> b = np.array([0,8,-9])>>> np.linalg.solve(a, b)array([29., 16., 3.]) # x=29, y=16, z=3对于多元线性方程组,这是一个通用的解法,只要写出系数矩阵和常数项矩阵,调用NumPy的linalg.solve()即可一步求解
6. 求解非线性方程(组)
参加数学建模竞赛,是大学生最热衷的活动之一。工作中也需要数学建模。建模过程中,有时需要对一些稀奇古怪的方程(组)求解。求解非线性方程(组),Matlab是一个很好的选择,但Matlab是收费的,SciPy 则为我们提供了免费的午餐!
scipy.optimize库中的fsolve函数的基本调用形式:
fsolve(func, x0)func(x)是计算方程组误差的函数,它的参数x是一个矢量,表示方程组的各个未知数的一组可能解,func返回将x代入方程组之后得到的误差;x0为未知数矢量的初始值。
我们先从简单的非线性方程开始。

>>> from scipy.optimize import fsolve>>> def f(A): return [np.sin(np.radians(A[0])) - np.cos(np.radians(A[0])) - 0.2]>>> result = fsolve(f, (0,))>>> print(result) # 得到方程解为53.13010235°[53.13010235]>>> f(result) # 代入方程验证,误差小于在10的-15次方[-1.6653345369377348e-16]再来看非线性方程组的求解。

>>> def f(A): dev0 = 4*A[0]*A[0] - 2*np.sin(A[1]*A[2]) dev1 = 5*A[1] + 3 dev2 = A[0]*A[2] - 1.5 return [dev0, dev1, dev2]>>> result = fsolve(f, [1,1,1])>>> print(result) # x=0.27341748, y=-0.6, z=5.48611603[ 0.27341748 -0.6 5.48611603]>>> f(result) # 代入方程验证,误差小于在10的-14次方[-2.609024107869118e-15, 0.0, 5.10702591327572e-15]对于非线性方程组,这是一个通用的解法,先定义误差函数,再给出一组初始值,一般情况下就会得到非常近似的一组解。
7. 数值积分
我们就以求解半径为1的圆的面积为例,演示数值积分。先来复习一下经典微分法。所谓经典微分法,就是j将积分区间分成n等分,用n个矩形区域的面积之和近似曲线和积分区间封闭起来的区域面积。
>>> def circle(x): return np.power(1-x*x, 0.5)>>> x = np.linspace(-1, 1, 1000, endpoint=False) # 将[-,1]等分成1000段>>> dx = 2/1000>>> y = circle(x)>>> np.sum(y)*dx*2 # 圆面积(即pi)略小于3.14153.1414874770021406>>> x = np.linspace(-1, 1, 10000, endpoint=False) # 将[-,1]等分成10000段>>> dx = 2/10000>>> y = circle(x)>>> np.sum(y)*dx*2 # 圆面积(即pi)略小于3.141593.1415893274305816

Scipy提供了一个积分模块,使用积分模块,我们可以直接得到高精度的近似值:
>>> from scipy import integrate>>> def circle(x): return np.power(1-x*x, 0.5)>>> pi_half, err = integrate.quad(circle, -1, 1)>>> pi_half*2 # 这个精度比经典微分法分成10000段还要高很多3.1415926535897967
8. 数据插值
数据插值是数据处理过程中经常用到的技术,常用的插值有一维插值、二维插值、高阶插值等,常见的算法有线性插值、B样条插值、临近插值等。还有一种插值,网上讨论的较少,但在科研和工程领域使用率非常高,那就是散列数据插值到网格。
8.1 一维插值
下面的代码,展示了一维插值不同算法的效果:
# -*- coding: utf-8 -*-import numpy as npfrom scipy import interpolateimport matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['FangSong']plt.rcParams['axes.unicode_minus'] = Falsex = np.linspace(0,10,11)y = np.exp(-x/3.0)x_new = np.linspace(0,10,100) # 期望在0-10之间变成100个数据点f1 = interpolate.interp1d(x, y, kind='linear')f2 = interpolate.interp1d(x, y, kind='nearest')f3 = interpolate.interp1d(x, y, kind='zero')f4 = interpolate.interp1d(x, y, kind='slinear')f5 = interpolate.interp1d(x, y, kind='cubic')f6 = interpolate.interp1d(x, y, kind='quadratic')f7 = interpolate.interp1d(x, y, kind='previous')f8 = interpolate.interp1d(x, y, kind='next')plt.figure('Demo', facecolor='#eaeaea')plt.subplot(221)plt.plot(x, y, "o", label=u"原始数据")plt.plot(x_new, f2(x_new), label=u"临近点插值")plt.plot(x_new, f7(x_new), label=u"前点插值")plt.plot(x_new, f8(x_new), label=u"后点线性插值")plt.legend()plt.subplot(222)plt.plot(x, y, "o", label=u"原始数据")plt.plot(x_new, f1(x_new), label=u"线性插值")plt.plot(x_new, f3(x_new), label=u"零阶样条插值")plt.plot(x_new, f4(x_new), label=u"一阶样条插值")plt.legend()plt.subplot(223)plt.plot(x, y, "o", label=u"原始数据")plt.plot(x_new, f1(x_new), label=u"线性插值")plt.plot(x_new, f5(x_new), label=u"三阶样条插值")plt.legend()plt.subplot(224)plt.plot(x, y, "o", label=u"原始数据")plt.plot(x_new, f1(x_new), label=u"线性插值")plt.plot(x_new, f6(x_new), label=u"五阶样条插值")plt.legend()plt.show()

8.2 二维插值
咱们再来看看二维插值。既然是二维数据,通常总是对应着一个网格,比如,经纬度网格。如果插值对象只有一个二维数组,那么我们可以用数组的行列号来构造网格。
# -*- coding: utf-8 -*-import numpy as npfrom scipy import interpolateimport matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['FangSong']plt.rcParams['axes.unicode_minus'] = Falsey, x = np.mgrid[-2:2:20j,-3:3:30j]z = x*np.exp(-x**2-y**2)y_new, x_new = np.mgrid[-2:2:80j,-3:3:120j]f1 = interpolate.interp2d(x[0,:], y[:,0], z, kind='linear') # 线性插值f2 = interpolate.interp2d(x[0,:], y[:,0], z, kind='cubic') # 三阶样条f3 = interpolate.interp2d(x[0,:], y[:,0], z, kind='quintic') # 五阶样条z1 = f1(x_new[0,:], y_new[:,0])z2 = f2(x_new[0,:], y_new[:,0])z3 = f3(x_new[0,:], y_new[:,0])plt.subplot(221)plt.pcolor(x, y, z, cmap=plt.cm.hsv)plt.colorbar()plt.axis('equal')plt.subplot(222)plt.pcolor(x_new, y_new, z1, cmap=plt.cm.hsv)plt.colorbar()plt.axis('equal')plt.subplot(223)plt.pcolor(x_new, y_new, z2, cmap=plt.cm.hsv)plt.colorbar()plt.axis('equal')plt.subplot(224)plt.pcolor(x_new, y_new, z3, cmap=plt.cm.hsv)plt.colorbar()plt.axis('equal')plt.show()不同插值算法的效果比较如下图:

8.3
散列数据插值到网格
8.4
高阶快插
9. 拟合
插值和拟合看起来有一些相似,所以初学者比较容易混淆,实际上二者是完全不同的概念。最常见的拟合方法是多项式拟合,寻找多项式最佳系数的方法叫做最小二乘法。对此话题感兴趣的同学,请阅读文章《从寻找谷神星的过程,谈最小二乘法实现多项式拟合》
10. 数据平滑
11. k-means均值算法
12. HDF 和 netCDF
HDF(Hierarchical Data File),意为多层数据文件,是美国国家高级计算应用中心(National Center for Supercomputing Application, NCSA)为了满足各种领域研究需求而研制的一种能高效存储和分发科学数据的新型数据格式。HDF可以表示出科学数据存储和分布的许多必要条件。HDF提供6种基本数据类型:光栅
NetCDF(network Common Data Form),意为网络通用数据格式,是由美国大学大气研究协会(University Corporation for Atmospheric Research,UCAR)的Unidata项目科学家针对科学数据的特点开发的,是一种面向数组型并适于网络共享的数据的描述和编码标准。目前,NetCDF广泛应用于大气科学、水文、海洋学、环境模拟、地球物理等诸多领域。
HDF 和 netCDF 不仅仅美国人在用,我们中国也在用,尤其是空间科学、大气科学、地球物理等领域,几乎所有的数据分发,都依赖这两种格式的文件。解读这两种数据文件,需要安装两个模块:h5py和netCDF4,下载地址:https://
咱们先演示一下如何解读HDF数据。
接下来,咱们再用一个实际的应用问题演示一下 netCDF 数据文件的处理。该问题是一位研一在读的学生发给我的,他花费了很长时间也没有解决好,最后求助于我。问题是这样的:现有若干极轨卫星的 netCDF 格式的数据文件,每个文件有时间、经度、纬度、海平面高度异常数据等4个数据集,要求从全部文件中提取指定经纬度范围和指定时间范围内的所有数据。我们先看一下单个数据文件的数据集结构和值域范围、数据类型:
给定经纬度范围和时间范围,我们就可以从 netCDF4 数据文件中提取出符合条件的数据了。函数fetch_from_nc()实现了该功能:
def fetch_from_nc(dlon, dlat, dtime, nc): """从单个nc文件中提取符合要求的数据 dlon - 经度范围,元组 dlat - 纬度范围,元组 dtime - 时间范围,元组 nc - 数据文件名 return - 二维数组 """ with netCDF4.Dataset(nc, mode='r', format="NETCDF4") as fp: ssha = fp.variables['ssha'][:] time = fp.variables['time'][:] lon = fp.variables['lon'][:] lat = fp.variables['lat'][:] filter = (lon>=dlon[0])&(lon<=dlon[1])&(lat>=dlat[0])&(lat<=dlat[1])&(time>=dtime[0])&(time<=dtime[1]) return np.stack((time[filter], lon[filter], lat[filter], ssha[filter]), axis=1
















