在学习李沐老师的目标检测篇章 目标检测中由于负类较多,正类较少,我们可以适当的减少对负类的惩罚
因此根据视频教程我们来重新写一个损失函数
通常mxnet的损失函数需要继承Loss类
from mxnet.gluon.loss import Loss class FocalLoss(gluon.loss.Loss): def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs): super(FocalLoss, self).__init__(None, batch_axis, **kwargs) self._axis = axis self._alpha = alpha self._gamma = gamma self._batch_axis = batch_axis然后在init方法里面初始化各个参数,包括我们的alpha,gamma等等 接着我们重写hybrid_forward这个函数 F这个参数代表的是ndarray或者是一个Symbol output是输出 label表示标签
def hybrid_forward(self, F, output, label): output = F.softmax(output) pj = output.pick(label, axis=self._axis, keepdims=True) loss = - self._alpha*((1-pj)**self._gamma)*pj.log() return loss.mean(axis=self._batch_axis, exclude=True)我们这里使用pick来根据标签选取各个概率 得到pj 接着loss就是按照我们之前定义的损失函数公式编写 最后return的时候按照batch_axis这一维度进行求均值
总的代码如下
class FocalLoss(gluon.loss.Loss): def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs): super(FocalLoss, self).__init__(None, batch_axis, **kwargs) self._axis = axis self._alpha = alpha self._gamma = gamma self._batch_axis = batch_axis def hybrid_forward(self, F, output, label): output = F.softmax(output) pj = output.pick(label, axis=self._axis, keepdims=True) loss = - self._alpha*((1-pj)**self._gamma)*pj.log() return loss.mean(axis=self._batch_axis, exclude=True)然后使用的时候先实例化一个LOSS对象 再传入
cls_loss_v2 = FocalLoss()