Hardshrink
-
class torch.nn.Hardshrink(lambd=0.5)[source][source]
逐元素应用硬收缩(Hard Shrinkage)函数。
硬收缩定义为:
HardShrink(x)=⎩⎨⎧x,x,0, if x>λ if x<−λ otherwise
- 参数:
lambd (浮点数) – 硬收缩公式的 λ 值。默认:0.5
- 形状:
-
示例:
>>> m = nn.Hardshrink()
>>> input = torch.randn(2)
>>> output = m(input)