JitScalarType¶
- class torch.onnx.JitScalarType(value, names=<未指定>, *values, module=None, qualname=None, type=None, start=1, boundary=None) ¶
PyTorch 中定义的标量类型。
使用
JitScalarType
将 PyTorch 和 JIT 标量类型转换为 ONNX 标量类型。示例
>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() TensorProtoDataType.FLOAT
- dtype()[来源][来源]
将 JitScalarType 转换为 PyTorch 数据类型。
- 返回类型:
- classmethod from_dtype(dtype)[source][source]¶
将 torch 数据类型转换为 JitScalarType。
- 注意:当 dtype 来自 torch._C.Value.type()调用时,请不要使用此 API。
在形状信息不存在的情况下,可能会在以下几种场景中引发“RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h”错误。请改用 from_value API,它更安全。
- 参数:
dtype (torch.dtype | None) – 用于创建 JitScalarType 的 torch.dtype
- 返回值:
JitScalarType
- 引发:
OnnxExporterError – 如果 dtype 不是一个有效的 torch.dtype 或者为 None。
- 返回类型:
- from_onnx_type(onnx_type)[source][source] ¶
将 ONNX 数据类型转换为 JitScalarType。
- 参数:
onnx_type (int | torch._C._onnx.TensorProtoDataType | None) – 要创建 JitScalarType 的 torch._C._onnx.TensorProtoDataType
- 返回值:
JitScalarType
- 引发:
OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或者为 None。
- 返回类型:
- classmethod from_value(value, default=None)[source][source]¶
从值的标量类型创建一个 JitScalarType。
- 参数:
从对象中获取标量类型的对象(None | torch.Value | torch.Tensor)
默认值 - 如果无法从 value 中获取有效的标量类型,则返回的 JitScalarType
- 返回值:
JitScalarType.
- 引发:
OnnxExporterError - 如果 value 没有有效的标量类型且默认值为 None。
符号值错误 – 当 value.type() 的信息为空且默认值为 None 时
- 返回类型:
- onnx_type()[来源][来源] ¶
将 JitScalarType 转换为 ONNX 数据类型。
- 返回类型:
TensorProtoDataType
- scalar_name()[来源][来源] ¶
将 JitScalarType 转换为 JIT 标量类型名称。
- 返回类型:
Literal[‘字节’,‘字符’,‘双精度’,‘浮点’,‘半精度’,‘整型’,‘长整型’,‘短整型’,‘布尔’,‘复数半精度’,‘复数浮点’,‘复数双精度’,‘QInt8’,‘QUInt8’,‘QInt32’,‘BFloat16’,‘Float8E5M2’,‘Float8E4M3FN’,‘Float8E5M2FNUZ’,‘Float8E4M3FNUZ’,‘未定义’]
- torch_name()[来源][来源]
将 JitScalarType 转换为 torch 类型名称。
- 返回类型:
Literal[‘布尔’,‘uint8_t’,‘int8_t’,‘double’,‘float’,‘half’,‘int’,‘int64_t’,‘int16_t’,‘complex32’,‘complex64’,‘complex128’,‘qint8’,‘quint8’,‘qint32’,‘bfloat16’,‘float8_e5m2’,‘float8_e4m3fn’,‘float8_e5m2fnuz’,‘float8_e4m3fnuz’]