torch.testing ¬
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source][source] ¬
断言
actual
和expected
是接近的。如果
actual
和expected
是步进、非量化、实值且有限的,则它们被认为是接近的。非有限值(
-inf
和inf
)只有在它们相等的情况下才被认为是接近的。NaN
只有在equal_nan
等于True
的情况下才被认为是相等的。此外,只有在它们具有相同的条件下,才被认为是接近的。
如果
check_device
是True
,dtype
(如果check_dtype
是True
),layout
(如果check_layout
是True
),并且步长(如果
check_stride
是True
)。
如果
actual
或expected
是元张量,则只执行属性检查。如果
actual
和expected
是稀疏的(无论是 COO、CSR、CSC、BSR 或 BSC 布局),则它们的带偏移量成员将单独检查。索引,即 COO 的indices
,CSR 和 BCSR 的crow_indices
和col_indices
,或 CSC 和 BSC 布局的ccol_indices
和row_indices
,始终检查是否相等,而值则根据上述定义检查接近程度。如果
actual
和expected
是量化的,则它们被认为是接近的,如果它们有相同的qscheme()
,并且dequantize()
的结果根据上述定义是接近的。actual
和expected
可以是Tensor
的,或者是从中可以构造出torch.Tensor
的任何 tensor-or-scalar-likes。除了 Python 标量外,输入类型必须直接相关。此外,actual
和expected
可以是Sequence
的或Mapping
的,在这种情况下,如果它们的结构匹配,并且所有元素根据上述定义被认为是接近的,则它们被认为是接近的。注意
Python 标量是类型关系要求的例外,因为它们的
type()
,即int
、float
和complex
,与 tensor-like 的dtype
等价。因此,不同类型的 Python 标量可以进行检查,但需要check_dtype=False
。- 参数:
实际(任何)- 实际输入。
预期(任何)- 预期输入。
allow_subclasses(布尔值)- 如果
True
(默认)并且除了 Python 标量之外,允许直接相关类型的输入。否则需要类型相等。rtol(可选[浮点数])- 相对容差。如果指定
atol
,则必须也指定dtype
。如果省略,则根据以下表格选择基于dtype
的默认值。绝对容差(Optional[float])- 绝对容差。如果指定了
rtol
,则必须也指定rtol
。如果省略,则根据以下表格选择基于dtype
的默认值。equal_nan(Union[bool, str])- 如果
True
,则将两个NaN
值视为相等。check_device(bool)- 如果
True
(默认),则断言相应的张量位于同一device
。如果禁用此检查,则将位于不同device
的张量移动到 CPU 后再进行比较。check_dtype(bool)- 如果
True
(默认),则断言相应的张量具有相同的dtype
。如果禁用此检查,则将具有不同dtype
的张量提升到公共dtype
(根据torch.promote_types()
)后再进行比较。check_layout (bool) – 如果
True
(默认),则断言相应的张量具有相同的layout
。如果此检查被禁用,则具有不同layout
的张量在比较之前将被转换为带偏移的张量。check_stride (bool) – 如果
True
和相应的张量是带偏移的,则断言它们具有相同的步长。msg (Optional[Union[str, Callable[[str], str]]]) – 可选的错误消息,用于比较失败时使用。也可以作为可调用对象传递,在这种情况下,它将使用生成的消息进行调用,并应返回新的消息。
- 引发:
ValueError – 如果无法从输入构造
torch.Tensor
。ValueError – 如果只指定了
rtol
或atol
。AssertionError – 如果对应的输入既不是 Python 标量也不是直接相关的。
AssertionError – 如果
allow_subclasses
是False
,但对应的输入既不是 Python 标量,类型也不同。AssertionError – 如果输入是
Sequence
的,但它们的长度不匹配。断言错误 - 如果输入是
Mapping
的,但它们的键集不匹配。断言错误 - 如果对应的张量没有相同的
shape
。断言错误 - 如果
check_layout
是True
,但对应的张量没有相同的layout
。断言错误 - 如果只有对应的一个张量被量化。
断言错误 - 如果对应的张量已量化,但有不同的
qscheme()
。断言错误 - 如果
check_device
是True
,但对应的张量不在同一个device
上。断言错误 - 如果
check_dtype
是True
,但对应的张量没有相同的dtype
。断言错误 - 如果
check_stride
是True
,但对应的带偏移的张量没有相同的步长。断言错误 - 如果根据上述定义,对应张量的值不相近。
下表显示了不同
dtype
的默认rtol
和atol
。在dtype
不匹配的情况下,使用两者容忍度的最大值。dtype
rtol
atol
float16
1e-3
1e-5
bfloat16
1.6e-2
1e-5
float32
1.3e-6
1e-5
float64
1e-7
1e-7
complex32
1e-3
1e-5
complex64
1.3e-6
1e-5
complex128
1e-7
1e-7
quint8
1.3e-6
1e-5
quint2x4
1.3e-6
1e-5
quint4x2
1.3e-6
1e-5
qint8
1.3e-6
1e-5
qint32
1.3e-6
1e-5
其他
0.0
0.0
注意
assert_close()
具有高度的可配置性,默认设置严格。鼓励用户根据其用例进行调整。例如,如果需要进行等价性检查,可以定义一个默认对每个dtype
使用零容忍度的assert_equal
:>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0
示例
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source][source]¶
创建一个具有给定
shape
,device
和dtype
的张量,并用从[low, high)
中均匀抽取的值填充。如果
low
或high
被指定且超出dtype
可表示的有限值范围,则它们将被分别夹到最低或最高可表示的有限值。如果None
,则下表描述了low
和high
的默认值,这些值取决于dtype
。dtype
low
high
布尔类型
0
2
无符号整型
0
10
有符号整型
-9
10
浮点型
-9
9
复合类型
-9
9
- 参数:
shape (Tuple[int, ...]) – 单个整数或整数序列,用于定义输出张量的形状。
dtype (
torch.dtype
) – 返回张量的数据类型。device (Union[str, torch.device]) – 返回张量的设备。
low (Optional[Number]) – 设置给定范围的最低限制(包含)。如果提供数字,则将其夹在给定数据类型的可表示的最小有限值中。当
None
(默认)时,此值根据dtype
(见上表)确定。默认:None
。高(可选[数字]) –
设置给定范围的上限(不包括)。如果提供数字,则将其限制为给定数据类型的最大可表示有限值。当为None
(默认)时,此值根据dtype
(见上表)确定。默认:None
。自版本 2.1 开始弃用:对于浮点型或复数类型,从
low==high
传递到make_tensor()
自 2.1 版本开始弃用,将在 2.3 版本中删除。请使用torch.full()
代替。requires_grad(可选[布尔值]) – 如果 autograd 应记录对返回张量的操作。默认:
False
。noncontiguous(可选[bool])- 如果为 True,则返回的 tensor 将是非连续的。如果构造的 tensor 元素少于两个,则忽略此参数。与
memory_format
互斥。exclude_zero(可选[bool])- 如果
True
,则将零替换为根据dtype
的 dtype 的小的正值。对于 bool 和整数类型,零被替换为一。对于浮点类型,它被替换为 dtype 的最小正正常数(dtype
的finfo()
对象的“tiny”值),对于复数类型,它被替换为一个实部和虚部都是复数类型能表示的最小正正常数的复数。默认False
。memory_format(可选 torch.memory_format)- 返回 tensor 的内存格式。与
noncontiguous
互斥。
- 引发:
ValueError - 如果为整型 dtype 传递了
requires_grad=True
。ValueError – 如果
low >= high
。ValueError – 如果
low
或high
是nan
。ValueError – 如果同时传递了
noncontiguous
和memory_format
。TypeError – 如果
dtype
不支持此函数。
- 返回类型:
示例
>>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')