文章目录(Table of Contents)
简介
在使用 Pytorch
创建完模型之后, 我们希望可以有一个工具可以看到模型的参数个数, 模型占用的内容. 这个就类似于在 Keras
中的 model.summary()
的功能. 但是在 Pytorch
中, 本身并没有实现这个功能, 但是有一个库实现了该功能. pytorch-summary. 这里简单介绍一下。
【更新】目前官方已经推荐使用另外一个库,Github-torchinfo,来完成相同的功能。如果想要看到模型的结构的可视化,可以直接写入 tensorboard,参考这个链接,在 Pytorch 中使用 Tensorboard 进行可视化。
Torchinfo 简单使用介绍
安装直接使用 pip
进行安装即可:
- pip install torchinfo
同样是使用 summary
来查看模型。这里可以选择 input_size
或是 input_data
均可;下面的例子是指定 input_size
获得结果:
- from torchinfo import summary
- model = ConvNet()
- batch_size = 16
- summary(model, input_size=(batch_size, 1, 28, 28))
我们也可以指定 input_data
,这里需要放在一个 list
里面:
- input_data = torch.randn(1, 300)
- other_input_data = torch.randn(1, 300).long()
- model = MultipleInputNetDifferentDtypes()
- summary(model, input_data=[input_data, other_input_data, ...])
最终可视化的结果类似下面:
data:image/s3,"s3://crabby-images/814f0/814f01771286a9efd4d1b3bdef63e2adaf35a593" alt="Pytorch模型概览(Pytorch Summary)"
Pytorch Summary 简单使用介绍
安装直接使用 pip 进行安装即可:
- pip install torchsummary
对于任何一个模型来说,可以使用下面的方式来看 summary
:
- from torchsummary import summary
- summary(your_model, input_size=(channels, H, W))
其中 input_size 需要可以符合你模型的输入要求, 可以用来进行前向传播.
下面看一个具体的例子和最终的输出结果.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchsummary import summary
- class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
- self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
- self.conv2_drop = nn.Dropout2d()
- self.fc1 = nn.Linear(320, 50)
- self.fc2 = nn.Linear(50, 10)
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
- x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
- x = x.view(-1, 320)
- x = F.relu(self.fc1(x))
- x = F.dropout(x, training=self.training)
- x = self.fc2(x)
- return F.log_softmax(x, dim=1)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
- model = Net().to(device)
- summary(model, (1, 28, 28))
最终的结果如下所示, 可以看到每一层的具体变化.
data:image/s3,"s3://crabby-images/97abd/97abd89cf27385db155082c83e7aff8d5ac7c811" alt="Pytorch模型概览(Pytorch Summary)"
- 微信公众号
- 关注微信公众号
-
- QQ群
- 我们的QQ群号
-
评论