Pytorch训练模型得到输出后计算F1-Score 和AUC的操作

 更新时间:2021年05月14日 08:40:23   作者:烟雨人长安  
这篇文章主要介绍了Pytorch训练模型得到输出后计算F1-Score 和AUC的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

1、计算F1-Score

对于二分类来说,假设batch size 大小为64的话,那么模型一个batch的输出应该是torch.size([64,2]),所以首先做的是得到这个二维矩阵的每一行的最大索引值,然后添加到一个列表中,同时把标签也添加到一个列表中,最后使用sklearn中计算F1的工具包进行计算,代码如下

import numpy as np
import sklearn.metrics import f1_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob = prob.cpu().numpy() #先把prob转到CPU上,然后再转成numpy,如果本身在CPU上训练的话就不用先转成CPU了
    prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
    label_all.extend(label)
print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))

2、计算AUC

计算AUC的时候,本次使用的是sklearn中的roc_auc_score () 方法

输入参数:

y_true:真实的标签。形状 (n_samples,) 或 (n_samples, n_classes)。二分类的形状 (n_samples,1),而多标签情况的形状 (n_samples, n_classes)。

y_score:目标分数。形状 (n_samples,) 或 (n_samples, n_classes)。二分类情况形状 (n_samples,1),“分数必须是具有较大标签的类的分数”,通俗点理解:模型打分的第二列。举个例子:模型输入的得分是一个数组 [0.98361117 0.01638886],索引是其类别,这里 “较大标签类的分数”,指的是索引为 1 的分数:0.01638886,也就是正例的预测得分。

average='macro':二分类时,该参数可以忽略。用于多分类,' micro ':将标签指标矩阵的每个元素看作一个标签,计算全局的指标。' macro ':计算每个标签的指标,并找到它们的未加权平均值。这并没有考虑标签的不平衡。' weighted ':计算每个标签的指标,并找到它们的平均值,根据支持度 (每个标签的真实实例的数量) 进行加权。

sample_weight=None:样本权重。形状 (n_samples,),默认 = 无。

max_fpr=None

multi_class='raise':(多分类的问题在下一篇文章中解释)

labels=None

输出:

auc:是一个 float 的值。

import numpy as np
import sklearn.metrics import roc_auc_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的数,根据该函数的参数可知,y_score表示的较大标签类的分数,因此就是最大索引对应的那个值,而不是最大索引值
    label_all.extend(label)
print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))

补充:pytorch训练模型的一些坑

1. 图像读取

opencv的python和c++读取的图像结果不一致,是因为python和c++采用的opencv版本不一样,从而使用的解码库不同,导致读取的结果不同。

2. 图像变换

PIL和pytorch的图像resize操作,与opencv的resize结果不一样,这样会导致训练采用PIL,预测时采用opencv,结果差别很大,尤其是在检测和分割任务中比较明显。

3. 数值计算

pytorch的torch.exp与c++的exp计算,10e-6的数值时候会有10e-3的误差,对于高精度计算需要特别注意,比如

两个输入5.601597, 5.601601, 经过exp计算后变成270.85862343143174, 270.85970686809225

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python之基础函数案例详解

    Python之基础函数案例详解

    这篇文章主要介绍了Python之基础函数案例详解,本篇文章通过简要的案例,讲解了该项技术的了解与使用,以下就是详细内容,需要的朋友可以参考下
    2021-08-08
  • Pytorch随机数生成常用的4种方法汇总

    Pytorch随机数生成常用的4种方法汇总

    随机数广泛应用在科学研究,但是计算机无法产生真正的随机数,一般成为伪随机数,下面这篇文章主要给大家介绍了关于Pytorch随机数生成常用的4种方法,需要的朋友可以参考下
    2023-05-05
  • Python如何在脚本中设置环境变量

    Python如何在脚本中设置环境变量

    环境变量是与系统进程交互的一种深入方式,它允许用户获得有关系统属性、路径和已经存在的变量的更详细信息,下面我们就来看看Python是如何通过脚本来设置环境变量的吧
    2023-10-10
  • 利用Opencv实现图片的油画特效实例

    利用Opencv实现图片的油画特效实例

    这篇文章主要给大家介绍了关于利用Opencv实现图片的油画特效的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • Python写的Socks5协议代理服务器

    Python写的Socks5协议代理服务器

    这篇文章主要介绍了Python写的Socks5协议代理服务器,代码来自网上,需要的朋友可以参考下
    2014-08-08
  • Python使用enumerate获取迭代元素下标

    Python使用enumerate获取迭代元素下标

    这篇文章主要介绍了python使用enumerate获取迭代元素下标,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-02-02
  • 基于Python实现将列表数据生成折线图

    基于Python实现将列表数据生成折线图

    这篇文章主要介绍了如何利用Python中的pandas库和matplotlib库,实现将列表数据生成折线图,文中的示例代码简洁易懂,需要的可以参考一下
    2022-03-03
  • Flask与FastAPI对比选择最佳Python Web框架的超详细指南

    Flask与FastAPI对比选择最佳Python Web框架的超详细指南

    Flask和FastAPI都是流行的Python Web框架,各有特点,Flask轻量级、灵活,适合小型项目和原型开发但不支持异步操作,FastAPI高性能、支持异步,内置数据验证和自动生成API文档,适合高并发和API开发,需要的朋友可以参考下
    2025-02-02
  • Python循环语句之while循环和for循环详解

    Python循环语句之while循环和for循环详解

    在Python中,循环语句用于重复执行一段代码,直到满足某个条件为止,在Python中,有两种主要的循环语句:for循环和while循环,本文就来给大家介绍一下这两个循环的用法,需要的朋友可以参考下
    2023-08-08
  • Python+django实现文件下载

    Python+django实现文件下载

    本文是python+django系列的第二篇文章,主要是讲述是先文件下载的方法和代码,有需要的小伙伴可以参考下。
    2016-01-01

最新评论