torch.nn.utils.prune.ln_structured¶
- torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)[source][source]¶
通过移除指定维度上 L
n
-范数最低的通道来修剪张量。通过移除指定数量的(当前未剪枝的)通道,并沿着指定维度具有最低 L
n
-范数的方向剪枝与参数名为name
的module
对应的张量。就地修改模块(并返回修改后的模块):添加一个名为
name+'_mask'
的命名缓冲区,对应于通过剪枝方法应用于参数name
的二进制掩码。将参数
name
替换为其剪枝版本,而原始(未剪枝)参数存储在名为name+'_orig'
的新参数中。
- 参数:
module (nn.Module) – 包含要剪枝的张量的模块
名称(str)- 在
module
中参数名称,剪枝将在此参数上执行。数量(整数或浮点数)- 要剪枝的参数数量。如果
float
,应在 0.0 和 1.0 之间,表示要剪枝的参数比例。如果int
,表示要剪枝的参数的绝对数量。n(整数、浮点数、inf、-inf、'fro'、'nuc')- 请参阅
torch.norm()
中p
参数的有效条目文档。dim(整数)- 定义剪枝通道的 dim 索引。
importance_scores(torch.Tensor)- 用于计算剪枝掩码的重要性得分张量(与模块参数形状相同)。该张量中的值表示要剪枝的参数中对应元素的重要性。如果未指定或为 None,将使用模块参数代替。
- 返回值:
输入模块的修改版(即剪枝版)
- 返回类型:
模块(nn.Module)
示例
>>> from torch.nn.utils import prune >>> m = prune.ln_structured( ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ... )