• 文档 >
  • torch >
  • torch.set_default_dtype
快捷键

torch.set_default_dtype

torch.set_default_dtype(d, /)[source][source]

将默认浮点数据类型设置为 d 。支持浮点数据类型作为输入。其他数据类型会导致 torch 抛出异常。

当 PyTorch 初始化时,其默认浮点数据类型为 torch.float32,set_default_dtype(torch.float64)的目的是为了方便 NumPy-like 类型推断。默认浮点数据类型用于:

  1. 隐式确定默认复数数据类型。当默认浮点类型为 float16 时,默认复数数据类型为 complex32。对于 float32,默认复数数据类型为 complex64。对于 float64,则为 complex128。对于 bfloat16,将抛出异常,因为没有对应于 bfloat16 的复数类型。

  2. 推断使用 Python 浮点数或复数 Python 数字构建的张量的数据类型。请参见以下示例。

  3. 确定布尔型和整型张量以及 Python 浮点数和复数 Python 数之间的类型提升结果。

参数:

d ( torch.dtype ) – 将默认的浮点类型设置为。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源