From cd83866527c8b947f072d473660623343aee3919 Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 6 Dec 2018 11:16:25 +0800 Subject: [PATCH] bug fix in LossInForward --- fastNLP/core/losses.py | 3 ++- fastNLP/core/utils.py | 22 +++++++++++++--------- test/core/test_trainer.py | 6 +++--- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index a4976540..fbd64e81 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -221,7 +221,8 @@ class LossInForward(LossBase): def get_loss(self, **kwargs): if self.loss_key not in kwargs: - check_res = CheckRes(missing=[self.loss_key], + check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ + f"in `{self.__class__.__name__}`"], unused=[], duplicated=[], required=[], diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 508d5587..c58e4f71 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -257,7 +257,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re if _unused_param: unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward - module_name = '' + module_name = func_signature.split('.')[0] if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") import re @@ -265,15 +265,19 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re unmapped_missing = [] input_func_map = {} for _miss in check_res.missing: - fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) if '(' in _miss: # if they are like 'SomeParam(assign to xxx)' _miss = _miss.split('(')[0] - input_func_map[_miss] = fun_arg - if fun_arg == _miss: - unmapped_missing.append(_miss) + matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) + if len(matches) == 2: + fun_arg, module_name = matches + input_func_map[_miss] = fun_arg + if fun_arg == _miss: + unmapped_missing.append(_miss) + else: + mapped_missing.append(_miss) else: - mapped_missing.append(_miss) + unmapped_missing.append(_miss) for _miss in mapped_missing: if _miss in dataset: @@ -281,7 +285,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re else: _tmp = '' if check_res.unused: - _tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}." + _tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." if _tmp: _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' else: @@ -293,11 +297,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re else: _tmp = '' if check_res.unused: - _tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}." + _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." if _tmp: _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' else: - _tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' + _tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.' suggestions.append(_tmp) if check_res.duplicated: diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 6f6fbbf3..2f2505e4 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -159,8 +159,8 @@ class TrainerTestGround(unittest.TestCase): def test_trainer_suggestion4(self): # 检查报错提示能否正确提醒用户 # 这里传入forward需要的数据,是否可以正确提示unused - dataset = prepare_fake_dataset2('x1', 'x_unused') - dataset.set_input('x1', 'x_unused', 'y', flag=True) + dataset = prepare_fake_dataset2('x1', 'x2') + dataset.set_input('x1', 'x2', 'y', flag=True) class Model(nn.Module): def __init__(self): super().__init__() @@ -170,7 +170,7 @@ class TrainerTestGround(unittest.TestCase): x2 = self.fc(x2) x = x1 + x2 loss = F.cross_entropy(x, y) - return {'loss': loss} + return {'losses': loss} model = Model() with self.assertRaises(NameError):