torch.library ¬
torch.library 是一个用于扩展 PyTorch 核心库操作符的 API 集合。它包含用于测试自定义操作符、创建新的自定义操作符以及扩展使用 PyTorch 的 C++ 操作符注册 API 定义的操作符(例如 aten 操作符)的实用工具。
有关如何有效使用这些 API 的详细指南,请参阅 PyTorch 自定义操作符着陆页以获取更多详细信息。
测试自定义操作符 ¶
使用 torch.library.opcheck()
测试自定义操作符对于 Python torch.library 和/或 C++ TORCH_LIBRARY API 的错误使用。如果您的操作符支持训练,请使用 torch.autograd.gradcheck()
测试梯度是否在数学上是正确的。
- torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)[source][source]¶
给定一个操作符和一些示例参数,测试操作符是否正确注册。
即使用 torch.library/TORCH_LIBRARY API 创建自定义操作时,您指定了自定义操作的元数据(例如可变性信息),这些 API 要求您传递给它们的函数满足某些属性(例如在假/元/抽象内核中不允许访问数据指针)
opcheck
测试这些元数据和属性。具体来说,我们测试以下内容:
test_schema:如果模式与操作符的实现匹配。例如:如果模式指定 Tensor 被修改,则我们检查实现修改了 Tensor。如果模式指定我们返回一个新的 Tensor,则我们检查实现返回了一个新的 Tensor(而不是现有的一个或现有的一个视图)。
test_autograd_registration:如果操作符支持训练(autograd):我们检查其 autograd 公式是否通过 torch.library.register_autograd 或手动注册到一或多个 DispatchKey::Autograd 键进行注册。任何其他基于 DispatchKey 的注册可能导致未定义的行为。
测试_faketensor:如果算子具有 FakeTensor 内核(并且是正确的)。FakeTensor 内核对于算子使用 PyTorch 编译 API(torch.compile/export/FX)是必要的(但不是充分的)。我们检查算子是否已注册了 FakeTensor 内核(有时也称为元内核)并且它是正确的。此测试将算子在实际张量上运行的结果与在 FakeTensors 上运行的结果进行比较,并检查它们具有相同的 Tensor 元数据(大小/步长/数据类型/设备等)。
测试_aot_dispatch_dynamic:如果算子与 PyTorch 编译 API(torch.compile/export/FX)具有正确的行为。这检查在 eager-mode PyTorch 和 torch.compile 下输出(以及如果适用则梯度)是否相同。此测试是
test_faketensor
的超集,是一个端到端测试;它还测试算子支持函数化,以及反向传播(如果存在)也支持 FakeTensor 和函数化。
为了获得最佳结果,请多次调用
opcheck
并使用一组代表性的输入。如果您的操作符支持自动微分,请使用opcheck
并使用带有requires_grad = True
的输入;如果您的操作符支持多个设备(例如 CPU 和 CUDA),请使用opcheck
并在所有支持的设备上使用输入。- 参数:
op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 操作符。必须是带有
torch.library.custom_op()
装饰的函数,或者在 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin, torch.ops.mylib.foo)args (tuple[Any, ...]) – 操作符的参数
kwargs (Optional[dict[str, Any]]) – 操作符的可选关键字参数
test_utils (Union[str, Sequence[str]]) – 要运行的测试。默认:所有测试。示例:(“test_schema”, “test_faketensor”)
raise_exception (bool) – 是否在首次错误时引发异常。如果为 False,则返回包含每个测试是否通过的信息的字典。
rtol (Optional[float]) – 浮点数比较的相对容差。如果指定,则必须也指定
atol
。如果省略,则根据dtype
选择默认值(见torch.testing.assert_close()
中的表格)。atol (Optional[float]) – 浮点数比较的绝对容差。如果指定,则必须也指定
rtol
。如果省略,则根据dtype
选择默认值(见torch.testing.assert_close()
中的表格)。
- 返回类型:
dict[str, str]
警告
opcheck 和
torch.autograd.gradcheck()
测试不同的事情;opcheck 测试您对 torch.library API 的使用是否正确,而torch.autograd.gradcheck()
测试您的 autograd 公式是否在数学上是正确的。使用两者来测试支持梯度计算的自定义操作。示例
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: float) -> Tensor: >>> x_np = x.numpy(force=True) >>> z_np = x_np * y >>> return torch.from_numpy(z_np).to(x.device) >>> >>> @numpy_mul.register_fake >>> def _(x, y): >>> return torch.empty_like(x) >>> >>> def setup_context(ctx, inputs, output): >>> y, = inputs >>> ctx.y = y >>> >>> def backward(ctx, grad): >>> return grad * ctx.y, None >>> >>> numpy_mul.register_autograd(backward, setup_context=setup_context) >>> >>> sample_inputs = [ >>> (torch.randn(3), 3.14), >>> (torch.randn(2, 3, device='cuda'), 2.718), >>> (torch.randn(1, 10, requires_grad=True), 1.234), >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), >>> ] >>> >>> for args in sample_inputs: >>> torch.library.opcheck(numpy_mul, args)
在 Python 中创建新的自定义操作
使用 torch.library.custom_op()
创建新的自定义操作
- torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)[source]¶
将函数封装成自定义操作符。
创建自定义操作符的原因可能包括:- 将第三方库或自定义内核封装,以便与 PyTorch 子系统(如 Autograd)协同工作。- 防止 torch.compile/export/FX 追踪查看您的函数。
此 API 用作函数的装饰器(请参阅示例)。提供的函数必须具有类型提示;这些提示是用于与 PyTorch 的各种子系统交互所必需的。
- 参数:
name (str) – 自定义操作的名称,形如 “{namespace}::{name}”,例如 “mylib::my_linear”。该名称用作 PyTorch 子系统中操作的不变标识符(例如 torch.export、FX 图)。为了避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中所有自定义操作都使用 “fbgemm” 作为命名空间。
mutates_args (Iterable[str] 或 "unknown") – 函数修改的参数名称。这必须准确无误,否则行为是未定义的。如果为 “unknown”,则悲观地假设操作的所有输入都被修改。
device_types (None | str | Sequence[str]) – 函数有效的设备类型。如果没有提供设备类型,则函数用作所有设备类型的默认实现。示例:“cpu”、“cuda”。在注册不接受张量的操作器的特定设备实现时,我们要求操作器具有 “device: torch.device” 参数。
schema (None | str) – 操作符的 schema 字符串。如果为 None(推荐),我们将从操作符的类型注解中推断出 schema。除非您有特殊原因不这样做,我们建议您让我们推断 schema。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。
- 返回类型:
Union[Callable[[Callable[[…], object]], CustomOpDef], CustomOpDef]
注意
我们建议不要传递
schema
参数,而是让我们从类型注解中推断它。自己编写 schema 容易出错。如果您对我们的类型注解的解释不符合您的期望,您可能希望提供自己的 schema。有关如何编写 schema 字符串的更多信息,请参阅此处。- 示例::
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> @custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x = torch.randn(3) >>> y = numpy_sin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that only works for one device type. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") >>> def numpy_sin_cpu(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> x = torch.randn(3) >>> y = numpy_sin_cpu(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that mutates an input >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") >>> def numpy_sin_inplace(x: Tensor) -> None: >>> x_np = x.numpy() >>> np.sin(x_np, out=x_np) >>> >>> x = torch.randn(3) >>> expected = x.sin() >>> numpy_sin_inplace(x) >>> assert torch.allclose(x, expected) >>> >>> # Example of a factory function >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") >>> def bar(device: torch.device) -> Tensor: >>> return torch.ones(3) >>> >>> bar("cpu")
- torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)[source]¶
创建一个由 1+个 triton 内核支持的定制操作符。
这是一种更结构化的使用 triton 内核与 PyTorch 的方式。建议使用没有
torch.library
定制操作符包装器(如torch.library.custom_op()
,torch.library.triton_op()
)的 triton 内核,因为这更简单;只有当你想要创建一个像 PyTorch 内置操作符那样行为的操作符时,才使用torch.library.custom_op()
/torch.library.triton_op()
。例如,你可以使用torch.library
包装器 API 来定义当传入 tensor 子类或处于 TorchDispatchMode 下时 triton 内核的行为。当实现由 1+个 triton 内核组成时,使用
torch.library.triton_op()
而不是torch.library.custom_op()
。torch.library.custom_op()
将定制操作符视为不透明的(torch.compile()
和torch.export.export()
永远不会追踪到它们),但triton_op
使实现对这些子系统可见,允许它们优化 triton 内核(s)。注意
fn
必须仅包含 PyTorch 理解的运算符和 triton 内核的调用。在fn
内部调用的任何 triton 内核都必须被torch.library.wrap_triton()
的调用所包装。- 参数:
name (str) – 自定义运算符的名称,格式为“{namespace}::{name}”,例如“mylib::my_linear”。该名称用作运算符在 PyTorch 子系统(例如 torch.export、FX 图)中的稳定标识符。为了避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中的所有自定义运算符都使用“fbgemm”作为命名空间。
mutates_args (Iterable[str] or "unknown") – 函数修改的参数名称。这必须准确,否则行为是未定义的。如果为“unknown”,则悲观地假设运算符的所有输入都被修改。
schema (None | str) – 运算符的 schema 字符串。如果为 None(推荐),我们将从其类型注解中推断运算符的 schema。我们建议让我们推断 schema,除非您有特定的理由不这样做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。
- 返回类型:
示例:
>>> import torch >>> from torch.library import triton_op, wrap_triton >>> >>> import triton >>> from triton import language as tl >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> @triton_op("mylib::add", mutates_args={}) >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> # NB: we need to wrap the triton kernel in a call to wrap_triton >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) >>> return output >>> >>> @torch.compile >>> def f(x, y): >>> return add(x, y) >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> >>> z = f(x, y) >>> assert torch.allclose(z, x + y)
- torch.library.wrap_triton(triton_kernel, /)[源代码] ¶
允许通过 make_fx 或非严格
torch.export
将 triton 内核捕获到图中。这些技术执行基于 Dispatcher 的跟踪(通过
__torch_dispatch__
),无法看到对原始 triton 内核的调用。wrap_triton
API 将 triton 内核包装成可调用的形式,实际上可以将其跟踪到图中。请与
torch.library.triton_op()
一起使用此 API。示例
>>> import torch >>> import triton >>> from triton import language as tl >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.library import wrap_triton >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> def add(x, y): >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid_fn(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) >>> return output >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> gm = make_fx(add)(x, y) >>> print(gm.code) >>> # def forward(self, x_1, y_1): >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( >>> # kernel_idx = 0, constant_args_idx = 0, >>> # grid = [(1, 1, 1)], kwargs = { >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 >>> # }) >>> # return empty_like
- 返回类型:
扩展自定义操作(由 Python 或 C++创建)¶
使用 register.*方法,例如 torch.library.register_kernel()
和 torch.library.register_fake()
,为任何操作添加实现(它们可能使用 torch.library.custom_op()
或通过 PyTorch 的 C++操作注册 API 创建)。
- torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source][source]¶
为此运营商注册设备类型的实现。
一些有效的设备类型有:“cpu”,“cuda”,“xla”,“mps”,“ipu”,“xpu”。此 API 可以用作装饰器。
- 参数:
op (str | OpOverload) – 要注册实现的运算符。
device_types (None | str | Sequence[str]) – 要注册实现的设备类型。如果为 None,我们将注册到所有设备类型 – 请仅在您的实现真正与设备类型无关时使用此选项。
将作为给定设备类型实现注册的函数(Callable)–
如果提供,则此注册的生命周期(Optional[Library])
- 示例::
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> # Create a custom op that works on cpu >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> # Add implementations for the cuda device >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") >>> def _(x): >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x_cpu = torch.randn(3) >>> x_cuda = x_cpu.cuda() >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)[source][source]¶
为此自定义操作注册自动派发规则。
有效的设备类型包括:“cpu”和“cuda”。
- 参数:
op (str | OpOverload) – 要注册自动派发规则的运算符。
device_type (str) – 要使用的设备类型。‘cuda’或‘cpu’。类型与
torch.device
的类型属性相同。因此,您可以使用 Tensor.device.type 获取张量的设备类型。cast_inputs (
torch.dtype
) – 当自定义操作在自动类型转换启用的区域内运行时,将传入的浮点张量转换为目标数据类型(非浮点张量不受影响),然后以自动类型转换禁用的方式执行自定义操作。lib (Optional[Library]) – 如果提供,则此注册的生命周期
- 示例::
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> >>> # Create a custom op that works on cuda >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) >>> def my_sin(x: Tensor) -> Tensor: >>> return torch.sin(x) >>> >>> # Register autocast dispatch rule for the cuda device >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) >>> >>> x = torch.randn(3, dtype=torch.float32, device="cuda") >>> with torch.autocast("cuda", dtype=torch.float16): >>> y = torch.ops.mylib.my_sin(x) >>> assert y.dtype == torch.float16
- torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source][source]¶
为此自定义操作注册一个反向公式。
为了使操作符能够与 autograd 一起工作,您需要注册一个反向公式:1. 您必须通过提供一个“backward”函数来告诉我们如何在反向传递期间计算梯度。2. 如果您需要任何前向值来计算梯度,您可以使用 setup_context 来保存这些值用于反向计算。
backward
在反向传递期间运行。它接受(ctx, *grads)
: -grads
是一个或多个梯度。梯度的数量与操作符的输出数量相匹配。ctx
对象与torch.autograd.Function
使用的 ctx 对象相同。backward_fn
的语义与torch.autograd.Function.backward()
相同。setup_context(ctx, inputs, output)
在前向传递期间运行。请通过torch.autograd.function.FunctionCtx.save_for_backward()
或将它们作为ctx
的属性赋值,将需要用于反向的量保存到ctx
对象上。如果您的自定义操作只有 kwarg-only 参数,我们期望setup_context
的签名是setup_context(ctx, inputs, keyword_only_inputs, output)
。setup_context_fn
和backward_fn
必须可追踪。也就是说,它们不能直接访问torch.Tensor.data_ptr()
,并且不能依赖于或修改全局状态。如果您需要不可追踪的回溯,您可以将它作为一个独立的 custom_op 在backward_fn
内部调用。如果您需要在不同的设备上使用不同的 autograd 行为,那么我们建议为每个需要不同行为的设备创建两个不同的自定义算子,并在运行时在这两个算子之间切换。
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, output) -> Tensor: >>> x, = inputs >>> ctx.save_for_backward(x) >>> >>> def backward(ctx, grad): >>> x, = ctx.saved_tensors >>> return grad * x.cos() >>> >>> torch.library.register_autograd( ... "mylib::numpy_sin", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = x_np * val >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: >>> ctx.val = keyword_only_inputs["val"] >>> >>> def backward(ctx, grad): >>> return grad * ctx.val >>> >>> torch.library.register_autograd( ... "mylib::numpy_mul", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[source][source]¶
为此算子注册 FakeTensor 实现(“fake impl”)。
也被称为“元内核”、“抽象实现”。
“FakeTensor 实现”指定了此算子对不携带数据(“FakeTensor”)的 Tensors 的行为。给定一些具有某些属性(大小/步长/存储偏移/设备)的输入 Tensors,它指定了输出 Tensors 的属性。
FakeTensor 实现与算子的签名相同。它对 FakeTensors 和元张量都运行。要编写 FakeTensor 实现,假设算子的所有 Tensor 输入都是常规 CPU/CUDA/元张量,但它们没有存储,并且你正在尝试返回常规 CPU/CUDA/元张量作为输出。FakeTensor 实现必须仅由 PyTorch 操作组成(并且不能直接访问任何输入或中间张量的存储或数据)。
此 API 可以用作装饰器(请参阅示例)。
关于自定义操作的详细指南,请参阅 https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: >>> raise NotImplementedError("Implementation goes here") >>> >>> @torch.library.register_fake("mylib::custom_linear") >>> def _(x, weight, bias): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> with torch._subclasses.fake_tensor.FakeTensorMode(): >>> x = torch.randn(2, 3) >>> w = torch.randn(3, 3) >>> b = torch.randn(3) >>> y = torch.ops.mylib.custom_linear(x, w, b) >>> >>> assert y.shape == (2, 3) >>> >>> # Example 2: an operator with data-dependent output shape >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) >>> def custom_nonzero(x: Tensor) -> Tensor: >>> x_np = x.numpy(force=True) >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) >>> >>> @torch.library.register_fake("mylib::custom_nonzero") >>> def _(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an fake impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> from torch.fx.experimental.proxy_tensor import make_fx >>> >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) >>> trace.print_readable() >>> >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
- torch.library.register_vmap(op, func=None, /, *, lib=None)[source][source]¶
注册一个 vmap 实现以支持对此自定义操作使用
torch.vmap()
。此 API 可以用作装饰器(请参阅示例)。
为了操作员能够使用
torch.vmap()
,您可能需要在以下签名中注册一个 vmap 实现:vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)
,其中
*args
和**kwargs
是op
的参数和 kwargs。我们不支持仅 kwargs 的 Tensor 参数。它指定了如何计算给定具有额外维度(由
in_dims
指定)的输入的op
的批处理版本。对于
args
中的每个参数,in_dims
对应一个Optional[int]
。如果参数不是 Tensor 或者参数没有被 vmapped,则是None
,否则,它是一个整数,指定了 Tensor 被 vmapped 的维度。info
是一组可能有助于的附加元数据:info.batch_size
指定了正在 vmapped 的维度的尺寸,而info.randomness
是传递给torch.vmap()
的randomness
选项。函数
func
的返回值是一个(output, out_dims)
的元组。类似于in_dims
,out_dims
应该与output
具有相同的结构,并且每个输出包含一个out_dim
,指定输出是否具有 vmapped 维度以及它在其中的索引。示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> from typing import Tuple >>> >>> def to_numpy(tensor): >>> return tensor.cpu().numpy() >>> >>> lib = torch.library.Library("mylib", "FRAGMENT") >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: >>> x_np = to_numpy(x) >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) >>> return torch.tensor(x_np ** 3, device=x.device), dx >>> >>> def numpy_cube_vmap(info, in_dims, x): >>> result = numpy_cube(x) >>> return result, (in_dims[0], in_dims[0]) >>> >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) >>> >>> x = torch.randn(3) >>> torch.vmap(numpy_cube)(x) >>> >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) >>> >>> @torch.library.register_vmap("mylib::numpy_mul") >>> def numpy_mul_vmap(info, in_dims, x, y): >>> x_bdim, y_bdim = in_dims >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) >>> result = x * y >>> result = result.movedim(-1, 0) >>> return result, 0 >>> >>> >>> x = torch.randn(3) >>> y = torch.randn(3) >>> torch.vmap(numpy_mul)(x, y)
注意
vmap 函数应旨在保留整个自定义算子的语义。也就是说,
grad(vmap(op))
应该可以被grad(map(op))
替换。如果您的自定义算子在反向传播中具有任何自定义行为,请记住这一点。
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source][source]¶
此 API 在 PyTorch 2.4 中被重命名为
torch.library.register_fake()
。请使用它代替。
- torch.library.get_ctx()[source][source]¶
get_ctx() 返回当前 AbstractImplCtx 对象。
在模拟实现内部调用
get_ctx()
才有效(更多使用详情请参阅torch.library.register_fake()
)。- 返回类型:
模拟实现上下文
- torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source][source]¶
为给定操作和
torch_dispatch_class
注册 torch_dispatch 规则。这允许在不修改
torch_dispatch_class
或操作员直接的情况下,开放注册以指定操作员与torch_dispatch_class
之间的行为。torch_dispatch_class
可以是具有__torch_dispatch__
的 Tensor 子类或 TorchDispatchMode。如果是 Tensor 子类,我们期望
func
具有以下签名:(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
如果是 TorchDispatchMode,我们期望
func
具有以下签名:(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
args
和kwargs
将与__torch_dispatch__
中的方式相同进行规范化(参见 __torch_dispatch__ 调用约定)。示例
>>> import torch >>> >>> @torch.library.custom_op("mylib::foo", mutates_args={}) >>> def foo(x: torch.Tensor) -> torch.Tensor: >>> return x.clone() >>> >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): >>> return func(*args, **kwargs) >>> >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) >>> def _(mode, func, types, args, kwargs): >>> x, = args >>> return x + 1 >>> >>> x = torch.randn(3) >>> y = foo(x) >>> assert torch.allclose(y, x) >>> >>> with MyMode(): >>> y = foo(x) >>> assert torch.allclose(y, x + 1)
- torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)[source]¶
解析给定函数的方案,其中包含类型提示。方案是从函数的类型提示中推断出来的,可以用来定义新的操作符。
我们做出以下假设:
没有任何输出别名与输入或彼此相同。
- 不指定库的字符串类型注解“device, dtype, Tensor, types”被认为是 torch.*。同样,不指定库的字符串类型注解“Optional, List, Sequence, Union”被认为是 typing.*。没有指定库的字符串类型注解“Optional, List, Sequence, Union”被认为是 typing.*。没有指定库的字符串类型注解“Optional, List, Sequence, Union”被认为是 typing.*。
- 仅对列出的
mutates_args
参数进行修改。如果mutates_args
是“未知”,则假定操作符的所有输入都被修改。
调用者(例如自定义操作 API)负责检查这些假设。
- 参数:
prototype_function(可调用)- 从其类型注解中推断模式的函数。
op_name (Optional[str]) – 模式中的操作员名称。如果
name
为 None,则名称不包括在推断的方案中。请注意,torch.library.Library.define
的输入方案需要操作员名称。mutates_args ("unknown" | Iterable[str]) – 函数中发生变化的参数。
- 返回:
推断的方案。
- 返回类型:
示例
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor: >>> return x.sin() >>> >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) foo(Tensor x) -> Tensor >>> >>> infer_schema(foo_impl, mutates_args={}) (Tensor x) -> Tensor
- class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)[source][source]¶
CustomOpDef 是一个将函数包装成自定义操作的包装器。
它为注册此自定义操作的其他行为提供了各种方法。
不应直接实例化 CustomOpDef;相反,请使用
torch.library.custom_op()
API。- set_kernel_enabled(device_type, enabled=True)[source][source]¶
禁用或重新启用此自定义操作符已注册的内核。
如果内核已禁用/启用,则此操作不执行任何操作。
注意
如果首先禁用内核然后重新注册,则内核将保持禁用状态,直到再次启用。
- 参数:
device_type (str) – 要禁用/启用内核的设备类型。
disable (bool) – 是否禁用或启用内核。
示例
>>> inp = torch.randn(1) >>> >>> # define custom op `f`. >>> @custom_op("mylib::f", mutates_args=()) >>> def f(x: Tensor) -> Tensor: >>> return torch.zeros(1) >>> >>> print(f(inp)) # tensor([0.]), default kernel >>> >>> @f.register_kernel("cpu") >>> def _(x): >>> return torch.ones(1) >>> >>> print(f(inp)) # tensor([1.]), CPU kernel >>> >>> # temporarily disable the CPU kernel >>> with f.set_kernel_enabled("cpu", enabled = False): >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
低级 APIs
以下 API 是 PyTorch 的 C++底层操作符注册 API 的直接绑定。
警告
低级操作符注册 API 和 PyTorch 调度器是 PyTorch 的一个复杂概念。我们建议您尽可能使用上述高级 API(不需要 torch.library.Library 对象)。这篇博客文章(http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/)是了解 PyTorch 调度器的一个好起点。
一篇教程,通过一些示例向您展示如何使用此 API,可在 Google Colab 上找到。
- class torch.library.Library(ns, kind, dispatch_key='')
一个用于创建库的类,这些库可以用来注册新的操作符或覆盖现有库中的操作符。用户可以选择传递一个调度键名,如果他们只想注册与特定调度键相对应的内核。
要覆盖现有库(名称为 ns)中的操作符,将类型设置为“IMPL”。要创建一个新库(名称为 ns)以注册新操作符,将类型设置为“DEF”。要创建一个可能存在的库片段以注册操作符(并绕过给定命名空间只有一个库的限制),将类型设置为“FRAGMENT”。
- 参数:
ns – 库名称
类型 – “DEF”,“IMPL”(默认:“IMPL”),“FRAGMENT”
dispatch_key – PyTorch 调度键(默认:“”)
- define(schema, alias_analysis='', *, tags=())[source][source]¶
在 ns 命名空间中定义一个新的操作符及其语义。
- 参数:
schema – 定义新操作符的函数模式。
别名分析(可选)- 表示操作符参数的别名属性是否可以从模式中推断出来(默认行为)或不能推断出来(“CONSERVATIVE”)。
标签(Tag | Sequence[Tag])- 应用到此操作符的一个或多个 torch.Tag。对操作符进行标记会改变操作符在各种 PyTorch 子系统中的行为;请在应用 torch.Tag 之前仔细阅读相关文档。
- 返回:
从模式中推断出的操作符名称。
- 示例::
>>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- fallback(fn, dispatch_key='', *, with_keyset=False)[source][source]
将函数实现注册为给定键的回退。
此函数仅适用于具有全局命名空间(“_”)的库。
- 参数:
fn – 用作给定分发键回退的函数或
fallthrough_kernel()
以注册跳转。dispatch_key – 输入函数应注册的分发键。默认情况下,它使用库创建时的分发键。
with_keyset – 控制当前调度器调用是否将 keyset 作为第一个参数传递给
fn
。这应该用于为 redispatch 调用创建适当的 keyset。
- 示例::
>>> my_lib = Library("_", "IMPL") >>> def fallback_kernel(op, *args, **kwargs): >>> # Handle all autocast ops generically >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast")
- impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source][source]¶
注册库中定义的操作符的功能实现。
- 参数:
op_name – 操作符名称(包括重载)或 OpOverload 对象。
fn – 函数,是输入分发键的运算符实现,或
fallthrough_kernel()
用于注册跳转。dispatch_key – 输入函数应注册的分发键。默认情况下,它使用库创建时使用的分发键。
with_keyset – 控制当前分发器调用键集是否应作为第一个参数传递给
fn
。这应用于创建适当的键集以进行重新分发调用。
- 示例::
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- torch.library.fallthrough_kernel()[source][source]
一个用于传递给
Library.impl
以注册跳转的虚拟函数。
- torch.library.define(qualname, schema, *, lib=None, tags=())[source][source]¶
- torch.library.define(lib, schema, alias_analysis='')
定义了一个新的操作符。
在 PyTorch 中,定义一个操作符(简称“算子”)是一个两步过程:- 我们需要定义算子(通过提供算子名称和模式)- 我们需要实现算子与各种 PyTorch 子系统(如 CPU/CUDA 张量、Autograd 等)交互的行为。
此入口定义了自定义算子(第一步),然后您必须通过调用各种
impl_*
API 执行第二步,如torch.library.impl()
或torch.library.register_fake()
。- 参数:
qualname(字符串)- 算子的限定名称。应该是一个类似于“命名空间::名称”的字符串,例如“aten::sin”。PyTorch 中的算子需要一个命名空间以避免名称冲突;给定的算子只能创建一次。如果您正在编写 Python 库,我们建议命名空间为您的顶级模块的名称。
schema(字符串)- 算子的模式。例如,对于接受一个张量并返回一个张量的算子,“(Tensor x) -> Tensor”。它不包含算子名称(通过
qualname
传入)。lib(可选[库])- 如果提供,此操作符的生存期将与库对象的生存期绑定。
tags(Tag | Sequence[Tag])- 应用到此操作符的一个或多个 torch.Tag。对操作符进行标记会改变操作符在各种 PyTorch 子系统中的行为;请在应用 torch.Tag 之前仔细阅读相关文档。
- 示例::
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl(lib, name, dispatch_key='')[source][source]
- torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Literal[None] = None, *, lib: Optional[Library] = None) Callable[[Callable[..., object]], None]
- torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Callable[..., object], *, lib: Optional[Library] = None) None
- torch.library.impl(lib: Library, name: str, dispatch_key: str = '') Callable[[Callable[_P, _T]], Callable[_P, _T]]
注册此算子的设备类型实现。
您可以为
types
传递“default”,以将此实现注册为所有设备类型的默认实现。请仅在实现确实支持所有设备类型时使用此选项;例如,如果它是内置 PyTorch 算子的组合,则这是正确的。此 API 可以用作装饰器。您可以使用嵌套装饰器,只要它们返回一个函数并且放在此 API 内部即可(参见示例 2)。
一些有效的类型有:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。
- 参数:
qualname(字符串)- 应该是一个类似于“namespace::operator_name”的字符串。
types(字符串 | 字符串序列)- 将 impl 注册到的设备类型。
lib(可选[库])- 如果提供,则此注册的生命周期将与库对象的生命周期绑定。
示例
>>> import torch >>> import numpy as np >>> # Example 1: Register function. >>> # Define the operator >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylib::mysin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylib.mysin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example 2: Register function with decorator. >>> def custom_decorator(func): >>> def wrapper(*args, **kwargs): >>> return func(*args, **kwargs) + 1 >>> return wrapper >>> >>> # Define the operator >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin_plus_one", "cpu") >>> @custom_decorator >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> >>> y1 = torch.ops.mylib.sin_plus_one(x) >>> y2 = torch.sin(x) + 1 >>> assert torch.allclose(y1, y2)