什么是单元测试?
单元测试,就是对单元进行测试,英文叫 unit testing,是指对软件中的最小可测试单元进行检查和验证,比如一个函数,一个类。
Python 真的需要单元测试吗?
单元测试无关语言,逻辑简单,一眼就可以看出有无 bug 的程序,没必要单元测试。但现实世界的程序往往都不是图样图森破的,因此非常有必要进行单元测试。 单元测试是保证所写代码的稳定、高效、无误的关键。因此学会合理地使用单元测试,正是帮助我们写出高质量代码这一目标的重要路径。
单元测试是解决什么问题的?
先想一想自己没接触单元测试之前是如何保证自己写的程序是正确的,一般写完程序后,自己在 main 函数给一定的输入,打印输出,看程序是否按预期打印信息,然后接着写其他模块,最后整体运行,发现程序有问题,debug。
这样做有三个缺点:
一是 mian 函数里不能把测试用例和真的代码分离,需要经常添加测试代码再删除,二是没有打印出测试结果和期望结果,例如,expected: 3628800, but actual: 123456,三是很难编写一组通用的测试代码。
单元测试就是解决这种测试问题的,我们可以把多个测试用例写在一个类里,每次修改代码直接后运行一下单元测试,通过了就没问题,这里使用官方库 unittest 举个例子。
如何使用 unittest 库
import unittest
# 将要被测试的排序函数
def sort(arr):
l = len(arr)
for i in range(0, l):
for j in range(i + 1, l):
if arr[i] >= arr[j]:
tmp = arr[i]
arr[i] = arr[j]
arr[j] = tmp
# 编写子类继承unittest.TestCase
class TestSort(unittest.TestCase):
# 以test开头的函数将会被测试
def test_sort(self):
arr = [3, 4, 1, 5, 6]
sort(arr)
# assert 结果跟我们期待的一样
self.assertEqual(arr, [1, 3, 4, 5, 6])
if __name__ == '__main__':
## 如果在Jupyter下,请用如下方式运行单元测试
unittest.main(argv=['first-arg-is-ignored'], exit=False)
## 如果是命令行下运行,则:
## unittest.main()
## 输出
..
----------------------------------------------------------------------
Ran 2 tests in 0.002s
OK
这里我们使用了一个测试用例测试了排序函数,代码有详细的注释,相信你都可以看懂,首先,我们需要创建一个类继承 unittest.TestCase,然后,在这个类中定义相应的测试函数 test_sort(),进行测试。注意,测试函数要以 test 开头,而测试函数的内部,通常使用 assertEqual()、assertTrue()、assertFalse() 和 assertRaise() 等 assert 语句对结果进行验证。
如果你使用 IPython 或者 Jupyter 请使用下面的代码:
unittest.main(argv=['first-arg-is-ignored'], exit=False)
而如果你用的是命令行,直接使用 unittest.main() 就可以了。你可以看到,运行结果输出 OK
这是比较简单的,如果我们测试的函数有其他的依赖,如数据库等,网络接口等,我们就需要借助 mock。
如何使用 mock
mock 的英文含义是模拟,当我们的代码涉及数据库,文件,api 接口,其他服务时,单元测试将变的困难起来,有时候为了测试一个函数,我们需要启动 Mysql,Redis,ElstaticSearch,如果这些服务不在一台机器上,我们还要确保网络互通,测试成本太高,有了 mock,我们就可以模拟这些服务,将测试的精力集中在我们的单元测试上。
在 mock 模块中,两个常用的类型为 Mock,MagicMock,两个类的关系是 MagicMock 继承自 Mock,最重要的两个属性是 return_value, side_effect。
>>> from mock import Mock>>> mock_obj = Mock()>>>mock_obj.return_value = 'This is a mock object'>>> mock_obj()'This is a mock object'
通过 Mock() 可以创建一个 mock 对象,通过 renturn_value 指定它的返回值。即当下文出现 mock_obj() 会返回其 return_value 所指定的值。这里再给出一段 mock 示例:我们要测试的方法 m1 依赖方法 m2 的返回值,并使用 m2 返回值调用 m3,我们只需要测试 m1 逻辑的正确性,代码如下:
import unittest
from unittest.mock import MagicMock
class A(unittest.TestCase):
def m1(self):
val = self.m2()
print("do something")
self.m3(val)
def m2(self):
pass
def m3(self, val):
pass
def test_m1(self):
a = A()
a.m2 = MagicMock(return_value="custom_val")
a.m3 = MagicMock()
a.m1()
self.assertTrue(a.m2.called) #验证m2被call过
a.m3.assert_called_with("custom_val") #验证m3被指定参数call过
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
## 输出
..
----------------------------------------------------------------------
Ran 2 tests in 0.002s
OK
这里的 m1 为了代码方便,没有单独拿出来,实际上 m1 并不是类 A 的成员函数。
如何使用 Mock Side Effect
Mock Side Effect,这个概念很好理解,就是 mock 的函数,属性是可以根据不同的输入,返回不同的数值,而不只是一个 return_value。比如下面这个示例,例子很简单,测试的是输入参数是否为负数,输入小于 0 则输出为 1 ,否则输出为 2。代码很简短,你一定可以看懂,这便是 Mock Side Effect 的用法。
from unittest.mock import MagicMock
def side_effect(arg):
if arg < 0:
return 1
else:
return 2
mock = MagicMock()
mock.side_effect = side_effect
mock(-1)
1
mock(1)
2
也可以通过 side_effect 指定它的副作用,这个副作用就是当你调用这个 mock 对象是会调用的函数,也可以选择抛出一个异常,来对程序的错误状态进行测试。
>>>def b():
... print 'This is b'
...
>>>mock_obj.side_effect = b
>>>mock_obj()
This is b
>>>mock_obj.side_effect = KeyError('This is b')
>>>mock_obj()
...
KeyError: 'This is b'
如果要模拟一个对象而不是函数,你可以直接在 mock 对象上添加属性和方法,并且每一个添加的属性都是一个 mock 对象【注意,这种方式很有用】,也就是说可以对这些属性进行配置,并且可以一直递归的定义下去。
>>>mock_obj.mock_a.return_value = 'This is mock_obj.mock_a'>>>mock_obj.mock_a()'This is mock_obj.mock_a'
上述代码片段中 mock_obj 是一个 mock 对象,而 mock_obj.mock_a 的这种形式使得 mock_a 变成了 mock_obj 的一个属性,作用是在 mock_obj.mock_a() 调用时会返回其 return_value。 另外也可以通过为 side_effect 指定一个列表,这样在每次调用时会依次返回,如下:
>>> mock_obj = Mock(side_effect = [1, 2, 3])>>>mock_obj()1>>>mock_obj()2>>>mock_obj()3
在单元测试中给对象打补丁
patch 用于单元测试中需要给指定的对象打补丁, 用来断言它们在测试中的期望行为(比如,断言被调用时的参数个数,访问指定的属性等)。
patch 给开发者提供了非常便利的函数 mock 方法。它可以应用 Python 的 decoration 模式或是 context manager 概念,快速自然地 mock 所需的函数。它的用法也不难,我们来看代码:
from unittest.mock import patch
import example
@patch('example.func')
def test1(x, mock_func):
example.func(x) # Uses patched example.func
mock_func.assert_called_with(x)
它还可以被当做一个上下文管理器:
with patch('example.func') as mock_func: example.func(x) # Uses patched example.func mock_func.assert_called_with(x)
最后,你还可以手动的使用它打补丁:
p = patch('example.func')mock_func = p.start()example.func(x)mock_func.assert_called_with(x)p.stop()
如果可能的话,你能够叠加装饰器和上下文管理器来给多个对象打补丁。例如:
@patch('example.func1')
@patch('example.func2')
@patch('example.func3')
def test1(mock1, mock2, mock3):
...
def test2():
with patch('example.patch1') as mock1, \
patch('example.patch2') as mock2, \
patch('example.patch3') as mock3:
...
patch() 接受一个已存在对象的全路径名,将其替换为一个新的值。 原来的值会在装饰器函数或上下文管理器完成后自动恢复回来。 默认情况下,所有值会被 MagicMock 实例替代。例如
>>> x = 42
>>> with patch('__main__.x'):
... print(x)
...
<MagicMock name='x' id='4314230032'>
>>> x
42
>>>
不过,你可以通过给 patch() 提供第二个参数来将值替换成任何你想要的:
>>> x42>>> with patch('__main__.x', 'patched_value'):... print(x)...patched_value>>> x42>>>
一个实例
如果你还不理解,那么我们举个实用的例子。假设你已经有了像下面这样的函数,文件名:example.py,并对它进行单元测试。
# example.py
from urllib.request import urlopen
import csv
def dowprices():
u = urlopen('http://finance.yahoo.com/d/quotes.csv?s=@^DJI&f=sl1')
lines = (line.decode('utf-8') for line in u)
rows = (row for row in csv.reader(lines) if len(row) == 2)
prices = { name:float(price) for name, price in rows }
return prices
这个函数需要网络连接,函数会使用 urlopen() 从Web上面获取数据并解析它。如果在内网开发,怎么测试呢,就需要 mock。
在单元测试中,你可以给它一个预先定义好的数据集。下面是使用补丁操作的例子:
import unittest
from unittest.mock import patch
import io
import example
sample_data = io.BytesIO(b'''\
"IBM",91.1\r
"AA",13.25\r
"MSFT",27.72\r
\r
''')
class Tests(unittest.TestCase):
@patch('example.urlopen', return_value=sample_data)
def test_dowprices(self, mock_urlopen):
p = example.dowprices()
self.assertTrue(mock_urlopen.called)
self.assertEqual(p,
{'IBM': 91.1,
'AA': 13.25,
'MSFT' : 27.72})
if __name__ == '__main__':
unittest.main()
在打补丁时我们使用了 example.urlopen 来代替 urllib.request.urlopen 。 当你创建补丁的时候,你必须使用它们在测试代码中的名称。 由于测试代码使用了 from urllib.request import urlopen ,那么 dowprices() 函数 中使用的 urlopen() 函数实际上就位于 example 模块了。