用 PyTorch 写一个 NanoGPT (3): 归一化和残差连接

仓库链接: https://github.com/Davidwadesmith/NanoGPT

本项目参照 NanoGPT,实现了一个简单易用的 GPT 语言模型,适合用于学习和实验生成式预训练变换器 (GPT) 的核心原理。


归一化:LayerNorm / RMSNorm

归一化是指将一组样本的均值变为 $0$、方差变为 $1$。LayerNorm 在隐藏层维度做归一化,使每个 token 的特征分布更稳定。

在 LayerNorm 之前常见的是 BatchNorm。BatchNorm 倾向于要求“同一特征在不同 token 上”分布稳定,但在 LLM 里会遇到一些实际问题:训练时有 padding token,推理时 batch 往往很小,且 BatchNorm 不太利于算子融合。因此在隐藏层维度做归一化通常更合适。

LayerNorm 的形式为:

它将向量方向和长度的调节解耦:归一化部分约束数值尺度,$\gamma, \beta$ 保留可学习的缩放和平移能力。常见初始化是 $\gamma = 1, \beta = 0$。

RMSNorm可以看作去掉均值项的 LayerNorm:

直观上,很多 Transformer 激活在统计上已接近零均值,显式减均值带来的收益有限;只要用 RMS 把尺度压住,就可以显著改善梯度稳定性,同时减少计算量。

Pre-LN 梯度有界(形式化推导)

Theorem: Gradient Norm Bound in Pre-LN Architecture

本节旨在从形式化角度证明,在 Transformer 的 Pre-LN 架构中,随着网络层数 $L$ 的加深,其反向传播梯度的期望范数被严格有界约束,从而在理论上规避了梯度爆炸问题。

Definition 1 (Pre-LN Transformer Layer)

设 $x_l \in \mathbb{R}^d$ 为第 $l$ 层的输入表征,Pre-LN 架构的前向传播形式化定义为:

其中,$f_l: \mathbb{R}^d \to \mathbb{R}^d$ 表示残差分支(如 Multi-Head Attention 或 FFN),$\text{LN}(\cdot)$ 表示层归一化操作。

Definition 2 (Layer Normalization)

对于任意输入向量 $x \in \mathbb{R}^d$,层归一化操作定义为:

其中,均值 $\mu = \frac{1}{d}x^T\mathbf{1}$,方差 $\sigma^2 = \frac{1}{d}|x - \mu\mathbf{1}|^2_2$。为分析初始化阶段的优化动力学,假设仿射变换参数初始化为 $\gamma = \mathbf{1}$ 且 $\beta = \mathbf{0}$。

Lemma 1 (Jacobian of Layer Normalization)

归一化算子 $\text{LN}(x)$ 对输入 $x$ 的雅可比矩阵(Jacobian Matrix)$J_{\text{LN}}(x) \in \mathbb{R}^{d \times d}$ 解析式为:

其中 $\tilde{x} = \frac{x - \mu\mathbf{1}}{\sigma}$,$I$ 为单位矩阵。

Corollary 1.1: 雅可比矩阵 $J{\text{LN}}(x)$ 的谱范数与其输入的标准差严格成反比,即 $|J{\text{LN}}(x)|_2 \propto \mathcal{O}(\frac{1}{\sigma})$。

Assumption 1 (Variance Accumulation at Initialization)

假设在模型初始化阶段,输入 $x_0$ 与所有残差分支 $f_l$ 的权重服从独立同分布。对于任意标准正态输入 $y$,满足 $\mathbb{E}[f_l(y)] = 0$ 且 $\text{Var}(f_l(y)) = \eta^2$。

由于 $\text{LN}(x_l)$ 恒定输出标准差为 $1$ 的表征,残差路径在各层的输出方差具有可加性:

据此可导出,第 $l$ 层输入表征的标准差满足 $\sigma_l \propto \sqrt{l}$。

Proof of the Theorem

设 $\mathcal{L}$ 为全局损失函数。根据多元微积分的链式法则,损失 $\mathcal{L}$ 对第 $l$ 层输入 $x_l$ 的梯度流可表示为:

Definition 1 中的前向传播方程求导,可得单层的雅可比矩阵:

其中 $J{f_k} = \frac{\partial f_k(y)}{\partial y}\big|{y=\text{LN}(x_k)}$。

将其代入全局梯度流连乘方程中:

为分析上述连乘积的数值稳定性,考察扰动项 $J{f_k} J{\text{LN}}(x_k)$ 的谱范数。由柯西-施瓦茨不等式及 Corollary 1.1 可得:

引入 Assumption 1 中关于前向方差累加的结论 $\sigma_k \propto \sqrt{k}$,可得:

综上所述,反向传播的梯度连乘积由恒等映射主干 $I$ 与受动态阻尼系数 $\mathcal{O}(\frac{1}{\sqrt{k}})$ 约束的残差分支构成。由于扰动项的范数随网络深度 $k$ 的增加而严格衰减,连乘矩阵的谱半径被控制在 $\mathcal{O}(L^{c})$ ($c$ 为极小常数),彻底消除了单纯残差网络中 $\mathcal{O}(e^L)$ 的指数级增长。

因此,$\left| \frac{\partial \mathcal{L}}{\partial x_l} \right|_2$ 被动态收敛于一个安全的有界区间,确保证明成立。$\blacksquare$

无归一化残差网络中的梯度爆炸(形式化推导)

Theorem: Exponential Gradient Explosion in Unnormalized Residual Networks

本节旨在证明,在缺乏归一化算子的纯残差网络中,反向传播的梯度方差会随着网络层数 $L$ 的增加呈指数级(Exponential)爆炸,导致深层网络在理论上无法收敛。

Definition 1 (Unnormalized Forward Pass)

设 $x_l \in \mathbb{R}^d$ 为第 $l$ 层的输入表征。剥离 Layer Norm 后,纯残差网络的前向传播形式化定义为:

其中,$f_l: \mathbb{R}^d \to \mathbb{R}^d$ 为第 $l$ 层的非线性变换模块(如包含可学习权重 $W_l$ 的线性层及非线性激活函数)。

Lemma 1 (Undamped Jacobian Matrix)

对上述无归一化的前向方程求导,可得单层的雅可比矩阵:

其中 $J_{f_l}(x_l) = \frac{\partial f_l(x_l)}{\partial x_l}$。

关键差异: 由于失去了 Layer Norm 中 $\frac{1}{\sigma}$ 的缩放系数,这里的输入 $xl$ 未经任何尺度重置。在标准的参数初始化(期望为 $0$,方差为 $\eta^2$)下,雅可比矩阵 $J{f_l}(x_l)$ 的元素方差完全独立于层数 $l$。因此,其期望谱范数不再具有随深度衰减的性质,即:

Proof of the Theorem

设全局损失函数为 $\mathcal{L}$。根据链式法则,损失 $\mathcal{L}$ 对初始输入 $x_0$ 的反向传播梯度流为:

为衡量梯度的离散程度,我们考察梯度向量在反向传播过程中的协方差矩阵(Covariance Matrix)的迹(Trace),即梯度的期望二范数平方 $\mathbb{E}\left[ \left| \frac{\partial \mathcal{L}}{\partial x_l} \right|^2_2 \right]$。

假设在初始化阶段,各层的雅可比矩阵 $J_{f_l}$ 元素均值为 $0$,方差为 $v^2$,且不同层之间相互独立。当梯度从第 $l+1$ 层传至第 $l$ 层时,其方差的递推关系为:

展开矩阵乘法,由于交叉项期望为 $0$,递推式简化为:

其中 $d$ 为隐藏层维度,$(1 + d \cdot v^2) > 1$ 是一个大于 $1$ 的严格常数放大因子。

将此递推关系从顶层 $L$ 展开至最底层 $0$:

在数学上,上述结果等价于:

综上所述,由于残差分支缺少 $\mathcal{O}(\frac{1}{\sqrt{l}})$ 的动态阻尼项,微小的局部梯度方差在连乘效应下被不断放大,最终整个网络的梯度范数随层数 $L$ 呈现 $\mathcal{O}(e^L)$ 的指数级增长,必然导致数值溢出(NaN)与训练崩溃。证明完毕。$\blacksquare$


残差连接

残差连接本质上也是一个和梯度稳定性相关的优化。由于底层参数的梯度计算涉及很长的连乘,梯度容易衰减;加入残差后,连乘中的项会更接近恒等映射,从而提升稳定性。

从函数拟合角度看,每层学习“在恒等映射上的增量”通常比直接学习一个全新映射更容易。同时,残差结构可以近似看作不同深度子网络的集成,常常会让最终拟合函数更平滑。

残差连接改善梯度传播(推导)

在没有残差连接的普通深层网络(Plain Network)中,假设第 $L$ 层的输出为 $x_L$,通过链式法则,损失函数 $\mathcal{L}$ 对底层参数 $w_l$ 的梯度为:

对于普通网络 $x{i+1} = \sigma(w_i x_i)$,雅可比矩阵 $\frac{\partial x{i+1}}{\partial x_i}$ 的模长通常小于 $1$。当 $L$ 很大时,连乘项会呈指数级衰减,导致梯度消失

残差网络的改进:

引入残差结构 $x_{i+1} = F(x_i, w_i) + x_i$ 后,梯度变为:

其中 $I$ 是单位矩阵。代入链式法则:

结论: 即使中间层的梯度 $\frac{\partial F}{\partial x}$ 极小,由于单位阵 $I$ 的存在,梯度仍可直接回传到浅层,梯度范数不易无限制衰减。

总结

这一部分的核心结论是:归一化(尤其 Pre-LN / RMSNorm)负责控制数值尺度,残差连接负责提供稳定的梯度通路。两者组合在一起,基本构成了现代 Transformer 稳定训练的底座。