• 文档 >
  • torch.overrides
快捷键

torch.overrides

此模块公开了各种辅助函数用于 __torch_function__ 协议。有关 __torch_function__ 协议的更多详细信息,请参阅扩展 torch Python API。

函数 ¶

torch.overrides.get_ignored_functions()[source][source]

返回不能被 __torch_function__ .覆盖的公共函数

返回值:

一个函数元组,这些函数在 torch API 中公开可用,但不能用 __torch_function__ .覆盖。这主要是因为这些函数的参数中没有一个是张量或张量类似物。

返回类型:

set[Callable]

示例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[source][source]

通过 __torch_function__ 可覆盖的函数列表

返回值:

将包含可覆盖函数的命名空间映射到该命名空间中可被覆盖的函数的字典

返回类型:

Dict[Any, List[Callable]]

torch.overrides.resolve_name(f)[source][source]

获取传递给 __torch_function__ 的函数的人类可读字符串名称

参数:

f (Callable) – 解析名称的函数。

返回值:

函数名称;如果评估它,应返回输入函数。

返回类型:

str

torch.overrides.get_testing_overrides()[source][source]

返回一个包含所有可覆盖函数的虚拟覆盖的字典

返回值:

一个将 PyTorch API 中的可覆盖函数映射到与真实函数具有相同签名且无条件返回-1 的 lambda 函数的字典。这些 lambda 函数对于测试定义 __torch_function__ 类型的 API 覆盖率很有用。

返回类型:

Dict[Callable, Callable]

示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source][source]

实现一个带有 __torch_function__ 覆盖检查的功能。

请参阅 torch::autograd::handle_torch_function,了解此函数在 C++实现中的等效函数。

参数:
  • public_api(函数)- 由公共 torch API 公开的函数,最初称为 public_api(*args, **kwargs) ,现在正在检查其参数。

  • relevant_args(可迭代对象)- 要检查__torch_function__方法的参数的可迭代对象。

  • args(元组)- 原先传递给 public_api 的任意位置参数。

  • kwargs(元组)- 原先传递给 public_api 的任意关键字参数。

返回值:

调用 implementation 或适当的 __torch_function__ 方法的结果。

返回类型:

对象

如果未找到实现,则引发 TypeError:

示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

检查可迭代对象的元素中是否存在 __torch_function__ 实现或是否启用了 __torch_function__ 模式。考虑精确的 TensorParameter 不可调度。使用此方法来保护对 handle_torch_function() 的调用;不要用它来测试是否是 Tensor-like,请使用 is_tensor_like() 代替。 :param relevant_args: 要检查 __torch_function__ 方法的可迭代对象或参数。 :type relevant_args: 可迭代对象

返回值:

如果 relevant_args 的任何元素具有 __torch_function__ 实现则为 True,否则为 False。

返回类型:

布尔型

参见

torch.is_tensor_like

检查某个对象是否是 Tensor-like 类型,包括精确的 Tensor

torch.overrides.is_tensor_like(inp)[source][source]

如果传入的输入是 Tensor-like 类型,则返回 True

目前,这发生在输入类型的 __torch_function__ 属性上。

示例

张量的一般子类通常是类似于张量的。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置或用户类型通常不是类似于张量的。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,通过实现 __torch_function__ 可以使它们变为类似于张量的。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[source][source]

判断传入的函数是否为属于 torch.Tensor 的方法或属性的处理程序,如传入 __torch_function__

注意

对于属性,必须传入其 __get__ 方法。

这可能特别需要以下原因:

  1. 方法/属性有时不包含__module__槽位。

  2. 它们要求第一个传入的参数是 torch.Tensor 的实例。

示例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回类型:

布尔型

torch.overrides.wrap_torch_function(dispatcher)[source][source]

包装一个给定的函数,使其具有 __torch_function__ 相关的功能。

参数:

dispatcher (Callable) – 一个可调用的对象,返回一个可迭代的 Tensor-like 对象,这些对象被传递到函数中。

注意

此装饰器可能会降低您的代码性能。通常,将代码表达为一系列支持 __torch_function__ 的函数就足够了。如果您遇到罕见的情况,例如您正在包装一个底层库并且还需要它对 Tensor-likes 也能工作,那么这个函数是可用的。

示例

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源