基于mmdetection 旋转目标检测(OBB detection)+DOTA数据集&自定义数据集+配docker

这两周得益于组里的任务需求,肝了一个遥感类的飞机旋转框目标检测,在给定的4096*4096的大尺度分辨率图片上去识别检测飞机。

经过模型检测后输出结果图如下图所示:
img_output_example
可以看到最终的结果还是不错的,通过该任务的锻炼,自己对一般的目标检测工程上的问题可以说基本走了个遍,能够完成基本的目标检测、旋转框的目标检测任务等。在这里简单分享一下任务的心得。

核心思想

	一:基于mmdetection的目标检测框架
	二:DOTA数据集格式
	三:任务相关

一:基于mmdetection的目标检测框架

(1)mmdetection相关
做目标检测现在已经绕不开mmdetection了,该框架是一个基于Pytorch实现的深度学习目标检测工具箱,与MMCV进行搭配使用。目前许多SOTA的模型都在上面进行更改。
一些教程:
MMDetection中文文档—详解:https://zhuanlan.zhihu.com/p/101225733
数据处理过程:https://blog.csdn.net/u014453898/article/details/107701094
configs相关:https://zhuanlan.zhihu.com/p/102072353

由于任务是检测目标的旋转框,最终选择了s2anet模型。对于一般的目标检测任务(如coco,VOC等)则尝试使用了VarifocalNet。
s2anet:https://github.com/csuhan/s2anet
varifocalNet:https://github.com/hyz-xmaster/VarifocalNet

mmdetection使用的关键步骤在于定义config文件,框架会从config中定义好的字段中去加载使用、定义相应的函数、模型、数据加载、数据预处理、优化器、路径等。一般而言需要修改模型中的num_classes,其值为数据集中分类的类别个数(较老的mmdet的版本需要个数+1,即加一个背景类)、数据集的加载路径、work_dir等。

(2)自定义数据集
自定义的数据集类定义在mmdet/datasets中,一般而言是将数据集转换成COCO(or voc)格式,然后继承
已经写好的CocoDataset(CustomDataset)。将类中的CLASSES属性修改成自定义数据集中的类别。在自定义的类的上一行需要加入@DATASETS.register_module (mmdet版本不一样可能有所区别) 用来注册已经自定义好的类,同时需要在 datasets/__init__py中加入自定义的数据集类。如下图所示:
自定义数据集
同时可以自己重写evaluation函数满足自己的评估需求。

二:DOTA数据集格式

(1)DOTA数据集简介
DOTA数据集是一个比较著名的遥感类高分辨率数据集,包括v1.0,v1.5,v2.0三个版本的数据,一共30G左右。采用旋转框的标记方式,标记四个顶点八个坐标得到不规则四边形。具体实现是,首先标注出一个初始点,为(x1,y1),然后顺时针方向依次标注2、3、4共4个点。如下图ace所示。bcd是传统的水平标注方法,有大量的重叠区域。
DOTA数据集标注样例
标注文件的格式如下图所示:DOTA_label
其中(x1,y1)用于表示OBB的顶点起始位置,四个顶点按照顺时针进行排列。category表示目标种类,difficult表示实例的检测难度。

DOTA_devkit是官方给的配套的数据处理的配套文件,包括绘制目标边框的示例,剪裁数据集、合并检测结果、评估模型性能等。
DOTA_devkit官方github:https://github.com/CAPTAIN-WHU/DOTA_devkit

下面这个DOTA_devkit的整理(踩坑记录)一文里详细介绍了DOTA_devkit的各个py文件的作用、代码中的实际应用、剪裁、合并策略等,介绍的比较全面。
DOTA_devkit的整理(踩坑记录):https://zhuanlan.zhihu.com/p/355862906

(2)标签格式转换
由于任务所给的数据集同样是遥感图片,且是大分辨率图片,目标也是标注OBB,因此可以类比于DOTA数据集的操作。第一步是将任务的数据集从labelme的标注格式转换至DOTA标注格式,然后采用官方给的DOTA_devkit进行图片的预处理操作。
同时,DOTA_devit的dota_evaluation_task1.py中的voc_eval()、即数据集的评估函数中,还需要提供一个测试图片的name_list和储存剪裁前的图片注释文件夹label_txt,其中name_list需要自己写一个脚本生成,如下所示:

import os
dir="/home/dataset/airplaneDOTA/airplane/val/images"
img_name_list=[]
for root, dirs, files in os.walk(dir):
  for file in files:
   # print os.path.join(root,file)
   img_name=file.split(".")[0]
   img_name_list.append(img_name)

write_path="/home/dataset/airplaneDOTA/airplane/val/test_image_list.txt"
#写入文本
with open(write_path,"w") as f:
    for i in range(len(img_name_list)):
        f.write(img_name_list[i]) 
        f.write("\n")
print("end")

在s2anet中,则是在evaluation中提供。

evaluation = dict( gt_dir='/home/dataset/airplaneDOTA/airplane/val/labelTxt/',# change it to valset for offline validation
imagesetfile='/home/dataset/airplaneDOTA/airplane/val/test_image_list.txt')

(3)图片裁剪策略
由于训练成本的问题,难以直接将4096x4096的图片直接输入到网络中训练,因此需要将图片进行相应的调整。最直观的方法是在图片预处理中将图片直接resize成1024x1024(or更低)的大小,但是这种方法会使一些本就size较小的目标在训练时更小,从而导致模型的训练性能受损,因此该种方法在实际应用中不进行考虑,而是选择合适的裁剪策略。

图片剪裁策略是将4096x4096的图片裁剪成1024x1024的图片。单纯的将4096x4096的图片按比例裁剪成16张1024x1024的图片明显不是个好的方法,因为如果目标正好位于两张图片的交界位置,那么两个图片各有一半的目标会大大影响裁剪性能。一种好的策略是使得裁剪后的图片有部分重合的像素,这样能够很大概率保证待检测的目标能够在某张或者多张裁剪后的图片里完整,同时也能够起到数据增强的目的。在实际应用中,剪裁图片重合的面积越大,实际效果越好(本任务中每条边重合512个像素,50%)。

同时,为了帮助模型训练多尺度的目标,将裁剪后的图片缩放至0.5倍、1.5倍并进行存储,使得数据集中包含同一张图片的0.5、1、1.5倍三种尺度比例的图片。

然后这种策略会导致数据集扩充比较大,一张4096x4096能够裁剪出100多张图片,但是该种方式会使得模型性能提高很多。使用的是ReDet中实现的prepare_dota1_5_v2.py的代码来进行裁剪,github如下:

https://github.com/csuhan/ReDet/blob/master/DOTA_devkit/prepare_dota1_5_v2.py

三:任务相关

(1)自定义数据集格式转换
如前文所言,对于任务所给的自定义数据集,最好的方法是将其转换成现有的写好的数据集标签格式进行训练,这样就免去自定义dataloader的烦恼。如下是常见的目标检测的格式转换code:

目标检测常见数据格式转换:https://github.com/spytensor/prepare_detection_dataset
格式转换完成后,自定义类里只需要继承相应的类,并在mmdet中register一下即可。

(2)mmdetection的版本以及适配CUDA的问题
mmdetecion目前来说已经趋于稳定,但是之前每个大版本之间还是差的有点多的,比如0.几版本都不包含@DATASETS.register_module(),只能在执行setup.py的时候进行注册(无法动态加载模块)。同时还需要适配不同版本的mmcv(mmcv-full)。同时,由于mmdetecion框架安装的时候需要自拟脚本,对CUDA的版本、torch的版本还有一定的要求。笔者在跑s2anet的时候使用的是官方的10.1版本的cuda以及1.3版本的torch,但由于任务的docker要求,需要适配cuda11与torch1.7版本,在”升级的时候“需要修改一下mmdetection安装时的setup.py文件以及相应的torch版本的不同带来的问题,花了不少时间改了很多bug才适配完成。可以参考以下网址:
https://github.com/open-mmlab/mmdetection/issues/3363

https://github.com/pytorch/pytorch/issues/52669

(3)图片预测
s2anet中给出了图片inference的代码示例。给定一张图片, 返回经过检测后的画有bounding box的图片。将待检测图片存入至img_dir的路径中,在out_dir中给出预测的图片。其中图片的预处理方式则采用的是config中data.test中的方式。

import argparse
import os
import os.path as osp
import pdb
import random

import cv2
import mmcv
from mmcv import Config

from mmdet.apis import init_detector, inference_detector
from mmdet.core import rotated_box_to_poly_single
from mmdet.datasets import build_dataset

def show_result_rbox(img,
                     detections,
                     class_names,
                     scale=1.0,
                     threshold=0.2,
                     colormap=None,
                     show_label=False):
    assert isinstance(class_names, (tuple, list))
    if colormap:
        assert len(class_names) == len(colormap)
    img = mmcv.imread(img)
    color_white = (255, 255, 255)

    for j, name in enumerate(class_names):
        if colormap:
            color = colormap[j]
        else:
            color = (random.randint(0, 256), random.randint(0, 256), random.randint(0, 256))
        try:
            dets = detections[j]
        except:
            pdb.set_trace()
        # import ipdb;ipdb.set_trace()
        for det in dets:
            score = det[-1]
            det = rotated_box_to_poly_single(det[:-1])
            bbox = det[:8] * scale
            if score < threshold:
                continue
            bbox = list(map(int, bbox))
     #       print(bbox)
            #[2482, 2230, 2550, 2239, 2542, 2301, 2474, 2292]坐标
            for i in range(3):
                cv2.line(img, (bbox[i * 2], bbox[i * 2 + 1]), (bbox[(i + 1) * 2], bbox[(i + 1) * 2 + 1]), color=color,
                         thickness=2, lineType=cv2.LINE_AA)
            cv2.line(img, (bbox[6], bbox[7]), (bbox[0], bbox[1]), color=color, thickness=2, lineType=cv2.LINE_AA)
            if show_label:
                cv2.putText(img, '%s %.3f' % (class_names[j], score), (bbox[0], bbox[1] + 10),
                            color=color_white, fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5)
    return img


def save_det_result(config_file, out_dir, checkpoint_file=None, img_dir=None, colormap=None):
    cfg = Config.fromfile(config_file)
    data_test = cfg.data.test
    dataset = build_dataset(data_test)
    classnames = dataset.CLASSES
  #  print(classnames)
    # use checkpoint path in cfg
    if not checkpoint_file:
        checkpoint_file = osp.join(cfg.work_dir, 'latest.pth')
  
    # use testset in cfg
    if not img_dir:
        img_dir = data_test.img_prefix

    model = init_detector(config_file, checkpoint_file, device='cuda:0')

    img_list = os.listdir(img_dir)
    for img_name in img_list:
        img_path = osp.join(img_dir, img_name)
        img_out_path = osp.join(out_dir, img_name)
        result = inference_detector(model, img_path)
        img = show_result_rbox(img_path,
                               result,
                               classnames,
                               scale=1.0,
                               threshold=0.5,
                               colormap=colormap)
  #      print(result)
        cv2.imwrite(img_out_path, img)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='inference demo')
    parser.add_argument('--config_file', help='input config file',default="s2anet_dota.py")
    parser.add_argument('--model', help='pretrain model',default="./work_dir/s2anet/latest.pth")
    parser.add_argument('--img_dir', help='img dir',default="example")
    parser.add_argument('--out_dir', help='output dir',default="example_result")
    args = parser.parse_args()

    dota_colormap = [
        (54, 67, 244),
        (99, 30, 233),
        (176, 39, 156),
        (183, 58, 103),
        (181, 81, 63),
        (243, 150, 33),
        (212, 188, 0),
        (136, 150, 0),
        (80, 175, 76),
        (74, 195, 139),
        (57, 220, 205),
        (59, 235, 255),
        (0, 152, 255),
        (34, 87, 255),
        (72, 85, 121)]

    hrsc2016_colormap = [(212, 188, 0)]
    save_det_result(args.config_file, args.out_dir, checkpoint_file=args.model, img_dir=args.img_dir,
                    colormap=dota_colormap)

(4)docker相关
模型跑通以后本以为打docker是件相对轻松的事情,结果最终也花了小两三天的功夫才整通。主要是需要使用官方指定的docker镜像,里面包含指定的torch版本与cuda版本(torch1.7 & cuda11)。由于模型是基于mmdetection完成,之前在服务器中对setup.py以及模型代码进行了相应的修改,使得代码已经在torch1.7 &cuda11中跑通,这里只介绍mmdetection安装docker时的一些坑。
1.适配python版本
之前没考虑环境的python版本问题。指定镜像中使用的是python3.6的版本,而在服务器中使用的是python3.7的版本,导致mmdetection编译的文件在docker中无法跑通。最终在服务器中又重新建了个环境,重新编译了一遍setup.py。
2.docker中不用编译setup.py
配置docker的环境时想当然的去执行setup.py去配置环境,报了各种奇奇怪怪的bug,一个一个改,折腾了好久后最后发现指定的镜像版本是runtime的,压根**不支持编译!**既然无法编译该文件那么是不是就说明无法配置环境?并不是,最终的做法是在服务器上执行python setup.py develop,然后把build文件夹之前copy到打docker的地方,(我们也把mmdet文件夹直接copy过来,代码在本机服务器上运行后里面有pyc文件),然后把其他需要的包pip后就能运行了。
3. 线上CUDA error
在线下用官方的test测试通过后将镜像push到线上,在线上测试时报了error: error
本来我们就是在竞赛最后两周才整代码,最后一周才提交测试,而且才发现一周就两次test的机会,报错还直接占了一次(坑!)。这个错误原因是mmdetection的问题,可能是线上的gpu的算力不支持相应的pytorch版本导致的。需要在编译setup.py时,将python setup.py develop 改成TORCH_CUDA_ARCH_LIST=“3.5 3.7 5.0 5.2 6.0 6.1 7.0 7.5” python setup.py develop,最终解决了该问题。

最后附上我们打docker的文件:

# 基础镜像, cuda为11.0,ubuntu18.04
FROM image.rsaicp.com/base/cuda:11.0-cudnn8-runtime-ubuntu18.04

# 将程序复制到容器内的/work路径下
COPY .  /work
WORKDIR /work

# 配置程序依赖环境
RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cmake \
         curl \
         ca-certificates \
         libjpeg-dev \
         libpng-dev \
         libgl1-mesa-glx \
         libglib2.0-0 \
         libsm6 \
         libxext6 \
         libxrender-dev \
         python3 \
         python3-dev \
         python3-pip && \
     rm -rf /var/lib/apt/lists/* \
     # 升级pip 安装setup tools
     && pip3 install -U pip setuptools \
     # 安装torch
     && pip3 install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html \
     # 安装依赖
     && pip3 install -r /work/requirements.txt \
     && pip3 install opencv-python-headless 
     # 容器启动命令
    CMD ["python3", "-u", "/work/main.py"]

(5)线上测试相关
在线上测试的时候我们发现,前排队伍模型的预测时间基本在一个小时以上,而我们的预测时间只需要十分钟。我们的模型在最后测试时没有进行图片的剪裁(虽然在训练和验证时有剪裁),而是直接将图片输入模型进行处理,最终导致模型的精度不高。可行的做法是类似模型训练与验证时的操作,对测试目录下的图片首先进行剪裁,进行测试后再将结果进行合并,合理利用游戏规则。

总结

通过此次任务对于目标检测、尤其是遥感目标检测的一般方法有了比较清晰的认识、对于mmdetecion也框架也有了比较深刻的理解。对目标检测各个常见数据都跑了一遍,数据处理了一遍,也掌握了自定义数据集的数据处理方法,未来可以很快上手,完成相似的任务。

版权声明:本文为CSDN博主「Big Watermonster~」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43407090/article/details/120973936

Big Watermonster~

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐