keras建模的3种方式详解

 更新时间:2023年08月23日 10:48:02   作者:月疯  
这篇文章主要介绍了keras建模的3种方式详解,keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱,需要的朋友可以参考下

keras建模的3种方式

keras是google公司2016年发布的tensorflow为后端的深度学习网络的高级接口。

三种建模方式:

  1. 序列模型
  2. 函数模型
  3. 子类模型

第一种序列模型:

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#序列模型
def DNN(train_x,train_y,valid_x,valid_y):
    #創建模型
    model=Sequential()
    model.add(Dense(64,input_dim=784,activation='relu'))
    model.add(Dense(128,activation='relu'))
    model.add(Dense(10,activation='softmax'))
    #查看网络模型
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
    #保存模型
    model.save('sequential.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('sequential.h5')  #下载模型
pre=model.predict(test_x)  #测试验证
#计算验证集精度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

 第二种函数模型

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#函数模型
def DNN(train_x,train_y,valid_x,valid_y):
    #创建模型
    inputs=Input(shape=(784,))
    x=Dense(64,activation='relu')(inputs)
    x=Dense(128,activation='relu')(x)
    output=Dense(10,activation='softmax')(x)
    model=Model(input=inputs,output=output)
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
    #保存模型
    model.save('fun_model.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('fun_model.h5')  #下载模型
pre=model.predict(test_x)  #验证数据集
#验证数据集准确度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

第三种子类模型

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense
#加载数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y
#子类模型
class DNN(Model):
    def __init__(self,train_x,train_y,valid_x,valid_y):
        super(DNN,self).__init__()
        #初始化网络模型
        self.dense1=Dense(64,input_dim=784,activation='relu')
        self.dense2=Dense(128,activation='relu')
        self.dense3=Dense(10,activation='softmax')
    def call(self,inputs):  #回调順序
        x=self.dense1(inputs)
        x=self.dense2(x)
        x=self.dense3(x)
        return x
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN(train_x,train_y,valid_x,valid_y)
#编译模型(学习率、损失函数、模型评估)
model.compile(optimizer='adam(lr=0.001)',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary()
pre=model.predict(test_x)  #验证数据集
#计算验证数据集的准确度
a=np.argmax(pre,1)
b=np.argmax(test_y,1)
t=(a==b).astype(int)
acc=np.sum(t)/len(a)
print(acc)

常用的损失函数: 

mse #均方差(回归)

mae #绝对误差(回归)

binary_crossentropy #二值交叉熵(二分类,逻辑回归)

categorical_crossentropy #交叉熵(多分类)

到此这篇关于keras建模的3种方式详解的文章就介绍到这了,更多相关keras建模方式内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • jupyter notebook保存文件默认路径更改方法汇总(亲测可以)

    jupyter notebook保存文件默认路径更改方法汇总(亲测可以)

    安装Anaconda后,新建文件的默认存储路径一般在C系统盘,那么路径是什么呢?如何更改jupyter notebook保存文件默认路径呢?今天小编就这一问题通过两种方法给大家讲解,需要的朋友跟随小编一起看看吧
    2021-06-06
  • Python中比较特别的除法运算和幂运算介绍

    Python中比较特别的除法运算和幂运算介绍

    这篇文章主要介绍了Python中比较特别的除法运算和幂运算介绍,“/”这个是除法运算,那么这个“//”呢?“*”这个是乘法运算,那么这个“**”呢?本文就讲解这些运算的不同,需要的朋友可以参考下
    2015-04-04
  • Python中的lambda和apply用法及说明

    Python中的lambda和apply用法及说明

    这篇文章主要介绍了Python中的lambda和apply用法及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • python中把元组转换为namedtuple方法

    python中把元组转换为namedtuple方法

    在本篇文章里小编给大家整理的是一篇关于python中把元组转换为namedtuple方法,有兴趣的朋友们可以参考下。
    2020-12-12
  • Python中的请求重试策略深入探讨

    Python中的请求重试策略深入探讨

    在网络通信中,由于各种原因,请求可能会失败,为了增加程序的健壮性和可靠性,实现一个优雅的请求重试策略是至关重要的,本文将深入探讨如何在Python中实现优雅的请求重试,通过丰富的示例代码和详细的解释,帮助大家更好地理解和应用重试机制
    2024-01-01
  • python实现beta分布概率密度函数的方法

    python实现beta分布概率密度函数的方法

    今天小编就为大家分享一篇python实现beta分布概率密度函数的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python正则表达式经典入门教程

    Python正则表达式经典入门教程

    这篇文章主要介绍了Python正则表达式,结合具体实例形式归纳总结了Python正则表达式中常用的各种函数与相关使用技巧,需要的朋友可以参考下
    2017-05-05
  • Python教程之生产者消费者模式解析

    Python教程之生产者消费者模式解析

    在并发编程中使用生产者和消费者模式能够解决大不多的并发问题。该模式通过平衡生产线程和消费线程的工作能力来提高程序的整体处理数据的速度
    2021-09-09
  • ansible作为python模块库使用的方法实例

    ansible作为python模块库使用的方法实例

    ansible是一个python package,是个完全的unpack and play软件,对客户端唯一的要求是有ssh有python,并且装了python-simplejson包,部署上简单到发指。下面这篇文章就给大家主要介绍了ansible作为python模块库使用的方法实例,需要的朋友可以参考借鉴。
    2017-01-01
  • Python封装git命令的流程步骤

    Python封装git命令的流程步骤

    在日常的 Android 项目开发中,一般只会使用到: git add, git commit, git push, git pull, git rebase, git merge, git diff等常规命令,但是使用 git 命令,还可以做一些特别的事情,下面将介绍使用 python 封装 git 命令,需要的朋友可以参考下
    2024-01-01

最新评论