mmdetection中的dataset

1、train.py

datasets = [build_dataset(cfg.data.train)]

train.py中使用上述代码实例化一个dataset。其中传入的参数如下:
数据集配置文件

2、build_dataset ; mmdet/datasets/builder.py

调用build_from_cfg()方法。其中DATASETS是提前注册好的注册库。

def build_dataset(cfg, default_args=None):
    from .dataset_wrappers import (ConcatDataset, RepeatDataset,
                                   ClassBalancedDataset)
    if isinstance(cfg, (list, tuple)):
        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
    elif cfg['type'] == 'ConcatDataset':
        dataset = ConcatDataset(
            [build_dataset(c, default_args) for c in cfg['datasets']],
            cfg.get('separate_eval', True))
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(
            build_dataset(cfg['dataset'], default_args), cfg['times'])
    elif cfg['type'] == 'ClassBalancedDataset':
        dataset = ClassBalancedDataset(
            build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
    elif isinstance(cfg.get('ann_file'), (list, tuple)):
        dataset = _concat_dataset(cfg, default_args)
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

3、build_from_cfg ; mmcv/utils/registry.py

从注册库中挑选相应的类(type),然后根据cfg中的参数实例化。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

4、CocoFmtDataset.init ; mmdet/datasets/cocofmt.py

根据传入的type参数跳转到相应的初始化方法。这里我们传入的是CocoFmtDataset。具体传入的参数如步骤一。这部分初始化过程比较复杂,传入init方法的参数如下图:
在这里插入图片描述可以发现这里只是将除了type之外的所有参数传入名字叫做type的这个类的init方法。

先搞清楚CocoFmtDataset这个类的继承关系:

  1. CocoFmtDataset继承CocoDataset;
  2. CocoDataset继承CustomDataset;
  3. CustomDataset继承Dataset。

有很对方法被重写,这里一定要先搞清楚各个类的关系。
其中Dataset是torch.utils.data中的类,我们暂且不做研究。

CocoFmtDataset.init:

  1. 给self.train_ignore_as_bg 和 self.merge_after_infer_kwargs 两个变量赋值。都是默认的参数。
  2. 调用父类的父类CustomDataset的init方法。
    def __init__(self,
                 ann_file,
                 data_root=None,
                 corner_kwargs=None,
                 train_ignore_as_bg=True,
                 noise_kwargs=None,
                 merge_after_infer_kwargs=None,
                 **kwargs):
        # add by hui, if there is not corner dataset, create one
        if corner_kwargs is not None:
            assert ann_file[-5:] == '.json', "ann_file must be a json file."
            ann_file = generate_corner_json_file_if_not_exist(ann_file, data_root, corner_kwargs)
            print("load corner dataset json file from {}".format(ann_file))
        if noise_kwargs is not None:
            if 'pseudo_wh' in noise_kwargs and noise_kwargs['pseudo_wh'] is not None:
                ann_file = generate_pesudo_bbox_for_noise_data(ann_file, data_root, noise_kwargs)
            elif 'wh_suffix' in noise_kwargs:
                from .noise_data_utils import get_new_json_file_path
                ann_file, _ = get_new_json_file_path(ann_file, data_root, noise_kwargs['sub_dir'],
                                                     noise_kwargs['wh_suffix'])
            else:
                raise ValueError('one of [pseudo_wh, wh_suffix] must be given')
            print("load noise dataset json file from {}".format(ann_file))

        self.train_ignore_as_bg = train_ignore_as_bg
        self.merge_after_infer_kwargs = merge_after_infer_kwargs

        super(CocoFmtDataset, self).__init__(
            ann_file,
            data_root=data_root,
            **kwargs
        )

CustomDataset.init:

由于CocoDataset没有初始化方法,这里跳转到CustomDataset.init。传入的参数依旧是cfg中的参数。

一、首先将相关参数保存到属性中,值得注意的是这里的self是CocoFmtDataset。

    def __init__(self,
                 ann_file,
                 pipeline,
                 classes=None,
                 data_root=None,
                 img_prefix='',
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,
                 filter_empty_gt=True):
        self.ann_file = ann_file
        self.data_root = data_root
        self.img_prefix = img_prefix
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt

二、调用get_classes方法获取类别列表。

此时的self.CLASS=None,因为cfg配置文件中并没有设置类别名称。

self.CLASSES = self.get_classes(classes)
    @classmethod
    def get_classes(cls, classes=None):
        """Get class names of current dataset.
        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.

        Returns:
            tuple[str] or list[str]: Names of categories of the dataset.
        """
        if classes is None:
            return cls.CLASSES
        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

三、拼接文件路径,这里访问文件使用的路径都是绝对路径。

由于我们传入的都是绝对路径,所以这里不需要拼接。self.data_root也是默认的None.

        if self.data_root is not None:
            if not osp.isabs(self.ann_file):
                self.ann_file = osp.join(self.data_root, self.ann_file)
            if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
                self.img_prefix = osp.join(self.data_root, self.img_prefix)
            if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
                self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
            if not (self.proposal_file is None or osp.isabs(self.proposal_file)):
                self.proposal_file = osp.join(self.data_root,self.proposal_file)

四、调用load_annotations方法加载标注文件中的图片信息,并将该信息存放在属性self.data_infos中。

这个地方就要好好记录一下:

  1. 根据标注文件路径实例化一个COCO类;
  2. 使用标签中的类别信息给self.CLASS赋值;
  3. 给self.cat_ids赋值,该属性记录类别编号,我们的数据集只有一个类别,这里self.cat_ids=[1];
  4. 给self.cat2label赋值,该属性记录类别标号的位置,这里self.cat2label={1:0};
  5. 获得所有的图片编号并存入self.img_ids中;
  6. 遍历所有图片信息,给每个图片信息新加一个filename=file_name字典,并检查该标注文件的标签和图片能否对应起来。
  7. 最后返回整理好的图片信息。

所以这里的load_annotations并不是仅仅返回图片的信息,它还给实例化的CocoFmtDataset类添加了很多的属性,并且检查了标注数据是否有问题。

self.data_infos = self.load_annotations(self.ann_file)
    def load_annotations(self, ann_file):
        """Load annotation from COCO style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from COCO api.
        """

        self.coco = COCO(ann_file)
        if self.CLASSES is None:
            self.CLASSES = [cat['name'] for cat in self.coco.dataset['categories']]  # add by hui
        # The order of returned `cat_ids` will not
        # change with the order of the CLASSES
        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)

        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []
        total_ann_ids = []
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
            ann_ids = self.coco.get_ann_ids(img_ids=[i])
            total_ann_ids.extend(ann_ids)
        assert len(set(total_ann_ids)) == len(
            total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
        return data_infos

五、给self.proposal_file赋值

这个暂时不知道他的功能,本人使用时self.proposal_file=None

        if self.proposal_file is not None:
            self.proposals = self.load_proposals(self.proposal_file)
        else:
            self.proposals = None

六、过滤掉所占像素过小的标签和没有标签的图片

这个过程也比较复杂,我们看看源码是如何操作的。

        if not test_mode:
            valid_inds = self._filter_imgs()
            self.data_infos = [self.data_infos[i] for i in valid_inds]
            if self.proposals is not None:
                self.proposals = [self.proposals[i] for i in valid_inds]
            # set group flag for the sampler
            self._set_group_flag()

1、首先检查是否是测试,在测试过程中就不能过滤掉这些信息。
2、调用_self.filter_imgs方法,获取图片列表中符合要求的索引,该方法在CocoFmtDataset中被重写。

    def _filter_imgs(self, min_size=32):
        valid_inds = super(CocoFmtDataset, self)._filter_imgs(min_size)
        print("valid image count: ", len(valid_inds))   # add by hui
        return valid_inds

这里也是添加一个默认参数min_size=32后直接调用父类(CocoDataset)的_filter_imgs方法。
下面是CocoDataset._filter_imgs方法的实现源码:

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        # obtain images that contain annotation
        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
        # obtain images that contain annotations of the required categories
        ids_in_cat = set()
        for i, class_id in enumerate(self.cat_ids):
            ids_in_cat |= set(self.coco.cat_img_map[class_id])
        # merge the image id sets of the two conditions and use the merged set
        # to filter out images if self.filter_empty_gt=True
        ids_in_cat &= ids_with_ann

        valid_img_ids = []
        for i, img_info in enumerate(self.data_infos):
            img_id = self.img_ids[i]
            if self.filter_empty_gt and img_id not in ids_in_cat:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
                valid_img_ids.append(img_id)
        self.img_ids = valid_img_ids
        return valid_inds

这个方法的操作流程:

  1. 获取所有存在标签的图片编号;
  2. 获取所有带有标签的图片编号,且这些标签都有类别;
  3. 将上面的两个set合并,表示所有有标签的图片编号;
  4. 遍历数据集中的所有图片,获取所有有标签且图片尺寸大于32的图片编号;
  5. 将合法的图片标号保存在self.img_ids中,返回合法图片的索引。

3、根据合法索引,重新给self.data_infos赋值,仅仅保留符合要求的图片信息。
4、又是self.proposals,这里的self.proposals依然是None。
5、然后调用self._set_group_flag()方法,

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            img_info = self.data_infos[i]
            if img_info['width'] / img_info['height'] > 1:
                self.flag[i] = 1

这里会添加一个self.flag属性,这个属性记录图片是横着的还是竖着的。

6、最后添加self.pipeline属性,这个属性十分重要,其中定义了对原始图片的增强或者说变换的流程,后面会详细介绍这个属性。

至此,dataset的初始化就结束了。

pipeline

LoadImageFromFile

关于数据管道中的类有很多,这里仅仅详细介绍以下LoadImageFromFile这个类。

@PIPELINES.register_module()
class LoadImageFromFile:
    """Load an image from file.

    Required keys are "img_prefix" and "img_info" (a dict that must contain the
    key "filename"). Added or updated keys are "filename", "img", "img_shape",
    "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
    "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).

    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
            Defaults to 'color'.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmcv.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
    """

    def __init__(self,
                 to_float32=False,
                 color_type='color',
                 file_client_args=dict(backend='disk')):
        self.to_float32 = to_float32
        self.color_type = color_type
        self.file_client_args = file_client_args.copy()
        self.file_client = None

    def __call__(self, results):
        """Call functions to load image and get image meta information.

        Args:
            results (dict): Result dict from :obj:`mmdet.CustomDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """

        if self.file_client is None:
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results['img_prefix'] is not None:
            filename = osp.join(results['img_prefix'],
                                results['img_info']['filename'])
        else:
            filename = results['img_info']['filename']

        img_bytes = self.file_client.get(filename)
        img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
        if self.to_float32:
            img = img.astype(np.float32)
        # add by hui ####################################################
        if 'corner' in results['img_info']:
            l, u, r, b = results['img_info']['corner']
            img = img[u:b, l:r]
            assert img.shape[0] * img.shape[1] > 0
            results['corner'] = results['img_info']['corner']
        # ###############################################################

        results['filename'] = filename
        results['ori_filename'] = results['img_info']['filename']
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        results['img_fields'] = ['img']
        return results

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'to_float32={self.to_float32}, '
                    f"color_type='{self.color_type}', "
                    f'file_client_args={self.file_client_args})')
        return repr_str

与mmdetection中其它模块一样,该模块使用@PIPELINES.register_module()在PIPELINES中被注册。

LoadImageFromFile一共有三个方法:

  1. 初始化方法init():
    初始化方法中声明一些属性,这些属性会在call方法中使用。

  2. 重载 () 运算符的call()方法:

  3. 用于显示实例化对象信息的 repr()方法:
    该方法就是构造一个与实例化对象相关的字符串,然后返回该字符串。这个方法也是为了在调试过程中查看实例化对象的相关信息的。

这里call方法执行过程如下:

  1. 获取一个文件读取器,这个类可以自定义,如果用户没有定义就会选择默认的mmcv.FileClient(‘backend’=‘disk’);
  2. 根据方法的输入,构造图片的绝对路径;
  3. 根据路径将图片加载到内存;
  4. 根据图片中的corner信息裁减图片;
  5. 整理图片信息并返回。

这里需要注意:1、本人在调试过程中使用的tinyPerson数据集,该数据集中的标注信息有corner。2、call方法的输入是从dataloader获取的一个字典,该字典中包含图像、标注等相关信息。

LoadAnnotations

这个类的方法比较多,除了init、call和repr之外,又多出来一些为call方法服务的方法。

init和repr方法和LoadImageFromFile中的同名方法的功能一样,这里就不再赘述。

call方法的功能是根据init中初始化的标志位读取相关标注信息并返回。代码如下:

    def __call__(self, results):
        """Call function to load multiple types annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet.CustomDataset`.

        Returns:
            dict: The dict contains loaded bounding box, label, mask and
                semantic segmentation annotations.
        """

        if self.with_bbox:
            results = self._load_bboxes(results)
            if results is None:
                return None
        if self.with_label:
            results = self._load_labels(results)
        if self.with_mask:
            results = self._load_masks(results)
        if self.with_seg:
            results = self._load_semantic_seg(results)
        return results

这里判断的标志位都在init初始化的时候根据cfg被赋值。

_load_bboxes

该方法就是将藏在result深处的真实标注框拿出来,拿到外层来。( results[‘gt_bboxes’] = results[‘ann_info’][‘bboxes’].copy() )

_load_labels

同上。( results[‘gt_labels’] = results[‘ann_info’][‘labels’].copy() )

我的实验是关于目标检测的,所以并没有跳入_load_masks、_load_semantic_seg方法。

Resize

我们直接来看Ressize.call()

  1. 给result添加scale_factor=1.0的键值对(这里也是因为tinyPerson数据集的特殊性);
  2. 计算缩放后图片的尺度,这里scale=1.0(特殊数据集);
  3. 将图片缩放到期望的大小;
  4. 将标注框也缩放到相同的尺度;
  5. 返回整理好的result。

RandomFlip

一定概率进行水平反转。

Normalize

调用mmcv.imnormalize方法对图像进行归一化。

Pad

调用mmcv.impad_to_multiple将图片填充到期望的大小。为什么已经resize的图片还需要pad?原因是有些情况下我们在resize的时候需要保持图片的长宽比,这样的resize无法保证一定可以将图片缩放到期望的大小。但是这里由于我的数据集特殊,并不需要pad。也就是说这里作者虽然在cfg中配置的pad的过程,但是该过程并没有什么用。

DefaultFormatBundle

这个过程对result中的缺失参数进行补充,并将相关数据封装成tensor格式。

Collect

重新整理result中的参数。这里需要注意,不要把该方法与dataloader中的collet弄混了,两个方法的功能不一样。

结束

这部分内容就先总结到这里,记录是为了方便自己未来对知识点的回顾,也希望能够帮助到正在看此篇博客的你。

版权声明:本文为CSDN博主「大胡子7777」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_43403200/article/details/121584490

大胡子7777

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐