快速傅里叶变换(FFT)以 O(n log n)的时间复杂度计算离散傅里叶变换。由于它使得在信号的“频域”中工作与在空间或时间域中工作一样容易,因此它是众多数值算法和信号处理技术的基石。
作为 PyTorch 支持硬件加速深度学习和科学计算的目标的一部分,我们投资于改进我们的 FFT 支持,并在 PyTorch 1.8 版本中发布了 torch.fft
模块。此模块实现了与 NumPy 的 np.fft
模块相同的函数,但支持加速器,如 GPU,和自动微分。
开始使用
无论您是否熟悉 NumPy 的 np.fft
模块,开始使用新的 torch.fft
模块都很简单。虽然该模块中每个函数的完整文档可以在此处找到,但以下是对其提供的功能的概述:
-
fft
,它在一个维度上计算复杂的 FFT,以及其逆变换ifft
- 更通用的
fftn
和ifftn
,支持多维度 - “实数”FFT 函数,
rfft
,irfft
,rfftn
,irfftn
,旨在处理时域中值为实数的信号 - “厄米特”FFT 函数,
hfft
和ihfft
,旨在处理频域中值为实数的信号 - 辅助函数,如
fftfreq
、rfftfreq
、fftshift
、ifftshift
,使信号操作更加便捷
我们认为这些函数提供了一个直观的 FFT 功能接口,这已经得到了 NumPy 社区的验证,尽管我们始终欢迎反馈和建议!
为了更好地说明从 NumPy 的 np.fft
模块迁移到 PyTorch 的 torch.fft
模块有多简单,让我们来看一个 NumPy 实现的简单低通滤波器,该滤波器可以从二维图像中移除高频方差,这是一种降噪或模糊的形式:
import numpy as np
import numpy.fft as fft
def lowpass_np(input, limit):
pass1 = np.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = np.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = np.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
现在让我们看看同样的滤波器在 PyTorch 中的实现:
import torch
import torch.fft as fft
def lowpass_torch(input, limit):
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = torch.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
不仅 NumPy 的 np.fft
模块的当前用法可以直接转换为 torch.fft
,而且 torch.fft
操作也支持在加速器(如 GPU 和 autograd)上的张量。这使得(其他事物中)使用 FFT 开发新的神经网络模块成为可能。
性能
torch.fft
模块不仅易于使用,而且速度也非常快!PyTorch 原生支持 Intel 的 MKL-FFT 库在 Intel CPU 上,以及 NVIDIA 的 cuFFT 库在 CUDA 设备上,我们已仔细优化了使用这些库的方式,以最大化性能。虽然您的结果将取决于您的 CPU 和 CUDA 硬件,但在 CUDA 设备上计算快速傅里叶变换(FFT)可能比在 CPU 上快得多,尤其是在处理较大信号时。
未来,我们可能会添加对其他数学库的支持,以支持更多硬件。有关如何请求额外硬件支持的信息,请参阅以下内容。
更新到较旧的 PyTorch 版本
一些 PyTorch 用户可能知道,PyTorch 的旧版本也提供了 FFT 功能,使用 torch.fft()
函数。不幸的是,由于该函数的名称与新模块的名称冲突,不得不将其删除,我们认为新的功能是使用 PyTorch 中的快速傅里叶变换的最佳方式。特别是, torch.fft()
是在 PyTorch 支持复数张量之前开发的,而 torch.fft
模块是为了与它们一起工作而设计的。
PyTorch 还有一个“短时傅里叶变换”, torch.stft
,及其逆变换 torch.istft
。这些函数将被保留,但将更新以支持复数张量。
未来
如前所述,PyTorch 1.8 提供了 torch.fft 模块,这使得在加速器和支持 autograd 的情况下使用快速傅里叶变换(FFT)变得容易。我们鼓励您尝试一下!
虽然这个模块到目前为止是模仿 NumPy 的 np.fft
模块,但我们不会就此止步。我们渴望听到来自您,我们社区的声音,关于您需要的 FFT 相关功能,我们鼓励您在我们的论坛 https://discuss.pytorch.org/上创建帖子,或者在 Github 上提交反馈和请求的问题。早期用户已经开始询问离散余弦变换以及更多硬件平台的支持,例如,我们目前正在调查这些功能。
我们期待听到您的反馈,并看到社区如何使用 PyTorch 的新 FFT 功能!