Browse Source

1. metric修改fast_param

2. trainer中update_every改为print_every, 因为update_every可能引起optimizer update的误解
3. fieldarray content支持使用np.ndarray初始化
tags/v0.2.0^2
yh 6 years ago
parent
commit
1fb1df4a31
6 changed files with 282 additions and 171 deletions
  1. +6
    -0
      fastNLP/core/fieldarray.py
  2. +40
    -33
      fastNLP/core/metrics.py
  3. +3
    -3
      fastNLP/core/trainer.py
  4. +6
    -6
      fastNLP/core/utils.py
  5. +111
    -116
      test/core/test_metrics.py
  6. +116
    -13
      test/core/test_trainer.py

+ 6
- 0
fastNLP/core/fieldarray.py View File

@@ -17,6 +17,12 @@ class FieldArray(object):
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
""" """
self.name = name self.name = name
if isinstance(content, list):
content = content
elif isinstance(content, np.ndarray):
content = content.tolist()
else:
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
self.content = content self.content = content
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target self.is_target = is_target


+ 40
- 33
fastNLP/core/metrics.py View File

@@ -61,8 +61,7 @@ class MetricBase(object):
if func_param not in func_args: if func_param not in func_args:
raise NameError( raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")
f"initialization parameters, or change its signature.")


# evaluate should not have varargs. # evaluate should not have varargs.
if func_spect.varargs: if func_spect.varargs:
@@ -79,13 +78,14 @@ class MetricBase(object):
such as pred_dict has one element, target_dict has one element such as pred_dict has one element, target_dict has one element
:param pred_dict: :param pred_dict:
:param target_dict: :param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
""" """
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0] and target_dict.values[0] return pred_dict.values[0] and target_dict.values[0]
return None
return fast_param


def __call__(self, pred_dict, target_dict, check=False):
def __call__(self, pred_dict, target_dict):
""" """


This method will call self.evaluate method. This method will call self.evaluate method.
@@ -96,20 +96,19 @@ class MetricBase(object):
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted)
will be conducted.)
This function also support _fast_param_map.
:param pred_dict: usually the output of forward or prediction function :param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target.. :param target_dict: usually features set as target..
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`.
:return: :return:
""" """
if not callable(self.evaluate): if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


if not check:
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param is not None:
self.evaluate(*fast_param)
return
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param:
self.evaluate(**fast_param)
return


if not self._checked: if not self._checked:
# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
@@ -147,7 +146,7 @@ class MetricBase(object):
duplicated.append(input_arg) duplicated.append(input_arg)


# missing # missing
if check or not self._checked:
if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
# only check missing. # only check missing.
missing = check_res.missing missing = check_res.missing
@@ -175,40 +174,49 @@ class MetricBase(object):




class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
def __init__(self, pred=None, target=None, seq_lens=None):
super().__init__() super().__init__()


self._init_param_map(pred=pred, target=target,
masks=masks, seq_lens=seq_lens)
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)


self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0


def _fast_call_evaluate(self, pred_dict, target_dict):
def _fast_param_map(self, pred_dict, target_dict):
""" """


Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element such as pred_dict has one element, target_dict has one element
:param pred_dict: :param pred_dict:
:param target_dict: :param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping.
""" """
if len(pred_dict) == 1 and len(target_dict) == 1:
pred = list(pred_dict.values())[0]
target = list(target_dict.values())[0]
self.evaluate(pred=pred, target=target)
return True
return False

def evaluate(self, pred, target, masks=None, seq_lens=None):
fast_param = {}
targets = list(target_dict.values())
if len(targets)==1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict)==1:
pred = list(pred_dict.values())[0]
fast_param['pred'] = pred
elif len(pred_dict)==2:
pred1 = list(pred_dict.values())[0]
pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param
if len(pred1.size())>len(pred2.size()):
fast_param['pred'] = pred1
fast_param['seq_lens'] = pred2
else:
return fast_param
fast_param['target'] = targets[0]
return fast_param

def evaluate(self, pred, target, seq_lens=None):
""" """


:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: :param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B, max_len], torch.Size([B, max_len])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float}) :return: dict({'acc': float})
@@ -221,15 +229,14 @@ class AccuracyMetric(MetricBase):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")


if masks is not None and not isinstance(masks, torch.Tensor):
raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(masks)}.")
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.") f"got {type(seq_lens)}.")


if masks is None and seq_lens is not None:
if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)
else:
masks = None


if pred.size() == target.size(): if pred.size() == target.size():
pass pass


+ 3
- 3
fastNLP/core/trainer.py View File

@@ -29,7 +29,7 @@ class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler(), use_tqdm=True): metric_key=None, sampler=RandomSampler(), use_tqdm=True):
@@ -41,7 +41,7 @@ class Trainer(object):
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs :param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation :param int batch_size: batch size for training and validation
:param int update_every: step interval to print next training information. Default: -1(no print).
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). :param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data :param DataSet dev_data: the validation data
:param use_cuda: :param use_cuda:
@@ -106,7 +106,7 @@ class Trainer(object):
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = save_path self.save_path = save_path
self.print_every = int(update_every)
self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self.best_metric_indicator = None self.best_metric_indicator = None
self.sampler = sampler self.sampler = sampler


+ 6
- 6
fastNLP/core/utils.py View File

@@ -214,7 +214,7 @@ class CheckError(Exception):
""" """


def __init__(self, check_res: CheckRes, func_signature: str): def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'The following problems occurred when calling `{func_signature}`']
errs = [f'Problems occurred when calling `{func_signature}`']


if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
@@ -276,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
f"target is {list(target_dict.keys())}).") f"target is {list(target_dict.keys())}).")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
f"target has {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).")
if _unused_field: if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp) suggestions.append(_tmp)
@@ -291,7 +291,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
errs.extend(unuseds) errs.extend(unuseds)


if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
@@ -341,7 +341,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
errs.extend(_unused) errs.extend(_unused)


if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
@@ -356,7 +356,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)




def seq_lens_to_masks(seq_lens, float=True):
def seq_lens_to_masks(seq_lens, float=False):
""" """


Convert seq_lens to masks. Convert seq_lens to masks.


+ 111
- 116
test/core/test_metrics.py View File

@@ -6,131 +6,126 @@ import torch
import numpy as np import numpy as np


class TestAccuracyMetric(unittest.TestCase): class TestAccuracyMetric(unittest.TestCase):
# def test_AccuracyMetric1(self):
# # (1) only input, targets passed
# pred_dict = {"pred": torch.zeros(4, 3)}
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# print(metric.get_metric())
#
# def test_AccuracyMetric2(self):
# # (2) with corrupted size
# try:
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# print(metric.get_metric())
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
#
# def test_AccuracyMetric3(self):
# # (3) with check=False , the second batch is corrupted size
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# print(metric.get_metric())
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
#
# def test_AccuracyMetric4(self):
# # (4) with check=True , the second batch is corrupted size
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# print(metric.get_metric())
#
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
def test_AccuracyMetric1(self):
# (1) only input, targets passed
pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
# #
# def test_AccuaryMetric5(self):
# # (5) check reset
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
def test_AccuracyMetric2(self):
# (2) with corrupted size
try:
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
# #
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)+1}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc':0})
def test_AccuracyMetric3(self):
# (3) the second batch is corrupted size
try:
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict)

print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

# #
# def test_AccuaryMetric6(self):
# # (6) check numpy array is not acceptable
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": np.zeros((4, 3, 2))}
# target_dict = {'target': np.zeros((4, 3))}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."

# def test_AccuaryMetric7(self):
# # (7) check map, match
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"predictions": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
def test_AccuaryMetric4(self):
# (5) check reset
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0})

def test_AccuaryMetric5(self):
# (5) check reset
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0.5})

# #
# def test_AccuaryMetric8(self):
# # (8) check map, does not match
# try:
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"prediction": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."

# def test_AccuaryMetric9(self):
# # (9) check map, include unused
# try:
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable
try:
metric = AccuracyMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric7(self):
# (7) check map, match
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

def test_AccuaryMetric8(self):
# (8) check map, does not match. use stop_fast_param to stop fast param map
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric9(self):
# (9) check map, include unused
try:
metric = AccuracyMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."


def test_AccuaryMetric10(self): def test_AccuaryMetric10(self):
# (10) check _fast_metric # (10) check _fast_metric
try: try:
metric = AccuracyMetric() metric = AccuracyMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)}
target_dict = {'targets': torch.zeros(4, 3)} target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1}) self.assertDictEqual(metric.get_metric(), {'acc': 1})


+ 116
- 13
test/core/test_trainer.py View File

@@ -1,6 +1,8 @@
import unittest import unittest


import numpy as np import numpy as np
from torch import nn
import torch.nn.functional as F


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
@@ -11,19 +13,29 @@ from fastNLP.core.trainer import Trainer
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier




class TrainerTestGround(unittest.TestCase):
def test_case(self):
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))


mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))


data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100)
data = {'y': ys}
for arg in args:
data[arg] = np.random.randn(size, 5)
return DataSet(data=data)

class TrainerTestGround(unittest.TestCase):
def test_case(self):
data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True) data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True) data_set.set_target("y", flag=True)


@@ -36,10 +48,101 @@ class TrainerTestGround(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10, n_epochs=10,
batch_size=32, batch_size=32,
update_every=1,
validate_every=10,
print_every=50,
validate_every=-1,
dev_data=dev_set, dev_data=dev_set,
optimizer=SGD(lr=0.1), optimizer=SGD(lr=0.1),
check_code_level=2, check_code_level=2,
use_tqdm=True) use_tqdm=True)
trainer.train()
trainer.train()

def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
dataset = prepare_fake_dataset2('x')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model
)
"""
# 应该获取到的报错提示
NameError:
The following problems occurred when calling Model.forward(self, x1, x2, y)
missing param: ['y', 'x1', 'x2']
Suggestion: (1). You might need to set ['y'] as input.
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input.

"""

def test_trainer_suggestion2(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,看是否可以运行
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()
"""
# 应该正确运行
"""

def test_trainer_suggestion3(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,但是forward没有返回loss这个key
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'wrong_loss_key': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()
"""
# 应该正确运行
"""


def test_case2(self):
# check metrics Wrong
data_set = prepare_fake_dataset2('x1', 'x2')

Loading…
Cancel
Save