用 PyTorch 写一个 NanoGPT (1): Attention 模块

仓库链接: https://github.com/Davidwadesmith/NanoGPT

本项目参照 NanoGPT, 实现了一个简单易用的 GPT 语言模型. 适合用于学习和实验生成式预训练变换器 (GPT) 的核心原理.

项目结构

本项目采用模块化设计, 便于扩展和实验.

顶级目录

  • train.py: 模型训练的主入口脚本.
  • eval.py: 模型评估与文本生成脚本.
  • setup.py: 安装脚本 (依赖清单).
  • README.md: 项目说明文档与路线图.

源码模块 (src/)

核心逻辑位于 src/ 目录下.

  • src/model.py: 定义 Transformer 架构 (Attention, FFN, LayerNorm).
  • src/dataloader.py: 负责数据加载与预处理.
  • src/tokenizer.py: 处理文本的分词逻辑.
  • src/config.py: 配置文件. 集中管理所有超参数和路径设置.

测试 (tests/)

包含单元测试以确保组件可靠性.

  • tests/test_attention.py: 验证 Attention 机制的输出形状.

Step 1: Attention 模块

概览

首先实现 Attention class, 这个 class 继承自 nn.Module, 这使得它自动成为一个 PyTorch 的神经网络模块, 方便进行参数的统一管理和操作.

每个模块中, forward 函数是很重要的一个方法, 它就是这个模块在前向传播时调用的函数. forward 函数需要明确输入、输出以及中间的处理环节, 其中输入输出在这里就是隐藏层的 hidden state, 形状是 $(B, S, D)$, 即 (batch_num, seq_len, hidden_dim).

明确形状后就可以开始操作了, 我们回忆 Attention 的公式:

可以看到应该先把 $Q$、$K$、$V$ 这三个处理出来, 对于最原始的 Attention 层来说, 我们可以直接让 input 经过三个线性层分别得到 $Q$、$K$、$V$. 因此先在 __init__() 中声明:

1
2
3
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=False, device=cfg.device)
self.wk = nn.Linear(hidden_dim, hidden_dim, bias=False, device=cfg.device)
self.wv = nn.Linear(hidden_dim, hidden_dim, bias=False, device=cfg.device)

之后调用这三个函数即可.

计算 Scaled Dot-Product

得到 $Q$、$K$、$V$ 后就可以轻松进行矩阵运算了:

1
2
3
4
5
attention_map = (
q # (batch_size, head_n, seqlen, head_dim)
@ torch.transpose(k, -2, -1) # (batch_size, head_n, head_dim, seqlen)
/ ((self.cfg.hidden_dim // self.cfg.head_n) ** 0.5)
) # (batch_size, head_n, seqlen, seqlen)

attention_map 就是 $\frac{QK^T}{\sqrt{d_k}}$, 你可以看到有对 head 进行处理, 不过我们可以暂且认为 self.cfg.hidden_dim // self.cfg.head_n 就是 $d_k$.

可以看到我们把 $K$ 的后两个维度置换来对应 Attention 公式中的转置, 这里也可以用 torch.permute() 进行计算, torch.permute() 会更灵活一些, 可以同时对多个维度进行调换的操作. 并且要注意维度变换操作是不保证连续性的 (contiguity), 这也是为什么后面对矩阵形状进行操作的时候使用 reshape 的原因, reshape 可以自动处理连续性.

Causal Masking (因果掩码)

由于我们构造的是自注意力的 GPT 架构, 因此需要使用因果掩码 (Causal Mask). 简单来说就是第 $i$ 个 token 对应的 $Q_i$ 是看不到它之后的 $K_j$ ($i<j<\text{seqlen}$) 的.

对应到代码上, 相应位置上的 attention score 应该是 0, 也就消除了相应位置的 attention score 对 $V_i$ 的影响. 要实现这个效果需要对刚才我们得到的 attention_map 进行 mask 处理, 也就是“遮住”相应的值, 具体来说:

1
2
3
self.mask = (torch.tril(torch.ones(self.cfg.seqlen, self.cfg.seqlen)) == 0).to(
torch.device(self.cfg.device)
) # (seqlen, seqlen)

这里先构建了一个全 1 的 seqlen 长宽的矩阵, 截取下三角, 上三角变成 0, 再进行布尔操作让上三角变成 True, 下三角为 False, 这个掩码将会作用在 attention_map 上:

1
2
3
masked_attention_map = torch.masked_fill(
attention_map, self.mask, float('-inf')
) # match -1, -2 dim

Mask 为真的地方会被填入负无穷 (-inf), 这里的 torch.masked_fill 默认从后向前匹配两个维度, 所以这里就正好是 -1, -2 维进行操作. 填入负无穷的好处在于 softmax 后得到的矩阵每一行都是归一化的, 即加起来为 1:

1
2
3
masked_attention = (
self.softmax(masked_attention_map) @ v
) # (batch_size, head_n, seqlen, head_dim)

如果 softmax 后才加掩码, 每一行加起来就不是 1 了.


下一期说 Multi-head 是如何实现的.