Browse Source

trainer修改

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
4a4b001047
1 changed files with 89 additions and 13 deletions
  1. +89
    -13
      fastNLP/core/trainer.py

+ 89
- 13
fastNLP/core/trainer.py View File

@@ -237,14 +237,10 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No


# loss check # loss check
if batch_count == 0: if batch_count == 0:
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict) != 0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass
_check_loss(model=model, model_func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)


# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
@@ -281,7 +277,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
truths[k].append(v) truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break break
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level,
_check_loss(model=model, model_func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths) output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths) refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input) metrics = model.evaluate(**refined_input)
@@ -323,16 +319,17 @@ def _check_forward_error(model, model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL: elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused) warnings.warn(message=_unused)


def _check_loss_evaluate(model, model_func, check_level, output, batch_y):
def _check_loss(model, model_func, check_level, output, batch_y):
check_res = _check_arg_dict_list(model_func, [output, batch_y]) check_res = _check_arg_dict_list(model_func, [output, batch_y])
_missing = '' _missing = ''
_unused = '' _unused = ''
_duplicated = '' _duplicated = ''
signature_str = get_func_signature(model_func) signature_str = get_func_signature(model_func)
func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1])
forward_signature_str = get_func_signature(model.forward)
forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, forward_signature_str[1:-1])
model_name = model.__class__.__name__ model_name = model.__class__.__name__
model_func_name = model_func.__name__
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1])
forward_signature_str = get_func_signature(model.forward)
forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1])
if len(check_res.missing)>0: if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \ "{}." \
@@ -384,6 +381,77 @@ def _check_loss_evaluate(model, model_func, check_level, output, batch_y):
if _error_str: if _error_str:
raise ValueError(_error_str) raise ValueError(_error_str)


def _check_evaluate(model, model_func, check_level, output, batch_y):

check_res = _check_arg_dict_list(model_func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
signature_str = get_func_signature(model_func)
model_name = model.__class__.__name__
model_func_name = model_func.__name__
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1])
if hasattr(model, 'predict'):
previous_func = model.predict
previous_func_name = 'predict'
else:
previous_func = model.forward
previous_func_name = 'forward'
previous_signature_str = get_func_signature(previous_func)
previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1])
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \
.format(func_signature, check_res.missing,
list(output.keys()), previous_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \
"them in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
previous_func_signature)
else:
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
previous_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_str = ''
if _number_errs > 1:
count = 1
if _missing:
_error_str += '({}).{}'.format(count, _missing)
count += 1
if _duplicated:
_error_str += '({}).{}'.format(count, _duplicated)
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_str += '({}).{}'.format(count, _unused)
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_str = _unused
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_str = _missing + _duplicated
if _error_str:
raise ValueError(_error_str)





if __name__ == '__main__': if __name__ == '__main__':
import torch import torch
@@ -430,5 +498,13 @@ if __name__ == '__main__':
# import inspect # import inspect
# print(inspect.getfullargspec(model.forward)) # print(inspect.getfullargspec(model.forward))


import numpy as np

a = [1, 3]
np.asarray(a)

import pandas
df = pandas.DataFrame(fake_data_dict)
df.infer_objects()





Loading…
Cancel
Save