PyTorch 自定义操作¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
PyTorch 提供了一个大型操作库,这些操作可以在张量上工作(例如 torch.add
, torch.sum
等)。然而,您可能希望将新的自定义操作引入 PyTorch,并使其与子系统 torch.compile
、autograd 和 torch.vmap
协同工作。为此,您必须通过 Python torch.library 文档或 C++ TORCH_LIBRARY
API 将自定义操作注册到 PyTorch 中。
从 Python 编写自定义操作符
请参阅自定义 Python 操作符。
如果您希望将 Python 函数作为 PyTorch 的不可见调用处理,特别是与@0#和@1#相关,您可能希望从 Python(而不是 C++)编写自定义操作符:
您有一个希望 PyTorch 将其视为不可见调用的 Python 函数。
你有一些用于 C++/CUDA 内核的 Python 绑定,并希望它们与 PyTorch 子系统(如
torch.compile
或torch.autograd
)协同工作你正在使用 Python(而不是仅限 C++ 的 AOTInductor 环境等)。
将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成
请参阅自定义 C++ 和 CUDA 操作符。
您可能希望从 C++(而不是 Python)中创建一个自定义操作符,如果:
您有自定义的 C++和/或 CUDA 代码。
您计划使用此代码与
AOTInductor
进行 Python 无关的推理。
自定义操作符手册 ¶
对于本教程和本页面未涵盖的信息,请参阅《自定义运算符手册》(我们正在努力将信息迁移到我们的文档网站上)。我们建议您首先阅读上述教程之一,然后使用自定义运算符手册作为参考;它不是从头到尾阅读的内容。
我应该在什么时候创建自定义运算符?
如果您的操作可以用内置的 PyTorch 运算符表示,请将其编写为 Python 函数并调用,而不是创建自定义运算符。如果您正在调用 PyTorch 不理解的一些库(例如自定义 C/C++代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用运算符注册 API 创建自定义运算符。
为什么我要创建自定义运算符?
通过获取张量的数据指针并将其传递给 pybind 的内核,可以使用 C/C++/CUDA 内核。然而,这种方法与 PyTorch 子系统(如 autograd、torch.compile、vmap 等)不兼容。为了使操作与 PyTorch 子系统兼容,必须通过操作注册 API 进行注册。