|
@@ -237,14 +237,10 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
|
|
|
|
|
# loss check |
|
|
# loss check |
|
|
if batch_count == 0: |
|
|
if batch_count == 0: |
|
|
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) |
|
|
|
|
|
if len(_dict) != 0: |
|
|
|
|
|
pass |
|
|
|
|
|
loss_input = _build_args(model.loss, **output, **batch_y) |
|
|
|
|
|
loss = model.loss(**loss_input) |
|
|
|
|
|
if batch_count == 0: |
|
|
|
|
|
if isinstance(loss, torch.Tensor): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
_check_loss(model=model, model_func=model.get_loss, check_level=check_level, |
|
|
|
|
|
output=output, batch_y=batch_y) |
|
|
|
|
|
loss_input = _build_args(model.get_loss, **output, **batch_y) |
|
|
|
|
|
loss = model.get_loss(**loss_input) |
|
|
|
|
|
|
|
|
# check loss output |
|
|
# check loss output |
|
|
if batch_count == 0: |
|
|
if batch_count == 0: |
|
@@ -281,7 +277,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
truths[k].append(v) |
|
|
truths[k].append(v) |
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
break |
|
|
break |
|
|
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level, |
|
|
|
|
|
|
|
|
_check_loss(model=model, model_func=model.evaluate, check_level=check_level, |
|
|
output=outputs, batch_y=truths) |
|
|
output=outputs, batch_y=truths) |
|
|
refined_input = _build_args(model.evaluate, **outputs, **truths) |
|
|
refined_input = _build_args(model.evaluate, **outputs, **truths) |
|
|
metrics = model.evaluate(**refined_input) |
|
|
metrics = model.evaluate(**refined_input) |
|
@@ -323,16 +319,17 @@ def _check_forward_error(model, model_func, check_level, batch_x): |
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
warnings.warn(message=_unused) |
|
|
warnings.warn(message=_unused) |
|
|
|
|
|
|
|
|
def _check_loss_evaluate(model, model_func, check_level, output, batch_y): |
|
|
|
|
|
|
|
|
def _check_loss(model, model_func, check_level, output, batch_y): |
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
_missing = '' |
|
|
_missing = '' |
|
|
_unused = '' |
|
|
_unused = '' |
|
|
_duplicated = '' |
|
|
_duplicated = '' |
|
|
signature_str = get_func_signature(model_func) |
|
|
signature_str = get_func_signature(model_func) |
|
|
func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1]) |
|
|
|
|
|
forward_signature_str = get_func_signature(model.forward) |
|
|
|
|
|
forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, forward_signature_str[1:-1]) |
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
model_name = model.__class__.__name__ |
|
|
|
|
|
model_func_name = model_func.__name__ |
|
|
|
|
|
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) |
|
|
|
|
|
forward_signature_str = get_func_signature(model.forward) |
|
|
|
|
|
forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1]) |
|
|
if len(check_res.missing)>0: |
|
|
if len(check_res.missing)>0: |
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
"{}." \ |
|
|
"{}." \ |
|
@@ -384,6 +381,77 @@ def _check_loss_evaluate(model, model_func, check_level, output, batch_y): |
|
|
if _error_str: |
|
|
if _error_str: |
|
|
raise ValueError(_error_str) |
|
|
raise ValueError(_error_str) |
|
|
|
|
|
|
|
|
|
|
|
def _check_evaluate(model, model_func, check_level, output, batch_y): |
|
|
|
|
|
|
|
|
|
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
|
|
|
_missing = '' |
|
|
|
|
|
_unused = '' |
|
|
|
|
|
_duplicated = '' |
|
|
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
|
|
|
model_func_name = model_func.__name__ |
|
|
|
|
|
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) |
|
|
|
|
|
if hasattr(model, 'predict'): |
|
|
|
|
|
previous_func = model.predict |
|
|
|
|
|
previous_func_name = 'predict' |
|
|
|
|
|
else: |
|
|
|
|
|
previous_func = model.forward |
|
|
|
|
|
previous_func_name = 'forward' |
|
|
|
|
|
previous_signature_str = get_func_signature(previous_func) |
|
|
|
|
|
previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1]) |
|
|
|
|
|
if len(check_res.missing)>0: |
|
|
|
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
|
|
|
"{}." \ |
|
|
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
|
|
list(output.keys()), previous_func_signature, |
|
|
|
|
|
list(batch_y.keys())) |
|
|
|
|
|
if len(check_res.unused)>0: |
|
|
|
|
|
if len(check_res.unused) > 1: |
|
|
|
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
|
|
if len(check_res.duplicated)>0: |
|
|
|
|
|
if len(check_res.duplicated) > 1: |
|
|
|
|
|
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
|
|
"them in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
previous_func_signature) |
|
|
|
|
|
else: |
|
|
|
|
|
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
|
|
"it in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
previous_func_signature) |
|
|
|
|
|
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) |
|
|
|
|
|
if _number_errs > 0: |
|
|
|
|
|
_error_str = '' |
|
|
|
|
|
if _number_errs > 1: |
|
|
|
|
|
count = 1 |
|
|
|
|
|
if _missing: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _missing) |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if _duplicated: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _duplicated) |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if _unused and check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _unused) |
|
|
|
|
|
else: |
|
|
|
|
|
if _unused: |
|
|
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
|
|
_error_str = _unused |
|
|
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
|
|
_unused = _unused.strip() |
|
|
|
|
|
warnings.warn(_unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_error_str = _missing + _duplicated |
|
|
|
|
|
if _error_str: |
|
|
|
|
|
raise ValueError(_error_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
import torch |
|
|
import torch |
|
@@ -430,5 +498,13 @@ if __name__ == '__main__': |
|
|
# import inspect |
|
|
# import inspect |
|
|
# print(inspect.getfullargspec(model.forward)) |
|
|
# print(inspect.getfullargspec(model.forward)) |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
a = [1, 3] |
|
|
|
|
|
np.asarray(a) |
|
|
|
|
|
|
|
|
|
|
|
import pandas |
|
|
|
|
|
df = pandas.DataFrame(fake_data_dict) |
|
|
|
|
|
df.infer_objects() |
|
|
|
|
|
|
|
|
|
|
|
|