CS336 Lecture Notes 3

3184 字
16 分钟
CS336 Lecture Notes 3
Warning

本文为本人学习相关开源课程过程中,整理的个人学习笔记及作业解答,核心目的仅用于记录个人学习轨迹、巩固所学知识、梳理学习思路,全程为个人自主学习使用,不具备任何商业用途,也不构成任何形式的课程辅导或标准答案参考。

需特别说明的是,由于本人学习进度及知识储备有限,笔记内容及作业解答中可能存在大量纰漏、思路偏差甚至错误,仅代表本人当时的学习理解,不具备权威性和准确性。

在此郑重提醒:请勿将本文中的任何作业解答复制粘贴,作为自身所修课程的提交答案。任何因抄袭本文内容导致的课程成绩问题、学术诚信问题,均由抄袭者自行承担全部责任,本人不承担任何相关连带责任。

同时,本文所分享的内容均基于开源课程的公开内容整理,尊重原课程创作者的知识产权,若涉及相关内容的版权问题,请及时联系本人,本人将第一时间进行调整或删除。

感谢各位读者的理解与支持,也欢迎大家针对笔记及解答中的问题提出宝贵建议,共同交流学习、共同进步。

Important

GPUs, TPUs#

GPU 与 CPU 的区别#

  • CPU:优化少量快速线程,强调延迟优化(每个线程快速完成)
  • GPU:优化大量并行线程,强调吞吐量优化(总处理数据量大)
  • GPU 拥有许多小型计算单元(ALU),对分支(控制、缓存)的支持较少

GPU 架构#

执行单元#

  • SM (Streaming Multiprocessor):GPU 拥有多个 SM,每个 SM 独立执行”块”(blocks/jobs)
  • SP (Streaming Processor):每个 SM 包含多个 SP,可并行执行”线程”(threads)

内存层次#

距离 SM 越近的内存越快:

  • L1 / 共享内存:位于 SM 内部,最快
  • L2 缓存:位于芯片上
  • 全局内存:GPU 旁边的内存芯片

SRAM(共享/缓存内存)比 DRAM(全局内存)贵约 100 倍,但快约 8 倍

GPU 执行模型#

三个重要概念:

  • Thread(线程):并行执行工作单元,所有线程执行相同指令但输入不同(SIMT 模型)
  • Block(块):线程组,每个块在一个 SM 上运行,拥有自己的共享内存
  • Warp(线程束):32 个连续编号的线程总是作为一个 warp 执行

GPU 内存模型#

  • 每个线程可以访问自己的寄存器和块内的共享内存
  • 跨块的信息需要通过全局内存读写(较慢)

TPU 简介#

TPU 与 GPU 在高层类似:

  • 核心结构:轻量控制、快速(大型)矩阵乘法单元、快速内存
  • 差异:加速器的网络方式不同,没有 warp(只有 blocks)
特性GPUTPU
计算单元更多 SM更少 TC(但矩阵乘法性能相近)
执行单位Warp (32 线程)Block

GPU 性能优化#

Roofline 模型#

关键问题:如何避免成为内存瓶颈?

优化技巧 1:低精度计算#

更少的位数意味着更少的数据移动。

示例:元素级 ReLU(向量大小 n)

  • Float32:内存访问 8 bytes/FLOP
  • Float16:内存访问 4 bytes/FLOP

Tensor Coores(V、T 系列引入)是专门的矩阵乘法电路,矩阵乘法比其他浮点运算快 10 倍以上。

前沿低精度技术

  • FP8:不同权衡
  • MXFP8(Blackwell):每个缩放因子对应 32 个元素
  • MXFP4:更极致的量化

优化技巧 2:算子融合#

将多个操作合并到单个 CUDA kernel 中,减少内存访问。

示例sin2x+cos2x\sin^2 x + \cos^2 x 的计算

  • 朴素方法:启动 5 个 CUDA kernel(反复读写内存)
  • 融合方法:合并为单个 kernel 调用

简单融合可由编译器自动完成(如 torch.compile)。

优化技巧 3:重计算#

在反向传播中,存储激活值可能很昂贵。有时丢弃激活值并重计算反而是最优的(减少内存访问)。

示例:3 个 sigmoid 堆叠

  • 朴素方法:8 次内存读写
  • 重计算方法:5/8 的内存访问量

优化技巧 4:内存合并#

DRAM 以”突发模式”读取——每次读取给出多个字节!

内存访问被合并的条件:warp 中的所有线程访问都在同一突发范围内。

矩阵乘法中的合并:对于行优先矩阵,沿行移动的线程不会被合并。

优化技巧 5:分块(Tiling)#

分块是分组和排序线程以最小化全局内存访问的技术。

矩阵乘法分块

  1. 将矩阵切成小块(tiles),加载到共享内存
  2. 分阶段计算矩阵乘法
  3. 重复读取现在访问共享内存而非全局内存

分块数学

  • 非分块:每个输入从全局内存读取 N 次
  • 分块:每个输入从全局内存读取 N/T 次,在 tile 内读取 T 次
  • 结果:全局内存访问减少 T 倍

分块复杂问题

  • Tile 大小可能不整除矩阵维度,导致利用率低
  • 合并访问可能无法实现(需要填充)
  • 需要考虑:共享内存大小、矩阵维度可除性

矩阵尺寸与性能的神秘关系#

为什么更大的矩阵更快?

  1. 分块影响:分块通过对齐产生重大影响
  2. 波量化(Wave Quantization):周期性行为
    • 示例:使用 256×128 的 tile,1792 大小产生 98 个 tiles
    • 增加到 1793 产生 120 个 tiles
    • A100 有 108 个 SM,无法同时执行所有 120 个 tiles

Flash Attention 解析#

Flash Attention [Dao et al] 大幅加速注意力计算,核心是分块 + 在线 softmax。

注意力计算回顾#

注意力计算:3 个矩阵乘法(K、Q、V)中间有 softmax。

分块策略#

  1. KQV 矩阵乘法的分块:标准的矩阵乘法分块
  2. 增量 softmax 计算:通过在线更新 max 值和望远镜求和,逐 tile 计算 softmax

在线 Softmax 公式(来自 Mikailov & Gimelshein 2018):

  • 跟踪最大值
  • 增量更新最大值,设置望远镜求和

Flash Attention 前向传播总结#

  • Tile 级计算内积(S)
  • 融合指数运算
  • 通过在线望远镜求和技巧逐 tile 计算 softmax

总结#

GPU 性能优化的核心思路:

策略方法
减少内存访问合并、融合
将内存移到共享内存分块
用内存换计算/精度量化、重计算

硬件驱动扩展,底层细节决定什么能扩展。当前 GPU 计算强烈鼓励思考矩阵乘法 + 数据移动。

Kernels, Triton, XLA#

GPU 硬件规格#

加速器A100H100B200
# SMs108132148
Register size (per SM)256 KB256 KB256 KB
L1 cache + shared memory (per SM)192 KB256 KB256 KB
L2 cache size40 MB50 MB96-126 MB
HBM size80 GB80 GB192 GB
Register bandwidth~116 TB/s~401 TB/s~447 TB/s
L1 + shared memory bandwidth~19 TB/s~33 TB/s~19 TB/s
L2 cache bandwidth~5-8 TB/s~12 TB/s~9 TB/s
HBM bandwidth2 TB/s3.35 TB/s8 TB/s

B200 还有 Tensor Memory (TMEM) 用于 tensor cores,位于寄存器和共享内存之间,对程序员不可见。

编程模型#

  • Thread(线程):在数据的小部分上执行代码
  • Thread block / CTA:线程组,可以访问同一共享内存
  • Grid:线程块的集合

为什么需要 Thread Block?

  • 元素级操作(如 GeLU):线程最自然,每个线程处理一个元素
  • 非元素级操作(如 softmax、矩阵乘法):线程需要通信
  • 从 HBM 读写很慢,使用共享内存(SM 本地)
  • Thread block:访问相同共享内存的线程集合,在一个 SM 上调度

程序模型与硬件交互#

编程模型提供硬件抽象,理论上只需关注编程模型(保证正确性)。但性能高度依赖硬件细节,需要理解硬件才能获得高性能。

Warps(线程束)#

  • Thread block 内的线程分组为 warps(每 warp 32 线程)
  • Warp 内所有线程必须以 lockstep 执行相同指令
  • 控制分歧(Control divergence):不同线程需要执行不同指令时,必须串行执行(性能下降)

Warp Occupancy(占用率)#

  • 每个线程可使用 0-255 个寄存器
  • 线程使用更多寄存器 → SM 上可调度的线程更少(低占用率)
  • 低占用率不一定坏事,如果每个线程做更多工作
  • 示例:block 有 128 线程,每线程 160 寄存器,SM 有 65536 寄存器
    • num_blocks = 65536 / (128 × 160) = 3
    • occupancy = 3 × 128 / 32 / 64 = 12.5%

Bank Conflicts(共享内存银行冲突)#

  • 共享内存分为 32 个 banks,每个 4 字节宽
  • 每个 cycle,每个 bank 只能被一个线程访问
  • 多线程访问同一 bank → 访问串行化(bank conflict)
  • 最坏情况:矩阵每行横跨所有 banks;32 线程访问第一列 → 32-way bank conflict
  • 解决方案:swizzling 重排共享内存(如 row xor col)

Memory Coalescing(内存合并)#

  • Warp 中 32 线程访问 HBM 时,内存访问合并为 128 字节的 cache line 事务
  • 最佳情况:完全合并,所有线程访问同一 cache line

Block Occupancy(块占用率)#

  • Thread blocks 以 waves 形式调度到 SMs
  • B200 有 148 SMs,启动 160 blocks → 第一波 148,第二波 12
  • Wave quantization 问题:最后一波 blocks 更少,部分 SM 空闲
  • 解决方案:使 block 数量能整除 SM 数量

Benchmarking 与 Profiling#

成功配方#

  1. Benchmark 和 profile 代码
  2. 做改动
  3. 再次 benchmark 和 profile

Benchmarking#

测量操作的 wall-clock 时间,只给出端到端时间。

用途:

  • 比较不同实现(哪个更快)
  • 理解性能如何随维度扩展

关键技巧:

  • Warmup:首次运行可能较慢(编译等)
  • CUDA events:准确 GPU 计时(避免 CPU 开销)
  • torch.cuda.synchronize():等待 CUDA 线程完成

Profiling#

查看时间花费在哪里,帮助理解底层发生了什么。

PyTorch 内置 profiler,更详细可用 nsight。

CUDA kernel 名称解读

cutlass3x_sm100_simt_sgemm_f32_f32_f32_f32_f32_64x64x16_1x1x1_3_nnn_align1_bi...
  • cutlass: NVIDIA 的 CUDA 线性代数库
  • sm100: Blackwell 架构 (B200)
  • f32: float32
  • 64x64x16: tile 形状

Triton 简介#

特性CUDATriton
开发者NVIDIAOpenAI
编程单位ThreadThread block
优点细粒度控制更简单,自动管理共享内存
缺点需管理更多细节灵活性略低

Triton 概念框架:加载数据到共享内存 → 操作(可融合)→ 写回 HBM

Triton Kernel 示例#

GeLU(元素级操作)#

@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) # Block ID
start = pid * BLOCK_SIZE
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
x = tl.load(x_ptr + offsets, mask=mask)
# GeLU 计算
a = 0.79788456 * (x + 0.044715 * x * x * x)
exp = tl.exp(2 * a)
tanh = (exp - 1) / (exp + 1)
y = 0.5 * x * (1 + tanh)
tl.store(y_ptr + offsets, y, mask=mask)

PTX 观察要点

  • ld.global.* / st.global.*: 全局内存读写
  • %ctaid.x: block index, %tid.x: thread index
  • %f*: 浮点寄存器, %r*: 整数寄存器
  • 一个线程处理 8 个元素(thread coarsening)

Softmax(Reduce 操作,行 fits in block)#

朴素实现:5MN + M reads, 3MN + 2M writes

Triton fused softmax:MN reads, MN writes(提速 ~4x)

@triton.jit
def triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, num_cols, BLOCK_SIZE):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
x_row = tl.load(x_ptr + row_idx * x_row_stride + col_offsets,
mask=col_offsets < num_cols, other=float("-inf"))
x_row = x_row - tl.max(x_row, axis=0) # 减去 max
numerator = tl.exp(x_row)
denominator = tl.sum(numerator, axis=0)
y_row = numerator / denominator
tl.store(y_ptr + row_idx * y_row_stride + col_offsets, y_row,
mask=col_offsets < num_cols)

Row Sum(Reduce 操作,行不 fits in block)#

当行太大无法放入单个 block 时:

策略

  • 将行分成多个 tiles
  • 每个线程遍历 tiles 并累加
  • 最后对每个线程的累加器做 reduction
@triton.jit
def row_sum_kernel(x_ptr, out_ptr, N, BLOCK_SIZE):
row = tl.program_id(0)
acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for start in range(0, N, BLOCK_SIZE): # 遍历 tiles
cols = start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + row * N + cols, mask=mask, other=0.0)
acc += x
result = tl.sum(acc, axis=0) # 最终 reduction
tl.store(out_ptr + row, result)

Matmul ReLU(Tiling + Fusion)#

朴素方法

  • 固定 (m, n),遍历 k
  • 每次 read A[m,k] 和 B[k,n]
  • MKN reads, MN writes
  • Arithmetic intensity: O(1)

理想方法

  • 加载全部 A 和 B 到共享内存
  • MK + KN reads, MN writes
  • Arithmetic intensity: O(N)
  • 问题:矩阵通常太大

Tiling 解决方案

将矩阵 C 分成 output tiles(thread blocks):

  • 固定一个 output tile
  • 遍历 A 的行 tile 和 B 的列 tile
  • 加载对应 A tile 和 B tile 到共享内存
  • 执行 tile 级矩阵乘法
  • 累加部分和
  • 写 output tile 到 HBM

Arithmetic intensity: O(tile_size)

@triton.jit
def matmul_relu_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, BLOCK_M, BLOCK_N, BLOCK_K):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
indices_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
indices_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
indices_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + indices_m[:, None] * stride_am + indices_k[None, :] * stride_ak
b_ptrs = b_ptr + indices_k[:, None] * stride_bk + indices_n[None, :] * stride_bn
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=...)
b = tl.load(b_ptrs, mask=...)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
acc = tl.maximum(acc, 0.0) # ReLU fusion
c_ptrs = c_ptr + indices_m[:, None] * stride_cm + indices_n[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=...)

总结#

层面要点
编程模型PyTorch, Triton, PTX → 正确性
硬件理解SMs, warps, occupancy, bank conflicts, coalescing → 性能
Benchmarking测量端到端时间,理解扩展规律
Profiling查看时间花费,理解底层执行
Triton 思维Thread block:读到共享内存 → 操作(fusion)→ 写回 HBM

下次:多 GPU!

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
CS336 Lecture Notes 3
https://llm-tech.com.cn/posts/cs336-lec-notes-3/
作者
Ming
发布于
2026-05-04
许可协议
CC BY-NC-SA 4.0
Profile Image of the Author
Ming
你是来找 Ming 学习的吗
🎉 欢迎来到 Ming 的博客
这里是我的个人博客,分享 AI Infra、LLM 等技术内容。欢迎关注交流!
分类
标签
站点统计
文章
19
分类
6
标签
12
总字数
69,591
运行时长
0
最后活动
0 天前

目录