代码实践——Faster R-CNN

环境,编译:

win10
anaconda3
pytorch1.2
cuda10.0

$ git clone https://github.com/supernotman/Faster-RCNN-with-torchvision.git
$ pip install -r requirements.txt

准备训练数据

下载coco2017数据集,下载地址:
http://images.cocodataset.org/zips/train2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
http://images.cocodataset.org/zips/val2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
http://images.cocodataset.org/zips/test2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip
https://blog.csdn.net/watermelon1123/article/details/100095940
数据下载后按照如下结构放置:

  coco/
    2017/
      annotations/
      test2017/
      train2017/
      val2017/

模型训练

$ python -m torch.distributed.launch --nproc_per_node=$gpus --use_env train.py --world-size $gpus --b 4

训练过程中每个epoch会给出一次评估结果,形式如下:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.352
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.573
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.375
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.387
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.296
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.474
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.498
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.312
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631

其中AP为准确率,AR为召回率,第一行为训练结果的mAP,第四、五、六行分别为小/中/大物体对应的mAP
单张图片检测

$ python detect.py --model_path result/model_13.pth --image_path imgs/1.jpg

model_path为模型路径,image_path为测试图片路径。
代码文件夹中assets给出了从coco2017测试集中挑选的11张图片测试结果。
核心代码:

"""
Faster rcnn实现目标检测
"""

import os
import time
import torch
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt

# 获取当前路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# classes_coco类别信息
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


if __name__ == "__main__":

    # 检测图片路径
    path_img = os.path.join(BASE_DIR, "moto.jpg")

    # 预处理
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])

    input_image = Image.open(path_img).convert("RGB")
    img_chw = preprocess(input_image)

    # 加载预训练模型
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    if torch.cuda.is_available():
        img_chw = img_chw.to('cuda')
        model.to('cuda')

    # 前向传播
    input_list = [img_chw]
    with torch.no_grad():
        tic = time.time()
        print("input img tensor shape:{}".format(input_list[0].shape))
        output_list = model(input_list)
        output_dict = output_list[0]
        print("pass: {:.3f}s".format(time.time() - tic))
        # 打印输出信息
        for k, v in output_dict.items():
            print("key:{}, value:{}".format(k, v))

    # 取得相应结果
    out_boxes = output_dict["boxes"].cpu()
    out_scores = output_dict["scores"].cpu()
    out_labels = output_dict["labels"].cpu()

    # 可视化
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(input_image, aspect='equal')

    num_boxes = out_boxes.shape[0]
    max_vis = 400
    thres = 0.6

    # 循环描框
    for idx in range(0, min(num_boxes, max_vis)):
        score = out_scores[idx].numpy()
        bbox = out_boxes[idx].numpy()
        class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]

        if score < thres:
            continue

        ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
                                   edgecolor='red', linewidth=3.5))
        ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
        ax.set_title("just a simple try about Faster Rcnn", fontsize=28, color='blue')
    plt.show()
    plt.close()

效果

在这里插入图片描述

在这里插入图片描述

版权声明:本文为CSDN博主「刹那永恒HB」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_43165081/article/details/117885432

刹那永恒HB

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐