线性 ¶
-
class torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[source][source]
对传入数据进行仿射线性变换: y=xAT+b 。
此模块支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,该模块将使用不同的精度进行反向操作。
- 参数:
in_features (int) – 每个输入样本的大小
out_features (int) – 每个输出样本的大小
bias (bool) – 如果设置为 False ,则层将不会学习加性偏置。默认: True
- 形状:
输入: (∗,Hin) 其中 ∗ 表示包括零在内的任意维数和 Hin=in_features 。
输出: (∗,Hout) 其中除了最后一个维度外,所有维度都与输入相同形状, Hout=out_features 。
- 变量:
weight (torch.Tensor) – 该模块的可学习权重,形状为 (out_features,in_features) 。值从 U(−k,k) 初始化,其中 k=in_features1
偏差 - 该模块的形状为 (out_features) 的可学习偏差。如果 bias 是 True ,则值从 U(−k,k) 初始化,其中 k=in_features1
示例:
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])