Browse Source

add accuracy_score; fix optimizer

tags/v0.1.0
choosewhatulike 7 years ago
parent
commit
9701ab2897
3 changed files with 129 additions and 0 deletions
  1. +102
    -0
      fastNLP/action/metrics.py
  2. +0
    -0
      fastNLP/action/optimizer.py
  3. +27
    -0
      test/test_metrics.py

+ 102
- 0
fastNLP/action/metrics.py View File

@@ -5,4 +5,106 @@ To do:
建议是每种metric写成一个函数 (由Tester的evaluate函数调用)
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置
support numpy array and torch tensor
"""
import numpy as np
import torch
import sklearn.metrics as M


def _conver_numpy(x):
'''
converte input data to numpy array
'''
if isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
return x.numpy()
elif isinstance(x, list):
return np.array(x)
raise TypeError('cannot accept obejct: {}'.format(x))

def _check_same_len(*arrays, axis=0):
'''
check if input array list has same length for one dimension
'''
lens = set([x.shape[axis] for x in arrays if x is not None])
return len(lens) == 1

def _label_types(y):
'''
determine the type
"binary"
"multiclass"
"multiclass-multioutput"
"multilabel"
'''
# never squeeze the first dimension
y = np.squeeze(y, list(range(1, len(y.shape))))
shape = y.shape
if len(shape) < 1:
raise ValueError('cannot accept data: {}'.format(y))
if len(shape) == 1:
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y
if len(shape) == 2:
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y
return 'unknown', y

def _check_data(y_true, y_pred):
'''
check if y_true and y_pred is same type of data e.g both binary or multiclass
'''
if not _check_same_len(y_true, y_pred):
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred))
type_true, y_true = _label_types(y_true)
type_pred, y_pred = _label_types(y_pred)

type_set = set(['binary', 'multiclass'])
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred

type_set = set(['multiclass-multioutput', 'multilabel'])
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred))

def _weight_sum(y, normalize=True, sample_weight=None):
if normalize:
return np.average(y, weights=sample_weight)
if sample_weight is None:
return y.sum()
else:
return np.dot(y, sample_weight)


def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if y_type == 'multiclass-multioutput':
raise ValueError('cannot accept data type {0}'.format(y_type))
if y_type == 'multilabel':
equel = (y_true == y_pred).sum(1)
count = equel == y_true.shape[1]
else:
count = y_true == y_pred
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight)


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
raise NotImplementedError

def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
raise NotImplementedError

def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary', sample_weight=None):
raise NotImplementedError

def classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2):
raise NotImplementedError

if __name__ == '__main__':
y = np.array([1,0,1,0,1,1])
print(_label_types(y))

fastNLP/action/optimizor.py → fastNLP/action/optimizer.py View File


+ 27
- 0
test/test_metrics.py View File

@@ -0,0 +1,27 @@
import sys, os
sys.path = [os.path.abspath('..')] + sys.path

from fastNLP.action.metrics import accuracy_score
from sklearn import metrics as M
import unittest
import numpy as np
from numpy import random

def generate_fake_label(low, high, size):
return random.randint(low, high, size), random.randint(low, high, size)

class TestMetrics(unittest.TestCase):
delta = 1e-5
def test_accuracy_score(self):
for shape, high_bound in [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]:
# test for binary, multiclass, multilabel
y_true, y_pred = generate_fake_label(0, high_bound, shape)
for normalize in [True, False]:
for sample_weight in [None, random.rand(shape[0])]:
test = accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
ans = M.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
self.assertAlmostEqual(test, ans, delta=self.delta)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save