文 | McGL
写深度学习网络代码,最大的挑战之一,尤其对新手来说,就是把所有的张量维度正确对齐。如果以前就有TensorSensor这个工具,相信我的头发一定比现在更浓密茂盛!
TensorSensor,码痴教授 Terence Parr 出品,他也是著名 parser 工具 ANTLR 的作者。
在包含多个张量和张量运算的复杂表达式中,张量的维数很容易忘了。即使只是将数据输入到预定义的 TensorFlow 网络层,维度也要弄对。当你要求进行错误的计算时,通常会得到一些没啥用的异常消息。为了帮助自己和其他程序员调试张量代码,Terence Parr 写了一个名叫 TensorSensor 的库(pip install tensor-sensor 直接安装) 。TensorSensor 通过增加消息和可视化 Python 代码来展示张量变量的形状,让异常更清晰(见下图)。它可以兼容 TensorFlow、PyTorch 和 Numpy以及 Keras 和 fastai 等高级库。
在张量代码中定位问题令人抓狂!
即使是专家,执行张量操作的 Python 代码行中发生异常,也很难快速定位原因。调试过程通常是在有问题的行前面添加一个 print 语句,以打出每个张量的形状。这需要编辑代码添加调试语句并重新运行训练过程。或者,我们可以使用交互式调试器手动单击或键入命令来请求所有张量形状。(这在像 PyCharm 这样的 IDE 中不太实用,因为在调试模式很慢。)下面将详细对比展示看了让人贫血的缺省异常消息和 TensorSensor 提出的方法,而不用调试器或 print 大法。
调试一个简单的线性层
让我们来看一个简单的张量计算,来说明缺省异常消息提供的信息不太理想。下面是一个包含张量维度错误的硬编码单(线性)网络层的简单 NumPy 实现。
import numpy as np
n = 200 # number of instances
d = 764 # number of instance features
n_neurons = 100 # how many neurons in this layer?
W = np.random.rand(d,n_neurons) # Ooops! Should be (n_neurons,d)
b = np.random.rand(n_neurons,1)
X = np.random.rand(n,d) # fake input matrix with n rows of d-dimensions
Y = W @ X.T + b # pass all X instances through layer
10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
执行该代码会触发一个异常,其重要元素如下:
...
---> 10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
异常显示了出错的行以及是哪个操作(matmul: 矩阵乘法),但是如果给出完整的张量维数会更有用。此外,这个异常也无法区分在 Python 的一行中的多个矩阵乘法。
接下来,让我们看看 TensorSensor 如何使调试语句更加容易的。如果我们使用 Python with 和tsensor 的 clarify()包装语句,我们将得到一个可视化和增强的错误消息。
import tsensor
with tsensor.clarify():
Y = W @ X.T + b
...
ValueError: matmul: Input operand ...
Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)
从可视化中可以清楚地看到,W 的维度应该翻转为 n _ neurons x d; W 的列必须与 X.T 的行匹配。您还可以检查一个完整的带有和不带阐明()的并排图像,以查看它在笔记本中的样子。下面是带有和没有 clarify() 的例子在notebook 中的比较。
clarify() 功能在没有异常时不会增加正在执行的程序任何开销。有异常时, clarify():
- 增加由底层张量库创建的异常对象消息。
- 给出出错操作所涉及的张量大小的可视化表示; 只突出显示异常涉及的操作对象和运算符,而其他 Python 元素则不突出显示。
TensorSensor 还区分了 PyTorch 和 TensorFlow 引发的与张量相关的异常。下面是等效的代码片段和增强的异常错误消息(Cause: @ on tensor ...)以及 TensorSensor 的可视化:
PyTorch 消息没有标识是哪个操作触发了异常,但 TensorFlow 的消息指出了是矩阵乘法。两者都显示操作对象维度。
调试复杂的张量表达式
缺省消息缺乏具体细节,在包含大量操作符的更复杂的语句中,识别出有问题的子表达式很难。例如,下面是从一个门控循环单元(GRU)实现的内部提取的一个语句:
h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
这是什么计算或者变量代表什么不重要,它们只是张量变量。有两个矩阵乘法,两个向量加法,还有一个向量逐元素修改(r*h)。如果没有增强的错误消息或可视化,我们就无法知道是哪个操作符或操作对象导致了异常。为了演示 TensorSensor 在这种情况下是如何分清异常的,我们需要给语句中使用的变量(为 h _ 赋值)一些伪定义,以得到可执行代码:
nhidden = 256
Whh_ = torch.eye(nhidden, nhidden) # Identity matrix
Uxh_ = torch.randn(d, nhidden)
bh_ = torch.zeros(nhidden, 1)
h = torch.randn(nhidden, 1) # fake previous hidden state h
r = torch.randn(nhidden, 1) # fake this computation
X = torch.rand(n,d) # fake input
with tsensor.clarify():
h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
同样,你可以忽略代码执行的实际计算,将重点放在张量变量的形状上。
对于我们大多数人来说,仅仅通过张量维数和张量代码是不可能识别问题的。当然,默认的异常消息是有帮助的,但是我们中的大多数人仍然难以定位问题。以下是默认异常消息的关键部分(注意对 C++ 代码的不太有用的引用) :
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
我们需要知道的是哪个操作符和操作对象出错了,然后我们可以通过维数来确定问题。以下是 TensorSensor 的可视化和增强的异常消息:
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: @ on tensor operand Uxh_ w/shape [764, 256] and operand X.T w/shape [764, 200]
人眼可以迅速锁定在指示的算子和矩阵相乘的维度上。哎呀, Uxh 的列必须与 X.T的行匹配,Uxh_的维度翻转了,应该为:
Uxh_ = torch.randn(nhidden, d)
现在,我们只在 with 代码块中使用我们自己直接指定的张量计算。那么在张量库的内置预建网络层中触发的异常又会如何呢?
理清预建层中触发的异常
TensorSensor 可视化进入你选择的张量库前的最后一段代码。例如,让我们使用标准的 PyTorch nn.Linear 线性层,但输入一个 X 矩阵维度是 n x n,而不是正确的 n x d:
L = torch.nn.Linear(d, n_neurons)
X = torch.rand(n,n) # oops! Should be n x d
with tsensor.clarify():
Y = L(X)
增强的异常信息
RuntimeError: size mismatch, m1: [200 x 200], m2: [764 x 100] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: L(X) tensor arg X w/shape [200, 200]
TensorSensor 将张量库的调用视为操作符,无论是对网络层还是对 torch.dot(a,b) 之类的简单操作的调用。在库函数中触发的异常会产生消息,消息标示了函数和任何张量参数的维数。
[1] https://explained.ai/tensor-sensor/index.html