Detectron2注册自己的COCO数据集

Detectron2注册自己的COCO数据集

1.在train.py添加

CLASS_NAMES =['background','A220', 'A330', 'A320/321', 'Boeing737-800', 'Boeing787', 'ARJ21', 'other']

# 数据集路径
DATASET_ROOT = './datasets/coco'
ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')

TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')

TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
#VAL_JSON = os.path.join(ANN_ROOT, 'val.json')
VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')

# 声明数据集的子集
PREDEFINED_SPLITS_DATASET = {
    "coco_my_train": (TRAIN_PATH, TRAIN_JSON),
    "coco_my_val": (VAL_PATH, VAL_JSON),
}

def plain_register_dataset():
    #训练集
    DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH))
    MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES,  # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭
                                                    evaluator_type='coco', # 指定评估方式
                                                    json_file=TRAIN_JSON,
                                                    image_root=TRAIN_PATH)

    #DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val"))
    #验证/测试集
    DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH))
    MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭
                                                evaluator_type='coco', # 指定评估方式
                                                json_file=VAL_JSON,
                                                image_root=VAL_PATH)

注意导入from detectron2.data.datasets.coco import load_coco_json和from detectron2.data import DatasetCatalog, MetadataCatalog

2.在main()函数中调用

def main(args):
    cfg = setup(args)
    plain_register_dataset()#调用注册函数
    if args.eval_only:
        model = Trainer.build_model(cfg)
        SLRDetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        # test
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

3.在config文件中修改

DATASETS:
  TRAIN: ("coco_my_train",)#上文中的注册名
  TEST: ("coco_my_val",)#上文中的注册名

恭喜你注册完成!!!

https://blog.csdn.net/qq_29750461/article/details/106761382

版权声明:本文为CSDN博主「SSSLaker」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/SSSKXP/article/details/122155505

SSSLaker

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐