yolov5训练模型(数据集的整理)——数据xml转换成yolo数据集txt格式

下载数据

下载得到的数据集,一般情况下是给出两个文件夹,分别是Anntations,JPEGImages,有.xml文件和.jpg文件。
以下链接是yolov下的关于人头检测的数据集:
https://github.com/HCIILAB/SCUT-HEAD-Dataset-Release

xml文件转换生成txt文件

一般情况下,下载下来的数据集会保存在root_path文件夹中,在运行这个代码的时候,记得在该目录下创建worktxt文件夹,来进行存放转换过来的同名txt文件。


import xml.etree.ElementTree as ET
import os
import shutil
import random
root_path=os.path.abspath('...')#这里是以下内容的绝对地址
def convert_annotation(image_id):
    classes=['person']#标签名
    in_file = open(root_path,'Annotations/%s.xml' % (image_id), encoding='UTF-8')
    out_file = open(root_path,'worktxt/%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    size_width = int(size.find('width').text)
    size_height = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),float(xmlbox.find('ymax').text)]
        # 标注越界修正
        if b[1] > size_width:
            b[1] = size_width
        if b[3] > size_height:
            b[3] = size_height
        txt_data=[((b[0]+b[1])/2.0-1)/size_width,((b[2]+b[3])/2.0-1)/size_height,(b[1]-b[0])/size_width,(b[3]-b[2])/size_height]
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in txt_data]) + '\n')   

img_path=os.path.join(root_path,'JPEGImages')               
imglist=os.listdir(img_path)

for img_id in imglist:
    img_id=img_id[:-4]    
    convert_annotation(img_id)

将数据集划分成训练集和验证集

一般在进行训练模型的时候,只会进行调用训练集和验证集,进行测试的时候,我会使用其他数据进行训练,或者将整个数据集当做测试集来进行测试训练。
在root_path的绝对路径下创建一个文件夹,命名为data,再在其目录下创建train和val文件夹,分别在这两个文件夹中创建名为images和labels文件夹。(看你自己意愿啦,记得跟代码的文件夹同名就可以啦)

'''
import os
import random
root_path=os.path.abspath('...')#数据集保存的绝对路径
img_path=os.path.join(root_path,'JPEGImages')               
imglist=os.listdir(img_path)
'''#这块内容上面的代码有哦,要是单独训练的话,记得自己改一下哦
txt_path=os.path.join(root_path,'worktxt')
train_image_path=os.path.join(root_path,'data/train/images')
train_txt_path=os.path.join(root_path,'data/train/labels')
val_image_path=os.path.join(root_path,'data/val/images')
val_txt_path=os.path.join(root_path,'data/val/labels')
for img in imglist:
    img=img[:-4]
    shutil.copy(os.path.join(img_path,img+'.jpg'),os.path.join(train_image_path,img+'.jpg'))
    shutil.copy(os.path.join(txt_path,img+'.txt'),os.path.join(train_txt_path,img+'.txt'))
random_list=random.sample(imglist,int(0.3*len(imglist))) 
for img in random_list:
    img=img[:-4]
    shutil.copy(os.path.join(train_image_path,img+'.jpg'),os.path.join(val_image_path,img+'.jpg'))
    shutil.copy(os.path.join(train_txt_path,img+'.txt'),os.path.join(val_txt_path,img+'.txt'))
    os.remove(os.path.join(train_image_path,img+'.jpg'))
    os.remove(os.path.join(train_txt_path,img+'.txt'))  

版权声明:本文为CSDN博主「「已注销」」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43869349/article/details/115733272

「已注销」

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

GiraffeDet:Heavy Neck的目标检测框架

关注并星标 从此不迷路 计算机视觉研究院 公众号ID|ComputerVisionGzq 学习群|扫码在主页获取加入方式 获取论文:关注并回复“GD” 计算机视觉研究院专栏 作者:Edison_G 在传统的目标检测框架中,从图像识别模型继承的主