CS336 Lecture Notes 3
本文为本人学习相关开源课程过程中,整理的个人学习笔记及作业解答,核心目的仅用于记录个人学习轨迹、巩固所学知识、梳理学习思路,全程为个人自主学习使用,不具备任何商业用途,也不构成任何形式的课程辅导或标准答案参考。
需特别说明的是,由于本人学习进度及知识储备有限,笔记内容及作业解答中可能存在大量纰漏、思路偏差甚至错误,仅代表本人当时的学习理解,不具备权威性和准确性。
在此郑重提醒:请勿将本文中的任何作业解答复制粘贴,作为自身所修课程的提交答案。任何因抄袭本文内容导致的课程成绩问题、学术诚信问题,均由抄袭者自行承担全部责任,本人不承担任何相关连带责任。
同时,本文所分享的内容均基于开源课程的公开内容整理,尊重原课程创作者的知识产权,若涉及相关内容的版权问题,请及时联系本人,本人将第一时间进行调整或删除。
感谢各位读者的理解与支持,也欢迎大家针对笔记及解答中的问题提出宝贵建议,共同交流学习、共同进步。
- 课程网站: https://cs336.stanford.edu/
- Lec05 资料: lecture_05.pdf
- Lec06 资料: lecture_06.py
GPUs, TPUs
GPU 与 CPU 的区别
- CPU:优化少量快速线程,强调延迟优化(每个线程快速完成)
- GPU:优化大量并行线程,强调吞吐量优化(总处理数据量大)
- GPU 拥有许多小型计算单元(ALU),对分支(控制、缓存)的支持较少
GPU 架构
执行单元
- SM (Streaming Multiprocessor):GPU 拥有多个 SM,每个 SM 独立执行”块”(blocks/jobs)
- SP (Streaming Processor):每个 SM 包含多个 SP,可并行执行”线程”(threads)
内存层次
距离 SM 越近的内存越快:
- L1 / 共享内存:位于 SM 内部,最快
- L2 缓存:位于芯片上
- 全局内存:GPU 旁边的内存芯片
SRAM(共享/缓存内存)比 DRAM(全局内存)贵约 100 倍,但快约 8 倍
GPU 执行模型
三个重要概念:
- Thread(线程):并行执行工作单元,所有线程执行相同指令但输入不同(SIMT 模型)
- Block(块):线程组,每个块在一个 SM 上运行,拥有自己的共享内存
- Warp(线程束):32 个连续编号的线程总是作为一个 warp 执行
GPU 内存模型
- 每个线程可以访问自己的寄存器和块内的共享内存
- 跨块的信息需要通过全局内存读写(较慢)
TPU 简介
TPU 与 GPU 在高层类似:
- 核心结构:轻量控制、快速(大型)矩阵乘法单元、快速内存
- 差异:加速器的网络方式不同,没有 warp(只有 blocks)
| 特性 | GPU | TPU |
|---|---|---|
| 计算单元 | 更多 SM | 更少 TC(但矩阵乘法性能相近) |
| 执行单位 | Warp (32 线程) | Block |
GPU 性能优化
Roofline 模型
关键问题:如何避免成为内存瓶颈?
优化技巧 1:低精度计算
更少的位数意味着更少的数据移动。
示例:元素级 ReLU(向量大小 n)
- Float32:内存访问 8 bytes/FLOP
- Float16:内存访问 4 bytes/FLOP
Tensor Coores(V、T 系列引入)是专门的矩阵乘法电路,矩阵乘法比其他浮点运算快 10 倍以上。
前沿低精度技术:
- FP8:不同权衡
- MXFP8(Blackwell):每个缩放因子对应 32 个元素
- MXFP4:更极致的量化
优化技巧 2:算子融合
将多个操作合并到单个 CUDA kernel 中,减少内存访问。
示例: 的计算
- 朴素方法:启动 5 个 CUDA kernel(反复读写内存)
- 融合方法:合并为单个 kernel 调用
简单融合可由编译器自动完成(如 torch.compile)。
优化技巧 3:重计算
在反向传播中,存储激活值可能很昂贵。有时丢弃激活值并重计算反而是最优的(减少内存访问)。
示例:3 个 sigmoid 堆叠
- 朴素方法:8 次内存读写
- 重计算方法:5/8 的内存访问量
优化技巧 4:内存合并
DRAM 以”突发模式”读取——每次读取给出多个字节!
内存访问被合并的条件:warp 中的所有线程访问都在同一突发范围内。
矩阵乘法中的合并:对于行优先矩阵,沿行移动的线程不会被合并。
优化技巧 5:分块(Tiling)
分块是分组和排序线程以最小化全局内存访问的技术。
矩阵乘法分块:
- 将矩阵切成小块(tiles),加载到共享内存
- 分阶段计算矩阵乘法
- 重复读取现在访问共享内存而非全局内存
分块数学:
- 非分块:每个输入从全局内存读取 N 次
- 分块:每个输入从全局内存读取 N/T 次,在 tile 内读取 T 次
- 结果:全局内存访问减少 T 倍
分块复杂问题:
- Tile 大小可能不整除矩阵维度,导致利用率低
- 合并访问可能无法实现(需要填充)
- 需要考虑:共享内存大小、矩阵维度可除性
矩阵尺寸与性能的神秘关系
为什么更大的矩阵更快?
- 分块影响:分块通过对齐产生重大影响
- 波量化(Wave Quantization):周期性行为
- 示例:使用 256×128 的 tile,1792 大小产生 98 个 tiles
- 增加到 1793 产生 120 个 tiles
- A100 有 108 个 SM,无法同时执行所有 120 个 tiles
Flash Attention 解析
Flash Attention [Dao et al] 大幅加速注意力计算,核心是分块 + 在线 softmax。
注意力计算回顾
注意力计算:3 个矩阵乘法(K、Q、V)中间有 softmax。
分块策略
- KQV 矩阵乘法的分块:标准的矩阵乘法分块
- 增量 softmax 计算:通过在线更新 max 值和望远镜求和,逐 tile 计算 softmax
在线 Softmax 公式(来自 Mikailov & Gimelshein 2018):
- 跟踪最大值
- 增量更新最大值,设置望远镜求和
Flash Attention 前向传播总结
- Tile 级计算内积(S)
- 融合指数运算
- 通过在线望远镜求和技巧逐 tile 计算 softmax
总结
GPU 性能优化的核心思路:
| 策略 | 方法 |
|---|---|
| 减少内存访问 | 合并、融合 |
| 将内存移到共享内存 | 分块 |
| 用内存换计算/精度 | 量化、重计算 |
硬件驱动扩展,底层细节决定什么能扩展。当前 GPU 计算强烈鼓励思考矩阵乘法 + 数据移动。
Kernels, Triton, XLA
GPU 硬件规格
| 加速器 | A100 | H100 | B200 |
|---|---|---|---|
| # SMs | 108 | 132 | 148 |
| Register size (per SM) | 256 KB | 256 KB | 256 KB |
| L1 cache + shared memory (per SM) | 192 KB | 256 KB | 256 KB |
| L2 cache size | 40 MB | 50 MB | 96-126 MB |
| HBM size | 80 GB | 80 GB | 192 GB |
| Register bandwidth | ~116 TB/s | ~401 TB/s | ~447 TB/s |
| L1 + shared memory bandwidth | ~19 TB/s | ~33 TB/s | ~19 TB/s |
| L2 cache bandwidth | ~5-8 TB/s | ~12 TB/s | ~9 TB/s |
| HBM bandwidth | 2 TB/s | 3.35 TB/s | 8 TB/s |
B200 还有 Tensor Memory (TMEM) 用于 tensor cores,位于寄存器和共享内存之间,对程序员不可见。
编程模型
- Thread(线程):在数据的小部分上执行代码
- Thread block / CTA:线程组,可以访问同一共享内存
- Grid:线程块的集合
为什么需要 Thread Block?
- 元素级操作(如 GeLU):线程最自然,每个线程处理一个元素
- 非元素级操作(如 softmax、矩阵乘法):线程需要通信
- 从 HBM 读写很慢,使用共享内存(SM 本地)
- Thread block:访问相同共享内存的线程集合,在一个 SM 上调度
程序模型与硬件交互
编程模型提供硬件抽象,理论上只需关注编程模型(保证正确性)。但性能高度依赖硬件细节,需要理解硬件才能获得高性能。
Warps(线程束)
- Thread block 内的线程分组为 warps(每 warp 32 线程)
- Warp 内所有线程必须以 lockstep 执行相同指令
- 控制分歧(Control divergence):不同线程需要执行不同指令时,必须串行执行(性能下降)
Warp Occupancy(占用率)
- 每个线程可使用 0-255 个寄存器
- 线程使用更多寄存器 → SM 上可调度的线程更少(低占用率)
- 低占用率不一定坏事,如果每个线程做更多工作
- 示例:block 有 128 线程,每线程 160 寄存器,SM 有 65536 寄存器
- num_blocks = 65536 / (128 × 160) = 3
- occupancy = 3 × 128 / 32 / 64 = 12.5%
Bank Conflicts(共享内存银行冲突)
- 共享内存分为 32 个 banks,每个 4 字节宽
- 每个 cycle,每个 bank 只能被一个线程访问
- 多线程访问同一 bank → 访问串行化(bank conflict)
- 最坏情况:矩阵每行横跨所有 banks;32 线程访问第一列 → 32-way bank conflict
- 解决方案:swizzling 重排共享内存(如 row xor col)
Memory Coalescing(内存合并)
- Warp 中 32 线程访问 HBM 时,内存访问合并为 128 字节的 cache line 事务
- 最佳情况:完全合并,所有线程访问同一 cache line
Block Occupancy(块占用率)
- Thread blocks 以 waves 形式调度到 SMs
- B200 有 148 SMs,启动 160 blocks → 第一波 148,第二波 12
- Wave quantization 问题:最后一波 blocks 更少,部分 SM 空闲
- 解决方案:使 block 数量能整除 SM 数量
Benchmarking 与 Profiling
成功配方
- Benchmark 和 profile 代码
- 做改动
- 再次 benchmark 和 profile
Benchmarking
测量操作的 wall-clock 时间,只给出端到端时间。
用途:
- 比较不同实现(哪个更快)
- 理解性能如何随维度扩展
关键技巧:
- Warmup:首次运行可能较慢(编译等)
- CUDA events:准确 GPU 计时(避免 CPU 开销)
torch.cuda.synchronize():等待 CUDA 线程完成
Profiling
查看时间花费在哪里,帮助理解底层发生了什么。
PyTorch 内置 profiler,更详细可用 nsight。
CUDA kernel 名称解读:
cutlass3x_sm100_simt_sgemm_f32_f32_f32_f32_f32_64x64x16_1x1x1_3_nnn_align1_bi...cutlass: NVIDIA 的 CUDA 线性代数库sm100: Blackwell 架构 (B200)f32: float3264x64x16: tile 形状
Triton 简介
| 特性 | CUDA | Triton |
|---|---|---|
| 开发者 | NVIDIA | OpenAI |
| 编程单位 | Thread | Thread block |
| 优点 | 细粒度控制 | 更简单,自动管理共享内存 |
| 缺点 | 需管理更多细节 | 灵活性略低 |
Triton 概念框架:加载数据到共享内存 → 操作(可融合)→ 写回 HBM
Triton Kernel 示例
GeLU(元素级操作)
@triton.jitdef triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) # Block ID start = pid * BLOCK_SIZE offsets = start + tl.arange(0, BLOCK_SIZE) mask = offsets < num_elements
x = tl.load(x_ptr + offsets, mask=mask)
# GeLU 计算 a = 0.79788456 * (x + 0.044715 * x * x * x) exp = tl.exp(2 * a) tanh = (exp - 1) / (exp + 1) y = 0.5 * x * (1 + tanh)
tl.store(y_ptr + offsets, y, mask=mask)PTX 观察要点:
ld.global.*/st.global.*: 全局内存读写%ctaid.x: block index,%tid.x: thread index%f*: 浮点寄存器,%r*: 整数寄存器- 一个线程处理 8 个元素(thread coarsening)
Softmax(Reduce 操作,行 fits in block)
朴素实现:5MN + M reads, 3MN + 2M writes
Triton fused softmax:MN reads, MN writes(提速 ~4x)
@triton.jitdef triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, num_cols, BLOCK_SIZE): row_idx = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE)
x_row = tl.load(x_ptr + row_idx * x_row_stride + col_offsets, mask=col_offsets < num_cols, other=float("-inf"))
x_row = x_row - tl.max(x_row, axis=0) # 减去 max numerator = tl.exp(x_row) denominator = tl.sum(numerator, axis=0) y_row = numerator / denominator
tl.store(y_ptr + row_idx * y_row_stride + col_offsets, y_row, mask=col_offsets < num_cols)Row Sum(Reduce 操作,行不 fits in block)
当行太大无法放入单个 block 时:
策略:
- 将行分成多个 tiles
- 每个线程遍历 tiles 并累加
- 最后对每个线程的累加器做 reduction
@triton.jitdef row_sum_kernel(x_ptr, out_ptr, N, BLOCK_SIZE): row = tl.program_id(0) acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for start in range(0, N, BLOCK_SIZE): # 遍历 tiles cols = start + tl.arange(0, BLOCK_SIZE) mask = cols < N x = tl.load(x_ptr + row * N + cols, mask=mask, other=0.0) acc += x
result = tl.sum(acc, axis=0) # 最终 reduction tl.store(out_ptr + row, result)Matmul ReLU(Tiling + Fusion)
朴素方法:
- 固定 (m, n),遍历 k
- 每次 read A[m,k] 和 B[k,n]
- MKN reads, MN writes
- Arithmetic intensity: O(1)
理想方法:
- 加载全部 A 和 B 到共享内存
- MK + KN reads, MN writes
- Arithmetic intensity: O(N)
- 问题:矩阵通常太大
Tiling 解决方案:
将矩阵 C 分成 output tiles(thread blocks):
- 固定一个 output tile
- 遍历 A 的行 tile 和 B 的列 tile
- 加载对应 A tile 和 B tile 到共享内存
- 执行 tile 级矩阵乘法
- 累加部分和
- 写 output tile 到 HBM
Arithmetic intensity: O(tile_size)
@triton.jitdef matmul_relu_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M, BLOCK_N, BLOCK_K): pid_m = tl.program_id(0) pid_n = tl.program_id(1)
indices_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) indices_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) indices_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + indices_m[:, None] * stride_am + indices_k[None, :] * stride_ak b_ptrs = b_ptr + indices_k[:, None] * stride_bk + indices_n[None, :] * stride_bn
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for k in range(0, K, BLOCK_K): a = tl.load(a_ptrs, mask=...) b = tl.load(b_ptrs, mask=...) acc += tl.dot(a, b) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk
acc = tl.maximum(acc, 0.0) # ReLU fusion
c_ptrs = c_ptr + indices_m[:, None] * stride_cm + indices_n[None, :] * stride_cn tl.store(c_ptrs, acc, mask=...)总结
| 层面 | 要点 |
|---|---|
| 编程模型 | PyTorch, Triton, PTX → 正确性 |
| 硬件理解 | SMs, warps, occupancy, bank conflicts, coalescing → 性能 |
| Benchmarking | 测量端到端时间,理解扩展规律 |
| Profiling | 查看时间花费,理解底层执行 |
| Triton 思维 | Thread block:读到共享内存 → 操作(fusion)→ 写回 HBM |
下次:多 GPU!
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!