如何计算 keras 中的接收操作特性 (ROC) 和 AUC?

本文介绍了如何计算 keras 中的接收操作特性 (ROC) 和 AUC?的处理方法,对大家解决问题具有一定的参考价值

问题描述

我有一个用 keras 编写的多输出(200)二元分类模型.

I have a multi output(200) binary classification model which I wrote in keras.

在此模型中,我想添加其他指标,例如 ROC 和 AUC,但据我所知,keras 没有内置的 ROC 和 AUC 指标函数.

In this model I want to add additional metrics such as ROC and AUC but to my knowledge keras dosen't have in-built ROC and AUC metric functions.

我尝试从 scikit-learn 导入 ROC、AUC 函数

I tried to import ROC, AUC functions from scikit-learn

from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense
.
.
.
model.add(Dense(200, activation='relu'))
model.add(Dense(300, activation='relu'))
model.add(Dense(400, activation='relu'))
model.add(Dense(300, activation='relu'))
model.add(Dense(200,init='normal', activation='softmax')) #outputlayer

model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy','roc_curve','auc'])

但它给出了这个错误:

Exception: Invalid metric: roc_curve

我应该如何将 ROC、AUC 添加到 keras?

How should I add ROC, AUC to keras?

推荐答案

由于不能通过 mini-batches 计算 ROC&AUC,所以只能在一个 epoch 结束时计算.jamartinh 有一个解决方案,为了方便起见,我修补了以下代码:

Due to that you can't calculate ROC&AUC by mini-batches, you can only calculate it on the end of one epoch. There is a solution from jamartinh, I patch the codes below for convenience:

from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback
class RocCallback(Callback):
    def __init__(self,training_data,validation_data):
        self.x = training_data[0]
        self.y = training_data[1]
        self.x_val = validation_data[0]
        self.y_val = validation_data[1]


    def on_train_begin(self, logs={}):
        return

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        y_pred_train = self.model.predict_proba(self.x)
        roc_train = roc_auc_score(self.y, y_pred_train)
        y_pred_val = self.model.predict_proba(self.x_val)
        roc_val = roc_auc_score(self.y_val, y_pred_val)
        print('
roc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_train,4)),str(round(roc_val,4))),end=100*' '+'
')
        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        return

roc = RocCallback(training_data=(X_train, y_train),
                  validation_data=(X_test, y_test))

model.fit(X_train, y_train, 
          validation_data=(X_test, y_test),
          callbacks=[roc])

使用 tf.contrib.metrics.streaming_auc 的一种更易破解的方法:

A more hackable way using tf.contrib.metrics.streaming_auc:

import numpy as np
import tensorflow as tf
from sklearn.metrics import roc_auc_score
from sklearn.datasets import make_classification
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.callbacks import Callback, EarlyStopping


# define roc_callback, inspired by https://github.com/keras-team/keras/issues/6050#issuecomment-329996505
def auc_roc(y_true, y_pred):
    # any tensorflow metric
    value, update_op = tf.contrib.metrics.streaming_auc(y_pred, y_true)

    # find all variables created for this metric
    metric_vars = [i for i in tf.local_variables() if 'auc_roc' in i.name.split('/')[1]]

    # Add metric variables to GLOBAL_VARIABLES collection.
    # They will be initialized for new session.
    for v in metric_vars:
        tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)

    # force to update metric values
    with tf.control_dependencies([update_op]):
        value = tf.identity(value)
        return value

# generation a small dataset
N_all = 10000
N_tr = int(0.7 * N_all)
N_te = N_all - N_tr
X, y = make_classification(n_samples=N_all, n_features=20, n_classes=2)
y = np_utils.to_categorical(y, num_classes=2)

X_train, X_valid = X[:N_tr, :], X[N_tr:, :]
y_train, y_valid = y[:N_tr, :], y[N_tr:, :]

# model & train
model = Sequential()
model.add(Dense(2, activation="softmax", input_shape=(X.shape[1],)))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy', auc_roc])

my_callbacks = [EarlyStopping(monitor='auc_roc', patience=300, verbose=1, mode='max')]

model.fit(X, y,
          validation_split=0.3,
          shuffle=True,
          batch_size=32, nb_epoch=5, verbose=1,
          callbacks=my_callbacks)

# # or use independent valid set
# model.fit(X_train, y_train,
#           validation_data=(X_valid, y_valid),
#           batch_size=32, nb_epoch=5, verbose=1,
#           callbacks=my_callbacks)

这篇关于如何计算 keras 中的接收操作特性 (ROC) 和 AUC?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,WP2

admin_action_{$_REQUEST[‘action’]}

do_action( "admin_action_{$_REQUEST[‘action’]}" )动作钩子::在发送“Action”请求变量时激发。Action Hook: Fires when an ‘action’ request variable is sent.目录锚点:#说明#源码说明(Description)钩子名称的动态部分$_REQUEST['action']引用从GET或POST请求派生的操作。源码(Source)更新版本源码位置使用被使用2.6.0 wp-admin/admin.php:...

日期:2020-09-02 17:44:16 浏览:1173

admin_footer-{$GLOBALS[‘hook_suffix’]}

do_action( "admin_footer-{$GLOBALS[‘hook_suffix’]}", string $hook_suffix )操作挂钩:在默认页脚脚本之后打印脚本或数据。Action Hook: Print scripts or data after the default footer scripts.目录锚点:#说明#参数#源码说明(Description)钩子名的动态部分,$GLOBALS['hook_suffix']引用当前页的全局钩子后缀。参数(Parameters)参数类...

日期:2020-09-02 17:44:20 浏览:1071

customize_save_{$this->id_data[‘base’]}

do_action( "customize_save_{$this->id_data[‘base’]}", WP_Customize_Setting $this )动作钩子::在调用WP_Customize_Setting::save()方法时激发。Action Hook: Fires when the WP_Customize_Setting::save() method is called.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分,$this->id_data...

日期:2020-08-15 15:47:24 浏览:809

customize_value_{$this->id_data[‘base’]}

apply_filters( "customize_value_{$this->id_data[‘base’]}", mixed $default )过滤器::过滤未作为主题模式或选项处理的自定义设置值。Filter Hook: Filter a Customize setting value not handled as a theme_mod or option.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分,$this->id_date['base'],指的是设置...

日期:2020-08-15 15:47:24 浏览:901

get_comment_author_url

过滤钩子:过滤评论作者的URL。Filter Hook: Filters the comment author’s URL.目录锚点:#源码源码(Source)更新版本源码位置使用被使用 wp-includes/comment-template.php:32610...

日期:2020-08-10 23:06:14 浏览:932

network_admin_edit_{$_GET[‘action’]}

do_action( "network_admin_edit_{$_GET[‘action’]}" )操作挂钩:启动请求的处理程序操作。Action Hook: Fires the requested handler action.目录锚点:#说明#源码说明(Description)钩子名称的动态部分$u GET['action']引用请求的操作的名称。源码(Source)更新版本源码位置使用被使用3.1.0 wp-admin/network/edit.php:3600...

日期:2020-08-02 09:56:09 浏览:880

network_sites_updated_message_{$_GET[‘updated’]}

apply_filters( "network_sites_updated_message_{$_GET[‘updated’]}", string $msg )筛选器挂钩:在网络管理中筛选特定的非默认站点更新消息。Filter Hook: Filters a specific, non-default site-updated message in the Network admin.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分$_GET['updated']引用了非默认的...

日期:2020-08-02 09:56:03 浏览:865

pre_wp_is_site_initialized

过滤器::过滤在访问数据库之前是否初始化站点的检查。Filter Hook: Filters the check for whether a site is initialized before the database is accessed.目录锚点:#源码源码(Source)更新版本源码位置使用被使用 wp-includes/ms-site.php:93910...

日期:2020-07-29 10:15:38 浏览:834

WordPress 的SEO 教学:如何在网站中加入关键字(Meta Keywords)与Meta 描述(Meta Description)?

你想在WordPress 中添加关键字和meta 描述吗?关键字和meta 描述使你能够提高网站的SEO。在本文中,我们将向你展示如何在WordPress 中正确添加关键字和meta 描述。为什么要在WordPress 中添加关键字和Meta 描述?关键字和说明让搜寻引擎更了解您的帖子和页面的内容。关键词是人们寻找您发布的内容时,可能会搜索的重要词语或片语。而Meta Description则是对你的页面和文章的简要描述。如果你想要了解更多关于中继标签的资讯,可以参考Google的说明。Meta 关键字和描...

日期:2020-10-03 21:18:25 浏览:1738

谷歌的SEO是什么

SEO (Search Engine Optimization)中文是搜寻引擎最佳化,意思近于「关键字自然排序」、「网站排名优化」。简言之,SEO是以搜索引擎(如Google、Bing)为曝光媒体的行销手法。例如搜寻「wordpress教学」,会看到本站的「WordPress教学:12个课程…」排行Google第一:关键字:wordpress教学、wordpress课程…若搜寻「网站架设」,则会看到另一个网页排名第1:关键字:网站架设、架站…以上两个网页,每月从搜寻引擎导入自然流量,达2万4千:每月「有机搜...

日期:2020-10-30 17:23:57 浏览:1309