• 文档 >
  • 模块代码 >
  • torch >
  • torch._库.custom_ops
快捷键

torch._library.custom_ops 的源代码

# mypy: 允许未类型化定义
导入 集合
导入 检查
导入 记录日志
导入 弱引用
from collections.abc 导入 迭代器, 序列
from contextlib 导入 contextmanager
from 打字 导入 任何, 可调用, 直接, 可选, 过载, 联合

导入 火炬
from 火炬 导入 _C, 操作符, 张量
from torch.types 导入 _dtype
from torch.utils._exposed_in 导入 exposed_in

from . 导入 自动微分, 工具


device_types_t = 可选[联盟[字符串, 序列[字符串]]]
日志 = 记录.获取日志记录器(__name__)


@overload
定义 自定义操作(
    名称: 字符串,
    函数: 直接[] = ,
    /,
    *,
    修改参数: 联盟[字符串, 迭代器[字符串]],
    设备类型: 设备类型 = ,
    架构: 可选[字符串] = ,
) -> 可调用[[可调用[..., 对象]], "自定义操作定义"]:
    ...


@overload
定义 自定义操作(
    名称: 字符串,
    函数: 可调用[..., 对象],
    /,
    *,
    修改参数: 联盟[字符串, 迭代器[字符串]],
    设备类型: 设备类型_t = ,
    架构: 可选[字符串] = ,
) -> "自定义操作定义":
    ...


@exposed_in("torch 库")
定义 自定义操作(
    名称: 字符串,
    函数: 可选[可调用] = ,
    /,
    *,
    修改参数: 联盟[字符串, 迭代器[字符串]],
    设备类型: 设备类型_t = ,
    架构: 可选[字符串] = ,
) -> 联盟[可调用[[可调用[..., 对象]], "自定义操作定义"], "自定义操作定义"]:
    将函数包装成自定义操作符。

创建自定义操作符的原因可能包括:
- 将第三方库或自定义内核包装以与 PyTorch 的 Autograd 等子系统协同工作。
- 防止 torch.compile/export/FX tracing 窥探您的函数内部。
- 防止 torch.compile/export/FX tracing 窥探您的函数内部。

该 API 用作函数的装饰器(请参阅示例)。
提供的函数必须有类型提示;这些提示是用于与 PyTorch 的各种子系统交互所必需的。
的。

参数:
名称(str):自定义操作符的名称,格式为"{namespace}::{name}"。
例如 "mylib::my_linear"。该名称用作操作的稳定标识符
在 PyTorch 子系统中(例如 torch.export、FX 图)
为避免名称冲突,请使用您的项目名称作为命名空间;
例如,pytorch/fbgemm 中所有自定义操作都使用 "fbgemm" 作为命名空间。
mutates_args (Iterable[str] 或 "unknown"): 函数修改的参数名称。
            This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
            it pessimistically assumes that all inputs to the operator are being mutated.
device_types (None | str | Sequence[str]): 函数的设备类型(None | 字符串 | 字符串序列)。
适用于。如果没有提供设备类型,则该函数
是所有设备类型的默认实现。
示例:"cpu","cuda"。
当为接受无张量的算子注册特定设备的实现时,
我们要求操作员必须有一个 "device: torch.device" 参数。
schema (None | str): 操作员的模式字符串。如果为 None
(推荐)我们将从操作员类型推断模式
注释中推断操作员的模式。除非你有特殊需求,我们建议您让我们推断模式
有特定的理由不这么做。
例如:"(Tensor x, int y) -> (Tensor, Tensor)"。

.. 注意::
我们建议不要传递 `schema` 参数,而是让我们从类型注解中推断。自己编写模式容易出错。
它是错误倾向的,编写自己的模式。
您可能希望提供自己的模式,如果我们的解释
类型注解不是你想要的。
关于如何编写模式字符串的更多信息,请参阅
这里 

示例:
>>> 导入 torch
>>> 从 torch 导入 Tensor
        >>> from torch.library import custom_op
        >>> import numpy as np
...
        >>> @custom_op("mylib::numpy_sin", mutates_args=())
        >>> def numpy_sin(x: Tensor) -> Tensor:
        >>>     x_np = x.cpu().numpy()
        >>>     y_np = np.sin(x_np)
        >>>     return torch.from_numpy(y_np).to(device=x.device)
...
        >>> x = torch.randn(3)
        >>> y = numpy_sin(x)
        >>> assert torch.allclose(y, x.sin())
...
>>> # 示例:仅适用于一种设备类型的自定义操作。
        >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
        >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
        >>>     x_np = x.numpy()
        >>>     y_np = np.sin(x_np)
        >>>     return torch.from_numpy(y_np)
...
        >>> x = torch.randn(3)
        >>> y = numpy_sin_cpu(x)
        >>> assert torch.allclose(y, x.sin())
...
>>> # 示例:自定义操作符,修改输入
        >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
        >>> def numpy_sin_inplace(x: Tensor) -> None:
        >>>     x_np = x.numpy()
>>> np.sin(x_np, out=x_np)
...
        >>> x = torch.randn(3)
        >>> expected = x.sin()
        >>> numpy_sin_inplace(x)
        >>> assert torch.allclose(x, expected)
...
>>> # 示例工厂函数
        >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
        >>> def bar(device: torch.device) -> Tensor:
        >>>     return torch.ones(3)
...
        >>> bar("cpu")

"文档"

    定义 内部(函数: 可调用[..., 对象]) -> 自定义运算定义:
        导入 火炬

        如果 架构  :
            schema_str = 火把.图书馆.infer_schema(函数, mutates_args=修改参数)
        否则:
            schema_str = 架构

        命名空间, 操作名称 = 名称.分割(双冒号)
        结果 = 自定义运算定义(命名空间, 操作名称, schema_str, 函数)
        如果 架构   :
            检查 schema 的别名注解是否与`mutates_args`匹配。
            预期 = 设置()
            for 参数  结果._opoverload._模式.参数:
                如果 arg.别名信息    并且 arg.别名信息.是否写入:
                    预期.添加(arg.名称)
            如果 预期 != 设置(修改参数):
                raise ValueError(
                    f尝试创建一个具有 `修改参数=` 的自定义操作{修改参数}"请提供需要翻译的文本,以便我进行翻译。"
                    f"并且 `schema="{架构}该模式表明操作会修改{预期}"
                    f"这与我们提供的 `mutates_args` 中的不同。"
                    f"请保持一致性。"
                )
        结果.注册内核(设备类型)(函数)
        返回 结果

    如果 fn  :
        返回 内部
    返回 内部(函数)


[文档] 自定义运算定义: CustomOpDef 是一个将函数转换为自定义操作的包装器。 它有多种方法用于注册此处的附加行为 自定义操作 您不应直接实例化 CustomOpDef;相反,请使用 `torch.library.custom_op` API。 "文档" 定义 __init__(, 命名空间: 字符串, 名称: 字符串, 架构: 字符串, 函数: 可调用) -> : 用于与 PyTorch 分发器接口的字段 ._命名空间 = 命名空间 .名称 = 名称 ._架构 = 架构 ._初始化函数 = fn ._后端函数: 字典[联盟[字符串, ], 可调用] = {} ._抽象函数: 可选[可调用] = ._设置上下文函数: 可选[可调用] = ._反向函数: 可选[可调用] = ._torch_dispatch_fns: 字典[类型, 可调用] = {} ._vmap_fn: 可选[可调用] = ._autocast_cuda_dtype: 可选[数据类型] = ._autocast_cpu_dtype: 可选[数据类型] = ._lib = 获取允许覆盖的库(._命名空间, .名称) ._注册到分发器() ._禁用内核: 设置 = 设置() OP 定义[._qualname] = self @property 定义 _qualname() -> 字符串: 返回 f"{._namespace}::{.名称}" 定义 __repr__() -> 字符串: 返回 f"<自定义操作符定义("{._qualname})>"
[文档] @contextmanager 定义 set_kernel_enabled(, 设备类型: 字符串, 启用: 布尔 = ): "" 禁用或重新启用此自定义操作符已注册的内核。 如果内核已被禁用/启用,则此操作无效果。 注意: 如果内核先被禁用然后注册,则内核将保持禁用状态,直到再次启用。 参数: device_type (str): 要禁用/启用内核的设备类型。 disable (bool): 是否禁用或启用内核。 示例: >>> inp = torch.randn(1) ... >>> # 定义自定义操作 `f`。 >>> @custom_op("mylib::f", mutates_args=()) >>> def f(x: Tensor) -> Tensor: >>> return torch.zeros(1) ... >>> print(f(inp)) # tensor([0.]), 默认内核 ... >>> @f.register_kernel("cpu") >>> def _(x): >>> return torch.ones(1) ... >>> print(f(inp)) # 索引([1.]), CPU 内核 ... >>> # 暂时禁用 CPU 内核 >>> with f.set_kernel_enabled("cpu", enabled = False): >>> print(f(inp)) # 索引([0.]) 禁用 CPU 内核后 "文档" 行动 = 启用 如果 启用 否则 禁用 最初禁用 = 设备类型 .禁用内核 如果 设备类型 .后端函数: 日志.警告( 尝试执行%s内核,但未为该设备类型注册内核。%s但没有为该设备类型注册内核。, 动作, 设备类型, ) 如果 启用: 如果 原先已禁用: 日志.警告( 尝试禁用内核%s但它已经被禁用了。, 设备类型, ) 否则: ._已禁用内核.添加(设备类型) 否则: 启用内核 如果 最初是禁用的: 日志.警告( 尝试启用内核,%s但它已经被启用。, 设备类型, ) 否则: .禁用内核.删除(设备类型) 尝试: 产生 最后: 恢复原始状态 如果 最初禁用: .禁用内核.添加(设备类型) 否则: .禁用内核.丢弃(设备类型)
定义 注册内核( , 设备类型: 设备类型_t, 函数: 可选[可调用] = , / ) -> 可调用: 为此操作员注册设备类型的实现。 一些有效的设备类型包括:"cpu"、"cuda"、"xla"、"mps"、"ipu"、"xpu"。 此 API 可以用作装饰器。 参数: fn (Callable): 将注册为实现该功能的函数 给定的设备类型。 device_types (str | Sequence[str]): 注册实现所用的设备类型 示例: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> 导入 torch >>> 从 torch 导入 Tensor >>> from torch.library import custom_op >>> import numpy as np ... >>> # 创建一个在 cpu 上工作的自定义操作 >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) ... >>> # 添加对 cuda 设备的实现 >>> @numpy_sin.register_kernel("cuda") >>> def _(x): >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) ... >>> x_cpu = torch.randn(3) >>> x_cuda = x_cpu.cuda() >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) "文档" 定义 内部(函数): 如果 设备类型 或者 isinstance(设备类型, 字符串): 数据类型: 列表[联盟[字符串, ]] = [设备类型] 否则: 数据类型 = 列表(设备类型) for 设备类型 数据类型: 如果 设备类型 .后端函数: 定义 后端实现(*参数, **kwargs): 结果 = .后端函数[设备类型]*参数, **kwargs) 定义 获取模块(): fn = .后端函数[设备类型] 返回 检查.获取模块(函数) 工具.检查别名约束( .名称, 工具.迭代张量(参数, kwargs), 结果, 获取模块, ) 返回 结果 如果 设备类型 : ._lib.实现( .名称, 后端实现, "复合显式自动微分" ) 否则: ._lib.实现( .名称, 后端实现, _C._设备调度键(设备类型), ) # 将函数包装以在默认实现或设备特定实现之间进行选择 # 根据内核是否禁用,选择实现 @torch.禁用 dynamo 定义 wrapped_fn(*参数, **kwargs): 如果 设备类型 .禁用内核: 返回 .初始化函数(*参数, **kwargs) 否则: 返回 函数(*参数, **kwargs) .后端函数[设备类型] = wrapped_fn 返回 fn 如果 设备类型 并且 工具.有张量参数( .操作符重载._架构 ): 设备参数索引 = 工具.获取设备参数索引(._opoverload._模式) 如果 设备参数索引 : raise ValueError( 函数没有张量输入时,必须有一个 `device: torch.device` 参数 ) .注册后端选择分发器(设备参数索引) # 查看 NOTE: [支持装饰器和非装饰器用法] 如果 fn : 返回 内部 返回 内部(函数) 定义 注册伪造(, 函数: 可调用, /) -> 可调用: r"""为这个自定义操作注册 FakeTensor 实现。 这是使操作员能够高效地与 torch.compile 一起工作的必要条件。 模拟实现(有时也称为元内核或抽象实现) 指定了此操作符在没有任何数据的张量上的行为。 给定一些具有特定属性的输入张量 (尺寸/步长/存储偏移/设备),它指定了输出张量的属性 请参阅 :func:`torch.library.impl_abstract` 获取更多详细信息。 fn (Callable):注册为 FakeTensor 的函数 参数: fn (Callable):注册为 FakeTensor 的函数 实现。 示例: >>> 导入 torch >>> import numpy as np >>> 从 torch 导入 Tensor ... >>> # 示例 1:一个没有数据依赖输出形状的操作符 >>> @torch.library.custom_op("mylib::linear", mutates_args=()) >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: >>> 返回 (x @ weight.t()) + bias ... >>> @linear.register_fake >>> def _(x, weight, bias): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> 断言 bias 的维度等于 1 >>> 断言 x 的形状[1] 等于 weight 的形状[1] >>> 断言 weight 的形状[0] 等于 bias 的形状[0] >>> 断言 x 的设备等于 weight 的设备 >>> 返回 x.new_empty(x.size(0), weight.size(0)) ... >>> x = torch.randn(2, 2) >>> weight = torch.randn(2, 2) >>> bias = torch.randn(2) >>> # xdoctest: +SKIP("需要 Python <= 3.11") >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) >>> # xdoctest: +SKIP("需要 Python <= 3.11") >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) ... >>> # Example 2: an operator with data-dependent output shape >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) >>> def nonzero(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) ... >>> @nonzero.register_fake >>> def _(x): >>> # 非零元素的数量依赖于数据。 >>> # 由于我们不能在抽象实现中查看数据, >>> # 我们使用 ctx 对象来构造一个新的 symint, >>> # 以表示依赖于数据大小的尺寸。 >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result ... >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) >>> # xdoctest: +SKIP("需要 Python <= 3.11") >>> out = torch.compile(nonzero, fullgraph=True)(x) >>> # xdoctest: +SKIP("需要 Python <= 3.11") >>> 断言 torch.allclose(out, x.nonzero()) "文档" ._abstract_fn = fn 返回 fn 定义 注册 torch 分派( , torch 调度类: 任何, 函数: 可选[可调用] = , / ) -> 可调用: r注册给定操作符和 `torch_dispatch_class` 的 torch_dispatch 规则。 这允许开放注册以指定操作员之间的行为 无需修改的 `torch_dispatch_class` 或直接操作员。 请参阅 :func:`torch.library.register_torch_dispatch` 以获取示例和更多详细信息。 "文档" 定义 注册(函数): 如果 torch_dispatch_class ._torch_dispatch_fns: 定义 内部(*参数, **kwargs): 返回 ._torch_dispatch_fns[torch 调度类[ ]( *参数, **kwargs ) ._lib._注册 torch 调度规则( .名称, torch 调度类, 内部 ) ._torch_dispatch_fns[torch 调度类] = fn 返回 fn 如果 fn : 返回 注册 否则: 返回 注册(函数) 定义 注册自动微分( , 反向: 可调用, /, *, 设置上下文: 可选[可调用] = , ) -> : r注册此自定义操作的逆向公式。 为了使操作符能够与 autograd 一起工作,您需要注册 一个反向公式: 1. 您必须告诉我们如何在反向传播过程中计算梯度 通过为我们提供一个“反向”函数。 2. 如果您需要从正向计算梯度时使用任何正向的值,您可以 使用 `setup_context` 来保存向后传递的值。 `backward_fn` 在反向传播过程中运行。它接受 `(ctx, *grads)`: - ``grads`` 是一个或多个梯度。梯度的数量与 操作符的输出数量相匹配。 ``ctx``对象与``_中使用的`ctx`对象相同。 `torch.autograd.Function`的语义与`backward_fn`相同。 `setup_context(ctx, inputs, output)`在正向传播期间运行。 `torch.autograd.Function.backward`的语义相同。 请将所需的数量反向保存到“ctx”对象中 either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` 翻译为:either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` 将它们作为 ``ctx`` 的属性分配。如果你的自定义操作 仅支持关键字参数,我们期望`setup_context`的签名 `setup_context(ctx, inputs, keyword_only_inputs, output)` 需要设置上下文。 `setup_context_fn` 和 `backward_fn` 必须可追踪。也就是说, 它们不能直接访问 :meth:`torch.Tensor.data_ptr`,并且不能 依赖于或修改全局状态。如果您需要不可追踪的反向操作, 你可以将其制作为一个独立的 custom_op,在 backward_fn 中调用。 如果您需要在不同的设备上实现不同的 autograd 行为,那么我们建议创建两个不同的自定义算子,每个设备一个,并在运行时切换它们。 推荐为需要不同行为的每个设备创建两个不同的自定义算子,并在运行时进行切换。 需要不同行为的设备,建议创建两个不同的自定义算子,并在运行时进行切换。 示例: >>> 导入 torch >>> import numpy as np >>> 从 torch 导入 Tensor ... >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) ... >>> def setup_context(ctx, inputs, output) -> Tensor: >>> x, = inputs >>> ctx.save_for_backward(x) ... >>> def backward(ctx, grad): >>> x, = ctx.saved_tensors >>> return grad * x.cos() ... >>> numpy_sin.register_autograd(backward, setup_context=setup_context) ... >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) ... >>> # 示例:仅使用关键字参数 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = x_np * val >>> return torch.from_numpy(y_np).to(device=x.device) ... >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: >>> ctx.val = keyword_only_inputs["val"] ... >>> def backward(ctx, grad): >>> return grad * ctx.val ... >>> numpy_mul.register_autograd(backward, setup_context=setup_context) ... >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) "文档" 架构 = ._opoverload._架构 如果 工具.功能性模式(架构): raise 运行时错误( f"无法为非功能性算子注册自动微分公式" f"{}与模式{架构}请创建 " f一个功能操作符并为该操作符注册一个自动求导公式。 ) .向后函数 = 向后 ._setup_context_fn = 设置上下文 定义 _register_to_dispatcher() -> : 如果 火把._运行时使用部署(): 工具.警告部署(栈级别=5) 返回 = ._lib schema_str = .名称 + ._架构 cpp_schema = _C.解析模式(schema_str) 如果 工具.只有关键字参数的张量(cpp_schema): # 如果要支持此功能,进度如下: # - 支持仅带参数的不可微分张量 # - 支持仅带参数的张量(无论是否可微分) raise 不支持的操作异常( f"自定义操作使用仅带参数的 Tensor 参数。请确保您的 " f张量不是仅作为关键字参数。得到:{schema_str}" ) .定义( schema_str, 标签=[_C.标签.pt2_compliant_tag, _C.标签.需要修正步长顺序], ) ._操作符重载 = 工具.查找操作(._qualname) 定义 模拟实现(*参数, **kwargs): 如果 ._abstract_fn : 如果 工具.can_generate_trivial_fake_impl(._opoverload): 返回 raise 运行时错误( f"没有注册任何假的 impl"{}. f"这是 torch.compile/export/fx 追踪工作所必需的。" f"请使用 `"{._init_fn.__name__}`.register_fake` 用于添加一个 "` f"伪造实现。" ) 返回 .`_抽象函数`(*参数, **kwargs) .注册模拟实现(.名称, 模拟实现, _stacklevel=4) 自动微分实现 = 自动微分.创建自动微分实现(._操作符重载, ) .实现(.名称, 自动微分实现, 自动微分, 带有密钥集=) 架构 = ._操作符重载._架构 如果 架构.可变: 被修改的索引, 被修改的键 = 工具.被修改的参数和关键字参数(架构) 定义 在位或视图实现(键集, *参数, **kwargs): for 索引 变异索引: 增加版本号(参数[索引]) for key 突变键: 增量版本(kwargs[]) _C._AutoDispatchBelowADInplaceOrView(): 返回 ._opoverload.调度( 键集 & _C._after_ADInplaceOrView_键集, *参数, **kwargs ) .实现( .名称, 放置或查看实现, ADInplaceOrView, 带有密钥集=, ) 定义 注册后端选择调度器(, 设备参数索引: 整数): "" 打开设备参数以选择正确的后端进行调度。 "文档" 定义 后端选择(键集, *参数, **kwargs): 设备 = 参数[设备参数索引].类型 如果 设备 ._后端函数: raise 运行时错误( f"{.名称}没有注册内核{设备}. "请使用 register_kernel 来这样做。" ) 分发键 = _C._设备调度键(设备) 分发键 = getattr(_C.发送键, 分发键) 返回 ._opoverload.重新调度( _C.发送键集(分发键), *参数, **kwargs ) ._lib.实现(.名称, 后端选择, 后端选择, 带有密钥集=) 定义 __调用__(, *参数, **kwargs): 返回 ._opoverload(*参数, **kwargs) 定义 注册_vmap( , 函数: 可选[可调用] = , ): r注册一个 vmap 实现,以支持此自定义操作的:func:`torch.vmap`。 此 API 可以用作装饰器。 为了使操作员能够使用 :func:`torch.vmap`,您可能需要注册 vmap 实现如下签名: vmap_func(info, in_dims: 可选[int] 元组, *args, **kwargs) ``*args`` 和 ``**kwargs`` 是 ``op`` 的参数和关键字参数。 它指定了在给定具有额外维度(由`in_dims`指定)的输入时,如何计算`op`的批处理版本。 对于`args`中的每个参数,`in_dims`都有一个对应的`Optional[int]`。它是`None`。 对于`args`中的每个参数,`in_dims`都有一个对应的`Optional[int]`。它是`None`。 如果参数不是张量或者参数没有被 vmapped,否则它是一个整数 指定正在被 vmapped 操作覆盖的张量的哪个维度 ``info`` 是一个可能包含有用额外元数据的集合: ``info.batch_size`` 指定了正在被 vmapped 覆盖的维度的尺寸,同时 ``info.randomness`` 是传递给 :func:`torch.vmap` 的 ``randomness`` 选项。 函数 ``func`` 的返回值是一个包含 ``output, out_dims`` 的元组。与 ``in_dims`` 类似, ``out_dims`` 应与 ``output`` 具有相同的结构,并包含一个 ``out_dim``, 每个输出指定该输出是否具有 vmapped 维度以及它在其中的索引。 示例: >>> 导入 torch >>> import numpy as np >>> 从 torch 导入 Tensor >>> 从 typing 导入 Tuple ... >>> 定义 to_numpy(tensor)函数 >>> 返回 tensor.cpu().numpy() ... >>> lib = torch.library.Library("mylib", "FRAGMENT") >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: >>> x_np = to_numpy(x) >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) >>> return torch.tensor(x_np ** 3, device=x.device), dx ... >>> def numpy_cube_vmap(info, in_dims, x): >>> result = numpy_cube(x) >>> 返回结果,(in_dims[0],in_dims[0]) ... >>> numpy_cube.register_vmap(numpy_cube_vmap) ... >>> x = torch.randn(3) >>> torch.vmap(numpy_cube)(x) ... >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) ... >>> @numpy_mul.register_vmap >>> def numpy_mul_vmap(info, in_dims, x, y): >>> x_bdim, y_bdim = in_dims >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) >>> result = x * y >>> result = result.movedim(-1, 0) 结果 = 结果.movedim(-1, 0) >>> 返回结果,0 ... ... >>> x = torch.randn(3) >>> y = torch.randn(3) >>> torch.vmap(numpy_mul)(x, y) "文档" from torch._functorch.autograd_function 导入 custom_function_call_vmap_helper from torch._functorch.pyfunctorch 导入 retrieve_current_functorch_interpreter 定义 注册(函数): 需要注册 = ._vmap_fn ._vmap_fn = 函数 如果 需要注册: 定义 wrapped_func(键集, *参数, **kwargs): 解释器 = 获取当前 functorch 解释器() 返回 自定义函数调用_vmap 辅助器( 解释器, ._vmap_fn, ._opoverload, *参数, **kwargs ) ._lib.实现( .名称, wrapped_func, "FuncTorchBatched", 带有密钥集=真实 ) 如果 函数 : 返回 注册 否则: 返回 注册(函数) 定义 注册自动转换( , 设备类型: 字符串, 输入设备: 数据类型, ): r注册此自定义操作的自动广播调度规则。 有效的 `device_type` 包括: "cpu" 和 "cuda"。 参数: op (str | OpOverload): 要注册自动广播调度规则的运算符。 device_type(str): 要使用的设备类型。'cuda' 或 'cpu'。 类型与 :class:`torch.device` 的 `type` 属性相同。 因此,您可以使用 `Tensor.device.type` 获取张量的设备类型。 cast_inputs (:class:`torch.dtype`): 当自定义操作在自动类型转换区域运行时, 将传入的浮点型张量转换为目标数据类型(非浮点型张量) 这些不受影响),然后执行禁用自动转换的自定义操作。 lib(可选[库]):如果提供,则此注册的生命周期 示例: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> 导入 torch >>> 从 torch 导入 Tensor >>> from torch.library import custom_op ... >>> # 创建一个在 cuda 上工作的自定义操作 >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) >>> def my_sin(x: Tensor) -> Tensor: >>> return torch.sin(x) ... >>> # 注册针对 cuda 设备的 autocast 调度规则 >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) ... >>> x = torch.randn(3, dtype=torch.float32, device="cuda") >>> with torch.autocast("cuda", dtype=torch.float16): >>> y = torch.ops.mylib.my_sin(x) >>> assert y.dtype == torch.float16 "文档" 如果 isinstance(设备类型, 字符串): raise ValueError( f预期 `device_type` 为 `str` 类型,实际得到:`{类型(设备类型)} ) 如果 设备类型 ["cpu", cuda]: raise ValueError(f"未知设备类型:"{设备类型}") 需要注册 CUDA = ._自动转换 CUDA 数据类型 需要注册 CPU = ._自动转换 CPU 数据类型 如果 设备类型 == cuda: ._autocast_cuda_dtype = 设置输入 否则: ._autocast_cpu_dtype = 设置输入 定义 内核(_, *参数, **kwargs): 断言 长度(kwargs) == 0, "自定义操作目前尚不支持 kwargs。" autocast_keyset = 火把._C.发送键集( 火把._C.发送键.自动转换 CPU ) | 火把._C.发送键集(火把._C.发送键.自动将 CUDA) 火把._C._排除调度键保护(自动铸件键集): 返回 ._opoverload(*_投射(参数, 设备类型, 输入设备)) 如果 需要注册 CUDA 并且 ._autocast_cuda_dtype: ._lib.实现(.名称, 内核, 自动转换 CUDA, 带有密钥集=) elif 需要注册 CPU 并且 .自动转换 CPU 数据类型: ._lib.实现(.名称, 内核, 自动转换 CPU, 带有密钥集=) 返回 内核
# TODO: 将此函数与 torch.amp.autocast_mode._cast 合并,并进行重构 将相关操作转换为一次性的实用函数,一旦自定义操作支持任意输入类型。 定义 _投射(, 设备类型: 字符串, 数据类型: 数据类型): 如果 isinstance(, 火把.张量): 是否符合资格 = ( .is_floating_point() 并且 .设备.类型 == 设备类型 并且 (.dtype 火把.float64) ) 返回 .(数据类型) 如果 是否符合资格 否则 elif isinstance(, (字符串, 字节)): 返回 elif isinstance(, 集合.abc.迭代器): 可迭代的 = (_投射(v, 设备类型, 数据类型) for v ) 如果 isinstance(, (列表, 元组)): 返回 类型()(可迭代对象) 否则: 返回 可迭代的 否则: 返回 定义 增加版本(val: 任何) -> : 如果 isinstance(val, 张量): 火把.自动微分..增加版本(val) elif isinstance(val, (元组, 列表)): for v val: 如果 isinstance(v, 张量): 火把.自动微分..增加版本(v) # NOTE: [支持装饰器和非装饰器用法] # 一些 API 可能既可以用作装饰器,也可以不用作装饰器。 例如: # >>> 定义函数 fn(x): >>> 返回 x 的正弦值 注释 # 使用 1:不是作为装饰器 # numpy_sin.register_kernel("cuda", fn) 注释 # >>> 使用 2:作为装饰器 # >>> @numpy_sin.register_kernel("cuda") # >>> def fn2(x): # >>> return x.sin # 我们支持这种方式是通过`register_kernel`函数接受一个可选的`fn`。 如果提供了`fn`(用法 1),则我们知道用户正在使用它不 # 作为装饰器。 如果 `fn` 未提供(用法 2),则 `register_kernel` 需要返回 装饰器。 OPDEF_TO_LIB: 字典[字符串, "torch.library.Library"] = {} OPDEFS: 弱引用.弱值字典 = 弱引用.弱值字典() 定义 获取允许覆盖的库( 命名空间: 字符串, 名称: 字符串 ) -> torch.library.Library: qualname = f"{命名空间}::{名称}" 如果 qualname OPDEF_TO_LIB: OPDEF_TO_LIB[qualname]._destroy() 删除 OPDEF_TO_LIB[qualname] = 火把.图书馆.(命名空间, "碎片") # noqa: TOR901 OPDEF_TO_LIB[qualname] = 返回 定义 可能获取操作定义( 操作: 联盟[自定义运算定义, 操作符.操作符重载, 字符串] ) -> 可选[自定义运算定义]: 如果 isinstance(操作, 自定义运算定义): 返回 操作符 如果 isinstance(操作, 操作符.操作符重载): 操作符 = 操作.名称 断言 isinstance(操作, 字符串) 如果 操作符 OPDEFS: 返回 OPDEFS[操作] 返回

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源