• 文档 >
  • 复数
快捷键

复数 ¶

复数是能够表示为 a 和 b 为实数,j 为虚数单位的形式 a + bi 的数,其中虚数单位满足方程 j^2 = -1。复数在数学和工程学中经常出现,尤其是在信号处理等主题中。传统上,许多用户和库(例如 TorchAudio)通过将数据表示为形状为 (2,) 的浮点张量来处理复数,其中最后一个维度包含实部和虚部值。

复数类型的张量提供了更自然的用户体验,在处理复数时。对复数张量(例如,加法、乘法)的操作可能比模拟它们的浮点张量操作更快、更节省内存。PyTorch 中对复数的操作已优化,以使用向量化的汇编指令和专用内核(例如 LAPACK、cuBlas)。

注意

torch.fft 模块中的频谱操作支持原生复数张量。

警告

复数张量是一个测试功能,可能会发生变化。

创建复数张量

我们支持两种复数数据类型:torch.cfloat 和 torch.cdouble

>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
     [ 0.7706+0.1421j,  1.2110+0.1918j]])

注意

复杂张量的默认数据类型由默认的浮点数据类型决定。如果默认的浮点数据类型是 torch.float64,则复数推断为 torch.complex128 数据类型,否则假定复数的数据类型为 torch.complex64。

除了 torch.linspace()torch.logspace()torch.arange() 之外的所有工厂函数都支持复数张量。

从旧表示法的过渡 ¶

目前使用形状为 (...,2)(..., 2) 的实张量绕过复数张量缺失的用户可以轻松地使用代码中的复数张量进行切换,使用 torch.view_as_complex()torch.view_as_real() 。请注意,这些函数不执行任何复制,并返回输入张量的视图。

>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])

访问实部和虚部 ¶

复张量的实部和虚部可以通过使用 realimag 来访问。

注意

访问实部和虚部属性不会分配任何内存,对实部和虚部张量的就地更新将更新原始复数张量。此外,返回的实部和虚部张量不是连续的。

>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681,  1.3487, -0.7981])

>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)

角度和模长 ¶

复合张量的角度和绝对值可以使用 torch.angle()torch.abs() 进行计算。

>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])

线性代数 ¶

许多线性代数操作,如 torch.matmul()torch.linalg.svd()torch.linalg.solve() 等,支持复数。如果您想请求我们目前不支持的操作,请先搜索是否已提交相关问题,如果没有,请提交一个问题。

序列化 ¶

复杂张量可以被序列化,允许将数据保存为复数值。

>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])

自动微分 ¶

PyTorch 支持复数张量的自动微分。计算出的梯度是共轭 Wirtinger 导数,其负值正是梯度下降算法中使用的最速下降方向。因此,所有现有的优化器都可以直接用于复数参数。更多详情,请查看关于复数自动微分的笔记。

优化器 ¶

从语义上讲,我们将使用复数参数的 PyTorch 优化器进行迭代定义为与在复数参数的 torch.view_as_real() 等价形式上迭代相同的优化器等价。更具体地说:

>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]

>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)

real_optim 和 complex_optim 将在参数上计算相同的更新,尽管这两个优化器之间可能存在轻微的数值差异,类似于 foreach 和 forloop 优化器以及可捕获和默认优化器之间的数值差异。更多详情,请参阅 https://pytorch.org/docs/stable/notes/numerical_accuracy.html。

具体来说,虽然你可以将我们的优化器处理复杂张量的方式视为分别优化它们的 p.real 和 p.imag 部分,但实现细节并非完全如此。请注意, torch.view_as_real() 等效于将一个复数张量转换为形状为 (...,2)(..., 2) 的实数张量,而将复数张量拆分为两个张量则是 2 个大小为 (...)(...) 的张量。这种区别对点对点优化器(如 AdamW)没有影响,但会导致全局减少优化器(如 LBFGS)产生轻微的差异。我们目前没有执行逐张量减少的优化器,因此尚未定义此行为。如果您有需要精确定义此行为的用例,请提交一个问题。

我们不完全支持以下子系统:

  • 量化

  • JIT

  • 稀疏张量

  • 分布式

如果这些中的任何一项能帮助您的用例,请搜索是否已提交相关问题,如果没有,请提交一个问题。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源