目标检测部署(卡牌识别)

        最近在折腾yolov5,训练了一个识别纸牌的模型,最后使用onnxruntime进行部署,感兴趣的可以上github上clone下来玩玩,模型的权重文件上传到了百度网盘,链接和提取码写在readme里。

        模型的训练使用了yolov5l的权重模型,训练的时候使用的batchsize为8(理论可以设置的更大,gpu的占用还没吃满),训练了200个epoch,取了效果最好的权重模型。

        从git上下载下来后的文件结构大致是这样的:

         function文件夹中存放了一些工具函数,image存放待检测的图片,model则存放模型的权重文件。

config.py

LABEL_DICT = {'cardlabel': ['10C', '10D', '10H', '10S', '2C', '2D', '2H', '2S', '3C', '3D', '3H', '3S',
                            '4C', '4D', '4H', '4S', '5C', '5D', '5H', '5S', '6C', '6D', '6H', '6S', '7C',
                            '7D', '7H', '7S', '8C', '8D', '8H', '8S', '9C', '9D', '9H', '9S', '1C', '1D',
                            '1H', '1S', '11C', '11D', '11H', '11S', '13C', '13D', '13H', '13S', '12C', '12D', '12H', '12S']}

        这里主要是包括52张牌(不包括大小王)的label,J,Q,K,A我都用对应的数字进行替代,数字后面的H,C,D,S对应的是花色,分别是红桃,草花,方片和黑桃。

utils.py(只贴部分代码)

class LoadImages:
    def __init__(self, path, img_size=640, stride=32, auto=True):
        p = str(Path(path).resolve())
        if '*' in p:
            files = sorted(glob.glob(p, recursive=True))
        elif os.path.isdir(p):
            files = sorted(glob.glob(os.path.join(p, '*.*')))
        elif os.path.isfile(p):
            files = [p]
        else:
            raise Exception(f'ERROR: {p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
        ni, nv = len(images), len(videos)

        self.img_size = img_size
        self.stride = stride
        self.files = images + videos
        self.nf = ni + nv
        self.video_flag = [False] * ni + [True] * nv
        self.mode = 'image'
        self.auto = auto
        if any(videos):
            self.new_video(videos[0])
        else:
            self.cap = None
        assert self.nf > 0, f'No images or videos found in {p}. ' \
                            f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            ret_val, img0 = self.cap.read()
            if not ret_val:
                self.count += 1
                self.cap.release()
                if self.count == self.nf:
                    raise StopIteration
                else:
                    path = self.files[self.count]
                    self.new_video(path)
                    ret_val, img0 = self.cap.read()

            self.frame += 1
            print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')

        else:
            # Read image
            self.count += 1
            img0 = cv2.imread(path)  # BGR
            assert img0 is not None, 'Image Not Found ' + path
            print(f'image {self.count}/{self.nf} {path}: ', end='')

        # Padded resize
        img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]

        # Convert
        img = img.transpose((2, 0, 1))[::-1]
        img = np.ascontiguousarray(img)

        return path, img, img0, self.cap

    def new_video(self, path):
        self.frame = 0
        self.cap = cv2.VideoCapture(path)
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

    def __len__(self):
        return self.nf

        这个类主要负责读取图像数据。

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=(), max_det=300):
    nc = prediction.shape[2] - 5
    xc = prediction[..., 4] > conf_thres

    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    min_wh, max_wh = 2, 4096
    max_nms = 30000
    time_limit = 10.0
    redundant = True
    multi_label &= nc > 1
    merge = False

    t = time.time()
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):
        # Apply constraints
        x = x[xc[xi]]

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]
            v[:, 4] = 1.0
            v[range(len(l)), l[:, 0].long() + 5] = 1.0
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]
        box = xywh2xyxy(x[:, :4])

        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]
        if not n:
            continue
        elif n > max_nms:
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)
        boxes, scores = x[:, :4] + c, x[:, 4]
        i = torchvision.ops.nms(boxes, scores, iou_thres)
        if i.shape[0] > max_det:
            i = i[:max_det]
        if merge and (1 < n < 3E3):
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres
            weights = iou * scores[None]
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)
            if redundant:
                i = i[iou.sum(1) > 1]

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break
    return output

        这里主要是非极大值抑制,输出最终的预测结果,包括坐标值,置信度以及标签索引值

# -*- coding: utf-8 -*-
"""
Time:     2021.10.26
Author:   Athrunsunny
Version:  V 0.1
File:     inference.py
Describe: Functions in this file is use to inference
"""

import cv2
import torch
import time
import onnxruntime
import numpy as np
from function.utils import LoadImages, Annotator, colors, check_img_size, non_max_suppression, scale_coords
from function import config as CFG


def load_model(weights, **options):
    imgsz = options.pop('imgsz', 640)
    stride = options.pop('stride', 64)

    w = str(weights[0] if isinstance(weights, list) else weights)
    session = onnxruntime.InferenceSession(w, None)
    imgsz = check_img_size(imgsz, s=stride)
    return session, imgsz, stride


def image_process(img):
    assert isinstance(img, np.ndarray)
    img = img.astype('float32')
    img /= 255.0
    if len(img.shape) == 3:
        img = img[None]
    return img


def inference(session, img, **options):
    conf_thres = options.pop('conf_thres', 0.25)
    iou_thres = options.pop('iou_thres', 0.45)
    classes = options.pop('classes', None)
    agnostic = options.pop('agnostic', False)
    max_det = options.pop('max_det', 1000)

    pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
    pred = non_max_suppression(pred, conf_thres=conf_thres, iou_thres=iou_thres, classes=classes, max_det=max_det,
                               agnostic=agnostic)
    return pred


def post_process(pred, img, im0s, dataset, **options):
    showImg = options.pop('showImg', False)
    hide_conf = options.pop('hide_conf', False)
    hide_labels = options.pop('hide_labels', False)
    line_thickness = options.pop('line_thickness', 1)
    labelDict = options.pop('labelDict', None)

    labels = labelDict['cardlabel']
    res_label = []
    for i, det in enumerate(pred):
        s, im0, frame = '', im0s.copy(), getattr(dataset, 'frame', 0)
        annotator = Annotator(im0, line_width=line_thickness, example=str(labels))
        if len(det):
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
            for *xyxy, conf, cls in reversed(det):
                c = int(cls)
                label = None if hide_labels else (labels[c] if hide_conf else f'{labels[c]} {conf:.2f}')
                label_no_conf = None if hide_labels else (labels[c] if hide_conf else f'{labels[c]}')
                res_label.append(label_no_conf)
                annotator.box_label(xyxy, label, color=colors(c, True))
        print(f'{s}')
        im0 = annotator.result()
        if showImg:
            cv2.imshow('result', im0)
            cv2.waitKey(0)
    return res_label


def run(weights, source, **options):
    conf_thres = options.pop('conf_thres', 0.25)  # confidence threshold
    iou_thres = options.pop('iou_thres', 0.45)  # NMS IOU threshold
    classes = options.pop('classes', None)  # filter by class: --class 0, or --class 0 2 3
    agnostic = options.pop('agnostic', False)  # class-agnostic NMS
    max_det = options.pop('max_det', 1000)  # maximum detections per image
    hide_conf = options.pop('hide_conf', False)  # hide confidences
    hide_labels = options.pop('hide_labels', False)  # hide labels
    line_thickness = options.pop('line_thickness', 1)  # bounding box thickness (pixels)
    imgsz = options.pop('imgsz', 640)  # inference size (pixels)
    showImg = options.pop('showImg', False)  # show results
    labelDict = options.pop('labelDict', CFG.LABEL_DICT)  # config labels

    session, imgsz, stride = load_model(weights=weights, imgsz=imgsz)
    dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=False)
    res = []
    for path, img, im0s, vid_cap in dataset:
        img = image_process(img)
        t1 = time.time()
        pred = inference(session, img, conf_thres=conf_thres, iou_thres=iou_thres, max_det=max_det, classes=classes,
                         agnostic=agnostic)
        t2 = time.time()
        print('Inference time:%.3fs' % (t2 - t1))
        res = post_process(pred, img, im0s, dataset, hide_conf=hide_conf, hide_labels=hide_labels,
                           line_thickness=line_thickness, showImg=showImg, labelDict=labelDict)
    return res


if __name__ == '__main__':
    imagepath = 'image/1.jpg'
    modelpath = 'model/weight.onnx'
    res = run(modelpath, imagepath, showImg=True)
    print(res)

        该项目到这里也就结束了,代码量也比较少,比较容易理解,以下附一张实际检测的效果图

        由于用的是yolov5l的模型,最后检测的时候,也是比较耗时,cpu上平均检测耗时为500ms左右。

版权声明:本文为CSDN博主「athrunsunny」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/athrunsunny/article/details/120983318

athrunsunny

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

YOLO-V3-SPP详细解析

YOLO-V3-SPP 继前两篇简单的YOLO博文 YOLO-V1 论文理解《You Only Look Once: Unified, Real-Time Object Detection》YOLO-V2论文理解《YOLO9000: Bet