From e5353147531b326523843d88359274b4323324cc Mon Sep 17 00:00:00 2001 From: choosewhatulike <1901722105@qq.com> Date: Wed, 25 Jul 2018 23:37:25 +0800 Subject: [PATCH] add recall metrics --- fastNLP/action/metrics.py | 43 ++++++++++++++++++++++++++++++++------- test/test_metrics.py | 31 ++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/fastNLP/action/metrics.py b/fastNLP/action/metrics.py index 4ac463d5..18e06d5d 100644 --- a/fastNLP/action/metrics.py +++ b/fastNLP/action/metrics.py @@ -10,7 +10,7 @@ To do: import numpy as np import torch import sklearn.metrics as M - +import warnings def _conver_numpy(x): ''' @@ -39,6 +39,7 @@ def _label_types(y): "multiclass" "multiclass-multioutput" "multilabel" + "unknown" ''' # never squeeze the first dimension y = np.squeeze(y, list(range(1, len(y.shape)))) @@ -93,16 +94,44 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) -def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None): - raise NotImplementedError - -def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None): +def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): + y_type, y_true, y_pred = _check_data(y_true, y_pred) + if average == 'binary': + if y_type != 'binary': + raise ValueError("data type is {} but use average type {}".format(y_type, average)) + else: + pos = y_true == pos_label + tp = np.logical_and((y_true == y_pred), pos) + return tp.sum() / pos.sum() + elif average == None: + y_labels = set(list(np.unique(y_true))) + if labels is None: + labels = list(y_labels) + else: + for i in labels: + if i not in y_labels: + warnings.warn('label {} is not contained in data'.format(i), UserWarning) + + if y_type in ['binary', 'multiclass']: + y_pred_right = y_true == y_pred + pos_list = [y_true == i for i in labels] + return [np.logical_and(y_pred_right, pos_i).sum() / pos_i.sum() if pos_i.sum() != 0 else 0 for pos_i in pos_list] + elif y_type == 'multilabel': + y_pred_right = y_true == y_pred + pos = y_true == pos_label + tp = np.logical_and(y_pred_right, pos) + return [tp[:,i].sum() / pos[:,i].sum() if pos[:,i].sum() != 0 else 0 for i in labels] + else: + raise ValueError('not support targets type {}'.format(y_type)) + raise ValueError('not support for average type {}'.format(average)) + +def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): raise NotImplementedError -def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None): +def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): raise NotImplementedError -def classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2): +def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): raise NotImplementedError if __name__ == '__main__': diff --git a/test/test_metrics.py b/test/test_metrics.py index 20e50940..3cc7c286 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -1,8 +1,8 @@ import sys, os -sys.path = [os.path.abspath('..')] + sys.path +sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path -from fastNLP.action.metrics import accuracy_score -from sklearn import metrics as M +from fastNLP.action import metrics +from sklearn import metrics as skmetrics import unittest import numpy as np from numpy import random @@ -12,15 +12,28 @@ def generate_fake_label(low, high, size): class TestMetrics(unittest.TestCase): delta = 1e-5 + # test for binary, multiclass, multilabel + data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] + fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] def test_accuracy_score(self): - for shape, high_bound in [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]: - # test for binary, multiclass, multilabel - y_true, y_pred = generate_fake_label(0, high_bound, shape) + for y_true, y_pred in self.fake_data: for normalize in [True, False]: - for sample_weight in [None, random.rand(shape[0])]: - test = accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) - ans = M.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) + for sample_weight in [None, random.rand(y_true.shape[0])]: + ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) + test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) self.assertAlmostEqual(test, ans, delta=self.delta) + + def test_recall_score(self): + for y_true, y_pred in self.fake_data: + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) + test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) + ans = list(ans) + if not isinstance(test, list): + test = list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) if __name__ == '__main__':