备注
点击此处下载完整示例代码
(beta) 使用 FX 构建简单的 CPU 性能分析器 ¶
创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证
作者:詹姆斯·里德
在本教程中,我们将使用 FX 来完成以下操作:
以便我们可以检查并收集有关代码结构和执行情况的统计数据,我们将捕获 PyTorch Python 代码
构建一个小的类,作为简单的性能“分析器”,从实际运行中收集模型各部分的运行时统计数据
在本教程中,我们将使用 torchvision ResNet18 模型进行演示
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
现在我们有了我们的模型,我们想更深入地检查其性能。也就是说,对于以下调用,模型的哪些部分花费的时间最长?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答这个问题的常见方法是通过程序源代码,添加代码在程序中的各个点收集时间戳,并比较这些时间戳之间的差异,以查看时间戳之间的区域花费了多长时间。
这种技术当然适用于 PyTorch 代码,但是如果我们不需要复制并编辑模型代码,那就更好了,尤其是我们没有编写的代码(比如这个 torchvision 模型)。因此,我们将使用 FX 来自动化这个“仪器化”过程,而无需修改任何源代码。
首先,让我们先处理一些导入(我们将在代码中稍后使用这些导入)。
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
备注
tabulate
是一个外部库,不是 PyTorch 的依赖项。我们将使用它来更轻松地可视化性能数据。请确保您已从您喜欢的 Python 包源安装了它。
使用符号跟踪捕获模型
接下来,我们将使用 FX 的符号跟踪机制来捕获我们模型的定义,以便我们可以在数据结构中进行操作和检查。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
这为我们提供了 ResNet18 模型的图表示。图由一系列相互连接的节点组成。每个节点代表 Python 代码中的调用位置(无论是函数、模块还是方法),而边(在每个节点上表示为 args
和 kwargs
)代表这些调用位置之间传递的值。有关图表示和 FX 的其余 API 的更多信息,请参阅 FX 文档 https://maskerprc.github.io/docs/master/fx.html。
创建一个分析解释器
接下来,我们将创建一个继承自 torch.fx.Interpreter
的类。尽管 GraphModule
产生的 symbolic_trace
编译的 Python 代码在调用 GraphModule
时运行,但运行 GraphModule
的另一种方法是逐个执行 Node
中的 Graph
。这正是 Interpreter
提供的功能:它逐个解释图节点。
通过继承 Interpreter
,我们可以覆盖各种功能并安装我们想要的性能分析行为。目标是有一个对象,我们可以向其传递一个模型,调用模型 1 次或多次,然后获取有关这些运行期间模型及其各个部分花费时间的统计信息。
让我们定义我们的 ProfilingInterpreter
类:
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
备注
我们使用 Python 的 time.time
函数来获取墙钟时间戳并进行比较。这不是最准确的方式来衡量性能,只能提供一个一阶近似。我们仅为了本教程的演示目的使用这种简单技术。
调查 ResNet18 的性能
现在我们可以使用 ProfilingInterpreter
来检查我们的 ResNet18 模型的性能特征;
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
这里有两点我们应该指出:
这通常是耗时最长的部分。这是一个已知问题:https://github.com/pytorch/pytorch/issues/51393
BatchNorm2d 也占用相当多的时间。我们可以继续这条思路,并在 FX 的 Conv-BN 融合教程中对其进行优化。
结论 ¶
如我们所见,使用 FX 我们可以轻松捕获 PyTorch 程序(甚至是没有源代码的程序)并以机器可解释的格式使用,例如我们在这里所做的性能分析。FX 为使用 PyTorch 程序打开了一个令人兴奋的可能性世界。
最后,由于 FX 仍处于测试阶段,我们非常愿意听取您在使用它时的任何反馈。请随时使用 PyTorch 论坛(https://discuss.pytorch.org/)和问题跟踪器(https://github.com/pytorch/pytorch/issues)提供您可能有的任何反馈。
脚本总运行时间:(0 分钟 0.000 秒)