1. 简介

爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩等张量操作,熟练运用 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。通常情况下,可以在计算效果上等同于pytorch中的矩阵乘法运算等,且书写方式较为随意,主要是从理解层面进行操作。在Python中的引入如下:

1
2
import torch #只用下面一行即可引入einsum包,这里是为了下面的矩阵定义做准备
from torch import einsum

2. 矩阵乘法、转置用法及计算原理

如果你已经明白einsum中的维度对应关系,可以直接跳过2.1和2.2部分,直接开始看2.3部分。

2.1 入门理解

为了讲明白einsum的用法及计算原理,我们首先定义几个矩阵,从几个例子入手:

1
2
a = torch.randn([35])
b = torch.randn([5, 6])

用einsum实现a和b之间的矩阵乘法操作,如下代码:

1
2
mutiple = einsum('ik,kj -> ij', a, b)
print(mutiple.shape) # result:[3,6]

用einsum实现a矩阵转置操作,如下代码:

1
2
trans = einsum('ij -> ji', a)
print(trans.shape) # result:[5,3]

没错,你是否已经观察出了一些规律呢?那我们来验证一下你的观察结果是否正确,看下面的代码,猜猜看输出会是什么结果:

1
2
3
4
5
6
7
8
c = torch.randn([2, 2, 4, 5])
d = torch.randn([2, 2, 5, 5])
guess_1 = einsum('bhqd,bhdk -> bhqk', c, d)
guess_2 = einsum('bhqd -> bdhq', c)
guess_3 = einsum('bhqd,bhkd -> bhkq', c, d)
print(guess_1.shape)
print(guess_2.shape)
print(guess_3.shape)

现在验证一下你的猜测是否正确:

1
2
3
guess_1.shape = torch.Size([2, 2, 4, 5])
guess_2.shape = torch.Size([2, 5, 2, 4])
guess_3.shape = torch.Size([2, 2, 5, 4])

2.2 维度理解

我们来理解一下上述的三个guess,它的结果到底是怎么计算出来的。

  1. guess_1: 很简单,按照之前的规律,bhqd、bhdk 指代 c、d 两个输入矩阵的原始维度。也就是说,在 -> 的左边部分,bhqd 四个字母分别表示维度2 2 4 5,bhdk 四个字母分别表示维度2 2 5 5,那么 -> 右边的 bhqk 应该对应着2 2 4 5。
  2. guess_3:我们仍然将 -> 左边和右边的字母指代的维度作比较,很eazy就能得出结论:2 2 5 4。

看到这里我们发现的规律应该是这样:
在einsum函数里,我们需要关注的分为两部分,第一部分为字符串,如'bhqd,bhdk -> bhqk',第二部分为后面跟着的参数,如 c 和 d 。并且我们对字符串中的理解仍然分为两部分,这两部分由 -> 划分开。且 -> 左边的部分由逗号隔开,分别与参数一一对应。进一步理解,字符串中的每一个字符都指代其对应参数的某一个维度值。例如:einsum('bhqd,bhdk -> bhqk', c, d)中的 bhqd 分别对应着 c 矩阵的四个维度2 2 4 5,b 矩阵的对应替代关系同理。然后根据这种对应关系来算出 -> 右边的字母对应的维度,即为该函数返回的变量维度。
看到这里,你已经能准确判断出einsum函数的维度变换关系了,接下来该看看它计算的结果了。

2.3 内部计算

要探究它内部的计算关系,我们还是回到第一个例子:

1
2
3
a = torch.randn([35])
b = torch.randn([5, 6])
mutiple = einsum('ik,kj -> ij', a, b)

为什么说上述式子就实现了矩阵的乘法呢,那是因为该函数的三条基本准则:

  1. 在 -> 左边不同输入之间重复出现的索引(索引即是英文字母)表示把输入张量沿着该维度做乘法操作,比如einsum('ik,kj -> ij', a, b),k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作。
  2. 只出现在 -> 左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面例子中的 k。
  3. 在 -> 右边的索引顺序可以是任意的,比如上面的'ik,kj->ij'如果写成'ik,kj->ji',那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

是不是觉得云里雾里,没关系,我们先不管这个准则,回到我们的三个guess里。为了方便观看,我将其贴到这里来:

1
2
3
4
5
6
7
8
c = torch.randn([2, 2, 4, 5])
d = torch.randn([2, 2, 5, 5])
guess_1 = einsum('bhqd,bhdk -> bhqk', c, d)
guess_2 = einsum('bhqd -> bdhq', c)
guess_3 = einsum('bhqd,bhkd -> bhkq', c, d)
print(guess_1.shape) # [2, 2, 4, 5]
print(guess_2.shape) # [2, 5, 2, 4]
print(guess_3.shape) # [2, 2, 5, 4]

我们将上述 bhqd 理解为四个索引,即通过这四个索引值便可以定位 c 矩阵的任意位置。(举个例子,我们可以理解为c = torch.randn([2, 2, 4, 5])是一层书架,它的四个维度分别为 bhqd,b 表示这层书架上的书本数量为2,h 理解为每本书一共有2页,q 表示每一页书一共有4行字,d 表示每一行字一共有5个字。那么我们就可以通过一个 bhqd 索引来确定这一层书架上的每一个字。读者以后也可以通过这种方式理解高维矩阵,如果是五维,可以再加一个维度是书架层数等)
那么guess_1矩阵的 bhqk 也是它的索引,而且这个索引是通过 c 矩阵的 bhqd 与 d 矩阵的 bhdk 沿着 d 这个维度(这里为什么是沿着 d 维度,参照上面三条基本准则的第一条)做内积得到的。用数学公式表达则是:

$$guess1_{[b,h,q,k]}= \sum\limits_{i=0}^{d}c_{[b,h,q,i]}*d_{[b,h,i,k]}$$

guess_1中提到的'bhqd,bhdk -> bhqk'通过直观的理解,就是在 qd 维度上与 dk 维度相乘,自然而然得到了 qk。也符合矩阵乘法的运算规律。
guess_2中不涉及到运算,只是针对维度的变换,可以参照上面关于书架的假设来理解。
guess_3中提到的'bhqd,bhkd -> bhqk'虽然也是对后两个维度进行乘法操作,但是书写方式与我们常规理解的矩阵乘法有出入(通常的书写方式为$ [i,k]*[k,j] = [i,j] $),我们可以用数学公式帮助理解:

$$guess3_{[b,h,k,q]}= \sum\limits_{i=0}^{d}c_{[b,h,q,i]}*d_{[b,h,k,i]}$$

下图是对guess_3公式中索引含义的进一步解释,因为 b 和 h 可以理解为多几个维度的切片参与计算,所以可以先理解成二维矩阵之间相乘,再拓展到高维的多层矩阵乘法即可。假设guess_3中的索引值为[b h 3 2],从guess_3的求和公式可以得出,是 c 矩阵的 q 维度为2,d 矩阵的 k 维度为3,然后再从 i=0 开始对 d 维度做内积并求和。对应到图中则是 c 矩阵中橙色的一整行与 d 矩阵的粉色一整行对应位置相乘并相加。这是二维之间相乘,对应到高维,则是 c 矩阵中后续通道中对应橙色位置的一整行与 d 矩阵中后续通道中对应粉色位置的一整行相乘再全部相加,得到最终的guess_3矩阵中后续通道的紫色位置的值。同理,通过这样的乘法方式计算出guess_3矩阵每个位置的结果即可。

图1. guess_3中矩阵乘法可视化