跳转至

TileLang & Mega-kernel

导言

  • triton 虽然主流,大部分硬件都支持,虽然能快速拿到一部分收益,但是却较难极致性能。
  • 但是Ascend C / PyPTO 又过于Ascend定制化,(学了怎么跳槽啊)
  • 寻找一种更底层,支持极致性能的通用算子编程语言,是极致性能优化里不可或缺的一环。

TileLang、Mega-kernel 与 Triton 的核心差异

这里的 Triton 指 triton-lang 这个 GPU kernel DSL,不是 NVIDIA Triton Inference Server。三者最好不要放在同一抽象层比较:

  • Triton:面向单个 GPU kernel 的 Python-like DSL,主抽象是 block-level program。
  • TileLang:同样面向自定义 kernel,但主抽象是 tile-level dataflow,并更显式暴露共享内存、寄存器 fragment、layout、pipeline。
  • Mega-kernel:不是一种语言,而是一种执行范式。目标是把多个算子,甚至整个 LLM 推理迭代,融合进一个长期运行的 persistent kernel 中。

一句话概括:

Triton 让你更容易写高性能单 kernel;TileLang 让你更可控地写 tile 级高性能 kernel;Mega-kernel 则试图消除多个 kernel 之间的边界,把跨算子调度搬到 GPU 内部。

Triton 基线

Triton 的核心思想是 blocked program, scalar threads。开发者写的是一个 program instance 处理一个 block/tile,典型代码里会使用:

pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
x = tl.load(ptr + offs, mask=...)
y = ...
tl.store(out + offs, y, mask=...)

对于 GEMM、attention 等场景,则通过 tl.dottl.loadtl.storenum_warpsnum_stages、autotune 等手段表达块级计算。Triton 编译器负责把 block 内部的并行性映射到线程、warp、向量化、shared memory、tensor core 等底层机制。

Triton 的优势是:

  • 生态成熟,PyTorch/TorchInductor 集成好。
  • 写 fused elementwise、reduction、custom attention、small GEMM 很方便。
  • 编译器自动处理较多底层细节,如 coalescing、vectorization、shared memory allocation、tensor core instruction selection。
  • 当前官方仓库已标注支持 Linux、NVIDIA GPU Compute Capability 8.0+、AMD GPU ROCm 6.2+,CPU backend 仍在开发中。

但 Triton 的限制也来自其抽象:

  • 大多数时候不直接暴露 thread-level mapping。
  • shared memory layout、bank conflict、fragment layout、custom pipeline 控制相对间接。
  • 对极端定制场景,例如量化 GEMM、特殊 dequant layout、Hopper TMA/WGMMA 复杂流水,有时需要等待编译器支持、写 inline asm,或绕过抽象。

TileLang

TileLang 是一个 tile-level DSL,底层基于 TVM/TIR。它的目标不是替代 PyTorch 图编译,而是让用户更容易写出接近手写 CUDA/CUTLASS 水平的 AI kernel。

TileLang 的典型写法如下:

with T.Kernel(grid_x, grid_y, threads=128) as (bx, by):
    A_shared = T.alloc_shared((BM, BK), dtype)
    B_shared = T.alloc_shared((BK, BN), dtype)
    C_local  = T.alloc_fragment((BM, BN), accum_dtype)

    T.clear(C_local)

    for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
        T.copy(A[by * BM, ko * BK], A_shared)
        T.copy(B[ko * BK, bx * BN], B_shared)
        T.gemm(A_shared, B_shared, C_local)

    T.copy(C_local, C[by * BM, bx * BN])

它的关键抽象有:

  • T.alloc_shared:显式申请 shared memory。
  • T.alloc_fragment:显式表达寄存器/fragment 层级的累加器。
  • T.copy:表达 global/shared/fragment 之间的数据搬运。
  • T.gemmT.reduceT.atomic:tile-level operator。
  • T.Pipelined:表达软件流水。
  • T.annotate_layoutT.use_swizzle:控制或引导 layout。
  • Layout Inference:自动推导 thread binding、vectorized access、shared memory swizzle 等。

TileLang 和 Triton 的本质区别在于:

维度 Triton TileLang
主抽象 block-level program tile-level dataflow
编程入口 @triton.jit @tilelang.jit / @T.prim_func
内存层级 相对隐式,编译器推导较多 显式声明 shared、fragment、global
数据搬运 tl.load / tl.store T.copy / T.async_copy
矩阵计算 tl.dot T.gemm
layout 控制 主要靠编译器与 hints layout annotation + layout inference
pipeline num_stages 等参数 T.Pipelined,并可进一步显式控制
适合场景 常规 fused op、attention、reduction、GEMM 极致 GEMM、dequant GEMM、FlashAttention、MLA、复杂 layout kernel
学习成本 较低 略高,但比手写 CUDA/CUTLASS 低

TileLang 论文中强调,它把 dataflowscheduling space 分离:用户主要描述 tile 之间如何移动和计算,编译器再推导 thread binding、layout、tensorization、pipeline 等细节。相比 Triton,它更接近 “Python 语法下的 CUTLASS/CuTe/TVM 风格 tile 编程”。

需要注意的是,TileLang 生态仍比 Triton 年轻。Triton 胜在成熟度、PyTorch 集成、社区代码量;TileLang 胜在对 memory hierarchy、layout、pipeline 的显式表达能力。

Mega-kernel

Mega-kernel,也常被称为 persistent kernel,在 MPK/Mirage 这类系统里指:

用一个长期运行的 GPU kernel 执行多个算子,甚至整个模型推理迭代,把原本 CPU 端发起的 kernel launch、算子间 barrier、部分调度逻辑移动到 GPU 内部。

传统 LLM 推理通常是 kernel-per-operator:

RMSNorm kernel
-> QKV GEMM kernel
-> RoPE kernel
-> Attention kernel
-> O GEMM kernel
-> AllReduce kernel
-> MLP kernel
...

每个 kernel launch 之间天然存在边界。即使用 CUDA Graph 降低 launch overhead,kernel 边界仍然存在,跨算子的 fine-grained overlap 仍受限。

Mega-kernel 的思路是:

launch mega_kernel once

inside GPU:
    scheduler warps poll events
    worker SMs execute tasks:
        matmul tile
        attention tile
        allreduce chunk
        rmsnorm tile
        mlp tile
    tasks signal events
    dependent tasks become ready

以 Mirage Persistent Kernel 为例,它的设计包括:

  • SM-level task graph:把模型图拆成以 SM 为粒度的 task,而不是以 operator 为粒度。
  • event dependency:task 完成后触发 event,依赖满足后下游 task 可以执行。
  • in-kernel runtime:在 mega-kernel 内部划分 worker SM 和 scheduler warp。
  • JIT/AOT hybrid task launch:对动态耗时任务用 JIT 调度,对稳定任务提前放入队列。
  • paged shared memory:把 shared memory 分页,使跨 task pipeline 成为可能。
  • compute-communication overlap:在多 GPU 场景下,把 AllReduce/AllGather 这类通信拆成 GPU 内部 task,与计算 task 细粒度重叠。
  • PyTorch backend:论文中描述可通过 torch.compile(backend=MPK) 形式接入。

它和 Triton 的关系是正交的:

维度 Triton Mega-kernel / MPK
抽象层 单 kernel DSL 模型图/运行时执行范式
优化单位 一个 kernel 或少量 fused op 多个算子,甚至整个推理迭代
调度位置 CPU launch kernel,GPU 执行 kernel GPU 内部 scheduler/worker 执行 task
kernel 边界 通常一个 @triton.jit 对应一个 kernel 尽量消除多个 kernel 边界
跨算子 pipeline 受 kernel barrier 限制 设计目标就是跨 task pipeline
通信重叠 通常依赖外部 runtime/NCCL/NVSHMEM 调度 可把通信作为 task 融进同一个 mega-kernel
典型收益 单算子/单 kernel 性能 端到端 latency、CPU overhead、kernel barrier、通信重叠
主要代价 需要写 kernel 编译/运行时复杂度、调试难度、硬件绑定更强

要特别强调:Triton 可以手写某些 persistent kernel,但 Triton 本身不是端到端 mega-kernel 编译器。MPK 论文明确指出,PyTorch、Triton、TVM 等系统通常不支持自动生成端到端 mega-kernel。Mega-kernel 的关键不是 “一个 kernel 写得很长”,而是有一套图级依赖分析、SM 级任务分解、GPU 内调度 runtime 和共享内存管理机制。

差异表

维度 Triton TileLang Mega-kernel
类型 Kernel DSL/compiler Tile-level kernel DSL/compiler 执行范式/编译运行时
优化粒度 单 kernel 单 kernel,但 tile/memory 更显式 多算子/整模型迭代
主要目标 降低 CUDA 编程门槛 在高生产力下暴露更多底层控制 消除 kernel 边界与 CPU 调度开销
用户写什么 block program、pointer arithmetic、tl.dot tile dataflow、memory scope、pipeline、layout 模型图、task 配置、runtime 参数
内存控制 较隐式 显式 shared/fragment/global GPU 内部任务共享与分页 shared memory
线程控制 多由编译器推导 Layout inference + 可手动干预 SM 级 worker/scheduler
跨算子融合 有限,通常局部 fusion 主要仍是 kernel 内 fusion 核心能力,追求端到端 fusion
多 GPU 通信 通常交给外部 runtime 可写相关 kernel,但不是系统级调度 可把通信 task 融进 persistent runtime
成熟度 最高 快速发展中 研究/早期系统为主
最适合 快速写自定义 op 极致单算子、复杂 layout kernel 低延迟 LLM serving、small batch、多 GPU overlap

选型建议

如果目标是 快速替换 PyTorch 中的某个热点算子,优先选 Triton。它开发快,社区样例多,和 PyTorch 工具链结合最好。

如果目标是 把 GEMM、FlashAttention、MLA、dequant GEMM 写到接近手写 CUDA/CUTLASS 的水平,并且需要控制 shared memory layout、fragment、pipeline、bank conflict、TMA/WGMMA 等细节,可以考虑 TileLang。

如果目标是 端到端 LLM decode latency,尤其是 batch size 小、kernel launch 多、CPU 调度开销明显、或者多 GPU 通信需要和计算细粒度重叠,那么 mega-kernel/MPK 这类方案更有吸引力。

但 mega-kernel 不一定适合所有场景:

  • 大 batch、单个 GEMM 已经占主导时,收益可能有限。
  • 动态 shape/control flow 复杂时,图特化与调度复杂度会上升。
  • 调试、profiling、维护比 Triton/TileLang 单 kernel 更难。
  • 当前公开 MPK 评测主要集中在 NVIDIA A100/H100/B200 这类 GPU 上,跨硬件通用性仍需验证。

结论

可以用三句话总结:

  • Triton:面向单 kernel 的高生产力 GPU DSL,适合大多数自定义算子开发。
  • TileLang:更 tile-centric、更显式控制 memory/layout/pipeline,适合追求极致性能的复杂 AI kernel。
  • Mega-kernel:不是 Triton 的替代语言,而是图级执行范式,试图把多个算子放进一个 persistent kernel 里,优化端到端延迟和跨算子重叠。

实践中它们并不完全互斥。一个合理的系统可能是:用 Triton 或 TileLang 生成高性能 task/kernel,再用 mega-kernel runtime 进行端到端编排。区别在于,Triton/TileLang 主要解决 “单个 kernel 怎么写快”,mega-kernel 主要解决 “一整个模型执行过程怎么少停顿、少 launch、少 barrier”。

参考资料

  • TileLang 论文:TileLang: A Composable Tiled Programming Model for AI Systems 1
  • TileLang 文档:Language Basics 2
  • TileLang GitHub:tile-ai/tilelang 3
  • Triton 文档:Programming Guide Introduction 4
  • Triton GitHub:triton-lang/triton 5
  • Mirage Persistent Kernel 论文:Mirage Persistent Kernel: A Compiler and Runtime for Mega-Kernelizing Tensor Programs 6
  • Mirage GitHub:mirage-project/mirage 7