mmdetection官方文档
环境搭建
docker
找了一个torch版本为1.5.1+cu101的docker环境,然后安装mmdetection环境
pip install mmcv-full
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection-master
pip install -r requirements/build.txt
pip install -v -e .
安装apex
git clone https://github.com/NVIDIA/apex
cd apex
pip install -r requirements.txt
python setup.py install --cpp_ext
安装成功
Processing dependencies for apex==0.1
Finished processing dependencies for apex==0.1
- backbone:mmdet/models/backbones
- neck:mmdet/models/necks
- head:mmdet/models/roi_heads
- BBox Assigner:mmdet/core/bbox/assigners
- BBox Sampler:mmdet/core/bbox/samplers
- BBox Encoder:mmdet/core/bbox/coder
- BBox Decoder:mmdet/core/bbox/coder
- Loss:mmdet/models/losses
- BBox PostProcess:mmdet/core/post_processing
在"Swin-Transformer-Object-Detection-master/configs/swin/"目录下,可以看到模型文件,选择对应的修改
以"cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py"为例:
# head为例
roi_head=dict(
bbox_head=[
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=15, # 修改类别数量
# 根据gpu的数量,使用合适的BN
# norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),
# 调整学习率等相关参数,lr = 0.00125*batch_size
optimizer = dict(_delete_=True, type='AdamW', lr=0.00125, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
# 修改epoch
runner = dict(type='EpochBasedRunner', max_epochs=20)
# 不适用fp16,将use_fp16改为False
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=False,
)
在"configs/base/datasets/coco_instance.py"中根据需要修改
# 修改数据集的类型,路径
dataset_type = 'CocoDataset'
data_root = '/home/coco/'
# 修改img_size等参数,CUDA out of memory时可以修改
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
# 原本为1333*800
#dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='Resize', img_scale=(416, 416), keep_ratio=True),
# 修改batch_size
data = dict(
samples_per_gpu=1, # 每块GPU上的sample个数,batch_size = gpu数目*该参数
workers_per_gpu=1, # 每块GPU上的workers的个数
# 以train为例
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json', # 标注路径
img_prefix=data_root + 'train2017/', # 训练图片路径
pipeline=train_pipeline),
修改类别:mmdet/datasets/coco.py和 mmdet/core/evaluation/class_names.py文件
class CocoDataset(CustomDataset):
#CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
# 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
# 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
# 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
# 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
# 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
# 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
# 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
# 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
# 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
# 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
# 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
CLASSES = ('person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car')
def coco_classes():
return ['person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car']
修改"./tools/train.py"文件
# 选取其中一种版本,单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
模型预训练,权重加载、保存参数,config/base/default_runtime.py文件
checkpoint_config = dict(interval=1) # 每训练一个epoch,保存一次权重
load_from = None # 加载backbone权重
resume_from = None # 继续训练
训练模型
使用编号为3的单个gpu训练
python ./tools/train.py configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py --gpu-ids 3
使用多gpu训练
tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 4
训练Log及权重
保存在"Swin-Transformer-Object-Detection-master/work_dirs/"中
coco测试
python tools/test.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py cascade_mask_rcnn_swin_small_patch4_window7.pth --eval segm
输出demo,输出为cls,x1,y1,x2,y2的txt格式
from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector
import numpy as np
import os
from tqdm import tqdm
def main():
parser = ArgumentParser()
parser.add_argument('--img-path', default='/data/wj/test/',help='Image file')
parser.add_argument('--config', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco.py' ,help='Config file')
parser.add_argument('--checkpoint', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/latest.pth', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
imgs_path = args.img_path
save_path = '../output/'
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
for img_path in tqdm(os.listdir(imgs_path)):
img = os.path.join(imgs_path, img_path)
result = inference_detector(model, img)
bboxes = np.vstack(result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
score_thr = args.score_thr
if score_thr > 0:
assert bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
if len(bboxes) == 0:
txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
with open(txt_path, 'w') as f:
f.write("")
for i, (bbox, label) in enumerate(zip(bboxes, labels)):
bbox_int = bbox.astype(np.int32)
x1, y1, x2, y2, conf = bbox_int
txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
with open(txt_path, 'a') as f:
f.write("{} {} {} {} {}\n".format(label, x1, y1, x2, y2))
踩过的坑及解决方案:
error with env var RANK
参考:
轻松掌握 MMDetection 整体构建流程(一)
版权声明:本文为CSDN博主「云端一散仙」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_44347020/article/details/116153242
暂无评论