问题背景
我在推导注意力机制公式:
发现除dk这一项在很多文章中指出是【预防梯度消失】,但大多没有说明原因;
推导过程
通过计算梯度(商法则),得到公式:
i对i的梯度
注意到,当注意力分数QKT项较大时,如果没有缩放dK,则会导致其中1项(指数最大的哪一项)为1,其他输出项Si非常小,约等于0;
根据梯度计算公式计算得到
Zi输入:[100, 90, 80, 70]
Softmax输出 s: [1.00000000e+00 3.72007598e-05 1.38389653e-09 5.14820022e-14]
s之和: 1.0000000000000002
梯度矩阵 J:
[[ 0.00000000e+00 -3.72007598e-05 -1.38389653e-09 -5.14820022e-14]
[-3.72007598e-05 3.71933795e-05 -5.14820020e-14 -1.91568262e-18]
[-1.38389653e-09 -5.14820020e-14 1.38389653e-09 -7.12486585e-23]
[-5.14820022e-14 -1.91568262e-18 -7.12486585e-23 5.14820022e-14]]
梯度范数: 5.258048925507793e-05 # 几乎为零!