【DL】torch小技巧之网络参数统计 torchstat & torchsummary

1 torchstat 优缺点

优点:直接就可以使用,打印所以结构,信息详细,名字详细。
缺点:按照类的顺序进行打印,循环中未使用的结果也打印了。

1.1 安装工具包torchstat

pip install torchstat

1.2 测试参数

from torchstat import stat

# 导入模型,输入一张输入图片的尺寸
stat(model, (3, 224, 224))

1.3 结果

CCMBlk(
  (relu): ReLU(inplace=True)
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
      module name  input shape output shape   params memory(MB)          MAdd         Flops  MemRead(B)  MemWrite(B) duration[%]  MemR+W(B)
0            relu   64  32  32   64  32  32      0.0       0.25      65,536.0      65,536.0    262144.0     262144.0       0.00%   524288.0
1           conv1    3  32  32   64  32  32   1792.0       0.25   3,538,944.0   1,835,008.0     19456.0     262144.0      40.03%   281600.0
2           conv2   64  32  32   64  32  32  36928.0       0.25  75,497,472.0  37,814,272.0    409856.0     262144.0      19.99%   672000.0
3         maxpool   64  32  32   64  16  16      0.0       0.06     131,072.0      65,536.0    262144.0      65536.0      39.98%   327680.0
total                                        38720.0       0.81  79,233,024.0  39,780,352.0    262144.0      65536.0     100.00%  1805568.0
===========================================================================================================================================
Total params: 38,720
-------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 0.81MB
Total MAdd: 79.23MMAdd
Total Flops: 39.78MFlops
Total MemR+W: 1.72MB

2 torchsummary优缺点(*准确)

优点:需要把模型放在Cuda,只打印使用的结构。
缺点:名字使用的内部名字

2.1 安装工具包torchsummary

pip install torchsummary 

2.2 测试参数

from torchsummary import summary

# 导入模型,输入一张输入图片的尺寸
summary(model.cuda(), input_size=(3, 32, 32), batch_size=-1)

2.3 结果

CCMBlk(
  (relu): ReLU(inplace=True)
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
              ReLU-2           [-1, 64, 32, 32]               0
            Conv2d-3           [-1, 64, 32, 32]          36,928
              ReLU-4           [-1, 64, 32, 32]               0
         MaxPool2d-5           [-1, 64, 16, 16]               0
================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 2.12
Params size (MB): 0.15
Estimated Total Size (MB): 2.28
----------------------------------------------------------------

版权声明:本文为CSDN博主「张林克」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_45292794/article/details/108227437

张林克

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

单目3D目标检测调研

单目3D目标检测调研 一、 简介 现有的单目3D目标检测方案主要方案主要分为两类,分别为基于图片的方法和基于伪雷达点云的方法。   基于图片的方法一般通过2D-3D之间的几何约束来学习,包括目标形状信息&#xff0

mmdetection特征图可视化

mmdetection对特征图进行可视化 思路:在前向传播时将四个stage的特征图返回出来(更简单的方法在我下一篇博客,欢迎阅读) 1.two_stage.py修改 我修改的地方都