Sustie

主页 所有文章 文章检索

Attention层的简单讲解

输入与输出的维度

Attention 层是用来处理序列输入的,所以它接受的输入是一个二维的n×Nin \times N_i的矩阵,而输出则是一个n×Non \times N_o的矩阵。nn是输入序列的 token 个数,NiN_i则是 embedding 向量的长度,NoN_o则是输出的状态向量的长度。理论上 Attention 层对于nn没有限制,而NiN_iNoN_o则是模型的超参数。

Attention 层的超参数

除了NiN_iNoN_o以外,Attention 层还有其他超参数。

一个 Attention 层包含 3 个全连接层,分别叫做 Query、Key 和 Value,用符号QQKKVV表示。这 3 个全连接层的输入维度都是NiN_i,而输出维度都是超参数,分别为DqD_qDkD_kDvD_v。稍后会看到,Attention 层要求Dq=DkD_q=D_k以及Dv=NoD_v=N_o

综上所述,Attention 层的超参数有NiN_iNoN_oDqD_q

推理过程

我们将输入的向量依次记作x1x_1x2x_2……xnx_n,同理输出的向量记作y1y_1y2y_2……yny_n

推理的时候,首先要将输入向量都经过全连接层一次,得到Q(x1)Q(x_1)K(x1)K(x_1)V(x1)V(x_1)等等。最终的输出是由VV的结果加权求和得到的:

y_i = \sum_j \alpha_{ij} V(x_j), \forall i

因为αij\alpha_{ij}是权重,所以自然要满足:

\sum_j \alpha_{ij} = 1,\forall i

αij\alpha_{ij}是根据QQKK的结果得到的:

\alpha^\prime_{ij} = \frac{Q(x_i) \cdot K(x_j)}{\sqrt{D_q}}, \forall i, j

这里我用的符号是α\alpha^\prime而非α\alpha,因为α\alpha^\prime不满足权重和为 1 的要求,因此不是最终结果。要让和为 1,使用 softmax 层处理一下即可。

(计算α\alpha要除以Dq\sqrt{D_q}是为了归一化。)