torch.nn.utils.prune.identity¶
- torch.nn.utils.prune.identity(module, name)[source][source]¶
应用剪枝重参数化而不剪枝任何单元。
对名为
name
的参数对应的张量进行剪枝重参数化,实际上不剪除任何单元。就地修改模块(并返回修改后的模块)的方式是:添加一个名为
name+'_mask'
的缓冲区,对应于剪枝方法应用于参数name
的二进制掩码。用其剪枝版本替换参数
name
,而原始(未剪枝)参数存储在名为name+'_orig'
的新参数中。
注意
掩码是一个全为 1 的张量。
- 参数:
包含要剪枝的张量的模块(nn.Module)
名称(str)- 在
module
中参数名称,剪枝将在此参数上执行。
- 返回值:
输入模块的修改版(即剪枝版)
- 返回类型:
模块(nn.Module)
示例
>>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.])