keras的siamese(孪生网络)实现案例

 更新时间:2020年06月12日 14:20:25   作者:李上花开  
这篇文章主要介绍了keras的siamese(孪生网络)实现案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python标准库之collections包的使用教程

    Python标准库之collections包的使用教程

    这篇文章主要给大家介绍了Python标准库之collections包的使用教程,详细介绍了collections中多个集合类的使用方法,相信对大家具有一定的参考价值,需要的朋友们下面随小编一起来学习学习吧。
    2017-04-04
  • Django实战之用户认证(初始配置)

    Django实战之用户认证(初始配置)

    这篇文章主要介绍了Django实战之用户认证(初始配置),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-07-07
  • wxPython之解决闪烁的问题

    wxPython之解决闪烁的问题

    下面小编就为大家分享一篇wxPython之解决闪烁的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-01-01
  • OpenCV半小时掌握基本操作之色彩空间

    OpenCV半小时掌握基本操作之色彩空间

    这篇文章主要介绍了OpenCV基本操作之色彩空间,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • 基于python实现模拟数据结构模型

    基于python实现模拟数据结构模型

    这篇文章主要介绍了基于python实现模拟数据结构模型,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • python根据路径导入模块的方法

    python根据路径导入模块的方法

    这篇文章主要介绍了python根据路径导入模块的方法,分析了传统方法与改进方法,具有一定的实用价值,需要的朋友可以参考下
    2014-09-09
  • Python代码中如何读取键盘录入的值

    Python代码中如何读取键盘录入的值

    在本篇文章里小编给大家分享的是关于Python代码中读取键盘录入值的方法,需要的朋友们可以参考下。
    2020-05-05
  • Python计算当前日期是一年中的第几天的方法详解

    Python计算当前日期是一年中的第几天的方法详解

    在Python中,计算当前日期是一年中的第几天可以通过内置的datetime模块来实现,本文将详细介绍如何使用Python编写代码来完成这个任务,需要的可以参考下
    2023-12-12
  • Flask搭建api服务的实现步骤

    Flask搭建api服务的实现步骤

    本文主要介绍了Flask搭建api服务的实现步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-06-06
  • python3 打开外部程序及关闭的示例

    python3 打开外部程序及关闭的示例

    今天小编就为大家分享一篇python3 打开外部程序及关闭的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11

最新评论