From 1fb1df4a31da9204412dc6f4d3b89a0b8594a9b2 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 4 Dec 2018 10:43:40 +0800 Subject: [PATCH] =?UTF-8?q?1.=20metric=E4=BF=AE=E6=94=B9fast=5Fparam=202.?= =?UTF-8?q?=20trainer=E4=B8=ADupdate=5Fevery=E6=94=B9=E4=B8=BAprint=5Fever?= =?UTF-8?q?y,=20=E5=9B=A0=E4=B8=BAupdate=5Fevery=E5=8F=AF=E8=83=BD?= =?UTF-8?q?=E5=BC=95=E8=B5=B7optimizer=20update=E7=9A=84=E8=AF=AF=E8=A7=A3?= =?UTF-8?q?=203.=20fieldarray=20content=E6=94=AF=E6=8C=81=E4=BD=BF?= =?UTF-8?q?=E7=94=A8np.ndarray=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/fieldarray.py | 6 + fastNLP/core/metrics.py | 73 ++++++------ fastNLP/core/trainer.py | 6 +- fastNLP/core/utils.py | 12 +- test/core/test_metrics.py | 227 ++++++++++++++++++------------------- test/core/test_trainer.py | 129 ++++++++++++++++++--- 6 files changed, 282 insertions(+), 171 deletions(-) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 14c52829..1b1a89c1 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -17,6 +17,12 @@ class FieldArray(object): :param bool is_input: If True, this FieldArray is used to the model input. """ self.name = name + if isinstance(content, list): + content = content + elif isinstance(content, np.ndarray): + content = content.tolist() + else: + raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) self.content = content self.padding_val = padding_val self.is_target = is_target diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index d83c4022..ff40e4e4 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -61,8 +61,7 @@ class MetricBase(object): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " - f"initialization parameters, or change the signature of" - f" {get_func_signature(self.evaluate)}.") + f"initialization parameters, or change its signature.") # evaluate should not have varargs. if func_spect.varargs: @@ -79,13 +78,14 @@ class MetricBase(object): such as pred_dict has one element, target_dict has one element :param pred_dict: :param target_dict: - :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. + :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. """ + fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: return pred_dict.values[0] and target_dict.values[0] - return None + return fast_param - def __call__(self, pred_dict, target_dict, check=False): + def __call__(self, pred_dict, target_dict): """ This method will call self.evaluate method. @@ -96,20 +96,19 @@ class MetricBase(object): (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering - will be conducted) + will be conducted.) + This function also support _fast_param_map. :param pred_dict: usually the output of forward or prediction function :param target_dict: usually features set as target.. - :param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`. :return: """ if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") - if not check: - fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) - if fast_param is not None: - self.evaluate(*fast_param) - return + fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) + if fast_param: + self.evaluate(**fast_param) + return if not self._checked: # 1. check consistence between signature and param_map @@ -147,7 +146,7 @@ class MetricBase(object): duplicated.append(input_arg) # missing - if check or not self._checked: + if not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) # only check missing. missing = check_res.missing @@ -175,40 +174,49 @@ class MetricBase(object): class AccuracyMetric(MetricBase): - def __init__(self, pred=None, target=None, masks=None, seq_lens=None): + def __init__(self, pred=None, target=None, seq_lens=None): super().__init__() - self._init_param_map(pred=pred, target=target, - masks=masks, seq_lens=seq_lens) + self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) self.total = 0 self.acc_count = 0 - def _fast_call_evaluate(self, pred_dict, target_dict): + def _fast_param_map(self, pred_dict, target_dict): """ Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. such as pred_dict has one element, target_dict has one element :param pred_dict: :param target_dict: - :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. + :return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. """ - if len(pred_dict) == 1 and len(target_dict) == 1: - pred = list(pred_dict.values())[0] - target = list(target_dict.values())[0] - self.evaluate(pred=pred, target=target) - return True - return False - - def evaluate(self, pred, target, masks=None, seq_lens=None): + fast_param = {} + targets = list(target_dict.values()) + if len(targets)==1 and isinstance(targets[0], torch.Tensor): + if len(pred_dict)==1: + pred = list(pred_dict.values())[0] + fast_param['pred'] = pred + elif len(pred_dict)==2: + pred1 = list(pred_dict.values())[0] + pred2 = list(pred_dict.values())[1] + if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): + return fast_param + if len(pred1.size())>len(pred2.size()): + fast_param['pred'] = pred1 + fast_param['seq_lens'] = pred2 + else: + return fast_param + fast_param['target'] = targets[0] + return fast_param + + def evaluate(self, pred, target, seq_lens=None): """ :param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) - :param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: - None, None, torch.Size([B, max_len], torch.Size([B, max_len]) :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. :return: dict({'acc': float}) @@ -221,15 +229,14 @@ class AccuracyMetric(MetricBase): raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - if masks is not None and not isinstance(masks, torch.Tensor): - raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(masks)}.") - elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): + if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_lens)}.") - if masks is None and seq_lens is not None: + if seq_lens is not None: masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) + else: + masks = None if pred.size() == target.size(): pass diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 57c79369..a0069571 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -29,7 +29,7 @@ class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, + def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, use_cuda=False, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True): @@ -41,7 +41,7 @@ class Trainer(object): :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics :param int n_epochs: the number of training epochs :param int batch_size: batch size for training and validation - :param int update_every: step interval to print next training information. Default: -1(no print). + :param int print_every: step interval to print next training information. Default: -1(no print). :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). :param DataSet dev_data: the validation data :param use_cuda: @@ -106,7 +106,7 @@ class Trainer(object): self.batch_size = int(batch_size) self.use_cuda = bool(use_cuda) self.save_path = save_path - self.print_every = int(update_every) + self.print_every = int(print_every) self.validate_every = int(validate_every) self.best_metric_indicator = None self.sampler = sampler diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 9fc091a7..4fd5eaec 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -214,7 +214,7 @@ class CheckError(Exception): """ def __init__(self, check_res: CheckRes, func_signature: str): - errs = [f'The following problems occurred when calling `{func_signature}`'] + errs = [f'Problems occurred when calling `{func_signature}`'] if check_res.varargs: errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") @@ -276,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re f"target is {list(target_dict.keys())}).") if _miss_out_dataset: _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " - f"target is {list(target_dict.keys())}) or output it " - f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") + f"target has {list(target_dict.keys())}) or output it " + f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") if _unused_field: _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " suggestions.append(_tmp) @@ -291,7 +291,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re errs.extend(unuseds) if len(errs) > 0: - errs.insert(0, f'The following problems occurred when calling {func_signature}') + errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): @@ -341,7 +341,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): errs.extend(_unused) if len(errs) > 0: - errs.insert(0, f'The following problems occurred when calling {func_signature}') + errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): @@ -356,7 +356,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): warnings.warn(message=_unused_warn) -def seq_lens_to_masks(seq_lens, float=True): +def seq_lens_to_masks(seq_lens, float=False): """ Convert seq_lens to masks. diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index ffc11401..1b8ae70b 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -6,131 +6,126 @@ import torch import numpy as np class TestAccuracyMetric(unittest.TestCase): - # def test_AccuracyMetric1(self): - # # (1) only input, targets passed - # pred_dict = {"pred": torch.zeros(4, 3)} - # target_dict = {'target': torch.zeros(4)} - # metric = AccuracyMetric() - # - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # print(metric.get_metric()) - # - # def test_AccuracyMetric2(self): - # # (2) with corrupted size - # try: - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4)} - # metric = AccuracyMetric() - # - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # print(metric.get_metric()) - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." - # - # def test_AccuracyMetric3(self): - # # (3) with check=False , the second batch is corrupted size - # try: - # metric = AccuracyMetric() - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4)} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # - # print(metric.get_metric()) - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." - # - # def test_AccuracyMetric4(self): - # # (4) with check=True , the second batch is corrupted size - # try: - # metric = AccuracyMetric() - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict) - # - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4)} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # - # print(metric.get_metric()) - # - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." + def test_AccuracyMetric1(self): + # (1) only input, targets passed + pred_dict = {"pred": torch.zeros(4, 3)} + target_dict = {'target': torch.zeros(4)} + metric = AccuracyMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) # - # def test_AccuaryMetric5(self): - # # (5) check reset - # metric = AccuracyMetric() - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict) - # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + def test_AccuracyMetric2(self): + # (2) with corrupted size + try: + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric = AccuracyMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." # - # pred_dict = {"pred": torch.zeros(4, 3, 2)} - # target_dict = {'target': torch.zeros(4, 3)+1} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # self.assertDictEqual(metric.get_metric(), {'acc':0}) + def test_AccuracyMetric3(self): + # (3) the second batch is corrupted size + try: + metric = AccuracyMetric() + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + print(metric.get_metric()) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + # - # def test_AccuaryMetric6(self): - # # (6) check numpy array is not acceptable - # try: - # metric = AccuracyMetric() - # pred_dict = {"pred": np.zeros((4, 3, 2))} - # target_dict = {'target': np.zeros((4, 3))} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # self.assertDictEqual(metric.get_metric(), {'acc': 1}) - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." - - # def test_AccuaryMetric7(self): - # # (7) check map, match - # metric = AccuracyMetric(pred='predictions', target='targets') - # pred_dict = {"predictions": torch.zeros(4, 3, 2)} - # target_dict = {'targets': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + def test_AccuaryMetric4(self): + # (5) check reset + metric = AccuracyMetric() + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)+1} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc':0}) + + def test_AccuaryMetric5(self): + # (5) check reset + metric = AccuracyMetric() + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) + + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)+1} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc':0.5}) + # - # def test_AccuaryMetric8(self): - # # (8) check map, does not match - # try: - # metric = AccuracyMetric(pred='predictions', target='targets') - # pred_dict = {"prediction": torch.zeros(4, 3, 2)} - # target_dict = {'targets': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) - # self.assertDictEqual(metric.get_metric(), {'acc': 1}) - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." - - # def test_AccuaryMetric9(self): - # # (9) check map, include unused - # try: - # metric = AccuracyMetric(pred='predictions', target='targets') - # pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} - # target_dict = {'targets': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict) - # self.assertDictEqual(metric.get_metric(), {'acc': 1}) - # except Exception as e: - # print(e) - # return - # self.assertTrue(True, False), "No exception catches." + def test_AccuaryMetric6(self): + # (6) check numpy array is not acceptable + try: + metric = AccuracyMetric() + pred_dict = {"pred": np.zeros((4, 3, 2))} + target_dict = {'target': np.zeros((4, 3))} + metric(pred_dict=pred_dict, target_dict=target_dict) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_AccuaryMetric7(self): + # (7) check map, match + metric = AccuracyMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.zeros(4, 3, 2)} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + + def test_AccuaryMetric8(self): + # (8) check map, does not match. use stop_fast_param to stop fast param map + try: + metric = AccuracyMetric(pred='predictions', target='targets') + pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict, ) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_AccuaryMetric9(self): + # (9) check map, include unused + try: + metric = AccuracyMetric(pred='prediction', target='targets') + pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." def test_AccuaryMetric10(self): # (10) check _fast_metric try: metric = AccuracyMetric() - pred_dict = {"predictions": torch.zeros(4, 3, 2)} + pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 2975f39c..ed4cc38d 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -1,6 +1,8 @@ import unittest import numpy as np +from torch import nn +import torch.nn.functional as F from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance @@ -11,19 +13,29 @@ from fastNLP.core.trainer import Trainer from fastNLP.models.base_model import NaiveClassifier -class TrainerTestGround(unittest.TestCase): - def test_case(self): - mean = np.array([-3, -3]) - cov = np.array([[1, 0], [0, 1]]) - class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) - mean = np.array([3, 3]) - cov = np.array([[1, 0], [0, 1]]) - class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) - data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + - [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set +def prepare_fake_dataset2(*args, size=100): + ys = np.random.randint(4, size=100) + data = {'y': ys} + for arg in args: + data[arg] = np.random.randn(size, 5) + return DataSet(data=data) + +class TrainerTestGround(unittest.TestCase): + def test_case(self): + data_set = prepare_fake_dataset() data_set.set_input("x", flag=True) data_set.set_target("y", flag=True) @@ -36,10 +48,101 @@ class TrainerTestGround(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), n_epochs=10, batch_size=32, - update_every=1, - validate_every=10, + print_every=50, + validate_every=-1, dev_data=dev_set, optimizer=SGD(lr=0.1), check_code_level=2, use_tqdm=True) - trainer.train() \ No newline at end of file + trainer.train() + + def test_trainer_suggestion1(self): + # 检查报错提示能否正确提醒用户。 + # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 + dataset = prepare_fake_dataset2('x') + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): + x1 = self.fc(x1) + x2 = self.fc(x2) + x = x1 + x2 + loss = F.cross_entropy(x, y) + return {'loss': loss} + + model = Model() + trainer = Trainer( + train_data=dataset, + model=model + ) + """ + # 应该获取到的报错提示 + NameError: + The following problems occurred when calling Model.forward(self, x1, x2, y) + missing param: ['y', 'x1', 'x2'] + Suggestion: (1). You might need to set ['y'] as input. + (2). You need to provide ['x1', 'x2'] in DataSet and set it as input. + + """ + + def test_trainer_suggestion2(self): + # 检查报错提示能否正确提醒用户 + # 这里传入forward需要的数据,看是否可以运行 + dataset = prepare_fake_dataset2('x1', 'x2') + dataset.set_input('x1', 'x2', 'y', flag=True) + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): + x1 = self.fc(x1) + x2 = self.fc(x2) + x = x1 + x2 + loss = F.cross_entropy(x, y) + return {'loss': loss} + + model = Model() + trainer = Trainer( + train_data=dataset, + model=model, + use_tqdm=False, + print_every=2 + ) + trainer.train() + """ + # 应该正确运行 + """ + + def test_trainer_suggestion3(self): + # 检查报错提示能否正确提醒用户 + # 这里传入forward需要的数据,但是forward没有返回loss这个key + dataset = prepare_fake_dataset2('x1', 'x2') + dataset.set_input('x1', 'x2', 'y', flag=True) + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): + x1 = self.fc(x1) + x2 = self.fc(x2) + x = x1 + x2 + loss = F.cross_entropy(x, y) + return {'wrong_loss_key': loss} + + model = Model() + trainer = Trainer( + train_data=dataset, + model=model, + use_tqdm=False, + print_every=2 + ) + trainer.train() + """ + # 应该正确运行 + """ + + + def test_case2(self): + # check metrics Wrong + data_set = prepare_fake_dataset2('x1', 'x2')