一文读懂 AlphaTensor 论文

共 4751字,需浏览 10分钟

 ·

2022-10-19 11:19

前言

这篇文章的主要内容是,解读 AlphaTensor 这篇论文的主要思想,如何通过强化学习来探索发现更高效的矩阵乘算法。

1、二进制加法和乘法

这一节简单介绍一下计算机是怎么实现加法和乘法的。

2 + 52 * 5 为例。

我们知道数字在计算机中是以二进制形式表示的。

整数2的二进制表示为:0010

整数5的二进制表示为:0101

1.1、二进制加法

二进制加法很简单,也就是两个二进制数按位相加,如下图所示:

当然具体到硬件实现其实是包含了异或运算和运算,具体细节可以阅读文末参考的资料。

1.2、二进制乘法

二进制乘法其实也是通过二进制加法来实现的,如下图所示:

乘法在硬件上的实现本质是移位相加

对于二进制数来说乘数和被乘数的每一位非01

所以相当于乘数中的每一位从低位到高位,分别和被乘数的每一位进行与运算并产生其相应的局部乘积,再将这些局部乘积左移一位与上次的和相加。

从乘数的最低位开始:

若为1,则复制被乘数,并左移一位与上一次的和相加;

若为0,则直接将0左移一位与上一次的和相加;

如此循环至乘数的最高位。

从二进制乘法的实现也可以看出来,加法比乘法操作要快。

1.3、用加法替换乘法的简单例子

上面这个公式相信大家都很熟悉了,式子两边是等价的

左边包含了2次乘法和1次加法(减法也可以看成加法)

右边则包含了1次乘法和2次加法

可以看到通过数学上的等价变换,增加了加法的次数同时减少了乘法的次数。

2、矩阵乘算法

对于两个大小分别为 Q x RR x P 的矩阵相乘,通用的实现就需要 Q * P * R 次乘法操作(输出矩阵大小 Q x P,总共 Q * P 个元素,每个元素计算需要 R 次乘法操作)。

根据前面 1.2内容可知,乘法比加法慢,所以如果能减少的乘法次数就能有效加速矩阵乘的运算。

2.1、通用矩阵乘算法

首先来看一下通用的矩阵乘算法:

如上图所示,两个大小为2x2矩阵做乘法,总共需要8次乘法和4次加法。

2.2、Strassen 矩阵乘算法

上图所示即为 Strassen 矩阵乘算法,和通用矩阵乘算法不一样的地方是,引入了7个中间变量 m,只有在计算这7个中间变量才会用到乘法。

简单用 c1 验证一下:

可以看到 Strassen 算法总共包含7次乘法和18次加法,通过数学上的等价变换减少了1次乘法同时增加了14次加法。

3、AlphaTensor 核心思想解读

3.1、将矩阵乘表示为3维张量

首先来看下论文中的一张图

图中下方是3维张量,每个立方体表示3维张量一个坐标点。

其中张量每个位置的值只能是 0 或者 1,透明的立方体表示 0,紫色的立方体表示 1

现在将图简化一下,以[a,b,c]这样的维度顺序,将张量以维度a平摊开,这样更容易理解:

这个3维张量怎么理解呢?

比如对于 c1,我们知道 c1 的计算需要用到 a1,a2,b1,b3,对应到3维张量就是:

而从上图可知,对于两个 2 x 2 的矩阵相乘,3维张量大小为 4 x 4 x 4

一般的,对于两个 n x n 的矩阵相乘,3维张量大小为 n^2 x n^2 x n^2

更一般的,对于两个 n x mm x p 的矩阵相乘,3维张量大小为 n*m x m*p x n*p

然后论文中为了简化理解,都是以 n x n 矩阵乘来讲解的,论文中以

表示 n x n 矩阵乘的3维张量,下文中为了方便写作以 Tn 来表示。

3.2、3维张量分解

然后论文中提出了一个假设:

如果能将3维张量 Tn 分解为 R 个秩1的3维张量(R rank-one terms)的和的话,那么对于任意的 n x n 矩阵乘计算就只需要 R 次乘法。

如上图公式所示,就是表示的这个分解,其中的

就表示的一个秩1的3维张量,是由 u^(r)v^(r) 和  w^(r) 这3个一维向量做外积得到的。

这具体怎么什么理解呢?我们回去看一下 Strassen 矩阵乘算法:

上图左边就是 Strassen 矩阵乘算法的计算过程,右边的 UVW 3个矩阵,各自分别对应左边 U -> aV -> bW -> m

具体又怎么理解这三个矩阵呢?

我们在图上加一些标注来解释,其中 UVW 矩阵每一列从左到右按顺序,就对应上文提到的,u^(r)v^(r) 和  w^(r) 这3个一维向量。

然后矩阵 U 每一列和 [a1,a2,a3,a4] 做内积,矩阵 V 每一列和 [b1,b2,b3,b4] 做内积,然后内积结果相乘就得到 [m1,m2,m3,m4,m5,m6,m7]了。

最后矩阵 W 每一行和 [m1,m2,m3,m4,m5,m6,m7] 做内积就得到 [c1,c2,c3,c4]

接着再看一下的 UVW 这三个矩阵第一列的外积结果

如下图所示:

可以看到 UVW 三个矩阵每一列对应的外积的结果就是一个3维张量,那么这些3维张量全部加起来就会得到 Tn 么?下面我们来验证一下:

可以看到这些外积的结果全部加起来就恰好等于 Tn

所以也就证实了开头的假设:

如果能将表示矩阵乘的3维张量 Tn 分解为 R 个秩1的3维张量(R rank-one terms)的和,那么对于任意的 n x n 矩阵乘计算就只需要 R 次乘法。

因此也就很自然的可以想到,如果能找到更优的张量分解,也就是让 R 更小的话,那么就相当于找到乘法次数更小的矩阵乘算法了。

通过强化学习探索更优的3维张量分解

将探索3维张量分解过程变成游戏

论文中是采用了强化学习这个框架,来探索对3维张量Tn的更优的分解。强化学习的环境是一个单玩家的游戏(a single-player game, TensorGame)。

首先定义这个游戏进行 t 步之后的状态为 St

然后初始状态 S0 就设置为要分解的3维张量 Tn

对于游戏中的每一步t,玩家(就是本论文提出的 AlphaTensor)会根据当前的状态选择下一步的行动,也就是通过生成新的三个一维向量从而得到新的秩1张量:

接着更新状态 St减去这个秩1张量:

玩家的目标就是,让最终状态 St=0同时尽量的减少游戏的步数。

当到达最终状态 St=0 之后,也就找到了3维张量Tn的一个分解了:

还有些细节是,对于玩家每一步的选择都是给一个 -1 的分数奖励,其实也很容易理解,也就是玩的步数越多,奖励越低,从而鼓励玩家用更少的步数完成游戏。

而且对于一维向量的生成,也做了限制

就是生成这些一维向量的值,只限定在比如 [−2, −1, 0, 1, 2] 这5个离散值之内。

AlphaTensor 简要解读

论文中是怎么说的,在游戏过程中玩家 AlphaTensor 是通过一个深度神经网络来指导蒙特卡洛树搜索(MonteCarlo tree search)。关于这个蒙特卡洛树搜索,我不是很了解这里就不做解读了,有兴趣的读者可以阅读文末参考资料。

首先看下深渡神经网络部分:

深度神经网络的输入是当前的状态 St也就是需要分解的张量(上图中的最右边的粉红色立方体)。输出包含两个部分,分别是 Policy headValue head

其中 Policy head 的输出是对于当前状态可以采取的潜在下一步行动,也就是一维向量(u(t), v(t), w(t)) 的候选分布,然后通过采样得到下一步的行动。

然后 Value head 应该是对于给定的当前的状态 St ,估计游戏完成之后的最终奖励分数的分布。

接下来简要解读一下整个游戏的流程,还有深度神经网络是如何训练的:

先看流程图的上方 Acting 那个方框内,表示的是用训练好的网络做推理玩游戏的过程。

可以看到最左边绿色的立方体,也就是待分解的3维张量 Tn变换到粉红色立方体,论文中提到是作了基的变换,但是这块感觉如果不是去复现就不用了解的那么深入,而且我也没去细看这块就跳过吧。

然后从最初待分解的 Tn 开始,输入到神经网络,通过蒙特卡洛树搜索得到秩1张量,然后减去该张量之后,继续将相减的结果输入到网路中,继续这个过程直到张量相减的结果为0。

将游戏过程记录下来,就是流程图最右边的 Played game

然后流程图下方的 Learning 方框表示的就是训练过程,训练数据有两个部分,一个是已经玩过的游戏记录 Played games buffer 还有就是通过人工生成的数据。

人工怎么生成训练数据呢?

论文中提到,尽管张量分解是个 NP-hard 的问题,给定一个 Tn 要找其分解很难。但是我们可以反过来用秩1张量来构造出一个待分解的张量嘛!简单来说就是采样R个秩1张量,然后加起来就能的到分解的张量了。

因为对于强化学习这块我不是了解的并不深入,所以也就只能作粗浅的解读。

实验结果

最后看一下实验结果

表格最左边一列表示矩阵乘的规模,最右边三列表示矩阵乘算法乘法次数。

第一列表示目前为止,数学家找到的最优乘法次数。

第2和3列就是 AlphaTensor 找到的最优乘法次数。

可以看到其中有5个规模,AlphaTensor 能找到更优的乘法次数(标红的部分):

两个 4 x 44 x 4 的矩阵乘,AlphaTensor 搜索出47次乘法;

两个 5 x 55 x 5 的矩阵乘,AlphaTensor 搜索出96次乘法;

两个 3 x 44 x 5 的矩阵乘,AlphaTensor 搜索出47次乘法;

两个 4 x 44 x 5 的矩阵乘,AlphaTensor 搜索出63次乘法;

两个 4 x 55 x 5 的矩阵乘,AlphaTensor 搜索出76次乘法;

参考资料

  • https://www.nature.com/articles/s41586-022-05172-4
  • https://www.youtube.com/watch?v=3N3Bl5AA5QU&ab_channel=YannicKilcher
  • https://www.youtube.com/watch?v=gpYnDls4PdQ&ab_channel=HarvardMedicalAI%7CRajpurkarLab
  • https://www.jobilize.com/course/section/hardware-for-addition-and-subtraction-by-openstax
  • https://www.eet-china.com/mp/a94582.html
  • https://baike.baidu.com/item/%E7%A1%AC%E4%BB%B6%E4%B9%98%E6%B3%95%E5%99%A8/4865151
  • https://blog.csdn.net/SunnyYoona/article/details/43570853
  • https://nikcheerla.github.io/deeplearningschool/2018/01/01/AlphaZero-Explained/
  • https://www.youtube.com/watch?v=hmQogtp6-fs&ab_channel=GauravSen
  • https://www.youtube.com/watch?v=62nq4Zsn8vc&ab_channel=JoshVarty
  • https://www.youtube.com/watch?v=J3I3WaJei_E&ab_channel=%E8%B5%B0%E6%AD%AA%E7%9A%84%E5%B7%A5%E7%A8%8B%E5%B8%ABJames


浏览 68
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报