torch.compile 里 Inductor 是怎么把编译结果包进 CUDA Graph

 

1. Inductor 后端注册与调用链

1.1 注册点

torch/_dynamo/backends/inductor.py — 只有 32 行,是字符串 "inductor" 的绑定入口:

from torch._dynamo import register_backend

@register_backend   # 把 "inductor" 注册到 _COMPILER_FNS 字典
def inductor(*args, **kwargs):
    from torch._inductor.compile_fx import compile_fx  # 懒加载
    return compile_fx(*args, **kwargs)

1.2 查找机制

torch/_dynamo/backends/registry.py

_COMPILER_FNS: dict[str, CompilerFn] = {}
_default_backend: str | CompilerFn = "inductor"  # 默认后端就是 inductor

def lookup_backend(compiler_fn: str | CompilerFn) -> CompilerFn:
    if isinstance(compiler_fn, str):
        compiler_fn = _COMPILER_FNS[compiler_fn]  # "inductor" → 上面注册的函数
    return compiler_fn

1.3 完整调用链

torch.compile(model, backend="inductor")
  │
  ▼
torch/_dynamo/eval_frame.py
  get_compiler_fn("inductor") → lookup_backend("inductor") → _COMPILER_FNS["inductor"]
  │
  ▼
torch/_dynamo/backends/inductor.py :: inductor(gm, example_inputs)
  └─ compile_fx(gm, example_inputs)
  │
  ▼
torch/_inductor/compile_fx.py :: compile_fx()
  └─ 定义 fw_compiler / bw_compiler / inference_compiler
  └─ aot_autograd(fw_compiler, bw_compiler, decompositions, partition_fn)
       │
       ├─ (3a) 算子分解(decompositions,如 F.linear → mm + add)
       ├─ (3b) min-cut 分区:切出 forward graph + backward graph
       └─ (3c) fw_compiler(forward_gm) / bw_compiler(backward_gm)
                │
                ▼
             compile_fx_inner(gm)
                │
                ▼
             fx_codegen_and_compile(gm)
                │
                ├─ GraphLowering(gm).run()      ← FX op → Inductor IR
                ├─ graph.compile_to_fn()        ← Inductor IR → Triton/C++ kernel
                └─ CompiledFxGraph.post_compile()
                     └─ cudagraph_post_compile() ← 包 CUDA Graph

1.4 Inductor 核心:GraphLowering

compile_fx_inner 里最关键的两步(compile_fx.py ~1486):

graph = GraphLowering(gm, example_inputs=example_inputs, ...)

with V.set_graph_handler(graph):
    graph.run(*example_inputs)     # FX node → Inductor IR (TensorBox, Pointwise, Reduction...)

with dynamo_timed("GraphLowering.compile_to_fn"):
    compiled_fn = graph.compile_to_fn()  # Inductor IR → Triton kernel → 编译

GraphLowering.run() 遍历 FX graph 每个节点,调用 torch/_inductor/lowering.py 里注册的 lowering 函数,把 torch.ops.aten.xxx 转成 Inductor IR 节点。


2. Inductor 与 CUDA Graph 的集成

2.1 开关

# config.py,triton 子配置
cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"  # 默认关闭
cudagraph_trees = True   # 树形内存池共享,默认开启

# 开启方式
torch.compile(model, options={"triton.cudagraphs": True})
# 或
TORCHINDUCTOR_CUDAGRAPHS=1 python train.py

compile_fx_inner 入口处创建 BoxedBool(可变容器,允许下游原地 disable):

# compile_fx.py ~900
if graph_kwargs.get("cudagraphs") is None:
    graph_kwargs["cudagraphs"] = BoxedBool(config.triton.cudagraphs)

2.2 编译期检查(哪些情况会 skip)

fx_codegen_and_compile 阶段,结果存入 CompiledFxGraph.disabled_cudagraphs_reason, 随编译产物一起缓存:

检查项 触发条件
动态 shape cudagraph_skip_dynamic_graphs=True 且有 SymInt 输入
不兼容 op get_first_incompatible_cudagraph_node(gm) 返回非空
CPU/多设备 check_lowering_disable_cudagraph(device_node_mapping)
输入 mutation 非 cudagraph 管理的 tensor 被 mutate

2.3 post_compile — 实际包装发生的地方

output_code.py :: CompiledFxGraph.post_compile()

if cudagraphs:
    if self.disabled_cudagraphs_reason:
        BoxedBool.disable(cudagraphs)   # 有 skip 原因,放弃
    else:
        if config.graph_partition and policy is None:
            cudagraph_partition_post_compile(...)  # 按 partition 分别包装
        else:
            cudagraph_post_compile(...)            # 整图包装

cudagraph_post_compile 最终替换 compiled_graph.current_callable

compiled_graph.current_callable = cudagraphify(
    current_callable,
    static_input_idxs=...,
    device_index=...,
    is_backward=...,
    is_inference=...,
    ...
)

2.4 cudagraphify — 两条路径的分叉

compile_fx.py ~1883:

def cudagraphify(model, static_input_idxs=(), *, device_index, ...):
    if config.triton.cudagraph_trees:
        cudagraphify_fn = functools.partial(new_cudagraphify_impl, ...)  # 树形版
    else:
        cudagraphify_fn = cudagraphify_impl   # 简单版

    compiled_fn = None

    def run(new_inputs):   # 懒惰录制:第一次真正调用时才录制
        nonlocal compiled_fn
        if compiled_fn is None:
            compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
        return compiled_fn(new_inputs)

    return run

3. 两种 CUDA Graph 实现

3.1 简单版 cudagraphify_implcudagraph_trees=False

compile_fx.py ~1951,无内存池共享:

def cudagraphify_impl(model, inputs, static_input_idxs=()):
    # 1. 分配 static_inputs
    static_inputs = [static_input(x) if idx not in static_input_idxs else x.detach()
                     for idx, x in enumerate(inputs)]

    # 2. warmup(side stream)
    torch.cuda.synchronize()
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        model(list(static_inputs))
    torch.cuda.synchronize()

    # 3. 录制
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
        static_outputs = model(list(static_inputs))

    # 4. replay 闭包
    def run(new_inputs):
        for idx in copy_indices:
            index_expanded_dims_and_copy_(static_inputs[idx], new_inputs[idx], ...)
        new_inputs.clear()
        graph.replay()
        return static_outputs

    return run

3.2 树形版(cudagraph_trees=True,默认)

cudagraph_trees.py :: cudagraphify_impl,按 int 输入(动态 shape)做 key 缓存:

def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
    fn_cache: dict[tuple[int,...], Callable] = {}

    def deferred_cudagraphify(inputs):
        int_key = get_ints(inputs)
        fn = fn_cache.get(int_key)
        if fn is not None:
            return fn(inputs)
        # 第一次:录制,结果缓存
        fn, out = cudagraphify(model, inputs, ...)  # → CUDAGraphTreeManager.add_function()
        fn_cache[int_key] = fn
        return out

    return deferred_cudagraphify

4. CUDAGraphTreeManager

每个 device 一个实例,存在 thread-local 的 TreeManagerContainer 里。

4.1 初始化:创建共享内存池

def __init__(self, device_index):
    with torch.cuda.device(device_index):
        torch.cuda.synchronize()
        self.stream = torch.cuda.Stream()
        self.graph = torch.cuda.CUDAGraph()
        self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()  # 共享内存池

        # 用空图激活 pool,保持其存活
        with torch.cuda.graph(self.graph, pool=self.cuda_graphs_thread_pool, ...):
            pass

4.2 _run() 热路径状态机

_run(new_inputs, function_id)
  │
  ├─ 还在 recording/warmup 中?→ try_end_curr_recording/warmup
  │
  ├─ 检测到非 cudagraph 管理的 mutation?→ 直接 eager 执行
  │
  ├─ 还没 warmup?→ run_eager() → CUDAWarmupNode(在 pool 内存里跑)
  │
  ├─ 已 warmup,扫描 child_nodes[function_id]:
  │   ├─ check_invariants() == SUCCESS → execute_node() → graph.replay()
  │   └─ 不匹配 → 计入 re-record 次数
  │       └─ 超过 cudagraph_unexpected_rerecord_limit(128) → 永久 fallback eager
  │
  └─ 没有匹配的 child → record_function() → 新建 CUDAGraphNode

4.3 节点类型

类型 作用
CUDAWarmupNode 第一次执行(warmup),在 pool 内存里 eager 跑,建立 allocator 状态
CUDAGraphNode 录制一次 CUDA Graph,后续 replay

5. CUDAGraphNode — 实际录制

构造函数里直接调用 _record(),使用共享内存池:

def _record(self, model, inputs):
    with (
        preserve_rng_state(),
        torch.cuda.device(self.device),
        clear_cublas_manager(),
        torch.cuda.graph(
            self.graph,
            stream=self.stream,
            pool=self.cuda_graphs_pool,           # 共享内存池
            capture_error_mode="thread_local",
        ),
        CUDAGraphCaptureControlFlowOpDispatchMode(),
        get_history_recording(),
    ):
        static_outputs = model(inputs)
    return static_outputs

run() 方法(replay 路径):

def run(self, new_inputs):
    self.check_static_inputs_are_stable(new_inputs)   # 检查 static ptr 没变
    self._copy_inputs_and_remove_from_src(            # 新 input 拷到 static slot
        self.reconstructed_inputs, new_inputs
    )
    self.run_graph()                                   # graph.replay()
    return self.reconstruct_outputs()                  # 从 metadata 重建 tensor

6. 前向/后向协调

cudagraph_post_compile 对 backward 有特殊处理(output_code.py):

if is_backward and config.triton.cudagraph_trees:
    def compiled_artifact(new_inputs):
        manager.set_to_running_backward()  # 通知 manager 进入 backward 模式
        return compiled_graph_callable(new_inputs)
    compiled_graph.current_callable = compiled_artifact

set_to_running_backward()running_forwards_with_pending_backwards 置为 False, 允许 manager 在 backward 完成后开始新的 generation(新的训练迭代)。


7. 关键配置

# 主开关(默认关)
torch.compile(model, options={"triton.cudagraphs": True})

# 树形 pool 共享(默认开)
torch.compile(model, options={"triton.cudagraph_trees": True})

# 动态 shape 下跳过(默认不跳过,会为每个 shape 录制一个图)
torch.compile(model, options={"triton.cudagraph_skip_dynamic_graphs": True})

# 最多允许多少次 re-record(默认 128)
torch.compile(model, options={"triton.cudagraph_unexpected_rerecord_limit": 128})

# 强制要求 cudagraph,不兼容时报错而非 skip
torch.compile(model, options={"triton.cudagraph_or_error": True})

8. 各层职责总结

文件 职责
注册 torch/_dynamo/backends/inductor.py 把字符串 "inductor" 绑定到 compile_fx
查找 torch/_dynamo/backends/registry.py lookup_backend("inductor")
调度 torch/_dynamo/eval_frame.py Dynamo 追踪后把 FX graph 交给 backend
AOT 层 torch/_inductor/compile_fx.py :: compile_fx 调 AOT Autograd,拆 fw/bw
Inductor 编译 compile_fx_inner + GraphLowering FX op → Inductor IR → Triton kernel
CUDA Graph 开关 compile_fx_inner 创建 BoxedBool(config.triton.cudagraphs)
CUDA Graph 包装 CompiledFxGraph.post_compile 替换 current_callable
分发 compile_fx.cudagraphify 懒惰录制,分发到 trees 或简单版
内存池管理 CUDAGraphTreeManager 共享 pool,维护 warmup/record/replay 状态机
单次录制 CUDAGraphNode 持有 torch.cuda.CUDAGraph,管理 tensor 生命周期