|
@@ -229,29 +229,72 @@ WARNING_CHECK_LEVEL = 1 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
|
|
|
|
|
|
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, |
|
|
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, |
|
|
output:dict, batch_y:dict, check_level=0): |
|
|
|
|
|
|
|
|
pred_dict:dict, target_dict:dict, dataset, check_level=0): |
|
|
errs = [] |
|
|
errs = [] |
|
|
_unused = [] |
|
|
|
|
|
|
|
|
unuseds = [] |
|
|
|
|
|
_unused_field = [] |
|
|
|
|
|
_unused_param = [] |
|
|
|
|
|
suggestions = [] |
|
|
if check_res.varargs: |
|
|
if check_res.varargs: |
|
|
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " |
|
|
|
|
|
f"please delete it.)") |
|
|
|
|
|
|
|
|
errs.append(f"\tvarargs: *{check_res.varargs}") |
|
|
|
|
|
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") |
|
|
|
|
|
|
|
|
|
|
|
if check_res.unused: |
|
|
|
|
|
for _unused in check_res.unused: |
|
|
|
|
|
if _unused in target_dict: |
|
|
|
|
|
_unused_field.append(_unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_unused_param.append(_unused) |
|
|
|
|
|
if _unused_field: |
|
|
|
|
|
unuseds.append([f"\tunused field: {_unused_field}"]) |
|
|
|
|
|
if _unused_param: |
|
|
|
|
|
unuseds.append([f"\tunused param: {_unused_param}"]) |
|
|
|
|
|
|
|
|
if check_res.missing: |
|
|
if check_res.missing: |
|
|
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`" |
|
|
|
|
|
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).") |
|
|
|
|
|
|
|
|
errs.append(f"\tmissing param: {check_res.missing}") |
|
|
|
|
|
_miss_in_dataset = [] |
|
|
|
|
|
_miss_out_dataset = [] |
|
|
|
|
|
for _miss in check_res.missing: |
|
|
|
|
|
if '(' in _miss: |
|
|
|
|
|
# if they are like 'SomeParam(assign to xxx)' |
|
|
|
|
|
_miss = _miss.split('(')[0] |
|
|
|
|
|
if _miss in dataset: |
|
|
|
|
|
_miss_in_dataset.append(_miss) |
|
|
|
|
|
else: |
|
|
|
|
|
_miss_out_dataset.append(_miss) |
|
|
|
|
|
|
|
|
|
|
|
if _miss_in_dataset: |
|
|
|
|
|
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now " |
|
|
|
|
|
f"target is {list(target_dict.keys())}).") |
|
|
|
|
|
if _miss_out_dataset: |
|
|
|
|
|
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " |
|
|
|
|
|
f"target is {list(target_dict.keys())}) or output it " |
|
|
|
|
|
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") |
|
|
|
|
|
if _unused_field: |
|
|
|
|
|
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " |
|
|
|
|
|
suggestions.append(_tmp) |
|
|
|
|
|
|
|
|
if check_res.duplicated: |
|
|
if check_res.duplicated: |
|
|
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " |
|
|
|
|
|
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") |
|
|
|
|
|
if check_res.unused: |
|
|
|
|
|
_unused = [f"\tunused param: {check_res.unused}"] |
|
|
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
errs.extend(_unused) |
|
|
|
|
|
|
|
|
errs.append(f"\tduplicated param: {check_res.duplicated}.") |
|
|
|
|
|
suggestions.append(f"Delete {check_res.duplicated} in the output of " |
|
|
|
|
|
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") |
|
|
|
|
|
|
|
|
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
errs.extend(unuseds) |
|
|
|
|
|
|
|
|
if len(errs)>0: |
|
|
if len(errs)>0: |
|
|
errs.insert(0, f'The following problems occurred when calling {func_signature}') |
|
|
errs.insert(0, f'The following problems occurred when calling {func_signature}') |
|
|
raise NameError('\n'.join(errs)) |
|
|
|
|
|
if _unused: |
|
|
|
|
|
|
|
|
sugg_str = "" |
|
|
|
|
|
if len(suggestions)>1: |
|
|
|
|
|
for idx, sugg in enumerate(suggestions): |
|
|
|
|
|
sugg_str += f'({idx+1}). {sugg}' |
|
|
|
|
|
else: |
|
|
|
|
|
sugg_str += suggestions[0] |
|
|
|
|
|
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str |
|
|
|
|
|
raise NameError(err_str) |
|
|
|
|
|
if check_res.unused: |
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
_unused_warn = _unused[0] + f' in {func_signature}.' |
|
|
|
|
|
|
|
|
_unused_warn = f'{check_res.unused} is not used by {func_signature}.' |
|
|
warnings.warn(message=_unused_warn) |
|
|
warnings.warn(message=_unused_warn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -260,21 +303,45 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): |
|
|
func_signature = get_func_signature(forward_func) |
|
|
func_signature = get_func_signature(forward_func) |
|
|
|
|
|
|
|
|
errs = [] |
|
|
errs = [] |
|
|
|
|
|
suggestions = [] |
|
|
_unused = [] |
|
|
_unused = [] |
|
|
|
|
|
|
|
|
if check_res.varargs: |
|
|
if check_res.varargs: |
|
|
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") |
|
|
|
|
|
|
|
|
errs.append(f"\tvarargs: {check_res.varargs}") |
|
|
|
|
|
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") |
|
|
if check_res.missing: |
|
|
if check_res.missing: |
|
|
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. " |
|
|
|
|
|
f"Please set {check_res.missing} as input.") |
|
|
|
|
|
|
|
|
errs.append(f"\tmissing param: {check_res.missing}") |
|
|
|
|
|
_miss_in_dataset = [] |
|
|
|
|
|
_miss_out_dataset = [] |
|
|
|
|
|
for _miss in check_res.missing: |
|
|
|
|
|
if _miss in dataset: |
|
|
|
|
|
_miss_in_dataset.append(_miss) |
|
|
|
|
|
else: |
|
|
|
|
|
_miss_out_dataset.append(_miss) |
|
|
|
|
|
if _miss_in_dataset: |
|
|
|
|
|
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") |
|
|
|
|
|
if _miss_out_dataset: |
|
|
|
|
|
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " |
|
|
|
|
|
if check_res.unused: |
|
|
|
|
|
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \ |
|
|
|
|
|
f"rename the field in `unused field:`." |
|
|
|
|
|
suggestions.append(_tmp) |
|
|
|
|
|
|
|
|
if check_res.unused: |
|
|
if check_res.unused: |
|
|
_unused = [f"\tunused param: {check_res.unused}"] |
|
|
|
|
|
|
|
|
_unused = [f"\tunused field: {check_res.unused}"] |
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
errs.extend(_unused) |
|
|
errs.extend(_unused) |
|
|
|
|
|
|
|
|
if len(errs)>0: |
|
|
if len(errs)>0: |
|
|
errs.insert(0, f'The following problems occurred when calling {func_signature}') |
|
|
errs.insert(0, f'The following problems occurred when calling {func_signature}') |
|
|
raise NameError('\n'.join(errs)) |
|
|
|
|
|
|
|
|
sugg_str = "" |
|
|
|
|
|
if len(suggestions)>1: |
|
|
|
|
|
for idx, sugg in enumerate(suggestions): |
|
|
|
|
|
sugg_str += f'({idx+1}). {sugg}' |
|
|
|
|
|
else: |
|
|
|
|
|
sugg_str += suggestions[0] |
|
|
|
|
|
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str |
|
|
|
|
|
raise NameError(err_str) |
|
|
if _unused: |
|
|
if _unused: |
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
_unused_warn = _unused[0] + f' in {func_signature}.' |
|
|
_unused_warn = _unused[0] + f' in {func_signature}.' |
|
|