tensorflow实现逻辑回归模型

 更新时间:2018年09月08日 09:42:36   作者:Missayaa  
这篇文章主要为大家详细介绍了tensorflow实现逻辑回归模型的相关资料,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 详解Python中的变量及其命名和打印

    详解Python中的变量及其命名和打印

    这篇文章主要介绍了Python中的变量及其命名和打印,是Python入门学习中的基础知识,需要的朋友可以参考下
    2016-03-03
  • Python 多进程原理及实现

    Python 多进程原理及实现

    这篇文章主要介绍了Python 多进程原理及实现,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-12-12
  • Python反编译的两种实现方式

    Python反编译的两种实现方式

    这篇文章主要介绍了Python反编译的两种实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-06-06
  • Python实现中一次读取多个值的方法

    Python实现中一次读取多个值的方法

    下面小编就为大家分享一篇Python实现中一次读取多个值的方法,具有很好的参考价值,我对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python三方库之requests的快速上手

    python三方库之requests的快速上手

    这篇文章主要介绍了python三方库之requests的快速上手,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • 详解如何在Python中实现遗传算法

    详解如何在Python中实现遗传算法

    遗传算法是一种模拟自然进化过程与机制来搜索最优解的方法,这篇文章主要为大家介绍了如何在Python中实现遗传算法,感兴趣的小伙伴可以了解一下
    2023-06-06
  • python中15种3D绘图函数总结

    python中15种3D绘图函数总结

    这篇文章主要为大家详细介绍了python中15种3D绘图函数的用法,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的小伙伴可以跟随小编一起了解一下
    2023-09-09
  • Python使用Selenium模块模拟浏览器抓取斗鱼直播间信息示例

    Python使用Selenium模块模拟浏览器抓取斗鱼直播间信息示例

    这篇文章主要介绍了Python使用Selenium模块模拟浏览器抓取斗鱼直播间信息,涉及Python基于Selenium模块的模拟浏览器登陆、解析、抓取信息,以及MongoDB数据库的连接、写入等相关操作技巧,需要的朋友可以参考下
    2018-07-07
  • Python中遍历字典过程中更改元素导致异常的解决方法

    Python中遍历字典过程中更改元素导致异常的解决方法

    这篇文章主要介绍了Python中遍历字典过程中更改元素导致错误的解决方法,针对增删元素后出现dictionary changed size during iteration的异常解决做出讨论和解决,需要的朋友可以参考下
    2016-05-05
  • Python本地及虚拟解释器配置过程解析

    Python本地及虚拟解释器配置过程解析

    这篇文章主要介绍了Python本地及虚拟解释器配置过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10

最新评论