- 负责预测目标网格中与ground truth的IOU最大的anchor为正样本(记住这里没有阈值的事情,否则会绕晕)
- 剩下的anchor中,与全部ground truth的IOU都小于阈值的anchor为负样本
- 其他是忽略样本
- 代码未完待续
- 获取正样本代码,参考这里
'''
targets是标签列表,长度是batch_size,元素的shape是(真实框个数*5)
anchors是[[116,90],[156,198],[373,326]]或[[30,61],[62,45],[59,119]]或[[10,13],[16,30],[33,23]]
in_h, in_w是13,13或26,26或52,52
num_classes是类别数,voc是20,COCO是80
calculate_iou这里不提供,保证输出shape是(真实框个数*3)
'''
def get_target(targets, anchors, in_h, in_w, num_classes):
bs=len(targets)
positive=torch.zeros(bs,len(anchors),in_h, in_w, 5+num_classes,requires_grad = False)
negtive=torch.ones(bs,len(anchors),in_h, in_w, requires_grad = False)
for b in range(bs):
batch_target = torch.zeros_like(targets[b])
# 计算该特征图上标签的值
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
batch_target[:, 4] = targets[b][:, 4]
batch_target = batch_target.cpu()
# 计算标签和anchor的IOU
# 这里可以随便选一个共同中心(0,0),根据高宽计算IOU
gt_box= torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
anchor_shapes=torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
iou=calculate_iou(gt_box, anchor_shapes)
# 获得与标签最匹配的anchor的索引
best_ns = torch.argmax(iou, dim=-1)
for t, best_n in enumerate(best_ns):
# 第t个标签中心所在网格,种类
i = torch.floor(batch_target[t, 0]).long()
j = torch.floor(batch_target[t, 1]).long()
c = batch_target[t, 4].long()
positive[b,best_n,j,i,0]=batch_target[t, 0] - i.float()
positive[b,best_n,j,i,1]=batch_target[t, 1] - j.float()
positive[b,best_n,j,i,2]=math.log(batch_target[t, 2] / anchors[best_n][0])
positive[b,best_n,j,i,3]=math.log(batch_target[t, 3] / anchors[best_n][1])
positive[b,best_n,j,i,4]=1
positive[b,best_n,j,i,c+5]=1
negtive[b,best_n,j,i]=0
return positive,negtive
版权声明:本文为CSDN博主「刀么克瑟拉莫」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/random_repick/article/details/122565143
暂无评论