目录
工程:https://github.com/VITA-Group/DeblurGANv2
一.环境的配置
直接使用python train.py缺什么库,就安装什么库;
二.跑通测试predict.py
需要设置以下参数:
--img_pattern
/media/XXX/test/LR/2021-01-28_11-21-04_white.jpg
--mask_pattern
None
--weights_path
/media/XXX/deblur/DeblurGANv2-master/weights/fpn_inception.h5
--out_dir
submit/
--side_by_side
False
--video
False
三.跑通训练train.py
配置config中的config.yaml参数:
---
project: deblur_gan
experiment_desc: fpn #日志存储文件夹
train:
files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/LR/*.jpg #&FILES_A /datasets/my_dataset/**/*.jpg #low quality/blury images
files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/HR/*.jpg #*FILES_A #clean files
size: &SIZE 256
crop: random #裁剪方式选择,备选项为:center
preload: &PRELOAD false
preload_size: &PRELOAD_SIZE 0
bounds: [0, .9]
scope: geometric
corrupt: &CORRUPT
- name: cutout
prob: 0.5 #数据增强概率
num_holes: 3
max_h_size: 25
max_w_size: 25
- name: jpeg #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
quality_lower: 70
quality_upper: 90
- name: motion_blur #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
- name: median_blur
- name: gamma
- name: rgb_shift
- name: hsv_shift
- name: sharpen
val:
files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/LR/*.jpg #*FILES_A
files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/HR/*.jpg #*FILES_A
size: *SIZE
scope: geometric
crop: center
preload: *PRELOAD
preload_size: *PRELOAD_SIZE
bounds: [.9, 1]
corrupt: *CORRUPT
phase: train
warmup_num: 3
model:
g_name: fpn_inception
blocks: 9
d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
d_layers: 3
content_loss: perceptual
adv_lambda: 0.001
disc_loss: wgan-gp
learn_residual: True
norm_layer: instance
dropout: True
num_epochs: 200
train_batches_per_epoch: 1000 #训练进度条长度
val_batches_per_epoch: 100 #验证时进度条长度
batch_size: 1
image_size: [256, 256] #图像推理尺寸
optimizer:
name: adam
lr: 0.0001
scheduler:
name: linear
start_epoch: 50
min_lr: 0.0000001
1.数据准备
注意:训练时数据推理尺寸为256*256,为了防止图像变形,所以使用的训练连样本都是宽高相等的图片;
准备自己的数据时,HR和LR图像的尺寸要相等,这个有别于超分辨率准备的数据,当HR和LR图像尺寸不相等时,模型训练的精度会一直起不来,本人训练时PSNR一直在16徘徊,跑了一晚上才醒悟(训练有问题啊);
2.数据增强方式
该项目中使用的是albumentations库,结合config中的config.yaml进行参数配置;
目前有这么多种增强方式,可修改源码
选其一:aug.py def get_transform中
albu.HorizontalFlip(always_apply=True), #左右翻转
albu.ShiftScaleRotate(always_apply=True),#随机仿射变换
albu.Transpose(always_apply=True),#转置
albu.OpticalDistortion(always_apply=True),#非刚体变换方法
albu.ElasticTransform(always_apply=True)#非刚体变换方法
albu.RandomCrop#随机裁剪
albu.CenterCrop#中心裁剪
选其一:aug.py def _resolve_aug_fn(name):
albu.Cutout,#随机擦除
albu.RGBShift,#对图像RGB的每个通道随机移动值
albu.HueSaturationValue,#随机更改图像的颜色,饱和度和值
albu.MotionBlur,#
albu.MedianBlur,
albu.RandomSnow,
albu.RandomShadow,
albu.RandomFog,#随机雾化
albu.RandomBrightnessContrast,
albu.RandomGamma,#随机灰度系数
albu.RandomSunFlare,
albu.Sharpen,
albu.ImageCompression,
albu.ToGray,
albu.Downscale,
3.加载预训练模型
train.py中的def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.load_state_dict(torch.load("weights/fpn_inception.h5", map_location='cpu')['model'])
4.模型训练结果存储问题
按照原工程中的设置,日志文件存储在fpn文件夹下;
训练模型只存储最新一个和最好的一个模型,而且是存储在工程根目录下,没有另起一个文件夹存储,可以修改def train(self)中的代码:
原代码为:
if self.metric_counter.update_best_model():
torch.save({'model': self.netG.state_dict()},
'best_{}.h5'.format(self.config['experiment_desc']))
torch.save({'model': self.netG.state_dict()
}, 'last_{}.h5'.format(self.config['experiment_desc']))
修改为:
if self.metric_counter.update_best_model():
torch.save({'model': self.netG.state_dict()},self.config['experiment_desc']+'/best_{}.h5'.format(self.config['experiment_desc']))
torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/last_{}.h5'.format(self.config['experiment_desc']))
if epoch // 50:
torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/epoch_{}.h5'.format(epoch))
其他链接:
图像去模糊之DeblurGAN-v2_年轻即出发,-CSDN博客_deblurganv2
版权声明:本文为CSDN博主「猫猫与橙子」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_22764813/article/details/121762177
暂无评论