python之tensorflow手把手实例讲解猫狗识别实现

 更新时间:2021年09月22日 16:03:07   作者:鑫xing  
要说到深度学习图像分类的经典案例之一,那就是猫狗大战了。猫和狗在外观上的差别还是挺明显的,无论是体型、四肢、脸庞和毛发等等, 都是能通过肉眼很容易区分的。那么如何让机器来识别猫和狗呢?网上已经有不少人写过这案例了,我也来尝试下练练手。

作为tensorflow初学的大三学生,本次课程作业的使用猫狗数据集做一个二分类模型。

一,猫狗数据集数目构成

train cats:1000 ,dogs:1000
test cats: 500,dogs:500
validation cats:500,dogs:500

二,数据导入

train_dir = 'Data/train'
test_dir = 'Data/test'
validation_dir = 'Data/validation'
train_datagen = ImageDataGenerator(rescale=1/255,
                                   rotation_range=10,
                                   width_shift_range=0.2,  #图片水平偏移的角度
                                   height_shift_range=0.2,  #图片数值偏移的角度
                                   shear_range=0.2,  #剪切强度 
                                   zoom_range=0.2,   #随机缩放的幅度
                                   horizontal_flip=True,   #是否进行随机水平翻转
#                                    fill_mode='nearest'
                                  )
train_generator = train_datagen.flow_from_directory(train_dir,
                 (224,224),batch_size=1,class_mode='binary',shuffle=False)
test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir,
                 (224,224),batch_size=1,class_mode='binary',shuffle=True)
validation_datagen = ImageDataGenerator(rescale=1/255)
validation_generator = validation_datagen.flow_from_directory(
                validation_dir,(224,224),batch_size=1,class_mode='binary')
print(train_datagen)
print(test_datagen)
print(train_datagen)

三,数据集构建

我这里是将ImageDataGenerator类里的数据提取出来,将数据与标签分别存放在两个列表,后面在转为np.array,也可以使用model.fit_generator,我将数据放在内存为了后续调参数时模型训练能更快读取到数据,不用每次训练一整轮都去读一次数据(应该是这样的…我是这样理解…)
注意我这里的数据集构建后,三种数据都是存放在内存中的,我电脑内存是16g的可以存放下。

train_data=[]
train_labels=[]
a=0
for data_train, labels_train in train_generator:
    train_data.append(data_train)
    train_labels.append(labels_train)
    a=a+1
    if a>1999:
        break
x_train=np.array(train_data)
y_train=np.array(train_labels)
x_train=x_train.reshape(2000,224,224,3)
test_data=[]
test_labels=[]
a=0
for data_test, labels_test in test_generator:
    test_data.append(data_test)
    test_labels.append(labels_test)
    a=a+1
    if a>999:
        break
x_test=np.array(test_data)
y_test=np.array(test_labels)
x_test=x_test.reshape(1000,224,224,3)
validation_data=[]
validation_labels=[]
a=0
for data_validation, labels_validation in validation_generator:
    validation_data.append(data_validation)
    validation_labels.append(labels_validation)
    a=a+1
    if a>999:
        break
x_validation=np.array(validation_data)
y_validation=np.array(validation_labels)
x_validation=x_validation.reshape(1000,224,224,3)

四,模型搭建

model1 = tf.keras.models.Sequential([
    # 第一层卷积,卷积核为,共16个,输入为150*150*1
    tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(224,224,3)),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 第二层卷积,卷积核为3*3,共32个,
    tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 第三层卷积,卷积核为3*3,共64个,
    tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 数据铺平
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64,activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1,activation='sigmoid')
])
print(model1.summary())

模型summary:

在这里插入图片描述

五,模型训练

model1.compile(optimize=tf.keras.optimizers.SGD(0.00001),
             loss=tf.keras.losses.binary_crossentropy,
             metrics=['acc'])
history1=model1.fit(x_train,y_train,
# 					validation_split=(0~1)   选择一定的比例用于验证集,可被validation_data覆盖
                  validation_data=(x_validation,y_validation),
                  batch_size=10,
                  shuffle=True,
                  epochs=10)
model1.save('cats_and_dogs_plain1.h5')
print(history1)

在这里插入图片描述

plt.plot(history1.epoch,history1.history.get('acc'),label='acc')
plt.plot(history1.epoch,history1.history.get('val_acc'),label='val_acc')
plt.title('正确率')
plt.legend()

在这里插入图片描述

可以看到我们的模型泛化能力还是有点差,测试集的acc能达到0.85以上,验证集却在0.65~0.70之前跳动。

六,模型测试

model1.evaluate(x_validation,y_validation)

在这里插入图片描述

最后我们的模型在测试集上的正确率为0.67,可以说还不够好,有点过拟合,可能是训练数据不够多,后续可以数据增广或者从验证集、测试集中调取一部分数据用于训练模型,可能效果好一些。

到此这篇关于python之tensorflow手把手实例讲解猫狗识别实现的文章就介绍到这了,更多相关python tensorflow 猫狗识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python包管理工具pip用法详解

    Python包管理工具pip用法详解

    本文详细讲解了Python包管理工具pip的用法,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-05-05
  • PyTorch 实现L2正则化以及Dropout的操作

    PyTorch 实现L2正则化以及Dropout的操作

    这篇文章主要介绍了PyTorch 实现L2正则化以及Dropout的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python正则表达式以及常用匹配实例

    Python正则表达式以及常用匹配实例

    在处理字符串时,经常会遇到查找符合某些复杂规则字符串的需求,正则表达式就是用于描述这些规则的工具,下面这篇文章主要给大家介绍了关于Python正则表达式以及常用匹配的相关资料,需要的朋友可以参考下
    2022-07-07
  • Python数据可视化Pyecharts库的使用教程

    Python数据可视化Pyecharts库的使用教程

    pyecharts是一个用于生成echarts图表的类库。echarts是百度开源的一个数据可视化库,用echarts生成的图可视化效果非常棒。使用pyechart库可以在python中生成echarts数据图。本文将详细介绍一下Pyecharts库的使用,需要的可以参考一下
    2022-02-02
  • 使用python生成云词图实现画红楼梦词云图

    使用python生成云词图实现画红楼梦词云图

    红楼梦相信大家都看过,今天给大家介绍另一种不用搞得乌漆麻黑的方式来制作红楼梦的词云图,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-09-09
  • Python中关于元组 集合 字符串 函数 异常处理的全面详解

    Python中关于元组 集合 字符串 函数 异常处理的全面详解

    本篇文章介绍了我在学习python过程中对元组、集合、字符串、函数、异常处理的总结,通读本篇对大家的学习或工作具有一定的价值,需要的朋友可以参考下
    2021-10-10
  • 在Python的Django框架中更新数据库数据的方法

    在Python的Django框架中更新数据库数据的方法

    这篇文章主要介绍了在Python的Django框架中更新数据库数据,对此Django框架中提供了便利的插入和更新方法,需要的朋友可以参考下
    2015-07-07
  • Python 实现「食行生鲜」签到领积分功能

    Python 实现「食行生鲜」签到领积分功能

    今天我们就用 Python 来实现自动签到,省得我每天打开 APP 来操作了。感兴趣的朋友跟随小编一起看看吧
    2018-09-09
  • Python format()格式化输出方法

    Python format()格式化输出方法

    这篇文章主要介绍了Python format()格式化输出方法, Python 2.6以后,Python 中的就提供了字符串类型(str)提供了 format() 方法对字符串进行格式化,夏敏我们就来了解这个方法吧,需要的小伙伴也可以参考一下

    2021-12-12
  • 浅谈python中的__init__、__new__和__call__方法

    浅谈python中的__init__、__new__和__call__方法

    这篇文章主要给大家介绍了关于python中__init__、__new__和__call__方法的相关资料,文中通过示例代码介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友可以参考学习,下面来跟着小编一起看看吧。
    2017-07-07

最新评论