python机器学习之神经网络实现

 更新时间:2021年05月08日 11:48:44   作者:Yaniesta  
这篇文章主要为大家详细介绍了python机器学习之神经网络的实现方法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

对分类而言:

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

我们有:

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

然后,我们进行化简:

化简到这里,我们同理再对W1进行求导。

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • django中SMTP发送邮件配置详解

    django中SMTP发送邮件配置详解

    这篇文章主要介绍了django中SMTP发送邮件配置,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python实现像awk一样分割字符串

    Python实现像awk一样分割字符串

    这篇文章主要介绍了Python实现像awk一样分割字符串,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-09-09
  • wxpython 学习笔记 第一天

    wxpython 学习笔记 第一天

    学习wxpython的朋友,可以看下,了解下wxpython
    2009-03-03
  • python可视化分析绘制带趋势线的散点图和边缘直方图

    python可视化分析绘制带趋势线的散点图和边缘直方图

    这篇文章主要介绍了python可视化分析绘制带趋势线的散点图和边缘直方图,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-06-06
  • Python中处理表格数据的Tablib库详解

    Python中处理表格数据的Tablib库详解

    这篇文章主要介绍了Python中处理表格数据的Tablib库详解,Tablib 是一个 MIT 许可的格式不可知的表格数据集库,用 Python 编写,它允许您导入、导出和操作表格数据集,需要的朋友可以参考下
    2023-08-08
  • Python smtp邮件发送模块用法教程

    Python smtp邮件发送模块用法教程

    这篇文章主要介绍了Python smtp邮件发送模块用法教程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • Python使用HTTP POST上传WAV文件的方法

    Python使用HTTP POST上传WAV文件的方法

    Python是一个非常流行的编程语言,可以用于开发不同类型的应用程序。其中,上传文件是一个非常常见的需求。具体而言,我们探讨了使用HTTP POST请求上传单个和多个WAV文件的方法。无论你是需要将音频文件上传到云存储还是服务器,这些方法都能帮助你上传文件。
    2023-06-06
  • Django 使用Ajax进行前后台交互的示例讲解

    Django 使用Ajax进行前后台交互的示例讲解

    今天小编就为大家分享一篇Django 使用Ajax进行前后台交互的示例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python模糊查询本地文件夹去除文件后缀的实例(7行代码)

    Python模糊查询本地文件夹去除文件后缀的实例(7行代码)

    下面小编就为大家带来一篇Python模糊查询本地文件夹去除文件后缀的实例(7行代码) 。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-11-11
  • python爬虫-模拟微博登录功能

    python爬虫-模拟微博登录功能

    这篇文章主要介绍了python爬虫-模拟微博登录功能,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-09-09

最新评论