文章目录[隐藏]
2021SC@SDUSC
ppdet/data/reader.py源码分析
首先是在yaml上的配置:
文件./_base_/datasets/coco.yml
'''
metric: COCO # 验证模型的评测标准,可以选择COCO或者VOC
# 用于训练或验证的数据集的类别数目,注意这里不含背景类
# RCNN系列中包含背景类,即81=80 + 1(背景类)
num_classes: 80 #类别数量
TrainDataset: #训练数据
!COCODataSet #COCO数据集
image_dir: train2017 # 图片文件夹相对路径,路径是相对于dataset_dir,图像路径= dataset_dir + image_dir + image_name
anno_path: annotations/instances_train2017.json # anno_path,路径是相对于dataset_dir
dataset_dir: dataset/coco # 数据集相对路径,路径是相对于PaddleDetection
EvalDataset: #验证数据
!COCODataSet #COCO数据集
image_dir: val2017 #图片文件夹相对路径,路径是相对于dataset_dir,图像路径= dataset_dir + image_dir + image_name
anno_path: annotations/instances_val2017.json #标签目录,路径是相对于dataset_dir
dataset_dir: dataset/coco #数据集相对路径,路径是相对于PaddleDetection
TestDataset: #测试数据
!ImageFolder
anno_path: annotations/instances_val2017.json #标签目录,路径是相对于dataset_dir
'''
然后是./_base_/readers/mask_fpn_reader.yml #流程基本都是相同的,数据处理会根据算法 相应的做一些调整
'''
worker_num: 2 #数据读取线程数
TrainReader: # 训练过程中模型的输入设置
sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
- DecodeOp: {}
- RandomFlipImage: {prob: 0.5, is_mask_flip: true}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
- Permute: {to_bgr: false, channel_first: true}
batch_transforms:#数据批处理
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true}
batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
shuffle: true #数据是否随机
drop_last: true #是否丢弃最后与设置维度不匹配的数据 # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true
EvalReader: #验证数据读取
sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms: #数据批处理
- PadBatchOp: {pad_to_stride: 32, pad_gt: false}
batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
shuffle: false #数据是否随机
drop_last: false # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true
drop_empty: false #丢弃空数据
TestReader: #测试数据读取,有些前处理需要保持一致
sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms: #数据批处理
- PadBatchOp: {pad_to_stride: 32, pad_gt: false}
batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
shuffle: false #数据是否随机
drop_last: false #是否丢弃最后与设置维度不匹配的数据 # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true
'''
引用相关库:
import copy
import traceback
import six
import sys
import multiprocessing as mp
if sys.version_info >= (3, 0):
import queue as Queue
else:
import Queue
import numpy as np
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from ppdet.core.workspace import register, serializable, create
from . import transform
from ppdet.utils.logger import setup_logger
logger = setup_logger('reader')
class Compose(object):
def __init__(self, transforms, num_classes=81):
self.transforms = transforms
self.transforms_cls = []
for t in self.transforms:
for k, v in t.items():
op_cls = getattr(transform, k)
self.transforms_cls.append(op_cls(**v))
if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes
def __call__(self, data):
for f in self.transforms_cls:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
format(f, e, str(stack_info)))
raise e
return data
该类是单张图片数据增强类,多种单张数据增强方式都在transforms列表中 ,通过遍历该列表对图片进行多种数据增强最后返回增强后的结果。
class BatchCompose(Compose):
def __init__(self, transforms, num_classes=81):
super(BatchCompose, self).__init__(transforms, num_classes)
self.output_fields = mp.Manager().list([])
self.lock = mp.Lock()
def __call__(self, data):
for f in self.transforms_cls:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
format(f, e, str(stack_info)))
raise e
# parse output fields by first sample
# **this shoule be fixed if paddle.io.DataLoader support**
# For paddle.io.DataLoader not support dict currently,
# we need to parse the key from the first sample,
# BatchCompose.__call__ will be called in each worker
# process, so lock is need here.
if len(self.output_fields) == 0:
self.lock.acquire()
if len(self.output_fields) == 0:
for k, v in data[0].items():
# FIXME(dkp): for more elegent coding
if k not in ['flipped', 'h', 'w']:
self.output_fields.append(k)
self.lock.release()
data = [[data[i][k] for k in self.output_fields]
for i in range(len(data))]
data = list(zip(*data))
batch_data = [np.stack(d, axis=0) for d in data]
return batch_data
此类为批量图片数据增强类 , 同Compose,这里是对批量数据进行增强
class BaseDataLoader(object):
__share__ = ['num_classes']
def __init__(self,
inputs_def=None,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=81,
with_background=True,
**kwargs):
# sample transform
self._sample_transforms = Compose(
sample_transforms, num_classes=num_classes)
# batch transfrom
self._batch_transforms = BatchCompose(batch_transforms, num_classes)
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.with_background = with_background
self.kwargs = kwargs
def __call__(self,
dataset,
worker_num,
batch_sampler=None,
return_list=False,
use_prefetch=True):
self.dataset = dataset
self.dataset.parse_dataset(self.with_background)
# get data
self.dataset.set_transform(self._sample_transforms)
# set kwargs
self.dataset.set_kwargs(**self.kwargs)
# batch sampler
if batch_sampler is None:
self._batch_sampler = DistributedBatchSampler(
self.dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=self.drop_last)
else:
self._batch_sampler = batch_sampler
self.dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
num_workers=worker_num,
return_list=return_list,
use_buffer_reader=use_prefetch,
use_shared_memory=False)
self.loader = iter(self.dataloader)
return self
def __len__(self):
return len(self._batch_sampler)
def __iter__(self):
return self
def __next__(self):
# pack {filed_name: field_data} here
# looking forward to support dictionary
# data structure in paddle.io.DataLoader
try:
data = next(self.loader)
return {
k: v
for k, v in zip(self._batch_transforms.output_fields, data)
}
except StopIteration:
self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info())
def next(self):
# python2 compatibility
return self.__next__()
该类为数据加载基类 ,调用Compose和BatchCompose中的方法进行数据增强,最后迭代输出数据(通过call调用)。
'''
@register
@serializable
class IouLoss(object):
def __init__(self,):
def __call__(self, s):
return s
'''
@的含义: #Python当解释器读到@的这样的修饰符之后,会先解析@后的内容, 直接就把@下一行的函数或者类作为@后边的函数的参数, 然后将返回值赋值给下一行修饰的函数对象。
@register
class TrainReader(BaseDataLoader):
def __init__(self,
inputs_def=None,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=True,
drop_last=True,
drop_empty=True,
num_classes=81,
with_background=True,
**kwargs):
super(TrainReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
训练数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去 ,例如TrainDataset的配置如下,那TrainReader就会传入下面所有的参数,其中的成员也会传入相应的参数实例相应的类,实现TrainReader。
@register
class EvalReader(BaseDataLoader):
def __init__(self,
inputs_def=None,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=True,
drop_empty=True,
num_classes=81,
with_background=True,
**kwargs):
super(EvalReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
验证数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去。
@register
class TestReader(BaseDataLoader):
def __init__(self,
inputs_def=None,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=81,
with_background=True,
**kwargs):
super(TestReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
测试数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去。
版权声明:本文为CSDN博主「无情铁铲」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_45684033/article/details/122148120
暂无评论