文章目录[隐藏]
本文代码来自于https://github.com/bubbliiiing/faster-rcnn-pytorch,b站视频https://www.bilibili.com/video/BV1BK41157Vs?p=1,本文仅作学习使用
1.数据集划分
加载的数据集为VOC2007,对其进行划分,(train+val):(test). ,训练只需要2007_train.txt and 2007_val.txt
voc_annotation.py
import os
import random
import xml.etree.ElementTree as ET
from utils.utils import get_classes
#--------------------------------------------------------------------------------------------------------------------------------#
# annotation_mode用于指定该文件运行时计算的内容
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode = 0
#-------------------------------------------------------------------#
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
# 与训练和预测所用的classes_path一致即可
# 如果生成的2007_train.txt里面没有目标信息
# 那么就是因为classes没有设定正确
# 仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
#--------------------------------------------------------------------------------------------------------------------------------#
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
# 仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
trainval_percent = 0.9
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
#见1.1
classes, _ = get_classes(classes_path)
def convert_annotation(year, image_id, list_file):
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
#cls_id是类别映射为数字
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
if __name__ == "__main__":
random.seed(0)
if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.")
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval= random.sample(list,tv)
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
#trainval=训练集加验证集
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
#test=测试集
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
#train=训练集
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
#val验证集
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
if annotation_mode == 0 or annotation_mode == 2:
print("Generate 2007_train.txt and 2007_val.txt for train.")
for year, image_set in VOCdevkit_sets:
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
convert_annotation(year, image_id, list_file)
list_file.write('\n')
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")
1.1 voc_classes.txt
存放的是Voc数据集的类别
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
#---------------------------------------------------#
# 获得类
#---------------------------------------------------#
def get_classes(classes_path):
with open(classes_path, encoding='utf-8') as f:
class_names = f.readlines()
#得到类别
class_names = [c.strip() for c in class_names]
return class_names, len(class_names)
2.重写加载数据的类DataLoader
重写torch.utils.data.dataset.Dataset 来返回图片和我们的bbox信息
class FRCNNDataset(Dataset):
def __init__(self, annotation_lines, input_shape = [600, 600], train = True):
self.annotation_lines = annotation_lines#annotation_lines信息
self.length = len(annotation_lines)
self.input_shape = input_shape#图片大小
self.train = train
# 返回长度
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length
#---------------------------------------------------#
# 训练时进行数据的随机增强
# 验证时不进行数据的随机增强
#---------------------------------------------------#
# get_random_data函数见2.1
image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
box_data = np.zeros((len(y), 5))
if len(y) > 0:
box_data[:len(y)] = y
box = box_data[:, :4]
label = box_data[:, -1]
return image, box, label
2.1 get_random_data
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
line = annotation_line.split()
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image = Image.open(line[0])
# cvtColor见2.1.1
image = cvtColor(image)
#------------------------------#
# 获得图像的高宽与目标高宽
#. 按照比例来放大或者缩小图片,bbox
#------------------------------#
iw, ih = image.size
h, w = input_shape
#------------------------------#
# 获得预测框
#. 见2.2.2. bbox:预测框的坐标
#------------------------------#
# box = np.array(
# [np.array(
# list(
# map(int, box.split(',')
# )
# )
# ) for box in line[1:]]
# )
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]
#不是训练,是test
if not random:
scale = min(w/iw, h/ih)#缩放比例
nw = int(iw*scale)
nh = int(ih*scale)
#//向下取整,见2.2.3
dx = (w-nw)//2#缩放多余的部分
dy = (h-nh)//2
#---------------------------------#
# 将图像多余的部分加上灰条
#。 https://www.pianshen.com/article/9180110157/
#---------------------------------#
#resize将图片缩放到(nw,nh),插值方法为双三次插值BICUBIC
#PIL.Image.new(mode, size, color)使用給定的模式和大小创建一个新图像
#參數:
#mode:用於新圖像的模式。 (可能是RGB,RGBA)
#size:包含(寬度,高度)以像素為單位的2元組。
#color:圖像使用什麽顏色。默認為黑色。如果給出的話,對於單頻帶模式,它應該是單個整數或浮點值;對於多頻帶模式,它應該是一個元組。
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', (w,h), (128,128,128))
#paste. https://blog.csdn.net/MiniCatTwo/article/details/80626330
#image粘贴到new_image的(dx,dy)位置
new_image.paste(image, (dx, dy))
image_data = np.array(new_image, np.float32)
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)#随机打乱
'''
X1,X2=(X1,X2)*(nw/iw)+dx
y1,y2=(y1,y2)*(nh/ih)+dy
x坐标按照比例扩大,再加上边距dx
y坐标按照比例扩大,再加上边距dy
'''
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
#限制条件
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
#返回构造好的(600,600)的图片,对应的bbox坐标
return image_data, box
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
# train
#------------------------------------------#
new_ar = w/h * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(.25, 2)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)
#------------------------------------------#
# 将图像多余的部分加上灰条
#------------------------------------------#
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image = new_image
#------------------------------------------#
# 翻转图像
#------------------------------------------#
flip = self.rand()<.5
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
#------------------------------------------#
# 色域扭曲
#. [...,1]表示遍历每行第一列
#------------------------------------------#
hue = self.rand(-hue, hue)
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
#x是HSV格式的图片,在 HSV 色彩空间中 H,S,V 这三个通道分别代表着色相(Hue),饱和度(Saturation)和明度(Value)。
#随机对图像进行增强
#[...,0]是hue,1是Sat,2是val
x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
x[..., 0] += hue*360
x[..., 0][x[..., 0]>1] -= 1
x[..., 0][x[..., 0]<0] += 1
x[..., 1] *= sat
x[..., 2] *= val
x[x[:,:, 0]>360, 0] = 360
x[:, :, 1:][x[:, :, 1:]>1] = 1
x[x<0] = 0
image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
if flip: box[:, [0,2]] = w - box[:, [2,0]]
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)]
return image_data, box
#返回从a到b的随机值
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
2.1.1 cvtColor
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#. https://www.cnblogs.com/haifwu/p/12825741.html
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
2.2.2 获得预测框
train_annotation_path='../2007_train.txt'
with open(train_annotation_path) as f:
train_lines = f.readlines()
num_train = len(train_lines)
print(train_lines)
print(num_train)
annotation_line=train_lines[0]
print('annotation_line:',annotation_line)
line = annotation_line.split()
# ------------------------------#
# 读取图像并转换成RGB图像
# ------------------------------#
image = Image.open(line[0])
# ------------------------------#
# 获得预测框
# ------------------------------#
#i=1时
print(line[1])
l1=map(int, line[1].split(','))
l1=list(l1)
print(l1)
l1=np.array(l1)
print(l1)
'''
输出
annotation_line: faster-rcnn-pytorch-master/VOCdevkit/VOC2007/JPEGImages/000004.jpg 13,311,84,362,6 362,330,500,389,6 235,328,334,375,6 175,327,252,364,6 139,320,189,359,6 108,325,150,353,6 84,323,121,350,6
13,311,84,362,6
[13, 311, 84, 362, 6]
[ 13 311 84 362 6]
'''
就是得到了2007_train.txt的bbox坐标
2.2.3 图片缩放
这儿的宽是h,高是w,写反了,懒得修改了。
3.使用数据集
# 重写数据加载类
train_dataset = FRCNNDataset(train_lines, input_shape, train = True)
val_dataset = FRCNNDataset(val_lines, input_shape, train = False)
#得到数据集的loader
gen = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=frcnn_dataset_collate)
gen_val = DataLoader(val_dataset , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=frcnn_dataset_collate)
版权声明:本文为CSDN博主「微凉code」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_41921315/article/details/122523150
暂无评论