PyTorch实现卷积神经网络的搭建详解

 更新时间:2022年05月07日 08:56:01   作者:Bubbliiiing  
这篇文章主要为大家介绍了PyTorch实现卷积神经网络的搭建详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

PyTorch中实现卷积的重要基础函数

1、nn.Conv2d:

nn.Conv2d在pytorch中用于实现卷积。

nn.Conv2d(
    in_channels=32,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
)

1、in_channels为输入通道数。

2、out_channels为输出通道数。

3、kernel_size为卷积核大小。

4、stride为步数。

5、padding为padding情况。

6、dilation表示空洞卷积情况。

2、nn.MaxPool2d(kernel_size=2)

nn.MaxPool2d在pytorch中用于实现最大池化。

具体使用方式如下:

MaxPool2d(kernel_size, 
		stride=None, 
		padding=0, 
		dilation=1, 
		return_indices=False, 
		ceil_mode=False)

1、kernel_size为池化核的大小

2、stride为步长

3、padding为填充情况

3、nn.ReLU()

nn.ReLU()用来实现Relu函数,实现非线性。

4、x.view()

x.view用于reshape特征层的形状。

全部代码

这是一个简单的CNN模型,用于预测mnist手写体。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循环世代
EPOCH = 20
BATCH_SIZE = 50
# 下载mnist数据集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())                 
# (60000)
print(train_data.train_labels.size())               
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 标准化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #----------------------------#
        #   第一部分卷积
        #----------------------------#
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   第二部分卷积
        #----------------------------#
        self.conv2 = nn.Sequential( 
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   全连接+池化+全连接
        #----------------------------#
        self.ful1 = nn.Linear(64 * 7 * 7, 512)
        self.drop = nn.Dropout(0.5)
        self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
    #----------------------------#
    #   前向传播
    #----------------------------#   
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.ful1(x)
        x = self.drop(x)
        output = self.ful2(x)
        return output
cnn = CNN()
# 指定优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) 
# 指定loss函数
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader): 
        #----------------------------#
        #   计算loss并修正权值
        #----------------------------#   
        output = cnn(b_x)
        loss = loss_func(output, b_y) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        #----------------------------#
        #   打印
        #----------------------------#   
        if step % 50 == 0:
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)

以上就是PyTorch实现卷积神经网络的搭建详解的详细内容,更多关于PyTorch搭建卷积神经网络的资料请关注脚本之家其它相关文章!

相关文章

  • Python 实现简单的电话本功能

    Python 实现简单的电话本功能

    这篇文章主要介绍了Python 实现简单的电话本功能的相关资料,包括添加联系人信息,查找姓名显示联系人,存储联系人到 TXT 文档等内容,十分的细致,有需要的小伙伴可以参考下
    2015-08-08
  • Python开启Http Server的实现步骤

    Python开启Http Server的实现步骤

    本文主要介绍了Python开启Http Server的实现步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07
  • 在Python中使用matplotlib模块绘制数据图的示例

    在Python中使用matplotlib模块绘制数据图的示例

    这篇文章主要介绍了在Python中使用matplotlib模块绘制数据图的示例,matplotlib模块经常被用来实现数据的可视化,需要的朋友可以参考下
    2015-05-05
  • matplotlib作图添加表格实例代码

    matplotlib作图添加表格实例代码

    这篇文章主要介绍了matplotlib作图添加表格实例代码,实例绘制了一个简单的折线图,并且在图中添加了一个表格,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • TensorFlow Autodiff自动微分详解

    TensorFlow Autodiff自动微分详解

    这篇文章主要介绍了TensorFlow Autodiff自动微分详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • 解决IDEA 的 plugins 搜不到任何的插件问题

    解决IDEA 的 plugins 搜不到任何的插件问题

    这篇文章主要介绍了解决IDEA 的 plugins 搜不到任何的插件问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python 实现进度条的六种方式

    Python 实现进度条的六种方式

    这篇文章主要介绍了Python 实现进度条的六种方式,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-01-01
  • PyCharm插件开发实践之PyGetterAndSetter详解

    PyCharm插件开发实践之PyGetterAndSetter详解

    这篇文章主要介绍了PyCharm插件开发实践-PyGetterAndSetter,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-10-10
  • pyenv虚拟环境管理python多版本和软件库的方法

    pyenv虚拟环境管理python多版本和软件库的方法

    这篇文章主要介绍了pyenv虚拟环境管理python多版本和软件库,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-12-12
  • python 将列表中的字符串连接成一个长路径的方法

    python 将列表中的字符串连接成一个长路径的方法

    今天小编就为大家分享一篇python 将列表中的字符串连接成一个长路径的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10

最新评论