合并类别
此次博主为数据集设置3个类别, ‘Car’,’Cyclist’,’Pedestrian’,只不过标注信息中还有其他类型的车和人,直接略过有点浪费,博主希望将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’ 合并到 ‘Pedestrian’ 类别中去(‘Misc’ 和 ‘Dontcare’ 这两类直接忽略)。参考:前辈
对原博主的代码进行了一些适当的修改
# modify_annotations_txt.py
'''
PASCAL VOC数据集总共20个类别,如果用于特定场景,20个类别确实多了。此次博主为数据集设置3个类别,
‘Car’,’Cyclist’,’Pedestrian’,只不过标注信息中还有其他类型的车和人,直接略过有点浪费,
博主希望将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’ 合并到 ‘Pedestrian’ 类别中去
(‘Misc’ 和 ‘Dontcare’ 这两类直接忽略)。
'''
# 合并类别
import glob
import os
import tqdm
# txt_list = glob.glob('G:/A_Data/kitti_voc/label/') # 存储Labels文件夹所有txt文件路径
path = 'G:/A_Data/kitti_voc/label/'
txt_list = os.listdir(path)
# txt_list = os.listdir(path='G:/A_Data/kitti_voc/label/')
def show_category(txt_list):
category_list = []
for item in txt_list:
item = path + item
try:
with open(item) as tdf:
for each_line in tdf:
labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
category_list.append(labeldata[0]) # 只要第一个字段,即类别
except IOError as ioerr:
print('File error:'+str(ioerr))
print(set(category_list)) # 输出集合
def merge(line):
each_line=''
for i in range(len(line)):
if i != (len(line)-1):
each_line=each_line+line[i]+' '
else:
each_line=each_line+line[i] # 最后一条字段后面不加空格
each_line=each_line+'\n'
return (each_line)
# print('before modify categories are:\n')
show_category(txt_list)
for item in tqdm.tqdm(txt_list):
new_txt=[]
item = path + item
try:
with open(item, 'r') as r_tdf:
for each_line in r_tdf:
labeldata = each_line.strip().split(' ')
if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
if labeldata[0] == 'Person_sitting': # 合并行人类
labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
if labeldata[0] == 'DontCare': # 忽略Dontcare类
continue
if labeldata[0] == 'Misc': # 忽略Misc类
continue
new_txt.append(merge(labeldata)) # 重新写入新的txt文件
with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去
for temp in new_txt:
w_tdf.write(temp)
except IOError as ioerr:
print('File error:'+str(ioerr))
# print('\nafter modify categories are:\n')
show_category(txt_list)
txt转xml
把路径修改正确就可以了。
# txt_to_xml.py
# encoding:utf-8
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML
from xml.dom.minidom import Document
import cv2
import os
import tqdm
def generate_xml(name,split_lines,img_size,class_ind):
doc = Document() # 创建DOM文档对象
annotation = doc.createElement('annotation')
doc.appendChild(annotation)
title = doc.createElement('folder')
title_text = doc.createTextNode('KITTI')
title.appendChild(title_text)
annotation.appendChild(title)
img_name=name+'.png'
title = doc.createElement('filename')
title_text = doc.createTextNode(img_name)
title.appendChild(title_text)
annotation.appendChild(title)
source = doc.createElement('source')
annotation.appendChild(source)
title = doc.createElement('database')
title_text = doc.createTextNode('The KITTI Database')
title.appendChild(title_text)
source.appendChild(title)
title = doc.createElement('annotation')
title_text = doc.createTextNode('KITTI')
title.appendChild(title_text)
source.appendChild(title)
size = doc.createElement('size')
annotation.appendChild(size)
title = doc.createElement('width')
title_text = doc.createTextNode(str(img_size[1]))
title.appendChild(title_text)
size.appendChild(title)
title = doc.createElement('height')
title_text = doc.createTextNode(str(img_size[0]))
title.appendChild(title_text)
size.appendChild(title)
title = doc.createElement('depth')
title_text = doc.createTextNode(str(img_size[2]))
title.appendChild(title_text)
size.appendChild(title)
for split_line in split_lines:
line=split_line.strip().split()
if line[0] in class_ind:
object = doc.createElement('object')
annotation.appendChild(object)
title = doc.createElement('name')
title_text = doc.createTextNode(line[0])
title.appendChild(title_text)
object.appendChild(title)
bndbox = doc.createElement('bndbox')
object.appendChild(bndbox)
title = doc.createElement('xmin')
title_text = doc.createTextNode(str(int(float(line[4]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('ymin')
title_text = doc.createTextNode(str(int(float(line[5]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('xmax')
title_text = doc.createTextNode(str(int(float(line[6]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('ymax')
title_text = doc.createTextNode(str(int(float(line[7]))))
title.appendChild(title_text)
bndbox.appendChild(title)
# 将DOM对象doc写入文件
f = open('G:/A_Data/kitti_voc/Annotations/'+name+'.xml', 'w')
f.write(doc.toprettyxml(indent=''))
f.close()
if __name__ == '__main__':
class_ind = ('Pedestrian', 'Car', 'Cyclist')
# cur_dir = os.getcwd()
cur_dir = 'G:/A_Data/kitti_voc/'
labels_dir = os.path.join(cur_dir, 'label')
for parent, dirnames, filenames in os.walk(labels_dir): # 分别得到根目录,子目录和根目录下文件
for file_name in tqdm.tqdm(filenames):
full_path=os.path.join(parent, file_name) # 获取文件全路径
f=open(full_path)
split_lines = f.readlines()
name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名
img_name = name+'.png'
img_path = os.path.join('E:/A____paper/person/data/kitti/training/image_2/', img_name) # 路径需要自行修改
img_size = cv2.imread(img_path).shape
generate_xml(name, split_lines, img_size, class_ind)
print('all txts has converted into xmls')
划分训练集和验证集
# create_train_test_txt.py
# encoding:utf-8
import pdb
import glob
import os
import random
import math
def get_sample_value(txt_name, category_name):
label_path = 'G:/A_Data/kitti_voc/label/'
txt_path = label_path + txt_name+'.txt'
try:
with open(txt_path) as r_tdf:
if category_name in r_tdf.read():
return ' 1'
else:
return '-1'
except IOError as ioerr:
print('File error:'+str(ioerr))
# txt_list_path = glob.glob('./Labels/*.txt')
txt_list_path = os.listdir(path='G:/A_Data/kitti_voc/label/')
txt_list = []
for item in txt_list_path:
temp1,temp2 = os.path.splitext(os.path.basename(item))
txt_list.append(temp1)
txt_list.sort()
print(txt_list, end = '\n\n')
# 有博客建议train:val:test=8:1:1,先尝试用一下
num_trainval = random.sample(txt_list, math.floor(len(txt_list)*9/10.0)) # 可修改百分比
num_trainval.sort()
print(num_trainval, end = '\n\n')
num_train = random.sample(num_trainval,math.floor(len(num_trainval)*8/9.0)) # 可修改百分比
num_train.sort()
print(num_train, end = '\n\n')
num_val = list(set(num_trainval).difference(set(num_train)))
num_val.sort()
print(num_val, end = '\n\n')
num_test = list(set(txt_list).difference(set(num_trainval)))
num_test.sort()
print(num_test, end = '\n\n')
# pdb.set_trace()
Main_path = 'G:/A_Data/kitti_voc/ImageSets/Main/'
train_test_name = ['trainval', 'train', 'val', 'test']
category_name = ['Car', 'Pedestrian', 'Cyclist']
# 循环写trainvl train val test
for item_train_test_name in train_test_name:
list_name = 'num_'
list_name += item_train_test_name
train_test_txt_name = Main_path + item_train_test_name + '.txt'
try:
# 写单个文件
with open(train_test_txt_name, 'w') as w_tdf:
# 一行一行写
for item in eval(list_name):
w_tdf.write(item+'\n')
# 循环写Car Pedestrian Cyclist
for item_category_name in category_name:
category_txt_name = Main_path + item_category_name + '_' + item_train_test_name + '.txt'
with open(category_txt_name, 'w') as w_tdf:
# 一行一行写
for item in eval(list_name):
w_tdf.write(item+' '+ get_sample_value(item, item_category_name)+'\n')
except IOError as ioerr:
print('File error:'+str(ioerr))
VOC转COCO
生成训练和验证的json,和相应的图片文件夹
# -*- coding=utf-8 -*-
#!/usr/bin/python
import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET
# 检测框的ID起始值
import tqdm
START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
PRE_DEFINE_CATEGORIES = {}
def get(root, name):
vars = root.findall(name)
return vars
def get_and_check(root, name, length):
vars = root.findall(name)
if len(vars) == 0:
raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
if length > 0 and len(vars) != length:
raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
if length == 1:
vars = vars[0]
return vars
# 得到图片唯一标识号
def get_filename_as_int(filename):
try:
filename = os.path.splitext(filename)[0]
return int(filename)
except:
raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))
def convert(xml_list, xml_dir, json_file):
'''
:param xml_list: 需要转换的XML文件列表
:param xml_dir: XML的存储文件夹
:param json_file: 导出json文件的路径
:return: None
'''
list_fp = xml_list
# 标注基本结构
json_dict = {"images":[],
"type": "instances",
"annotations": [],
"categories": []}
categories = PRE_DEFINE_CATEGORIES
bnd_id = START_BOUNDING_BOX_ID
for line in list_fp:
line = line.strip()
# print("buddy~ Processing {}".format(line))
# 解析XML
xml_f = os.path.join(xml_dir, line)
tree = ET.parse(xml_f)
root = tree.getroot()
path = get(root, 'path')
# 取出图片名字
if len(path) == 1:
filename = os.path.basename(path[0].text)
elif len(path) == 0:
filename = get_and_check(root, 'filename', 1).text
else:
raise NotImplementedError('%d paths found in %s'%(len(path), line))
## The filename must be a number
image_id = get_filename_as_int(filename) # 图片ID
size = get_and_check(root, 'size', 1)
# 图片的基本信息
width = int(get_and_check(size, 'width', 1).text)
height = int(get_and_check(size, 'height', 1).text)
image = {'file_name': filename,
'height': height,
'width': width,
'id':image_id}
json_dict['images'].append(image)
## Cruuently we do not support segmentation
# segmented = get_and_check(root, 'segmented', 1).text
# assert segmented == '0'
# 处理每个标注的检测框
for obj in get(root, 'object'):
# 取出检测框类别名称
category = get_and_check(obj, 'name', 1).text
# 更新类别ID字典
if category not in categories:
new_id = len(categories)
categories[category] = new_id
category_id = categories[category]
bndbox = get_and_check(obj, 'bndbox', 1)
xmin = int(get_and_check(bndbox, 'xmin', 1).text)
ymin = int(get_and_check(bndbox, 'ymin', 1).text)
xmax = int(get_and_check(bndbox, 'xmax', 1).text)
ymax = int(get_and_check(bndbox, 'ymax', 1).text)
assert(xmax > xmin)
assert(ymax > ymin)
o_width = abs(xmax - xmin)
o_height = abs(ymax - ymin)
annotation = dict()
annotation['area'] = o_width*o_height
annotation['iscrowd'] = 0
annotation['image_id'] = image_id
annotation['bbox'] = [xmin, ymin, o_width, o_height]
annotation['category_id'] = category_id
annotation['id'] = bnd_id
annotation['ignore'] = 0
# 设置分割数据,点的顺序为逆时针方向
annotation['segmentation'] = [[xmin,ymin,xmin,ymax,xmax,ymax,xmax,ymin]]
json_dict['annotations'].append(annotation)
bnd_id = bnd_id + 1
# 写入类别ID字典
for cate, cid in categories.items():
cat = {'supercategory': 'none', 'id': cid, 'name': cate}
json_dict['categories'].append(cat)
# 导出到json
json_fp = open(json_file, 'w')
json_str = json.dumps(json_dict)
json_fp.write(json_str)
json_fp.close()
if __name__ == '__main__':
# root_path = os.getcwd()
root_path = 'G:/A_Data/kitti_voc/'
ori_path = 'E:/A____paper/person/data/kitti/training/image_2/'
xml_dir = os.path.join(root_path, 'Annotations')
xml_labels = os.listdir(os.path.join(root_path, 'Annotations'))
np.random.shuffle(xml_labels)
split_point = int(len(xml_labels)/10) # 这个是直接把前10%的图片作为验证集
# validation data
xml_list = xml_labels[0:split_point]
json_file = 'G:/A_Data/kitti_voc/instances_val2014.json'
convert(xml_list, xml_dir, json_file)
for xml_file in tqdm.tqdm(xml_list):
img_name = xml_file[:-4] + '.jpg'
shutil.copy(os.path.join(ori_path, img_name),
os.path.join(root_path, 'val_img', img_name))
# train data
xml_list = xml_labels[split_point:]
json_file = 'G:/A_Data/kitti_voc/instances_train2014.json'
convert(xml_list, xml_dir, json_file)
for xml_file in tqdm.tqdm(xml_list):
img_name = xml_file[:-4] + '.jpg'
shutil.copy(os.path.join(ori_path, img_name),
os.path.join(root_path, 'train_img', img_name))
版权声明:本文为CSDN博主「Kerwin_97」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_51486807/article/details/122097028
暂无评论