Browse Source

对trainer中check code的报错信息进行了增强;将tester中的output修改为pred_dict

tags/v0.2.0^2
yh 6 years ago
parent
commit
77f8ac77da
4 changed files with 102 additions and 35 deletions
  1. +3
    -3
      fastNLP/core/metrics.py
  2. +6
    -6
      fastNLP/core/tester.py
  3. +6
    -6
      fastNLP/core/trainer.py
  4. +87
    -20
      fastNLP/core/utils.py

+ 3
- 3
fastNLP/core/metrics.py View File

@@ -96,7 +96,7 @@ class MetricBase(object):
will be conducted) will be conducted)
:param pred_dict: usually the output of forward or prediction function :param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target.. :param target_dict: usually features set as target..
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`.
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`.
:return: :return:
""" """
if not callable(self.evaluate): if not callable(self.evaluate):
@@ -148,8 +148,8 @@ class MetricBase(object):
missing = check_res.missing missing = check_res.missing
replaced_missing = list(missing) replaced_missing = list(missing)
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \
f"in `{get_func_signature(self.evaluate)}`)"
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"


check_res = CheckRes(missing=replaced_missing, check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused, unused=check_res.unused,


+ 6
- 6
fastNLP/core/tester.py View File

@@ -51,19 +51,18 @@ class Tester(object):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
self._mode(network, is_test=True) self._mode(network, is_test=True)
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
eval_results = {} eval_results = {}
try: try:
with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self._predict_func, batch_x)
if not isinstance(prediction, dict):
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(prediction)}.")
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics: for metric in self.metrics:
metric(prediction, batch_y)
metric(pred_dict, batch_y)
for metric in self.metrics: for metric in self.metrics:
eval_result = metric.get_metric() eval_result = metric.get_metric()
if not isinstance(eval_result, dict): if not isinstance(eval_result, dict):
@@ -74,7 +73,8 @@ class Tester(object):
except CheckError as e: except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func) prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths, check_level=0)
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)


if self.verbose >= 1: if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results))) print("[tester] \n{}".format(self._format_eval_results(eval_results)))


+ 6
- 6
fastNLP/core/trainer.py View File

@@ -311,14 +311,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
batch_x=batch_x, check_level=check_level) batch_x=batch_x, check_level=check_level)


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward) func_signature = get_func_signature(model.forward)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")


# loss check # loss check
try: try:
loss = losser(output, batch_y)
loss = losser(pred_dict, batch_y)
# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
@@ -333,8 +333,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
except CheckError as e: except CheckError as e:
pre_func_signature = get_func_signature(model.forward) pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=dataset, check_level=check_level)
model.zero_grad() model.zero_grad()
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break break


+ 87
- 20
fastNLP/core/utils.py View File

@@ -229,29 +229,72 @@ WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2


def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
output:dict, batch_y:dict, check_level=0):
pred_dict:dict, target_dict:dict, dataset, check_level=0):
errs = [] errs = []
_unused = []
unuseds = []
_unused_field = []
_unused_param = []
suggestions = []
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)")
errs.append(f"\tvarargs: *{check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")

if check_res.unused:
for _unused in check_res.unused:
if _unused in target_dict:
_unused_field.append(_unused)
else:
_unused_param.append(_unused)
if _unused_field:
unuseds.append([f"\tunused field: {_unused_field}"])
if _unused_param:
unuseds.append([f"\tunused param: {_unused_param}"])

if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`"
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).")
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)

if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now "
f"target is {list(target_dict.keys())}).")
if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp)

if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)


if len(errs)>0: if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
sugg_str = ""
if len(suggestions)>1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if check_res.unused:
if check_level == WARNING_CHECK_LEVEL: if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
_unused_warn = f'{check_res.unused} is not used by {func_signature}.'
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)




@@ -260,21 +303,45 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
func_signature = get_func_signature(forward_func) func_signature = get_func_signature(forward_func)


errs = [] errs = []
suggestions = []
_unused = [] _unused = []


if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
errs.append(f"\tvarargs: {check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. "
f"Please set {check_res.missing} as input.")
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
if check_res.unused:
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \
f"rename the field in `unused field:`."
suggestions.append(_tmp)

if check_res.unused: if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
_unused = [f"\tunused field: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)


if len(errs)>0: if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
sugg_str = ""
if len(suggestions)>1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if _unused: if _unused:
if check_level == WARNING_CHECK_LEVEL: if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.' _unused_warn = _unused[0] + f' in {func_signature}.'


Loading…
Cancel
Save