深度学习分类、识别等任务常用的余弦距离和对应的PyTorch代码

mac2025-02-17  8

余弦距离常常在人脸识别,图像分类,行人重识别中应用。自从centerNet可视化了softmax loss之后,人们得知神经网络的输出空间原来是呈现原点向外发散状,分类结果是可以通过判断两个样本在输出空间对应的向量之间的夹角来得知是否是同一类样本。这个夹角就是所谓的余弦距离,夹角越小,两个样本越相似。

预备的数学知识

cos曲线:

比如现在有样本A,B,对应在输出空间的特征向量分别是, , 先对这两个特征值除以各自的模。

根据求向量之间夹角公式,A,B之间的角度的cos值就是:

这个值越大,说明,向量夹角越小,说明越相似。

 

Pytorch代码

from torch.nn import functional as F def calculate_cos_distance(a,b): a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) cose = torch.mm(a,b) return 1 - cose

 

最新回复(0)