文章目录[隐藏]
本文展示了如何使用MATLAB训练Faster R-CNN目标检测器,实现对车辆的检测。本例使用一个包含295张图像的小标记数据集。每个图像包含一个或两个已标记的车辆目标。一个小的数据集对于探索 Faster R-CNN 训练过程是有用的,但在实践中,需要更多的标注图像来训练一个鲁棒的检测器。
一、数据集准备
1、数据集下载
本数据集中的许多图像来自加州理工学院1999年和2001年的汽车数据集,可以在加州理工学院计算视觉网站上获得,由Pietro Perona创建并经许可使用。
% 下载一个预先训练过的检测器,以避免等待训练完成。如果您想训练检测器,请将doTraining变量设置为true。
doTraining = false;
if ~doTraining && ~exist('fasterRCNNResNet50EndToEndVehicleExample.mat','file')
disp('Downloading pretrained detector (118 MB)...');
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/fasterRCNNResNet50EndToEndVehicleExample.mat';
websave('fasterRCNNResNet50EndToEndVehicleExample.mat',pretrainedURL);
end
2、数据集加载
下面进行解压缩车辆图像,加载数据集:
unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;
车辆数据存储在一个两列表中,其中第一列包含图像文件路径,第二列包含车辆包围框。
3、数据集划分
将数据集分解为训练集、验证集和测试集。选择60%的数据用于训练,10%用于验证,其余的用于测试训练过的检测器。
rng(0)
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * height(vehicleDataset));
trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);
4、创建数据存储
使用imageDatastore和boxLabelDatastore创建用于在培训和评估期间加载图像和标签数据的数据存储。
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,'vehicle'));
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(validationDataTbl(:,'vehicle'));
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(testDataTbl(:,'vehicle'));
组合图像和标签数据存储:
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
下面来展示数据集中的一张图片和对应的标签框:
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
原图和标签如下:
二、创建Faster R-CNN网络
Faster R-CNN目标检测网络由特征提取网络和两个子网络组成。特征提取网络通常是预先训练的CNN,如ResNet-50或Inception v3。在特征提取网络之后的第一个子网络是区域建议网络(RPN),训练它生成候选区域——图像中可能存在目标的区域。第二个子网络被训练来预测每个对象提议的实际类别。
特征提取网络通常是预先训练的CNN,本例使用ResNet-50进行特征提取。您还可以使用其他预先训练过的网络,如MobileNet v2或ResNet-18,这取决于您的应用程序需求。
1、初始化参数
使用fasterRCNNLayers创建Faster R-CNN网络,需要你指定几个参数:
1)确定网络输入图片的大小
inputSize = [224 224 3];
2)指定锚框大小和个数
numAnchors = 3;
anchorBoxes = [29 17; 46 39; 136 116];
3)确定特征提取网络
featureExtractionNetwork = resnet50;
在使用 ResNet-50 之前,需要打开附加功能资源管理器,并点击安装Deep Learning Toolbox Model for ResNet-50
Network。
2、创建网路
选择’activation_40_relu’作为特征提取层。该特征提取层输出特征映射,向下采样的因子为16。这种向下采样量是空间分辨率和提取特征强度之间的一个很好的权衡,因为进一步向下提取的特征以空间分辨率为代价编码更强的图像特征。选择最优的特征提取层需要实证分析。您可以使用analyzeNetwork来查找网络中其他潜在特征提取层的名称。
% 特征提取层
featureLayer = 'activation_40_relu';
% 定义类别
numClasses = width(vehicleDataset)-1;
% 创建网络
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
三、训练网络
1、设置训练选项
使用trainingOptions指定网络训练选项,设置’ValidationData’为预处理的验证数据。将CheckpointPath设置为临时位置。这使得在训练过程中能够保存部分训练过的检测器。如果训练被中断,例如断电或系统故障,您可以从保存的检查点恢复训练。
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
2、加载模型训练
如果doTraining为true,则使用trainFasterRCNNObjectDetector训练Fast R-CNN对象检测器。否则,加载预训练的网络。
if doTraining
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
else
% Load pretrained detector for the example.
pretrained = load('fasterRCNNResNet50EndToEndVehicleExample.mat');
detector = pretrained.detector;
end
四、车辆目标检测
为了检查训练效果,我们读取一张测试图像,并运行训练好的检测器。
% 读取图片
I = imread(testDataTbl.imageFilename{3});
% 将图像调整为与训练图像相同的大小
I = imresize(I,inputSize(1:2));
% 运行检测器
[bboxes,scores] = detect(detector,I);
% 显示效果
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
显示检测结果:
版权声明:本文为CSDN博主「AI 菌」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/wjinjie/article/details/120771409
暂无评论