Python利用DNN实现宝石识别

 更新时间:2022年01月28日 10:01:02   作者:Livingbody  
深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。本文将利用DNN实现宝石识别,感兴趣的可以了解一下

任务描述

本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别

实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图

深度神经网络(DNN)

深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。

数据集介绍

  • 数据集文件名为archive_train.zip,archive_test.zip。
  • 该数据集包含25个类别不同宝石的图像。
  • 这些类别已经分为训练和测试数据。
  • 图像大小不一,格式为.jpeg。

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data

data55032 dataset

#导入需要的包
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import matplotlib.pyplot as plt

1.数据准备

'''
参数配置
'''
train_parameters = {
    "input_size": [3, 64, 64],                           #输入图片的shape
    "class_dim": -1,                                     #分类数
    'augment_path' : '/home/aistudio/augment',           #数据增强图片目录
    "src_path":"data/data55032/archive_train.zip",       #原始数据集路径
    "target_path":"/home/aistudio/data/dataset",        #要解压的路径 
    "train_list_path": "./train_data.txt",              #train_data.txt路径
    "eval_list_path": "./val_data.txt",                  #eval_data.txt路径
    "label_dict":{},                                    #标签字典
    "readme_path": "/home/aistudio/data/readme.json",   #readme.json路径
    "num_epochs": 20,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {                              #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    } 
}
def unzip_data(src_path,target_path):

    '''
    解压原始数据集,将src_path路径下的zip包解压至data/dataset目录下
    '''

    if(not os.path.isdir(target_path)):    
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
    else:
        print("文件已解压")
def get_data_list(target_path,train_list_path,eval_list_path, augment_path):
    '''
    生成数据列表
    '''
    #存放所有类别的信息
    class_detail = []
    #获取所有类别保存的文件夹名称
    data_list_path=target_path
    class_dirs = os.listdir(data_list_path)
    if '__MACOSX' in class_dirs:
        class_dirs.remove('__MACOSX')
    # #总的图像数量
    all_class_images = 0
    # #存放类别标签
    class_label=0
    # #存放类别数目
    class_dim = 0
    # #存储要写进eval.txt和train.txt中的内容
    trainer_list=[]
    eval_list=[]
    #读取每个类别
    for class_dir in class_dirs:
        if class_dir != ".DS_Store":
            class_dim += 1
            #每个类别的信息
            class_detail_list = {}
            eval_sum = 0
            trainer_sum = 0
            #统计每个类别有多少张图片
            class_sum = 0
            #获取类别路径 
            path = os.path.join(data_list_path,class_dir)
            # print(path)
            # 获取所有图片
            img_paths = os.listdir(path)
            for img_path in img_paths:                                  # 遍历文件夹下的每个图片
                if img_path =='.DS_Store':
                    continue
                name_path = os.path.join(path,img_path)                       # 每张图片的路径
                if class_sum % 15 == 0:                                 # 每10张图片取一个做验证数据
                    eval_sum += 1                                       # eval_sum为测试数据的数目
                    eval_list.append(name_path + "\t%d" % class_label + "\n")
                else:
                    trainer_sum += 1 
                    trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                class_sum += 1                                          #每类图片的数目
                all_class_images += 1                                   #所有类图片的数目 
            # ----------------------------------数据增强----------------------------------
            aug_path = os.path.join(augment_path, class_dir)
            for img_path in os.listdir(aug_path):                                  # 遍历文件夹下的每个图片
                name_path = os.path.join(aug_path,img_path)                       # 每张图片的路径
                trainer_sum += 1 
                trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                all_class_images += 1                                   #所有类图片的数目
            # ----------------------------------------------------------------------------
            # 说明的json文件的class_detail数据
            class_detail_list['class_name'] = class_dir             #类别名称
            class_detail_list['class_label'] = class_label          #类别标签
            class_detail_list['class_eval_images'] = eval_sum       #该类数据的测试集数目
            class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目
            class_detail.append(class_detail_list)  
            #初始化标签列表
            train_parameters['label_dict'][str(class_label)] = class_dir
            class_label += 1
            
    #初始化分类数
    train_parameters['class_dim'] = class_dim
    print(train_parameters)
    #乱序  
    random.shuffle(eval_list)
    with open(eval_list_path, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image) 
    #乱序        
    random.shuffle(trainer_list) 
    with open(train_list_path, 'a') as f2:
        for train_image in trainer_list:
            f2.write(train_image) 

    # 说明的json文件信息
    readjson = {}
    readjson['all_class_name'] = data_list_path                  #文件父目录
    readjson['all_class_images'] = all_class_images
    readjson['class_detail'] = class_detail
    jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))
    with open(train_parameters['readme_path'],'w') as f:
        f.write(jsons)
    print ('生成数据列表完成!')
def data_reader(file_list):
    '''
    自定义data_reader
    '''
    def reader():
        with open(file_list, 'r') as f:
            lines = [line.strip() for line in f]
            for line in lines:
                img_path, lab = line.strip().split('\t')
                img = Image.open(img_path) 
                if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                img = img.resize((64, 64), Image.BILINEAR)
                img = np.array(img).astype('float32') 
                img = img.transpose((2, 0, 1))  # HWC to CHW 
                img = img/255                   # 像素值归一化 
                yield img, int(lab) 
    return reader
!pip install Augmentor
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already satisfied: Augmentor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.2.8)
Requirement already satisfied: tqdm>=4.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (4.36.1)
Requirement already satisfied: future>=0.16.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (0.18.0)
Requirement already satisfied: numpy>=1.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (1.16.4)
Requirement already satisfied: Pillow>=5.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (7.1.2)

'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
augment_path = train_parameters['augment_path']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)

文件已解压

def proc_img(src):
    for root, dirs, files in os.walk(src):
        if '__MACOSX' in root:continue
        for file in files:            
            src=os.path.join(root,file)
            img=Image.open(src)
            if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                    img.save(src)            


if __name__=='__main__':
    proc_img(r"data/dataset")
import os, Augmentor
import shutil, glob

if not os.path.exists(augment_path): # 控制不重复增强数据
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            path_ = os.path.join(root, name)
            if '__MACOSX' in path_:continue
            print('数据增强:',os.path.join(root, name))
            print('image:',os.path.join(root, name))
            p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output')
            p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2)
            p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1)
            p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1)

            count = 1000 - len(glob.glob(pathname=path_+'/*.jpg'))
            p.sample(count, multi_threaded=False)
            p.process()

    print('将生成的图片拷贝到正确的目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in files:
            path_ = os.path.join(root, name)
            if path_.rsplit('/',3)[2] == 'output':
                type_ = path_.rsplit('/',3)[1]
                dest_dir = os.path.join(augment_path ,type_) 
                if not os.path.exists(dest_dir):os.makedirs(dest_dir) 
                dest_path_ = os.path.join(augment_path ,type_, name) 
                shutil.move(path_, dest_path_)
    print('删除所有output目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            if name == 'output':
                path_ = os.path.join(root, name)
                shutil.rmtree(path_)
    print('完成数据增强')
Processing kunzite_20.jpg:   1%|          | 11/968 [00:00<00:14, 65.61 Samples/s]

数据增强: data/dataset/Kunzite
image: data/dataset/Kunzite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Kunzite/output.

Processing kunzite_14.jpg:   2%|▏         | 24/968 [00:00<00:17, 54.43 Samples/s]Processing kunzite_15.jpg: 100%|██████████| 968/968 [00:15<00:00, 61.57 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=350x366 at 0x7F7060EB06D0>: 100%|██████████| 32/32 [00:00<00:00, 269.33 Samples/s]                  
Processing almandine_5.jpg:   1%|          | 6/969 [00:00<00:20, 45.91 Samples/s] 

数据增强: data/dataset/Almandine
image: data/dataset/Almandine
Initialised with 31 image(s) found.
Output directory set to data/dataset/Almandine/output.

Processing almandine_2.jpg:   1%|▏         | 14/969 [00:00<00:27, 34.12 Samples/s] Processing almandine_25.jpg: 100%|██████████| 969/969 [00:22<00:00, 42.25 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E020C90>: 100%|██████████| 31/31 [00:00<00:00, 173.21 Samples/s]                
Processing emerald_2.jpg:   1%|          | 10/964 [00:00<00:16, 58.72 Samples/s]

数据增强: data/dataset/Emerald
image: data/dataset/Emerald
Initialised with 36 image(s) found.
Output directory set to data/dataset/Emerald/output.

Processing emerald_36.jpg:   2%|▏         | 20/964 [00:00<00:17, 54.08 Samples/s]Processing emerald_15.jpg: 100%|██████████| 964/964 [00:26<00:00, 36.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=460x460 at 0x7F705DED0110>: 100%|██████████| 36/36 [00:00<00:00, 149.48 Samples/s]                   
Processing sapphire blue_9.jpg:   1%|          | 10/966 [00:00<00:13, 68.91 Samples/s]

数据增强: data/dataset/Sapphire Blue
image: data/dataset/Sapphire Blue
Initialised with 34 image(s) found.
Output directory set to data/dataset/Sapphire Blue/output.

Processing sapphire blue_16.jpg:   2%|▏         | 22/966 [00:00<00:16, 56.52 Samples/s]Processing sapphire blue_30.jpg: 100%|██████████| 966/966 [00:18<00:00, 53.08 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=450x450 at 0x7F706885B810>: 100%|██████████| 34/34 [00:00<00:00, 177.29 Samples/s]                  
Processing malachite_2.jpg:   1%|          | 10/972 [00:00<00:20, 47.64 Samples/s]

数据增强: data/dataset/Malachite
image: data/dataset/Malachite
Initialised with 28 image(s) found.
Output directory set to data/dataset/Malachite/output.

Processing malachite_16.jpg:   2%|▏         | 18/972 [00:00<00:20, 47.14 Samples/s]Processing malachite_22.jpg: 100%|██████████| 972/972 [00:18<00:00, 52.32 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=376x262 at 0x7F7060E93D10>: 100%|██████████| 28/28 [00:00<00:00, 173.34 Samples/s]
Processing alexandrite_0.jpg:   1%|          | 6/966 [00:00<00:24, 39.61 Samples/s] 

数据增强: data/dataset/Alexandrite
image: data/dataset/Alexandrite
Initialised with 34 image(s) found.
Output directory set to data/dataset/Alexandrite/output.

Processing alexandrite_23.jpg:   2%|▏         | 18/966 [00:00<00:21, 44.52 Samples/s]Processing alexandrite_20.jpg: 100%|██████████| 966/966 [00:20<00:00, 48.06 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705E025B10>: 100%|██████████| 34/34 [00:00<00:00, 129.49 Samples/s]                 
Processing zircon_8.jpg:   1%|          | 5/967 [00:00<00:33, 28.43 Samples/s] 

数据增强: data/dataset/Zircon
image: data/dataset/Zircon
Initialised with 33 image(s) found.
Output directory set to data/dataset/Zircon/output.

Processing zircon_23.jpg:   1%|          | 6/967 [00:00<00:33, 28.43 Samples/s]Processing zircon_24.jpg: 100%|██████████| 967/967 [00:24<00:00, 38.88 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705DEAC3D0>: 100%|██████████| 33/33 [00:00<00:00, 134.76 Samples/s]                 
Processing onyx black_16.jpg:   1%|          | 8/972 [00:00<00:13, 69.17 Samples/s]

数据增强: data/dataset/Onyx Black
image: data/dataset/Onyx Black
Initialised with 28 image(s) found.
Output directory set to data/dataset/Onyx Black/output.

Processing onyx black_6.jpg:   2%|▏         | 18/972 [00:00<00:18, 51.84 Samples/s] Processing onyx black_2.jpg: 100%|██████████| 972/972 [00:18<00:00, 53.19 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DEE1910>: 100%|██████████| 28/28 [00:00<00:00, 131.50 Samples/s]                 
Processing rhodochrosite_29.jpg:   1%|          | 10/971 [00:00<00:18, 53.20 Samples/s]

数据增强: data/dataset/Rhodochrosite
image: data/dataset/Rhodochrosite
Initialised with 29 image(s) found.
Output directory set to data/dataset/Rhodochrosite/output.

Processing rhodochrosite_21.jpg:   2%|▏         | 21/971 [00:00<00:16, 58.01 Samples/s]Processing rhodochrosite_15.jpg: 100%|██████████| 971/971 [00:20<00:00, 46.42 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=373x356 at 0x7F705E011910>: 100%|██████████| 29/29 [00:00<00:00, 243.76 Samples/s]                  
Processing diamond_16.jpg:   1%|          | 5/969 [00:00<00:28, 34.31 Samples/s]

数据增强: data/dataset/Diamond
image: data/dataset/Diamond
Initialised with 31 image(s) found.
Output directory set to data/dataset/Diamond/output.

Processing diamond_6.jpg:   1%|          | 11/969 [00:00<00:26, 35.79 Samples/s] Processing diamond_20.jpg: 100%|██████████| 969/969 [00:24<00:00, 40.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE6CCD0>: 100%|██████████| 31/31 [00:00<00:00, 150.83 Samples/s]                
Processing benitoite_29.jpg:   1%|          | 7/969 [00:00<00:15, 63.04 Samples/s]

数据增强: data/dataset/Benitoite
image: data/dataset/Benitoite
Initialised with 31 image(s) found.
Output directory set to data/dataset/Benitoite/output.

Processing benitoite_2.jpg:   2%|▏         | 24/969 [00:00<00:16, 57.15 Samples/s] Processing benitoite_12.jpg: 100%|██████████| 969/969 [00:17<00:00, 55.09 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=472x433 at 0x7F705DFE9290>: 100%|██████████| 31/31 [00:00<00:00, 178.70 Samples/s]                 
Processing pearl_0.jpg:   1%|          | 6/967 [00:00<00:25, 38.13 Samples/s] 

数据增强: data/dataset/Pearl
image: data/dataset/Pearl
Initialised with 33 image(s) found.
Output directory set to data/dataset/Pearl/output.

Processing pearl_32.jpg:   2%|▏         | 21/967 [00:00<00:20, 47.09 Samples/s]Processing pearl_12.jpg: 100%|██████████| 967/967 [00:17<00:00, 54.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020A50>: 100%|██████████| 33/33 [00:00<00:00, 205.47 Samples/s]                 
Processing beryl golden_39.jpg:   1%|          | 11/964 [00:00<00:12, 79.36 Samples/s]

数据增强: data/dataset/Beryl Golden
image: data/dataset/Beryl Golden
Initialised with 36 image(s) found.
Output directory set to data/dataset/Beryl Golden/output.

Processing beryl golden_29.jpg:   2%|▏         | 22/964 [00:00<00:14, 63.92 Samples/s]Processing beryl golden_2.jpg: 100%|██████████| 964/964 [00:16<00:00, 58.61 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE6F910>: 100%|██████████| 36/36 [00:00<00:00, 273.71 Samples/s]                
Processing labradorite_16.jpg:   1%|          | 9/960 [00:00<00:17, 55.49 Samples/s]

数据增强: data/dataset/Labradorite
image: data/dataset/Labradorite
Initialised with 40 image(s) found.
Output directory set to data/dataset/Labradorite/output.

Processing labradorite_17.jpg:   2%|▏         | 20/960 [00:00<00:18, 52.03 Samples/s]Processing labradorite_11.jpg: 100%|██████████| 960/960 [00:21<00:00, 45.63 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE70F10>: 100%|██████████| 40/40 [00:00<00:00, 117.40 Samples/s]                 
Processing fluorite_23.jpg:   1%|          | 11/968 [00:00<00:14, 65.24 Samples/s]

数据增强: data/dataset/Fluorite
image: data/dataset/Fluorite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Fluorite/output.

Processing fluorite_4.jpg:   1%|▏         | 14/968 [00:00<00:19, 49.03 Samples/s] Processing fluorite_4.jpg: 100%|██████████| 968/968 [00:21<00:00, 44.39 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=500x442 at 0x7F705DE87CD0>: 100%|██████████| 32/32 [00:00<00:00, 169.43 Samples/s]                 
Processing iolite_2.jpg:   1%|          | 7/968 [00:00<00:24, 39.15 Samples/s] 

数据增强: data/dataset/Iolite
image: data/dataset/Iolite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Iolite/output.

Processing iolite_35.jpg:   2%|▏         | 23/968 [00:00<00:18, 51.39 Samples/s]Processing iolite_23.jpg: 100%|██████████| 968/968 [00:16<00:00, 57.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE764D0>: 100%|██████████| 32/32 [00:00<00:00, 373.16 Samples/s]                  
Processing quartz beer_24.jpg:   1%|          | 12/965 [00:00<00:16, 57.87 Samples/s]

数据增强: data/dataset/Quartz Beer
image: data/dataset/Quartz Beer
Initialised with 35 image(s) found.
Output directory set to data/dataset/Quartz Beer/output.

Processing quartz beer_28.jpg:   2%|▏         | 24/965 [00:00<00:14, 65.30 Samples/s]Processing quartz beer_30.jpg: 100%|██████████| 965/965 [00:16<00:00, 59.48 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=300x300 at 0x7F705DE82DD0>: 100%|██████████| 35/35 [00:00<00:00, 173.58 Samples/s]                 
Processing garnet red_21.jpg:   1%|          | 7/964 [00:00<00:34, 27.76 Samples/s]

数据增强: data/dataset/Garnet Red
image: data/dataset/Garnet Red
Initialised with 36 image(s) found.
Output directory set to data/dataset/Garnet Red/output.

Processing garnet red_2.jpg:   2%|▏         | 17/964 [00:00<00:28, 33.50 Samples/s] Processing garnet red_2.jpg: 100%|██████████| 964/964 [00:20<00:00, 46.97 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020090>: 100%|██████████| 36/36 [00:00<00:00, 197.00 Samples/s]                 
Processing danburite_35.jpg:   1%|          | 8/968 [00:00<00:16, 58.65 Samples/s]

数据增强: data/dataset/Danburite
image: data/dataset/Danburite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Danburite/output.

Processing danburite_32.jpg:   2%|▏         | 17/968 [00:00<00:19, 49.88 Samples/s]Processing danburite_23.jpg: 100%|██████████| 968/968 [00:19<00:00, 50.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE78390>: 100%|██████████| 32/32 [00:00<00:00, 144.25 Samples/s]                 
Processing cats eye_7.jpg:   1%|          | 8/969 [00:00<00:24, 39.01 Samples/s] 

数据增强: data/dataset/Cats Eye
image: data/dataset/Cats Eye
Initialised with 31 image(s) found.
Output directory set to data/dataset/Cats Eye/output.

Processing cats eye_26.jpg:   2%|▏         | 15/969 [00:00<00:23, 41.33 Samples/s]Processing cats eye_33.jpg: 100%|██████████| 969/969 [00:25<00:00, 38.19 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=401x401 at 0x7F706AF09510>: 100%|██████████| 31/31 [00:00<00:00, 214.03 Samples/s]                 
Processing hessonite_1.jpg:   0%|          | 3/970 [00:00<00:33, 28.84 Samples/s] 

数据增强: data/dataset/Hessonite
image: data/dataset/Hessonite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Hessonite/output.

Processing hessonite_19.jpg:   1%|▏         | 13/970 [00:00<00:31, 30.34 Samples/s]Processing hessonite_33.jpg: 100%|██████████| 970/970 [00:20<00:00, 47.73 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020610>: 100%|██████████| 30/30 [00:00<00:00, 162.33 Samples/s]                 
Processing carnelian_12.jpg:   1%|          | 5/967 [00:00<00:28, 34.19 Samples/s]

数据增强: data/dataset/Carnelian
image: data/dataset/Carnelian
Initialised with 33 image(s) found.
Output directory set to data/dataset/Carnelian/output.

Processing carnelian_32.jpg:   1%|          | 12/967 [00:00<00:29, 32.65 Samples/s]Processing carnelian_31.jpg: 100%|██████████| 967/967 [00:24<00:00, 39.93 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=425x425 at 0x7F705DE840D0>: 100%|██████████| 33/33 [00:00<00:00, 147.85 Samples/s]                 
Processing jade_26.jpg:   1%|          | 9/972 [00:00<00:25, 38.24 Samples/s]

数据增强: data/dataset/Jade
image: data/dataset/Jade
Initialised with 28 image(s) found.
Output directory set to data/dataset/Jade/output.

Processing jade_20.jpg:   2%|▏         | 22/972 [00:00<00:19, 47.93 Samples/s]Processing jade_18.jpg: 100%|██████████| 972/972 [00:18<00:00, 51.18 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE8B050>: 100%|██████████| 28/28 [00:00<00:00, 331.02 Samples/s]                  
Processing variscite_22.jpg:   1%|          | 5/970 [00:00<00:25, 37.31 Samples/s]

数据增强: data/dataset/Variscite
image: data/dataset/Variscite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Variscite/output.

Processing variscite_10.jpg:   1%|▏         | 13/970 [00:00<00:26, 35.70 Samples/s]Processing variscite_31.jpg: 100%|██████████| 970/970 [00:21<00:00, 45.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE7BE50>: 100%|██████████| 30/30 [00:00<00:00, 157.22 Samples/s]                 
Processing tanzanite_2.jpg:   1%|          | 5/964 [00:00<00:31, 30.52 Samples/s] 

数据增强: data/dataset/Tanzanite
image: data/dataset/Tanzanite
Initialised with 36 image(s) found.
Output directory set to data/dataset/Tanzanite/output.

Processing tanzanite_15.jpg:   2%|▏         | 15/964 [00:00<00:25, 36.60 Samples/s]Processing tanzanite_37.jpg: 100%|██████████| 964/964 [00:25<00:00, 38.41 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E00E4D0>: 100%|██████████| 36/36 [00:00<00:00, 144.18 Samples/s]                 


将生成的图片拷贝到正确的目录
删除所有output目录
完成数据增强


#每次生成数据列表前,首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
with open(eval_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
    
#生成数据列表   
get_data_list(target_path,train_list_path,eval_list_path,augment_path)

'''
构造数据提供器
'''
train_reader = paddle.batch(data_reader(train_list_path),
                            batch_size=batch_size,
                            drop_last=True)
eval_reader = paddle.batch(data_reader(eval_list_path),
                            batch_size=batch_size,
                            drop_last=True)
{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': '/home/aistudio/augment', 'src_path': 'data/data55032/archive_train.zip', 'target_path': '/home/aistudio/data/dataset', 'train_list_path': './train_data.txt', 'eval_list_path': './val_data.txt', 'label_dict': {'0': 'Kunzite', '1': 'Almandine', '2': 'Emerald', '3': 'Sapphire Blue', '4': 'Malachite', '5': 'Alexandrite', '6': 'Zircon', '7': 'Onyx Black', '8': 'Rhodochrosite', '9': 'Diamond', '10': 'Benitoite', '11': 'Pearl', '12': 'Beryl Golden', '13': 'Labradorite', '14': 'Fluorite', '15': 'Iolite', '16': 'Quartz Beer', '17': 'Garnet Red', '18': 'Danburite', '19': 'Cats Eye', '20': 'Hessonite', '21': 'Carnelian', '22': 'Jade', '23': 'Variscite', '24': 'Tanzanite'}, 'readme_path': '/home/aistudio/data/readme.json', 'num_epochs': 20, 'train_batch_size': 64, 'learning_strategy': {'lr': 0.001}}
生成数据列表完成!
Batch=0
Batchs=[]
all_train_accs=[]
def draw_train_acc(Batchs, train_accs):
    title="training accs"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(Batchs, train_accs, color='green', label='training accs')
    plt.legend()
    plt.grid()
    plt.show()

all_train_loss=[]
def draw_train_loss(Batchs, train_loss):
    title="training loss"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("loss", fontsize=14)
    plt.plot(Batchs, train_loss, color='red', label='training loss')
    plt.legend()
    plt.grid()
    plt.show()

2.定义模型

###在以下cell中完成DNN网络的定义###

#定义网络
class MyDNN(fluid.dygraph.Layer):
    '''
    卷积神经网络
    '''
    def __init__(self):
        super(MyDNN,self).__init__()
        self.hidden1=fluid.dygraph.Linear(3*64*64,1000, act='relu')
        self.hidden2=fluid.dygraph.Linear(1000,500, act='relu')
        self.hidden3=fluid.dygraph.Linear(500,100, act='relu')
        self.out = fluid.dygraph.Linear(input_dim=100, output_dim=25, act='softmax')

    def forward(self,input):
        x = fluid.layers.reshape(input,shape=[-1,3*64*64])
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x

3.训练模型

with fluid.dygraph.guard(place = fluid.CUDAPlace(0)):
    print(train_parameters['class_dim'])
    print(train_parameters['label_dict'])
    model=MyDNN() #模型实例化
    model.train() #训练模式
    opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
    epochs_num=train_parameters['num_epochs'] #迭代次数
    
    for pass_num in range(epochs_num):
        for batch_id,data in enumerate(train_reader()):
            images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
            labels = np.array([x[1] for x in data]).astype('int64')
            labels = labels[:, np.newaxis]

            image=fluid.dygraph.to_variable(images)
            label=fluid.dygraph.to_variable(labels)

            predict=model(image) #数据传入model
            
            loss=fluid.layers.cross_entropy(predict,label)
            avg_loss=fluid.layers.mean(loss)#获取loss值
            
            acc=fluid.layers.accuracy(predict,label)#计算精度
            
            if batch_id!=0 and batch_id%5==0:
                Batch = Batch+5 
                Batchs.append(Batch)
                all_train_loss.append(avg_loss.numpy()[0])
                all_train_accs.append(acc.numpy()[0])
                
                print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num,batch_id,avg_loss.numpy(),acc.numpy()))
            
            avg_loss.backward()       
            opt.minimize(avg_loss)    #优化器对象的minimize方法对参数进行更新 
            model.clear_gradients()   #model.clear_gradients()来重置梯度
    fluid.save_dygraph(model.state_dict(),'MyDNN')#保存模型

draw_train_acc(Batchs,all_train_accs)
draw_train_loss(Batchs,all_train_loss)

train_pass:19,batch_id:400,train_loss:[0.24890603],train_acc:[0.96875]

4.模型评估

#模型评估
with fluid.dygraph.guard():
    accs = []
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    for batch_id,data in enumerate(eval_reader()):#测试集
        images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
        labels = np.array([x[1] for x in data]).astype('int64')
        labels = labels[:, np.newaxis]
        image=fluid.dygraph.to_variable(images)
        label=fluid.dygraph.to_variable(labels)       
        predict=model(image)       
        acc=fluid.layers.accuracy(predict,label)
        accs.append(acc.numpy()[0])
        avg_acc = np.mean(accs)
    print(avg_acc)

0.96875

5.模型预测

import os
import zipfile

def unzip_infer_data(src_path,target_path):
    '''
    解压预测数据集
    '''
    if(not os.path.isdir(target_path)):     
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()


def load_image(img_path):
    '''
    预测图片预处理
    '''
    img = Image.open(img_path) 
    if img.mode != 'RGB': 
        img = img.convert('RGB') 
    img = img.resize((64, 64), Image.BILINEAR)
    img = np.array(img).astype('float32') 
    img = img.transpose((2, 0, 1))  # HWC to CHW 
    img = img/255                # 像素值归一化 
    return img


infer_src_path = '/home/aistudio/data/data55032/archive_test.zip'
infer_dst_path = '/home/aistudio/data/archive_test'
unzip_infer_data(infer_src_path,infer_dst_path)
label_dic = train_parameters['label_dict']

'''
模型预测
'''
with fluid.dygraph.guard():
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    
    #展示预测图片
    infer_path='data/archive_test/alexandrite_3.jpg'
    img = Image.open(infer_path)
    plt.imshow(img)          #根据数组绘制图像
    plt.show()               #显示图像

    #对预测图片进行预处理
    infer_imgs = []
    infer_imgs.append(load_image(infer_path))
    infer_imgs = np.array(infer_imgs)
   
    for i in range(len(infer_imgs)):
        data = infer_imgs[i]
        dy_x_data = np.array(data).astype('float32')
        dy_x_data=dy_x_data[np.newaxis,:, : ,:]
        img = fluid.dygraph.to_variable(dy_x_data)
        out = model(img)
        lab = np.argmax(out.numpy())  #argmax():返回最大数的索引

        print("第{}个样本,被预测为:{},真实标签为:{}".format(i+1,label_dic[str(lab)],infer_path.split('/')[-1].split("_")[0]))
        
print("结束")

第1个样本,被预测为:Malachite,真实标签为:alexandrite 结束

以上就是Python利用DNN实现宝石识别的详细内容,更多关于Python DNN宝石识别的资料请关注脚本之家其它相关文章!

相关文章

  • Python字符编码转码之GBK,UTF8互转

    Python字符编码转码之GBK,UTF8互转

    说到python的编码,一句话总结,说多了都是泪啊,这个在以后的python的开发中绝对是一件令人头疼的事情。所以有必要输入理解
    2020-02-02
  • Python私有属性私有方法应用实例解析

    Python私有属性私有方法应用实例解析

    这篇文章主要介绍了Python私有属性私有方法应用场景解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-09-09
  • Python爬虫实现抓取京东店铺信息及下载图片功能示例

    Python爬虫实现抓取京东店铺信息及下载图片功能示例

    这篇文章主要介绍了Python爬虫实现抓取京东店铺信息及下载图片功能,涉及Python页面请求、响应、解析等相关操作技巧,需要的朋友可以参考下
    2018-08-08
  • 用python写一个定时提醒程序的实现代码

    用python写一个定时提醒程序的实现代码

    今天小编就为大家分享一篇用python写一个定时提醒程序的实现代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • python生成式的send()方法(详解)

    python生成式的send()方法(详解)

    下面小编就为 大家带来一篇python生成式的send()方法(详解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-05-05
  • python修改linux中文件(文件夹)的权限属性操作

    python修改linux中文件(文件夹)的权限属性操作

    这篇文章主要介绍了python修改linux中文件(文件夹)的权限属性操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • Pycharm中无法使用pip安装的包问题解决方案

    Pycharm中无法使用pip安装的包问题解决方案

    本文主要介绍了Pycharm中无法使用pip安装的包问题解决方案,在终端通过pip装好包以后,在pycharm中导入包时,依然会报错,下面就来介绍一下解决方法
    2023-09-09
  • 在tensorflow实现直接读取网络的参数(weight and bias)的值

    在tensorflow实现直接读取网络的参数(weight and bias)的值

    这篇文章主要介绍了在tensorflow实现直接读取网络的参数(weight and bias)的值,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python实现中文数字转换为阿拉伯数字的方法示例

    Python实现中文数字转换为阿拉伯数字的方法示例

    这篇文章主要介绍了Python实现中文数字转换为阿拉伯数字的方法,涉及Python字符串遍历、转换相关操作技巧,需要的朋友可以参考下
    2017-05-05
  • python基本数据类型练习题

    python基本数据类型练习题

    这篇文章主要介绍了python基本数据类型,Python 中的变量不需要声明。每个变量在使用前都必须赋值,变量赋值以后该变量才会被创建。在 Python 中,变量就是变量,它没有类型,我们所说的"类型"是变量所指的内存中对象的类型。下面举例说明改内容,,需要的朋友可以参考一下
    2022-01-01

最新评论