LLM 推理(Inference)笔记梳理:为什么慢、慢在哪里、怎么加速

 

TL;DR

  • 自回归生成(decode)时每步只产出 1 个 token,算术强度很低,通常被 显存带宽卡住(memory-bound)。
  • MLP 在 decode 阶段可以靠增大 batch 摊薄“读权重”的成本;但 Attention 在 decode 阶段算术强度几乎锁死在 ~1,batch 只能提吞吐但很难提单 token 利用率。
  • 推理显存主要是两块:模型权重(静态) + KV Cache(随 batch/长度线性增长);KV Cache 既影响 OOM,也影响 latency/throughput 的权衡。
  • 工程上常见的四条加速路径:KV Cache 变小(GQA/MLA/Local/…)batch 做大但不浪费(continuous batching)KV Cache 管得更省(paged attention)decode 步数变少(speculative decoding)

1. 先把推理过程拆开:Prefill vs Decode

同一个模型,在推理时其实有两个阶段:

  1. Prefill(预填充):把 prompt 一次性喂进去,算出每层的隐藏状态,同时把 K/V 写入 KV Cache。这个阶段对 prompt 的 token 可并行,形态更像训练,通常更接近 compute-bound
  2. Decode(生成/解码):自回归逐 token 生成。每一步只处理 “当前 1 个 token”,需要读取(并更新)已有的 KV Cache。这个阶段通常 memory-bound,决定了 tokens/s。

prefill vs decode

2. 符号约定

  • B:batch size(并发请求数)
  • S:上下文长度(prompt + 已生成),在 decode 时随时间增长
  • T:当前这次计算处理的 token 数;prefill 时可取一段(常见 T≈S),decode 时 T=1
  • D:模型隐藏维度
  • L:层数
  • N:Query 头数
  • K:Key/Value 头数(GQA/MQA 下可能 K < N
  • H:单头维度,通常 D = N·H
  • F:MLP 中间维度(常用近似 F = 4D
  • V:词表大小

(结构示意) transformer recap

3. 一条“推理性能主线”:共享权重 vs 私有 KV

推理慢的根源可以用一句话概括:

Decode 阶段每步都要“读很多、算很少”。

而“读很多”主要来自两类数据:

  • 共享权重(Shared Weights):MLP/Attention 的权重矩阵,对所有请求/所有 token 都是同一份。
  • 私有状态(Private States):每条序列自己的 KV Cache,大小随 S、B、L 线性长。

两类层在 decode 时的性质很不一样:

  • MLP 层:主要是矩阵乘(读权重 + 少量激活),权重共享 → batch 越大越划算
  • Attention 层:除了读权重,还要读“每条序列私有”的 KV Cache → batch 变大时 KV 读也同比变大,算术强度几乎不涨

4. 为什么 decode 慢:算术强度(Arithmetic Intensity)+ Roofline

核心概念: \(AI=\frac{\text{FLOPs}}{\text{Bytes moved}}\)

硬件视角(Roofline):当 AI 低到一定程度时,性能上限由 内存带宽决定,而不是由算力(TFLOPS)决定。

以 A100 为例(数量级直觉):

  • FP16 算力 ~312 TFLOPS
  • HBM 带宽 ~2 TB/s
  • “跑满算力”需要的 AI312/2≈156(ops/byte 量级)

decode 时单请求(B=1)的 AI 往往接近 1,因此会出现“算力空转、等数据”的现象。

4.1 从一个 GEMM 看出 batch 的意义(MLP 友好)

考虑矩阵乘 Y = XW,其中 X∈R^{B×D}W∈R^{D×F}

  • FLOPs:2·B·D·F
  • Bytes(FP16,2 bytes):约 2·B·D + 2·D·F + 2·B·F

在推理常见的近似下(F≈4DB ≪ D),2·D·F(读权重)是主项,于是: \(AI \approx B\)

这解释了:MLP 想要更高利用率,本质上需要更大的 batch 来“摊薄读权重”。

4.2 MLP:Prefill 高 AI,Decode 低 AI

以 Llama 常见的 SwiGLU MLP 为例(Up/Gate/Down 三个矩阵),如果忽略算子融合带来的缓存复用,粗略估计得到:

  • FLOPs:≈ 6·B·T·D·F
  • Bytes:包含读写激活 + 读权重,其中推理常见近似是 读权重占主导

B·T ≪ D,F(模型大、当前批量小)时: \(AI_{\text{MLP}} \approx B\cdot T\)

于是:

  • PrefillT≈SAI≈B·S → 往往 compute-bound(“像训练一样快”)
  • DecodeT=1AI≈B → 往往 memory-bound(除非 batch 大到离谱)

4.3 Attention:在 decode 阶段几乎“吃不进 batch”

在 FlashAttention(避免把注意力矩阵写回 HBM)的理想化分析下:

  • FLOPs:≈ 4·B·S·T·D
  • Bytes:≈ 4·B·S·D + 4·B·T·D(读 K/V + 读/写 Q/Y)

得到算术强度: \(AI_{\text{Attn}}=\frac{S\cdot T}{S+T}\)

关键点:公式里没有 B,也没有 D。

两种典型场景:

  • PrefillT=SAI≈S/2(prompt 越长越“算得动”)
  • DecodeT=1AI≈S/(S+1)≈1(无论上下文多长都接近 1)

这就是常说的:decode 的 attention 是“天生 memory-bound”,batch 只能涨吞吐,很难涨单步利用率。

5. 推理显存 / 延迟 / 吞吐量:一个够用的心智模型

把推理资源拆成两块:参数 + KV Cache

5.1 参数量(静态)

下面是一套“够用的估算”(对应 Llama/SwiGLU/GQA 的常见记法):

num_params = 2*V*D + D*F*3*L + (2*D*N*H + 2*D*K*H)*L
  • 2*V*D:embedding + unembedding
  • D*F*3*L:SwiGLU MLP(Gate/Up/Down)
  • (2*D*N*H + 2*D*K*H)*L:Attention(Wq/Wo + Wk/Wv),其中 K < N 体现 GQA/MQA 的节省

5.2 显存占用(动态 + 静态)

以 BF16/FP16(2 bytes)为例:

parameter_bytes = num_params * 2
kv_cache_bytes_per_seq = S * (K*H) * L * 2 * 2  # key+value, bf16 bytes
total_memory = parameter_bytes + B * kv_cache_bytes_per_seq

直觉:

  • parameter_bytes 固定不变
  • KV Cache 随 BS 线性增长,是 OOM 的主要来源之一

5.3 延迟与吞吐(memory-bound 近似)

在 decode 阶段,很多系统的第一近似是:

latency_per_token = bytes_moved / memory_bandwidth
throughput_tokens_per_sec = B / latency_per_token

因此会出现经典 trade-off:

  • 小 B:单用户延迟低,但吞吐差、成本高
  • 大 B:吞吐高,但 KV Cache 更大、排队/调度带来的延迟更高,也更容易 OOM

常见工程策略:

  • 复制模型副本:吞吐近似 ×M,单 token 延迟不变(但显存也 ×M)
  • 模型分片(Tensor Parallel / Sharding):把权重拆到多卡,等价于提升总带宽 → 有机会降低单 token 延迟
  • 分阶段目标:TTFT(首 token)更多受 prefill 影响;tokens/s 更多受 decode 影响,两者的最优 batch 往往不同

6. KV Cache 变小:从主流到前沿的“瘦身路线”

目标很明确:让 B·S·L·(K·H) 这坨东西更小,从而:

  • 更不容易 OOM(能塞更大 batch)
  • decode 时读写更少(更低 latency / 更高吞吐)

6.1 GQA / MQA:用“共享 KV 头”换显存

GQA

  • MHAN 个 Q 头对 N 个 KV 头(KV 最大)
  • MQAN 个 Q 头共享 1 个 KV 头(KV 最小,但精度风险更大)
  • GQA:介于两者之间,N 个 Q,K 个 KV(1 < K < N),工业界主流折中(Llama 3 / Mistral 等)

latency throughput

要点:GQA 本身不一定让单步 attention 变 compute-bound,但它能显著降低 KV Cache,使得 允许更大的 batch / 更少的 OOM,从而系统层面吞吐上去。

参考:https://arxiv.org/pdf/2305.13245

6.2 MLA(Multi-Head Latent Attention):把 KV 投影到低秩 latent

MLA

思路:不存完整 K/V,而是存一个低维 latent(比如从 N·H 压到 C),需要用时再“解码”回来(或等价地在 latent 空间完成计算)。

工程难点之一:RoPE 对位置敏感,通常需要保留一部分不压缩的维度承载位置信息(例如“压缩 512 + RoPE 64”)。

参考:https://arxiv.org/abs/2405.04434

6.3 CLA(Cross-Layer Attention):跨层共享 KV

CLA

思路:GQA 是“同层头共享”,CLA 更激进:让不同层共享一份 KV(或某种共享结构),打破 KV Cache 随层数 L 线性增长的规律。

参考:https://arxiv.org/abs/2405.12981

6.4 Local / Sliding Window Attention:用“只看最近 W”换常数级 KV

local local2

思路:每个 token 只 attend 最近 W 个 token,于是 KV Cache 不再随 S 增长,而是被窗口 W 截断:

  • KV 规模从 O(S) 变为 O(W)(对超长上下文尤其重要)
  • 代价是长距离依赖能力下降

常见混合策略:大多数层用 local,少数层用 global(例如“每 6 层里 5 层 local,1 层 global”),用少量全局层保住“智商”。

参考(Longformer):https://arxiv.org/abs/2004.05150

7. Batch 做大但不浪费:Continuous Batching(Orca)

现实世界的请求长度不一,静态 batching 要么等人(延迟高),要么 padding(浪费算力/显存),还会被“最慢请求”拖住。

Orca 的两点核心思路:

  1. Iteration-level scheduling:把调度粒度从“整个请求”变成“每生成 1 个 token”。每一步结束就:
    • 移除已完成请求
    • 插入新到请求
  2. Selective batching:把不同算子按形态拆开处理
    • MLP/Linear:把所有请求当前 token 的隐藏向量直接拼成一个大矩阵(flatten),避免 padding
    • Attention:仍需按序列区分(或使用变长 attention kernel),避免把上下文搅在一起

orca

8. KV Cache 管得更省:vLLM / PagedAttention

在 vLLM 之前,很多系统按 max_length 预分配一整段连续 KV Cache:用户早停会导致 内部碎片;显存不连续会导致 外部碎片

frag

PagedAttention 的核心是把 KV Cache 切成固定大小 block(像 OS 分页):

  • 按需增长:生成到哪儿就分配到哪儿
  • 物理不连续:用“页表/索引”把逻辑连续映射到物理离散

paged paged2

额外红利:更容易做 共享前缀(copy-on-write),例如:

  • 多请求共享同一段 system prompt
  • best-of-N 并行采样共享同一段 prefix

share share2

参考:https://github.com/vllm-project/vllm(以及相关论文/实现)

9. Decode 步数变少:Speculative Decoding(投机采样)

投机解码利用一个不对称性:

“验证一段 token(并行)”通常比“生成这段 token(串行)”更快。

流程(草图):

  1. 用小模型(draft,分布 p)快速生成接下来 K 个 token
  2. 用大模型(target,分布 q)一次前向并行计算这 K 步的分布
  3. 逐步做 accept/reject:接受的 token 直接“白嫖”,一旦拒绝则按残差分布修正

spec

关键性质:在合适的 accept/reject 规则下,最终输出分布严格等于 q(无损)。直觉上可以理解为“多退少补”:

  • draft 过采样的 token 会以概率 q/p 被拒掉一部分(削峰)
  • draft 欠采样的 token 会从残差 max(q-p,0) 中补回来(填谷)

工程变体(方向感):

  • Medusa:不再引入独立小模型,而是在大模型上加多个预测头一次预测多步
  • EAGLE:让 draft 读取大模型中间特征,提高命中率

spec2

10. 其他常见推理优化(本文不展开)

这份原始笔记里还留了几个坑,推理系统里也确实常见:

  • 量化(Quantization):权重量化(W4A16 / W8A16 等)、KV cache 量化、激活量化;本质是“少搬点 bytes”。
  • 剪枝(Pruning):结构化剪枝减少参数/计算,但需要配合训练或校准来稳住精度。

11. 走得更远:替代 Transformer 的方向

如果把目标定为“彻底摆脱 decode 的内存墙/串行墙”,那就会出现“替代架构”。

11.1 状态空间模型(SSM)/ Mamba 系列

直觉:把历史压成固定大小的状态,推理内存从 O(S) 变为 O(1),长文本生成更友好。

  • S4 → 长序列合成任务强,但语言的“回忆/检索”能力弱
  • Mamba → 引入输入依赖的选择机制,1B 规模能对齐 Transformer 质量且推理很快
  • Jamba → Transformer + Mamba 混合(再叠 MoE),体现“混合可能是赢家”

ssm

11.2 Diffusion(文本扩散/并行生成)

思路:非自回归,不是一词一词蹦,而是多轮迭代并行“去噪”整段文本;在某些结构化任务(如代码)可能非常快。

diffusion

参考与延伸阅读