python 常因速度慢而被诟病。其实优化代码性能的方案有很多,最主要的一般是两个方向:
- 从算法本身优化,这是最根本和彻底的优化;
- 从语言本身和工具层面优化,不外乎类型化、空间换时间等;
刚好看到一篇文章 讲的比较好,整理一下。
一、Fibonacci 函数
常见的Fibonacci 函数实现有两种,一种是递归,一种是非递归(好像是废话。。。。)
- 常规代码
def fib(n):
if n<2:
return n
return fib(n-1)+fib(n-2)
def fib_seq(n):
if n < 2:
return n
a,b = 1,0
for i in range(n-1):
a,b = a+b,a
return a
%timeit fib(20)
%timeit fib_seq(20)
100 loops, best of 3: 2.28 ms per loop
1000000 loops, best of 3: 1.04 µs per loop
- cython 版本
%%cython
def fib_cython(n):
if n<2:
return n
return fib_cython(n-1)+fib_cython(n-2)
def fib_seq_cython(n):
if n < 2:
return n
a,b = 1,0
for i in range(n-1):
a,b = a+b,a
return a
cpdef long fib_seq_cython_type(long n):
if n < 2:
return n
cdef long a,b
a,b = 1,0
for i in range(n-1):
a,b = a+b,b
return a
%timeit fib_cython(20)
%timeit fib_seq_cython(20)
%timeit fib_seq_cython_type(20)
1000 loops, best of 3: 741 µs per loop
1000000 loops, best of 3: 599 ns per loop
10000000 loops, best of 3: 33 ns per loop
- cache 版本
from functools import lru_cache as cache
@cache(maxsize=None)
def fib_cache(n):
if n<2:
return n
return fib_cache(n-1)+fib_cache(n-2)
@cache(maxsize=None)
def fib_seq_cache(n):
if n < 2:
return n
a,b = 1,0
for i in range(n-1):
a,b = a+b,a
return a
%timeit fib_cache(20)
%timeit fib_seq_cache(20)
10000000 loops, best of 3: 84.1 ns per loop
10000000 loops, best of 3: 87.9 ns per loop
- numba 版本
from numba import jit
@jit
def fib_numba(n):
if n<2:
return n
return fib_numba(n-1)+fib_numba(n-2)
@jit
def fib_seq_numba(n):
if n < 2:
return n
(a,b) = (1,0)
for i in range(n-1):
(a,b) = (a+b,a)
return a
%timeit fib_numba(20)
%timeit fib_seq_numba(20)
10000 loops, best of 3: 47.6 µs per loop
10000000 loops, best of 3: 164 ns per loop
二、快速排序
- 常规代码
def qsort_kernel(a, lo, hi):
i = lo
j = hi
while i < hi:
pivot = a[(lo+hi) // 2]
while i <= j:
while a[i] < pivot:
i += 1
while a[j] > pivot:
j -= 1
if i <= j:
a[i], a[j] = a[j], a[i]
i += 1
j -= 1
if lo < j:
qsort_kernel(a, lo, j)
lo = i
j = hi
return a
import random
lst = [ random.random() for i in range(1,5000) ]
%timeit qsort_kernel(lst, 0, len(lst)-1)
100 loops, best of 3: 4.95 ms per loop
- cython 版本
%%cython
def qsort_cython(a, lo, hi):
i = lo
j = hi
while i < hi:
pivot = a[(lo+hi) // 2]
while i <= j:
while a[i] < pivot:
i += 1
while a[j] > pivot:
j -= 1
if i <= j:
a[i], a[j] = a[j], a[i]
i += 1
j -= 1
if lo < j:
qsort_cython(a, lo, j)
lo = i
j = hi
return a
lst = [ random.random() for i in range(1,5000) ]
%timeit qsort_cython(lst, 0, len(lst)-1)
100 loops, best of 3: 2.47 ms per loop
- numba 版本
@jit
def qsort_numba(a, lo, hi):
i = lo
j = hi
while i < hi:
pivot = a[(lo+hi) // 2]
while i <= j:
while a[i] < pivot:
i += 1
while a[j] > pivot:
j -= 1
if i <= j:
a[i], a[j] = a[j], a[i]
i += 1
j -= 1
if lo < j:
qsort_numba(a, lo, j)
lo = i
j = hi
return a
lst = [ random.random() for i in range(1,5000) ]
%timeit qsort_numba(lst, 0, len(lst)-1)
10000 loops, best of 3: 150 µs per loop