目标检测 YOLOV5:loss介绍

目录

1.BCEWithLogitsLoss

1.1pytorch源码中的相关代码

1.2 数学原理

2.FocalLoss

2.1 pytorch源码

2.2 数学原理


1.BCEWithLogitsLoss

1.1pytorch源码中的相关代码

class BCEWithLogitsLoss(_Loss):
    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
                 pos_weight: Optional[Tensor] = None) -> None:
        super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)
        self.weight: Optional[Tensor]
        self.pos_weight: Optional[Tensor]

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.binary_cross_entropy_with_logits(input, target,
                                                  self.weight,
                                                  pos_weight=self.pos_weight,
                                                  reduction=self.reduction)

1.2 数学原理

BCEWithLogitsLoss是将BCELoss(BCE:Binary cross entropy)和sigmoid融合了,也就是说省略了sigmoid这个步骤;

BCELoss的数学公式:

-\frac{1}{n}\sum (y_{n}\times lnx_{n}+(1-y_{n})\times ln(1-x_{n}))

对于二分类的三个训练样本,计算方法:

import torch
import torch.nn as nn

input = torch.randn(3,3)
target = torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])

loss = nn.BCELoss()
m = nn.Sigmoid()
input_m = m(input)
result = loss(input_m, target) 

结果 result=tensor(1.0224)

而使用BCEWithLogitsLoss

loss_1 = nn.BCEWithLogitsLoss()
result = loss(input_m, target)

结果result=tensor(1.0224)

2.FocalLoss

2.1 pytorch源码

class FocalLoss(nn.Module):
    # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply FL to each element

    def forward(self, pred, true):
        loss = self.loss_fcn(pred, true)
        # p_t = torch.exp(-loss)
        # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

        # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
        pred_prob = torch.sigmoid(pred)  # prob from logits
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        modulating_factor = (1.0 - p_t) ** self.gamma
        loss *= alpha_factor * modulating_factor

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss

2.2 数学原理

Focal Loss 是何恺明设计的为了解决one-stage目标检测在训练阶段前景类和背景类极度不均衡(如1:1000)的场景的损失函数。它是由二分类交叉熵改造而来的。

L_{fl}=\left\{\begin{matrix} -\alpha (1-y^{'})^{r}logy^{'}, y=1& \\ -(1-\alpha )y^{'r}log(1-y^{'}), y=0 & \end{matrix}\right.

其中,\alpha\gamma均可以调节的超参数。y^{'}为模型预测,其值介于(0-1)之间。

当y=1时,y^{'}->1,表示easy positive,它对权重的贡献->0;

当y=0时,y^{'}->0,表示easy negative,它对权重的贡献->0.

因此,Focal Loss降低了背景类的同时,也降低了easy positive和easy negative的权重;

\gamma是对Focal Loss的调节;

由标准交叉熵推理出Focal Loss:

标准交叉熵

CE(p,y)=\left\{\begin{matrix} -log(p) & if\: y=1\\ -log(1-p) &otherwise \end{matrix}\right.

其中,p是模型预测属于y=1的概率。为了方便标记,定义如下:

pt=\left\{\begin{matrix} p &if\: y=1 \\ 1-p& otherwise \end{matrix}\right.

交叉熵CE重写为:

CE(p,y)=CE(pt)=-log(pt)

\alpha-平衡交叉熵:

有一种解决类别不平衡的方法就是引入[0,1]之间的权重因子\alpha:当y=1时,取\alpha;当 y=0时,取1-\alpha.随着\alpha

的增大,会对背景类的权重进行降低,从而加大对背景类的惩罚,从而减轻背景类数量太多对训练造成的影响;\alpha类似pt ,可将\alpha-CE写为:

CE(p_{t})=-\alpha _{t}log(p_{t})

替他链接:

Pytorch详解BCELoss和BCEWithLogitsLoss_豪哥的博客-CSDN博客_bcewithlogitsloss

BCELoss()与BCEWithLogitsLoss()区别 - 知乎

Focal Loss笔记 - HOU_JUN - 博客园

版权声明:本文为CSDN博主「猫猫与橙子」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_22764813/article/details/121394547

猫猫与橙子

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

目标检测xywh格式转xyxy格式

这两天在看YOLOv1的代码,看到这边博客给了代码 传送门:动手学习深度学习pytorch版——从零开始实现YOLOv1 其中有个地方需要用到cv2.rectangle()函数来给图像

小目标检测方法介绍

目标检测发展很快,但对于小目标 的检测还是有一定的瓶颈,特别是大分辨率图像小目标检测 。比如79202160,甚至1600016000的图像,还有一些遥感图像 。 图像的分辨率很大&#xf