diff --git a/fastNLP/action/metrics.py b/fastNLP/action/metrics.py index 18e06d5d..ad22eed5 100644 --- a/fastNLP/action/metrics.py +++ b/fastNLP/action/metrics.py @@ -57,6 +57,7 @@ def _check_data(y_true, y_pred): ''' check if y_true and y_pred is same type of data e.g both binary or multiclass ''' + y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) if not _check_same_len(y_true, y_pred): raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) type_true, y_true = _label_types(y_true) @@ -100,9 +101,10 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): if y_type != 'binary': raise ValueError("data type is {} but use average type {}".format(y_type, average)) else: - pos = y_true == pos_label - tp = np.logical_and((y_true == y_pred), pos) - return tp.sum() / pos.sum() + pos = (y_true == pos_label) + tp = np.logical_and((y_true == y_pred), pos).sum() + pos_sum = pos.sum() + return tp / pos_sum if pos_sum > 0 else 0 elif average == None: y_labels = set(list(np.unique(y_true))) if labels is None: @@ -111,25 +113,67 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): for i in labels: if i not in y_labels: warnings.warn('label {} is not contained in data'.format(i), UserWarning) - + if y_type in ['binary', 'multiclass']: y_pred_right = y_true == y_pred pos_list = [y_true == i for i in labels] - return [np.logical_and(y_pred_right, pos_i).sum() / pos_i.sum() if pos_i.sum() != 0 else 0 for pos_i in pos_list] + pos_sum_list = [pos_i.sum() for pos_i in pos_list] + return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) elif y_type == 'multilabel': y_pred_right = y_true == y_pred - pos = y_true == pos_label - tp = np.logical_and(y_pred_right, pos) - return [tp[:,i].sum() / pos[:,i].sum() if pos[:,i].sum() != 0 else 0 for i in labels] + pos = (y_true == pos_label) + tp = np.logical_and(y_pred_right, pos).sum(0) + pos_sum = pos.sum(0) + return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) else: raise ValueError('not support targets type {}'.format(y_type)) raise ValueError('not support for average type {}'.format(average)) def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): - raise NotImplementedError + y_type, y_true, y_pred = _check_data(y_true, y_pred) + if average == 'binary': + if y_type != 'binary': + raise ValueError("data type is {} but use average type {}".format(y_type, average)) + else: + pos = (y_true == pos_label) + tp = np.logical_and((y_true == y_pred), pos).sum() + pos_pred = (y_pred == pos_label).sum() + return tp / pos_pred if pos_pred > 0 else 0 + elif average == None: + y_labels = set(list(np.unique(y_true))) + if labels is None: + labels = list(y_labels) + else: + for i in labels: + if i not in y_labels: + warnings.warn('label {} is not contained in data'.format(i), UserWarning) + + if y_type in ['binary', 'multiclass']: + y_pred_right = y_true == y_pred + pos_list = [y_true == i for i in labels] + pos_sum_list = [(y_pred == i).sum() for i in labels] + return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) + elif y_type == 'multilabel': + y_pred_right = y_true == y_pred + pos = (y_true == pos_label) + tp = np.logical_and(y_pred_right, pos).sum(0) + pos_sum = (y_pred == pos_label).sum(0) + return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) + else: + raise ValueError('not support targets type {}'.format(y_type)) + raise ValueError('not support for average type {}'.format(average)) def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): - raise NotImplementedError + precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) + recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) + if isinstance(precision, np.ndarray): + res = 2 * precision * recall / (precision + recall) + res[(precision + recall) <= 0] = 0 + return res + return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): raise NotImplementedError diff --git a/test/test_metrics.py b/test/test_metrics.py index 3cc7c286..47007106 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -25,6 +25,7 @@ class TestMetrics(unittest.TestCase): def test_recall_score(self): for y_true, y_pred in self.fake_data: + # print(y_true.shape) labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) @@ -34,7 +35,59 @@ class TestMetrics(unittest.TestCase): for a, b in zip(test, ans): # print('{}, {}'.format(a, b)) self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.recall_score(y_true, y_pred) + test = metrics.recall_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + def test_precision_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) + test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.precision_score(y_true, y_pred) + test = metrics.precision_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + + def test_precision_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) + test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.precision_score(y_true, y_pred) + test = metrics.precision_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + + def test_f1_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) + test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.f1_score(y_true, y_pred) + test = metrics.f1_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) if __name__ == '__main__': unittest.main() \ No newline at end of file