yolov3选取正负样本

  • 负责预测目标网格中与ground truth的IOU最大的anchor为正样本(记住这里没有阈值的事情,否则会绕晕)
  • 剩下的anchor中,与全部ground truth的IOU都小于阈值的anchor为负样本
  • 其他是忽略样本
  • 代码未完待续
  • 获取正样本代码,参考这里
'''
targets是标签列表,长度是batch_size,元素的shape是(真实框个数*5)
anchors是[[116,90],[156,198],[373,326]]或[[30,61],[62,45],[59,119]]或[[10,13],[16,30],[33,23]]
in_h, in_w是13,13或26,26或52,52
num_classes是类别数,voc是20,COCO是80
calculate_iou这里不提供,保证输出shape是(真实框个数*3)
'''
def get_target(targets, anchors, in_h, in_w, num_classes):
    bs=len(targets)
    positive=torch.zeros(bs,len(anchors),in_h, in_w, 5+num_classes,requires_grad = False)
    negtive=torch.ones(bs,len(anchors),in_h, in_w, requires_grad = False)
    for b in range(bs):
        batch_target = torch.zeros_like(targets[b])
        # 计算该特征图上标签的值
        batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
        batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
        batch_target[:, 4] = targets[b][:, 4]
        batch_target = batch_target.cpu()
        # 计算标签和anchor的IOU
        # 这里可以随便选一个共同中心(0,0),根据高宽计算IOU
        gt_box= torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
        anchor_shapes=torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
        iou=calculate_iou(gt_box, anchor_shapes)
        # 获得与标签最匹配的anchor的索引
        best_ns = torch.argmax(iou, dim=-1)
        for t, best_n in enumerate(best_ns):
            # 第t个标签中心所在网格,种类
            i = torch.floor(batch_target[t, 0]).long()
            j = torch.floor(batch_target[t, 1]).long()
            c = batch_target[t, 4].long()
            positive[b,best_n,j,i,0]=batch_target[t, 0] - i.float()
            positive[b,best_n,j,i,1]=batch_target[t, 1] - j.float()
            positive[b,best_n,j,i,2]=math.log(batch_target[t, 2] / anchors[best_n][0])
            positive[b,best_n,j,i,3]=math.log(batch_target[t, 3] / anchors[best_n][1])
            positive[b,best_n,j,i,4]=1
            positive[b,best_n,j,i,c+5]=1
            negtive[b,best_n,j,i]=0
    return positive,negtive

版权声明:本文为CSDN博主「刀么克瑟拉莫」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/random_repick/article/details/122565143

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

YoloV3 案例

学习目标 知道YoloV3模型结构及构建方法知道数据处理方法能够利用yoloV3模型进行训练和预测 1.数据获取 一部分是网络数据,可以是开源数据,也可以通过百度、Google图片爬虫得到 在接下来的课程中我们

AlexeyAB DarkNet YOLOv3 Loss计算全解析

先附上AlexeyAB大神版本的DarkNet:github 一、前言 目前还没有对yolo loss计算方法讲的很明白的资料,尤其是loss计算中是如何选取正负样本和忽略样本的。因此在这里做出详细的解释。本文是基于

PsROI Pooling 深入理解,附代码

faster rcnn 和 rfcn 的最大不同点在于rfcn采用了PsROI Pooling 保留了局部区域的位置敏感性。 输入batch_size N 的批次训练图像。 假设我们通过 RPN 层网络获取了 M 个 rois, 每个 ro