diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2be6e2fa..2a6458c6 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -237,14 +237,10 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No # loss check if batch_count == 0: - _dict = _check_arg_dict_list(model.loss, [output, batch_y]) - if len(_dict) != 0: - pass - loss_input = _build_args(model.loss, **output, **batch_y) - loss = model.loss(**loss_input) - if batch_count == 0: - if isinstance(loss, torch.Tensor): - pass + _check_loss(model=model, model_func=model.get_loss, check_level=check_level, + output=output, batch_y=batch_y) + loss_input = _build_args(model.get_loss, **output, **batch_y) + loss = model.get_loss(**loss_input) # check loss output if batch_count == 0: @@ -281,7 +277,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No truths[k].append(v) if batch_count+1>DEFAULT_CHECK_NUM_BATCH: break - _check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level, + _check_loss(model=model, model_func=model.evaluate, check_level=check_level, output=outputs, batch_y=truths) refined_input = _build_args(model.evaluate, **outputs, **truths) metrics = model.evaluate(**refined_input) @@ -323,16 +319,17 @@ def _check_forward_error(model, model_func, check_level, batch_x): elif check_level == WARNING_CHECK_LEVEL: warnings.warn(message=_unused) -def _check_loss_evaluate(model, model_func, check_level, output, batch_y): +def _check_loss(model, model_func, check_level, output, batch_y): check_res = _check_arg_dict_list(model_func, [output, batch_y]) _missing = '' _unused = '' _duplicated = '' signature_str = get_func_signature(model_func) - func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1]) - forward_signature_str = get_func_signature(model.forward) - forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, forward_signature_str[1:-1]) model_name = model.__class__.__name__ + model_func_name = model_func.__name__ + func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) + forward_signature_str = get_func_signature(model.forward) + forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1]) if len(check_res.missing)>0: _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ "{}." \ @@ -384,6 +381,77 @@ def _check_loss_evaluate(model, model_func, check_level, output, batch_y): if _error_str: raise ValueError(_error_str) +def _check_evaluate(model, model_func, check_level, output, batch_y): + + check_res = _check_arg_dict_list(model_func, [output, batch_y]) + _missing = '' + _unused = '' + _duplicated = '' + signature_str = get_func_signature(model_func) + model_name = model.__class__.__name__ + model_func_name = model_func.__name__ + func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) + if hasattr(model, 'predict'): + previous_func = model.predict + previous_func_name = 'predict' + else: + previous_func = model.forward + previous_func_name = 'forward' + previous_signature_str = get_func_signature(previous_func) + previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1]) + if len(check_res.missing)>0: + _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ + "{}." \ + .format(func_signature, check_res.missing, + list(output.keys()), previous_func_signature, + list(batch_y.keys())) + if len(check_res.unused)>0: + if len(check_res.unused) > 1: + _unused = "{} are not used ".format(check_res.unused) + else: + _unused = "{} is not used ".format(check_res.unused) + _unused += "in function {}.\n".format(func_signature) + if len(check_res.duplicated)>0: + if len(check_res.duplicated) > 1: + _duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ + "them in {} at the same time.\n".format(check_res.duplicated, + func_signature, + check_res.duplicated, + previous_func_signature) + else: + _duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ + "it in {} at the same time.\n".format(check_res.duplicated, + func_signature, + check_res.duplicated, + previous_func_signature) + _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) + if _number_errs > 0: + _error_str = '' + if _number_errs > 1: + count = 1 + if _missing: + _error_str += '({}).{}'.format(count, _missing) + count += 1 + if _duplicated: + _error_str += '({}).{}'.format(count, _duplicated) + count += 1 + if _unused and check_level == STRICT_CHECK_LEVEL: + _error_str += '({}).{}'.format(count, _unused) + else: + if _unused: + if check_level == STRICT_CHECK_LEVEL: + # TODO 这里可能需要自定义一些Error类型 + _error_str = _unused + elif check_level == WARNING_CHECK_LEVEL: + _unused = _unused.strip() + warnings.warn(_unused) + else: + _error_str = _missing + _duplicated + if _error_str: + raise ValueError(_error_str) + + + if __name__ == '__main__': import torch @@ -430,5 +498,13 @@ if __name__ == '__main__': # import inspect # print(inspect.getfullargspec(model.forward)) + import numpy as np + + a = [1, 3] + np.asarray(a) + + import pandas + df = pandas.DataFrame(fake_data_dict) + df.infer_objects()