1、论文下载地址:
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. [paper]
2、代码下载地址:
SwinT可以用于分类、检测、分割等任务
原地址(此代码用于图像分类):
https://github.com/microsoft/Swin-Transformer
下载目标检测代码:
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
3、新建虚拟python环境并激活
conda create -n SwinTrans python=3.7
conda activate SwinTrans
4、安装pytorch和torchvision
pip3 install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
注意:因为作者是用较高版本的torch训练的模型,所以安装的torch版本要大于等于1.6.0。我最起初安装torch1.4.0的版本时,加载作者提供的预训练模型时出现如下从错误:
RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:132)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f7cdd46a193 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7f7c548399eb in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7f7c5483ac04 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x6c6536 (0x7f7c9c76b536 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x295a74 (0x7f7c9c33aa74 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #46: __libc_start_main + 0xf0 (0x7f7cee1fb840 in /lib/x86_64-linux-gnu/libc.so.6)
5、安装mmcv-full
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
我的CUDA版本是10.1,torch版本是1.6.0,大家针对性更改。安装编译需要很长一段时间,耐心等待。
6、安装MMDetection
进入工程路径运行:
pip install -r requirements/build.txt
python setup.py develop
7、下载预训练模型
密码: swin
新建checkpoints路径并放入
8、新建demo.py并输入如下代码:
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import cv2
config_file = 'configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'checkpoints/cascade_mask_rcnn_swin_small_patch4_window7.pth'
device = 'cuda:0'
# init a detector
model = init_detector(config_file, checkpoint_file, device=device)
# inference the demo image
image='demo/demo.jpg'
result = inference_detector(model, image)
show_result_pyplot(model, image, result, score_thr=0.3)
# image = model.show_result(image, result, score_thr=0.3)
#
# cv2.imshow('demo', image)
# cv2.waitKey()
9、运行python demo.py得到预测结果
版权声明:本文为CSDN博主「博博有个大大大的Dream」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_17783559/article/details/119381672
暂无评论