Browse Source

add recall metrics

tags/v0.1.0
choosewhatulike 7 years ago
parent
commit
e535314753
2 changed files with 58 additions and 16 deletions
  1. +36
    -7
      fastNLP/action/metrics.py
  2. +22
    -9
      test/test_metrics.py

+ 36
- 7
fastNLP/action/metrics.py View File

@@ -10,7 +10,7 @@ To do:
import numpy as np
import torch
import sklearn.metrics as M
import warnings

def _conver_numpy(x):
'''
@@ -39,6 +39,7 @@ def _label_types(y):
"multiclass"
"multiclass-multioutput"
"multilabel"
"unknown"
'''
# never squeeze the first dimension
y = np.squeeze(y, list(range(1, len(y.shape))))
@@ -93,16 +94,44 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight)


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
raise NotImplementedError

def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
if y_type != 'binary':
raise ValueError("data type is {} but use average type {}".format(y_type, average))
else:
pos = y_true == pos_label
tp = np.logical_and((y_true == y_pred), pos)
return tp.sum() / pos.sum()
elif average == None:
y_labels = set(list(np.unique(y_true)))
if labels is None:
labels = list(y_labels)
else:
for i in labels:
if i not in y_labels:
warnings.warn('label {} is not contained in data'.format(i), UserWarning)
if y_type in ['binary', 'multiclass']:
y_pred_right = y_true == y_pred
pos_list = [y_true == i for i in labels]
return [np.logical_and(y_pred_right, pos_i).sum() / pos_i.sum() if pos_i.sum() != 0 else 0 for pos_i in pos_list]
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = y_true == pos_label
tp = np.logical_and(y_pred_right, pos)
return [tp[:,i].sum() / pos[:,i].sum() if pos[:,i].sum() != 0 else 0 for i in labels]
else:
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))

def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
raise NotImplementedError

def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
raise NotImplementedError

def classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2):
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError

if __name__ == '__main__':


+ 22
- 9
test/test_metrics.py View File

@@ -1,8 +1,8 @@
import sys, os
sys.path = [os.path.abspath('..')] + sys.path
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path

from fastNLP.action.metrics import accuracy_score
from sklearn import metrics as M
from fastNLP.action import metrics
from sklearn import metrics as skmetrics
import unittest
import numpy as np
from numpy import random
@@ -12,15 +12,28 @@ def generate_fake_label(low, high, size):

class TestMetrics(unittest.TestCase):
delta = 1e-5
# test for binary, multiclass, multilabel
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
def test_accuracy_score(self):
for shape, high_bound in [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]:
# test for binary, multiclass, multilabel
y_true, y_pred = generate_fake_label(0, high_bound, shape)
for y_true, y_pred in self.fake_data:
for normalize in [True, False]:
for sample_weight in [None, random.rand(shape[0])]:
test = accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
ans = M.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
for sample_weight in [None, random.rand(y_true.shape[0])]:
ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
self.assertAlmostEqual(test, ans, delta=self.delta)
def test_recall_score(self):
for y_true, y_pred in self.fake_data:
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None
ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None)
test = metrics.recall_score(y_true, y_pred, labels=labels, average=None)
ans = list(ans)
if not isinstance(test, list):
test = list(test)
for a, b in zip(test, ans):
# print('{}, {}'.format(a, b))
self.assertAlmostEqual(a, b, delta=self.delta)


if __name__ == '__main__':

Loading…
Cancel
Save