2. trainer中update_every改为print_every, 因为update_every可能引起optimizer update的误解 3. fieldarray content支持使用np.ndarray初始化tags/v0.2.0^2
@@ -17,6 +17,12 @@ class FieldArray(object): | |||
:param bool is_input: If True, this FieldArray is used to the model input. | |||
""" | |||
self.name = name | |||
if isinstance(content, list): | |||
content = content | |||
elif isinstance(content, np.ndarray): | |||
content = content.tolist() | |||
else: | |||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||
self.content = content | |||
self.padding_val = padding_val | |||
self.is_target = is_target | |||
@@ -61,8 +61,7 @@ class MetricBase(object): | |||
if func_param not in func_args: | |||
raise NameError( | |||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||
f"initialization parameters, or change the signature of" | |||
f" {get_func_signature(self.evaluate)}.") | |||
f"initialization parameters, or change its signature.") | |||
# evaluate should not have varargs. | |||
if func_spect.varargs: | |||
@@ -79,13 +78,14 @@ class MetricBase(object): | |||
such as pred_dict has one element, target_dict has one element | |||
:param pred_dict: | |||
:param target_dict: | |||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
""" | |||
fast_param = {} | |||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
return pred_dict.values[0] and target_dict.values[0] | |||
return None | |||
return fast_param | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
def __call__(self, pred_dict, target_dict): | |||
""" | |||
This method will call self.evaluate method. | |||
@@ -96,20 +96,19 @@ class MetricBase(object): | |||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
will be conducted) | |||
will be conducted.) | |||
This function also support _fast_param_map. | |||
:param pred_dict: usually the output of forward or prediction function | |||
:param target_dict: usually features set as target.. | |||
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`. | |||
:return: | |||
""" | |||
if not callable(self.evaluate): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
if not check: | |||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||
if fast_param is not None: | |||
self.evaluate(*fast_param) | |||
return | |||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||
if fast_param: | |||
self.evaluate(**fast_param) | |||
return | |||
if not self._checked: | |||
# 1. check consistence between signature and param_map | |||
@@ -147,7 +146,7 @@ class MetricBase(object): | |||
duplicated.append(input_arg) | |||
# missing | |||
if check or not self._checked: | |||
if not self._checked: | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||
# only check missing. | |||
missing = check_res.missing | |||
@@ -175,40 +174,49 @@ class MetricBase(object): | |||
class AccuracyMetric(MetricBase): | |||
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | |||
def __init__(self, pred=None, target=None, seq_lens=None): | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, | |||
masks=masks, seq_lens=seq_lens) | |||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||
self.total = 0 | |||
self.acc_count = 0 | |||
def _fast_call_evaluate(self, pred_dict, target_dict): | |||
def _fast_param_map(self, pred_dict, target_dict): | |||
""" | |||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
such as pred_dict has one element, target_dict has one element | |||
:param pred_dict: | |||
:param target_dict: | |||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||
""" | |||
if len(pred_dict) == 1 and len(target_dict) == 1: | |||
pred = list(pred_dict.values())[0] | |||
target = list(target_dict.values())[0] | |||
self.evaluate(pred=pred, target=target) | |||
return True | |||
return False | |||
def evaluate(self, pred, target, masks=None, seq_lens=None): | |||
fast_param = {} | |||
targets = list(target_dict.values()) | |||
if len(targets)==1 and isinstance(targets[0], torch.Tensor): | |||
if len(pred_dict)==1: | |||
pred = list(pred_dict.values())[0] | |||
fast_param['pred'] = pred | |||
elif len(pred_dict)==2: | |||
pred1 = list(pred_dict.values())[0] | |||
pred2 = list(pred_dict.values())[1] | |||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||
return fast_param | |||
if len(pred1.size())>len(pred2.size()): | |||
fast_param['pred'] = pred1 | |||
fast_param['seq_lens'] = pred2 | |||
else: | |||
return fast_param | |||
fast_param['target'] = targets[0] | |||
return fast_param | |||
def evaluate(self, pred, target, seq_lens=None): | |||
""" | |||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) | |||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||
:return: dict({'acc': float}) | |||
@@ -221,15 +229,14 @@ class AccuracyMetric(MetricBase): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if masks is not None and not isinstance(masks, torch.Tensor): | |||
raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(masks)}.") | |||
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_lens)}.") | |||
if masks is None and seq_lens is not None: | |||
if seq_lens is not None: | |||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||
else: | |||
masks = None | |||
if pred.size() == target.size(): | |||
pass | |||
@@ -29,7 +29,7 @@ class Trainer(object): | |||
"""Main Training Loop | |||
""" | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||
@@ -41,7 +41,7 @@ class Trainer(object): | |||
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | |||
:param int n_epochs: the number of training epochs | |||
:param int batch_size: batch size for training and validation | |||
:param int update_every: step interval to print next training information. Default: -1(no print). | |||
:param int print_every: step interval to print next training information. Default: -1(no print). | |||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | |||
:param DataSet dev_data: the validation data | |||
:param use_cuda: | |||
@@ -106,7 +106,7 @@ class Trainer(object): | |||
self.batch_size = int(batch_size) | |||
self.use_cuda = bool(use_cuda) | |||
self.save_path = save_path | |||
self.print_every = int(update_every) | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) | |||
self.best_metric_indicator = None | |||
self.sampler = sampler | |||
@@ -214,7 +214,7 @@ class CheckError(Exception): | |||
""" | |||
def __init__(self, check_res: CheckRes, func_signature: str): | |||
errs = [f'The following problems occurred when calling `{func_signature}`'] | |||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
@@ -276,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
f"target is {list(target_dict.keys())}).") | |||
if _miss_out_dataset: | |||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | |||
f"target is {list(target_dict.keys())}) or output it " | |||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||
f"target has {list(target_dict.keys())}) or output it " | |||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | |||
if _unused_field: | |||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||
suggestions.append(_tmp) | |||
@@ -291,7 +291,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
errs.extend(unuseds) | |||
if len(errs) > 0: | |||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
if len(suggestions) > 1: | |||
for idx, sugg in enumerate(suggestions): | |||
@@ -341,7 +341,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
errs.extend(_unused) | |||
if len(errs) > 0: | |||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
if len(suggestions) > 1: | |||
for idx, sugg in enumerate(suggestions): | |||
@@ -356,7 +356,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
warnings.warn(message=_unused_warn) | |||
def seq_lens_to_masks(seq_lens, float=True): | |||
def seq_lens_to_masks(seq_lens, float=False): | |||
""" | |||
Convert seq_lens to masks. | |||
@@ -6,131 +6,126 @@ import torch | |||
import numpy as np | |||
class TestAccuracyMetric(unittest.TestCase): | |||
# def test_AccuracyMetric1(self): | |||
# # (1) only input, targets passed | |||
# pred_dict = {"pred": torch.zeros(4, 3)} | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric = AccuracyMetric() | |||
# | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# print(metric.get_metric()) | |||
# | |||
# def test_AccuracyMetric2(self): | |||
# # (2) with corrupted size | |||
# try: | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric = AccuracyMetric() | |||
# | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# print(metric.get_metric()) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
# | |||
# def test_AccuracyMetric3(self): | |||
# # (3) with check=False , the second batch is corrupted size | |||
# try: | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# | |||
# print(metric.get_metric()) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
# | |||
# def test_AccuracyMetric4(self): | |||
# # (4) with check=True , the second batch is corrupted size | |||
# try: | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# | |||
# print(metric.get_metric()) | |||
# | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
def test_AccuracyMetric1(self): | |||
# (1) only input, targets passed | |||
pred_dict = {"pred": torch.zeros(4, 3)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
# | |||
# def test_AccuaryMetric5(self): | |||
# # (5) check reset | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
def test_AccuracyMetric2(self): | |||
# (2) with corrupted size | |||
try: | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
# | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)+1} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||
def test_AccuracyMetric3(self): | |||
# (3) the second batch is corrupted size | |||
try: | |||
metric = AccuracyMetric() | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
# | |||
# def test_AccuaryMetric6(self): | |||
# # (6) check numpy array is not acceptable | |||
# try: | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
# target_dict = {'target': np.zeros((4, 3))} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
# def test_AccuaryMetric7(self): | |||
# # (7) check map, match | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
def test_AccuaryMetric4(self): | |||
# (5) check reset | |||
metric = AccuracyMetric() | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)+1} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||
def test_AccuaryMetric5(self): | |||
# (5) check reset | |||
metric = AccuracyMetric() | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)+1} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc':0.5}) | |||
# | |||
# def test_AccuaryMetric8(self): | |||
# # (8) check map, does not match | |||
# try: | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
# def test_AccuaryMetric9(self): | |||
# # (9) check map, include unused | |||
# try: | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric6(self): | |||
# (6) check numpy array is not acceptable | |||
try: | |||
metric = AccuracyMetric() | |||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
target_dict = {'target': np.zeros((4, 3))} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric7(self): | |||
# (7) check map, match | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
def test_AccuaryMetric8(self): | |||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||
try: | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric9(self): | |||
# (9) check map, include unused | |||
try: | |||
metric = AccuracyMetric(pred='prediction', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric10(self): | |||
# (10) check _fast_metric | |||
try: | |||
metric = AccuracyMetric() | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
@@ -1,6 +1,8 @@ | |||
import unittest | |||
import numpy as np | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
@@ -11,19 +13,29 @@ from fastNLP.core.trainer import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
class TrainerTestGround(unittest.TestCase): | |||
def test_case(self): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
def prepare_fake_dataset2(*args, size=100): | |||
ys = np.random.randint(4, size=100) | |||
data = {'y': ys} | |||
for arg in args: | |||
data[arg] = np.random.randn(size, 5) | |||
return DataSet(data=data) | |||
class TrainerTestGround(unittest.TestCase): | |||
def test_case(self): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
@@ -36,10 +48,101 @@ class TrainerTestGround(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
n_epochs=10, | |||
batch_size=32, | |||
update_every=1, | |||
validate_every=10, | |||
print_every=50, | |||
validate_every=-1, | |||
dev_data=dev_set, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=True) | |||
trainer.train() | |||
trainer.train() | |||
def test_trainer_suggestion1(self): | |||
# 检查报错提示能否正确提醒用户。 | |||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||
dataset = prepare_fake_dataset2('x') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model | |||
) | |||
""" | |||
# 应该获取到的报错提示 | |||
NameError: | |||
The following problems occurred when calling Model.forward(self, x1, x2, y) | |||
missing param: ['y', 'x1', 'x2'] | |||
Suggestion: (1). You might need to set ['y'] as input. | |||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||
""" | |||
def test_trainer_suggestion2(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,看是否可以运行 | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def test_trainer_suggestion3(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'wrong_loss_key': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def test_case2(self): | |||
# check metrics Wrong | |||
data_set = prepare_fake_dataset2('x1', 'x2') |