图像去模糊:跑通DeblurGAN-v2

目录

一.环境的配置

二.跑通测试predict.py

三.跑通训练train.py

1.数据准备

2.数据增强方式

3.加载预训练模型

4.模型训练结果存储问题


工程:https://github.com/VITA-Group/DeblurGANv2

一.环境的配置

直接使用python train.py缺什么库,就安装什么库;

二.跑通测试predict.py

需要设置以下参数:

--img_pattern
/media/XXX/test/LR/2021-01-28_11-21-04_white.jpg
--mask_pattern
None
--weights_path
/media/XXX/deblur/DeblurGANv2-master/weights/fpn_inception.h5
--out_dir
submit/
--side_by_side
False
--video
False

三.跑通训练train.py

配置config中的config.yaml参数:

---
project: deblur_gan
experiment_desc: fpn #日志存储文件夹

train:
  files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/LR/*.jpg #&FILES_A /datasets/my_dataset/**/*.jpg #low quality/blury images
  files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/HR/*.jpg #*FILES_A #clean files
  size: &SIZE 256
  crop: random #裁剪方式选择,备选项为:center
  preload: &PRELOAD false
  preload_size: &PRELOAD_SIZE 0
  bounds: [0, .9]
  scope: geometric
  corrupt: &CORRUPT
    - name: cutout
      prob: 0.5 #数据增强概率
      num_holes: 3
      max_h_size: 25
      max_w_size: 25
    - name: jpeg #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
      quality_lower: 70
      quality_upper: 90
    - name: motion_blur #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
    - name: median_blur
    - name: gamma
    - name: rgb_shift
    - name: hsv_shift
    - name: sharpen

val:
  files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/LR/*.jpg #*FILES_A
  files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/HR/*.jpg #*FILES_A
  size: *SIZE
  scope: geometric
  crop: center
  preload: *PRELOAD
  preload_size: *PRELOAD_SIZE
  bounds: [.9, 1]
  corrupt: *CORRUPT

phase: train
warmup_num: 3
model:
  g_name: fpn_inception
  blocks: 9
  d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
  d_layers: 3
  content_loss: perceptual
  adv_lambda: 0.001
  disc_loss: wgan-gp
  learn_residual: True
  norm_layer: instance
  dropout: True

num_epochs: 200
train_batches_per_epoch: 1000 #训练进度条长度
val_batches_per_epoch: 100 #验证时进度条长度
batch_size: 1
image_size: [256, 256] #图像推理尺寸

optimizer:
  name: adam
  lr: 0.0001
scheduler:
  name: linear
  start_epoch: 50
  min_lr: 0.0000001

1.数据准备

注意:训练时数据推理尺寸为256*256,为了防止图像变形,所以使用的训练连样本都是宽高相等的图片;

准备自己的数据时,HR和LR图像的尺寸要相等,这个有别于超分辨率准备的数据,当HR和LR图像尺寸不相等时,模型训练的精度会一直起不来,本人训练时PSNR一直在16徘徊,跑了一晚上才醒悟(训练有问题啊);

2.数据增强方式

该项目中使用的是albumentations库,结合config中的config.yaml进行参数配置;

目前有这么多种增强方式,可修改源码

选其一:aug.py def get_transform中
albu.HorizontalFlip(always_apply=True), #左右翻转
albu.ShiftScaleRotate(always_apply=True),#随机仿射变换
albu.Transpose(always_apply=True),#转置
albu.OpticalDistortion(always_apply=True),#非刚体变换方法
albu.ElasticTransform(always_apply=True)#非刚体变换方法
albu.RandomCrop#随机裁剪
albu.CenterCrop#中心裁剪
选其一:aug.py def _resolve_aug_fn(name):
albu.Cutout,#随机擦除
albu.RGBShift,#对图像RGB的每个通道随机移动值
albu.HueSaturationValue,#随机更改图像的颜色,饱和度和值
albu.MotionBlur,#
albu.MedianBlur,
albu.RandomSnow,
albu.RandomShadow,
albu.RandomFog,#随机雾化
albu.RandomBrightnessContrast,
albu.RandomGamma,#随机灰度系数
albu.RandomSunFlare,
albu.Sharpen,
albu.ImageCompression,
albu.ToGray,
albu.Downscale,

3.加载预训练模型

train.py中的def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.load_state_dict(torch.load("weights/fpn_inception.h5", map_location='cpu')['model'])

4.模型训练结果存储问题

按照原工程中的设置,日志文件存储在fpn文件夹下;

训练模型只存储最新一个和最好的一个模型,而且是存储在工程根目录下,没有另起一个文件夹存储,可以修改def train(self)中的代码:

原代码为:

if self.metric_counter.update_best_model():
    torch.save({'model': self.netG.state_dict()}, 
         'best_{}.h5'.format(self.config['experiment_desc']))
    torch.save({'model': self.netG.state_dict()
            }, 'last_{}.h5'.format(self.config['experiment_desc']))

修改为:

if self.metric_counter.update_best_model():
    torch.save({'model': self.netG.state_dict()},self.config['experiment_desc']+'/best_{}.h5'.format(self.config['experiment_desc']))
    torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/last_{}.h5'.format(self.config['experiment_desc']))

if epoch // 50:
    torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/epoch_{}.h5'.format(epoch))

其他链接:

图像去模糊之DeblurGAN-v2_年轻即出发,-CSDN博客_deblurganv2

版权声明:本文为CSDN博主「猫猫与橙子」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_22764813/article/details/121762177

猫猫与橙子

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

单目3D目标检测调研

单目3D目标检测调研 一、 简介 现有的单目3D目标检测方案主要方案主要分为两类,分别为基于图片的方法和基于伪雷达点云的方法。   基于图片的方法一般通过2D-3D之间的几何约束来学习,包括目标形状信息&#xff0