PyTorch实现卷积神经网络的搭建详解

 更新时间:2022年05月07日 08:56:01   作者:Bubbliiiing  
这篇文章主要为大家介绍了PyTorch实现卷积神经网络的搭建详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

PyTorch中实现卷积的重要基础函数

1、nn.Conv2d:

nn.Conv2d在pytorch中用于实现卷积。

nn.Conv2d(
    in_channels=32,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
)

1、in_channels为输入通道数。

2、out_channels为输出通道数。

3、kernel_size为卷积核大小。

4、stride为步数。

5、padding为padding情况。

6、dilation表示空洞卷积情况。

2、nn.MaxPool2d(kernel_size=2)

nn.MaxPool2d在pytorch中用于实现最大池化。

具体使用方式如下:

MaxPool2d(kernel_size, 
		stride=None, 
		padding=0, 
		dilation=1, 
		return_indices=False, 
		ceil_mode=False)

1、kernel_size为池化核的大小

2、stride为步长

3、padding为填充情况

3、nn.ReLU()

nn.ReLU()用来实现Relu函数,实现非线性。

4、x.view()

x.view用于reshape特征层的形状。

全部代码

这是一个简单的CNN模型,用于预测mnist手写体。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循环世代
EPOCH = 20
BATCH_SIZE = 50
# 下载mnist数据集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())                 
# (60000)
print(train_data.train_labels.size())               
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 标准化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #----------------------------#
        #   第一部分卷积
        #----------------------------#
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   第二部分卷积
        #----------------------------#
        self.conv2 = nn.Sequential( 
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   全连接+池化+全连接
        #----------------------------#
        self.ful1 = nn.Linear(64 * 7 * 7, 512)
        self.drop = nn.Dropout(0.5)
        self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
    #----------------------------#
    #   前向传播
    #----------------------------#   
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.ful1(x)
        x = self.drop(x)
        output = self.ful2(x)
        return output
cnn = CNN()
# 指定优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) 
# 指定loss函数
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader): 
        #----------------------------#
        #   计算loss并修正权值
        #----------------------------#   
        output = cnn(b_x)
        loss = loss_func(output, b_y) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        #----------------------------#
        #   打印
        #----------------------------#   
        if step % 50 == 0:
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)

以上就是PyTorch实现卷积神经网络的搭建详解的详细内容,更多关于PyTorch搭建卷积神经网络的资料请关注脚本之家其它相关文章!

相关文章

  • Python+Selenium实现一键摸鱼&采集数据

    Python+Selenium实现一键摸鱼&采集数据

    将Selenium程序编写为 .bat 可执行文件,从此一键启动封装好的Selenium程序,省时省力还可以复用,岂不美哉。所以本文将利用Selenium实现一键摸鱼&一键采集数据,需要的可以参考一下
    2022-08-08
  • python与c语言的语法有哪些不一样的

    python与c语言的语法有哪些不一样的

    在本篇内容里小编给大家整理的是一篇关于python与c语法区别的相关内容,有兴趣的朋友们可以参考下。
    2020-09-09
  • python 爬虫如何正确的使用cookie

    python 爬虫如何正确的使用cookie

    这篇文章主要介绍了python 爬虫如何使用cookie,帮助大家绕过网站设置的登录规则以及登录时的验证码识别,完成自身的爬取需求,感兴趣的朋友可以了解下
    2020-10-10
  • selenium 安装与chromedriver安装的方法步骤

    selenium 安装与chromedriver安装的方法步骤

    这篇文章主要介绍了selenium 安装与chromedriver安装的方法步骤,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-06-06
  • Python+Opencv实现把图片、视频互转的示例

    Python+Opencv实现把图片、视频互转的示例

    这篇文章主要介绍了Python+Opencv实现把图片、视频互转的示例,帮助大家更好的理解和实用python,感兴趣的朋友可以了解下
    2020-12-12
  • Python使用Selenium自动进行百度搜索的实现

    Python使用Selenium自动进行百度搜索的实现

    我们今天介绍一个非常适合新手的python自动化小项目,这个例子非常适合新手学习Python网络自动化,不仅能够了解如何使用Selenium,而且还能知道一些超级好用的小工具。感兴趣的可以了解一下
    2021-07-07
  • Python使用pymysql从MySQL数据库中读出数据的方法

    Python使用pymysql从MySQL数据库中读出数据的方法

    今天小编就为大家分享一篇Python使用pymysql从MySQL数据库中读出数据的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python初学定义函数

    python初学定义函数

    这篇文章主要为大家介绍了python的定义函数,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助,希望能够给你带来帮助
    2021-11-11
  • python生成1行四列全2矩阵的方法

    python生成1行四列全2矩阵的方法

    今天小编就为大家分享一篇python生成1行四列全2矩阵的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-08-08
  • 如何用python批量发送工资条邮件

    如何用python批量发送工资条邮件

    大家好,本篇文章主要讲的是如何用python批量发送工资条邮件,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-01-01

最新评论