Keras 数据增强ImageDataGenerator多输入多输出实例

 更新时间:2020年07月03日 14:14:04   作者:青盏  
这篇文章主要介绍了Keras 数据增强ImageDataGenerator多输入多输出实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,大家还是直接看代码吧~

import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]=""
import sys
import gc
import time
import cv2
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

from random_eraser import get_random_eraser
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

datagen = ImageDataGenerator(
  rotation_range=20,   #旋转
  width_shift_range=0.1,  #水平位置平移
#   height_shift_range=0.2,  #上下位置平移
  shear_range=0.5,    #错切变换,让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移
  zoom_range=[0.9,0.9],  # 单方向缩放,当一个数值时两个方向等比例缩放,参数为list时长宽不同程度缩放。参数大于0小于1时,执行的是放大操作,当参数大于1时,执行的是缩小操作。
  channel_shift_range = 40, #偏移通道数值,改变图片颜色,越大颜色越深
  horizontal_flip=True,  #水平翻转,垂直翻转vertical_flip
  fill_mode='nearest',   #操作导致图像缺失时填充方式。“constant”、“nearest”(默认)、“reflect”和“wrap”
  preprocessing_function = get_random_eraser(p=0.7,v_l=0,v_h=255,s_l=0.01,s_h=0.03,r_1=1,r_2=1.5,pixel_level=True)
  )

# train_generator = datagen.flow_from_directory(
#       'base/Images/',
#       save_to_dir = 'base/fake/',
#       batch_size=1
#       )
# for i in range(5):
#  train_generator.next()

# !
# df_train = pd.read_csv('base/Annotations/label.csv', header=None)
# df_train.columns = ['image_id', 'class', 'label']
# classes = ['collar_design_labels', 'neckline_design_labels', 'skirt_length_labels', 
#   'sleeve_length_labels', 'neck_design_labels', 'coat_length_labels', 'lapel_design_labels', 
#   'pant_length_labels']
# !

# classes = ['collar_design_labels']

# !
# for i in range(len(classes)):
#  gc.enable()

# #  单个分类
#  cur_class = classes[i]
#  df_load = df_train[(df_train['class'] == cur_class)].copy()
#  df_load.reset_index(inplace=True)
#  del df_load['index']

# #  print(cur_class)

# #  加载数据和label
#  n = len(df_load)
# #  n_class = len(df_load['label'][0])
# #  width = 256

# #  X = np.zeros((n,width, width, 3), dtype=np.uint8)
# #  y = np.zeros((n, n_class), dtype=np.uint8)

#  print(f'starting load trainset {cur_class} {n}')
#  sys.stdout.flush()
#  for i in tqdm(range(n)):
# #   tmp_label = df_load['label'][i]
#   img = load_img('base/{0}'.format(df_load['image_id'][i]))
#   x = img_to_array(img)
#   x = x.reshape((1,) + x.shape)
#   m=0
#   for batch in datagen.flow(x,batch_size=1):
# #    plt.imshow(array_to_img(batch[0]))
# #    print(batch)
#    array_to_img(batch[0]).save(f'base/fake/{format(df_load["image_id"][i])}-{m}.jpg')
#    m+=1
#    if m>3:
#     break
#  gc.collect()
# !  

img = load_img('base/Images/collar_design_labels/2f639f11de22076ead5fe1258eae024d.jpg')
plt.figure()
plt.imshow(img)
x = img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0
for batch in datagen.flow(x,batch_size=5):
 plt.figure()
 plt.imshow(array_to_img(batch[0]))
#  print(len(batch))
 i += 1
 if i >0:
  break
#多输入,设置随机种子
# Define the image transformations here
gen = ImageDataGenerator(horizontal_flip = True,
       vertical_flip = True,
       width_shift_range = 0.1,
       height_shift_range = 0.1,
       zoom_range = 0.1,
       rotation_range = 40)

# Here is the function that merges our two generators
# We use the exact same generator with the same random seed for both the y and angle arrays
def gen_flow_for_two_inputs(X1, X2, y):
 genX1 = gen.flow(X1,y, batch_size=batch_size,seed=666)
 genX2 = gen.flow(X1,X2, batch_size=batch_size,seed=666)
 while True:
   X1i = genX1.next()
   X2i = genX2.next()
   #Assert arrays are equal - this was for peace of mind, but slows down training
   #np.testing.assert_array_equal(X1i[0],X2i[0])
   yield [X1i[0], X2i[1]], X1i[1]
#手动构造,直接输出多label
generator = ImageDataGenerator(rotation_range=5.,
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True, 
        vertical_flip=True)

def generate_data_generator(generator, X, Y1, Y2):
 genX = generator.flow(X, seed=7)
 genY1 = generator.flow(Y1, seed=7)
 while True:
   Xi = genX.next()
   Yi1 = genY1.next()
   Yi2 = function(Y2)
   yield Xi, [Yi1, Yi2]
model.fit_generator(generate_data_generator(generator, X, Y1, Y2),
    epochs=epochs)
def batch_generator(generator,X,Y):
 Xgen = generator.flow(X)
 while True:
  yield Xgen.next(),Y
h = model.fit_generator(batch_generator(datagen, X_all, y_all), 
       steps_per_epoch=len(X_all)//32+1,
       epochs=80,workers=3,
       callbacks=[EarlyStopping(patience=3), checkpointer,ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=1)], 
       validation_data=(X_val,y_val))

补充知识:读取图片成numpy数组,裁剪并保存 和 数据增强(ImageDataGenerator)

我就废话不多说了,大家还是直接看代码吧~

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import matplotlib.pyplot as plt
import os
import cv2
# from scipy.misc import toimage
import matplotlib
# 生成图片地址和对应标签
file_dir = '../train/'
image_list = []
label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 temp = name.split('/')
 path = '../train_new/' + temp[-1]
 isExists = os.path.exists(path)
 if not isExists:
  os.makedirs(path) # 目录不存在则创建
 class_path = name + "/"

 for file in os.listdir(class_path):
  print(file)
  img_obj = Image.open(class_path + file) # 读取图片
  img_array = np.array(img_obj)
  resized = cv2.resize(img_array, (256, 256)) # 裁剪
  resized = resized.astype('float32')
  resized /= 255.
  # plt.imshow(resized)
  # plt.show()
  save_path = path + '/' + file
  matplotlib.image.imsave(save_path, resized) # 保存

keras之数据增强

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import os
import cv2
# 生成图片地址和对应标签
file_dir = '../train/'

label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 image_list = []
 class_path = name + "/"
 for file in os.listdir(class_path):
  image_list.append(class_path + file)
 batch_size = 64
 if len(image_list) < 10000:
  num = int(10000 / len(image_list))
 else:
  num = 0
 # 设置生成器参数
 datagen = image.ImageDataGenerator(fill_mode='wrap', # 填充模式
          rotation_range=40, # 指定旋转角度范围
          width_shift_range=0.2, # 水平位置平移
          height_shift_range=0.2, # 上下位置平移
          horizontal_flip=True, # 随机对图片执行水平翻转操作
          vertical_flip=True, # 对图片执行上下翻转操作
          shear_range=0.2,
          rescale=1./255, # 缩放
          data_format='channels_last')
 if num > 0:
  temp = name.split('/')
  path = '../train_datage/' + temp[-1]
  isExists = os.path.exists(path)
  if not isExists:
   os.makedirs(path)

  for image_path in image_list:
   i = 1
   img_obj = Image.open(image_path) # 读取图片
   img_array = np.array(img_obj)
   x = img_array.reshape((1,) + img_array.shape)  #要求为4维
   name_image = image_path.split('/')
   print(name_image)
   for batch in datagen.flow(x,
        batch_size=1,
        save_to_dir=path,
        save_prefix=name_image[-1][:-4] + '_',
        save_format='jpg'):
    i += 1
    if i > num:
     break

以上这篇Keras 数据增强ImageDataGenerator多输入多输出实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python+Delorean实现时间格式智能转换

    Python+Delorean实现时间格式智能转换

    DeLorean是一个Python的第三方模块,基于 pytz 和 dateutil 开发,用于处理Python中日期时间的格式转换。本文将详细讲讲DeLorean的使用,感兴趣的可以了解一下
    2022-04-04
  • 利用Python破解生日悖论问题

    利用Python破解生日悖论问题

    生日悖论,就是23个人在一个房间,期间必然有两个人生日相同的概率为50%,30个人的话概率是70%,60个人甚至上升到99%。本文就来用Python语言破解这一问题,感兴趣的可以了解一下
    2022-12-12
  • 解决TensorFlow调用Keras库函数存在的问题

    解决TensorFlow调用Keras库函数存在的问题

    这篇文章主要介绍了解决TensorFlow调用Keras库函数存在的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Python实现简繁体转换

    Python实现简繁体转换

    很多时候简繁体转换,掌握了简体与繁体的转换,往往能够事半功倍,本文主要介绍了Python实现简繁体转换,感兴趣的可以了解一下
    2021-06-06
  • 教你用python3根据关键词爬取百度百科的内容

    教你用python3根据关键词爬取百度百科的内容

    这篇文章介绍的是利用python3根据关键词爬取百度百科的内容,注意本文用的是python3版本以及根据关键词爬取,爬取也只是单纯的爬网页信息,有需要的可以参考借鉴。
    2016-08-08
  • 解决Python字典写入文件出行首行有空格的问题

    解决Python字典写入文件出行首行有空格的问题

    下面小编就为大家带来一篇解决Python字典写入文件出行首行有空格的问题。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-09-09
  • python实现桌面气泡提示功能

    python实现桌面气泡提示功能

    这篇文章主要为大家详细介绍了python实现桌面气泡提示功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-07-07
  • python关于矩阵重复赋值覆盖问题的解决方法

    python关于矩阵重复赋值覆盖问题的解决方法

    这篇文章主要介绍了python关于矩阵重复赋值覆盖问题的解决方法,涉及Python深拷贝与浅拷贝相关操作与使用技巧,需要的朋友可以参考下
    2019-07-07
  • 基于keras 模型、结构、权重保存的实现

    基于keras 模型、结构、权重保存的实现

    今天小编就为大家分享一篇基于keras 模型、结构、权重保存的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • python 读取文本文件的行数据,文件.splitlines()的方法

    python 读取文本文件的行数据,文件.splitlines()的方法

    今天小编就为大家分享一篇python 读取文本文件的行数据,文件.splitlines()的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07

最新评论