浅谈keras中Dropout在预测过程中是否仍要起作用

 更新时间:2020年07月09日 09:04:53   作者:zyl681327  
这篇文章主要介绍了浅谈keras中Dropout在预测过程中是否仍要起作用,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

因为需要,要重写训练好的keras模型,虽然只具备预测功能,但是发现还是有很多坑要趟过。其中Dropout这个坑,我记忆犹新。

一开始,我以为预测时要保持和训练时完全一样的网络结构,也就是预测时用的网络也是有丢弃的网络节点,但是这样想就掉进了一个大坑!因为无法通过已经训练好的模型,来获取其训练时随机丢弃的网络节点是那些,这本身就根本不可能。

更重要的是:我发现每一个迭代周期丢弃的神经元也不完全一样。

假若迭代500次,网络共有1000个神经元, 在第n(1<= n <500)个迭代周期内,从1000个神经元里随机丢弃了200个神经元,在n+1个迭代周期内,会在这1000个神经元里(不是在剩余得800个)重新随机丢弃200个神经元。

训练过程中,使用Dropout,其实就是对部分权重和偏置在某次迭代训练过程中,不参与计算和更新而已,并不是不再使用这些权重和偏置了(预测时,会使用全部的神经元,包括使用训练时丢弃的神经元)。

也就是说在预测过程中完全没有Dropout什么事了,他只是在训练时有用,特别是针对训练集比较小时防止过拟合非常有用。

补充知识:TensorFlow直接使用ckpt模型predict不用restore

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'

input_checkpoint = './model/xu_spatial_model_1340.ckpt'

sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)

# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0") # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 读取测试图片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
                       is_training: False,
                       batch_size: 1})
print(out)

ckpt模型中的所有节点名称,可以这样查看

[n.name for n in tf.get_default_graph().as_graph_def().node]

以上这篇浅谈keras中Dropout在预测过程中是否仍要起作用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python实现发送QQ邮件(可加附件)

    python实现发送QQ邮件(可加附件)

    这篇文章主要为大家详细介绍了python实现发送QQ邮件,可添加附件功能,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-12-12
  • 基于python中theano库的线性回归

    基于python中theano库的线性回归

    这篇文章主要为大家详细介绍了基于python中theano库的线性回归,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-08-08
  • 使Python代码流畅无缝连接的链式调用技巧

    使Python代码流畅无缝连接的链式调用技巧

    链式调用是一种编程风格,它允许将多个方法调用连接在一起,形成一个连贯的操作链,在Python中,链式调用常常用于使代码更简洁、易读,尤其在处理数据处理和函数式编程中应用广泛
    2024-01-01
  • python 定时器每天就执行一次的实现代码

    python 定时器每天就执行一次的实现代码

    这篇文章主要介绍了python 定时器每天就执行一次的实现代码,代码简单易懂非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • python执行shell获取硬件参数写入mysql的方法

    python执行shell获取硬件参数写入mysql的方法

    这篇文章主要介绍了python执行shell获取硬件参数写入mysql的方法,可实现对服务器硬件信息的读取及写入数据库的功能,非常具有实用价值,需要的朋友可以参考下
    2014-12-12
  • Numpy中Meshgrid函数基本用法及2种应用场景

    Numpy中Meshgrid函数基本用法及2种应用场景

    NumPy包含很多实用的数学函数,涵盖线性代数运算、傅里叶变换和随机数生成等功能,下面这篇文章主要给大家介绍了关于Numpy中Meshgrid函数基本用法及2种应用场景的相关资料,需要的朋友可以参考下
    2022-08-08
  • python爬虫 线程池创建并获取文件代码实例

    python爬虫 线程池创建并获取文件代码实例

    这篇文章主要介绍了python爬虫 线程池创建并获取文件代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 详细分析Python collections工具库

    详细分析Python collections工具库

    这篇文章主要介绍了详解Python collections工具库的相关资料,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • Python中OTSU算法的原理与实现详解

    Python中OTSU算法的原理与实现详解

    OTSU算法是大津展之提出的阈值分割方法,又叫最大类间方差法,本文主要为大家详细介绍了OTSU算法的原理与Python实现,感兴趣的小伙伴可以了解下
    2023-12-12
  • Pytorch中膨胀卷积的用法详解

    Pytorch中膨胀卷积的用法详解

    今天小编就为大家分享一篇Pytorch中膨胀卷积的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01

最新评论