torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[source][source]¶
设置 float32 矩阵乘法的内部精度。
在低精度下运行 float32 矩阵乘法可以显著提高性能,在某些程序中,精度损失的影响可以忽略不计。
支持三种设置:
“最高”,float32 矩阵乘法使用 float32 数据类型(24 位尾数,其中 23 位显式存储)进行内部计算。
“高”,float32 矩阵乘法要么使用 TensorFloat32 数据类型(10 位尾数显式存储),要么如果可用适当的快速矩阵乘法算法,将每个 float32 数字视为两个 bfloat16 数字的和(大约 16 位尾数,其中 14 位显式存储)。否则,float32 矩阵乘法将像“最高”精度一样计算。有关 bfloat16 方法的更多信息,请参阅下文。
“medium”,使用 float32 矩阵乘法时,如果内部可用使用 bfloat16 数据类型(8 位尾数,其中 7 位显式存储)进行计算,则使用该数据类型进行内部计算。否则,float32 矩阵乘法将按“高精度”计算。
使用“高精度”时,float32 乘法可能使用基于 bfloat16 的算法,该算法比简单地截断到一些较小的数尾数位(例如 TensorFloat32 的 10 位,bfloat16 显式存储的 7 位)更复杂。有关此算法的完整描述,请参阅[Henry2019]。在此简要解释,第一步是意识到我们可以完美地将单个 float32 数字编码为三个 bfloat16 数字之和(因为 float32 有 23 位尾数位,而 bfloat16 有 7 位显式存储,并且两者都有相同数量的指数位)。这意味着两个 float32 数字的乘积可以精确地由九个 bfloat16 数字乘积之和给出。然后我们可以通过省略一些这些乘积来以速度换取精度。具体来说,“高精度”算法仅保留三个最重要的乘积,这方便地排除了涉及任一输入最后 8 位尾数的所有乘积。这意味着我们可以将我们的输入表示为两个 bfloat16 数字之和,而不是三个。 因为 bfloat16 融合乘加(FMA)指令通常比 float32 指令快 10 倍以上,所以使用 bfloat16 精度进行三次乘法和两次加法比使用 float32 精度进行一次乘法要快。
[Henry2019]注意
这不会改变 float32 矩阵乘法的输出数据类型,它控制矩阵乘法内部计算的方式。
注意
这不会改变卷积操作的精度。其他标志,如 torch.backends.cudnn.allow_tf32,可能控制卷积操作的精度。
注意
此标志目前仅影响一种本地设备类型:CUDA。如果设置为“高”或“中”,则在计算 float32 矩阵乘法时将使用 TensorFloat32 数据类型,相当于设置 torch.backends.cuda.matmul.allow_tf32 = True。当设置为“最高”(默认值)时,内部计算将使用 float32 数据类型,相当于设置 torch.backends.cuda.matmul.allow_tf32 = False。
- 参数:
精度(字符串)- 可以设置为“最高”(默认值)、“高”或“中”(见上文)。