python 常因速度慢而被诟病。其实优化代码性能的方案有很多,最主要的一般是两个方向:

  • 从算法本身优化,这是最根本和彻底的优化;
  • 从语言本身和工具层面优化,不外乎类型化、空间换时间等;

刚好看到一篇文章 讲的比较好,整理一下。

一、Fibonacci 函数

常见的Fibonacci 函数实现有两种,一种是递归,一种是非递归(好像是废话。。。。)

  • 常规代码
def fib(n):
 if n<2:
   return n
 return fib(n-1)+fib(n-2)

def fib_seq(n):
 if n < 2:
   return n
 a,b = 1,0
 for i in range(n-1):
   a,b = a+b,a
 return a 

%timeit fib(20)

%timeit fib_seq(20)

100 loops, best of 3: 2.28 ms per loop
1000000 loops, best of 3: 1.04 µs per loop
  • cython 版本
%%cython

def fib_cython(n):
 if n<2:
  return n
 return fib_cython(n-1)+fib_cython(n-2)

def fib_seq_cython(n):
 if n < 2:
   return n
 a,b = 1,0
 for i in range(n-1):
   a,b = a+b,a
 return a 

cpdef long fib_seq_cython_type(long n):
 if n < 2:
   return n
 cdef long a,b
 a,b = 1,0
 for i in range(n-1):
   a,b = a+b,b
 return a
 
%timeit fib_cython(20)

%timeit fib_seq_cython(20)
 
%timeit fib_seq_cython_type(20)

1000 loops, best of 3: 741 µs per loop
1000000 loops, best of 3: 599 ns per loop
10000000 loops, best of 3: 33 ns per loop
  • cache 版本
from functools import lru_cache as cache

@cache(maxsize=None)
def fib_cache(n):
 if n<2:
  return n
 return fib_cache(n-1)+fib_cache(n-2)

@cache(maxsize=None)
def fib_seq_cache(n):
 if n < 2:
   return n
 a,b = 1,0
 for i in range(n-1):
   a,b = a+b,a
 return a 
 
%timeit fib_cache(20)

%timeit fib_seq_cache(20)

10000000 loops, best of 3: 84.1 ns per loop
10000000 loops, best of 3: 87.9 ns per loop
  • numba 版本
from numba import jit

@jit
def fib_numba(n):
 if n<2:
  return n
 return fib_numba(n-1)+fib_numba(n-2)

@jit
def fib_seq_numba(n):
 if n < 2:
  return n
 
 (a,b) = (1,0)
 for i in range(n-1):
  (a,b) = (a+b,a)
 return a
 
%timeit fib_numba(20)

%timeit fib_seq_numba(20)

10000 loops, best of 3: 47.6 µs per loop
10000000 loops, best of 3: 164 ns per loop

二、快速排序

  • 常规代码
def qsort_kernel(a, lo, hi):
    i = lo
    j = hi
    while i < hi:
        pivot = a[(lo+hi) // 2]
        while i <= j:
            while a[i] < pivot:
                i += 1
            while a[j] > pivot:
                j -= 1
            if i <= j:
                a[i], a[j] = a[j], a[i]
                i += 1
                j -= 1
        if lo < j:
            qsort_kernel(a, lo, j)
        lo = i
        j = hi
    return a

import random

lst = [ random.random() for i in range(1,5000) ]

%timeit qsort_kernel(lst, 0, len(lst)-1)

100 loops, best of 3: 4.95 ms per loop
  • cython 版本
%%cython

def qsort_cython(a, lo, hi):
    i = lo
    j = hi
    while i < hi:
        pivot = a[(lo+hi) // 2]
        while i <= j:
            while a[i] < pivot:
                i += 1
            while a[j] > pivot:
                j -= 1
            if i <= j:
                a[i], a[j] = a[j], a[i]
                i += 1
                j -= 1
        if lo < j:
            qsort_cython(a, lo, j)
        lo = i
        j = hi
    return a

lst = [ random.random() for i in range(1,5000) ]

%timeit qsort_cython(lst, 0, len(lst)-1)

100 loops, best of 3: 2.47 ms per loop
  • numba 版本
@jit
def qsort_numba(a, lo, hi):
    i = lo
    j = hi
    while i < hi:
        pivot = a[(lo+hi) // 2]
        while i <= j:
            while a[i] < pivot:
                i += 1
            while a[j] > pivot:
                j -= 1
            if i <= j:
                a[i], a[j] = a[j], a[i]
                i += 1
                j -= 1
        if lo < j:
            qsort_numba(a, lo, j)
        lo = i
        j = hi
    return a

lst = [ random.random() for i in range(1,5000) ]

%timeit qsort_numba(lst, 0, len(lst)-1)

10000 loops, best of 3: 150 µs per loop