前言
自YOLO(You Only Look Once)算法面世以来便得到相关从业者的广泛关注,而目前最新版本的yolov5更是将算法的性能,无论是在速度还是准确性上提升到了前所未有的高度。YOLOv5,YYDS!
相信不少小伙伴都下载过原项目,并使用官方训练好的模型跑过demo。但如果要用自己的数据集训练模型,数据格式转换以及处理过程都会相对比较麻烦,网上也有许多讲解如何处理的教程,但看着还是觉得挺麻烦,于是本小白花了几天写了个小工具,准备及训练过程被简化,也不用每次使用新的数据集都要忙忙碌碌弄一大堆脚本了。
注:目前这个小工具只能在Linux系统下使用,目前只在Ubuntu16.04,18.04,20.04上使用过。等有空改改代码写一个在Windows下也能通用的(不过说实话Windows下训练速度明显变慢)。以下步骤也许不完善,如果发现,后续有空也会更新。接下来将从安装步骤以及操作步骤两方面介绍:
安装步骤
0. 环境配置
安装本工具前,请确保系统中已经正确安装CUDA、Cudnn、Pytorch、torchvision,pyqt5,pyqt5-tools以确保本工具能正常使用。这里建议大家使用CUDA10.1以及Pytorch1.7.1。如果运行环境在conda中,需要先进入该conda环境中。
conda activate $YOUR_ENVI_NAME # 将$YOUR_ENVI_NAME替换为conda环境名称
同时,如果使用conda环境运行本项目,请使用
conda install pyqt==5.X.X
安装pyqt5而不是使用
pip3 install pyqt5 pyqt5-tools
否则大概率会报错。
本工具已上传至GitHub,项目地址:
如果访问速度过慢,可以使用镜像站下载,缺点就是貌似不能登录账号
1. 下载项目
在命令行输入
cd ~
git clone https://github.com/LSH9832/yolov5_training_tool.git
# 或者使用镜像源
git clone https://hub.fastgit.org/LSH9832/yolov5_training_tool.git
即可下载。
2. 安装项目
等待下载完成,输入
cd yolov5_training_tool # 进入目录
chmod +x ./setup.sh
./setup.sh # 安装相关依赖,创建桌面文件
安装过程中会询问是否安装pytorch,如果没有安装的话会根据系统中的CUDA版本安装相应的torch包,如果没有正确配置CUDA则安装CPU版本的torch(建议不要这样)。之后还会询问是否使用conda环境。如果使用conda环境则输入该环境的目录再按回车,比如说我服务器上的环境目录就是
/home/lsh/anaconda3/envs/yolov5
如果直接使用使用系统的python环境,则只需要直接回车即可。等待setup.sh运行完毕,此时桌面侧边栏中的Applications中会生成相应的图标,名称为YOLOv5 Train Guide Tool。可以通过搜索应用搜索到。单击即可打开。
3.下载权重文件
权重文件一共有s,m,l,x四种大小,可以根据需要下载,训练哪一种网络模型就下哪一种。
wget -O packages/yolov5/models/pt/yolov5s.pt https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt
wget -O packages/yolov5/models/pt/yolov5m.pt https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m.pt
wget -O packages/yolov5/models/pt/yolov5l.pt https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5l.pt
wget -O packages/yolov5/models/pt/yolov5x.pt https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5x.pt
同上,如果下载速度太慢,使用国内镜像源
wget -O packages/yolov5/models/pt/yolov5s.pt https://hub.fastgit.org/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt
wget -O packages/yolov5/models/pt/yolov5m.pt https://hub.fastgit.org/ultralytics/yolov5/releases/download/v5.0/yolov5m.pt
wget -O packages/yolov5/models/pt/yolov5l.pt https://hub.fastgit.org/ultralytics/yolov5/releases/download/v5.0/yolov5l.pt
wget -O packages/yolov5/models/pt/yolov5x.pt https://hub.fastgit.org/ultralytics/yolov5/releases/download/v5.0/yolov5x.pt
等待下载完成即可。
操作步骤
-
点击按钮Browse选择一个放数据集的目录Location
-
填写一个自定义数据集名称Data Name,再点击Create Dir按钮,这个小工具将再目录Location下自动创建一个与自定义数据集名称同名的文件夹。
-
打开该文件夹,里面有四个子目录,如下图
- 将标签名文件 label.txt 放入该目录下(没有就自己创建),格式如下图,有多少个类别就写多少行,最后一行不要按回车变成n+1行。
5. 将图片文件放在子目录 images 下。
- 将图片对应的标签文件(PascalVOC xml格式,目前只支持该格式,如果使用其他格式请先转换为该格式。标签格式如下图,里面的folder和path错了不影响。许多数据集的标签格式都如此。自己标定的数据集可以选择使用标签工具labelImg生成这种格式的标签文件)放入子目录 Annotations 下。
7. 选择用于训练的图片的百分比,剩余的图片将用于验证。
-
点击按钮 Generate Training Data生成可用于yolov5训练的数据文件。由于需要计算Anchors,所以如果数据集标签数量较多,计算时间也会较长,需要耐心等待方差var降至0。
-
选择模型网络大小(有四种大小,模型越大速度越慢,模型最终的效果越好),默认用yolov5自带的模型训练,如果有自己训练过的模型也可选择,但必须清楚该模型对应网络的大小,选择的模型大小必须对应上,否则训练时会报错。然后选择训练次数epoch(建议300以上,这里50只是用于展示)和批处理大小batch size(需要根据自己显卡显存的大小适当调整batch size,如果训练开始就报错就调小一点,4G显存s模型可以最大调到16,11G显存大概在30-40之间,m在25左右,供参考),然后选择使用的GPU数量。注意:batch size必须为使用的GPU数量的整数倍。按下Generate Code后,所有准备工作基本完成。
-
这时在数据集目录下生成了训练脚本start_train.py以及超参数文件hyp.yaml,用文本文档打开hyp.yaml即可调整训练的超参数。
-
这里有两种训练方式,一是直接在界面上选择好是否使用conda环境以及conda环境的目录后直接点击界面上的Start Train开始训练。二是在在命令行中转到相应的环境中并按照消息栏中的提示输入如下命令即可开始训练。
cd $YOUR_DATA_PATH # 根据具体目录位置调整
python3 start_train.py # single GPU
python3 -m torch.distributed.launch --nproc_per_node GPUNUM start_train.py # multi GPU of number ${GPUNUR}
当使用多个GPU训练模型时,将GPUNUM改为需要使用的GPU个数。
After Training
训练完后,命令行或界面上的消息栏将会告诉训练得到的权重文件所在位置,有best和last两个版本,根据需求使用。
权重文件的使用方法可以参考如下代码,使用时需要将本工具目录下的packages文件夹与脚本放在同一个文件夹下。
from packages import yolov5
import cv2
from glob import glob
detector = yolov5.Detector(weights="./best.pt", conf_thres=0.3)
cam = cv2.VideoCapture(0)
if __name__ == '__main__':
while True:
success, frame = cam.read()
if success:
preds = detector.detect(frame.copy())
img_show, bbs, labels = detector.draw_bb(frame.copy(),preds)
if len(bbs):
print('bounding_box:' bbs)
print('label:', labels)
cv2.imshow('capture', img_show)
if cv2.waitKey(1) == 27: # esc
cv2.destroyAllWindows()
break
由于本人水平有限,项目中出现bug在所难免,欢迎各位大佬反馈bug。
版权声明:本文为CSDN博主「Sparks Fly」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_40266665/article/details/120827369
暂无评论