模拟张量 ¶
代码:fake_tensor.py
动机 ¶
在进行 Dynamo 符号评估和编译器传递时,我们通常希望能够运行张量操作来了解输出大小/数据类型/设备等信息,而不实际运行这些操作(或破坏现有的张量),这会更快(如果你在进行大量计算)并且占用大量内存(如果编译器需要在编译程序时使用 GPU 内存,那就更糟糕了)。模拟张量在所有方面都像真实张量一样,只是实际上没有任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户张量代码并回答有关中间结果的问题(例如,如果用户对一个中间张量进行条件操作)。没有模拟张量,我们就无法为这些查询提供准确的信息。
同样地,假设你想为张量存储元数据,例如在 FX IR 节点上(meta[‘val’])。你可以在节点上直接存储一个假张量,这将为你提供所需的所有张量元数据,包括你可能没有处理过的细微信息(例如别名关系)。
总体架构
所有伪造张量都与一个 FakeTensorMode 相关联。因为伪造张量的主要用例是对真实张量进行分析,所以一般的工作流程是:你有一堆真实张量,分配一个 FakeTensorMode,然后使用 from_real_tensor 将所有这些真实张量转换为伪造张量,然后对伪造张量进行操作。特别是,FakeTensorMode 会持久化地维护一个备忘表,将张量(和存储)映射到相同的存储。如果你多次伪造相同的张量,你会得到相同的伪造张量;如果你伪造了两个相互引用的张量,你会得到两个相互引用的伪造张量,它们引用相同的伪造存储。FakeTensors 是张量子类,所以如果你对它们进行操作,你会自动得到一个伪造张量,但通常你会在 FakeTensorMode 激活的情况下对伪造张量进行操作(例如,如果你正在运行 FX 传递);张量操作将自动打开伪造张量模式并再次尝试。
模拟张量被表示为元张量的一个 __torch_dispatch__ 张量子类。这意味着在底层,模拟张量是元设备张量;然后它们使用额外的可扩展钩子,特别是 dispatch_device,来欺骗张量的实际设备。这曾是早期模拟张量中较为容易出错的部分之一:有时,模拟张量在欺骗自己是 CPU/CUDA 等方面过于出色,以至于你会在使用模拟张量尝试解引用数据指针时调用 CPU 内核,这显然是不可行的。如果你在模拟张量代码中遇到段错误,这是你应该首先检查的事情:C++堆栈跟踪是否在 CPU 内核(意外!)或元内核(预期!)中。元内核就像一个真正的内核,但它所做的只是分配输出,不进行任何数据计算。
张量子类必须定义如何实现各种操作。以下是一般模拟张量的配方:
在输入的假张量上运行元内核,将它们重新解释为元张量。这是通过_in_kernel_invocation_manager_魔法上下文管理器完成的,它指示 PyTorch 将假张量视为其底层元张量,而不是将假张量“展开”为元张量(假张量是元张量)。以这种方式表示假张量是为了避免需要保持两组元数据同步(元张量的元数据和假张量的元数据);“是”关系确保只有一个规范副本的元数据。
如果您是工厂函数,则将调用底层工厂函数,并将 device 参数设置为’meta’。
将生成的元张量转换为假张量,计算张量的输出设备应该是哪个(这通常很简单,但有时并非如此,例如,cpu 标量提升或设备转换操作。)
API:重要部分
非 PT2 用法(更多示例请查看 test/test_fake_tensor.py):
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
问:为什么你有真实张量作为输入?
答:在 PT2 的上下文中,这是因为你通常是在即时编译,所以对于你正在编译的图的所有输入,你已经有“真实”的输入,因为你在执行程序时进行编译。
PT2 预-AOTAutograd 用法(这很不常见,你可能不想这样做):
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
detect_fake_mode 将搜索多个位置以尝试找到与生命周期相关的“假”张量模式。通常,它将从跟踪上下文中提取出来。
PT2 AOTAutograd 后使用:
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用信息:
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
在什么情况下你可能想要禁用假张量模式?通常你不需要这样做。我们找到一个有用的特例是,在假张量模式中实现常量传播:在这种情况下,即使在假张量模式下,我们也需要进行一些实际的张量计算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
详情
是否自动转换?最初,FakeTensorMode 不会在 FakeTensorMode 区域内尝试对真实张量进行计算时自动进行伪造。背后的动机是为了防止以下错误:
with FakeTensorMode():
real_tensor.t_()
这段代码应该做什么?如果我们实际上修改了真实张量的元数据,这可能会令人惊讶。但与此同时,也没有明显的创建 FakeTensor 的机会。因此,我们保守地决定抛出错误:“在 FakeTensorMode 中使用非 Fake Tensor 输入调用操作符尚不支持。请先将所有 Tensor 转换为 FakeTensor。”
实际上,这个错误相当令人烦恼。例如,假设你有一个真实的 nn.Module,并且你想通过它传递伪造的张量。你需要以某种方式伪造 nn.Module。这促使了 FakeCopyMode 的诞生。
最终,我们放弃了,并添加了自动的伪造功能。然而,在许多 FakeTensorMode 的使用中,这仍然不是默认启用的。
关于伪造张量的元数据变更:如果你有一个伪造的张量,并且你对其执行 t_()操作,伪造张量的元数据会发生变化。表面上看起来这是合理的,但有时你希望也将伪造张量作为 FX 节点上的元数据存储;修改伪造张量是错误的,因为这将使旧的元数据失效!
事实上,这里存在一个基本的矛盾,即伪造张量维护着关于张量的极其精确的元数据,包括对象标识符。如果 FX 图中的对象元数据随时间变化,实际上并没有任何方法来表示这种变化。大多数时候,我们进行的严肃 FX 分析都是在没有这种功能的函数化图上进行的,但偶尔你需要在非函数化图上进行分析。也许将伪造张量放入 meta['val']中是一个错误。
关于张量子类
模拟张量同时使用子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__启用了与模拟张量关联的 FakeTensorMode,然后重新分发(依赖于 FakeTensorMode 进行繁重的工作)。如果模拟张量操作收到它不认识的子类参数,它将返回 NotImplemented,给其他子类一个先运行的机会(希望简化为普通张量操作),然后再尝试一次。这可能导致无限循环。
每个操作符是如何实现的?
不幸的是,任何给定操作符可能实现的地点相当复杂。一些需要了解的重要情况:
当元素数量非常小的时候,张量子类支持有限的常量传播(这有助于处理我们立即在这样张量上调用 item()的一些情况。)
我们为某些操作实现了一些 fastpath,这些实现完全在 fake tensor 中进行,出于性能考虑。
如果您使用@custom_op 生成自定义 tensor,这些将直接将 impl_abstract 注册到 fake tensor。
Fake tensor 本身对设备转换操作有一些硬编码的特殊情况。
如果没有元实现也没有任何分解,我们将生成真实的零填充 tensor 并尝试直接运行操作以找出结果。如果操作尝试使用数据执行索引,这可能会导致段错误,因此我们默认不为此自定义操作启用此功能。
转换器是如何工作的?
由于假张量被用于对张量的精确属性非常敏感的情况,因此假张量在转换时非常小心,保留了叶节点性、requires_grad 属性、别名以及一大堆其他属性。大部分繁重的工作都在 MetaConverter 中完成。
性能特性
你可能会认为假张量运行得很快,因为它们不做任何张量计算。但在小张量尺寸下,我们实际上完全是受开销限制的,而且,假张量是用 Python 实现的,我们通常要做很多工作才能完成一个张量操作(因为它们是作为分解实现的)。所以实际上假张量在实践中相当慢,尤其是在涉及符号形状时。在假张量中,我们目前有两个重要的快速路径,这在实践中有很大的影响:
点对点操作不通过 PrimTorch 分解,而是我们手动编写了它们的传播规则。
如果可能的话,我们应该这样做。
假张量中的假张量?¶
有兴趣将假张量作为用户输入发送到 PT2 堆栈中,这意味着我们需要能够创建一个假张量的假张量。目前这并不真正支持,但也许并不太难实现。
与动态形状的交互
每个 FakeTensorMode 都包含一个 ShapeEnv,它跟踪所有符号形状信息。它们的生命周期通常是相互关联的:它们共同生存和死亡。
因为 FakeTensorMode 有一个 ShapeEnv(但元实现没有),依赖于数据的元函数和需要分配无后盾 SymInt 的情况存在于假张量中。假张量还负责缓存无后盾的 SymInts,例如,如果你在同一个假张量上两次调用 nonzero(),你会得到相同的符号大小。