torch.fft 模块:在 PyTorch 中使用 Autograd 加速快速傅立叶变换
共 2430字,需浏览 5分钟
·
2021-06-21 14:52
快速傅立叶变换 (FFT) 在 O(n log n) 时间内计算离散傅立叶变换。它是各种数值算法和信号处理技术的基础,因为它使在信号的“频域”中工作与在空间或时间域中工作一样容易处理。
作为 PyTorch 支持硬件加速深度学习和科学计算目标的一部分,我们已投资改进我们的 FFT 支持,并且在 PyTorch 1.8 中,我们将发布该torch.fft
模块。该模块实现与 NumPynp.fft
模块相同的功能,但支持加速器,如 GPU 和 autograd。
入门
torch.fft
无论您是否熟悉 NumPy 的np.fft
模块,新模块的入门都很容易。虽然可以在此处找到模块中每个功能的完整文档,但它提供的细分如下:
fft
,它计算一个单一维度的复数 FFT,以及ifft
,它的逆更通用的
fftn
andifftn
,支持多维“真实”FFT 函数,
rfft
,irfft
,rfftn
,irfftn
设计用于处理在其时域中为实值的信号“厄米” FFT 函数
hfft
和ihfft
,设计用于处理在其频域中为实值的信号辅助函数,例如
fftfreq
,rfftfreq
,fftshift
,ifftshift
,可以更轻松地操作信号
我们认为这些函数为 FFT 功能提供了一个简单的接口,正如 NumPy 社区所审查的那样,尽管我们总是对反馈和建议感兴趣!
为了更好地说明从 NumPy 的np.fft
模块转移到 PyTorch 的torch.fft
模块是多么容易,让我们看一个简单的低通滤波器的 NumPy 实现,它从二维图像中去除高频方差,一种降噪或模糊的形式:
import numpy as np
import numpy.fft as fft
def lowpass_np(input, limit):
pass1 = np.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = np.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = np.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
现在让我们看看在 PyTorch 中实现的相同过滤器:
import torch
import torch.fft as fft
def lowpass_torch(input, limit):
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit
kernel = torch.outer(pass2, pass1)
fft_input = fft.rfft2(input)
return fft.irfft2(fft_input * kernel, s=input.shape[-2:])
不仅当前使用的 NumPynp.fft
模块直接转换为torch.fft
,torch.fft
操作还支持加速器上的张量,如 GPU 和 autograd。这使得(除其他外)使用 FFT 开发新的神经网络模块成为可能。
表现
该torch.fft
模块不仅易于使用 - 速度也很快!PyTorch 在 Intel CPU 上原生支持 Intel 的 MKL-FFT 库,在 CUDA 设备上支持 NVIDIA 的 cuFFT 库,我们已经仔细优化了我们如何使用这些库来最大化性能。虽然您自己的结果将取决于您的 CPU 和 CUDA 硬件,但在 CUDA 设备上计算快速傅立叶变换可能比在 CPU 上计算快很多倍,尤其是对于较大的信号。
将来,我们可能会添加对其他数学库的支持,以支持更多硬件。请参阅下文,了解您可以在哪里请求额外的硬件支持。
从旧的 PYTORCH 版本更新
一些 PyTorch 用户可能知道旧版本的 PyTorch 也提供了带有该torch.fft()
功能的FFT功能。不幸的是,这个函数不得不被删除,因为它的名称与新模块的名称冲突,我们认为新功能是在 PyTorch 中使用快速傅立叶变换的最佳方式。特别是,它torch.fft()
是在 PyTorch 支持复杂张量之前开发的,而该torch.fft
模块旨在与它们一起工作。
PyTorch 也有一个“短时傅立叶变换” torch.stft
,以及它的逆torch.istft
. 这些函数被保留但更新以支持复杂的张量。
未来
如前所述,PyTorch 1.8 提供了 torch.fft 模块,这使得在加速器上使用快速傅立叶变换 (FFT) 并支持 autograd 变得容易。我们鼓励您尝试一下!
虽然np.fft
到目前为止这个模块是在 NumPy 的模块之后建模的,但我们并没有就此止步。我们渴望听到您和我们的社区关于您需要哪些 FFT 相关功能的意见,我们鼓励您在https://discuss.pytorch.org/的论坛上发帖,或在我们的 Github 上将问题与您的反馈和要求。例如,早期采用者已经开始询问离散余弦变换和对更多硬件平台的支持,我们现在正在研究这些功能。