文章目录[隐藏]
1、train.py
datasets = [build_dataset(cfg.data.train)]
train.py中使用上述代码实例化一个dataset。其中传入的参数如下:
2、build_dataset ; mmdet/datasets/builder.py
调用build_from_cfg()方法。其中DATASETS是提前注册好的注册库。
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
3、build_from_cfg ; mmcv/utils/registry.py
从注册库中挑选相应的类(type),然后根据cfg中的参数实例化。
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
4、CocoFmtDataset.init ; mmdet/datasets/cocofmt.py
根据传入的type参数跳转到相应的初始化方法。这里我们传入的是CocoFmtDataset。具体传入的参数如步骤一。这部分初始化过程比较复杂,传入init方法的参数如下图:
可以发现这里只是将除了type之外的所有参数传入名字叫做type的这个类的init方法。
先搞清楚CocoFmtDataset这个类的继承关系:
- CocoFmtDataset继承CocoDataset;
- CocoDataset继承CustomDataset;
- CustomDataset继承Dataset。
有很对方法被重写,这里一定要先搞清楚各个类的关系。
其中Dataset是torch.utils.data中的类,我们暂且不做研究。
CocoFmtDataset.init:
- 给self.train_ignore_as_bg 和 self.merge_after_infer_kwargs 两个变量赋值。都是默认的参数。
- 调用父类的父类CustomDataset的init方法。
def __init__(self,
ann_file,
data_root=None,
corner_kwargs=None,
train_ignore_as_bg=True,
noise_kwargs=None,
merge_after_infer_kwargs=None,
**kwargs):
# add by hui, if there is not corner dataset, create one
if corner_kwargs is not None:
assert ann_file[-5:] == '.json', "ann_file must be a json file."
ann_file = generate_corner_json_file_if_not_exist(ann_file, data_root, corner_kwargs)
print("load corner dataset json file from {}".format(ann_file))
if noise_kwargs is not None:
if 'pseudo_wh' in noise_kwargs and noise_kwargs['pseudo_wh'] is not None:
ann_file = generate_pesudo_bbox_for_noise_data(ann_file, data_root, noise_kwargs)
elif 'wh_suffix' in noise_kwargs:
from .noise_data_utils import get_new_json_file_path
ann_file, _ = get_new_json_file_path(ann_file, data_root, noise_kwargs['sub_dir'],
noise_kwargs['wh_suffix'])
else:
raise ValueError('one of [pseudo_wh, wh_suffix] must be given')
print("load noise dataset json file from {}".format(ann_file))
self.train_ignore_as_bg = train_ignore_as_bg
self.merge_after_infer_kwargs = merge_after_infer_kwargs
super(CocoFmtDataset, self).__init__(
ann_file,
data_root=data_root,
**kwargs
)
CustomDataset.init:
由于CocoDataset没有初始化方法,这里跳转到CustomDataset.init。传入的参数依旧是cfg中的参数。
一、首先将相关参数保存到属性中,值得注意的是这里的self是CocoFmtDataset。
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
二、调用get_classes方法获取类别列表。
此时的self.CLASS=None,因为cfg配置文件中并没有设置类别名称。
self.CLASSES = self.get_classes(classes)
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
三、拼接文件路径,这里访问文件使用的路径都是绝对路径。
由于我们传入的都是绝对路径,所以这里不需要拼接。self.data_root也是默认的None.
if self.data_root is not None:
if not osp.isabs(self.ann_file):
self.ann_file = osp.join(self.data_root, self.ann_file)
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
self.img_prefix = osp.join(self.data_root, self.img_prefix)
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
if not (self.proposal_file is None or osp.isabs(self.proposal_file)):
self.proposal_file = osp.join(self.data_root,self.proposal_file)
四、调用load_annotations方法加载标注文件中的图片信息,并将该信息存放在属性self.data_infos中。
这个地方就要好好记录一下:
- 根据标注文件路径实例化一个COCO类;
- 使用标签中的类别信息给self.CLASS赋值;
- 给self.cat_ids赋值,该属性记录类别编号,我们的数据集只有一个类别,这里self.cat_ids=[1];
- 给self.cat2label赋值,该属性记录类别标号的位置,这里self.cat2label={1:0};
- 获得所有的图片编号并存入self.img_ids中;
- 遍历所有图片信息,给每个图片信息新加一个filename=file_name字典,并检查该标注文件的标签和图片能否对应起来。
- 最后返回整理好的图片信息。
所以这里的load_annotations并不是仅仅返回图片的信息,它还给实例化的CocoFmtDataset类添加了很多的属性,并且检查了标注数据是否有问题。
self.data_infos = self.load_annotations(self.ann_file)
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
self.coco = COCO(ann_file)
if self.CLASSES is None:
self.CLASSES = [cat['name'] for cat in self.coco.dataset['categories']] # add by hui
# The order of returned `cat_ids` will not
# change with the order of the CLASSES
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
total_ann_ids = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
ann_ids = self.coco.get_ann_ids(img_ids=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
return data_infos
五、给self.proposal_file赋值
这个暂时不知道他的功能,本人使用时self.proposal_file=None
if self.proposal_file is not None:
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
六、过滤掉所占像素过小的标签和没有标签的图片
这个过程也比较复杂,我们看看源码是如何操作的。
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# set group flag for the sampler
self._set_group_flag()
1、首先检查是否是测试,在测试过程中就不能过滤掉这些信息。
2、调用_self.filter_imgs方法,获取图片列表中符合要求的索引,该方法在CocoFmtDataset中被重写。
def _filter_imgs(self, min_size=32):
valid_inds = super(CocoFmtDataset, self)._filter_imgs(min_size)
print("valid image count: ", len(valid_inds)) # add by hui
return valid_inds
这里也是添加一个默认参数min_size=32后直接调用父类(CocoDataset)的_filter_imgs方法。
下面是CocoDataset._filter_imgs方法的实现源码:
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
这个方法的操作流程:
- 获取所有存在标签的图片编号;
- 获取所有带有标签的图片编号,且这些标签都有类别;
- 将上面的两个set合并,表示所有有标签的图片编号;
- 遍历数据集中的所有图片,获取所有有标签且图片尺寸大于32的图片编号;
- 将合法的图片标号保存在self.img_ids中,返回合法图片的索引。
3、根据合法索引,重新给self.data_infos赋值,仅仅保留符合要求的图片信息。
4、又是self.proposals,这里的self.proposals依然是None。
5、然后调用self._set_group_flag()方法,
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
for i in range(len(self)):
img_info = self.data_infos[i]
if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1
这里会添加一个self.flag属性,这个属性记录图片是横着的还是竖着的。
6、最后添加self.pipeline属性,这个属性十分重要,其中定义了对原始图片的增强或者说变换的流程,后面会详细介绍这个属性。
至此,dataset的初始化就结束了。
pipeline
LoadImageFromFile
关于数据管道中的类有很多,这里仅仅详细介绍以下LoadImageFromFile这个类。
@PIPELINES.register_module()
class LoadImageFromFile:
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk')):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
def __call__(self, results):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results['img_prefix'] is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
if self.to_float32:
img = img.astype(np.float32)
# add by hui ####################################################
if 'corner' in results['img_info']:
l, u, r, b = results['img_info']['corner']
img = img[u:b, l:r]
assert img.shape[0] * img.shape[1] > 0
results['corner'] = results['img_info']['corner']
# ###############################################################
results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f'file_client_args={self.file_client_args})')
return repr_str
与mmdetection中其它模块一样,该模块使用@PIPELINES.register_module()在PIPELINES中被注册。
LoadImageFromFile一共有三个方法:
-
初始化方法init():
初始化方法中声明一些属性,这些属性会在call方法中使用。 -
重载 () 运算符的call()方法:
-
用于显示实例化对象信息的 repr()方法:
该方法就是构造一个与实例化对象相关的字符串,然后返回该字符串。这个方法也是为了在调试过程中查看实例化对象的相关信息的。
这里call方法执行过程如下:
- 获取一个文件读取器,这个类可以自定义,如果用户没有定义就会选择默认的mmcv.FileClient(‘backend’=‘disk’);
- 根据方法的输入,构造图片的绝对路径;
- 根据路径将图片加载到内存;
- 根据图片中的corner信息裁减图片;
- 整理图片信息并返回。
这里需要注意:1、本人在调试过程中使用的tinyPerson数据集,该数据集中的标注信息有corner。2、call方法的输入是从dataloader获取的一个字典,该字典中包含图像、标注等相关信息。
LoadAnnotations
这个类的方法比较多,除了init、call和repr之外,又多出来一些为call方法服务的方法。
init和repr方法和LoadImageFromFile中的同名方法的功能一样,这里就不再赘述。
call方法的功能是根据init中初始化的标志位读取相关标注信息并返回。代码如下:
def __call__(self, results):
"""Call function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if self.with_bbox:
results = self._load_bboxes(results)
if results is None:
return None
if self.with_label:
results = self._load_labels(results)
if self.with_mask:
results = self._load_masks(results)
if self.with_seg:
results = self._load_semantic_seg(results)
return results
这里判断的标志位都在init初始化的时候根据cfg被赋值。
_load_bboxes
该方法就是将藏在result深处的真实标注框拿出来,拿到外层来。( results[‘gt_bboxes’] = results[‘ann_info’][‘bboxes’].copy() )
_load_labels
同上。( results[‘gt_labels’] = results[‘ann_info’][‘labels’].copy() )
我的实验是关于目标检测的,所以并没有跳入_load_masks、_load_semantic_seg方法。
Resize
我们直接来看Ressize.call()
- 给result添加scale_factor=1.0的键值对(这里也是因为tinyPerson数据集的特殊性);
- 计算缩放后图片的尺度,这里scale=1.0(特殊数据集);
- 将图片缩放到期望的大小;
- 将标注框也缩放到相同的尺度;
- 返回整理好的result。
RandomFlip
一定概率进行水平反转。
Normalize
调用mmcv.imnormalize方法对图像进行归一化。
Pad
调用mmcv.impad_to_multiple将图片填充到期望的大小。为什么已经resize的图片还需要pad?原因是有些情况下我们在resize的时候需要保持图片的长宽比,这样的resize无法保证一定可以将图片缩放到期望的大小。但是这里由于我的数据集特殊,并不需要pad。也就是说这里作者虽然在cfg中配置的pad的过程,但是该过程并没有什么用。
DefaultFormatBundle
这个过程对result中的缺失参数进行补充,并将相关数据封装成tensor格式。
Collect
重新整理result中的参数。这里需要注意,不要把该方法与dataloader中的collet弄混了,两个方法的功能不一样。
结束
这部分内容就先总结到这里,记录是为了方便自己未来对知识点的回顾,也希望能够帮助到正在看此篇博客的你。
版权声明:本文为CSDN博主「大胡子7777」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_43403200/article/details/121584490
暂无评论