本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。einsum is all you needed!

如果问pytorch中最强大的一个数学函数是什么?

我会说是torch.einsum:爱因斯坦求和函数。它几乎是一个"万能函数":能实现超过一万种功能的函数。

不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。

einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。

尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。

套用一句深度学习paper标题当中非常时髦的话术,einsum is all you needed !

本文源码路径:

https://github.com/lyhue1991/eat_pytorch_in_20_days/blob/master/4-2,%E5%BC%A0%E9%87%8F%E7%9A%84%E6%95%B0%E5%AD%A6%E8%BF%90%E7%AE%97.md

einsum规则原理

顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦。

很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。

在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?

小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。例如在我们熟悉的矩阵乘法中

Pytorch~einsum_人工智能

k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。

小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。

这就是爱因斯坦求和约定:

只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。

公式立刻清爽了很多。

Pytorch~einsum_人工智能_02

公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。

借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。

上述矩阵乘法可以被einsum这个函数表述成

C = torch.einsum("ik,kj->ij",A,B)

这个函数的规则原理非常简洁,3句话说明白。

  • 1,用元素计算公式来表达张量运算。
  • 2,只出现在元素计算公式箭头左边的指标叫做哑指标。
  • 3,省略元素计算公式中对哑指标的求和符号。

import torch 

A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])

C1 = A@B
print(C1)

C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)

tensor([[19., 22.],
        [43., 50.]])
tensor([[19., 22.],
        [43., 50.]])

einsum基础范例

einsum这个函数的精髓实际上是第一条:用元素计算公式来表达张量运算。

而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。

例1,张量转置

#例1,张量转置
A = torch.randn(3,4,5)

#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A) 

print("before:",A.shape)
print("after:",B.shape)

before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])

例2,取对角元

#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)

before: torch.Size([5, 5])
after: torch.Size([5])

例3,求和降维

#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)

before: torch.Size([4, 5])
after: torch.Size([4])

例4,哈达玛积

#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)

before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])

例5,向量内积

#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)

before: torch.Size([10]) torch.Size([10])
after: torch.Size([])

例6,向量外积

#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)

before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])

例7,矩阵乘法

#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)

before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])

例8,张量缩并

#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)

before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])

einsum高级范例

einsum可用于超过两个张量的计算。

例9,bilinear注意力机制

例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式

不考虑batch维度时,双线性变换的公式如下:

Pytorch~einsum_线性变换_03

#例9,bilinear注意力机制

#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features

#a = q@W@k.t()+b  
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)

#=====考虑batch维度====
Q = torch.randn(8,10)    #batch_size,query_features
K = torch.randn(8,10)    #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5)       #out_features

#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)

a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])

例10,scaled-dot-product注意力机制

我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.

不考虑batch维度时, scaled-dot-product形式的Attention用矩阵乘法公式表示如下:

 

Pytorch~einsum_人工智能_04

#例10,scaled-dot-product注意力机制

#====不考虑batch维度====
 q = torch.randn(10)  #query_features
 k = torch.randn(6,10) #key_size, key_features

 d_k = k.shape[-1]
 a = torch.softmax(q@k.t()/d_k,-1) 

 print("a.shape=",a.shape )

 #====考虑batch维度====
 Q = torch.randn(8,10)  #batch_size,query_features
 K = torch.randn(8,6,10) #batch_size,key_size,key_features

 d_k = K.shape[-1]
 A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1) 

 print("A.shape=",A.shape )

 a.shape= torch.Size([6])
 A.shape= torch.Size([8, 6])