余弦相似度 ¶
- class torch.nn.CosineSimilarity(dim=1, eps=1e-08)[source][source]¶
在 dim 维度上计算 和 之间的余弦相似度。
- 参数:
dim(int,可选)- 计算余弦相似度的维度。默认:1
eps(浮点数,可选)- 避免除以零的小值。默认:1e-8
- 形状:
Input1: 其中 D 位于 dim 位置
- Input2: ,与 x1 相同的维度数量,
并且在其他维度上可以广播与 x1 兼容。
输出:
- 示例::
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2)