问题背景

我在推导注意力机制公式:

Attention(Q,K,V)=Softmax(QKTdk)VAttention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V
Si=ezijezjS_i=\frac{e^{z_i}}{\sum_j{e^{z_j}}}

发现除dk这一项在很多文章中指出是【预防梯度消失】,但大多没有说明原因;

推导过程

通过计算梯度(商法则),得到公式:

i对i的梯度

sizi=Si(1Si)\frac{\partial{s_i}}{\partial{z_i}}=S_i{(1-S_i)}
sizj=SiSj\frac{\partial{s_i}}{\partial{z_j}}=-S_i{S_j}

注意到,当注意力分数QKT项较大时,如果没有缩放dK,则会导致其中1项(指数最大的哪一项)为1,其他输出项Si非常小,约等于0;

根据梯度计算公式计算得到

Zi输入:[100, 90, 80, 70]
Softmax输出 s: [1.00000000e+00 3.72007598e-05 1.38389653e-09 5.14820022e-14]
s之和: 1.0000000000000002

梯度矩阵 J:
[[ 0.00000000e+00 -3.72007598e-05 -1.38389653e-09 -5.14820022e-14]
 [-3.72007598e-05  3.71933795e-05 -5.14820020e-14 -1.91568262e-18]
 [-1.38389653e-09 -5.14820020e-14  1.38389653e-09 -7.12486585e-23]
 [-5.14820022e-14 -1.91568262e-18 -7.12486585e-23  5.14820022e-14]]

梯度范数: 5.258048925507793e-05  # 几乎为零!