Keras实现将两个模型连接到一起

 更新时间:2020年05月23日 10:04:10   作者:木盏  
这篇文章主要介绍了Keras实现将两个模型连接到一起,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python中jsonpath的使用小结

    python中jsonpath的使用小结

    JsonPath是一种信息抽取类库,是从JSON文档中抽取指定信息的工具,提供多种语言实现版本,本文主要介绍了python中jsonpath的使用小结,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • tkinter动态显示时间的两种实现方法

    tkinter动态显示时间的两种实现方法

    这篇文章主要介绍了tkinter动态显示时间的两种实现方法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-01-01
  • python 时间的访问和转换 time示例小结

    python 时间的访问和转换 time示例小结

    Python 的 time 模块提供了各种与时间处理相关的功能,包括获取当前时间、操作日期/时间以及执行与时间相关的各种其它功能,这篇文章主要介绍了python 时间的访问和转换 time,需要的朋友可以参考下
    2024-05-05
  • 浅谈python 中类属性共享的问题

    浅谈python 中类属性共享的问题

    今天小编就为大家分享一篇浅谈python 中类属性共享的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python自动化部署工具Fabric的简单上手指南

    Python自动化部署工具Fabric的简单上手指南

    这篇文章主要介绍了Python自动化部署工具Fabric的简单上手指南,涵盖Fabric的安装、fabric的远程操作与维护等方面,需要的朋友可以参考下
    2016-04-04
  • Python详细讲解浅拷贝与深拷贝的使用

    Python详细讲解浅拷贝与深拷贝的使用

    这篇文章主要介绍了Python中的深拷贝和浅拷贝,通过讲解Python中的浅拷贝和深拷贝的概念和背后的原理展开全文,需要的小伙伴可以参考一下
    2022-07-07
  • Django之使用celery和NGINX生成静态页面实现性能优化

    Django之使用celery和NGINX生成静态页面实现性能优化

    这篇文章主要介绍了Django之使用celery和NGINX生成静态页面实现性能优化,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • 详解Python流程控制语句

    详解Python流程控制语句

    这篇文章主要介绍了Python流程控制语句的的相关资料,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-10-10
  • matplotlib基础绘图命令之imshow的使用

    matplotlib基础绘图命令之imshow的使用

    这篇文章主要介绍了matplotlib基础绘图命令之imshow的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • Python中单引号、双引号和三引号具体的用法及注意点

    Python中单引号、双引号和三引号具体的用法及注意点

    这篇文章主要给大家介绍了关于Python中单引号、双引号和三引号具体的用法及注意点的相关资料,Python中单引号、双引号、三引号中使用常常困惑,想弄明白这三者相同点和不同点,需要的朋友可以参考下
    2023-07-07

最新评论