diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 7c8a6bec..ad22eed5 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -5,4 +5,179 @@ To do: 建议是每种metric写成一个函数 (由Tester的evaluate函数调用) 参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 + support numpy array and torch tensor """ +import numpy as np +import torch +import sklearn.metrics as M +import warnings + +def _conver_numpy(x): + ''' + converte input data to numpy array + ''' + if isinstance(x, np.ndarray): + return x + elif isinstance(x, torch.Tensor): + return x.numpy() + elif isinstance(x, list): + return np.array(x) + raise TypeError('cannot accept obejct: {}'.format(x)) + +def _check_same_len(*arrays, axis=0): + ''' + check if input array list has same length for one dimension + ''' + lens = set([x.shape[axis] for x in arrays if x is not None]) + return len(lens) == 1 + + +def _label_types(y): + ''' + determine the type + "binary" + "multiclass" + "multiclass-multioutput" + "multilabel" + "unknown" + ''' + # never squeeze the first dimension + y = np.squeeze(y, list(range(1, len(y.shape)))) + shape = y.shape + if len(shape) < 1: + raise ValueError('cannot accept data: {}'.format(y)) + if len(shape) == 1: + return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y + if len(shape) == 2: + return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y + return 'unknown', y + + +def _check_data(y_true, y_pred): + ''' + check if y_true and y_pred is same type of data e.g both binary or multiclass + ''' + y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) + if not _check_same_len(y_true, y_pred): + raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) + type_true, y_true = _label_types(y_true) + type_pred, y_pred = _label_types(y_pred) + + type_set = set(['binary', 'multiclass']) + if type_true in type_set and type_pred in type_set: + return type_true if type_true == type_pred else 'multiclass', y_true, y_pred + + type_set = set(['multiclass-multioutput', 'multilabel']) + if type_true in type_set and type_pred in type_set: + return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred + + raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) + + +def _weight_sum(y, normalize=True, sample_weight=None): + if normalize: + return np.average(y, weights=sample_weight) + if sample_weight is None: + return y.sum() + else: + return np.dot(y, sample_weight) + + +def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): + y_type, y_true, y_pred = _check_data(y_true, y_pred) + if y_type == 'multiclass-multioutput': + raise ValueError('cannot accept data type {0}'.format(y_type)) + if y_type == 'multilabel': + equel = (y_true == y_pred).sum(1) + count = equel == y_true.shape[1] + else: + count = y_true == y_pred + return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) + + +def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): + y_type, y_true, y_pred = _check_data(y_true, y_pred) + if average == 'binary': + if y_type != 'binary': + raise ValueError("data type is {} but use average type {}".format(y_type, average)) + else: + pos = (y_true == pos_label) + tp = np.logical_and((y_true == y_pred), pos).sum() + pos_sum = pos.sum() + return tp / pos_sum if pos_sum > 0 else 0 + elif average == None: + y_labels = set(list(np.unique(y_true))) + if labels is None: + labels = list(y_labels) + else: + for i in labels: + if i not in y_labels: + warnings.warn('label {} is not contained in data'.format(i), UserWarning) + + if y_type in ['binary', 'multiclass']: + y_pred_right = y_true == y_pred + pos_list = [y_true == i for i in labels] + pos_sum_list = [pos_i.sum() for pos_i in pos_list] + return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) + elif y_type == 'multilabel': + y_pred_right = y_true == y_pred + pos = (y_true == pos_label) + tp = np.logical_and(y_pred_right, pos).sum(0) + pos_sum = pos.sum(0) + return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) + else: + raise ValueError('not support targets type {}'.format(y_type)) + raise ValueError('not support for average type {}'.format(average)) + +def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): + y_type, y_true, y_pred = _check_data(y_true, y_pred) + if average == 'binary': + if y_type != 'binary': + raise ValueError("data type is {} but use average type {}".format(y_type, average)) + else: + pos = (y_true == pos_label) + tp = np.logical_and((y_true == y_pred), pos).sum() + pos_pred = (y_pred == pos_label).sum() + return tp / pos_pred if pos_pred > 0 else 0 + elif average == None: + y_labels = set(list(np.unique(y_true))) + if labels is None: + labels = list(y_labels) + else: + for i in labels: + if i not in y_labels: + warnings.warn('label {} is not contained in data'.format(i), UserWarning) + + if y_type in ['binary', 'multiclass']: + y_pred_right = y_true == y_pred + pos_list = [y_true == i for i in labels] + pos_sum_list = [(y_pred == i).sum() for i in labels] + return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) + elif y_type == 'multilabel': + y_pred_right = y_true == y_pred + pos = (y_true == pos_label) + tp = np.logical_and(y_pred_right, pos).sum(0) + pos_sum = (y_pred == pos_label).sum(0) + return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) + else: + raise ValueError('not support targets type {}'.format(y_type)) + raise ValueError('not support for average type {}'.format(average)) + +def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): + precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) + recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) + if isinstance(precision, np.ndarray): + res = 2 * precision * recall / (precision + recall) + res[(precision + recall) <= 0] = 0 + return res + return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + + +def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): + raise NotImplementedError + +if __name__ == '__main__': + y = np.array([1,0,1,0,1,1]) + print(_label_types(y)) \ No newline at end of file diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py new file mode 100644 index 00000000..b493e3f0 --- /dev/null +++ b/fastNLP/core/optimizer.py @@ -0,0 +1,5 @@ +''' +use optimizer from Pytorch +''' + +from torch.optim import * \ No newline at end of file diff --git a/fastNLP/core/optimizor.py b/fastNLP/core/optimizor.py deleted file mode 100644 index becdc499..00000000 --- a/fastNLP/core/optimizor.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch import optim - - -def get_torch_optimizer(params, alg_name='sgd', **args): - """ - construct PyTorch optimizer by algorithm's name - optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc - - usage: - optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) - - """ - - name = alg_name.lower() - if name == 'adadelta': - return optim.Adadelta(params, **args) - elif name == 'adagrad': - return optim.Adagrad(params, **args) - elif name == 'adam': - return optim.Adam(params, **args) - elif name == 'adamax': - return optim.Adamax(params, **args) - elif name == 'asgd': - return optim.ASGD(params, **args) - elif name == 'lbfgs': - return optim.LBFGS(params, **args) - elif name == 'rmsprop': - return optim.RMSprop(params, **args) - elif name == 'rprop': - return optim.Rprop(params, **args) - elif name == 'sgd': - # SGD's parameter lr is required - if 'lr' not in args: - args['lr'] = 0.01 - return optim.SGD(params, **args) - elif name == 'sparseadam': - return optim.SparseAdam(params, **args) - else: - raise TypeError('no such optimizer named {}'.format(alg_name)) - - -if __name__ == '__main__': - from torch.nn.modules import Linear - - net = Linear(2, 5) - - test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) - print(test1) - test2 = get_torch_optimizer(net.parameters(), 'SGD') - print(test2) diff --git a/test/test_metrics.py b/test/test_metrics.py new file mode 100644 index 00000000..47007106 --- /dev/null +++ b/test/test_metrics.py @@ -0,0 +1,93 @@ +import sys, os +sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path + +from fastNLP.action import metrics +from sklearn import metrics as skmetrics +import unittest +import numpy as np +from numpy import random + +def generate_fake_label(low, high, size): + return random.randint(low, high, size), random.randint(low, high, size) + +class TestMetrics(unittest.TestCase): + delta = 1e-5 + # test for binary, multiclass, multilabel + data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] + fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] + def test_accuracy_score(self): + for y_true, y_pred in self.fake_data: + for normalize in [True, False]: + for sample_weight in [None, random.rand(y_true.shape[0])]: + ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) + test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) + self.assertAlmostEqual(test, ans, delta=self.delta) + + def test_recall_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) + test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) + ans = list(ans) + if not isinstance(test, list): + test = list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.recall_score(y_true, y_pred) + test = metrics.recall_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + + def test_precision_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) + test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.precision_score(y_true, y_pred) + test = metrics.precision_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + + def test_precision_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) + test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.precision_score(y_true, y_pred) + test = metrics.precision_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + + def test_f1_score(self): + for y_true, y_pred in self.fake_data: + # print(y_true.shape) + labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None + ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) + test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) + ans, test = list(ans), list(test) + for a, b in zip(test, ans): + # print('{}, {}'.format(a, b)) + self.assertAlmostEqual(a, b, delta=self.delta) + # test binary + y_true, y_pred = generate_fake_label(0, 2, 1000) + ans = skmetrics.f1_score(y_true, y_pred) + test = metrics.f1_score(y_true, y_pred) + self.assertAlmostEqual(ans, test, delta=self.delta) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file