Python AST 插桩技巧
现在有一些基于对 python 代码进行 AST 插桩的工作, 比如 TrainCheck。 这篇博客主要介绍如何对 Python 进行 AST 插桩。
使用 NodeTransformer 进行插桩
ast.NodeTransformer 是 ast.NodeVisitor 的子类, 它可以遍历并且修改 AST 的节点。ast.NodeVisitor 提供了一些访问 (visit) 方法, 比如 visit_FunctionDef, visit_AsyncFunctionDef, visit_Import, visit_ImportFrom, visit_For, visit_While, visit_Constant等。
根据访问方法的返回值处理原节点,返回新节点则替换,返回 None 则移除, 返回原节点则保持不变。

例 1: 使用 visit_Constant 将数字常量变成 42
下面是一个使用 NodeTransformer 的简单例子, 它把所有的int和float类型的数字常量替换成 42。
import ast
class AnswerToEverything(ast.NodeTransformer): def visit_Constant(self, node): # Check if the constant is a number if isinstance(node.value, (int, float)): return ast.Constant(value=42) return node
# Example code to transformcode = "x = 10 + 5"tree = ast.parse(code)
# Apply the transformationtransformer = AnswerToEverything()new_tree = transformer.visit(tree)
# Finalize the tree by filling in required line numbers/offsetsast.fix_missing_locations(new_tree)
print(ast.unparse(new_tree)) # Output: x = 42 + 42例 2: 使用 visit_Expr 扩展 Print
下面的例子使用 visit_Expr 拦截 print 函数调用。
关于 Expr 的结构可查看 Expressions, 这可以帮助你理解node.value.func.id 是怎么一回事儿。
下面 visit_Expr 的返回值是一个列表, 表示会进行扩展。
import ast
# 1. 待转换的原始代码source_code = """\print("Hello")x = 10print("World")"""
# 2. 自定义 Transformerclass ExpandPrintTransformer(ast.NodeTransformer): def visit_Expr(self, node): # 判断是否为 print() 调用语句 if (isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == "print"):
# 构造一个新的 AST 节点:print(">>> 拦截到打印语句") log_node = ast.Expr( value=ast.Call( func=ast.Name(id="print", ctx=ast.Load()), args=[ast.Constant(value=">>> 拦截到打印语句")], keywords=[] ) ) # ⚠️ 手动创建的节点缺失行号/列号,必须补全,否则 compile() 会报错 ast.fix_missing_locations(log_node)
# 🔑 核心:返回一个列表,原节点会被“原地展开替换”为列表中的多个节点 # 顺序决定插入位置:[新节点, 原节点] 表示先执行日志,再执行原 print return [log_node, node]
# 非 print 语句直接返回原节点(不修改) return node
# 3. 执行转换流程tree = ast.parse(source_code)transformer = ExpandPrintTransformer()new_tree = transformer.visit(tree)
# 4. 查看转换结果print("【转换后的 Python 代码】")print(ast.unparse(new_tree)) # Python 3.9+ 支持
print("\n【实际运行输出】")exec(compile(new_tree, filename="<demo>", mode="exec"))在这个地方使用了 ast.Load(), 它是 Python AST 中用于标记标识符 (变量名/属性名) 上下文的枚举实例。
它告诉编译器: 这个名字正在被读取。
标识符的三种上下文 (ctx)
| 上下文类 | 含义 | 典型代码位置 | 编译器生成的字节码 |
|---|---|---|---|
ast.Load() | 读取值 | print(x)、y = x + 1、func(x) | LOAD_NAME / LOAD_FAST |
ast.Store() | 写入/赋值 | x = 10、for x in range(5) | STORE_NAME / STORE_FAST |
ast.Del() | 删除绑定 | del x、del obj.attr | DELETE_NAME |
例 3: generic_visit 递归遍历子节点
generic_visit 是 Python ast 模块中递归遍历子节点的核心引擎。
当你重写visit_XXX方法时, 如果希望继续处理当前节点内部的子节点 (如函数体、表达式、参数等), 就必须显式调用它, 否则遍历会在此处”断崖式停止”。
下面是一个例子:
import ast
source = """\def calculate(a, b): result = a + b print(result) return result"""
# ================= 错误示范:忘记调用 generic_visit =================class StopTransformer(ast.NodeTransformer): def visit_FunctionDef(self, node): print(f"🚨 遇到函数定义: {node.name}") # ❌ 没有调用 self.generic_visit(node) # 遍历到此停止,函数体内的赋值、打印、返回语句全被跳过 return node
def visit_Name(self, node): # 这个 visit_Name 永远不会被触发! print(f" -> 捕获到名称: {node.id}") return node
# ================= 正确示范:调用 generic_visit =================class DeepTransformer(ast.NodeTransformer): def visit_FunctionDef(self, node): print(f"✅ 遇到函数定义: {node.name}") # 🔑 关键:将控制权交还给默认遍历引擎,继续深入子节点 self.generic_visit(node) return node
def visit_Name(self, node): # 现在可以正常遍历到函数内部的所有变量/函数名了 print(f" -> 捕获到名称: {node.id}") return node
print("【错误示范输出】")tree1 = ast.parse(source)StopTransformer().visit(tree1)
print("\n【正确示范输出】")tree2 = ast.parse(source)DeepTransformer().visit(tree2)NodeTransformer 的默认行为其实是这样的:
# 父类 NodeTransformer 的隐式逻辑(伪代码)def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) return visitor(node) # 如果你没重写 visit_XXX,默认走 generic_visit当你重写了 visit_FunctionDef,你就覆盖了默认的 generic_visit 行为。此时 Python 不知道你还要不要继续往下走,必须你手动调用 self.generic_visit(node) 把接力棒传下去。
下面是 TrainCheck 项目中的摘取的源代码。

支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!