torch做目标检测时报错 RuntimeError: CUDA error: device-side assert triggered 原因和解决方法

用自己的数据训练torchvision的maskrcnn时候,报错如下:

Traceback (most recent call last):
  File "main_train_detection.py", line 232, in <module>
    main(params)
  File "main_train_detection.py", line 201, in main
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
  File "/raid/huaqing/tyler/suzhou/code/utils/engine.py", line 37, in train_one_epoch
    loss_dict = model(images, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/generalized_rcnn.py", line 97, in forward
    detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 760, in forward
    loss_classifier, loss_box_reg = fastrcnn_loss(
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 40, in fastrcnn_loss
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

发现根本原因是, 类别标签没有从0开始编号:
我的要识别的目标其实共有3类,所以设置的类别总数是3. 然后设置类别标签和类别的对应关系分别是:

cls_dict = {'holes':1, 'marker':2, 'band':3}.

在类别标签(label)编号的时候,其实是从0开始编号的, 对于总共3类别的情况,则label编号分别是0,1,2. 也就是说并没有label==3这个类别.所以采用上述cls_dict,将导致band类的编号溢出. 应更正如下:

cls_dict = {'holes':0, 'marker':1, 'band':2}

版权声明:本文为CSDN博主「Huatsing Liu」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/u014466076/article/details/121989323

Huatsing Liu

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

【目标检测】锚框

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 前言 【个人学习笔记记录,如有错误,请指正!】 一、锚框 理解锚框之前,我们需

yolo-fastest模型

两个关于yolo-fastest的资料 https://github.com/dog-qiuqiu/Yolo-FastestV2/ https://github.com/dog-qiuqiu/Yolo-Fastest

手把手教你实现YOLOv3 (一)

1. 引言 最近整理了YOLO系列相关论文阅读笔记,发现仅仅靠阅读论文还是有很多内容一知半解,吃得不是很透彻. 尽管网络上有很多博客都在讲解,但是很多实现细节细究起来还是有些困难. 俗话说的好: Talk is cheap. Show me