长尾分布-Overcoming Classifier Imbalance for Long-tail Object Detection with Balanced Group Softmax

参考https://blog.csdn.net/sinat_17456165/article/details/106846747
论文地址:http://openaccess.thecvf.com/content_CVPR_2020/papers/Li_Overcoming_Classifier_Imbalance_for_Long-Tail_Object_Detection_With_Balanced_Group_CVPR_2020_paper.pdf

代码地址:https://github.com/FishYuLi/BalancedGroupSoftmax

视频讲解地址:https://www.youtube.com/watch?v=ikdVuadfUo8

摘要

使用基于深度学习的模型来解决长尾 large vocabulary目标检测是一项具有挑战性而艰巨的任务,然而,这项工作尚未得到充分研究。在本文的工作中,首先对针对长尾分布问题所提出SOTA模型的性能进行了系统分析,找出其不足之处。发现当数据集极度偏斜时,现有的检测方法无法对few-shot类别进行建模,这可能导致分类器在参数大小上的不平衡。由于检测和分类之间的内在差异,将长尾分类模型直接应用于检测框架无法解决此问题。因此,在这项工作中,提出了一个新颖的balanced group softmax (BAGS)模块,用于通过逐组训练来平衡检测框架内的分类器。它隐式地调整了头和尾类的训练过程,并确保它们都得到了充分的训练,而无需对来自尾类的instance进行任何额外采样。

在最近的长尾large vocabulary目标识别任务数据集 LVIS上的大量实验表明,本文提出的BAGS大大提高了具有各种主干和框架的检测器在目标检测和实例分割上的性能。它击败了从长尾图像分类中转移过来的所有最新方法,并建立了新的方法。

介绍

LVIS 是由facebook AI研究院的研究人员们发布的一个大规模的词汇实例分割数据集(Large Vocabulary Instance Segmentation ),包含了164k图像,并针对超过1000类物体进行了约200万个高质量的实例分割标注。数据集中包含自然图像中的物体分布天然具有长尾属性。

受《Decou-pling representation and classifier for long-tailed recognition》的启发,首先将检测框架中的representation 和分类模块解耦,发现不同类别相对应的proposal 分类器的weight norm严重失衡,因为low-shot类别被激活的机会很少。通过分析,这是长尾检测器性能差的直接原因,而长尾检测器性能本质上是由数据不平衡引起的。

如图1所示,分别根据训练集中实例的数量对在COCO和LVIS上训练的模型的类别分类器权重范数进行排序。对于COCO,除了背景类(类别ID = 0)以外,相对平衡的数据分布导致所有类别的weight norm相对平衡。而对于LVIS,很明显类别weigh norm是不平衡的,并且与训练实例的数量呈正相关。这种不平衡将使low-shot 类别(尾类)的分类分数比many-shot 类别(头部类)的分类分数小得多。在标准softmax函数之后,这种不平衡会被进一步放大,因此分类器错误地抑制了预测为low-shot 类别的proposal 。
在这里插入图片描述

图1. COCO和LVIS训练集中类别的训练实例(#ins)的排序数量,以及在COCO和LVIS上训练的Faster R-CNN模型的相应分类器权重范数“ w”。x轴表示COCO和LVIS的分类索引。将80类COCO与1230类LVIS对齐,以获得更好的可视化效果。类别0表示背景。

Introduction

目标检测[31,29,25,23,21,1]是计算机视觉中最基本、最具挑战性的任务之一。最近的进展主要是由人工平衡的大规模数据集驱动的,如PASCAL VOC[9]和COCO[24]。然而在现实中,对象类别的分布通常是长尾[30]。有效的解决方案,使最先进的检测模型适应这种类不平衡的分布是非常需要的,但仍然缺乏。最近,一个长尾大词汇表对象识别数据集LVIS[14]发布了,它大大方便了更真实场景下的对象检测研究。长尾目标检测的一个简单的解决方案是,直接在长尾训练数据上训练一个成熟的检测模型(如Faster R-CNN[31])。然而,当将为相当平衡的数据集(如COCO)设计的检测器调整为长尾数据集(如LVIS)时,会观察到较大的性能下降该情况的具体原因尚不清楚。受[20]的启发,我们将检测框架内的表示模块和分类模块解耦,发现不同类别对应的proposal分类器的权重规范严重不平衡,因为low-shot类别被激活的机会很少。通过我们的分析,这是长尾检测性能差的直接原因之一,其本质是由数据不平衡引起的。如图1所示,我们分别根据训练集中的实例数对在COCO和LVIS上训练的模型的分类器权值规范进行分类排序。对于COCO来说,由于数据分布相对均衡,导致除了背景类(CID=0, CID为类别ID)外,所有类别的权重规范都相对均衡。对于LVIS来说,类别权重规范明显不平衡,且与训练实例数呈正相关。这种不平衡的分类器(w.r.t.,它们的参数规范)会使低概率分类器(尾部分类器)的分类分数比多概率分类器(头部分类器)的分类分数小得多。在标准softmax之后,这种不平衡会进一步放大,分类器会错误地抑制被预测为低射类的建议。

先回顾下通常解决长尾分布问题的方法:
1、Re-sampling:主要是在训练集上实现样本平衡,如对tail中的类别样本进行过采样,或者对head类别样本进行欠采样。基于重采样的解决方案适用于检测框架,但可能会导致训练时间增加以及对tail类别的过度拟合风险。

2、Re-weighting:主要在训练loss中,给不同的类别的loss设置不同的权重,对tail类别loss设置更大的权重。但是这种方法对超参数选择非常敏感,并且由于难以处理特殊背景类(非常多的类别)而不适用于检测框架。(我之前尝试过这种策略,但是效果也很差)

3、Learning strategy:有专门为解决少样本问题涉及的学习方法可以借鉴,如:meta-learning、metric learning、transfer learing。另外,还可以调整训练策略,将训练过程分为两步:第一步不区分head样本和tail样本,对模型正常训练;第二步,设置小的学习率,对第一步的模型使用各种样本平衡的策略进行finetune。
由上面可见,解决长尾分布问题,不管是哪一种都比较麻烦。

看下本文所提出的想法:
为了解决分类器不平衡的问题,我们在检测框架的分类头中引入了一个简单而有效的平衡组软码(BAGS)模块。我们提出将训练实例数目相似的对象分类放在同一组中,并分别计算组态软最大交叉熵损失。分别处理具有不同实例号的类别可以有效地缓解头类对尾类的支配。然而,由于每组训练都缺乏多样的负面例子,因此产生的模型存在过多的误报。因此,BAGS进一步在每组中增加一个类别others,并引入背景类别作为一个单独的组,在防止类别background和others的误报的同时,通过保持分类器的平衡,减轻了头类对尾类的抑制。我们通过实验发现BAGS的效果非常好。它尾巴上的性能提高了9% - 19%的各种框架包括更快R-CNN[31],级联R-CNN[1],面具R-CNN[16]和[4]HTC ResNet50-FPN(17、22)和resnext - 101 - x64x4d -红外系统[37]脊椎一直在长尾对象识别基准LVIS[14],与整体映射了约3% - 6%。
综上所述,本工作的贡献如下:
•通过综合分析,我们揭示了现有模型在长尾检测方面表现不佳的原因,即它们的分类器是不平衡的,训练也不是很好,从观察到的分类器权重规范不平衡可以看出。
•我们提出了一个简单而有效的平衡组softmax模块来解决这个问题。它可以很容易地与目标检测和实例分割框架相结合,以提高其长尾识别性能。
•我们使用最先进的长尾分类方法对目标进行了广泛的评估。这样的标杆研究不仅加深了我们对这些方法的理解,也解决了长尾检测所面临的独特挑战,同时也为未来这一方向的研究提供了可靠而有力的基线。

related work

Preliminary and Analysis

3.1. Preliminary

这部分主要讲了为什么目标检测器在tail类上失效,以及重采样方法为什么有效
通过所设计的对比实验发现(具体的实验细节可以参考论文原文),tail类的预测得分会先天性地低于head类,tail类的proposals 在softmax计算中与head类的proposals 竞争后,被选中的可能性会降低。这就解释了为什么目前的检测模型经常在tail类上失效。由于head类的训练实例远多于tail类的训练实例(例如,在某些极端情况下,10000:1),tail类的分类器权重更容易(频繁)被head类的权重所压制,导致训练后的weight norm不平衡。

因此,可以看出为什么重采样方法能够在长尾目标分类和分割任务中的使得tail类受益。它只是在训练过程中增加了tail类proposals 的采样频率,从而可以平等地激活或抑制不同类别的权重,从而在一定程度上平衡tail类和head类。同样,损失重新加权方法也可以通过类似的方式生效。尽管重采样策略可以减轻数据不平衡的影响,但实际上会带来新的风险,例如过度拟合tail类和额外的计算开销。同时,损失重新加权对每个类别的损失加权设计很敏感,通常在不同的框架,backbone和数据集之间会有所不同,因此很难在实际应用中进行部署。而且,基于重新加权的方法不能很好地处理检测问题中的背景类。因此,本文提出了一种简单而有效的解决方案,无需繁重的超参数工程即可平衡分类器weight norm。

4.Balanced Group Softmax

接下来看下具体的结构内容在这里插入图片描述

4.1 group softmax

如前所述,权值规范与训练样本数量之间的正相关关系会影响检测器的性能。为了解决这一问题,我们提出将类划分为几个不相关联的组,并分别进行softmax操作,使每个组内只有训练实例数量相似的类相互竞争。通过这种方式,可以在训练期间将包含显著不同数量实例的类彼此隔离。尾部类的分类器权重不会被头部类实质上抑制
根据训练实例数量将所有类别分为N组:
在这里插入图片描述
其中N(j)是训练集中类别J的标签中边界框的数量,而sl和sh是确定每组的最小和最大实例数的超参数。文中,分为四组N = 4,sl1 = 0,sl2 = 10,sl3 = 102,sl4 = 103,sh4 = +∞。
在这里插入图片描述
另外,我们手动设置G0只包含背景类别,因为G0拥有最多的训练实例(通常是对象类别的10-100倍)。对于G0我们采用s型交叉熵损失,因为它只包含一个预测,而对于其他组我们使用softmax交叉熵损失。选择softmax的原因是,softmax函数天生具有从另一个类中抑制每个类的能力,并且不太可能产生大量的误报。在培训过程中,对于ground-truth标签为c的proposal bk,会激活两个组,即后台组G0和前台组Gn,其中c∈Gn。
小结:
首先根据实例数量分组softmax,每个组内是数据差不多的类别竞争,这样就可以抑制head(头部)对tail(尾部)的权重压制,防止tail(尾部)出现检测不均衡问题。
同时上述方式也存在一定弊端。
但是,发现上述group softmax设计存在以下问题:在测试过程中,对于一个proposal,由于其类别未知将使用所有组进行预测,因此,每个组至少有一个类别将获得较高的预测分数,并且很难决定我们应该采用哪种分组预测,从而导致大量误报。为了解决这个问题,在每个组中添加了一个类别,以校准组之间的预测并抑制误报。此类别包含当前组中未包含的类别,可以是其他组中的背景类别或前景类别。对于G0来说,其他类别也代表前台类。具体来说,对于具有ground-真值标签c的提案bk,新的预测z应该是z∈R(c +1)+(N +1)。第j类的概率为
在这里插入图片描述
ground-truth标签应该在每组中重新映射。在不包含c的组中,类others将被定义为基本真理类。则最终损失函数为
在这里插入图片描述
其中yn和pn表示Gn中的标签和概率。

4.3. Balancing training samples in groups

在以上处理中,新添加的类别others将通过抑制众多实例,再次成为占主导地位的outlier 。为了平衡每组的训练样本数,仅对一定数量的others proposals进行训练,由采样率β控制

在包含标签真值的类别组中,将根据mini-batch of K proposals来按比例采样others实例。如果一组中没有激活正常类别,则所有others实例都不会激活,该组则被忽略。这样,每个组都可以保持平衡,且误报率低。添加others类别会使baseline提高2.7%。

4.4 Inference

在推理过程中,首先使用训练好的模型生成z,然后在每个组中应用softmax。 除G0外,其他所有节点均被忽略,所有类别的概率均按原始类别ID排序。G0中的p0可被视为前景proposals的概率。最后,使用 重新缩放正常类别的所有概率。这个新的概率向量将被送到后续的后处理步骤(如NMS),以产生最终的检测结果。应该注意的是,从概念上来说 不是真正的概率向量,因为它的总和不等于1,但它起着原始概率向量的作用,该向量通过选择最终boxes框来指导模型。
剩余部分为实验部分,看具体论文。

代码部分

import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import numpy as np

from mmdet.core import (delta2bbox, force_fp32,
                        multiclass_nms)
from .convfc_bbox_head import SharedFCBBoxHead
from ..builder import build_loss
from ..registry import HEADS


@HEADS.register_module
class GSBBoxHeadWith0(SharedFCBBoxHead):

    def __init__(self,
                 num_fcs=2,
                 fc_out_channels=1024,
                 gs_config=None,
                 *args,
                 **kwargs):
        super(GSBBoxHeadWith0, self).__init__(num_fcs=num_fcs,
                                         fc_out_channels=fc_out_channels,
                                         *args,
                                         **kwargs)
        # 1232, 0 for background, 1231 for foreground
        self.fc_cls = nn.Linear(self.cls_last_dim,
                                self.num_classes + gs_config.num_bins)

        # self.loss_bg = build_loss(gs_config.loss_bg)

        self.loss_bins = []
        for i in range(gs_config.num_bins):
            self.loss_bins.append(build_loss(gs_config.loss_bin))

        self.label2binlabel = torch.load(gs_config.label2binlabel).cuda()
        self.pred_slice = torch.load(gs_config.pred_slice).cuda()

        # TODO: update this ugly implementation. Save fg_split to a list and
        #  load groups by gs_config.num_bins
        with open(gs_config.fg_split, 'rb') as fin:
            fg_split = pickle.load(fin)

        self.fg_splits = []
        ######这里划分了四组(根据数据集数量划分)
        self.fg_splits.append(torch.from_numpy(fg_split['(0, 10)']).cuda())
        self.fg_splits.append(torch.from_numpy(fg_split['[10, 100)']).cuda())
        self.fg_splits.append(torch.from_numpy(fg_split['[100, 1000)']).cuda())
        self.fg_splits.append(torch.from_numpy(fg_split['[1000, ~)']).cuda())

        # self.fg_splits.append(torch.from_numpy(fg_split['(0, 5)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['(5, 10)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[10, 50)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[50, 100)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[100, 500)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[500, 1000)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[1000, 5000)']).cuda())
        # self.fg_splits.append(torch.from_numpy(fg_split['[5000, ~)']).cuda())

        self.others_sample_ratio = gs_config.others_sample_ratio


    def _sample_others(self, label):

        # only works for non bg-fg bins

        fg = torch.where(label > 0, torch.ones_like(label),
                         torch.zeros_like(label))
        fg_idx = fg.nonzero(as_tuple=True)[0]
        fg_num = fg_idx.shape[0]
        if fg_num == 0:
            return torch.zeros_like(label)

        bg = 1 - fg
        bg_idx = bg.nonzero(as_tuple=True)[0]
        bg_num = bg_idx.shape[0]

        bg_sample_num = int(fg_num * self.others_sample_ratio)

        if bg_sample_num >= bg_num:
            weight = torch.ones_like(label)
        else:
            sample_idx = np.random.choice(bg_idx.cpu().numpy(),
                                          (bg_sample_num, ), replace=False)
            sample_idx = torch.from_numpy(sample_idx).cuda()
            fg[sample_idx] = 1
            weight = fg

        return weight

    def _remap_labels(self, labels):

        num_bins = self.label2binlabel.shape[0]
        new_labels = []
        new_weights = []
        new_avg = []
        for i in range(num_bins):
            mapping = self.label2binlabel[i]
            new_bin_label = mapping[labels]

            if i < 1:
                weight = torch.ones_like(new_bin_label)
                # weight = torch.zeros_like(new_bin_label)
            else:
                weight = self._sample_others(new_bin_label)
            new_labels.append(new_bin_label)
            new_weights.append(weight)

            avg_factor = max(torch.sum(weight).float().item(), 1.)
            new_avg.append(avg_factor)

        return new_labels, new_weights, new_avg

    def _remap_labels1(self, labels):

        num_bins = self.label2binlabel.shape[0]
        new_labels = []
        new_weights = []
        new_avg = []
        for i in range(num_bins):
            mapping = self.label2binlabel[i]
            new_bin_label = mapping[labels]

            weight = torch.ones_like(new_bin_label)

            new_labels.append(new_bin_label)
            new_weights.append(weight)

            avg_factor = max(torch.sum(weight).float().item(), 1.)
            new_avg.append(avg_factor)

        return new_labels, new_weights, new_avg

    def _slice_preds(self, cls_score):

        new_preds = []

        num_bins = self.pred_slice.shape[0]
        for i in range(num_bins):
            start = self.pred_slice[i, 0]
            length = self.pred_slice[i, 1]
            sliced_pred = cls_score.narrow(1, start, length)
            new_preds.append(sliced_pred)

        return new_preds

    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
    def loss(self,
             cls_score,
             bbox_pred,
             labels,
             label_weights,
             bbox_targets,
             bbox_weights,
             reduction_override=None):
        losses = dict()

        if cls_score is not None:
            # Original label_weights is 1 for each roi.
            new_labels, new_weights, new_avgfactors = self._remap_labels(labels)
            new_preds = self._slice_preds(cls_score)

            num_bins = len(new_labels)
            for i in range(num_bins):
                losses['loss_cls_bin{}'.format(i)] = self.loss_bins[i](
                    new_preds[i],
                    new_labels[i],
                    new_weights[i],
                    avg_factor=new_avgfactors[i],
                    reduction_override=reduction_override
                )

        if bbox_pred is not None:
            pos_inds = labels > 0
            if self.reg_class_agnostic:
                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
            else:
                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
                                               4)[pos_inds, labels[pos_inds]]
            losses['loss_bbox'] = self.loss_bbox(
                pos_bbox_pred,
                bbox_targets[pos_inds],
                bbox_weights[pos_inds],
                avg_factor=bbox_targets.size(0),
                reduction_override=reduction_override)
        return losses

    @force_fp32(apply_to=('cls_score'))
    def _merge_score1(self, cls_score):
        '''
        Do softmax in each bin. Merge the scores directly.
        '''
        num_proposals = cls_score.shape[0]

        new_preds = self._slice_preds(cls_score)
        new_scores = [F.softmax(pred, dim=1) for pred in new_preds]

        bg_score = new_scores[0]
        fg_score = new_scores[1:]

        fg_merge = torch.zeros((num_proposals, 1231)).cuda()
        merge = torch.zeros((num_proposals, 1231)).cuda()

        for i, split in enumerate(self.fg_splits):
            fg_merge[:, split] = fg_score[i]

        merge[:, 0] = bg_score[:, 0]
        fg_idx = (bg_score[:,1] > 0.5).nonzero(as_tuple=True)[0]
        merge[fg_idx] = fg_merge[fg_idx]

        return merge

    @force_fp32(apply_to=('cls_score'))
    def _merge_score2(self, cls_score):
        '''
        Do softmax in each bin. Softmax again after merge.
        '''
        num_proposals = cls_score.shape[0]

        new_preds = self._slice_preds(cls_score)
        new_scores = [F.softmax(pred, dim=1) for pred in new_preds]

        bg_score = new_scores[0]
        fg_score = new_scores[1:]

        fg_merge = torch.zeros((num_proposals, 1231)).cuda()
        merge = torch.zeros((num_proposals, 1231)).cuda()

        for i, split in enumerate(self.fg_splits):
            fg_merge[:, split] = fg_score[i]

        merge[:, 0] = bg_score[:, 0]
        fg_idx = (bg_score[:,1] > 0.5).nonzero(as_tuple=True)[0]
        merge[fg_idx] = fg_merge[fg_idx]
        merge = F.softmax(merge)

        return merge

    @force_fp32(apply_to=('cls_score'))
    def _merge_score(self, cls_score):
        '''
        Do softmax in each bin. Decay the score of normal classes
        with the score of fg.
        From v1.
        '''

        num_proposals = cls_score.shape[0]

        new_preds = self._slice_preds(cls_score)
        new_scores = [F.softmax(pred, dim=1) for pred in new_preds]

        bg_score = new_scores[0]
        fg_score = new_scores[1:]

        fg_merge = torch.zeros((num_proposals, self.num_classes)).cuda()
        merge = torch.zeros((num_proposals, self.num_classes)).cuda()

        # import pdb
        # pdb.set_trace()
        for i, split in enumerate(self.fg_splits):
            fg_merge[:, split] = fg_score[i][:, 1:]

        weight = bg_score.narrow(1, 1, 1)

        # Whether we should add this? Test
        fg_merge = weight * fg_merge

        merge[:, 0] = bg_score[:, 0]
        merge[:, 1:] = fg_merge[:, 1:]
        # fg_idx = (bg_score[:, 1] > 0.5).nonzero(as_tuple=True)[0]
        # erge[fg_idx] = fg_merge[fg_idx]

        return merge

    @force_fp32(apply_to=('cls_score'))
    def _merge_score4(self, cls_score):
        '''
        Do softmax in each bin.
        Do softmax on merged fg classes.
        Decay the score of normal classes with the score of fg.
        From v2 and v3
        '''
        num_proposals = cls_score.shape[0]

        new_preds = self._slice_preds(cls_score)
        new_scores = [F.softmax(pred, dim=1) for pred in new_preds]

        bg_score = new_scores[0]
        fg_score = new_scores[1:]

        fg_merge = torch.zeros((num_proposals, 1231)).cuda()
        merge = torch.zeros((num_proposals, 1231)).cuda()

        for i, split in enumerate(self.fg_splits):
            fg_merge[:, split] = fg_score[i]

        fg_merge = F.softmax(fg_merge, dim=1)
        weight = bg_score.narrow(1, 1, 1)
        fg_merge = weight * fg_merge

        merge[:, 0] = bg_score[:, 0]
        merge[:, 1:] = fg_merge[:, 1:]
        # fg_idx = (bg_score[:, 1] > 0.5).nonzero(as_tuple=True)[0]
        # erge[fg_idx] = fg_merge[fg_idx]

        return merge

    @force_fp32(apply_to=('cls_score'))
    def _merge_score5(self, cls_score):
        '''
        Do softmax in each bin.
        Pick the bin with the max score for each box.
        '''
        num_proposals = cls_score.shape[0]

        new_preds = self._slice_preds(cls_score)
        new_scores = [F.softmax(pred, dim=1) for pred in new_preds]

        bg_score = new_scores[0]
        fg_score = new_scores[1:]
        max_scores = [s.max(dim=1, keepdim=True)[0] for s in fg_score]
        max_scores = torch.cat(max_scores, 1)
        max_idx = max_scores.argmax(dim=1)

        fg_merge = torch.zeros((num_proposals, 1231)).cuda()
        merge = torch.zeros((num_proposals, 1231)).cuda()

        for i, split in enumerate(self.fg_splits):
            tmp_merge = torch.zeros((num_proposals, 1231)).cuda()
            tmp_merge[:, split] = fg_score[i]
            roi_idx = torch.where(max_idx == i,
                                  torch.ones_like(max_idx),
                                  torch.zeros_like(max_idx)).nonzero(
                as_tuple=True)[0]
            fg_merge[roi_idx] = tmp_merge[roi_idx]

        merge[:, 0] = bg_score[:, 0]
        fg_idx = (bg_score[:, 1] > 0.5).nonzero(as_tuple=True)[0]
        merge[fg_idx] = fg_merge[fg_idx]

        return merge

    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
    def get_det_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
        if isinstance(cls_score, list):
            cls_score = sum(cls_score) / float(len(cls_score))

        scores = self._merge_score(cls_score)
        # scores = F.softmax(cls_score, dim=1) if cls_score is not None else None

        if bbox_pred is not None:
            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
                                self.target_stds, img_shape)
        else:
            bboxes = rois[:, 1:].clone()
            if img_shape is not None:
                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)

        if rescale:
            if isinstance(scale_factor, float):
                bboxes /= scale_factor
            else:
                bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)

        if cfg is None:
            return bboxes, scores
        else:
            det_bboxes, det_labels = multiclass_nms(bboxes, scores,
                                                    cfg.score_thr, cfg.nms,
                                                    cfg.max_per_img)

            return det_bboxes, det_labels

版权声明:本文为CSDN博主「dear_queen」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/dear_queen/article/details/122448333

dear_queen

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

Pytorch—万字入门SSD物体检测

前言 由于初入物体检测领域,我在学习SSD模型的时候遇到了很多的困难。一部分困难在于相关概念不清楚,专业词汇不知其意,相关文章不知所云;另一部分困难在于网上大部分文章要么只是简要介绍了SS