基于mmdetection训练Swin Transformer Object Detection

mmdetection官方文档
环境搭建
docker
找了一个torch版本为1.5.1+cu101的docker环境,然后安装mmdetection环境

pip install mmcv-full
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection-master
pip install -r requirements/build.txt
pip install -v -e .

安装apex

git clone https://github.com/NVIDIA/apex
cd apex
pip install -r requirements.txt
python setup.py install --cpp_ext

安装成功

Processing dependencies for apex==0.1
Finished processing dependencies for apex==0.1
  • backbone:mmdet/models/backbones
  • neck:mmdet/models/necks
  • head:mmdet/models/roi_heads
  • BBox Assigner:mmdet/core/bbox/assigners
  • BBox Sampler:mmdet/core/bbox/samplers
  • BBox Encoder:mmdet/core/bbox/coder
  • BBox Decoder:mmdet/core/bbox/coder
  • Loss:mmdet/models/losses
  • BBox PostProcess:mmdet/core/post_processing

在"Swin-Transformer-Object-Detection-master/configs/swin/"目录下,可以看到模型文件,选择对应的修改
以"cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py"为例:

# head为例
roi_head=dict(
        bbox_head=[
            dict(
                type='ConvFCBBoxHead',
                num_shared_convs=4,
                num_shared_fcs=1,
                in_channels=256,
                conv_out_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=15,  # 修改类别数量
# 根据gpu的数量,使用合适的BN
# norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),

# 调整学习率等相关参数,lr = 0.00125*batch_size
optimizer = dict(_delete_=True, type='AdamW', lr=0.00125, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
# 修改epoch
runner = dict(type='EpochBasedRunner', max_epochs=20)                                                  
# 不适用fp16,将use_fp16改为False
optimizer_config = dict(
    type="DistOptimizerHook",
    update_interval=1,
    grad_clip=None,
    coalesce=True,
    bucket_size_mb=-1,
    use_fp16=False,
)

在"configs/base/datasets/coco_instance.py"中根据需要修改

# 修改数据集的类型,路径
dataset_type = 'CocoDataset'
data_root = '/home/coco/'

# 修改img_size等参数,CUDA out of memory时可以修改
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    # 原本为1333*800
    #dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='Resize', img_scale=(416, 416), keep_ratio=True),

# 修改batch_size
data = dict(
    samples_per_gpu=1, # 每块GPU上的sample个数,batch_size = gpu数目*该参数
    workers_per_gpu=1, # 每块GPU上的workers的个数
    # 以train为例
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json', # 标注路径
        img_prefix=data_root + 'train2017/', # 训练图片路径
        pipeline=train_pipeline),

修改类别:mmdet/datasets/coco.py和 mmdet/core/evaluation/class_names.py文件

class CocoDataset(CustomDataset):

    #CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #           'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #           'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #           'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #           'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #           'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #           'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #           'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #           'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #           'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #           'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
    CLASSES = ('person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car')
def coco_classes():
    return ['person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car']

修改"./tools/train.py"文件

# 选取其中一种版本,单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

模型预训练,权重加载、保存参数,config/base/default_runtime.py文件

checkpoint_config = dict(interval=1) # 每训练一个epoch,保存一次权重
load_from = None # 加载backbone权重
resume_from = None # 继续训练

训练模型
使用编号为3的单个gpu训练

python ./tools/train.py configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py --gpu-ids 3

使用多gpu训练

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 4

训练Log及权重
保存在"Swin-Transformer-Object-Detection-master/work_dirs/"中

coco测试

python tools/test.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py cascade_mask_rcnn_swin_small_patch4_window7.pth --eval segm

输出demo,输出为cls,x1,y1,x2,y2的txt格式

from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector
import numpy as np
import os
from tqdm import tqdm

def main():
    parser = ArgumentParser()
    parser.add_argument('--img-path', default='/data/wj/test/',help='Image file')
    parser.add_argument('--config', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco.py' ,help='Config file')
    parser.add_argument('--checkpoint', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/latest.pth', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    args = parser.parse_args()
    imgs_path = args.img_path
    save_path = '../output/'

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    for img_path in tqdm(os.listdir(imgs_path)):
        img = os.path.join(imgs_path, img_path)
        result = inference_detector(model, img)
        bboxes = np.vstack(result)
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(result)
        ]
        labels = np.concatenate(labels)
        score_thr = args.score_thr
        if score_thr > 0:
            assert bboxes.shape[1] == 5
            scores = bboxes[:, -1]
            inds = scores > score_thr
            bboxes = bboxes[inds, :]
            labels = labels[inds]
        if len(bboxes) == 0:
            txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
            with open(txt_path, 'w') as f:
                f.write("")
        for i, (bbox, label) in enumerate(zip(bboxes, labels)):
            bbox_int = bbox.astype(np.int32)
            x1, y1, x2, y2, conf = bbox_int
            txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
            with open(txt_path, 'a') as f:
                f.write("{} {} {} {} {}\n".format(label, x1, y1, x2, y2))

踩过的坑及解决方案:
error with env var RANK
参考:
轻松掌握 MMDetection 整体构建流程(一)

版权声明:本文为CSDN博主「云端一散仙」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_44347020/article/details/116153242

云端一散仙

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

Pytorch—万字入门SSD物体检测

前言 由于初入物体检测领域,我在学习SSD模型的时候遇到了很多的困难。一部分困难在于相关概念不清楚,专业词汇不知其意,相关文章不知所云;另一部分困难在于网上大部分文章要么只是简要介绍了SS

YoloV5实战:手把手教物体检测——YoloV5

目录 摘要 训练 1、下载代码 2、配置环境 3、准备数据集 4、生成数据集 5、修改配置参数 6、修改train.py的参数 7、查看训练结果 测试 摘要 YOLOV5严格意义上说并不是YOLO的第五个版本&#xff0c