diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 070b1d17..b1fc110b 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -96,7 +96,7 @@ class MetricBase(object): will be conducted) :param pred_dict: usually the output of forward or prediction function :param target_dict: usually features set as target.. - :param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. + :param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`. :return: """ if not callable(self.evaluate): @@ -148,8 +148,8 @@ class MetricBase(object): missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): - replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \ - f"in `{get_func_signature(self.evaluate)}`)" + replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + f"in `{self.__class__.__name__}`)" check_res = CheckRes(missing=replaced_missing, unused=check_res.unused, diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 0e30ab9b..0ff724c0 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -51,19 +51,18 @@ class Tester(object): # turn on the testing mode; clean up the history network = self._model self._mode(network, is_test=True) - output, truths = defaultdict(list), defaultdict(list) data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) eval_results = {} try: with torch.no_grad(): for batch_x, batch_y in data_iterator: _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - prediction = self._data_forward(self._predict_func, batch_x) - if not isinstance(prediction, dict): + pred_dict = self._data_forward(self._predict_func, batch_x) + if not isinstance(pred_dict, dict): raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " - f"must be `dict`, got {type(prediction)}.") + f"must be `dict`, got {type(pred_dict)}.") for metric in self.metrics: - metric(prediction, batch_y) + metric(pred_dict, batch_y) for metric in self.metrics: eval_result = metric.get_metric() if not isinstance(eval_result, dict): @@ -74,7 +73,8 @@ class Tester(object): except CheckError as e: prev_func_signature = get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, - check_res=e.check_res, output=output, batch_y=truths, check_level=0) + check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, + dataset=self.data, check_level=0) if self.verbose >= 1: print("[tester] \n{}".format(self._format_eval_results(eval_results))) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 20d54073..b24af193 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -311,14 +311,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ batch_x=batch_x, check_level=check_level) refined_batch_x = _build_args(model.forward, **batch_x) - output = model(**refined_batch_x) + pred_dict = model(**refined_batch_x) func_signature = get_func_signature(model.forward) - if not isinstance(output, dict): - raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") + if not isinstance(pred_dict, dict): + raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") # loss check try: - loss = losser(output, batch_y) + loss = losser(pred_dict, batch_y) # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): @@ -333,8 +333,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ except CheckError as e: pre_func_signature = get_func_signature(model.forward) _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, - check_res=e.check_res, output=output, batch_y=batch_y, - check_level=check_level) + check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, + dataset=dataset, check_level=check_level) model.zero_grad() if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 95297a54..bfbeb6e5 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -229,29 +229,72 @@ WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, - output:dict, batch_y:dict, check_level=0): + pred_dict:dict, target_dict:dict, dataset, check_level=0): errs = [] - _unused = [] + unuseds = [] + _unused_field = [] + _unused_param = [] + suggestions = [] if check_res.varargs: - errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " - f"please delete it.)") + errs.append(f"\tvarargs: *{check_res.varargs}") + suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") + + if check_res.unused: + for _unused in check_res.unused: + if _unused in target_dict: + _unused_field.append(_unused) + else: + _unused_param.append(_unused) + if _unused_field: + unuseds.append([f"\tunused field: {_unused_field}"]) + if _unused_param: + unuseds.append([f"\tunused param: {_unused_param}"]) + if check_res.missing: - errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`" - f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).") + errs.append(f"\tmissing param: {check_res.missing}") + _miss_in_dataset = [] + _miss_out_dataset = [] + for _miss in check_res.missing: + if '(' in _miss: + # if they are like 'SomeParam(assign to xxx)' + _miss = _miss.split('(')[0] + if _miss in dataset: + _miss_in_dataset.append(_miss) + else: + _miss_out_dataset.append(_miss) + + if _miss_in_dataset: + suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now " + f"target is {list(target_dict.keys())}).") + if _miss_out_dataset: + _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " + f"target is {list(target_dict.keys())}) or output it " + f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") + if _unused_field: + _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " + suggestions.append(_tmp) + if check_res.duplicated: - errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " - f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") - if check_res.unused: - _unused = [f"\tunused param: {check_res.unused}"] - if check_level == STRICT_CHECK_LEVEL: - errs.extend(_unused) + errs.append(f"\tduplicated param: {check_res.duplicated}.") + suggestions.append(f"Delete {check_res.duplicated} in the output of " + f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") + + if check_level == STRICT_CHECK_LEVEL: + errs.extend(unuseds) if len(errs)>0: errs.insert(0, f'The following problems occurred when calling {func_signature}') - raise NameError('\n'.join(errs)) - if _unused: + sugg_str = "" + if len(suggestions)>1: + for idx, sugg in enumerate(suggestions): + sugg_str += f'({idx+1}). {sugg}' + else: + sugg_str += suggestions[0] + err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str + raise NameError(err_str) + if check_res.unused: if check_level == WARNING_CHECK_LEVEL: - _unused_warn = _unused[0] + f' in {func_signature}.' + _unused_warn = f'{check_res.unused} is not used by {func_signature}.' warnings.warn(message=_unused_warn) @@ -260,21 +303,45 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): func_signature = get_func_signature(forward_func) errs = [] + suggestions = [] _unused = [] if check_res.varargs: - errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") + errs.append(f"\tvarargs: {check_res.varargs}") + suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") if check_res.missing: - errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. " - f"Please set {check_res.missing} as input.") + errs.append(f"\tmissing param: {check_res.missing}") + _miss_in_dataset = [] + _miss_out_dataset = [] + for _miss in check_res.missing: + if _miss in dataset: + _miss_in_dataset.append(_miss) + else: + _miss_out_dataset.append(_miss) + if _miss_in_dataset: + suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") + if _miss_out_dataset: + _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " + if check_res.unused: + _tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \ + f"rename the field in `unused field:`." + suggestions.append(_tmp) + if check_res.unused: - _unused = [f"\tunused param: {check_res.unused}"] + _unused = [f"\tunused field: {check_res.unused}"] if check_level == STRICT_CHECK_LEVEL: errs.extend(_unused) if len(errs)>0: errs.insert(0, f'The following problems occurred when calling {func_signature}') - raise NameError('\n'.join(errs)) + sugg_str = "" + if len(suggestions)>1: + for idx, sugg in enumerate(suggestions): + sugg_str += f'({idx+1}). {sugg}' + else: + sugg_str += suggestions[0] + err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str + raise NameError(err_str) if _unused: if check_level == WARNING_CHECK_LEVEL: _unused_warn = _unused[0] + f' in {func_signature}.'