• 文档 >
  • torch.onnx >
  • 基于 TorchScript 的 ONNX 导出器 >
  • JitScalarType
快捷键

JitScalarType

class torch.onnx.JitScalarType(value, names=<未指定>, *values, module=None, qualname=None, type=None, start=1, boundary=None) ¶

PyTorch 中定义的标量类型。

使用 JitScalarType 将 PyTorch 和 JIT 标量类型转换为 ONNX 标量类型。

示例

>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type()
TensorProtoDataType.FLOAT
dtype()[来源][来源]

将 JitScalarType 转换为 PyTorch 数据类型。

返回类型:

数据类型

classmethod from_dtype(dtype)[source][source]

将 torch 数据类型转换为 JitScalarType。

注意:当 dtype 来自 torch._C.Value.type()调用时,请不要使用此 API。

在形状信息不存在的情况下,可能会在以下几种场景中引发“RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h”错误。请改用 from_value API,它更安全。

参数:

dtype (torch.dtype | None) – 用于创建 JitScalarType 的 torch.dtype

返回值:

JitScalarType

引发:

OnnxExporterError – 如果 dtype 不是一个有效的 torch.dtype 或者为 None。

返回类型:

JitScalarType

from_onnx_type(onnx_type)[source][source] ¶

将 ONNX 数据类型转换为 JitScalarType。

参数:

onnx_type (int | torch._C._onnx.TensorProtoDataType | None) – 要创建 JitScalarType 的 torch._C._onnx.TensorProtoDataType

返回值:

JitScalarType

引发:

OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或者为 None。

返回类型:

JitScalarType

classmethod from_value(value, default=None)[source][source]

从值的标量类型创建一个 JitScalarType。

参数:
  • 从对象中获取标量类型的对象(None | torch.Value | torch.Tensor)

  • 默认值 - 如果无法从 value 中获取有效的标量类型,则返回的 JitScalarType

返回值:

JitScalarType.

引发:
  • OnnxExporterError - 如果 value 没有有效的标量类型且默认值为 None。

  • 符号值错误 – 当 value.type() 的信息为空且默认值为 None 时

返回类型:

JitScalarType

onnx_compatible()[source][source]

返回此 JitScalarType 是否与 ONNX 兼容。

返回类型:

布尔型

onnx_type()[来源][来源] ¶

将 JitScalarType 转换为 ONNX 数据类型。

返回类型:

TensorProtoDataType

scalar_name()[来源][来源] ¶

将 JitScalarType 转换为 JIT 标量类型名称。

返回类型:

Literal[‘字节’,‘字符’,‘双精度’,‘浮点’,‘半精度’,‘整型’,‘长整型’,‘短整型’,‘布尔’,‘复数半精度’,‘复数浮点’,‘复数双精度’,‘QInt8’,‘QUInt8’,‘QInt32’,‘BFloat16’,‘Float8E5M2’,‘Float8E4M3FN’,‘Float8E5M2FNUZ’,‘Float8E4M3FNUZ’,‘未定义’]

torch_name()[来源][来源]

将 JitScalarType 转换为 torch 类型名称。

返回类型:

Literal[‘布尔’,‘uint8_t’,‘int8_t’,‘double’,‘float’,‘half’,‘int’,‘int64_t’,‘int16_t’,‘complex32’,‘complex64’,‘complex128’,‘qint8’,‘quint8’,‘qint32’,‘bfloat16’,‘float8_e5m2’,‘float8_e4m3fn’,‘float8_e5m2fnuz’,‘float8_e4m3fnuz’]


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源