Pytorch模型概览(Pytorch Summary)

  • A+
所属分类:深度学习
摘要关于查看Pytorch生成的模型简介的一个库, pytorch-summary, 可以用来查看参数个数, 对网络占用的缓存大小进行估计.

简介

在使用Pytorch创建完模型之后, 我们希望可以有一个工具可以看到模型的参数个数, 模型占用的内容. 这个就类似于在Keras中的model.summary()的功能. 但是在Pytorch中, 本身并没有实现这个功能, 但是有一个库实现了该功能.  pytorch-summary. 这里简单介绍一下.

官方链接: pytorch-summary

 

Pytorch Summary简单使用介绍

安装直接使用pip进行安装即可.

  1. pip install torchsummary

对于任何一个模型来说, 可以使用下面的方式来看summary

  1. from torchsummary import summary
  2. summary(your_model, input_size=(channels, H, W))

其中input_size需要可以符合你模型的输入要求, 可以用来进行前向传播.

下面看一个具体的例子和最终的输出结果.

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchsummary import summary
  5. class Net(nn.Module):
  6.     def __init__(self):
  7.         super(Net, self).__init__()
  8.         self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  9.         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  10.         self.conv2_drop = nn.Dropout2d()
  11.         self.fc1 = nn.Linear(320, 50)
  12.         self.fc2 = nn.Linear(50, 10)
  13.     def forward(self, x):
  14.         x = F.relu(F.max_pool2d(self.conv1(x), 2))
  15.         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  16.         x = x.view(-1, 320)
  17.         x = F.relu(self.fc1(x))
  18.         x = F.dropout(x, training=self.training)
  19.         x = self.fc2(x)
  20.         return F.log_softmax(x, dim=1)
  21. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
  22. model = Net().to(device)
  23. summary(model, (1, 28, 28))

最终的结果如下所示, 可以看到每一层的具体变化.

Pytorch模型概览(Pytorch Summary)
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: