• 文档 >
  • torch.testing
快捷键

torch.testing ¬

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source][source] ¬

断言 actualexpected 是接近的。

如果 actualexpected 是步进、非量化、实值且有限的,则它们被认为是接近的。

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

非有限值( -infinf )只有在它们相等的情况下才被认为是接近的。 NaN 只有在 equal_nan 等于 True 的情况下才被认为是相等的。

此外,只有在它们具有相同的条件下,才被认为是接近的。

  • 如果 check_deviceTrue

  • dtype (如果 check_dtypeTrue ),

  • layout (如果 check_layoutTrue ),并且

  • 步长(如果 check_strideTrue )。

如果 actualexpected 是元张量,则只执行属性检查。

如果 actualexpected 是稀疏的(无论是 COO、CSR、CSC、BSR 或 BSC 布局),则它们的带偏移量成员将单独检查。索引,即 COO 的 indices ,CSR 和 BCSR 的 crow_indicescol_indices ,或 CSC 和 BSC 布局的 ccol_indicesrow_indices ,始终检查是否相等,而值则根据上述定义检查接近程度。

如果 actualexpected 是量化的,则它们被认为是接近的,如果它们有相同的 qscheme() ,并且 dequantize() 的结果根据上述定义是接近的。

actualexpected 可以是 Tensor 的,或者是从中可以构造出 torch.Tensor 的任何 tensor-or-scalar-likes。除了 Python 标量外,输入类型必须直接相关。此外, actualexpected 可以是 Sequence 的或 Mapping 的,在这种情况下,如果它们的结构匹配,并且所有元素根据上述定义被认为是接近的,则它们被认为是接近的。

注意

Python 标量是类型关系要求的例外,因为它们的 type() ,即 intfloatcomplex ,与 tensor-like 的 dtype 等价。因此,不同类型的 Python 标量可以进行检查,但需要 check_dtype=False

参数:
  • 实际(任何)- 实际输入。

  • 预期(任何)- 预期输入。

  • allow_subclasses(布尔值)- 如果 True (默认)并且除了 Python 标量之外,允许直接相关类型的输入。否则需要类型相等。

  • rtol(可选[浮点数])- 相对容差。如果指定 atol ,则必须也指定 dtype 。如果省略,则根据以下表格选择基于 dtype 的默认值。

  • 绝对容差(Optional[float])- 绝对容差。如果指定了 rtol ,则必须也指定 rtol 。如果省略,则根据以下表格选择基于 dtype 的默认值。

  • equal_nan(Union[bool, str])- 如果 True ,则将两个 NaN 值视为相等。

  • check_device(bool)- 如果 True (默认),则断言相应的张量位于同一 device 。如果禁用此检查,则将位于不同 device 的张量移动到 CPU 后再进行比较。

  • check_dtype(bool)- 如果 True (默认),则断言相应的张量具有相同的 dtype 。如果禁用此检查,则将具有不同 dtype 的张量提升到公共 dtype (根据 torch.promote_types() )后再进行比较。

  • check_layout (bool) – 如果 True (默认),则断言相应的张量具有相同的 layout 。如果此检查被禁用,则具有不同 layout 的张量在比较之前将被转换为带偏移的张量。

  • check_stride (bool) – 如果 True 和相应的张量是带偏移的,则断言它们具有相同的步长。

  • msg (Optional[Union[str, Callable[[str], str]]]) – 可选的错误消息,用于比较失败时使用。也可以作为可调用对象传递,在这种情况下,它将使用生成的消息进行调用,并应返回新的消息。

引发:
  • ValueError – 如果无法从输入构造 torch.Tensor

  • ValueError – 如果只指定了 rtolatol

  • AssertionError – 如果对应的输入既不是 Python 标量也不是直接相关的。

  • AssertionError – 如果 allow_subclassesFalse ,但对应的输入既不是 Python 标量,类型也不同。

  • AssertionError – 如果输入是 Sequence 的,但它们的长度不匹配。

  • 断言错误 - 如果输入是 Mapping 的,但它们的键集不匹配。

  • 断言错误 - 如果对应的张量没有相同的 shape

  • 断言错误 - 如果 check_layoutTrue ,但对应的张量没有相同的 layout

  • 断言错误 - 如果只有对应的一个张量被量化。

  • 断言错误 - 如果对应的张量已量化,但有不同的 qscheme()

  • 断言错误 - 如果 check_deviceTrue ,但对应的张量不在同一个 device 上。

  • 断言错误 - 如果 check_dtypeTrue ,但对应的张量没有相同的 dtype

  • 断言错误 - 如果 check_strideTrue ,但对应的带偏移的张量没有相同的步长。

  • 断言错误 - 如果根据上述定义,对应张量的值不相近。

下表显示了不同 dtype 的默认 rtolatol 。在 dtype 不匹配的情况下,使用两者容忍度的最大值。

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

quint8

1.3e-6

1e-5

quint2x4

1.3e-6

1e-5

quint4x2

1.3e-6

1e-5

qint8

1.3e-6

1e-5

qint32

1.3e-6

1e-5

其他

0.0

0.0

注意

assert_close() 具有高度的可配置性,默认设置严格。鼓励用户根据其用例进行调整。例如,如果需要进行等价性检查,可以定义一个默认对每个 dtype 使用零容忍度的 assert_equal

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Expected 1e-10 but got 1e-09.
Absolute difference: 9.000000000000001e-10
Relative difference: 9.0

示例

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type
<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
and <class 'torch.Tensor'>.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Expected nan but got nan.
Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # If msg is a callable, it can be used to augment the generated message with
>>> # extra information
>>> torch.testing.assert_close(
...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
... )
Traceback (most recent call last):
...
AssertionError: Header

Tensor-likes are not close!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)

Footer
torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source][source]

创建一个具有给定 shapedevicedtype 的张量,并用从 [low, high) 中均匀抽取的值填充。

如果 lowhigh 被指定且超出 dtype 可表示的有限值范围,则它们将被分别夹到最低或最高可表示的有限值。如果 None ,则下表描述了 lowhigh 的默认值,这些值取决于 dtype

dtype

low

high

布尔类型

0

2

无符号整型

0

10

有符号整型

-9

10

浮点型

-9

9

复合类型

-9

9

参数:
  • shape (Tuple[int, ...]) – 单个整数或整数序列,用于定义输出张量的形状。

  • dtype ( torch.dtype ) – 返回张量的数据类型。

  • device (Union[str, torch.device]) – 返回张量的设备。

  • low (Optional[Number]) – 设置给定范围的最低限制(包含)。如果提供数字,则将其夹在给定数据类型的可表示的最小有限值中。当 None (默认)时,此值根据 dtype (见上表)确定。默认: None

  • 高(可选[数字]) –

    设置给定范围的上限(不包括)。如果提供数字,则将其限制为给定数据类型的最大可表示有限值。当为 None (默认)时,此值根据 dtype (见上表)确定。默认: None

    自版本 2.1 开始弃用:对于浮点型或复数类型,从 low==high 传递到 make_tensor() 自 2.1 版本开始弃用,将在 2.3 版本中删除。请使用 torch.full() 代替。

  • requires_grad(可选[布尔值]) – 如果 autograd 应记录对返回张量的操作。默认: False

  • noncontiguous(可选[bool])- 如果为 True,则返回的 tensor 将是非连续的。如果构造的 tensor 元素少于两个,则忽略此参数。与 memory_format 互斥。

  • exclude_zero(可选[bool])- 如果 True ,则将零替换为根据 dtype 的 dtype 的小的正值。对于 bool 和整数类型,零被替换为一。对于浮点类型,它被替换为 dtype 的最小正正常数( dtypefinfo() 对象的“tiny”值),对于复数类型,它被替换为一个实部和虚部都是复数类型能表示的最小正正常数的复数。默认 False

  • memory_format(可选 torch.memory_format)- 返回 tensor 的内存格式。与 noncontiguous 互斥。

引发:
  • ValueError - 如果为整型 dtype 传递了 requires_grad=True

  • ValueError – 如果 low >= high

  • ValueError – 如果 lowhighnan

  • ValueError – 如果同时传递了 noncontiguousmemory_format

  • TypeError – 如果 dtype 不支持此函数。

返回类型:

张量

示例

>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')
torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source][source]

警告

torch.testing.assert_allclose() 自从 1.12 被弃用以来,将在未来的版本中删除。请使用 torch.testing.assert_close() 代替。您可以在此处找到详细的升级说明。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源