|
本次源码剖析使用 2022/07/22 最新版本:
- NVIDIA PyTorch container: nvcr.io/nvidia/pytorch:22.06-py3;
- PyTorch: v1.12.0;
符号跟踪源码
import torch
from torch.fx import symbolic_trace
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
print(symbolic_traced.graph)
print(symbolic_traced.code)symbolic_trace 实现于 torch/fx/_symbolic_trace.py:827:symbolic_trace, 流程:
- 创建 torch.fx.Tracer,进行简单的初始化;
- 调用 trace() 来捕获 torch.fx.Graph;
- 创建 torch.fx.GraphModule;
符号跟踪(symblic tracing) 在 torch/fx/_symbolic_trace.py:499:trace 实现,符号跟踪整体流程如下:
- torch/fx/_symbolic_trace.py#L539:创建 torch.fx.Graph,Graph 是 FX IR 的主要数据结构,包含一系列 torch.fx.Node,Node 代表了输入/输出/算子,一系列 Node 构成了 Python 函数。Graph 初始化在 torch/fx/graph.py:600:__init__,其中创建了 root 节点(torch.fx.Node)和 CodeGen 对象。Node 初始化于 torch/fx/node.py:123:__init__,Node 包含 .op 属性,表明该节点的功能。此处主要函数执行流程:
text [P002] > torch/fx/graph.py:600:__init__ [New] [P003] > torch/fx/node.py:123:__init__ [New] [P003] < torch/fx/node.py:123:__init__ [P003] > torch/fx/graph.py:250:__init__ [New] [P003] < torch/fx/graph.py:250:__init__ [P002] < torch/fx/graph.py:600:__init__
- torch/fx/_symbolic_trace.py#L559:创建被 trace 的 root 函数(对 module 来说是 forward)的输入节点(placeholder),具体实现在 torch/fx/_symbolic_trace.py:376:create_args_for_root,流程如下:
- 对 root 函数进行自省以获取函数签名,它通过标准库的 inspect.signature() 实现,从而获得参数名称等信息;
- 对 root 函数的每个参数创建 torch.fx.Proxy 对象(torch/fx/proxy.py:232:__init__),其类型是 placeholder,含义是函数的输入。Proxy 是 Node 的 wrapper,每个 Proxy 都包含一个 Node,Proxy 在 symbolic tracing 过程中负责记录涉及到的算子。Proxy 是符号跟踪过程中的符号;
- 创建 Node 通过 graph.create_node() 实现(torch/fx/graph.py:691:create_node),Node 以双向链表的方式管理,每个 Node 都记录其前后节点,节点插入在 torch/fx/node.py:224:prepend 实现。这里 graph 是 TracerBase 的 类属性,Tracer 继承自 TracerBase;
此处主要函数执行流程:
text [P002] > torch/fx/_symbolic_trace.py:376:create_args_for_root [New] [P003] > python3.8/inspect.py:493:unwrap [New] [P003] < python3.8/inspect.py:493:unwrap [P003] > python3.8/inspect.py:3103:signature [New] [P003] < python3.8/inspect.py:3103:signature [P003] > torch/fx/_symbolic_trace.py:443:<genexpr> [New] [P004] > torch/fx/_symbolic_trace.py:403:proxy_placeholder [New] [P005] > torch/fx/proxy.py:49:create_proxy [New] [P006] > torch/fx/_symbolic_trace.py:201:create_arg [New] [P007] > torch/fx/proxy.py:107:create_arg [New] [P007] < torch/fx/proxy.py:107:create_arg [P006] < torch/fx/_symbolic_trace.py:201:create_arg [P006] > torch/fx/proxy.py:29:create_node [New] [P007] > torch/fx/graph.py:691:create_node [New] [P008] > torch/fx/node.py:123:__init__ [P008] < torch/fx/node.py:123:__init__ [P008] > torch/fx/node.py:224:prepend [New] [P009] > torch/fx/node.py:257:_remove_from_list [New] [P009] < torch/fx/node.py:257:_remove_from_list [P008] < torch/fx/node.py:224:prepend [P007] < torch/fx/graph.py:691:create_node [P006] < torch/fx/proxy.py:29:create_node [P006] > torch/fx/proxy.py:45:proxy [New] [P007] > torch/fx/proxy.py:232:__init__ [New] [P007] < torch/fx/proxy.py:232:__init__ [P006] < torch/fx/proxy.py:45:proxy [P005] < torch/fx/proxy.py:49:create_proxy [P004] < torch/fx/_symbolic_trace.py:403:proxy_placeholder [P003] < torch/fx/_symbolic_trace.py:443:<genexpr> [P003] > torch/utils/_pytree.py:126:tree_flatten [New] [P003] < torch/utils/_pytree.py:126:tree_flatten [P002] < torch/fx/_symbolic_trace.py:376:create_args_for_root
- torch/fx/_symbolic_trace.py#L581-L582:以 monkey patch 的方式给 torch.nn.Module.__getattr__ 和 torch.nn.Module.__call__ 打补丁,通过 setattr() 将它们替换为 module_getattr_wrapper 和 module_call_wrapper;
- torch/fx/_symbolic_trace.py#L583-L586:给叶子函数(leaf function)打补丁,确保在符号跟踪时直接调用原函数。叶子函数包括通过 torch.fx.wrap(fn_or_name) 标记的函数,以及在 Tracer 实例化时由参数 autowrap_modules 显示指定的 Python 模块(默认为 math 模块)和由 autowrap_function 显示指定的 Python 函数。这些叶子函数在符号跟踪的过程中不会被跟踪,会直接调用原函数;
- torch/fx/_symbolic_trace.py#L587-L588:以 Proxy 作为参数调用 root 函数,返回后为其创建 output 节点;
符号跟踪的核心正是 torch/fx/_symbolic_trace.py#L587-L588 的 fn(*args),其中 args 是 Proxy 构成的列表,fn 是上述示例代码中的 forward()。
在 forward() 中用到 self.param,而 self 是 torch.nn.Module 的实例,torch.nn.Module.__getattr__ 在此之前已被修改为 torch/fx/_symbolic_trace.py:565:module_getattr_wrapper,这里获取原属性后,转到 torch/fx/_symbolic_trace.py:473:_module_getattr。如果该属性的类型是 torch.nn.Parameter,则为其创建类型为 get_attr 的 Proxy,含义是从 module 中获取 Parameter。
然后执行 x + self.param,此时两个操作数都是 torch.fx.Proxy 类型,加法在这里需要调用 Python 中的魔术方法(magic method) __add__。Proxy.__add__ 在导入 torch.fx 模块时被设置为 torch/fx/proxy.py:383:impl,被修改的魔术方法列表由 torch/fx/graph.py#L1417 指定,包含了常见的 Python 数学运算符。在 impl() 内,找到真正的算子 operator.add,由 tracer 为 operator.add 创建新的 Proxy,类型为 call_function。创建 Proxy 实现在 torch/fx/proxy.py:49:create_proxy,两个操作数都是 Proxy 类型,获取参数(torch/fx/proxy.py#L63)会直接通过 Porxy.node 获取其对应的 Node(torch/fx/proxy.py#L147),然后通过 graph 创建新的 Node,其 .target 为 operator.add,.op 为 call_function,._input_nodes 是两个操作数的 Node,而新创建的 Node 也成了操作数 Node 的 .users。graph 以环形双向链表的方式管理 Node,新的 Node 被插入到 graph 的 root 节点前(root._prev)、最后一个节点后,每个 Node 自己负责记录 input_nodes (producer) 和 users (consumer)。 最后为新 Node 创建 Proxy。
此处主要的函数执行流程:
[P003] > torch/fx/proxy.py:383:impl [New]
[P004] > torch/fx/proxy.py:49:create_proxy
[P005] > torch/fx/_symbolic_trace.py:201:create_arg
[P006] > torch/fx/proxy.py:107:create_arg
[P007] > torch/fx/proxy.py:125:<genexpr>
[P008] > torch/fx/_symbolic_trace.py:201:create_arg
[P008] < torch/fx/_symbolic_trace.py:201:create_arg
[P007] < torch/fx/proxy.py:125:<genexpr>
[P006] < torch/fx/proxy.py:107:create_arg
[P005] < torch/fx/_symbolic_trace.py:201:create_arg
[P005] > torch/fx/proxy.py:29:create_node
[P006] > torch/fx/graph.py:691:create_node
[P007] > torch/fx/node.py:123:__init__
[P008] > torch/fx/node.py:365:__update_args_kwargs
[P008] < torch/fx/node.py:365:__update_args_kwargs
[P007] < torch/fx/node.py:123:__init__
[P006] < torch/fx/graph.py:691:create_node
[P005] < torch/fx/proxy.py:29:create_node
[P004] < torch/fx/proxy.py:49:create_proxy
[P003] < torch/fx/proxy.py:383:impltorch/fx/proxy.py:383:impl 正是符号跟踪的巧妙之处,调用加法运算符被转到了 Proxy.__add__,该函数并没有真正的执行加法运算,而是根据操作数的 Proxy 创建了新的 Proxy,其中 graph 记录了所有的 Node,每个 Node 记录了涉及到的运算符和输入 Node。
下一步执行 self.linear(x + self.param),调用 self.linear() 会执行 torch.nn.Module.__call__,而它在此前已被修改为 torch/fx/_symbolic_trace.py:570:module_call_wrapper,紧接着转到 torch/fx/_symbolic_trace.py:342:call_module。如果 module 不是 leaf module,即包含其他 torch.nn.Module 或容器的 module,则继续调用 forward() 函数,直到 module 是 leaf module,然后创建并返回新的 Proxy,类型为 call_module,创建 Proxy 的过程和前面相同。此处主要的函数执行流程:
[P003] > torch/fx/_symbolic_trace.py:570:module_call_wrapper [New]
[P004] > torch/fx/_symbolic_trace.py:759:_autowrap_check
[P004] < torch/fx/_symbolic_trace.py:759:_autowrap_check
[P004] > torch/fx/_symbolic_trace.py:342:call_module [New]
[P005] > torch/fx/_symbolic_trace.py:293:is_leaf_module [New]
[P005] < torch/fx/_symbolic_trace.py:293:is_leaf_module
[P005] > torch/fx/proxy.py:49:create_proxy
[P005] < torch/fx/proxy.py:49:create_proxy
[P004] < torch/fx/_symbolic_trace.py:342:call_module
[P003] < torch/fx/_symbolic_trace.py:570:module_call_wrapper随后执行 .clamp(min=0.0, max=1.0),因为 self.linear() 返回的是 Proxy,.clamp() 会被转到 Proxy.__getattr__(&#34;clamp&#34;)(),它的实现在 torch/fx/proxy.py:243:__getattr__,其中直接返回一个 Attribute(torch/fx/proxy.py:325:__init__),Attribute 继承自 Proxy。在调用 Attribute() 时,tracer 创建并返回了新的 Proxy,类型是 call_method,target 是 clamp。此处主要的函数执行流程:
[P003] > torch/fx/proxy.py:243:__getattr__ [New]
[P004] > torch/fx/proxy.py:325:__init__ [New]
[P004] < torch/fx/proxy.py:325:__init__
[P003] < torch/fx/proxy.py:243:__getattr__
[P003] > torch/fx/proxy.py:340:__call__ [New]
[P004] > torch/fx/proxy.py:49:create_proxy
[P004] < torch/fx/proxy.py:49:create_proxy
[P003] < torch/fx/proxy.py:340:__call__到此,示例代码中的 MyModule.forward() 执行完毕。Tracer.trace() 的最后一步是建立 output 节点,类型为 output,一个完整的 symbolic tracing 到此结束。
torch.fx.symbolic_trace 的最后一步是创建 torch.fx.GraphModule,实现于 graph_module.py:293:__new__。GraphModule 继承自 torch.nn.Module,包含从 graph 中生成的 .graph, .code, .forward 属性。GraphModule 在初始化的过程中,会通过 setattr() 引用原 module 的属性。在 tracing 过程中捕获的 graph,也会设置为 GraphModule.graph,此时会触发 torch/fx/graph_module.py:624:recompile,其功能是根据 torch.fx.Graph 重新编译 GraphModule,每次对 .graph 的修改,都需要重新编译 GraphModule。
gm.recompile() 的第一步是 Python 代码生成(torch/fx/graph_module.py#L634),代码生成的核心由 torch.fx.CodeGen 实现,具体过程详见 torch/fx/graph.py:297:_gen_python_code。它首先把一些内置名称添加到全局命名空间,例如 inf, None, torch。然后以逆序方式遍历图中的节点(torch/fx/graph.py#L389-L391),找到每个节点最后被使用的地方,从而在代码生成的过程中及时释放不用的节点。
生成 Python 代码在 torch/fx/graph.py#L475-L479,代码生成逐节点进行,依次为每个 Node 生成对应的 Python 代码。上述示例代码的代码生成过程如下:
Index | Node | Op | Target | Args | Kwargs | Body | Unused | 0 | x | placeholder | x | () | {} | 1 | param | get_attr | param | () | {} | param = self.param | 2 | add | call_function | operator.add | (x, param) | {} | add = x + param; | x = param = None | 3 | linear | call_module | linear | (add,) | {} | linear = self.linear(add); | add = None | 4 | clamp | call_method | clamp | (linear,) | {&#39;min&#39;: 0.0, &#39;max&#39;: 1.0} | clamp = linear.clamp(min = 0.0, max = 1.0); | linear = None | 5 | output | output | output | (clamp,) | {} | return clamp | 以生成 add 节点对应的代码为例,进入到 emit_node() 后,node.op 是 call_function,node.target 是 operator.add,是 Python 中的魔术方法 __add__。torch/fx/graph.py#L1400 定义 add 的模板是 &#39;{} + {}&#39;,通过 format() 向上述模板中填入两个参数 x 和 param,最终得到字符串 add = x + params。在 delete_unused_values() 中,变量 x 和 param 是他们在 add 节点最后被用到的地方,因此生成代码 ; x = param = None。add 节点的代码生成完毕,按照此流程即可生成所有节点对应的代码。
body.append(f&#39;{repr(node)}{maybe_type_annotation} = &#39;
f&#39;{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}&#39;)如果存在通过 torch.fx.wrap() 定义的叶子函数,则在 torch/fx/graph.py#L490-L491 为其生成 wrap 语句。torch/fx/graph.py#L501 生成函数定义 def forward(self, x):。最后,一份完整的 Python 代码字符串生成完毕。
gm.recompile() 的第二步是 代码编译,它通过 torch/fx/graph_module.py#L69-L80 实现。核心代码是 exec(compile(src, key, &#39;exec&#39;), globals),它调用 Python 内建函数 compile() 把生成的 Python 代码字符串编译为 Python 字节码,并调用 Python 内建函数 exec() 执行字节码。因为生成的代码是 forward() 函数的定义,所以在 exec() 之后,globals 中存在一个名为 forward 的可执行函数,该函数被直接设置为 graph module 的 forward 函数(torch/fx/graph_module.py#L638)。torch/fx/graph_module.py#L654 的 cls.__call__ = call_wrapped 将 GraphModule() 转发到 call_wrapped,最终会转到 forward() 函数。
到此,GraphModule 初始化完成,初始化 GraphModule 过程中的重要函数执行流程:
[P001] > torch/fx/graph_module.py:293:__new__ [New]
[P002] > torch/fx/graph_module.py:307:GraphModuleImpl [New]
[P002] < torch/fx/graph_module.py:307:GraphModuleImpl
[P001] < torch/fx/graph_module.py:293:__new__
[P001] > torch/fx/graph_module.py:311:__init__ [New]
[P002] > torch/fx/graph.py:632:nodes [New]
[P003] > torch/fx/graph.py:221:__init__ [New]
[P003] < torch/fx/graph.py:221:__init__
[P002] < torch/fx/graph.py:632:nodes
[P002] > torch/fx/graph.py:229:__iter__ [New]
[P002] < torch/fx/graph.py:229:__iter__
[P002] > torch/fx/graph_module.py:182:_copy_attr [New]
[P002] < torch/fx/graph_module.py:182:_copy_attr
[P002] > torch/nn/modules/module.py:1210:__setattr__
[P003] > torch/fx/graph_module.py:394:graph [New]
[P004] > torch/fx/graph_module.py:624:recompile [New]
[P005] > torch/fx/graph.py:1091:python_code [New]
[P006] > torch/fx/graph.py:115:__init__
[P006] < torch/fx/graph.py:115:__init__
[P006] > python3.8/contextlib.py:211:contextmanager [New]
[P006] < python3.8/contextlib.py:211:contextmanager
[P006] > python3.8/contextlib.py:108:__enter__ [New]
[P006] < python3.8/contextlib.py:108:__enter__
[P006] > torch/fx/graph.py:1153:_python_code [New]
[P007] > torch/fx/graph.py:297:_gen_python_code [New]
[P008] > torch/fx/graph.py:306:add_global [New]
[P008] < torch/fx/graph.py:306:add_global
[P008] > torch/fx/node.py:592:map_arg
[P008] < torch/fx/node.py:592:map_arg
[P008] > torch/fx/graph.py:412:emit_node
[P008] < torch/fx/graph.py:412:emit_node
[P008] > torch/fx/graph.py:393:delete_unused_values
[P008] < torch/fx/graph.py:393:delete_unused_values
[P008] > torch/fx/graph.py:290:additional_globals [New]
[P008] < torch/fx/graph.py:290:additional_globals
[P008] > torch/fx/graph.py:253:gen_fn_def [New]
[P008] < torch/fx/graph.py:253:gen_fn_def
[P007] < torch/fx/graph.py:297:_gen_python_code
[P006] < torch/fx/graph.py:1153:_python_code
[P006] > python3.8/contextlib.py:117:__exit__ [New]
[P006] < python3.8/contextlib.py:117:__exit__
[P005] < torch/fx/graph.py:1091:python_code
[P005] > torch/nn/modules/module.py:1210:__setattr__
[P005] < torch/nn/modules/module.py:1210:__setattr__
[P005] > torch/fx/graph_module.py:74:_forward_from_src [New]
[P006] > torch/fx/graph_module.py:69:_exec_with_source [New]
[P007] > torch/fx/graph_module.py:28:cache [New]
[P007] < torch/fx/graph_module.py:28:cache
[P007] > <eval_with_key>.0:4:<module> [New]
[P007] < <eval_with_key>.0:4:<module>
[P006] < torch/fx/graph_module.py:69:_exec_with_source
[P005] < torch/fx/graph_module.py:74:_forward_from_src
[P005] > torch/fx/graph_module.py:227:__init__ [New]
[P005] < torch/fx/graph_module.py:227:__init__
[P004] < torch/fx/graph_module.py:624:recompile
[P003] < torch/fx/graph_module.py:394:graph
[P002] < torch/nn/modules/module.py:1210:__setattr__
[P002] > torch/fx/graph_module.py:387:graph [New]
[P002] < torch/fx/graph_module.py:387:graph
[P001] < torch/fx/graph_module.py:311:__init__到此,symbolic_trace() 的最后一步,整个符号跟踪完成。
符号跟踪总结:
- 在初始化阶段,通过 monkey patch 把针对 Tensor 的函数修改为针对 Proxy 的函数;
- 在跟踪阶段,用 Proxy 替换 Tensor 作为函数输入,依次在 Proxy 的方法中建立新的 Proxy,其中的 Node 构成了 Graph;
- 在代码生成阶段,依次遍历 Node,逐个生成对应的 Python 代码字符串;
- 在编译阶段,通过 Python 内建函数 exec(compile()) 将字符串编译为可执行函数;
|
|