前言:此文是我从yolov5替换到yolox训练的过程,前提是我们有图片和标注文件,而且都是yolov5的txt格式的;之前在网上看了一圈,怎么用自己的数据训练yolox模型,都是需要把标注文件整理成voc格式或coco数据集格式,连文件夹的存放方式都必须一样,真是麻烦;而我之前的任务都是基于yolov5训练的,所以图片,标注文件已经有了,我也不想按voc,coco那样再去改变格式,于是就有了此文。
yolov5数据集目录如下:
一、利用yolov5标注生成xml格式的标注
利用yolov5的txt格式的标注文件生成xml格式的标注文件,在生成的时候需注意:
1、yolov5的标注是经过归一化的c_x, c_y, w, h
2、背景图片yolov5可以不用标注,即没有对应的txt文件,但yolox训练却不行
3、图片名字不要带有空格,yolov5可以正常训练验证,但yolox在验证的时候会报错。
直接上生成xml的代码,文件名yolotxt2xml.py:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/09/14 11:14
# @Author : lishanlu
# @File : yolotxt2xml.py
# @Software: PyCharm
# @Discription:
from __future__ import absolute_import, print_function, division
import os
from xml.dom.minidom import Document
import xml.etree.ElementTree as ET
import cv2
'''
import xml
xml.dom.minidom.Document().writexml()
def writexml(self,
writer: Any,
indent: str = "",
addindent: str = "",
newl: str = "",
encoding: Any = None) -> None
'''
class YOLO2VOCConvert:
def __init__(self, txts_path, xmls_path, imgs_path, classes_str_list):
self.txts_path = txts_path # 标注的yolo格式标签文件路径
self.xmls_path = xmls_path # 转化为voc格式标签之后保存路径
self.imgs_path = imgs_path # 读取读片的路径个图片名字,存储到xml标签文件中
self.classes = classes_str_list # 类别列表
# 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,...
# writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中
def search_all_classes(self, writer=False):
# 读取每一个txt标签文件,取出每个目标的标注信息
all_names = set()
txts = os.listdir(self.txts_path)
# 使用列表生成式过滤出只有后缀名为txt的标签文件
txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
txts = [txt for txt in txts if not txt.split('.')[0] == "classes"] # 过滤掉classes.txt文件
print(len(txts), txts)
# 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt']
for txt in txts:
txt_file = os.path.join(self.txts_path, txt)
with open(txt_file, 'r') as f:
objects = f.readlines()
for object in objects:
object = object.strip().split(' ')
print(object) # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
all_names.add(int(object[0]))
# print(objects) # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n']
print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts))
# 把从xmls标签文件中提取的类别写入到'./Annotations/classes.txt'文件中
# if writer:
# with open('./Annotations/classes.txt', 'w') as f:
# for label in all_names:
# f.write(label + '\n')
return list(all_names)
def yolo2voc(self):
"""
可以转换图片和txtlabel数量不匹配的情况,即有些图片是背景
:return:
"""
# 创建一个保存xml标签文件的文件夹
if not os.path.exists(self.xmls_path):
os.makedirs(self.xmls_path)
for img_name in os.listdir(self.imgs_path):
# 读取图片的尺度信息
print("读取图片:", img_name)
try:
img = cv2.imread(os.path.join(self.imgs_path, img_name))
height_img, width_img, depth_img = img.shape
print(height_img, width_img, depth_img) # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度)
except Exception as e:
print("%s read fail, %s"%(img_name, e))
continue
txt_name = img_name.replace(os.path.splitext(img_name)[1], '.txt')
txt_file = os.path.join(self.txts_path, txt_name)
all_objects = []
if os.path.exists(txt_file):
with open(txt_file, 'r') as f:
objects = f.readlines()
for object in objects:
object = object.strip().split(' ')
all_objects.append(object)
print(object) # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
# 创建xml标签文件中的标签
xmlBuilder = Document()
# 创建annotation标签,也是根标签
annotation = xmlBuilder.createElement("annotation")
# 给标签annotation添加一个子标签
xmlBuilder.appendChild(annotation)
# 创建子标签folder
folder = xmlBuilder.createElement("folder")
# 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages
folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1]) # 标签内存
folder.appendChild(folderContent) # 把内容存入标签
annotation.appendChild(folder) # 把存好内容的folder标签放到 annotation根标签下
# 创建子标签filename
filename = xmlBuilder.createElement("filename")
# 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg
filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg') # 标签内容
filename.appendChild(filenameContent)
annotation.appendChild(filename)
# 把图片的shape存入xml标签中
size = xmlBuilder.createElement("size")
# 给size标签创建子标签width
width = xmlBuilder.createElement("width") # size子标签width
widthContent = xmlBuilder.createTextNode(str(width_img))
width.appendChild(widthContent)
size.appendChild(width) # 把width添加为size的子标签
# 给size标签创建子标签height
height = xmlBuilder.createElement("height") # size子标签height
heightContent = xmlBuilder.createTextNode(str(height_img)) # xml标签中存入的内容都是字符串
height.appendChild(heightContent)
size.appendChild(height) # 把width添加为size的子标签
# 给size标签创建子标签depth
depth = xmlBuilder.createElement("depth") # size子标签width
depthContent = xmlBuilder.createTextNode(str(depth_img))
depth.appendChild(depthContent)
size.appendChild(depth) # 把width添加为size的子标签
annotation.appendChild(size) # 把size添加为annotation的子标签
# 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标
for object_info in all_objects:
# 开始创建标注目标的label信息的标签
object = xmlBuilder.createElement("object") # 创建object标签
# 创建label类别标签
# 创建name标签
imgName = xmlBuilder.createElement("name") # 创建name标签
imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])])
imgName.appendChild(imgNameContent)
object.appendChild(imgName) # 把name添加为object的子标签
# 创建pose标签
pose = xmlBuilder.createElement("pose")
poseContent = xmlBuilder.createTextNode("Unspecified")
pose.appendChild(poseContent)
object.appendChild(pose) # 把pose添加为object的标签
# 创建truncated标签
truncated = xmlBuilder.createElement("truncated")
truncatedContent = xmlBuilder.createTextNode("0")
truncated.appendChild(truncatedContent)
object.appendChild(truncated)
# 创建difficult标签
difficult = xmlBuilder.createElement("difficult")
difficultContent = xmlBuilder.createTextNode("0")
difficult.appendChild(difficultContent)
object.appendChild(difficult)
# 先转换一下坐标
# (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax)
x_center = float(object_info[1]) * width_img + 1
y_center = float(object_info[2]) * height_img + 1
xminVal = int(
x_center - 0.5 * float(object_info[3]) * width_img) # object_info列表中的元素都是字符串类型
yminVal = int(y_center - 0.5 * float(object_info[4]) * height_img)
xmaxVal = int(x_center + 0.5 * float(object_info[3]) * width_img)
ymaxVal = int(y_center + 0.5 * float(object_info[4]) * height_img)
# 创建bndbox标签(三级标签)
bndbox = xmlBuilder.createElement("bndbox")
# 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息
# 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标
# 1、创建xmin标签
xmin = xmlBuilder.createElement("xmin") # 创建xmin标签(四级标签)
xminContent = xmlBuilder.createTextNode(str(xminVal))
xmin.appendChild(xminContent)
bndbox.appendChild(xmin)
# 2、创建ymin标签
ymin = xmlBuilder.createElement("ymin") # 创建ymin标签(四级标签)
yminContent = xmlBuilder.createTextNode(str(yminVal))
ymin.appendChild(yminContent)
bndbox.appendChild(ymin)
# 3、创建xmax标签
xmax = xmlBuilder.createElement("xmax") # 创建xmax标签(四级标签)
xmaxContent = xmlBuilder.createTextNode(str(xmaxVal))
xmax.appendChild(xmaxContent)
bndbox.appendChild(xmax)
# 4、创建ymax标签
ymax = xmlBuilder.createElement("ymax") # 创建ymax标签(四级标签)
ymaxContent = xmlBuilder.createTextNode(str(ymaxVal))
ymax.appendChild(ymaxContent)
bndbox.appendChild(ymax)
object.appendChild(bndbox)
annotation.appendChild(object) # 把object添加为annotation的子标签
f = open(os.path.join(self.xmls_path, txt_name.split('.')[0] + '.xml'), 'w')
xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
f.close()
if __name__ == '__main__':
imgs_path1 = 'F:/Dataset/road/images/val' # ['train', 'val']
txts_path1 = 'F:/Dataset/road/labels/val' # ['train', 'val']
xmls_path1 = 'F:/Dataset/road/xmls/val' # ['train', 'val']
classes_str_list = ['road_crack','road_sag'] # class name
yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1, classes_str_list)
labels = yolo2voc_obj1.search_all_classes()
print('labels: ', labels)
yolo2voc_obj1.yolo2voc()
将train和val都转换生成后,目录格式如下:
二、定义数据读取文件
整个YOLOX的工程,训练过程,要想有一个大概浏览,可以见我的另一篇文章yolox训练解析
进入到YOLOX主目录
在yolox/data/datasets/目录下定义了数据的读取方式,有按coco方式读取,有按voc方式读取,另外mosaic增强也定义在这个文件夹下,我们添加新的读取方式就在这个目录下添加,添加yolo_style.py文件,代码如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/12/23 9:13
# @Author : lishanlu
# @File : yolo_style.py
# @Software: PyCharm
# @Discription: 读入yolox风格的xmls数据
from __future__ import absolute_import, print_function, division
import os
import os.path
import pickle
import xml.etree.ElementTree as ET
import cv2
import numpy as np
from yolox.evaluators.voc_eval import voc_eval
from .datasets_wrapper import Dataset
from pathlib import Path
import glob
from tqdm import tqdm
from PIL import Image, ExifTags
import torch
class AnnotationTransform(object):
"""Transforms a annotation into a Tensor of bbox coords and label index
Initilized with a dictionary lookup of classnames to indexes
Arguments:
classes_name: (str, str, ...): dictionary lookup of classnames -> indexes
keep_difficult (bool, optional): keep difficult instances or not
(default: False)
height (int): height
width (int): width
"""
def __init__(self, classes_name, keep_difficult=True):
self.class_to_ind = dict(zip(classes_name, range(len(classes_name))))
self.keep_difficult = keep_difficult
def __call__(self, target):
"""
Arguments:
target (annotation) : the target annotation to be made usable
will be an ET.Element
Returns:
a list containing lists of bounding boxes [bbox coords, class name]
"""
res = np.empty((0, 5))
for obj in target.iter("object"):
difficult = obj.find("difficult")
if difficult is not None:
difficult = int(difficult.text) == 1
else:
difficult = False
if not self.keep_difficult and difficult:
continue
name = obj.find("name").text.strip()
bbox = obj.find("bndbox")
pts = ["xmin", "ymin", "xmax", "ymax"]
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
label_idx = self.class_to_ind[name]
bndbox.append(label_idx)
res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
width = int(target.find("size").find("width").text)
height = int(target.find("size").find("height").text)
img_info = (height, width)
return res, img_info
"""
generation yolo style dataloader.
"""
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp'] # acceptable image suffixes
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
def img2xml_paths(img_paths):
# Define xml paths as a function of image paths
sa, sb = os.sep + 'images' + os.sep, os.sep + 'xmls' + os.sep # /images/, /xmls/ substrings
return ['xml'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
def get_hash(files):
# Returns a single hash value of a list of files
return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
try:
rotation = dict(img._getexif().items())[orientation]
if rotation == 6: # rotation 270
s = (s[1], s[0])
elif rotation == 8: # rotation 90
s = (s[1], s[0])
except:
pass
return s
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
class YOLODetection(Dataset):
"""
YOLO Style Detection Dataset Object (read label from yolo style XML)
input is image, target is annotation
Args:
data_dir (string): filepath to data folder.
classes (string, string, ....): class string names.
image_set (string): imageset to use (eg. 'train', 'val', 'test')
preproc (callable, optional): transformation to perform on the input image
target_transform (callable, optional): transformation to perform on the target `annotation`
(eg: take in caption string, return tensor of word indices)
dataset_name (string, optional): which dataset to load (default: 'yolo_dataset')
"""
def __init__(
self,
data_dir,
classes,
image_sets=['train'],
img_size=(416, 416),
preproc=None,
dataset_name="yolo_dataset",
cache=False,
):
super().__init__(img_size)
self.root = data_dir
self.image_set = image_sets
self.img_size = img_size
self.preproc = preproc
self._classes = classes
self.target_transform = AnnotationTransform(self._classes, keep_difficult=True)
self.name = dataset_name
for name in image_sets:
rootpath = self.root
image_dir = os.path.join(rootpath, 'images', name)
self.image_files = [os.path.join(image_dir, image_name) for image_name in os.listdir(image_dir)]
if name == 'val':
self.val_ids = [os.path.splitext(image_name)[0] for image_name in os.listdir(image_dir)]
with open(os.path.join(rootpath, name+'.txt'), 'w') as f:
for id in self.val_ids:
f.write(id+'\n')
self.xml_files = img2xml_paths(self.image_files) # list, xml file path
self.annotations = self._load_xml_annotations()
self.imgs = None
if cache:
self._cache_images()
def __len__(self):
return len(self.image_files)
def _load_xml_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in range(len(self.xml_files))]
def _cache_images(self):
pass
def load_anno_from_ids(self, index):
xml_file = self.xml_files[index]
target = ET.parse(xml_file).getroot()
assert self.target_transform is not None
res, img_info = self.target_transform(target)
height, width = img_info
r = min(self.img_size[0] / height, self.img_size[1] / width)
res[:, :4] *= r
resized_info = (int(height * r), int(width * r))
return (res, img_info, resized_info)
def load_anno(self, index):
return self.annotations[index][0]
def load_resized_img(self, index):
img = self.load_image(index)
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
return resized_img
def load_image(self, index):
img = cv2.imread(self.image_files[index], cv2.IMREAD_COLOR)
assert img is not None
return img
def pull_item(self, index):
"""Returns the original image and target at an index for mixup
Note: not using self.__getitem__(), as any transformations passed in
could mess up this functionality.
Argument:
index (int): index of img to show
Return:
img, target
"""
if self.imgs is not None:
target, img_info, resized_info = self.annotations[index]
pad_img = self.imgs[index]
img = pad_img[: resized_info[0], : resized_info[1], :].copy()
else:
img = self.load_resized_img(index)
target, img_info, _ = self.annotations[index]
return img, target, img_info, index
@Dataset.mosaic_getitem
def __getitem__(self, index):
img, target, img_info, img_id = self.pull_item(index) # 此target坐标为(x,y,x,y,cls)
### show read image and label.
# from PIL import Image,ImageDraw
# from matplotlib import pyplot as plt
# img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# draw = ImageDraw.Draw(img)
# for j in range(target.shape[0]):
# name = int(target[j][4])
# left = int(target[j][0])
# top = int(target[j][1])
# right = int(target[j][2])
# bottom = int(target[j][3])
# draw.text((left+10, top+10), f'{name}', fill='blue')
# draw.rectangle((left, top, right, bottom), outline='red', width=2)
# plt.imshow(img)
# plt.show()
if self.preproc is not None:
img, target = self.preproc(img, target, self.input_dim) # 此target坐标为(cls, cx,cy,w,h)
# from PIL import Image,ImageDraw
# from matplotlib import pyplot as plt
# img = np.transpose(img.astype(np.uint8), (1, 2, 0))
# img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# draw = ImageDraw.Draw(img)
# for j in range(target.shape[0]):
# name = int(target[j][0])
# left = int(target[j][1]-target[j][3]/2)
# top = int(target[j][2]-target[j][4]/2)
# right = int(target[j][1]+target[j][3]/2)
# bottom = int(target[j][2]+target[j][4]/2)
# draw.text((left+10, top+10), f'{name}', fill='blue')
# draw.rectangle((left, top, right, bottom), outline='red', width=2)
# plt.imshow(img)
# plt.show()
return img, target, img_info, img_id
def evaluate_detections(self, all_boxes, output_dir=None):
"""
all_boxes is a list of length number-of-classes.
Each list element is a list of length number-of-images.
Each of those list elements is either an empty list []
or a numpy array of detection.
all_boxes[class][image] = [] or np.array of shape #dets x 5
"""
self._write_voc_results_file(all_boxes)
IouTh = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
mAPs = []
for iou in IouTh:
mAP = self._do_python_eval(output_dir, iou)
mAPs.append(mAP)
print("--------------------------------------------------------------")
print("map_5095:", np.mean(mAPs))
print("map_50:", mAPs[0])
print("--------------------------------------------------------------")
return np.mean(mAPs), mAPs[0]
def _get_voc_results_file_template(self):
filename = "comp4_det_test" + "_{:s}.txt"
filedir = os.path.join(self.root, "results")
if not os.path.exists(filedir):
os.makedirs(filedir)
path = os.path.join(filedir, filename)
return path
def _write_voc_results_file(self, all_boxes):
self.ids = [os.path.splitext(os.path.split(image_file)[1])[0] for image_file in self.image_files]
for cls_ind, cls in enumerate(self._classes):
cls_ind = cls_ind
if cls == "__background__":
continue
print("Writing {} VOC results file".format(cls))
filename = self._get_voc_results_file_template().format(cls)
with open(filename, "wt") as f:
for im_ind, index in enumerate(self.ids):
#index = index[1]
dets = all_boxes[cls_ind][im_ind]
if dets == []:
continue
for k in range(dets.shape[0]):
f.write(
"{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
index,
dets[k, -1],
dets[k, 0] + 1,
dets[k, 1] + 1,
dets[k, 2] + 1,
dets[k, 3] + 1,
)
)
def _do_python_eval(self, output_dir="output", iou=0.5):
rootpath = self.root
name = self.image_set[0]
annopath = os.path.join(rootpath, "xmls", "val", "{:s}.xml")
imagesetfile = os.path.join(rootpath, name + ".txt")
cachedir = os.path.join(
self.root, "annotations_cache"
)
if not os.path.exists(cachedir):
os.makedirs(cachedir)
aps = []
# The PASCAL VOC metric changed in 2010
# use_07_metric = True if int(self._year) < 2010 else False
use_07_metric = True
print("Eval IoU : {:.2f}".format(iou))
if output_dir is not None and not os.path.isdir(output_dir):
os.mkdir(output_dir)
for i, cls in enumerate(self._classes):
if cls == "__background__":
continue
filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
filename,
annopath,
imagesetfile,
cls,
cachedir,
ovthresh=iou,
use_07_metric=use_07_metric,
)
aps += [ap]
if iou == 0.5:
print("AP for {} = {:.4f}".format(cls, ap))
if output_dir is not None:
with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
if iou == 0.5:
print("Mean AP = {:.4f}".format(np.mean(aps)))
print("~~~~~~~~")
print("Results:")
for ap in aps:
print("{:.3f}".format(ap))
print("{:.3f}".format(np.mean(aps)))
print("~~~~~~~~")
print("")
print("--------------------------------------------------------------")
print("Results computed with the **unofficial** Python eval code.")
print("Results should be very close to the official MATLAB eval code.")
print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
print("-- Thanks, The Management")
print("--------------------------------------------------------------")
return np.mean(aps)
定义好这个文件,别忘了在yolox/data/datasets/的__init__.py文件中加入from .yolo_style import YOLODetection
三、定义训练用的配置文件
在exps/example/目录下新建一个任务目录,比如road,在这个目录下新建文件yolox_road.py,这个文件用于定义训练用的类Exp,它继承自yolox/exp/下的yolox_base.py中的Exp类,主要定义模型参数,数据集参数及数据增强参数,创建dataloader等函数。代码示例如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/12/23 8:58
# @Author : lishanlu
# @File : yolox_road.py
# @Software: PyCharm
# @Discription:
from __future__ import absolute_import, print_function, division
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
# ------------ model config -------------------#
self.num_classes = 2 # 修改为和自己的数据类别一致
self.depth = 0.67
self.width = 0.75
# ---------------- dataloader config ---------------- #
# set worker to 4 for shorter dataloader init time
self.data_num_workers = 4
self.input_size = (640, 640) # (height, width)
# Actual multiscale ranges: [640-5*32, 640+5*32].
# To disable multiscale training, set the
# self.multiscale_range to 0.
self.multiscale_range = 5
# You can uncomment this line to specify a multiscale range
# self.random_size = (14, 26)
self.data_dir = 'your data rootdir' # 指定数据的根目录
self.classes_name = ('class1','class2') # 指定类别名字
self.dataset_name = 'yolo_dataset' # 数据库名字,可以不用修改
# --------------- transform config ----------------- #
self.mosaic_prob = 1.0
self.mixup_prob = 1.0
self.hsv_prob = 1.0
self.flip_prob = 0.5
self.degrees = 5.0
self.translate = 0.1
self.mosaic_scale = (0.5, 1.5)
self.mixup_scale = (0.5, 1.5)
self.shear = 2.0
self.perspective = 0.0
self.enable_mixup = False
# -------------- training config --------------------- #
self.warmup_epochs = 5
self.max_epoch = 300
self.warmup_lr = 0
self.basic_lr_per_img = 0.01 / 64.0
self.scheduler = "yoloxwarmcos"
self.milestones = [70, 120, 180, 300] # 该参数只用于multi_step学习率衰减
self.gamma = 0.1 # 该参数只用于multi_step学习率衰减
self.no_aug_epochs = 300
self.min_lr_ratio = 0.05
self.ema = True
self.weight_decay = 5e-4
self.momentum = 0.9
self.print_interval = 10
self.eval_interval = 1
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
# ----------------- testing config ------------------ #
self.test_size = (640, 640)
self.test_conf = 0.01
self.nmsthre = 0.65
def get_model(self):
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
head = YOLOXHead(self.num_classes, self.width,
in_channels=in_channels) # strides=[8,16,32], in_channels=in_channels
self.model = YOLOX(backbone, head)
self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
return self.model
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
from yolox.data import (
YOLODetection,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
worker_init_reset_seed,
)
from yolox.utils import (
wait_for_the_master,
get_local_rank,
)
local_rank = get_local_rank()
with wait_for_the_master(local_rank):
dataset = YOLODetection(data_dir=self.data_dir,
classes=self.classes_name,
image_sets=['train'],
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
dataset_name=self.dataset_name,
cache=cache_img)
dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=120,
flip_prob=self.flip_prob,
hsv_prob=self.hsv_prob),
degrees=self.degrees,
translate=self.translate,
mosaic_scale=self.mosaic_scale,
mixup_scale=self.mixup_scale,
shear=self.shear,
perspective=self.perspective,
enable_mixup=self.enable_mixup,
mosaic_prob=self.mosaic_prob,
mixup_prob=self.mixup_prob,
)
# import pdb;pdb.set_trace()
self.dataset = dataset
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
mosaic=not no_aug,
)
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
return train_loader
def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.data import YOLODetection, ValTransform
valdataset = YOLODetection(
data_dir=self.data_dir,
classes=self.classes_name,
image_sets=['val'],
img_size=self.test_size,
preproc=ValTransform(legacy=legacy),
dataset_name=self.dataset_name
)
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
valdataset, shuffle=False
)
else:
sampler = torch.utils.data.SequentialSampler(valdataset)
dataloader_kwargs = {
"num_workers": self.data_num_workers,
"pin_memory": True,
"sampler": sampler,
}
dataloader_kwargs["batch_size"] = batch_size
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
return val_loader
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
from yolox.evaluators import VOCEvaluator
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
evaluator = VOCEvaluator(
dataloader=val_loader,
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
)
return evaluator
def get_lr_scheduler(self, lr, iters_per_epoch, **kwargs):
from yolox.utils import LRScheduler
scheduler = LRScheduler(
self.scheduler,
lr,
iters_per_epoch,
self.max_epoch,
warmup_epochs=self.warmup_epochs,
warmup_lr_start=self.warmup_lr,
no_aug_epochs=self.no_aug_epochs,
min_lr_ratio=self.min_lr_ratio,
**kwargs
)
return scheduler
四、启动训练
写一个sh文件train.sh,代码如下:
python tools/train.py \
--experiment-name yolox_road \
--batch-size 48 \
--devices 0 \
--exp_file exps/example/road/yolox_road.py \
--fp16 \
--ckpt pre_train/yolox_m.pth
运行命令bash ./train.sh就可以启动训练
版权声明:本文为CSDN博主「lishanlu136」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/lishanlu136/article/details/122109741
暂无评论