Pytorch模型概览(Pytorch Summary)

王 茂南 2020年4月8日07:33:33
评论
1 1814字阅读6分2秒
摘要关于查看Pytorch生成的模型简介的一个库, pytorch-summary, 可以用来查看参数个数, 对网络占用的缓存大小进行估计.

简介

在使用 Pytorch 创建完模型之后, 我们希望可以有一个工具可以看到模型的参数个数, 模型占用的内容. 这个就类似于在 Keras 中的 model.summary() 的功能. 但是在 Pytorch 中, 本身并没有实现这个功能, 但是有一个库实现了该功能.  pytorch-summary. 这里简单介绍一下。

【更新】目前官方已经推荐使用另外一个库,Github-torchinfo,来完成相同的功能。如果想要看到模型的结构的可视化,可以直接写入 tensorboard,参考这个链接,在 Pytorch 中使用 Tensorboard 进行可视化

 

Torchinfo 简单使用介绍

安装直接使用 pip 进行安装即可:

  1. pip install torchinfo

同样是使用 summary 来查看模型。这里可以选择 input_size 或是 input_data 均可;下面的例子是指定 input_size 获得结果:

  1. from torchinfo import summary
  2. model = ConvNet()
  3. batch_size = 16
  4. summary(model, input_size=(batch_size, 1, 28, 28))

我们也可以指定 input_data,这里需要放在一个 list 里面:

  1. input_data = torch.randn(1, 300)
  2. other_input_data = torch.randn(1, 300).long()
  3. model = MultipleInputNetDifferentDtypes()
  4. summary(model, input_data=[input_data, other_input_data, ...])

最终可视化的结果类似下面:

Pytorch模型概览(Pytorch Summary)

 

Pytorch Summary 简单使用介绍

安装直接使用 pip 进行安装即可:

  1. pip install torchsummary

对于任何一个模型来说,可以使用下面的方式来看 summary

  1. from torchsummary import summary
  2. summary(your_model, input_size=(channels, H, W))

其中 input_size 需要可以符合你模型的输入要求, 可以用来进行前向传播.

下面看一个具体的例子和最终的输出结果.

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchsummary import summary
  5. class Net(nn.Module):
  6.     def __init__(self):
  7.         super(Net, self).__init__()
  8.         self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  9.         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  10.         self.conv2_drop = nn.Dropout2d()
  11.         self.fc1 = nn.Linear(320, 50)
  12.         self.fc2 = nn.Linear(50, 10)
  13.     def forward(self, x):
  14.         x = F.relu(F.max_pool2d(self.conv1(x), 2))
  15.         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  16.         x = x.view(-1, 320)
  17.         x = F.relu(self.fc1(x))
  18.         x = F.dropout(x, training=self.training)
  19.         x = self.fc2(x)
  20.         return F.log_softmax(x, dim=1)
  21. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
  22. model = Net().to(device)
  23. summary(model, (1, 28, 28))

最终的结果如下所示, 可以看到每一层的具体变化.

Pytorch模型概览(Pytorch Summary)

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2020年4月8日07:33:33
  • 转载请务必保留本文链接:https://mathpretty.com/12052.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: