keras多显卡训练方式

 更新时间:2020年06月10日 15:32:16   作者:深夜虫鸣  
这篇文章主要介绍了keras多显卡训练方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

 generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 聊聊pytorch测试的时候为何要加上model.eval()

    聊聊pytorch测试的时候为何要加上model.eval()

    这篇文章主要介绍了聊聊pytorch测试的时候为何要加上model.eval()的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python图形验证码识别教程详解

    Python图形验证码识别教程详解

    这篇文章主要介绍了Python图形验证码识别,目前,许多网站采取各种各样的措施来反爬虫,其中一个措施便是使用验证码。随着技术的发展,验证码的花样越来越多。验证码最初是几个数字组合的简单的图形验证码,后来加入了英文字母和混淆曲线
    2023-02-02
  • Python 变量类型实例详解

    Python 变量类型实例详解

    这篇文章主要介绍了Python 变量类型实例详解,基于变量的数据类型,解释器会分配指定内存,并决定什么数据可以被存储在内存中,接下来更多详细内容需要的小伙伴可以参考下面文章,希望对你有所帮助
    2022-02-02
  • 报错No module named numpy问题的解决办法

    报错No module named numpy问题的解决办法

    之前安装了Python,后来因为练习使用Python写科学计算的东西,又安装了Anaconda,但是安装Anaconda之后又出现了一个问题,下面这篇文章主要给大家介绍了关于报错No module named numpy问题的解决办法,需要的朋友可以参考下
    2022-08-08
  • Python标准库copy的具体使用

    Python标准库copy的具体使用

    copy模块是Python标准库中用于对象拷贝的核心模块,提供了浅拷贝(copy)和深拷贝(deepcopy)两种对象复制机制,本文主要介绍了Python标准库copy的具体使用,感兴趣的可以了解一下
    2025-04-04
  • 使用python采集Excel表中某一格数据

    使用python采集Excel表中某一格数据

    这篇文章主要介绍了使用python采集Excel表中某一格数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • 对django中render()与render_to_response()的区别详解

    对django中render()与render_to_response()的区别详解

    今天小编就为大家分享一篇对django中render()与render_to_response()的区别详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python 中的 list、tuple、set、dict的底层实现小结

    Python 中的 list、tuple、set、dict的底层实现小结

    本文详细介绍了Python中四种常用数据结构——list、tuple、set和dict的底层实现,包括它们的存储方式、性能特点以及适用场景,感兴趣的朋友一起看看吧
    2025-03-03
  • 关于windos10环境下编译python3版pjsua库的问题

    关于windos10环境下编译python3版pjsua库的问题

    pjsua默认绑定的python版本是python 2.4,使用起来有诸多限制,希望可以使用python3调用pjsua的库实现软电话的基础功能。这篇文章主要介绍了windos10环境下编译python3版pjsua库,需要的朋友可以参考下
    2021-10-10
  • Tensorflow 合并通道及加载子模型的方法

    Tensorflow 合并通道及加载子模型的方法

    今天小编就为大家分享一篇Tensorflow 合并通道及加载子模型的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07

最新评论