YoloV3 案例

学习目标

  • 知道YoloV3模型结构及构建方法
  • 知道数据处理方法
  • 能够利用yoloV3模型进行训练和预测

1.数据获取

一部分是网络数据,可以是开源数据,也可以通过百度、Google图片爬虫得到

在接下来的课程中我们就使用标注好的数据进行模型训练,模型预测。使用的工程如下所示:

主要内容是:

1.config中是网络的配置信息:anchors,类别信息

2.core中是损失函数计算,网络预测的内容

3.dateset中是对数据的处理

4.model是对模型的构建

5.utils是一些辅助文件,包括anchor,类别信息的获取等

6.weights中保存了一个使用coco数据集训练的预训练模型

2.TFrecord文件

  • TFRecord是Google官方推荐使用的数据格式化存储工具,为TensorFlow量身打造的。
  • TFRecord规范了数据的读写方式,数据读取和处理的效率都会得到显著的提高。

首先导入工具包

from dataset.vocdata_tfrecord import load_labels,write_to_tfrecord

import os

将数据写入tfrecord中的流程是:

  1. 指定要写入的数据集路径
  2. 获取所有的XML标注文件
  3. 指定tfrecord的存储位置
  4. 获取图像的路径
  5. 将数据写入到tfrecord文件中
    # 指定要写入的数据集路径
    data_path = '/Users/dataset/VOCdevkit/VOC2007'
    
    all_xml = load_labels(data_path, 'train')
    
    tfrecord_path = 'voc_train.tfrecords'
    
    voc_img_path = os.path.join(data_path, 'JPEGImages')
    
    # 将数据写入
    write_to_tfrecord(all_xml, tfrecord_path, voc_img_path)

    2.3 读取TFRecord文件 

    导入工具包:

    # 读取tfrecords文件所需的工具包
    from dataset.get_tfdata import getdata
    
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    
    
    datasets = getdata("dataset/voc_val.tfrecords")

from matplotlib.patches import Rectangle

from utils.config_utils import read_class_names
classes = read_class_names("config/classname")

plt.figure(figsize=(15, 10))

i = 0
# 从datasets中选取3个样本,获取图像,大小,框的标注信息和类别信息
for image, width, height, boxes, boxes_category in datasets.take(3):

    plt.subplot(1, 3, i+1)

    plt.imshow(image)

    ax = plt.gca()

    for j in range(boxes.shape[0]):
        # 绘制框
        rect = Rectangle((boxes[j, 0], boxes[j, 1]), boxes[j, 2] -boxes[j, 0], boxes[j, 3]-boxes[j, 1], color='r', fill=False)
        # 将框显示在图像上
        ax.add_patch(rect)
  
        label_id = boxes_category[j]
        # 获取标准信息
        label = classes.get(label_id.numpy())

        ax.text(boxes[j, 0], boxes[j, 1] + 8, label,color='w', size=11, backgroundcolor="none")
    # 下一个结果
    i += 1
# 显示图像
plt.show()

3. 数据处理 

 # 输入:原图像及图像上的标准框 # 输出:将尺度调整后的图像,及相应的目标框 image,bbox = preprocess(oriimage,oribbox,input_shape=(416,416))

对读取的数据进行处理并绘制结果


from dataset.preprocess import preprocess as ppro

plt.figure(figsize=(15,10))

i = 0
for image,width,height,boxes,boxes_category in datasets.take(3):

    image,boxes = preprocess(image,boxes)

    plt.subplot(1,3,i+1)

    plt.imshow(image[0])

    ax = plt.gca()
    for j in range(boxes.shape[0]):
        rect = Rectangle((boxes[j, 0], boxes[j, 1]), boxes[j, 2] -boxes[j, 0], boxes[j, 3]-boxes[j, 1], color='r', fill=False)
        ax.add_patch(rect)

        label_id = boxes_category[j]
        label = classes.get(label_id.numpy())
        ax.text(boxes[j, 0], boxes[j, 1] + 8, label,color='w', size=11, backgroundcolor="none")
    i+=1
plt.show()

 4.模型构建

 # 导入工具包 from model.yoloV3 import YOLOv3

# 模型实例化:指定输入图像的大小,和类别数

yolov3 = YOLOv3((416,416,3),80)

# 获取模型架构

yolov3.summary()

5.模型训练

 在计算损失函数时使用core.loss来完成:

# 导入所需的工具包 from core.loss import Loss

# 实例化 yolov3_loss = Loss((416,416,3),80)

 # 损失输入 yolov3_loss.inputs

 # 损失输出 yolov3_loss.outputs

 6.正负样本的设定

  • 正样本:首先计算目标中心点落在哪个grid上,然后计算这个grid对应的3个先验框(anchor)和目标真实位置的IOU值,取IOU值最大的先验框和目标匹配。那么该anchor 就负责预测这个目标,那这个anchor就作为正样本,将其置信度设为1,其他的目标值根据标注信息设置。
  • 负样本:所有不是正样本的anchor都是负样本,将其置信度设为0,参与损失计算,其它的值不参与损失计算,默认为0。

 将目标值绘制在图像上

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# 1.获取类别信息
from utils.config_utils import read_class_names
classes = read_class_names('config/classname')

plt.figure(figsize=(15,10))

for image,width,height,boxes,boxes_category in datasets.take(1):
    # 4.显示图像:plt.imshow()
    plt.imshow(image)
    # 5.显示box,遍历所有的bbox,rectange进行绘制
    ax = plt.gca()
    for j in range(boxes.shape[0]):
        rect = Rectangle((boxes[j, 0], boxes[j, 1]), boxes[j, 2] -boxes[j, 0], boxes[j, 3]-boxes[j, 1], color='r', fill=False)
        ax.add_patch(rect)
        # 6.显示类别
        label_id = boxes_category[j]
        label = classes.get(label_id.numpy())
        ax.text(boxes[j, 0], boxes[j, 1] + 8, label,color='w', size=11, backgroundcolor="none")
    # 7.绘制正样本的anchor的目标值
    anchor = label1[12, 12,0,0:4].numpy()
    rect2 = Rectangle((anchor[0]-anchor[2]/2, anchor[1]-anchor[3]/2), anchor[2], anchor[3],color='g', fill=False)
    ax.add_patch(rect2)
plt.show()

7. 模型训练

1、加载数据集:我们在这里使用VOC数据集,所以需要从TFrecord文件中加载VOC数据集

2、模型实例化:加载yoloV3模型和损失函数的实现

3、模型训练:计算损失函数,使用反向传播算法对模型进行训练

 # 导入 from dataset.preprocess import dataset

# 设置batch_size batch_size=1

# 获取训练集数据,并指定batchsize,返回训练集数据

trainset = dataset("dataset/voc_train.tfrecords",batch_size)

 在yoloV3模型和损失函数的计算进行实例化

# V3模型的实例化,指定输入图像的大小,即目标检测的类别个数
yolov3 = YOLOv3((416, 416, 3,), 20)
yolov3_loss = Loss((416,416,3), 20)

 8.模型训练

模型训练也就是要使用损失函数,进行反向传播,利用优化器进行参数更新,训练的流程是:

1、指定优化器:在这里我们使用加动量的SGD方法

2、设置epoch,进行遍历获取batch数据送入网络中进行预测

3、计算损失函数,使用反向传播更新参数,我们使用tf.GradientTape实现:

# 1、定义优化方法
optimizer = tf.keras.optimizers.SGD(0.1,0.9)

for epoch in range(300):
    loss_history = []
    # 遍历每一个batch的图像和目标值,进行更新
    for (batch, inputs) in enumerate(trainset):
        images, labels = inputs
        # 3.计算损失函数,使用反向传播更新参数
        # 3.1 定义上下文环境
        with tf.GradientTape() as tape:
   
            outputs = yolov3(images)
    
            loss = yolov3_loss([*outputs, *labels])
    
            grads = tape.gradient(loss, yolov3.trainable_variables)

            optimizer.apply_gradients(zip(grads, yolov3.trainable_variables))
  
            info = 'epoch: %d, batch: %d ,loss: %f'%(epoch, batch, np.mean(loss_history))
            print(info)
            loss_history.append(loss.numpy())
yolov3.save('yolov3.h5')

版权声明:本文为CSDN博主「AI-创造美好未来」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_53226226/article/details/122690814

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

目标检测指标计算

一、指标 True Positive,TP :预测为正样本(Positive),实际为正样本,则True,预测正确。 True Negative,TN &#x

Deep Learning 目标检测

对检测到的结果进行解析  #----------------------------目标检测*解析字典result------------------------------------------- from numpy import arr

YOLOV3预选框验证

对于一个输入图像,比如416*416*3,相应的会输出 13*13*3 26*26*3 52*52*3 10647 个预测框。我们希望这些预测框的信息能够尽量准确的反应出哪些位置存在对象,是哪种对