diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a21f2ded..d83e3936 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,7 +1,11 @@ import time -rom datetime import timedelta, datetime +from datetime import timedelta +from datetime import datetime +import warnings +from collections import defaultdict import os -import torch +import itertools + from tensorboardX import SummaryWriter from fastNLP.core.batch import Batch @@ -221,30 +225,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): - if batch_count == 0: - check_res = _check_arg_dict_list(model.forward, batch_x) - _info_str = '' - if len(check_res.missing) > 0: - if check_level == WARNING_CHECK_LEVEL: - for field_name in check_res.missing: - if hasattr(dataset, field_name): - _info_str += "{} " - _info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" - _info_str += "" - print("") - if len(check_res.unused) > 0: - if check_level == WARNING_CHECK_LEVEL: - _info_str += "" + _syn_model_data(model, batch_x, batch_y) + # forward check + if batch_count==0: + _check_forward_error(model_func=model.forward, check_level=check_level, + batch_x=batch_x) refined_batch_x = _build_args(model.forward, **batch_x) output = model(**refined_batch_x) - signature_str = get_func_signature(model.forward) - func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) + func_signature = get_func_signature(model.forward) assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) # loss check if batch_count == 0: - _check_loss(model=model, model_func=model.get_loss, check_level=check_level, + _check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, output=output, batch_y=batch_y) loss_input = _build_args(model.get_loss, **output, **batch_y) loss = model.get_loss(**loss_input) @@ -276,32 +270,42 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No for batch_count, (batch_x, batch_y) in enumerate(dev_batch): _syn_model_data(model, batch_x, batch_y) - refined_batch_x = _build_args(model.forward, **batch_x) - output = model(**refined_batch_x) + if hasattr(model, 'predict'): + refined_batch_x = _build_args(model.predict, **batch_x) + prev_func = model.predict + output = prev_func(**refined_batch_x) + func_signature = get_func_signature(model.predict) + assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) + else: + refined_batch_x = _build_args(model.forward, **batch_x) + prev_func = model.forward + output = prev_func(**refined_batch_x) for k, v in output.items(): outputs[k].append(v) for k, v in batch_y.items(): truths[k].append(v) if batch_count+1>DEFAULT_CHECK_NUM_BATCH: break - _check_loss(model=model, model_func=model.evaluate, check_level=check_level, + for k, v in outputs.items(): + outputs[k] = itertools.chain(*v) + for k, v in truths.items(): + truths[k] = itertools.chain(*v) + _check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, output=outputs, batch_y=truths) refined_input = _build_args(model.evaluate, **outputs, **truths) metrics = model.evaluate(**refined_input) - signature_str = get_func_signature(model.evaluate) - func_signature = '{}.evaluate(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) + func_signature = get_func_signature(model.evaluate) assert isinstance(metrics, dict), "The return value of {} should be dict.". \ format(func_signature) if check_level > IGNORE_CHECK_LEVEL: print("Finish checking evaluate process.", flush=True) -def _check_forward_error(model, model_func, check_level, batch_x): +def _check_forward_error(model_func, check_level, batch_x): check_res = _check_arg_dict_list(model_func, batch_x) _missing = '' _unused = '' - signature_str = get_func_signature(model_func) - func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) + func_signature = get_func_signature(model_func) if len(check_res.missing)!=0: _missing = "Function {} misses {}, only provided with {}, " \ ".\n".format(func_signature, check_res.missing, @@ -313,8 +317,8 @@ def _check_forward_error(model, model_func, check_level, batch_x): _unused = "{} is not used ".format(check_res.unused) _unused += "in function {}.\n".format(func_signature) if _missing: - if not _unused and STRICT_CHECK_LEVEL: - _error_str = "(1).{} (2).{}".format(_missing, _unused) + if len(_unused)>0 and STRICT_CHECK_LEVEL: + _error_str = "(1).{}\n(2).{}".format(_missing, _unused) else: _error_str = _missing # TODO 这里可能需要自定义一些Error类型 @@ -326,91 +330,19 @@ def _check_forward_error(model, model_func, check_level, batch_x): elif check_level == WARNING_CHECK_LEVEL: warnings.warn(message=_unused) -def _check_loss(model, model_func, check_level, output, batch_y): - check_res = _check_arg_dict_list(model_func, [output, batch_y]) - _missing = '' - _unused = '' - _duplicated = '' - signature_str = get_func_signature(model_func) - model_name = model.__class__.__name__ - model_func_name = model_func.__name__ - func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) - forward_signature_str = get_func_signature(model.forward) - forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1]) - if len(check_res.missing)>0: - _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ - "{}." \ - .format(func_signature, check_res.missing, - list(output.keys()), model_name, - list(batch_y.keys())) - if len(check_res.unused)>0: - if len(check_res.unused) > 1: - _unused = "{} are not used ".format(check_res.unused) - else: - _unused = "{} is not used ".format(check_res.unused) - _unused += "in function {}.\n".format(func_signature) - if len(check_res.duplicated)>0: - if len(check_res.duplicated) > 1: - _duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ - "them in {} at the same time.\n".format(check_res.duplicated, - func_signature, - check_res.duplicated, - forward_func_signature) - else: - _duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ - "it in {} at the same time.\n".format(check_res.duplicated, - func_signature, - check_res.duplicated, - forward_func_signature) - _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) - if _number_errs > 0: - _error_str = '' - if _number_errs > 1: - count = 1 - if _missing: - _error_str += '({}).{}'.format(count, _missing) - count += 1 - if _duplicated: - _error_str += '({}).{}'.format(count, _duplicated) - count += 1 - if _unused and check_level == STRICT_CHECK_LEVEL: - _error_str += '({}).{}'.format(count, _unused) - else: - if _unused: - if check_level == STRICT_CHECK_LEVEL: - # TODO 这里可能需要自定义一些Error类型 - _error_str = _unused - elif check_level == WARNING_CHECK_LEVEL: - _unused = _unused.strip() - warnings.warn(_unused) - else: - _error_str = _missing + _duplicated - if _error_str: - raise ValueError(_error_str) - -def _check_evaluate(model, model_func, check_level, output, batch_y): +def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): - check_res = _check_arg_dict_list(model_func, [output, batch_y]) + check_res = _check_arg_dict_list(func, [output, batch_y]) _missing = '' _unused = '' _duplicated = '' - signature_str = get_func_signature(model_func) - model_name = model.__class__.__name__ - model_func_name = model_func.__name__ - func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) - if hasattr(model, 'predict'): - previous_func = model.predict - previous_func_name = 'predict' - else: - previous_func = model.forward - previous_func_name = 'forward' - previous_signature_str = get_func_signature(previous_func) - previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1]) + func_signature = get_func_signature(func) + prev_func_signature = get_func_signature(prev_func) if len(check_res.missing)>0: - _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ - "{}." \ + _missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \ + "{}(from target in Dataset)." \ .format(func_signature, check_res.missing, - list(output.keys()), previous_func_signature, + list(output.keys()), prev_func_signature, list(batch_y.keys())) if len(check_res.unused)>0: if len(check_res.unused) > 1: @@ -424,40 +356,38 @@ def _check_evaluate(model, model_func, check_level, output, batch_y): "them in {} at the same time.\n".format(check_res.duplicated, func_signature, check_res.duplicated, - previous_func_signature) + prev_func_signature) else: _duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ "it in {} at the same time.\n".format(check_res.duplicated, func_signature, check_res.duplicated, - previous_func_signature) + prev_func_signature) _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) if _number_errs > 0: - _error_str = '' + _error_strs = [] if _number_errs > 1: count = 1 if _missing: - _error_str += '({}).{}'.format(count, _missing) + _error_strs.append('({}).{}'.format(count, _missing)) count += 1 if _duplicated: - _error_str += '({}).{}'.format(count, _duplicated) + _error_strs.append('({}).{}'.format(count, _duplicated)) count += 1 if _unused and check_level == STRICT_CHECK_LEVEL: - _error_str += '({}).{}'.format(count, _unused) + _error_strs.append('({}).{}'.format(count, _unused)) else: if _unused: if check_level == STRICT_CHECK_LEVEL: # TODO 这里可能需要自定义一些Error类型 - _error_str = _unused + _error_strs.append(_unused) elif check_level == WARNING_CHECK_LEVEL: _unused = _unused.strip() warnings.warn(_unused) else: - _error_str = _missing + _duplicated - if _error_str: - raise ValueError(_error_str) - - + _error_strs = [_missing, _duplicated] + if _error_strs: + raise ValueError('\n'.join(_error_strs)) if __name__ == '__main__': @@ -478,11 +408,12 @@ if __name__ == '__main__': output['words'] = words return output - def get_loss(self, prediction, labels, words): + def get_loss(self, prediction, labels, words, seq_lens): return torch.mean(self.fc1.weight) def evaluate(self, prediction, labels, demo=2): - return 0 + return {} + model = Model() @@ -493,7 +424,7 @@ if __name__ == '__main__': dataset = DataSet(fake_data_dict) dataset.set_input(words=True, chars=True) - dataset.set_target(labels=True) + dataset.set_target(labels=True, words=True) # trainer = Trainer(dataset, model) @@ -505,13 +436,5 @@ if __name__ == '__main__': # import inspect # print(inspect.getfullargspec(model.forward)) - import numpy as np - - a = [1, 3] - np.asarray(a) - - import pandas - df = pandas.DataFrame(fake_data_dict) - df.infer_objects() diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d816136e..84faaece 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -95,10 +95,22 @@ def _check_arg_dict_list(func, args): all_needed=list(all_args)) def get_func_signature(func): - # function signature, does not include self. - signature = inspect.signature(func) - signature_str = str(signature) - return signature_str + # can only be used in function or class method + if inspect.ismethod(func): + class_name = func.__self__.__class__.__name__ + signature = inspect.signature(func) + signature_str = str(signature) + if len(signature_str)>2: + _self = '(self, ' + else: + _self = '(self' + signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:] + return signature_str + elif inspect.isfunction(func): + signature = inspect.signature(func) + signature_str = str(signature) + signature_str = func.__name__ + signature_str + return signature_str # move data to model's device