• 教程 >
  • 模型集成
快捷键

模型集成 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

本教程演示了如何使用 torch.vmap 向量化模型集成。

什么是模型集成?

模型集成是将多个模型的预测结果组合在一起。传统上,这是通过分别对一些输入运行每个模型,然后将预测结果组合在一起来完成的。然而,如果您正在运行具有相同架构的模型,那么可能可以通过使用 torch.vmap 来将它们组合在一起。 vmap 是一个函数转换,它将函数映射到输入张量的维度上。其中一个用例是消除 for 循环并通过向量化加速它们。

让我们通过一个简单的 MLP 集成来演示如何做到这一点。

备注

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

让我们生成一批虚拟数据,假装我们在处理 MNIST 数据集。因此,虚拟图像是 28x28 的,我们有一个大小为 64 的 minibatch。此外,假设我们想结合 10 个不同模型的预测。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我们有几种生成预测的方法。也许我们想给每个模型不同的随机 minibatch 数据。或者,也许我们想将相同的 minibatch 数据通过每个模型(例如,如果我们正在测试不同模型初始化的影响)。

选项 1:每个模型不同的 minibatch

minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

选项 2:相同的 minibatch

minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

使用 vmap 对集成进行向量化 ¶

让我们使用 vmap 来加速 for-loop。我们首先必须使用 vmap 准备模型。

首先,让我们通过堆叠每个参数来组合模型的各个状态。例如, model[i].fc1.weight 的形状为 [784, 128] ;我们将堆叠 10 个模型中的每个 .fc1.weight ,以生成一个形状为 [10, 784, 128] 的大权重。

PyTorch 提供了 torch.func.stack_module_state 便利函数来完成此操作。

from torch.func import stack_module_state

params, buffers = stack_module_state(models)

接下来,我们需要定义一个函数来 vmap 。该函数应该使用给定的参数、缓冲区和输入运行模型。我们将使用 torch.func.functional_call 来帮忙:

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

选项 1:使用每个模型的不同的 minibatch 获取预测。

默认情况下, vmap 将函数映射到所有输入的第一个维度上。使用 stack_module_state 之后,每个 params 和缓冲区前面都会增加一个大小为‘num_models’的维度,minibatch 的维度大小为‘num_models’。

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)

选项 2:使用相同的数据 minibatch 获取预测。

vmap 具有一个 in_dims 参数,用于指定映射的维度。通过使用 None ,我们告诉 vmap 我们希望所有 10 个模型都应用相同的 mini-batch。

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

简要说明: vmap 可以转换的函数类型存在一些限制。最适合转换的函数是纯函数:输出仅由输入决定,没有副作用(例如,修改)。 vmap 无法处理任意 Python 数据结构的修改,但它能够处理许多 PyTorch 的就地操作。

性能 §

对性能数字感到好奇?以下是数字的显示方式。

from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')

使用 vmap !可以大幅提高速度

通常情况下,使用 vmap 进行向量化应该比使用 for 循环运行函数更快,并且与手动批处理具有竞争力。尽管如此,也有一些例外,比如如果我们还没有为特定操作实现 vmap 规则,或者底层内核没有针对旧硬件(GPU)进行优化。如果您遇到这些情况中的任何一种,请通过在 GitHub 上创建问题来告知我们。

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源