本周,我们正式发布了 PyTorch 1.1 版本,这是 PyTorch 1.0 的一个大型功能更新。我们新增的一个新特性是,通过 TorchScript(PyTorch JIT)对快速、自定义循环神经网络(fastrnns)提供了更好的支持(https://maskerprc.github.io/docs/stable/jit.html)。
RNNs 是流行的模型,在各种形状和大小的 NLP 任务上表现出良好的性能。PyTorch 实现了其中一些最受欢迎的模型,包括 Elman RNN、GRU 和 LSTM,以及多层和双向变体。
然而,许多用户希望实现自己的自定义 RNN,从近期文献中汲取灵感。将层归一化应用于 LSTM 就是一个这样的用例。由于 PyTorch CUDA LSTM 实现使用融合内核,因此很难插入归一化或修改基本 LSTM 实现。许多用户已经转向使用标准 PyTorch 运算符编写自定义实现,但这种代码存在高开销:大多数 PyTorch 运算至少在 GPU 上启动一个内核,而 RNN 由于其递归性质通常运行许多运算。然而,我们可以应用 TorchScript 来融合运算并自动优化我们的代码,在 GPU 上启动更少、更优化的内核。
我们的目标是让用户能够在不编写专门的 CUDA 内核以实现类似性能的情况下,使用 TorchScript 编写快速的自定义 RNN。在这篇文章中,我们将提供如何使用 TorchScript 编写自己的快速 RNN 的教程。为了更好地理解 TorchScript 应用的优化,我们将检查这些优化在标准 LSTM 实现上的工作方式,但大多数优化可以应用于一般的 RNN。
编写自定义 RNN
要开始,您可以使用此文件作为模板来编写您自己的自定义 RNN。
我们正在不断改进我们的基础设施,以使性能更好。如果您想获得 TorchScript 当前提供的速度/优化(如算子融合、批量矩阵乘法等),请遵循以下指南。下一节将深入解释优化。
-
如果自定义操作都是逐元素操作,那就太好了,因为您可以自动获得 PyTorch JIT 的算子融合的好处!
-
如果您有更复杂的操作(例如,将 reduce 操作与逐元素操作混合),请考虑将 reduce 操作和逐元素操作分别分组,以便将逐元素操作融合到一个单独的融合组中。
-
如果你想了解你的自定义 RNN 中融合了什么,你可以通过使用
graph_for
来检查操作的优化图。以下是一个使用LSTMCell
的例子:# get inputs and states for LSTMCell inputs = get_lstm_inputs() # instantiate a ScriptModule cell = LSTMCell(input_size, hidden_size) # print the optimized graph using graph_for out = cell(inputs) print(cell.graph_for(inputs))
这将生成针对你提供的特定输入的优化 TorchScript 图(即 PyTorch JIT IR):
graph(%x : Float(*, *), %hx : Float(*, *), %cx : Float(*, *), %w_ih : Float(*, *), %w_hh : Float(*, *), %b_ih : Float(*), %b_hh : Float(*)): %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) return (%30) with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *), %29 : Float(*), %33 : Float(*), %40 : Float(*, *), %43 : Float(*, *), %45 : Float(*, *), %48 : Float(*, *)): %49 : Float(*, *) = aten::t(%48) %47 : Float(*, *) = aten::mm(%45, %49) %44 : Float(*, *) = aten::t(%43) %42 : Float(*, *) = aten::mm(%40, %44) ...some broadcast sizes operations... %hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%13, %346, %345, %344, %343) ...some broadcast sizes operations... return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287) with prim::FusionGroup_0 = graph(%13 : Float(*, *), %71 : Tensor, %76 : Tensor, %81 : Tensor, %86 : Tensor): ...some chunks, constants, and add operations... %ingate.1 : Float(*, *) = aten::sigmoid(%38) %forgetgate.1 : Float(*, *) = aten::sigmoid(%34) %cellgate.1 : Float(*, *) = aten::tanh(%30) %outgate.1 : Float(*, *) = aten::sigmoid(%26) %14 : Float(*, *) = aten::mul(%forgetgate.1, %13) %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) %cy : Float(*, *) = aten::add(%14, %11, %69) %4 : Float(*, *) = aten::tanh(%cy) %hy : Float(*, *) = aten::mul(%outgate.1, %4) return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
从上面的图中我们可以看到,它有一个 prim::FusionGroup_0
子图,该子图融合了 LSTMCell 中的所有逐元素操作(转置和矩阵乘法不是逐元素操作)。一些图节点可能一开始就难以理解,但在优化部分我们将解释其中的一些,我们还在本文中省略了一些只是为了正确性而存在的冗长操作符。
变长序列的最佳实践
TorchScript 不支持 PackedSequence。通常,当处理可变长度的序列时,最好将它们填充到单个张量中,然后将该张量通过 TorchScript LSTM 传递。以下是一个示例:
sequences = [...] # List[Tensor], each Tensor is T' x C
padded = torch.utils.rnn.pad_sequence(sequences)
lengths = [seq.size(0) for seq in sequences]
padded # T x N x C, where N is batch size and T is the max of all T'
model = LSTM(...)
output, hiddens = model(padded)
output # T x N x C
当然, output
可能会在填充区域有一些垃圾数据;使用 lengths
来跟踪你不需要的部分。
优化
现在我们将解释 PyTorch JIT 执行的优化以提高自定义 RNN 的速度。我们将使用一个简单的自定义 LSTM 模型在 TorchScript 中说明优化,但其中许多优化是通用的,适用于其他 RNN。
为了说明我们进行的优化以及我们从这些优化中获得的好处,我们将运行一个简单的自定义 LSTM 模型,该模型是用 TorchScript 编写的(您可以在 custom_lstm.py 或下面的代码片段中查看代码),并计时我们的更改。
我们在一台配备 2 个英特尔至强芯片和一块 Nvidia P100 显卡的机器上搭建了环境,并安装了 cuDNN v7.3 和 CUDA 9.2。LSTM 模型的初步设置如下:
input_size = 512
hidden_size = 512
mini_batch = 64
numLayers = 1
seq_length = 100
PyTorch JIT 最重要的功能是将 Python 程序编译成 PyTorch JIT IR,这是一种中间表示,用于表示程序的图结构。这种 IR 可以受益于整个程序优化、硬件加速,并且总体上具有提供大量计算增益的潜力。在这个例子中,我们运行了初始的 TorchScript 模型,只使用了 JIT 提供的编译器优化传递,包括常见子表达式消除、常量池化、常量传播、死代码消除和一些窥孔优化。在预热后,我们对模型进行了 100 次训练,并平均了训练时间。模型前向时间大约为 27ms,后向时间大约为 64ms,这与 PyTorch cuDNN LSTM 提供的结果有一定差距。接下来,我们将解释我们在训练或推理性能提升方面所做的主要优化,从 LSTMCell 和 LSTMLayer 开始,以及一些其他优化。
LSTM 单元(正向)
LSTM 中的几乎所有计算都在 LSTMCell 中发生,因此了解它包含的计算以及如何提高其速度对我们来说非常重要。以下是 TorchScript 中的 LSTMCell 示例实现:
class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
TorchScript 生成的此图表示(IR)可以启用多种优化和可扩展的计算。除了我们通常可以做的编译器优化(如循环展开、常量传播等)之外,我们还可以运行其他 IR 转换来使我们的代码运行得更快。
- 元素级算子融合。PyTorch JIT 会自动融合元素级操作,因此当您有相邻的元素级操作时,JIT 会自动将这些操作组合成一个 FusionGroup,然后这个 FusionGroup 可以由单个 GPU/CPU 内核启动并一次性执行。这避免了每个操作昂贵的内存读写。
- 重新排列块和逐点操作以实现更多融合。LSTM 单元将门一起添加(逐点操作),然后将门分成四部分:ifco 门。然后,它对 ifco 门执行逐点操作,如上所述。这在实践中导致两个融合组:一个用于预块逐元素操作的融合组,一个用于后块逐元素操作的融合组。这里有趣的是要注意的是,逐点操作与
torch.chunk
交换:我们可以在某些输入张量上执行逐点操作并对输出进行分块,而不是在输出张量上执行相同的逐点操作并分块输入张量。通过将分块移动到第一个融合组之前,我们可以将第一个和第二个融合组合并成一个大的组。

- CPU 上创建张量代价高昂,但正在进行工作以使其更快。目前,一个 LSTMCell 运行三个 CUDA 内核:两个
gemm
内核和一个用于单个 pointwise 组的内核。我们注意到的一个问题是,第二个gemm
的完成和单个 pointwise 组的开始之间存在很大的差距。这个差距是 GPU 空闲且未执行任何操作的一段时间。进一步调查后,我们发现问题是torch.chunk
构建新的张量,而张量构建并不像可能的那样快。我们不是构建新的 Tensor 对象,而是教会了融合编译器如何操作数据指针和步长,在将其发送到融合内核之前完成torch.chunk
,从而减少了第二个 gemm 和元素级融合组启动之间的空闲时间。这使我们 LSTM 正向传递的速度提高了约 1.2 倍。
通过上述技巧,我们能够将几乎所有的 LSTMCell
前向图(除了两个 gemm 内核)融合成一个单一的融合组,这对应于上述 IR 图中的 prim::FusionGroup_0
。然后它将被启动为一个单一的融合内核进行执行。通过这些优化,模型性能显著提高,平均前向时间减少了约 17ms(1.7 倍速度提升)至 10ms,平均反向时间减少了 37ms 至 27ms(1.37 倍速度提升)。
LSTM 层(前向)
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(LSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
我们对为 TorchScript LSTM 生成的 IR 进行了多项技巧优化,以下是一些示例优化:
- 循环展开:我们自动展开代码中的循环(对于大循环,我们展开其中一小部分),这使我们能够进一步优化 for 循环的控制流。例如,fuser 可以将循环体迭代之间的操作融合在一起,这对于控制流密集型模型如 LSTM 来说,可以带来良好的性能提升。
- 批量矩阵乘法:对于输入已预乘的 RNN(即模型中有大量相同的 LHS 或 RHS 的矩阵乘法),我们可以将这些操作有效地批量组合成一个单一的矩阵乘法,同时将输出分块以实现等效语义。
通过应用这些技术,我们在正向传播中额外减少了 1.6ms 到 8.4ms(1.2 倍速度提升),反向传播中减少了 7ms 到大约 20ms(1.35 倍速度提升)。
LSTM 层(反向)
-
“树”批量矩阵乘法:在 LSTM 反向图中,单个权重通常被多次重用,形成一个叶子为矩阵乘法、节点为加法的树。这些节点可以通过在不同维度上连接 LHS 和 RHS 来组合在一起,然后作为一个单一的矩阵乘法来计算。等效公式的表示如下:
$L1 * R1 + L2 * R2 = torch.cat((L1, L2), dim=1) * torch.cat((R1, R2), dim=0)$
-
自动微分是 PyTorch 成为优雅的机器学习框架的关键组成部分。因此,我们在 PyTorch JIT 中也采用了这种思想,但使用了一种新的在 IR(中间表示)级别上工作的自动微分(AD)机制。JIT 自动微分将正向图切割成符号可微分的子图,并为这些子图生成反向节点。以上述 IR 为例,我们将图节点分组为单个
prim::DifferentiableGraph_0
,用于具有 AD 公式的操作。对于尚未添加到 AD 公式中的操作,在执行过程中将回退到 Autograd。 -
优化反向路径很困难,隐式广播语义使得自动微分优化更加困难。PyTorch 通过为你广播张量,使得编写张量操作变得方便,无需担心形状。在反向传播中,性能的痛点在于我们需要对这种可广播的操作进行求和。这导致每个可广播操作的导数后面都跟着一个求和操作。由于我们目前无法融合 reduce 操作,这导致 FusionGroups 分解成多个小组,导致性能下降。要解决这个问题,请参考托马斯·维曼姆(Thomas Viehmann)写的这篇优秀的文章。
其他优化
- 除了上述步骤之外,我们还消除了 CUDA 内核启动之间的开销以及不必要的张量分配。一个例子是在你进行张量设备查找时。这可能会在初始阶段提供一些性能较差的结果,伴随着大量的不必要的分配。当我们移除这些操作后,内核启动之间的时间从毫秒减少到纳秒。
- 最后,在自定义 LSTMCell 中可能应用了归一化,例如 LayerNorm。由于 LayerNorm 和其他归一化操作包含 reduce 操作,将其完全融合比较困难。因此,我们自动将 LayerNorm 分解为统计计算(reduce 操作)+逐元素变换,然后将这些逐元素部分融合在一起。截至本文发布时,我们的自动微分和图融合基础设施存在一些限制,这限制了当前的支持仅限于推理模式。我们计划在未来的版本中添加反向支持。
通过上述操作融合、循环展开、批量矩阵乘法和一些其他优化,我们可以从以下图中清楚地看到我们自定义的 TorchScript LSTM 正向和反向的性能提升:

本文未涵盖的许多额外优化。除了本文中提到的优化之外,我们现在看到我们的自定义 LSTM 正向传播与 cuDNN 相当。我们还在优化反向传播,并期待在未来版本中看到改进。除了 TorchScript 提供的速度外,我们还引入了一个更加灵活的 API,使您能够手动设计更多自定义 RNN,这是 cuDNN 无法提供的。