Attention Free

在本站文章《注意力机制及其线性化之路》中我们简单梳理了注意力机制的原理、不同评分函数的设计以及走在科研道路上的线性化之路。归根结底,各种 Attention 的变体,其本质都是为了对于给定的输入查询qq ,针对每一个kk 得出一个注意力权重,然后将kk 对应的vv 按这个注意力权重加权平均得到输出。

所以,我们其实也没必要将注意力计算的方式固定在缩放点积注意力上。只要我们能够以比 Transformer 更低的复杂度比较有效的计算出注意力权值,那么便可以解决Transformer带来的高复杂度问题。因此,由Apple公司提出的一种新型的神经网络模型Attention Free Transformer (AFT) 就聚焦于这一点,尝试设计一个高效的权值计算方式。

AFT的论文中提出了四个不同的Attention Free的计算方式,分别是AFT-full,AFT-local,AFT-simple和AFT-conv,下面我们来分别介绍他们。

AFT-full

AFT-full 是 AFT 的原始形态,它首先也是和自注意力机制一样先将输入通过三个线性层得到Q,K,V.

学习 Transformer-XL 引入了相对位置编码wt,tw_{t,t'} ,它是一个可学习的标量参数并且对 multi-head 是共享的,用来表示第tt 个输入与第tt' 个输入的位置关系。

然后直接就对加入位置编码偏置Kt+wt,t\mathbf K_{t^{\prime}}+w_{t, t^{\prime}} 过一个softmax\mathrm{softmax} ,从而得到后续给vv 加权的权值:

Weighted(Kt)=exp(Kt+wt,t)t=1Texp(Kt+wt,t)\text{Weighted}(\mathbf K_{t^{\prime}})=\frac{\exp \left(\mathbf K_{t^{\prime}}+w_{t, t^{\prime}}\right)}{\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}+w_{t, t^{\prime}}\right)}

开始给vv 加权:

t=1TWeighted(Kt)    Vt\sum_{t'=1}^{T}\text{Weighted}(\mathbf K_{t^{\prime}})\;\odot\;\mathbf V_{t'}

最后再和使用 Sigmoid 非线性化后的Qt{\bf Q}_t 按元素相乘,得到最终输出Yt{\bf Y}_t

Yt=σq(Qt)t=1Texp(Kt+wt,t)Vtt=1Texp(Kt+wt,t)\mathbf Y_t=\sigma_q\left(\mathbf Q_t\right) \odot \frac{\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}+w_{t, t^{\prime}}\right) \odot \mathbf V_{t^{\prime}}}{\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}+w_{t, t^{\prime}}\right)}

实际上,如果把σq(Qt)\sigma_q\left(\mathbf Q_t\right) 写到最外层的求和符号里面,那么vv 前面的一坨由q,kq,k 构成的式子我们仍然可以看做是一种注意力权重。

AFT-local

AFT-full 的实现难点在于t=1Texp(Kt+wt,t)Vt\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}+w_{t, t^{\prime}}\right) \odot \mathbf V_{t^{\prime}} 的计算,而这个其实可以通过矩阵乘法实现,有:

exp(W)×(exp(K)V)\exp(\mathbf W)\times(\exp(\mathbf K)\odot\mathbf V)

实际上不难得出它的复杂度和 Transformer 一样也是O(T2d)\mathcal O(T^2d) ,不过本身的速度还是相比 Transformer 有所优化。而且这个二次复杂性的根源在于矩阵wt,:\boldsymbol w_{t,:}
为了化简,乃至引入局部信息能力,可以设置一个窗口大小s<Ts\lt T 来约束位置编码:

wt,t={wt,t, if tt<s0, otherwise w_{t, t^{\prime}}= \begin{cases}w_{t, t^{\prime}}, & \text { if }\left|t-t^{\prime}\right|<s \\ 0, & \text { otherwise }\end{cases}

这也就是 AFT-local。

🤔这不就类似于 Local Attention 吗?

AFT-conv

作者进一步扩展到局部权重共享,即卷积方法上,用以使用视觉任务。具体来说,将w\boldsymbol w 作为可学习参数(其实就是卷积核)与K\bf K 做卷积,同时又保证之前的全局计算保留下来,有:

Yti=σq(Qti)conv(exp(Ki)Vi,exp(wi)1)+t=1Texp(Kti)Vticonv(exp(Ki),exp(wi)1)+t=1Texp(Kti)\mathbf Y_t^i=\sigma_q\left(\mathbf Q_t^i\right) \odot \frac{\operatorname{conv}\left(\exp \left(\mathbf K^i\right) \odot \mathbf V^i, \exp \left(\boldsymbol w^i\right)-1\right)+\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}^i\right) \odot \mathbf V_{t^{\prime}}^i}{\operatorname{conv}\left(\exp \left(\mathbf K^i\right), \exp \left(\boldsymbol w^i\right)-1\right)+\sum_{t^{\prime}=1}^T \exp \left(\mathbf K_{t^{\prime}}^i\right)}

这里的上标ii 在 Transformer 中被理解成ii-th multi-head,在视觉任务中可以理解成ii-th channel。

AFT-simple

如果去掉位置编码的话,我们会有形如Y=σq(Q)(softmax(K)V)\mathbf Y=\sigma_q(\mathbf Q)\odot\sum\big(\text{softmax}(\mathbf K)\odot\mathbf V\big) 形式的表达,这其实与线性注意力进化之路的思想如出一辙。这便是AFT的极简版本。在这个版本中,上下文缩减进一步简化为逐元素操作和全局池化。

RWKV-4

RWKV(读作 RWaKuV)是一种具有 GPT 级大型语言模型(LLM)性能的RNN,也可以像 GPT Transformer 一样直接训练(可并行化)。项目最初由彭博(Bo Peng ,BlinkDL)提出,随着项目被外界关注,RWKV 项目逐渐发展成一个开源社区。

RWKV 模型架构有 RWKV-4、RWKV-5(代号 Eagle)、RWKV-6(代号 Finch) 三个正式版本,更早以前的 RWKV 1/2/3 为实验版本。
其中,RWKV-4 是 RWKV 模型的首个正式版本,论文由 RWKV 作者彭博和 RWKV 社区共同完成,初次发表于 2023 年 5 月 22 日。

RWKV 架构的名称来源于时间混合通道混合块中使用的四个主要模型参数,分别如下:

  • RRReceptance,作为过去信息的接受程度的接受向量
  • WWWeight,位置权重衰减向量,可训练的模型参数
  • KK:键(Key)是类似于传统注意力中 KK 的向量
  • VV:值(Value)是类似于传统注意力中 VV 的向量

上图是 RWKV-V4 论文中的模型架构概览,其中:

  • 左侧:RWKV-V4 的 time-mixing 和 channel-mixing 模块
  • 右侧:RWKV-V4 的语言建模流程

🏷️官方社区: RWKV 中国

Token Shift

与 Transformer 的 Self-attention 类似,RWKV 中的两个关键 mixing 模块也运用了三个从原始输入投影而来的变量(R,K,VR,K,V for time mixing;R,KR',K' for channel mixing)。不同的是,这些变量除了由当前时间步tt 的输入,还包含了上一时间步t1t-1 的输入(在LLM中就是当前读到的单词的 token)。

类似于 RNN 的遗忘门操作,设立了固定的参数μ\mu 来平衡上一时间步的信息需要保留多少。

rt=Wr(μrxt+(1μr)xt1),kt=Wk(μkxt+(1μk)xt1),vt=Wv(μvxt+(1μv)xt1)rt=Wr(μrxt+(1μr)xt1),kt=Wk(μkxt+(1μk)xt1)\begin{aligned} \boldsymbol{r_{t}} & =\mathbf W_{r} \cdot( \mu_{r} \odot\boldsymbol x_{t}+( 1-\mu_{r} ) \odot\boldsymbol x_{t-1} ), \\ \boldsymbol{k_{t}} & =\mathbf W_{k} \cdot( \mu_{k} \odot\boldsymbol x_{t}+( 1-\mu_{k} ) \odot\boldsymbol x_{t-1} ), \\ \boldsymbol{v_{t}} & =\mathbf W_{v} \cdot( \mu_{v} \odot\boldsymbol x_{t}+( 1-\mu_{v} ) \odot\boldsymbol x_{t-1} ) \\\\ \boldsymbol{r_{t}'} & =\mathbf W_{r} \cdot( \mu_{r}' \odot\boldsymbol x_{t}+( 1-\mu_{r}' ) \odot\boldsymbol x_{t-1} ), \\ \boldsymbol{k_{t}'} & =\mathbf W_{k} \cdot( \mu_{k}' \odot\boldsymbol x_{t}+( 1-\mu_{k}' ) \odot\boldsymbol x_{t-1} ) \end{aligned}

NOTE: Channel mixing 模块的rt\boldsymbol{r}_t'kt\boldsymbol{k}_t' 不是从原始的输入xt\boldsymbol{x}_t 来的,公式里的xt\boldsymbol{x}_t 实际上是由 Time mixing 模块输出后经过 LayerNorm 得到的。(详见模型架构图)

Time Mixing

Time mixing 模块主要是受理所有输入数据并捕捉到 token 之间的时间依赖关系。RWKV 中则是借鉴了 AFT 的方法(说白了就是利用一种带位置偏置的线性注意力方法)来捕捉这种关系,然后输出隐变量。

值得注意的是,RWKV 中对位置偏置{wt,i}RT×T\{w_{t,i}\}\in\mathbb R^{T\times T} 不再将其全部作为可学习参数了,而是只学习参数wR(0)w\in\mathbb R_{(\geq 0)} ,每个 channel(这里是指输入向量的每一个维度) 都有一个这样的参数,称其为 a channel-wise time decay vector。

将其乘以当前时间步tttoken 和之前的时间步ii 相对位置来进行衰减,即对于其中一个 channel 来说,有:

wt,i=(ti)ww_{t,i}=-(t-i)w

此外,为了避免位置偏置W\bf W 产生的可能的潜在退化,RWKV还多引入一个单独关注当前 token 的向量。对前t1t-1 个时间步做位置偏置衰减,然后再保留一个当前时间步的信息。

最终 RWKV 延续了将注意力以RNN推理的思路,给出了从第一个时间步(i=1i=1)开始到当前时间步的wkvt\boldsymbol{wkv}_t 操作:

wkvt=i=1t1exp(ki(ti)w)vi+exp(kt+u)vii=1t1exp(ki(ti)w)+exp(kt+u)\boldsymbol{wkv}_t=\frac{\sum_{i=1}^{t-1} \exp \left(\boldsymbol k_{i}-(t-i)w\right) \odot \boldsymbol v_{i}+\exp(\boldsymbol k_t+u)\odot \boldsymbol v_{i}}{\sum_{i=1}^{t-1} \exp \left(\boldsymbol k_{i}-(t-i)w\right)+\exp(\boldsymbol k_t+u)}

和 AFT 一样,要得到 Time mixing 的输出,只需要将wkvt\boldsymbol{wkv}_t 再乘上过了 Sigmoid 的σ(rt)\sigma(\boldsymbol r_t) 即可(再加一个线性投影):

ot=Wo(σ(rt)wkvt)\boldsymbol o_t=\mathbf W_o\big(\sigma(\boldsymbol r_t)\odot \boldsymbol{wkv}_t\big)

Channel Mixing

Channel mixing 模块因为在 Time mixing 模块之后使用,该过程对拿到的输入重新做了线性投影,并且引入了非线性激活,这些操作充分让每一个 token 内部的 d_model 维向量数据充分混合。

ot=σ(rt)(Wvmax(kt,0)2)\boldsymbol o_t'=\sigma(\boldsymbol r_t')\odot\big(\mathbf W_v'\cdot \max( \boldsymbol{k}_t',0)^2\big)

其中max(,0)2\max(\cdot,0)^2 即为 平方 ReLU。

Conclusion

RWKV 的优点:结合了 Transformer 和 RNN 的优势,训练时能够像 Transformer 那样并行计算,推理时又能像 RNN 那样高效。尤其是后者,对于降低模型成本,尤其是在端侧部署有重要意义。另外 RWKV 的计算量与上下文长度无关,对于更长的上下文可能有更好的扩展性。

RWKV 的缺点:和 RNN 一样,历史信息是靠隐状态(WKV)来记忆的,对于长距离历史信息的记忆不如 Transformer。这个很容易理解,因为 RWKV 的历史信息是存在一个向量里,时间越久衰减就越厉害,与 full attention 比自然是有局限性的。这个局限性也使得 prompt engineering 对 RWKV 更加重要。与 Transformer 相比,由于 RWKV 对很长的上下文记忆能力有限,如何设计提示词对模型的性能会有很大影响。

RWKV-5

RWKV-5 是 RWKV-4 架构的改良版本,版本代号“Eagle(鹰)”。

RWKV-5 和 RWKV-6 这两个架构在同一篇论文 《Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence》中发布。

该论文初次发表于 2024 年 4 月 9 日。同年 10 月,RWKV 5/6 架构论文被 LLM 领域顶级会议 COLM 2024 收录。

上图是论文中 RWKV 5/6 的架构概览,其中:

  • 左侧:RWKV 的 time-mixing 和 channel-mixing 模块
  • 右上角:作为 RNN 单元时的 RWKV time-mixing 模块,虚线箭头表示 RWKV-V6 架构的连接,但在 RWKV-5 中不存在
  • 底部的中间:前向传播模式下 RWKV-5 time-mixing 的 token-shift 模块
  • 右下角:前向传播模式下 RWKV-6 time-mixing 的 token-shift 模块

相对 RWKV-4, RWKV-5 的最重点改动在于引入了多头的基于矩阵值的状态(state),即论文中的 “multi-headed matrix-valued states”。

在 RWKV-4 架构的 time mixing 计算中,uuwwkkvv 参数都是维度为DD 的向量,而 head size 是 1 :

ttRWKV-4u,w,kt,vtRDu, w, k_t, v_t \in \mathbb{R}^D, head size 1
0σ(r0)uk0v0uk0\sigma(r_0) \odot \frac{u \odot k_0 \odot v_0}{u \odot k_0}
1σ(r1)uk1v1+k0v0uk1+k0\sigma(r_1) \odot \frac{u \odot k_1 \odot v_1 + k_0 \odot v_0}{u \odot k_1 + k_0}
2σ(r2)uk2v2+k1v1+wk0v0uk2+k1+wk0\sigma(r_2) \odot \frac{u \odot k_2 \odot v_2 + k_1 \odot v_1 + w \odot k_0 \odot v_0}{u \odot k_2 + k_1 + w \odot k_0}
3σ(r3)uk3v3+k2v2+wk1v1+w2k0v0uk3+k2+wk1+w2k0\sigma(r_3) \odot \frac{u \odot k_3 \odot v_3 + k_2 \odot v_2 + w \odot k_1 \odot v_1 + w^2 \odot k_0 \odot v_0}{u \odot k_3 + k_2 + w \odot k_1 + w^2 \odot k_0}

而 RWKV-5 则将uuww参数分别进行对角化,而kkvv 从维度为DD向量转化为维度为64×6464\times64矩阵, head size 大小改为固定的 64 。

RWKV-V 的 time-mixing 时间步:

ttEagle (RWKV-5)diag(u)\text{diag}(u),diag(w)\text{diag}(w),ktk_t,vtR64×64v_t \in \mathbb{R}^{64 \times 64} for each head, head size 64
0r0(diag(u)k0v0)r_0 \cdot (\text{diag}(u) \cdot k_0^\top \cdot v_0)
1r1(diag(u)k1v1+k0v0)r_1 \cdot (\text{diag}(u) \cdot k_1^\top \cdot v_1 + k_0^\top \cdot v_0)
2r2(diag(u)k2v2+k1v1+diag(w)k0v0)r_2 \cdot (\text{diag}(u) \cdot k_2^\top \cdot v_2 + k_1^\top \cdot v_1 + \text{diag}(w) \cdot k_0^\top \cdot v_0)
3r3(diag(u)k3v3+k2v2+diag(w)k1v1+diag(w2)k0v0)r_3 \cdot (\text{diag}(u) \cdot k_3^\top \cdot v_3 + k_2^\top \cdot v_2 + \text{diag}(w) \cdot k_1^\top \cdot v_1 + \text{diag}(w^2) \cdot k_0^\top \cdot v_0)

RWKV-5 前向传播(推理过程)的 time-mixing 计算公式:

t=lerp(xt,xt1)W,{r,k,v,g}\square_t = \text{lerp}_{\square}(x_t, x_{t-1}) W_{\square}, \quad \square \in \{ r, k, v, g \}

w=exp(exp(ω))w = \exp(-\exp(\omega))

wkvt=diag(u)ktvt+i=1t1diag(w)t1ikiviR(D/h)×(D/h)wkv_t = \text{diag}(u) \cdot k_t^\top \cdot v_t + \sum_{i=1}^{t-1} \text{diag}(w)^{t-1-i} \cdot k_i^\top \cdot v_i \in \mathbb{R}^{(D/h) \times (D/h)}

ot=concat(SiLU(gt)LayerNorm(rtwkvt))WoRDo_t = \text{concat}(\text{SiLU}(g_t) \odot \text{LayerNorm}(r_t \cdot wkv_t)) W_o \in \mathbb{R}^D

NOTE: RWKV-5 的 Token Shift 和 RWKV-4 类似,是一个非常简单的线性插值(linear interpolation - lerp),且这个线性插值是数据无关的(data-independent),只由参数 μ\mu 决定当前 token 和前一个 token 混合到模型输入的占比。有:

lerp(a,b)=a+(ba)μ\text{lerp}_{\Box}(a, b) = a + (b - a) \odot \mu_{\Box}

其实和 RWKV-4 没区别.

通过将 RWKV-4 的向量变成矩阵,RWKV-5 的 state 计算从“基于向量”变成了“基于 64×64 的矩阵值”,即 “matrix-valued states” 。假设当前 RWKV 模型的维度是 512 ,则可以说有 512/64 = 8 个头 (八头×64 维),这就是 RWKV-5 的“多头-multi-headed” 概念。

因此,我们可以把 RWKV-5 的优化细节总结为:

  1. 消除了归一化项(RWKV-4 公式中的分母)
  2. 引入了矩阵值状态代替以往的向量值状态
  3. 引入了 Time-mixing 门控机制,即在 time-mixing 模块中添加额外的矩阵 WgW_g 和门控激活函数 SiLU 实现
  4. 引入了对角衰减矩阵,也就是将 uuww 向量参数分别进行对角化

通过这种方式,RWKV-5 巧妙地扩大了 state 的规模,使得 RWKV 模型有更好的记忆力和模型容量。

RWKV-6

RWKV-6 则借鉴了 Low-Rank Adaptation(LoRA)的技术,将 RWKV-4/5 中静态的参数μ\mu (学习完之后对所有的相邻步都采取同样的比例进行 Token Shift)替换成了动态的 LoRA,这个比率由每次输入的新旧数据的变化而变化,即它是数据依赖型的。

具体来说,RWKV-6 的 Token Shift 中的线性插值公式如下:

lora(x)=λ+tanh(xA)B\text{lora}_{\Box}(x) = \lambda_{\Box} + \tanh(x A_{\Box}) B_{\Box}

ddlerp(a,b)=a+(ba)lora(a+(ba)μx)\text{ddlerp}_{\Box}(a, b) = a + (b - a) \odot \text{lora}_{\Box}(a + (b - a) \odot \mu_x)

相对 ,RWKV-6 这种增强了数据依赖性的新型插值方法(data-dependent linear interpolation,ddlerp)有效地扩展模型的能力,每个通道分配的新旧数据量取决于当前和之前时间步骤的输入。

通俗地理解,这种动态递归机制/数据依赖性 使“重要的信息”可以有效地标记自身,以待在后续任务使用;而“不重要的信息”也可以标记自身,以减少或完全避免进入后续的数据流,从而为更重要的现有数据保留空间。

此外,如果某些信息对于特定任务没有用,那么动态递归机制可以允许这些信息预先被过滤掉。

然后,RWKV-6 也把原来的静态位置向量wtw_t 也换成 LoRA 式了:

dt=lorad(ddlerpd(xt,xt1))d_t = \mathrm{lora}_d( \mathrm{ddlerp}_d ( x_t, x_{t-1} ) )

wt=exp(exp(dt))w_t = \exp(-\exp(d_t))

因此wkvtwkv_t 的公式也得调整:

wkvt=diag(u)ktTvt+i=1t1diag(j=1i1wj)kiTviR(D/h)×(D/h){wkv}_{t} = \mathrm{diag}(u)\cdot k_{t}^\mathrm{T} \cdot v_{t} + \sum_{i=1}^{t-1} \mathrm{diag}\left(\bigodot_{j=1}^{i-1}w_{j}\right) \cdot k_{i}^\mathrm{T} \cdot v_{i} \in \mathbb{R}^{(D/h) \times (D/h)}

代码梳理

📃ChatRWKV/RWKV_in_150_lines.py at main · BlinkDL/ChatRWKV

代码截取时间:2024年12月9日14:44:35

此处梳理的是官方给出的 RWKV-4 简化 PyTorch 版本 ,并且是 slow version 的 RNN 纯推理版本,即并没有利用 CUDA 实现 GPU 并行加速的内容,此代码仅做推理演示使用

RWKV_in_150_lines.py

相关初始配置,如LM所需的 tokenizer等。其中,20B_tokenizer.json 为分词内容,在这里下载;RWKV-4-Pile-430M-20220808-8066为RWKV-4的预训练权重,在这里下载。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("20B_tokenizer.json")

args = types.SimpleNamespace()
args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024

context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
NUM_TRIALS = 3 #实验推理次数
LENGTH_PER_TRIAL = 100 #每次实验语言模型输出单词数量
TEMPERATURE = 1.0
TOP_P = 0.85

RNN模型类

这部分实现了在已知预训练权重(代码中通过 self.w 读取并存储)的情况下,使用RNN模式进行前向过程的模型代码。

为了统一管理,在代码实现上,RWKV传递的了 state 这么一个二维数组参数。每个时间步 i,它都存储了5个向量:

  1. state[5*i+0] 存储上一时间步的 Channel mixing 的输入变量xx'.
  2. state[5*i+1] 存储上一时间步的 Time mixing 的输入变量xx.
  3. state[5*i+2] 存储了 aa,即上一时间步的 a,是 RWKV-4 的 RNN 版本的中间变量,下同.
  4. state[5*i+2] 存储了 bb,即上一时间步的 b.
  5. state[5*i+2] 存储了 pp,即上一时间步的 p.

state 在 Time mixing 和 Channel mixing 操作结束后的值都会更新,等待下一时间步参与计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode

# map_location用于重定向,比如此前模型的参数是在cpu中,希望将其加载到cuda:0中
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# 根据key来指定预训练权重的值
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type

self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])

def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr) # rw, kw and vw 都是线性投影weight,下同
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)

@torch.jit.script_method
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv

aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = qq
return ow @ (r * wkv)

def forward(self, token, state):
with torch.no_grad():
if state == None: #处理初始情况(还没有state)
state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity

x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)

x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state

Logits输出

该步骤将模型的输出 x.float()(即 out/logit)进行 softmax 处理,并根据参数 temperaturetop_p 随机选取出一个样本,并返回所选的 token 索引。

说人话就是,这一步正在根据已喂到的文本随机挑出下一个单词的索引,当然这个单词要真正打印出来,还需要使用 tokenizer 进行解码。

1
2
3
4
5
6
7
8
9
10
11
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out

NOTE:代码中的 token 这个变量一直是索引的存在,从 RWKV_RNN.forward() 中也能看出,是通过调用词嵌入 self.w.emb.weight[token] 才取出的向量表示x 的。对应地,sample_logits 的输出也是随机预测到的单词的索引。

LM推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
init_out, init_state = model.forward(token, init_state)

for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
out, state = model.forward(token, state)
print('\n')

参考

  1. 免注意力Transformer (AFT):使用逐元素乘积而不是点积 - 鸽鸽的书房 - 博客园
  2. 颠覆Transformer?新语言模型算法RWKV解读
  3. 【手撕LLM-RWKV】重塑RNN 效率完爆Transformer - 知乎
  4. RWKV架构及历史 - RWKV文档