快捷键

torch.linalg.matrix_power

torch.linalg.matrix_power(A, n, *, out=None) Tensor

计算整数 n 的方阵 n 次幂。

支持输入 float、double、cfloat 和 cdouble 数据类型。也支持矩阵批处理,如果 A 是矩阵批,则输出具有相同的批维度。

如果 n 等于 0,则返回与 A 相同形状的恒等矩阵(或批量)。如果 n 为负数,则返回每个矩阵(如果可逆)的绝对值 n 次幂的逆矩阵。

注意

如果可能,请考虑使用 torch.linalg.solve() 来对矩阵左乘以负幂,因为如果 n > 0:

torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n)  @ B

在可能的情况下,始终推荐使用 solve() ,因为它比显式计算 AnA^{-n} 更快且更数值稳定。

参见

torch.linalg.solve() 使用数值稳定的算法计算 A .inverse() @ B

参数:
  • A(张量)- 形状为 (*, m, m) 的张量,其中 * 是零个或多个批处理维度。

  • n(整数)- 指数。

关键字参数:

out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。

引发:

RuntimeError - 如果 n 小于 0 并且矩阵 A 或矩阵批次 A 中的任何矩阵不可逆。

示例:

>>> A = torch.randn(3, 3)
>>> torch.linalg.matrix_power(A, 0)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> torch.linalg.matrix_power(A, 3)
tensor([[ 1.0756,  0.4980,  0.0100],
        [-1.6617,  1.4994, -1.9980],
        [-0.4509,  0.2731,  0.8001]])
>>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2)
tensor([[[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]],
        [[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源