转载: Attention为什么要除以$\sqrt{d}$

摘要: Attention 的计算公式中 Attention $(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{\top}}{\sqrt{d}}\right)$ 为什么要除以 $\sqrt{d}$ ?

原文链接: https://mp.weixin.qq.com/s/3o0NgpFPKS1RNICNuMuygg


〓 Table of Contents 〓




来自原论文的分析

〓 ReTURN 〓

《Attention is All Your Need》的原论文给出了一个粗略的答案。

While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ [3]. We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$.

当 $d_k$ 的值变大的时候,softmax 函数会导致梯度消失问题,因此设置了一个 softmax 的 temperature 来缓解这个问题。这里 temperature 被设置为了 $\sqrt{d_k}$ ,也就是乘上 $\frac{1}{\sqrt{d_k}}$.
延伸讨论:

  1. 为什么会导致梯度消失?
  2. 为什么是 $\sqrt{d_k}$, 有更好的值么?



梯度消失的原因

〓 ReTURN 〓

  1. 如果 $d_k$ 变大, $q \cdot k^{\top}$ 方差会变大。
  2. 方差变大会导致向量之间元素的差值变大。
  3. 元素的差值变大会导致 softmax 退化为 argmax, 也就是最大值 softmax 后的值为 1 , 其他值则为 0 。
  4. softmax 只有一个值为 1 的元素, 其他都为 0 的话, 反向传播的梯度会变为 0 , 也就是所谓的梯度消失。

$d_k$ 变大, $q \cdot k^{\top}$ 方差会变大。

假设 $q$ 和 $k$ 的向量长度为 $d_k$ ,均值为 0 , 方差为 1 。则 $q$ 和 $k$ 的点积的方差为:
$$
\begin{aligned}
\operatorname{var}\left[q \cdot k^{\top}\right] & =\operatorname{var}\left[\sum_{i=1}^{d_k} q_i \times k_i\right] \
& =\sum_{i=1}^{d_k} \operatorname{var}\left[q_i \times k_i\right] \
& =\sum_{i=1}^{d_k} \operatorname{var}\left[q_i\right] \times \operatorname{var}\left[k_i\right] \
& =\sum_{i=1}^{d_k} 1 \
& =d_k
\end{aligned}
$$

当 $d_k$ 变大时, 方差变大。

方差变大会导致向量之间元素的差值变大。

方差变大就是代表了数据之间的差异性变大。
看上去显然, 如果非要给出证明, 可以将这个问题换一个问题来侧面回答这个问题。
新的问题假设向量是通过独立同分布的数据采样出来的 $d_k$ 个数据, 那么这 $d_k$ 个数的最大值的期望是多少?

因为分布很多, 这里只给出最常用的正态分布的证明, 详细证明见:
http://www.gautamkamath.com/writings/gaussian_max.pdf

这里只给出结论如下:
Theorem 1. Let $Y=\max _{1 \leq i \leq n} X_i$, where $X_i \sim \mathcal{N}\left(0, \sigma^2\right)$ are i.i.d. random variables. Then
$$
\frac{1}{\sqrt{\pi \log 2}} \sigma \sqrt{\log n} \leq \mathbf{E}[Y] \leq \sqrt{2} \sigma \sqrt{\log n} .
$$

从期望的下界, 可以看出, 方差越大, 最大值的期望越大。同时还有个结论就是 $d_k$ 越大, 最大值的期望也越大。由于正太分布是对称的, 最小值就是最大值取负号。

所以方差变大, 数据分布的最大最小值的差值变大了, 也就从侧面证明了向量元素之间的差值变大了。

softmax 退化为 argmax

对于 softmax 函数中的每个分量 $\operatorname{softmax}\left(x_i\right)$, 我们可以写成:
$$
\operatorname{softmax}\left(x_i\right)=\frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}}
$$

当 $x_k$ 是最大的元素时, $e^{x_k}$ 会显著大于其他 $e^{x_i}$ (其中 $i \neq k$ ), 尤其是当这些 $x_i$ 和 $x_k$ 之间的差距变得非常大时。为了更清楚地看出这一点, 我们将 $x_i$ 的每个元素表示成最大元素 $x_k$ 减去一个差值 $\Delta_i$, 即 $x_i=x_k-\Delta_i$, 其中 $\Delta_k=0$ 。

因此, softmax 函数可以重写为:
$$
\operatorname{softmax}\left(x_i\right)=\frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}}=\frac{e^{x_k-\Delta_i}}{\sum_{j=1}^n e^{x_k-\Delta_j}}
$$

由于 $e^{x_k}$ 是公因子, 根据 $\exp$ 的性质, 可以提出来并且约掉:
$$
\operatorname{softmax}\left(x_i\right)=\frac{e^{x_k} e^{-\Delta_i}}{e^{x_k} \sum_{j=1}^n e^{-\Delta_j}}=\frac{e^{-\Delta_i}}{\sum_{j=1}^n e^{-\Delta_j}}
$$

当 $\Delta_i$ 非常大(即 $x_i$ 远小于 $x_k$ )时, $e^{-\Delta_i}$ 会接近于 0。因此, 除了 $\Delta_k=0$ 以外的所有项,其他项 $e^{-\Delta_j}$ 都会非常小, 可以忽略不计(其实都不用非常小, $e^{-5}=0.0067379 \ldots, e^{-10}=4.534-05$, 只要差值大于 10 , 就可以忽略不计了)。于是, 对于 $i=k$ :
$$
\operatorname{softmax}\left(x_k\right) \approx \frac{1}{1}=1
$$

而对于 $i \neq k$ :
$$
\operatorname{softmax}\left(x_i\right) \approx 0
$$

所以说当输入向量 $\mathbf{x}$ 的方差变得非常大时, softmax 函数将会趋近于将最大的元素赋值为 1 ,

而其他元素赋值为 0 , 也就是是 argmax 函数。用公式表示的话:
$$
\lim _{\operatorname{var}(\mathbf{x}) \rightarrow \infty} \operatorname{softmax}(\mathbf{x})=\operatorname{argmax}(\mathbf{x})
$$

所以方差变大时, softmax 函数会退化为 argmax 函数。
这里我们可以做个实验看一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np

n = 10

x1 = np.random.normal(loc=0, scale=1, size=n)
x2 = np.random.normal(loc=0, scale=np.sqrt(512), size=n)
print('x1最大值和最小值的差值:', max(x1) - min(x1))
print('x1最大值和最小值的差值:', max(x2) - min(x2))

def softmax(x):
return np.exp(x) / np.sum(np.exp(x), keepdims=True)

def softmax_grad(y):
return np.diag(y) - np.outer(y, y)

ex1 = softmax(x1)
ex2 = softmax(x2)
print('softmax(x1) =', ex1)

print('softmax(x2) =', ex2)

其结果为:

1
2
3
4
5
6
7
x1最大值和最小值的差值: 1.8973472870218264
x1最大值和最小值的差值: 66.62254341144866
softmax(x1) = [0.16704083 0.21684976 0.0579299 0.05408421 0.16109133 0.14433417
0.03252007 0.05499126 0.04213939 0.06901908]
softmax(x2) = [4.51671361e-19 2.88815837e-21 9.99999972e-01 3.02351231e-17
3.73439970e-25 8.18066523e-13 2.78385563e-08 1.16465424e-29
7.25661271e-20 3.21813750e-21]

可以看出, 在方差为 $\sqrt{512}$ 的时候, softmax 只有第三个元素接近 1 , 其他都几乎为 0 .

softmax 什么情况下会梯度消失

我们来对 softmax 函数进行求导。定义 softmax 为
$$
\operatorname{softmax}\left(z_i\right)=\frac{\exp \left(z_i\right)}{\sum_j \exp \left(z_j\right)}
$$

假设我们有一个向量 $\mathbf{z}=\left[z_1, z_2, \ldots, z_n\right]$, softmax 函数的输出是一个向量 $\mathbf{y}=\left[y_1, y_2, \ldots, y_n\right]$ , 其中:
$$
y_i=\operatorname{softmax}\left(z_i\right)=\frac{\exp \left(z_i\right)}{\sum_{j=1}^n \exp \left(z_j\right)}
$$

我们需要计算 softmax 函数的导数, 即 $\frac{\partial y_i}{\partial z_k}$, 分为两种情况:

  1. 当 $i=k$
  2. 当 $i \neq k$

首先, 计算 $y_i$ 对 $z_k$ 的导数:

  1. 当 $i=k$ 时
    $$
    \frac{\partial y_i}{\partial z_i}=\frac{\partial}{\partial z_i}\left(\frac{\exp \left(z_i\right)}{\sum_j \exp \left(z_j\right)}\right)
    $$

使用商的导数法则, 我们得到:
$$
\frac{\partial y_i}{\partial z_i}=\frac{\exp \left(z_i\right) \sum_j \exp \left(z_j\right)-\exp \left(z_i\right) \exp \left(z_i\right)}{\left(\sum_j \exp \left(z_j\right)\right)^2}
$$

化简得到:
$$
\frac{\partial y_i}{\partial z_i}=\frac{\exp \left(z_i\right)\left(\sum_j \exp \left(z_j\right)-\exp \left(z_i\right)\right)}{\left(\sum_j \exp \left(z_j\right)\right)^2}=\frac{\exp \left(z_i\right)}{\sum_j \exp \left(z_j\right)}\left(1-\frac{\exp \left(z_i\right)}{\sum_j \exp \left(z_j\right)}\right)
$$

即:
$$
\frac{\partial y_i}{\partial z_i}=y_i\left(1-y_i\right)
$$
2. 当 $i \neq k$ 时
$$
\frac{\partial y_i}{\partial z_k}=\frac{\partial}{\partial z_k}\left(\frac{\exp \left(z_i\right)}{\sum_j \exp \left(z_j\right)}\right)
$$

同样使用商的导数法则, 我们得到:
$$
\frac{\partial y_i}{\partial z_k}=\frac{0 \cdot \sum_j \exp \left(z_j\right)-\exp \left(z_i\right) \exp \left(z_k\right)}{\left(\sum_j \exp \left(z_j\right)\right)^2}=-\frac{\exp \left(z_i\right) \exp \left(z_k\right)}{\left(\sum_j \exp \left(z_j\right)\right)^2}
$$

即:
$$
\frac{\partial y_i}{\partial z_k}=-y_i y_k
$$

两种情况合并一下

将两种情况合并, softmax 的导数可以表示为:
$$
\frac{\partial y_i}{\partial z_k}=y_i\left(\delta_{i k}-y_k\right)
$$

其中, $\delta_{i k}$ 是 Kronecker delta 函数, 定义为:
$$
\delta_{i k}= \begin{cases}1, & \text { if } i=k \ 0, & \text { if } i \neq k\end{cases}
$$

最终可以用 Jacobian 矩阵表示, Jacobians 矩阵的第 $i$ 行和第 $k$ 列元素是 $\frac{\partial y_i}{\partial z_k}$ :
$$
\mathbf{J}=\left[\begin{array}{cccc}
y_1\left(1-y_1\right) & -y_1 y_2 & \cdots & -y_1 y_n \
-y_2 y_1 & y_2\left(1-y_2\right) & \cdots & -y_2 y_n \
\vdots & \vdots & \ddots & \vdots \
-y_n y_1 & -y_n y_2 & \cdots & y_n\left(1-y_n\right)
\end{array}\right]
$$

然后与第三步的世界线交汇,好玩的来了

在第三步中我们证明了当方差变大的时候, softmax 退化成了 argmax, 也就是变成一个只有一个 1 其他全为 0 的向量。

这个向量带入到上面的雅可比矩阵会发生什么? 我们发现对于任意的 $y_k=1, y_{j \neq k}=0$ 的向量来说, 雅可比矩阵变成了一个全 0 矩阵。

也就是说梯度全为 0 了。到这里才算是证明了为什么 $q \cdot k^{\top}$ 的方差不能太大, 太大了就梯度消失。

梯度实验

我们同样做个实验,看看梯度到底为多少。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np

n = 10

x1 = np.random.normal(loc=0, scale=1, size=n)
x2 = np.random.normal(loc=0, scale=np.sqrt(512), size=n)
print('x1最大值和最小值的差值:', max(x1) - min(x1))
print('x1最大值和最小值的差值:', max(x2) - min(x2))

def softmax(x):
return np.exp(x) / np.sum(np.exp(x), keepdims=True)

def softmax_grad(y):
return np.diag(y) - np.outer(y, y)

ex1 = softmax(x1)
ex2 = softmax(x2)
print('softmax(x1) =', ex1)
print('max of gradiant of softmax(x1) =', np.max(softmax_grad(ex1)))
print('softmax(x2) =', ex2)
print('max gradiant of softmax(x2) =', np.max(softmax_grad(ex2)))

其结果为:

1
2
3
4
5
6
7
8
9
x1最大值和最小值的差值: 1.8973472870218264
x1最大值和最小值的差值: 66.62254341144866
softmax(x1) = [0.16704083 0.21684976 0.0579299 0.05408421 0.16109133 0.14433417
0.03252007 0.05499126 0.04213939 0.06901908]
max of gradiant of softmax(x1) = 0.1698259433168865
softmax(x2) = [4.51671361e-19 2.88815837e-21 9.99999972e-01 3.02351231e-17
3.73439970e-25 8.18066523e-13 2.78385563e-08 1.16465424e-29
7.25661271e-20 3.21813750e-21]
max gradiant of softmax(x2) = 2.7839373695215386e-08

可以看出,在方差为 的时候,长度仅仅为10的向量x2,其梯度就已经快没有了,最大值为2.78e-8。 而如果将方差控制在1,则最大的梯度为0.1698

scale 的值为什么是 $\sqrt{d_k}$, 有更好的值么?

从上一节的第一步的证明, 可以发现, scale 的值为 $\sqrt{d_k}$ 其实是把 $q \cdot k^{\top}$ 归一化成了一个均值为 0 , 方差为 1 的向量。

至于是不是最好呢? 不好说, 因为参数的分布我们不太清楚。苏神曾经试图求解了一些常用分布的最佳 scale 值, 感兴趣的可以看下:https://spaces.ac.cn/archives/9812

转载: Attention为什么要除以$\sqrt{d}$

https://nerozac.com/2024/06/02/Attention为什么要除以根号d/

作者

Jiawei Li

发布于

2024-06-02

更新于

2024-06-02

许可协议