diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 58e6c09d..f392dd33 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -47,7 +47,7 @@ class FieldArray(object): assert self.is_input is True or self.is_target is True batch_size = len(indices) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 - if isinstance(self.content[0], int) or isinstance(self.content[0], float): + if not isiterable(self.content[0]): if self.dtype is None: self.dtype = np.int64 if isinstance(self.content[0], int) else np.double array = np.array([self.content[i] for i in indices], dtype=self.dtype) @@ -63,3 +63,10 @@ class FieldArray(object): def __len__(self): return len(self.content) + +def isiterable(content): + try: + _ = (e for e in content) + except TypeError: + return False + return True \ No newline at end of file diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 9538d3fc..eb727317 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,5 +1,9 @@ import time -from datetime import timedelta, datetime +from datetime import timedelta +from datetime import datetime + +import warnings +from collections import defaultdict import torch from tensorboardX import SummaryWriter @@ -12,13 +16,17 @@ from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester +from fastNLP.core.utils import _check_arg_dict_list +from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _syn_model_data +from fastNLP.core.utils import get_func_signature class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, n_epochs, batch_size, n_print, + def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1, dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), evaluator=Evaluator(), @@ -32,7 +40,7 @@ class Trainer(object): self.batch_size = int(batch_size) self.use_cuda = bool(use_cuda) self.save_path = str(save_path) - self.n_print = int(n_print) + self.print_every = int(print_every) self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) @@ -51,7 +59,7 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp - print(self.__dict__) + # print(self.__dict__) def train(self): """Start Training. @@ -70,17 +78,16 @@ class Trainer(object): epoch = 1 while epoch <= self.n_epochs: - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), - use_cuda=self.use_cuda) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) - self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print) + self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) if self.dev_data: self.do_validation() self.save_model(self.model, 'training_model_' + self.start_time) epoch += 1 - def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs): + def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): """Training process in one epoch. kwargs should contain: @@ -103,7 +110,7 @@ class Trainer(object): self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) - if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: + if self.print_every > 0 and self.step % self.print_every == 0: end = time.time() diff = timedelta(seconds=round(end - kwargs["start"])) print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( @@ -197,9 +204,6 @@ def best_eval_result(self, metrics): return False -from fastNLP.core.utils import _check_arg_dict_list -from fastNLP.core.utils import _build_args - DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 @@ -207,64 +211,209 @@ IGNORE_CHECK_LEVEL=0 WARNING_CHECK_LEVEL=1 STRICT_CHECK_LEVEL=2 - -def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): - # check loss 方法 +def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): + # check get_loss 方法 + model_name = model.__class__.__name__ if not hasattr(model, 'get_loss'): - raise AttributeError("{} has to have a 'get_loss' function.".format(type(model))) + raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): + _syn_model_data(model, batch_x, batch_y) + # forward check if batch_count==0: - check_res = _check_arg_dict_list(model.forward, batch_x) - _info_str = '' - if len(check_res.missing)>0: - if check_level == WARNING_CHECK_LEVEL: - for field_name in check_res.missing: - if hasattr(dataset, field_name): - _info_str += "{} " - _info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" - _info_str += "" - print("") - if len(check_res.unused)>0: - if check_level == WARNING_CHECK_LEVEL: - _info_str += "" + _check_forward_error(model=model, model_func=model.forward, check_level=check_level, + batch_x=batch_x) refined_batch_x = _build_args(model.forward, **batch_x) output = model(**refined_batch_x) + + assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name) + + # loss check if batch_count == 0: - _dict = _check_arg_dict_list(model.loss, [output, batch_y]) - if len(_dict)!=0: - pass - loss_input = _build_args(model.loss, **output, **batch_y) - loss = model.loss(**loss_input) - if batch_count == 0: - if isinstance(loss, torch.Tensor): - pass + _check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, + output=output, batch_y=batch_y) + loss_input = _build_args(model.get_loss, **output, **batch_y) + loss = model.get_loss(**loss_input) + # check loss output + if batch_count == 0: + if not isinstance(loss, torch.Tensor): + raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". + format(model_name, type(loss))) + if len(loss.size())!=0: + raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( + model_name, loss.size() + )) loss.backward() - - if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE: + model.zero_grad() + if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: break + if check_level > IGNORE_CHECK_LEVEL: + print('Finish checking training process.', flush=True) + - dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) if dev_data is not None: if not hasattr(model, 'evaluate'): - raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set" + raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" "dev_data to 'None'." - .format(type(model), type(model))) + .format(model_name)) + outputs, truths = defaultdict(list), defaultdict(list) + dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) + with torch.no_grad(): + for batch_count, (batch_x, batch_y) in enumerate(dev_batch): + _syn_model_data(model, batch_x, batch_y) + + refined_batch_x = _build_args(model.forward, **batch_x) + output = model(**refined_batch_x) + for k, v in output.items(): + outputs[k].append(v) + for k, v in batch_y.items(): + truths[k].append(v) + if batch_count+1>DEFAULT_CHECK_NUM_BATCH: + break + _check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level, + output=outputs, batch_y=truths) + print("Finish checking evaluate process.", flush=True) + + +def _check_forward_error(model, model_func, check_level, batch_x): + check_res = _check_arg_dict_list(model_func, batch_x) + _missing = '' + _unused = '' + signature_str = get_func_signature(model_func) + func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) + if len(check_res.missing)!=0: + _missing = "Function {} misses {}, only provided with {}, " \ + ".\n".format(func_signature, check_res.missing, + list(batch_x.keys())) + if len(check_res.unused)!=0: + if len(check_res.unused) > 1: + _unused = "{} are not used ".format(check_res.unused) + else: + _unused = "{} is not used ".format(check_res.unused) + _unused += "in function {}.\n".format(func_signature) + if _missing: + if not _unused and STRICT_CHECK_LEVEL: + _error_str = "(1).{} (2).{}".format(_missing, _unused) + else: + _error_str = _missing + # TODO 这里可能需要自定义一些Error类型 + raise TypeError(_error_str) + if _unused: + if check_level == STRICT_CHECK_LEVEL: + # TODO 这里可能需要自定义一些Error类型 + raise ValueError(_unused) + elif check_level == WARNING_CHECK_LEVEL: + warnings.warn(message=_unused, ) + +def _check_loss_evaluate(model, model_func, check_level, output, batch_y): + check_res = _check_arg_dict_list(model_func, [output, batch_y]) + _missing = '' + _unused = '' + _duplicated = '' + signature_str = get_func_signature(model_func) + func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1]) + forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, signature_str[1:-1]) + model_name = model.__class__.__name__ + if len(check_res.missing)>0: + _missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ + "{}." \ + .format(func_signature, check_res.missing, + list(output.keys()), model_name, + list(batch_y.keys())) + if len(check_res.unused)>0: + if len(check_res.unused) > 1: + _unused = "{} are not used ".format(check_res.unused) + else: + _unused = "{} is not used ".format(check_res.unused) + _unused += "in function {}.\n".format(func_signature) + if len(check_res.duplicated)>0: + if len(check_res.duplicated) > 1: + _duplicated = "Duplicated keys: {} are detected in function {}. Don't set {} as target and output " \ + "them in {} at the same time.\n".format(check_res.duplicated, + func_signature, + check_res.duplicated, + forward_func_signature) + else: + _duplicated = "Duplicated key: {} is detected in function {}. Don't set {} as target and output " \ + "it in {} at the same time.\n".format(check_res.duplicated, + func_signature, + check_res.duplicated, + forward_func_signature) + _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) + if _number_errs > 0: + _error_str = '' + if _number_errs > 1: + count = 1 + if _missing: + _error_str += '({}).{}'.format(count, _missing) + count += 1 + if _duplicated: + _error_str += '({}).{}'.format(count, _duplicated) + count += 1 + if _unused and check_level == STRICT_CHECK_LEVEL: + _error_str += '({}).{}'.format(count, _unused) + else: + if _unused: + if check_level == STRICT_CHECK_LEVEL: + # TODO 这里可能需要自定义一些Error类型 + _error_str = _unused + elif check_level == WARNING_CHECK_LEVEL: + _unused = _unused.strip() + warnings.warn(_unused) + else: + _error_str = _missing + _duplicated + if _error_str: + raise ValueError(_error_str) + + +if __name__ == '__main__': + import torch + from torch import nn + from fastNLP.core.dataset import DataSet + import numpy as np + + class Model(nn.Module): + def __init__(self): + super().__init__() + + self. fc1 = nn.Linear(10, 2) + + def forward(self, words, chars): + output = {} + output['prediction'] = torch.randn(3, 4) + output['words'] = words + return output + + def get_loss(self, prediction, labels, words): + return torch.mean(self.fc1.weight) + + def evaluate(self, prediction, labels, demo=2): + return 0 + + model = Model() + + num_samples = 4 + fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6), + 'labels': np.random.randint(2, size=(num_samples,))} + + + dataset = DataSet(fake_data_dict) + dataset.set_input(words=True, chars=True) + dataset.set_target(labels=True) - for batch_count, (batch_x, batch_y) in enumerate(dev_batch): - if batch_count == 0: - _dict = _check_arg_dict_list(model.evaluate, [output, batch_y]) + # trainer = Trainer(dataset, model) - if len(_dict)!=0: - pass - refined_batch_x = _build_args(model.forward, **batch_x) - output = model(**refined_batch_x) + _check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) + # _check_forward_error(model=model, model_func=model.forward, check_level=1, + # batch_x=fake_data_dict) + # import inspect + # print(inspect.getfullargspec(model.forward)) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index ca38e45e..84ed11e6 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -4,7 +4,7 @@ import inspect from collections import namedtuple from collections import Counter -CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=True) +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) def save_pickle(obj, pickle_path, file_name): @@ -55,8 +55,11 @@ def _build_args(func, **kwargs): if spect.varkw is not None: return kwargs needed_args = set(spect.args) - start_idx = len(spect.args) - len(spect.defaults) - output = {name: default for name, default in zip(spect.args[start_idx:], spect.defaults)} + defaults = [] + if spect.defaults is not None: + defaults = [arg for arg in spect.defaults] + start_idx = len(spect.args) - len(defaults) + output = {name: default for name, default in zip(spect.args[start_idx:], defaults)} output.update({name: val for name, val in kwargs.items() if name in needed_args}) return output @@ -71,8 +74,11 @@ def _check_arg_dict_list(func, args): assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) spect = inspect.getfullargspec(func) assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) - all_args = set(spect.args) - start_idx = len(spect.args) - len(spect.defaults) + all_args = set([arg for arg in spect.args if arg!='self']) + defaults = [] + if spect.defaults is not None: + defaults = [arg for arg in spect.defaults] + start_idx = len(spect.args) - len(defaults) default_args = set(spect.args[start_idx:]) require_args = all_args - default_args input_arg_count = Counter() @@ -87,3 +93,23 @@ def _check_arg_dict_list(func, args): duplicated=duplicated, required=list(require_args), all_needed=list(all_args)) + +def get_func_signature(func): + # function signature, does not include self. + signature = inspect.signature(func) + signature_str = str(signature) + return signature_str + + +# move data to model's device +import torch +def _syn_model_data(model, *args): + assert len(model.state_dict())!=0, "This model has no parameter." + device = model.parameters().__next__().device + for arg in args: + if isinstance(arg, dict): + for key, value in arg.items(): + if isinstance(value, torch.Tensor): + arg[key] = value.to(device) + else: + raise ValueError("Only support dict type right now.") \ No newline at end of file