• 教程 >
  • 剪枝教程
快捷键

剪枝教程 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:米歇拉·帕甘尼尼

最先进的深度学习技术依赖于过度参数化的模型,这些模型难以部署。相反,生物神经网络被认为使用高效的稀疏连接。识别通过减少模型中的参数数量来压缩模型的最优技术对于减少内存、电池和硬件消耗至关重要,同时不牺牲准确性。这反过来又允许您在设备上部署轻量级模型,并保证通过在设备上进行的私有计算来保证隐私。在研究方面,剪枝用于研究过度参数化和欠参数化网络之间的学习动态差异,研究幸运稀疏子网络和初始化(“彩票”)作为破坏性神经架构搜索技术的角色,以及更多。

在本教程中,您将学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它以实现您自己的自定义剪枝技术。

需求

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建一个模型

在本教程中,我们使用了 LeCun 等人于 1998 年的 LeNet 架构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查一个模块

让我们来检查我们的 LeNet 模型中的(未修剪的) conv1 层。它将包含两个参数 weightbias ,目前没有缓冲区。

module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

剪枝模块

要剪枝一个模块(例如,我们的 LeNet 架构中的 conv1 层),首先在 torch.nn.utils.prune 中(或通过继承 BasePruningMethod 实现自己的)选择一个剪枝技术。然后,指定模块以及该模块中要剪枝的参数名称。最后,使用所选剪枝技术所需的适当关键字参数指定剪枝参数。

在本例中,我们将随机剪枝 conv1 层中名为 weight 的参数的 30%的连接。模块作为函数的第一个参数传递; name 使用其字符串标识符标识该模块内的参数; amount 表示要剪枝的连接的百分比(如果它是一个介于 0.和 1.之间的浮点数),或者要剪枝的连接的绝对数量(如果它是一个非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)

剪枝操作通过从参数中移除 weight 并替换为一个新的参数 weight_orig (即在初始参数 name 后追加 "_orig" )来实现。 weight_orig 存储未剪枝的张量版本。 bias 未被剪枝,因此将保持完整。

print(list(module.named_parameters()))

上文所述的剪枝技术生成的剪枝掩码被保存为名为 weight_mask 的模块缓冲区(即在初始参数 name 后追加 "_mask" )。

print(list(module.named_buffers()))

为了使前向传播无需修改即可工作,需要存在 weight 属性。在 torch.nn.utils.prune 中实现的剪枝技术计算剪枝后的权重(通过将掩码与原始参数组合)并将它们存储在属性 weight 中。注意,这不再是 module 的参数,而是一个简单的属性。

print(module.weight)

最后,在每次前向传播之前使用 PyTorch 的 forward_pre_hooks 应用剪枝。具体来说,当 module 被剪枝时,如我们在此处所做的那样,它将为每个与之关联的剪枝参数获得一个 forward_pre_hook 。在这种情况下,由于我们迄今为止只剪枝了名为 weight 的原始参数,因此只有一个钩子存在。

print(module._forward_pre_hooks)

为了完整性,我们现在可以修剪 bias ,看看 module 的参数、缓冲区、钩子和属性如何变化。仅为了尝试另一种修剪技术,这里我们通过 l1_unstructured 修剪函数对偏差中的 3 个最小条目按 L1 范数进行修剪。

prune.l1_unstructured(module, name="bias", amount=3)

我们现在期望命名参数包括 weight_orig (之前的)和 bias_orig 。缓冲区将包括 weight_maskbias_mask 。两个张量的修剪版本将作为模块属性存在,并且模块现在将有两个 forward_pre_hooks

print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.bias)
print(module._forward_pre_hooks)

迭代修剪

模块中的相同参数可以被修剪多次,各种修剪调用的效果等于各种掩码按顺序应用的组合。新掩码与旧掩码的组合由 PruningContainercompute_mask 方法处理。

例如,现在我们想要进一步剪枝 module.weight ,这次使用结构化剪枝沿张量的 0 轴(0 轴对应于卷积层的输出通道, conv1 的维度为 6),基于通道的 L2 范数。这可以通过使用 ln_structured 函数,以及 n=2dim=0 来实现。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

相应的钩子现在将是类型 torch.nn.utils.prune.PruningContainer ,并将存储应用于 weight 参数的剪枝历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

剪枝模型的序列化

所有相关张量,包括掩码缓冲区和用于计算剪枝张量的原始参数,都存储在模型的 state_dict 中,因此可以轻松序列化和保存,如果需要的话。

print(model.state_dict().keys())

移除剪枝重新参数化 ¶

要使剪枝永久化,请移除关于 weight_origweight_mask 的重新参数化,并移除 forward_pre_hook ,我们可以使用来自 torch.nn.utils.pruneremove 功能。请注意,这并不会撤销剪枝,就像它从未发生一样。它只是通过将参数 weight 重新分配给模型的剪枝版本,使其永久化。

在移除重新参数化之前:

print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.weight)

移除重新参数化之后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))

模型中剪枝多个参数

通过指定所需的剪枝技术和参数,我们可以轻松地剪枝网络中的多个张量,例如根据它们的类型进行剪枝,正如我们将在本例中看到的那样。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

全局剪枝

到目前为止,我们只看了通常所说的“局部”剪枝,即逐个剪枝模型中的张量的做法,即通过将每个条目的统计信息(权重幅度、激活、梯度等)仅与该张量中的其他条目进行比较。然而,一种常见且可能更强大的技术是一次性剪枝整个模型,例如通过移除整个模型中(例如)最低的 20%的连接,而不是移除每个层中最低的 20%的连接。这可能导致每层的剪枝百分比不同。让我们看看如何使用 global_unstructuredtorch.nn.utils.prune 中做到这一点。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在我们可以检查每个剪枝参数引起的稀疏性,这不会在每个层中等于 20%。然而,全局稀疏性将(大约)为 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

扩展 torch.nn.utils.prune 使用自定义剪枝函数

要实现自己的剪枝函数,可以通过扩展 nn.utils.prune 模块,通过子类化 BasePruningMethod 基类来实现,就像所有其他剪枝方法一样。基类为您实现了以下方法: __call__apply_maskapplyprune ,和 remove 。除了某些特殊情况外,您不需要重新实现这些方法来为您的新的剪枝技术。但是,您将必须实现 __init__ (构造函数)和 compute_mask (根据您的剪枝技术逻辑计算给定张量掩码的说明)。此外,您还必须指定此技术实现哪种类型的剪枝(支持的选项是 globalstructuredunstructured )。这是为了确定在迭代应用剪枝的情况下如何组合掩码。换句话说,当剪枝预剪枝参数时,当前剪枝技术应作用于参数的未剪枝部分。指定 PRUNING_TYPE 将使 PruningContainer (处理剪枝掩码的迭代应用)能够正确识别要剪枝的参数部分。

假设,例如,你想实现一种剪枝技术,该技术剪除张量中的每个其他条目(或者如果张量之前已经被剪枝,则在剩余未剪枝的部分)。这将是由于它作用于层的单个连接,而不是整个单元/通道( 'structured' ),或者跨不同参数( 'global' )。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,要将此应用于一个 nn.Module 中的参数,你还应该提供一个简单的函数,该函数实例化该方法并将其应用。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们试试!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源