torch.nn.functional.cosine_similarity¶
- torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) → Tensor
返回
x1
和x2
之间的余弦相似度,沿 dim 计算。x1
和x2
必须广播到公共形状。dim
指的是这个公共形状中的维度。输出维度的dim
被压缩(见torch.squeeze()
),结果输出张量少一个维度。支持类型提升。
- 参数:
x1(张量)- 第一个输入。
x2(张量)- 第二个输入。
dim(int,可选)- 计算余弦相似度的维度。默认:1
eps(浮点数,可选)- 避免除以零的小值。默认:1e-8
示例:
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output)