Python如何加载模型并查看网络

 更新时间:2022年07月15日 10:30:42   作者:ShuqiaoS  
这篇文章主要介绍了Python如何加载模型并查看网络,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

在这里插入图片描述

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用Python的networkx绘制精美网络图教程

    使用Python的networkx绘制精美网络图教程

    今天小编就为大家分享一篇使用Python的networkx绘制精美网络图教程,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • pytorch 彩色图像转灰度图像实例

    pytorch 彩色图像转灰度图像实例

    今天小编就为大家分享一篇pytorch 彩色图像转灰度图像实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • python垃圾回收机制(GC)原理解析

    python垃圾回收机制(GC)原理解析

    这篇文章主要介绍了python垃圾回收机制(GC)原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • Python中取整的几种方法小结

    Python中取整的几种方法小结

    这篇文章主要介绍了Python中取整的几种方法,其中包括向下取整、四舍五入取整、向上取整以及分别取整数部分和小数部分。分别都给出了示例代码,相信对大家的理解和学习具有一定的参考借鉴价值,需要的朋友可以参考借鉴。
    2017-01-01
  • Django + Taro 前后端分离项目实现企业微信登录功能

    Django + Taro 前后端分离项目实现企业微信登录功能

    这篇文章主要介绍了Django + Taro 前后端分离项目实现企业微信登录功能,本文记录一下企业微信登录的流程,结合示例代码给大家分享实现思路,需要的朋友可以参考下
    2022-04-04
  • 基于多进程中APScheduler重复运行的解决方法

    基于多进程中APScheduler重复运行的解决方法

    今天小编就为大家分享一篇基于多进程中APScheduler重复运行的解决方法,具有很好的价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python实现将内容写入文件的五种方法总结

    Python实现将内容写入文件的五种方法总结

    本篇带你详细看一下python将内容写入文件的方法以及细节,主要包括write()方法、writelines() 方法、print() 函数、使用 csv 模块、使用 json 模块,需要的可以参考一下
    2023-04-04
  • Python机器学习iris数据集预处理和模型训练方式

    Python机器学习iris数据集预处理和模型训练方式

    iris数据集包含150个样本,每个样本有4个特征及其类别信息,本文介绍了iris数据集的基本操作和如何使用knn模型进行花卉种类预测,是机器学习中的经典案例,适用于监督式学习
    2024-10-10
  • 如何使用Typora+MinIO+Python代码打造舒适协作环境

    如何使用Typora+MinIO+Python代码打造舒适协作环境

    这篇文章主要介绍了如何使用Typora+MinIO+Python代码打造舒适协作环境,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-05-05
  • 如何用python合并多个excel文件

    如何用python合并多个excel文件

    这篇文章主要介绍了如何用python合并多个excel文件,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-03-03

最新评论