本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。einsum is all you needed!
如果问pytorch中最强大的一个数学函数是什么?
我会说是torch.einsum:爱因斯坦求和函数。它几乎是一个"万能函数":能实现超过一万种功能的函数。
不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。
einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。
尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。
套用一句深度学习paper标题当中非常时髦的话术,einsum is all you needed !
本文源码路径:
einsum规则原理
顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦。
很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。
在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?
小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。例如在我们熟悉的矩阵乘法中
k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。
小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。
这就是爱因斯坦求和约定:
只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。
公式立刻清爽了很多。
公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。
借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。
上述矩阵乘法可以被einsum这个函数表述成
C = torch.einsum("ik,kj->ij",A,B)
这个函数的规则原理非常简洁,3句话说明白。
- 1,用元素计算公式来表达张量运算。
- 2,只出现在元素计算公式箭头左边的指标叫做哑指标。
- 3,省略元素计算公式中对哑指标的求和符号。
import torch
A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])
C1 = A@B
print(C1)
C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)
tensor([[19., 22.],
[43., 50.]])
tensor([[19., 22.],
[43., 50.]])
einsum基础范例
einsum这个函数的精髓实际上是第一条:用元素计算公式来表达张量运算。
而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。
例1,张量转置
#例1,张量转置
A = torch.randn(3,4,5)
#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])
例2,取对角元
#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([5, 5])
after: torch.Size([5])
例3,求和降维
#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([4, 5])
after: torch.Size([4])
例4,哈达玛积
#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])
例5,向量内积
#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([10])
after: torch.Size([])
例6,向量外积
#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])
例7,矩阵乘法
#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])
例8,张量缩并
#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])
einsum高级范例
einsum可用于超过两个张量的计算。
例9,bilinear注意力机制
例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式
不考虑batch维度时,双线性变换的公式如下:
#例9,bilinear注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#a = q@W@k.t()+b
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)
#=====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,10) #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])
例10,scaled-dot-product注意力机制
我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.
不考虑batch维度时, scaled-dot-product形式的Attention用矩阵乘法公式表示如下:
#例10,scaled-dot-product注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(6,10) #key_size, key_features
d_k = k.shape[-1]
a = torch.softmax(q@k.t()/d_k,-1)
print("a.shape=",a.shape )
#====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,6,10) #batch_size,key_size,key_features
d_k = K.shape[-1]
A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1)
print("A.shape=",A.shape )
a.shape= torch.Size([6])
A.shape= torch.Size([8, 6])