第 4 章 model.py 精读·上:RMSNorm、RoPE、GQA Attention¶
本章覆盖 model.py 中除 MoE 与 DecoderLayer 之外的"底层零件":
Linear(自定义的hk.Linear子类)RMSNorm与hk_rms_normrotate_half与RotaryEmbeddingMultiHeadAttentionmake_attention_mask
读完本章你应该能在脑子里把"一次 attention 调用"全程跑一遍:输入是 [B, T, d],输出是 [B, T, d] + 更新后的 KV cache,中间经过哪些 reshape、哪些 RoPE 角度、哪些 partition 约束。
本章的写作策略是"每一段先把概念讲清楚,再贴代码,再做 PyTorch 对照"。普通 MLE 读到 JAX/Haiku 那几行陌生 API 时,往往会因为不熟悉而跳过;本章会把每一处 hk.get_parameter、with_sharding_constraint、jax.lax.dynamic_update_slice_in_dim 都对应到 PyTorch 里你已经熟悉的写法,让你能用"翻译"的方式把代码读下去。
Transformer 块的标准结构
在进入 Grok-1 的具体代码之前,先把一个标准 decoder-only Transformer 块的结构在脑中过一遍。每一层(也叫 block)做两件事:
- 自注意力子层(self-attention sub-layer):把输入
x经过Norm → Attention → 残差加回,让每个 token 与同一序列里的其他 token 交换信息。 - 前馈子层(FFN sub-layer):把上一步的输出再经过
Norm → FFN → 残差加回,让每个 token 单独做一次非线性变换。MoE 模型把这里的稠密 FFN 换成"路由器 + 多个专家"。
输入 embedding 经过 L 层这样的 block 之后,再做一次 Norm,最后乘 unembedding(通常和输入 embedding 共享权重)得到 vocab 上的 logit。Grok-1 的 L = 64,每层 attention 子层用 GQA,FFN 子层用 8-expert top-2 MoE,归一化用 sandwich norm。本章只看 attention 子层里的底层零件,FFN/MoE 留到第 5 章,整体的 sandwich 装配留到第 6 章。
KV cache(Key/Value Cache)
自回归生成时每生成一个新 token,都要让它的 query 与前面所有 token 的 key/value 做 attention。如果每步都重算所有历史 token 的 K、V,复杂度是 O(T²),T 一长就吃不消。KV cache 的做法是把每层 attention 算过的 K、V 存下来,下一步只算新 token 的 K、V 然后追加到 cache 末尾,attention 就退化成一次 [1, T] × [T, d] 的乘法。
Grok-1 的 KV cache 形状是 [B, T=8192, num_kv_heads=8, key_size=128],每层一份,64 层加起来 batch=1 时约 2 GB;GQA 把 KV 头数从 48 砍到 8 直接给 cache 减重 6 倍。
PyTorch 这边对应的实现是 HuggingFace transformers 里的 DynamicCache 或 StaticCache,本质都是一组按层组织的 K/V 张量。Grok-1 因为是 JAX,KV cache 是显式的 KVMemory dataclass(详见第 6 章),每一步推理都要把它作为参数传进去再接住返回值,不能像 PyTorch 那样隐式挂在 module 上。
4.1 Linear:内置 8-bit 反量化¶
model.py:525-584。Grok-1 自己写了一个 Linear,继承 hk.Linear,主要为了:
- 支持
QuantizedWeight8bit- 权重存 int8 + per-channel fp32 scale - 让 sharding 直接写在构造参数里
# model.py:525-584
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
sharding: Optional[P] = None,
mesh: Any = None,
name: Optional[str] = None,
shard_axis: int = 0,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
self.sharding = sharding
self.mesh = mesh
self.shard_axis = shard_axis
def __call__(self, inputs: jax.Array) -> jax.Array:
fprop_dtype = inputs.dtype
...
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
if hasattr(w, "scales"):
shape = inputs.shape
inputs = jnp.reshape(inputs, (-1, shape[-1]))
@functools.partial(
shard_map,
mesh=self.mesh,
in_specs=(self.sharding, self.sharding),
out_specs=self.sharding,
check_rep=False,
)
def mul(w, s):
return w.astype(s.dtype) * s
w = mul(w.weight, w.scales)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
关键观察:
hasattr(w, "scales")是一个"运行时鸭子类型"检查。hk.get_parameter在加载 ckpt 时会把参数替换成QuantizedWeight8bit实例(这个 dataclass 注册为 pytree node 在model.py:47-51),它就有.scales字段。如果没量化,直接 dense matmul- 反量化在
shard_map内做 - 这样每个 device 只乘自己那一片的 scale,省一次 all-gather - 反量化后的
wdtype 与 scales 一致(fp32),然后再 cast 到 fprop dtype 做 matmul
hk.get_parameter
Haiku 的参数获取入口,对应 PyTorch 里的 self.weight = nn.Parameter(...) 加 self.weight 访问。差别是 PyTorch 把参数挂在 module 实例上,谁需要谁直接拿;Haiku 是函数式的,参数其实存在外部的 params pytree 里,调用 hk.get_parameter("w", shape, dtype, init) 时框架会去当前 hk.transform 的上下文中按 module name + 参数名查这棵树,首次调用做 init、后续调用返回已有值。
hk.transform 之后参数全部展平为 {"module_name/param_name": value} 的扁平字典,加载 ckpt 时只要把 ckpt 里同名键塞进去就行。这种设计的副作用是参数名完全由 module name 决定,重命名 module 会导致 ckpt 加载失败 - 第 7 章会展开讲 Grok-1 的 ckpt key 与 module name 是怎样精确对应的。
shard_map
JAX 里写显式 SPMD 程序的 API。一段被 shard_map 包起来的函数会在每张 device 上并行执行一次,每张 device 只看到自己那一片的局部张量,所有跨 device 通信(all-reduce、all-gather 等)要靠 jax.lax.psum、jax.lax.all_gather 等 collective 显式写出来。
对照 PyTorch:shard_map 类似 torch.distributed 加上手写 all_reduce 的组合,但形状管理由 JAX 自动处理 - 你写的函数签名是"局部 shape",框架自动把全局张量切片后送进来。pjit(参见第 1 章 1.6.2)是另一个极端,它让 XLA 编译器自己决定通信怎么插;shard_map 是"手写并行",更底层但表达力强。Grok-1 这里的 Linear 反量化用 shard_map 是因为 weight 已经按 sharding 切到各 device,scale 也按相同方式切了,元素级乘法在每张 device 上完全独立,不需要任何跨 device 通信,写成 shard_map 既清晰又零通信开销。
与 Llama2 实现的区别:Llama2 用 nn.Linear 直接、没有内置量化支持;想要 int8/int4 需要靠 bitsandbytes 或 GPTQ 后处理。Grok-1 是训练时就为量化推理做好了一等公民支持。
这种"训练时就为量化做好支持"的设计,在 2023 年还不算主流。当时 GPTQ、AWQ 这些训练后量化方法刚成熟,业界标准做法是先训 bf16,再量化。Grok-1 反其道而行 - 让 Linear 内置反量化路径,这暗示 xAI 在训练 / 推理边界做了量化感知训练(QAT)或者至少量化感知推理的设计。
ckpt 里的 QuantizedWeight8bit 大概率是这样产生的:训练 bf16 → 训练结束做 per-channel int8 量化(每个 output channel 一个 fp32 scale)→ 保存为 (weight: int8, scales: fp32) 二元组。这种"per-channel 静态量化"是工业界最稳的 int8 方案,质量损失通常 <0.5% PPL。
per-channel vs per-tensor 量化
int8 量化最关键的选择是 scale 的粒度。per-tensor 是整张权重共用一个 scale(一个 fp32 数),最省存储但精度最差,因为权重不同行的数值范围常常差好几个数量级,统一缩放会把小值压成零。per-channel 是每个 output channel(在 Linear 中对应权重矩阵的每一行 / 每一列)配一个 scale,存储多一个 [output_size] 的 fp32 向量,但每个 channel 内部的数值都在差不多的量级,量化误差小。
Grok-1 的 QuantizedWeight8bit.scales 形状就是 [output_size] 这一档,对应 per-channel 静态量化。再细一档是 per-group(每 128 个 channel 一个 scale,AWQ 默认设置)、per-token activation 量化(动态)等等,那些方案在 2024 年后逐渐流行,但 Grok-1 的 ckpt 仍然是最稳的 per-channel 静态方案。
4.1.1 QuantizedWeight8bit¶
model.py:37-51:
# model.py:37-51
@dataclass
class QuantizedWeight8bit:
weight: jnp.array
scales: jnp.array
@property
def shape(self):
return self.weight.shape
tree_util.register_pytree_node(
QuantizedWeight8bit,
lambda qw: ([qw.weight, qw.scales], ()),
lambda _, children: QuantizedWeight8bit(children[0], children[1]),
)
weight 是 int8(实际 dtype 由 ckpt 决定)、scales 是 fp32 或 bf16。register_pytree_node 把它变成 JAX 可以遍历的 pytree。这样 jax.tree_map(...) 等操作能自动展开。
run.py:17 重导出为 QW8Bit,但 run.py 没真的用到 8-bit - 默认是全精度 bf16 加载。
bf16 / fp32 混合精度
bf16 是 16 位浮点,1 位符号 + 8 位指数 + 7 位尾数,动态范围和 fp32 一致但精度低。大模型推理用 bf16 做矩阵乘是因为它能让 H100/A100 的 tensor core 跑满,吞吐是 fp32 的 8 倍多,显存也减半。但 bf16 的尾数只有 7 位,在 softmax、norm、累加这种"小数差异要保留"的地方会丢精度。
工程做法是"大头 bf16,敏感点 fp32":参数和 activation 在 bf16 下流动,遇到 RMSNorm 把输入 cast 到 fp32 算完再 cast 回来,attention logit cast 到 fp32 做 softmax 再 cast 回来,bias 也用 fp32 维护。Grok-1 这一套精度策略贯穿全文 - Linear 的 fprop_dtype = inputs.dtype 和 RMSNorm 的 inputs.astype(jnp.float32) 都是这个套路的局部体现。
4.2 RMSNorm:fp32 中间计算¶
model.py:587-624:
# model.py:587-624
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
sharding: Optional[P] = None,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
self.sharding = sharding
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
if self.sharding:
scale = with_sharding_constraint(scale, self.sharding)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = scale.astype(jnp.float32)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
数学上:
RMSNorm(Root Mean Square Norm)
LayerNorm 做两件事:先把激活减去均值、除以标准差,再乘一个可学的 scale 加一个 bias。RMSNorm 把减均值和加 bias 都砍掉,只留"除以 RMS 再乘 scale"这一步,RMS 是 sqrt(mean(x^2)),比标准差好算一点。
省掉减均值这一步的依据来自原 paper 的消融:transformer 里"中心化"对最终质量的贡献几乎可以忽略,真正起作用的是"按尺度归一化"。RMSNorm 的训练曲线和 LayerNorm 几乎重合,但参数和 FLOPs 都少一截。在 LLaMA 之后这套基本是新 dense / MoE base 模型的默认 norm 写法,看到 LayerNorm 反而要追究一下是不是有别的考量。
把 RMSNorm 的公式逐步拆出来:
设输入张量某个 token 位置的 hidden vector 为 \(x = (x_1, x_2, \dots, x_d) \in \mathbb{R}^d\),其中 \(d = 6144\)。
第一步:计算均方(mean square)。 把 \(d\) 个分量平方求平均:
代码里对应的是 mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True),沿最后一维(hidden 维)求平均,输出形状从 [B, T, d] 变成 [B, T, 1]。
第二步:加 epsilon 防止开方为零,再求倒数平方根。 公式上是 \(\frac{1}{\sqrt{\text{MS}(x) + \epsilon}}\),代码里写成 jax.lax.rsqrt(mean_squared + self.eps) - rsqrt 是"reciprocal square root",硬件上有专门的快指令,比先 sqrt 再 1/x 快一倍。
第三步:按位归一化。 把每个分量除以这个倒数平方根的逆向缩放,等价于乘以 rsqrt:
代码里对应 normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps),注意这里没有减均值,所以和 LayerNorm 的核心差别就在这一步。
第四步:乘可学习的逐维 scale \(\gamma\)。 \(\gamma \in \mathbb{R}^d\) 是个 [d] 形状的参数:
整张公式压在一起:
PyTorch 里 LLaMA 风格的 RMSNorm 写法基本一样:
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
PyTorch 版直接用 torch.rsqrt,JAX 版用 jax.lax.rsqrt,逻辑一一对应。最主要的差别是 PyTorch 把 scale 挂在 module 实例上(self.weight),Haiku 用 hk.get_parameter 从外部 params 树里取。
注意几件事:
- eps = 1e-5(默认)- 与 Llama2 一致
- 整个计算上 fp32,最后才 cast 回 fprop dtype(bf16)。这是数值稳定的标配
scale默认 init 是 Constant(0) - 仅做加载占位,真实值从 ckpt 读。注意 RMSNorm 的 scale 在训练时一般初始化为 1,如果真用 0 初始化训练,整个 layer 输出永远是 0。这印证了 init 仅是占位- 没有 mean 减法、没有 bias - 与原始 LayerNorm 不同。这是 RMS 而非 mean-variance norm
hk_rms_norm(model.py:489-496)是个 wrapper:
# model.py:489-496
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
sharding=P(None),
) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = RMSNorm(axis=-1, create_scale=not fixed_scale, sharding=sharding)
return ln(x)
fixed_scale=False 默认会创建 scale 参数。Grok-1 的所有 RMSNorm 都有可学习 scale。
4.2.0 为什么 RMSNorm 而非 LayerNorm¶
LayerNorm 计算均值再减、计算方差再除:
RMSNorm 省略了均值减法和 bias:
理论分析(RMSNorm paper)发现:LayerNorm 的"均值减"步骤对模型质量的贡献很小,把它从 norm 操作中删除之后,在多个语言建模任务上 perplexity 的回退都不到 0.1 PPL,但同时可以节省约 15% 的 norm 计算量。在大模型时代这种"质量几乎不变、计算明显减少"的改动几乎没有理由不采用,所以从 LLaMA 起,所有主流 base 模型都把 LayerNorm 换成了 RMSNorm。
bias 项 \(\beta\) 同理 - 把它从 norm 中删除对质量影响极小,但每个 norm 因此少 6144 个可学参数(在 Grok-1 中 64 层 × 每层 4 个 norm,总计可省 1.5M 参数)。这个数字在 314B 总参里几乎可以忽略,但它体现的是一种"凡是能省的参数都尽量省"的设计倾向。
还有一个少被提到、但同样重要的理由:RMSNorm 对张量并行更友好。LayerNorm 算均值时需要在 hidden 维度上做求和,如果 hidden 维度被切到多张卡上(张量并行的标准切法),求均值需要一次 all-reduce 才能拿到正确的全局均值;RMSNorm 同样需要在 hidden 维度上做平方和,也需要 all-reduce,但它只需要一次通信,而 LayerNorm 要做"均值 → 减均值 → 方差 → 除标准差"两轮 reduce。在 64 层规模下这种通信成本累加起来不可忽视,省一半 reduce 是真金白银的吞吐提升。
Grok-1 的张量切分(参见 model.py:142-149 的 partition rules)确实把 hidden 维沿 model 轴切了 8 份,所以 RMSNorm 的"少一次 reduce"在它的部署 setup 下是真实的工程收益。
4.2.1 与 LLaMA-2 / Mistral 的对比¶
- LLaMA-2 同样使用 RMSNorm,其中 scale 参数在训练开始时初始化为 1,eps 设为 1e-6 防止除以 0
- Mistral / Mixtral 在 RMSNorm 配置上完全沿用了 LLaMA-2 的两项选择,scale 初值与 eps 都不变
- Grok-1 把 eps 默认值放大到 1e-5,比 LLaMA-2 大一个量级;推理代码里 scale 的初始化只用
Constant(0)占位,实际数值从 ckpt 加载,所以"训练时是不是从 1 开始"无从直接判断 - Grok-1 每层 RMSNorm 数量是 LLaMA-2 的两倍:LLaMA-2 一层只有 attention 前和 FFN 前共 2 个 norm,Grok-1 在两个子层的前后各放一个 RMSNorm,每层 4 个 RMSNorm,对应 sandwich norm 布局
第 6 章会展开这个 sandwich norm。
残差流(Residual Stream)
transformer 里每个子层(attention、FFN/MoE)的输出不是替换掉输入,而是和输入相加:x_new = x + sublayer(x)。这条"一直被加东西的主线张量"就叫 residual stream。从输入 embedding 一路传到最后一层 norm 前,中间每个子层只是往这条流上贡献自己的"修正量"。
这个视角对理解 sandwich norm 很关键:sandwich norm 在乎的不是子层内部的数值,而是"加进残差流的那一份量级要可控"。Grok-1 的残差流维度是 6144,要从 64 层、每层 4 个 RMSNorm 的连续累加里保持稳定,所以 xAI 把每个子层输出都先 RMSNorm 一次再加回去,等价于强制每次贡献的量级都在 ~1 附近。
RoPE(Rotary Position Embedding)
早期 Transformer 用 sinusoidal 或学得的绝对位置 embedding,做法是把一组位置向量直接加在 token embedding 上,让位置信息和内容信息共享同一组维度。RoPE 换了个思路:位置不再加在 embedding 上,而是在 attention 里对 Q、K 做一组按位置变化的旋转。第 m 个 token 的 Q 向量按 m 角度转、第 n 个 token 的 K 按 n 角度转,做内积时旋转量自动相减,最后 attention score 只看 Q 和 K 的相对位置差,绝对位置自然消掉。
这个改动看起来只是数学技巧,但实践收益很明显:不再有"绝对位置参数"这种东西,外推到训练时没见过的长度上不会立刻失效;旋转是逐 head 内部做的,不挤占模型其他维度。在 LLaMA 之后这套基本是 dense 与 MoE base 模型的默认位置编码方案,差别只剩下 base 取多大、要不要做长度外推插值这些二阶问题。
为什么 RoPE 比绝对位置编码好
绝对位置编码(无论是 sinusoidal 还是 learned)有几个根本痛点:
- 外推性差。learned absolute position 在训练时见过的最大位置是 4096,到 8192 的位置完全没学过,模型行为不可预测;sinusoidal 至少能算出 8192 位置的向量,但它从未参与训练,模型不知道该怎么用。
- 位置信息和内容信息共享维度。把 position vector 加在 embedding 上,等于让前几个维度同时承担"我是第几个 token"和"我是哪个 token"两件事,两类信号在 Q/K 投影时会相互干扰。
- 只有绝对位置,没有相对位置。但 attention 真正在意的几乎永远是"我距离那个 token 多远",绝对位置只是为了能计算相对位置而存在的中间产物。
RoPE 一次性解决了这三件事:把位置信息放进"旋转"这个独立的几何变换里,不挤占任何 hidden 维度;Q、K 做完点积之后只剩下角度差(相对位置),绝对位置自动消失;外推时只要旋转角度公式不变,新位置照样能算(虽然太远会有频率混叠,需要 base 调整或 YaRN 之类的插值修正)。
PyTorch 这边对应的实现在 HuggingFace transformers.models.llama.modeling_llama.LlamaRotaryEmbedding,逻辑和 Grok-1 几乎一样:预计算 inv_freq、按位置生成 cos / sin、用 apply_rotary_pos_emb 把它们应用到 Q、K 上。如果你读过 LLaMA 的 PyTorch 实现,再看 Grok-1 的 JAX 版本会觉得非常熟悉。
4.3 RoPE:rotate_half 与 RotaryEmbedding¶
model.py:627-691。
# model.py:627-633
def rotate_half(x: jax.Array) -> jax.Array:
"""Obtain the rotated counterpart of each feature"""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
注意 jnp.split(x, 2, axis=-1) 把最后一维前半 / 后半两块。这是 "GPT-NeoX 风格" 的 RoPE 排列,不是 "interleaved" 风格。
| 风格 | 拆分方式 | 代表实现 |
|---|---|---|
| GPT-NeoX | x = [x_first_half, x_second_half] 然后 rotate |
Grok-1, GPT-J, GPT-NeoX, LLaMA, Mistral, Mixtral |
| Interleaved | x = [x_0, x_1, x_2, x_3, ...] 偶数维度配对 |
原始 RoPE 论文, RoFormer |
这两种风格的 RoPE 数值不等价 - ckpt 是用 NeoX 风格训练的,加载到 interleaved 实现会得到错误结果。这是社区迁移 ckpt 时最常见的坑之一。
4.3.1 RotaryEmbedding¶
# model.py:635-691
class RotaryEmbedding(hk.Module):
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
offset = jnp.expand_dims(offset, 0)
if const_position:
t = const_position * jnp.ones(
(1, x.shape[seq_dim],), dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
逐段:
Line 665-668:频率表
exponents = [0, 2, 4, ..., d_h - 2],除以 dim = d_h = 128,得到 64 个频率。
Line 670-684:每个位置的相位
offset 是 KV cache 中已经缓存的 token 数。t 是从 0 开始的相对位置 + offset。在 prefill 阶段 offset = 0,t = [0, 1, ..., T-1];在 decode 阶段 offset = cur_len, t = [cur_len]。
phase 形状是 [B, T, d_h/2],然后 jnp.tile(phase, reps=(1, 2)) 复制成 [B, T, d_h](前后两半相同),再加上 head 维度变成 [B, T, 1, d_h]。
Line 688:旋转应用
这等价于对每个 (dim_i, dim_{i+d/2}) 对做 2D 旋转。
这一行写成纯数学就是:对配对的两个分量 \((x_i, x_{i+d/2})\)(i 在前一半),构造一个 2x2 旋转矩阵
把它作用到列向量 \((x_i, x_{i+d/2})^T\) 上。展开:
把这些等式按 i 排起来,正好等价于 x * cos(phase) + rotate_half(x) * sin(phase):rotate_half 把前后两半互换并对前半取负号,让加法的两路恰好对应旋转矩阵的两行。
PyTorch 里 LLaMA 系列的对应实现长这样:
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
rotate_half 与 Grok-1 完全一致;apply_rotary_pos_emb 等价于 Grok-1 的最后一行。所以 LLaMA、Mistral、Mixtral、Grok-1 在 RoPE 上是同一份代码,只是写在不同框架里。
4.3.1.1 GPT-NeoX 风格还是 interleaved,到底有什么实质区别¶
很多人不理解为什么"前后半"和"奇偶交错"会让 RoPE 数学等价但实现不可互换。我们来推导。
设 head_dim = 4,hidden vector \(x = (x_0, x_1, x_2, x_3)\),频率 inv_freq = \((\omega_0, \omega_1)\),位置 \(t\):
Interleaved(原始 RoPE 论文):
把 (0, 1) 当做一个 2D 复数旋转、(2, 3) 当做另一个 2D 复数旋转。
NeoX(Grok-1):
注意配对的不是 (0,1) 而是 (0, 2) - 即 (前半第 i 位, 后半第 i 位) 配对。
这两种风格在 Q @ K 内积上是等价的吗?
数学上是等价的 - 它们都实现"位置 t 的旋转矩阵作用于配对维度"。但具体的 Q、K 向量值不同,因为坐标排列不同。
后果:
- 一个用 NeoX 风格训练的 ckpt 里,Q 的 dim 0 对应"位置无关分量",dim 2 对应它的"配对位置感知分量"
- 如果你拿这个 ckpt 用 interleaved 实现做 forward,会把 dim 0 和 dim 1 配对 - 完全错乱
这是个坐标系问题,不是数学问题。两种实现都"对",但 ckpt 不能跨。
4.3.2 base = 10000¶
base_exponent = 10000 是 RoPE 的"标准"基数。这意味着:
- 最高频维度:周期约 \(2\pi\) token
- 最低频维度(i = d_h/2 - 1):周期约 \(2\pi \cdot 10000^{(d_h - 2)/d_h} \approx 2\pi \cdot 8500\) token
在 sequence_len = 8192 的范围内,最低频维度甚至没有完成一个完整周期 - 距离 RoPE 频率发生混叠的临界点还有相当大的余量。所以 Grok-1 沿用标准 base 10000 是合理的,没有必要像 LLaMA-3.1 那样将 base 提升到 500000 来支持更长的上下文外推。
RoPE 的频率谱与 base 的含义
把 RoPE 想成一组"信号采样"会更直观。每个 head 的 \(d_h = 128\) 维被分成 64 对,每对配一个频率 \(\omega_i = \text{base}^{-2i/d_h}\)。base = 10000、\(d_h = 128\) 时,频率从最高的 \(\omega_0 = 1\) 一直递减到最低的 \(\omega_{63} \approx 10000^{-126/128} \approx 1.17 \times 10^{-4}\),呈现指数下降。
每个频率对应一个周期 \(T_i = 2\pi / \omega_i\):最高频每隔约 6.28 token 转一圈,最低频每隔约 53500 token 才转一圈。在 8192 上下文长度下,最低频的整个周期都还没走完一半 - 也就是说,"距离很远的两个 token"在低频通道里旋转角度差不到 \(\pi\),仍然能被 attention 区分。
一旦上下文长度变成 32k、128k,base = 10000 就不够了:最低频的旋转角度差会超过 \(\pi\) 进入"折回"区域,模型分不清 "距离 30k" 和 "距离 60k" 是同一个角度。LLaMA-3.1 把 base 提到 500000,相当于把所有频率整体降低 50 倍,让低频通道在 128k 范围内还能保持单调,外推到长上下文不会立刻失效。Grok-1 因为只到 8k,标准 base 完全够用。
4.3.3 与 Llama / Mistral 的对比¶
| 项 | Grok-1 | LLaMA-2 | Mistral 7B | Mixtral 8x7B |
|---|---|---|---|---|
| 风格 | NeoX | NeoX | NeoX | NeoX |
| base | 10000 | 10000 | 10000 | 1000000 |
| 应用位置 | 在 KV cache update 之前 | 一致 | 一致 | 一致 |
| 应用对象 | Q + K(不对 V) | 一致 | 一致 | 一致 |
注意 Mixtral 8x7B 把 RoPE base 提升到了 1000000,目的是支持 32k 上下文的外推 - base 越大、低频维度的周期越长,能覆盖的位置范围也越大。Grok-1 的上下文上限只有 8k,标准 base 10000 已经足够覆盖。
4.4 注意 RoPE 在 KV cache 更新前后的位置¶
model.py:800-844 是 attention 调用里的关键段:
# model.py:800-844
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))
query_heads = rotate(query_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))
@functools.partial(jax.vmap)
def update_into(mem, start, update):
return jax.lax.dynamic_update_slice_in_dim(mem, update, start, axis=0)
if kv_memory:
if mesh is not None:
@functools.partial(
shard_map,
mesh=mesh,
in_specs=(
P("data", None, "model"),
P("data"),
P("data", None, "model"),
),
out_specs=P("data", None, "model"),
check_rep=False,
)
def update_into_shmap(mems, starts, updates):
return update_into(mems, starts, updates)
key_heads = update_into_shmap(kv_memory.k, kv_memory.step, key_heads)
value_heads = update_into_shmap(kv_memory.v, kv_memory.step, value_heads)
else:
key_heads = update_into(kv_memory.k, kv_memory.step, key_heads)
value_heads = update_into(kv_memory.v, kv_memory.step, value_heads)
new_step = kv_memory.step + sequence_length
memory_mask = jnp.arange(kv_memory.k.shape[1]) < new_step[:, None]
memory_mask = memory_mask[:, None, None, :] # [B, H, T, T]
if mask is not None:
mask = memory_mask * mask
else:
mask = memory_mask
new_memory = KVMemory(
k=key_heads,
v=value_heads,
step=new_step,
)
执行顺序:
- 投影 Q/K/V(在
model.py:774-799) - 对 K 和 Q 先应用 RoPE(
model.py:801-803) - 然后把新 K/V 写回 KV cache(
model.py:826-830) - 构造 memory mask(已写入的位置才能被 attend)
注意 RoPE 应用在 KV cache update 之前 - 这意味着 cache 里存的是已经过 RoPE 的 K。这是和 LLaMA 一致的做法。
update_into 用 jax.lax.dynamic_update_slice_in_dim 把新 K/V 写到 kv_memory.step 指定的位置。jax.vmap 把它沿 batch 维 vmap,因为不同 batch 元素可能有不同的 step。
dynamic_update_slice / dynamic_slice
JAX 在 jit 编译时要求所有 tensor 的形状是静态的,但写 KV cache 这种"把新算好的 [B, 1, H, D] 段塞到 [B, 8192, H, D] 大 cache 的某个动态位置"操作,slice 的 start index 是个运行时变量。普通的 Python 切片 cache[:, step] 在 jit 下不允许。
jax.lax.dynamic_update_slice_in_dim(mem, update, start, axis=0) 就是 JAX 提供的"静态形状、动态起点"的 in-place 写入原语,对应的读取是 dynamic_slice。底层 XLA 编译时会把它编成一个带 offset 参数的 memcpy。Grok-1 用它沿 axis=1(seq 维)把新 token 的 K/V 写入 kv_memory.step 指定的位置,再通过 jax.vmap 让 batch 中每个元素的 step 可以不同。
memory_mask:构造一个 [T] 的 0/1 mask,"位置 < 当前 step" 的为 1。然后 expand 成 [B, 1, 1, T],再和 causal mask 相乘。
causal mask(因果遮罩)
自回归语言模型要求位置 i 的 token 只能看到位置 ≤ i 的 token,否则训练时模型会"偷看"未来的答案。实现方式是给 attention 的 [T, T] logit 矩阵叠一个下三角 mask:保留下三角和对角线(i ≥ j 处为 1),上三角(i < j 处)置成 -1e30 或 -inf,softmax 后就变成 0。
Grok-1 在 Transformer.__call__ 里用 jnp.tril(jnp.ones((1, 1, seq_len, seq_len))) 直接构造一个下三角 mask,再和 padding mask 相乘得到最终 [B, 1, T, T] 的 mask。decode 阶段每步只有 1 个新 query,causal mask 退化成 [1, 1],真正起作用的是 memory_mask(只允许 attend 到 cache 里已写入的位置)。
4.4.1 mask 矩阵直观示意¶
prefill 阶段(一次性处理整段输入 prompt),attention 的 logit 矩阵形状是 [T, T],causal mask 是一个下三角矩阵:
K_0 K_1 K_2 K_3 K_4 K_5
Q_0 [ 1 0 0 0 0 0 ]
Q_1 [ 1 1 0 0 0 0 ]
Q_2 [ 1 1 1 0 0 0 ]
Q_3 [ 1 1 1 1 0 0 ]
Q_4 [ 1 1 1 1 1 0 ]
Q_5 [ 1 1 1 1 1 1 ]
矩阵里 1 的位置 attention 正常计算,0 的位置在 softmax 之前会被加上 -1e30 强行抹零。可以读出来 Q_i 只能 attend K_0..K_i 这一段历史,对应于"位置 i 的 token 生成时不能看到未来"这条因果约束。
decode 阶段(自回归生成第 T+1 个 token),attention 矩阵退化成 [1, T+1] - 只有当前这一行需要算。KV cache 里已经存了过去 0..T 时刻所有 token 的 K/V,新 query Q_{T+1} 直接与整个 cache 做一次 (1, T+1) 的乘法。这时候 causal mask 已经简化成全 1 的 [1, 1] 矩阵(query 看自己当然允许),真正起作用的是 memory_mask,它由当前 batch 各样本的 step 计数器算出:step 之前的位置是有效 K/V,step 及之后的位置是没写入的 pad,必须屏蔽。
graph LR
Q["Q new<br/>位置 T+1 的 query"]
K0["K/V 位置 0"]
K1["K/V 位置 1"]
Kd["...已写入的中间位置..."]
KT["K/V 位置 T"]
PAD["K/V 位置 T+1..T_max<br/>未写入<br/>memory_mask 屏蔽"]
A["attention scores<br/>形状 1 x T+1"]
SM["softmax + value 聚合"]
O["attention output"]
Q --> A
K0 --> A
K1 --> A
Kd --> A
KT --> A
PAD -. 0 mask .-> A
A --> SM --> O
这张图也解释了为什么 KV cache 一旦预分配出来([B, T_max, H_kv, d_h]),剩下的 decode 步骤只是把 step 指针往前推、把对应位置的 K/V 写进去,cache 的总形状不变 - 这是 JAX 静态形状要求的工程妥协,第 6.6 节会再展开。
4.5 MultiHeadAttention.__call__ 全程¶
model.py:720-911 是 attention 主体。我们已经看过 RoPE + cache 部分(4.4),现在看 attention 本身的计算。
4.5.0 GQA 的工程动机¶
GQA(Grouped Query Attention)
标准 multi-head attention 里 Q 头数等于 KV 头数,每个 head 都有自己独立的一组 K、V。一旦头数和层数变多,推理时 KV cache 占的显存会膨胀到很难接受 - 因为 cache 大小正比于 层数 × 头数 × seq_len × head_dim,每一项都是几十量级。GQA 的思路是把 Q 头分成若干组,组内共享一组 K、V:比如 48 个 Q 头分 8 组,每组 6 个 Q 共用同一组 K、V,cache 立刻只剩原来的 1/6。
这背后是一种"不对称的表达力假设" - 模型对"提问的方向"需要保留 multi-head 的多样性,但同一份 K/V 上挂多组 Q 一起去检索,质量损失很小。MQA(所有 head 共享 1 组 K/V)是这个思路的极端版,质量明显下降;GQA 介于 MHA 和 MQA 中间,是 LLaMA-2 70B 起几乎所有大模型 attention 的默认选项,只是各家在分组比例上小有差异。
GQA(Grouped Query Attention)是 MHA(Multi-Head Attention)和 MQA(Multi-Query Attention)的折中。
- MHA:标准多头注意力,每个 head 都拥有独立的 Q、K、V 投影矩阵,attention 质量在三种方案中最高,但 KV cache 的体积与 head 数成正比,在长上下文场景下显存压力最大
- MQA:所有 head 共享同一组 K/V(即
num_kv_heads = 1),只有 Q 仍按 head 拆分。KV cache 体积是 MHA 的 \(1/H\),在三种方案中最小,但因为所有 head 看到的是同一份 K/V,表达能力明显下降,在大模型上往往伴随可观的质量回退 - GQA:把 Q 头分成若干组,组内共享同一份 K/V。组数介于 MHA(组数 = 头数)和 MQA(组数 = 1)之间,KV cache 体积和模型质量也都落在两者之间
对于 Grok-1,64 层 × 8192 token × KV cache:
- MHA(如果 num_kv_heads = 48):64 × 8192 × 48 × 128 × 2 (k+v) × 2 bytes = 12 GB / batch
- GQA 48:8:64 × 8192 × 8 × 128 × 2 × 2 = 2 GB / batch
- MQA:64 × 8192 × 1 × 128 × 2 × 2 = 256 MB / batch
Grok-1 选用 6:1 比例的 GQA(48 Q 头对 8 KV 头),KV cache 相对完全 MHA 节省约 6 倍存储,相对 MQA 则大 8 倍。这是在 attention 质量和 cache 体积之间取的一个中间平衡点。
LLaMA-2 70B 采用的是 8:1 GQA(64 Q : 8 KV),Mixtral 8x7B 采用的是 4:1 GQA(32 Q : 8 KV)。Grok-1 的 6:1 介于这两者之间,KV 头数同为 8,但 Q 头数比 LLaMA-2 少、比 Mixtral 多。
MQA(Multi-Query Attention)
标准 MHA 给每个 attention head 都配独立的 K、V 投影,head 数等于 query 头数。Noam Shazeer 在 2019 发现一个简化方案:让所有 head 共用同一组 K、V(num_kv_heads = 1),只有 Q 还按 head 分。这样 KV cache 缩小到原来的 1/num_heads,长序列推理时显存压力骤减,但因为所有 head 看的是同一份 K/V,表达能力明显下降。
GQA 是 MQA 和 MHA 的折中 - 把 head 分组,组内共享 K/V。Grok-1 选 48 Q : 8 KV 的 6:1 GQA,KV cache 比 MQA 大 8 倍但比 MHA 小 6 倍,质量回到接近 MHA 的水平。
4.5.1 Q/K/V 投影¶
# model.py:773-799
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(
query, self.key_size, self.num_q_heads, name="query",
sharding=P("data", "model"), mesh=mesh,
) # [B, T', H, Q=K]
key_heads = projection(
key, self.key_size, self.num_kv_heads, name="key",
sharding=P("data", "model"), mesh=mesh,
) # [B, T, H, K]
value_heads = projection(
value, self.value_size, self.num_kv_heads, name="value",
sharding=P("data", "model"), mesh=mesh,
) # [B, T, H, V]
_linear_projection(model.py:893-911):
# model.py:893-911
def _linear_projection(
self, x, head_size, num_heads, sharding=None, name=None, mesh=None,
):
y = Linear(
num_heads * head_size,
with_bias=False,
name=name, sharding=sharding, mesh=mesh,
)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, num_heads, head_size))
Q/K/V projection 都是 with_bias=False - 注意这点 - 与 MultiHeadAttention.__init__ 的 with_bias=True 默认不冲突,因为这里硬编码 False。
输出 shape:
- Q:
[B, T, 48, 128] - K:
[B, T, 8, 128] - V:
[B, T, 8, 128]
4.5.2 GQA reshape 与 attention 计算¶
# model.py:846-866
query_heads = with_sharding_constraint(query_heads, P(self.data_axis, None, "model", None))
key_heads = with_sharding_constraint(key_heads, P(self.data_axis, None, "model", None))
value_heads = with_sharding_constraint(value_heads, P(self.data_axis, None, "model", None))
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
query_heads = with_sharding_constraint(
query_heads, P(self.data_axis, None, "model", None, None)
)
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
Q 被 reshape 成 [B, T, 8, 6, 128] - 8 个 KV group,每 group 6 个 Q 头。
einsum(Einstein Summation)
jnp.einsum 用一个字符串描述任意阶张量的乘加:箭头左边写输入张量的下标,右边写输出的下标,凡是左边出现但右边没出现的下标都被 sum 掉,相同字母代表 broadcast 或共缩。比如 "ij,jk->ik" 就是普通矩阵乘,"bij,bjk->bik" 是带 batch 的矩阵乘。
einsum 在 attention 实现里特别好用,因为 Q/K/V 是 4-5 阶张量、混合了 batch、head、seq、dim 多个维度。Grok-1 GQA 的核心一行 "...thHd,...Thd->...hHtT" 里 K 没有 H 维(q_per_kv),einsum 自动让 K 在 H 上 broadcast - 等于免费实现了"每 6 个 Q 头共享 1 组 K/V"的 GQA 语义,不用手动 tile K。
注意 einsum:"...thHd,...Thd->...hHtT"
- 输入 Q:
[..., t, h=kv_h, H=q_per_kv, d=128] - 输入 K:
[..., T, h=kv_h, d=128](无 H 维度,因为 K 在 group 内共享) - 输出:
[..., h=kv_h, H=q_per_kv, t, T]
也就是说 K 在 q_per_kv 上是 broadcast,不需要显式复制 K。这就是 GQA 的精髓 - 节省 6 倍的 K/V 内存与 IO。
PyTorch 里实现等价 GQA 的常见做法是手动把 K、V 沿 head 维 repeat_interleave 复制 q_per_kv 份,然后照常做 MHA:
# transformers/models/llama/modeling_llama.py
def repeat_kv(hidden_states, n_rep):
batch, num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_kv_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
PyTorch 这种写法会显式构造一个"复制 6 份"的临时张量,虽然 expand 本身是 view(不真的拷贝内存),但后续 reshape 会触发实际拷贝。Grok-1 的 einsum 方案完全跳过这一步,让 broadcast 在 matmul 内部隐式发生,省一次 reshape 加一次大张量内存分配。在 314B 这个规模下,省这一份临时张量对内存压力的缓解相当明显。
接着:
这是第 3 章 3.3.2 说过的软裁剪。
为什么 attention scale 是 \(1/\sqrt{d_h}\)
标准 scaled dot-product attention 的公式是 \(\text{softmax}(QK^T / \sqrt{d_h}) V\),里面的 \(1/\sqrt{d_h}\) 不是随手加的常数,它有具体推导。
假设 Q 和 K 的每个分量都是均值 0、方差 1 的独立随机变量,那么 Q 和 K 做内积得到的 \(QK^T\) 是 \(d_h\) 个独立乘积的求和。由独立同分布求和的方差公式,\(QK^T\) 的方差等于 \(d_h\),标准差等于 \(\sqrt{d_h}\)。也就是说,hidden head dim 越大,attention logit 的"自然量级"也越大。
softmax 对输入量级极其敏感:输入量级 ~1 时分布相对平滑,量级 ~10 时已经接近 one-hot。如果不做缩放,\(d_h = 128\) 时 logit 量级会是 \(\sqrt{128} \approx 11.3\),softmax 出来基本是 one-hot,梯度近乎全部为 0,训练根本走不动。除以 \(\sqrt{d_h}\) 把 logit 量级拉回 1 附近,softmax 输出回到可学习的平滑区间。
Grok-1 的 attn_output_multiplier = 0.08838834764831845 正好等于 \(1/\sqrt{128}\) - 这个数字不是巧合,是把 scaled dot-product 的 \(1/\sqrt{d_h}\) 显式写在配置里。把它做成超参也方便你在外部实验时调整(虽然实际几乎没人改)。
attention logit 软裁剪(soft-cap)
Attention 里 Q 和 K 算完点积、除以 sqrt(d_k) 之后会得到一组 logit,正常分布在 ±十几的量级。但训练时偶尔会出现某个 logit 异常爆掉到几十甚至上百 - 通常是某对 Q/K 的内容恰好在某些维度上严重共线,再叠上 bf16 的低精度,单个值就走偏了。一旦这种 logit 进了 softmax,exp 一下立刻顶到 fp16/bf16 的上限,结果出 NaN,反传梯度爆炸,整轮训练报废。
硬截断(直接 clip 到 ±cap)能挡住 NaN 但不可导,截断点附近梯度为 0。软裁剪改写成 cap * tanh(x / cap):小 logit 几乎不动(tanh 在 0 附近就是恒等映射),大 logit 被慢慢压到 ±cap 这个上界,全程光滑可导。代价是真的需要极大 logit 才能区分某两个 K 时,差别会被 tanh 抹掉一点,但实际训练里这种极端情形相当罕见。Gemma 2 后来同款思路用 cap = 50,做的是同一件事。
4.5.3 mask 与 softmax¶
# model.py:867-876
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(...)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
mask 在传入时是 [B, 1, T, T],加一个 None 维度变成 [B, 1, 1, T, T],与 logit 的 [B, kv_h, q_per_kv, T, T] 广播。
softmax 是 fp32 计算(因为 logit 是 fp32),然后 cast 回 query dtype(bf16)。这是数值稳定的标配。
softmax 公式与数值稳定性
softmax 把一个向量 \(z = (z_1, \dots, z_n)\) 转成概率分布:
直接按公式算会有数值溢出问题:bf16 的 exp 在输入超过约 88 时就 overflow 成 inf。工程上的标准做法是减去最大值再 exp,结果数学等价但所有 exp 输入都 ≤ 0:
jax.nn.softmax 内部已经做了这个减最大值处理,但只在指定的 dtype 上做。如果 logit 是 bf16,exp 的精度损失仍然存在(bf16 尾数只有 7 位,小的 logit 差异在 exp 后会完全丢失)。所以 Grok-1 先把 logit cast 到 fp32 再做 softmax,全程 fp32 累加,最后才 cast 回 bf16。
PyTorch 里 LLaMA 的 attention 同样有 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 这一行,逻辑完全对应。这种"fp32 算 softmax + bf16 存输出"的混合精度策略,从 GPT-3 时代起就是大模型 attention 的事实标准。
mask 应用:jnp.where vs additive mask
PyTorch 里常见的 mask 应用方式有两种。Additive mask:构造一个 [T, T] 矩阵,允许位置填 0、禁止位置填 -1e9(或 -inf),直接加到 logit 上,softmax 后禁止位置自然变成 0。Multiplicative mask:构造一个 0/1 mask,允许位置填 1、禁止位置填 0,用 masked_fill 把禁止位置的 logit 替换成 -inf。
Grok-1 走的是 jnp.where(mask, attn_logits, -1e30),本质和 PyTorch 的 masked_fill 相同:mask 为 True 的位置保留原 logit,False 的位置替换成 -1e30。注意这里用的不是 -inf 而是 -1e30,原因是 bf16 的最小有限值约为 -3.4e38,-1e30 比 -inf 更"温和",不容易在后续运算里产生 NaN(inf 减 inf 会得到 NaN)。
4.5.4 加权求和与最终投影¶
# model.py:878-891
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
attn = with_sharding_constraint(attn, P(self.data_axis, None, "model", None, None))
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
attn = with_sharding_constraint(attn, P(self.data_axis, None, "model"))
final_projection = Linear(
self.model_size,
with_bias=False,
sharding=P("model", "data"),
mesh=mesh,
)
return MHAOutput(final_projection(attn), new_memory)
einsum "...hHtT,...Thd->...thHd" 同样让 V 在 H(q_per_kv)维 broadcast。
reshape 把 [B, T, kv_h, q_per_kv, V] 压成 [B, T, kv_h * q_per_kv * V] = [B, T, 48 * 128] = [B, T, 6144]。
最终 projection (6144, 6144) with no bias,输出 [B, T, 6144]。
整段从 Q/K/V projection 到 final projection 串起来,对照 PyTorch 里 LLaMA 的 LlamaAttention.forward 几乎可以一一对应:
| 步骤 | Grok-1 (JAX) | LLaMA (PyTorch) |
|---|---|---|
| Q/K/V projection | _linear_projection 三次调用,配 with_sharding_constraint |
self.q_proj(x) 等三次调用,靠 tensor_parallel wrapper 处理切分 |
reshape 到 [B, T, H, D] |
y.reshape((*leading_dims, num_heads, head_size)) |
x.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| RoPE 应用 | RotaryEmbedding(...)(query_heads, ...) |
apply_rotary_pos_emb(q, k, cos, sin) |
| KV cache 更新 | update_into_shmap(kv_memory.k, kv_memory.step, key_heads) |
past_key_value.update(key_states, value_states, layer_idx) |
| 注意力计算 | jnp.einsum("...thHd,...Thd->...hHtT", q, k) |
torch.matmul(q, k.transpose(2, 3)) |
| scale & soft-cap | *= attn_output_multiplier,再 30 * tanh(.../30) |
/ math.sqrt(self.head_dim),LLaMA 没有 soft-cap |
| mask + softmax | jnp.where(mask, logits, -1e30) → jax.nn.softmax |
attn_weights += attention_mask → nn.functional.softmax |
| 加权求和 V | jnp.einsum("...hHtT,...Thd->...thHd", w, v) |
torch.matmul(attn_weights, value_states) |
reshape 回 [B, T, d] |
jnp.reshape(attn, (*leading_dims, -1)) |
attn_output.transpose(1, 2).reshape(bsz, q_len, -1) |
| output projection | Linear(self.model_size, ...) |
self.o_proj(attn_output) |
两个实现的核心算子几乎完全一致,差别集中在以下几处:
- GQA 的实现:Grok-1 用 einsum 隐式 broadcast,LLaMA 用
repeat_kv显式复制 - sharding 表达:Grok-1 用
with_sharding_constraint注解,LLaMA 依赖accelerate、tensor_parallel等外部库自动切 - soft-cap:Grok-1 有
30 * tanh(x / 30),LLaMA 没有 - KV cache 写入:Grok-1 用
dynamic_update_slice_in_dim+ vmap,LLaMA 用past_key_value.update()把新 K/Vtorch.cat到老 cache 上
理解了这张对照表,Grok-1 这段 attention 代码对一个熟悉 PyTorch 的 MLE 来说就不再陌生了。
4.6 与 LLaMA、Mistral、GPT-NeoX 的差异速查¶
| 项 | Grok-1 | LLaMA-2 70B | Mistral 7B / Mixtral | GPT-NeoX |
|---|---|---|---|---|
| Q 头数 | 48 | 64 | 32 | 64 |
| KV 头数 | 8 | 8 | 8 | 64 (no GQA) |
| Q/KV 比 | 6 | 8 | 4 | 1 |
| head dim | 128 | 128 | 128 | 96 |
| RoPE 风格 | NeoX | NeoX | NeoX | NeoX |
| RoPE base | 10000 | 10000 | 10000 / 1000000 | 10000 |
| Attention logit 软裁剪 | 有(30·tanh) | 无 | 无 | 无 |
| QKV bias | 否 | 否 | 否 | 是 |
| Output proj bias | 否 | 否 | 否 | 是 |
| Norm | RMSNorm,pre+post | RMSNorm,pre | RMSNorm,pre | LayerNorm,pre |
Grok-1 的差异点:
- GQA 比例 6:1:Grok-1 把 48 Q 头映射到 8 KV 头,每组 6 个 Q 共享一组 K/V,比 LLaMA-2 70B 的 8:1 略密集一些。代价是 KV cache 相对 MHA 节省 6 倍而不是 8 倍,体积稍大;收益是每组共享的 Q 数少了一点,attention 的表达能力相对更接近 MHA
- 30·tanh 软裁剪:Q·K 内积除以 \(\sqrt{d_h}\) 后再经过
30 * tanh(x/30),把 attention logit 平滑约束在 \(\pm 30\) 区间内。这是本书覆盖的所有模型中唯一一处显式的 logit 裁剪,Gemma 2 后来用了同款 soft-cap,但 LLaMA-2、Mistral、Mixtral 都没有 - sandwich norm:每个 sub-layer 前后各做一次 RMSNorm,每层共 4 个 RMSNorm。这一布局只见于 Grok-1 与 Cohere Command R 系列,LLaMA-2、Mistral、Mixtral 都是标准 pre-norm,每层只有 2 个。详细对比留到第 6 章 DecoderLayer 精读时再展开
下一章进入本仓库最难的部分:MoE 路由的 shard_map 实现。
延伸阅读¶
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints - GQA 的原始论文
- RoFormer: Enhanced Transformer with Rotary Position Embedding - RoPE 原始论文
- Root Mean Square Layer Normalization - RMSNorm 原始论文
- Mixtral of Experts - 同代 GQA + MoE 的最直接对照