Browse Source

bug fix in LossInForward

tags/v0.2.0^2
yh 6 years ago
parent
commit
cd83866527
3 changed files with 18 additions and 13 deletions
  1. +2
    -1
      fastNLP/core/losses.py
  2. +13
    -9
      fastNLP/core/utils.py
  3. +3
    -3
      test/core/test_trainer.py

+ 2
- 1
fastNLP/core/losses.py View File

@@ -221,7 +221,8 @@ class LossInForward(LossBase):

def get_loss(self, **kwargs):
if self.loss_key not in kwargs:
check_res = CheckRes(missing=[self.loss_key],
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \
f"in `{self.__class__.__name__}`"],
unused=[],
duplicated=[],
required=[],


+ 13
- 9
fastNLP/core/utils.py View File

@@ -257,7 +257,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward

module_name = ''
module_name = func_signature.split('.')[0]
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
import re
@@ -265,15 +265,19 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
unmapped_missing = []
input_func_map = {}
for _miss in check_res.missing:
fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
input_func_map[_miss] = fun_arg
if fun_arg == _miss:
unmapped_missing.append(_miss)
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
if len(matches) == 2:
fun_arg, module_name = matches
input_func_map[_miss] = fun_arg
if fun_arg == _miss:
unmapped_missing.append(_miss)
else:
mapped_missing.append(_miss)
else:
mapped_missing.append(_miss)
unmapped_missing.append(_miss)

for _miss in mapped_missing:
if _miss in dataset:
@@ -281,7 +285,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else:
_tmp = ''
if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}."
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
@@ -293,11 +297,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else:
_tmp = ''
if check_res.unused:
_tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}."
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.'
suggestions.append(_tmp)

if check_res.duplicated:


+ 3
- 3
test/core/test_trainer.py View File

@@ -159,8 +159,8 @@ class TrainerTestGround(unittest.TestCase):
def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,是否可以正确提示unused
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.set_input('x1', 'x_unused', 'y', flag=True)
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
@@ -170,7 +170,7 @@ class TrainerTestGround(unittest.TestCase):
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}
return {'losses': loss}

model = Model()
with self.assertRaises(NameError):


Loading…
Cancel
Save