From 3d91f2f024207c8bfc0dae62cdaead227f4558c7 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 1 Dec 2018 15:00:06 +0800 Subject: [PATCH 1/2] =?UTF-8?q?trainer=E8=BF=AD=E4=BB=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 18 ++++--- fastNLP/core/trainer.py | 117 +++++++++++++++++++++++++++------------- fastNLP/core/utils.py | 63 ++++++++++++++++++++-- 3 files changed, 148 insertions(+), 50 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index ee1354fe..5d264b80 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -6,33 +6,34 @@ import torch from fastNLP.core.batch import Batch from fastNLP.core.sampler import RandomSampler from fastNLP.core.utils import _build_args +from fastNLP.core.utils import get_func_signature class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ - def __init__(self, data, model, batch_size=16, use_cuda=False): + def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): super(Tester, self).__init__() self.use_cuda = use_cuda self.data = data self.batch_size = batch_size + self.verbose = verbose if torch.cuda.is_available() and self.use_cuda: self._model = model.cuda() else: self._model = model if hasattr(self._model, 'predict'): - assert callable(self._model.predict) + if not callable(self._model.predict): + raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " + f"for evaluation.") self._predict_func = self._model.predict else: self._predict_func = self._model - assert hasattr(model, 'evaluate') - self._evaluator = model.evaluate - self.eval_history = [] # evaluation results of all batches + def test(self): # turn on the testing mode; clean up the history network = self._model self.mode(network, is_test=True) - self.eval_history.clear() output, truths = defaultdict(list), defaultdict(list) data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) @@ -48,9 +49,10 @@ class Tester(object): output[k] = itertools.chain(*v) for k, v in truths.items(): truths[k] = itertools.chain(*v) - args = _build_args(self._evaluator, **output, **truths) + # args = _build_args(self._evaluator, **output, **truths) eval_results = self._evaluator(**args) - print("[tester] {}".format(self.print_eval_results(eval_results))) + if self.verbose >= 0: + print("[tester] {}".format(self.print_eval_results(eval_results))) self.mode(network, is_test=False) return eval_results diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6b0398b5..63eb963e 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -9,6 +9,7 @@ import shutil from tensorboardX import SummaryWriter import torch +from torch import nn from fastNLP.core.batch import Batch from fastNLP.core.loss import Loss @@ -21,12 +22,13 @@ from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _build_args from fastNLP.core.utils import _syn_model_data from fastNLP.core.utils import get_func_signature +from fastNLP.core.dataset import DataSet class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, + def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, **kwargs): @@ -35,6 +37,8 @@ class Trainer(object): self.train_data = train_data self.dev_data = dev_data # If None, No validation. self.model = model + self.losser = losser + self.metrics = metrics self.n_epochs = int(n_epochs) self.batch_size = int(batch_size) self.use_cuda = bool(use_cuda) @@ -43,23 +47,22 @@ class Trainer(object): self.validate_every = int(validate_every) self._best_accuracy = 0 - if need_check_code: - _check_code(dataset=train_data, model=model, dev_data=dev_data) - model_name = model.__class__.__name__ - assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) - self.loss_func = self.model.get_loss + # TODO check loss与metrics的类型 + + + + # TODO self._best_accuracy不能表现出当前的metric多种的情况 + if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer else: self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) - assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) - self.evaluator = self.model.evaluate - if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, + metrics=self.metrics, batch_size=self.batch_size, use_cuda=self.use_cuda) @@ -71,6 +74,38 @@ class Trainer(object): # print(self.__dict__) + def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1, + validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", + optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + **kwargs): + if not isinstance(train_data, DataSet): + raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\ + format(type(train_data))) + if not isinstance(model, nn.Module): + raise TypeError("The type of model must be torch.nn.Module, got {}.".\ + format(type(model))) + if losser is not None: + # TODO change + if not isinstance(losser, None): + raise TypeError("The type of losser must be xxx, got {}.".\ + format(type(losser))) + + # check metrics and dev_data + if (not metrics) and dev_data is not None: + raise ValueError("No metric for dev_data evaluation.") + if metrics and (dev_data is None): + raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") + + # check loss + if isinstance(losser, type): + self.losser = losser() + if not isinstance(self.losser, None): + raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.') + + if need_check_code: + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) + + def train(self): """Start Training. @@ -171,6 +206,9 @@ class Trainer(object): def data_forward(self, network, x): x = _build_args(network.forward, **x) y = network(**x) + if not isinstance(y, dict): + + raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") return y def grad_backward(self, loss): @@ -231,11 +269,11 @@ IGNORE_CHECK_LEVEL = 0 WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 -def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): +def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, + dev_data=None, + check_level=WARNING_CHECK_LEVEL): # check get_loss 方法 model_name = model.__class__.__name__ - if not hasattr(model, 'get_loss'): - raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): @@ -248,23 +286,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No refined_batch_x = _build_args(model.forward, **batch_x) output = model(**refined_batch_x) func_signature = get_func_signature(model.forward) - assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) + if not isinstance(output, dict): + raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") # loss check - if batch_count == 0: - _check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, - output=output, batch_y=batch_y) - loss_input = _build_args(model.get_loss, **output, **batch_y) - loss = model.get_loss(**loss_input) + if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss + # 需要保证output与batch_y是无歧义的? + # (1) output和batch_y长度为1 + # (2) output和batch_y的key是和losser接受的完全一致 + pass + + loss = losser(output, batch_y) # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): - raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". - format(model_name, type(loss))) + raise ValueError("The return value of {} should be torch.Tensor, but got {}.". + format(type(losser), type(loss))) if len(loss.size())!=0: - raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( - model_name, loss.size() + raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( + type(losser), loss.size() )) loss.backward() model.zero_grad() @@ -272,26 +313,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No break if dev_data is not None: - if not hasattr(model, 'evaluate'): - raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" - "dev_data to 'None'." - .format(model_name)) outputs, truths = defaultdict(list), defaultdict(list) dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) + # TODO 这里修改为使用tester + + with torch.no_grad(): for batch_count, (batch_x, batch_y) in enumerate(dev_batch): _syn_model_data(model, batch_x, batch_y) if hasattr(model, 'predict'): + if not callable(model.predict): + raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " + f"for evaluation.") refined_batch_x = _build_args(model.predict, **batch_x) prev_func = model.predict output = prev_func(**refined_batch_x) - func_signature = get_func_signature(model.predict) - assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) else: refined_batch_x = _build_args(model.forward, **batch_x) prev_func = model.forward output = prev_func(**refined_batch_x) + func_signature = get_func_signature(prev_func) + if not isinstance(output, dict): + raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") for k, v in output.items(): outputs[k].append(v) for k, v in batch_y.items(): @@ -299,16 +343,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No if batch_count+1>DEFAULT_CHECK_NUM_BATCH: break for k, v in outputs.items(): - outputs[k] = itertools.chain(*v) + outputs[k] = tuple(itertools.chain(*v)) for k, v in truths.items(): - truths[k] = itertools.chain(*v) - _check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, - output=outputs, batch_y=truths) - refined_input = _build_args(model.evaluate, **outputs, **truths) - metrics = model.evaluate(**refined_input) - func_signature = get_func_signature(model.evaluate) - assert isinstance(metrics, dict), "The return value of {} should be dict.". \ - format(func_signature) + truths[k] = tuple(itertools.chain(*v)) + #TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug + + + + + def _check_forward_error(model_func, check_level, batch_x): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 84faaece..8ffcc7bb 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -3,6 +3,7 @@ import inspect import os from collections import Counter from collections import namedtuple +import torch CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) @@ -95,7 +96,24 @@ def _check_arg_dict_list(func, args): all_needed=list(all_args)) def get_func_signature(func): - # can only be used in function or class method + """ + + Given a function or method, return its signature. + For example: + (1) function + def func(a, b='a', *args): + xxxx + get_func_signature(func) # 'func(a, b='a', *args)' + (2) method + class Demo: + def __init__(self): + xxx + def forward(self, a, b='a', **args) + demo = Demo() + get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' + :param func: a function or a method + :return: str or None + """ if inspect.ismethod(func): class_name = func.__self__.__class__.__name__ signature = inspect.signature(func) @@ -113,10 +131,16 @@ def get_func_signature(func): return signature_str -# move data to model's device -import torch def _syn_model_data(model, *args): - assert len(model.state_dict())!=0, "This model has no parameter." + """ + + move data to model's device, element in *args should be dict. This is a inplace change. + :param model: + :param args: + :return: + """ + if len(model.state_dict())==0: + raise ValueError("model has no parameter.") device = model.parameters().__next__().device for arg in args: if isinstance(arg, dict): @@ -124,4 +148,33 @@ def _syn_model_data(model, *args): if isinstance(value, torch.Tensor): arg[key] = value.to(device) else: - raise ValueError("Only support dict type right now.") \ No newline at end of file + raise TypeError("Only support `dict` type right now.") + +def _prepare_metrics(metrics): + """ + + Prepare list of Metric based on input + :param metrics: + :return: + """ + _metrics = [] + if metrics: + if isinstance(metrics, list): + for metric in metrics: + if isinstance(metric, type): + metric = metric() + if isinstance(metric, None): + _metrics.append(metric) + else: + raise TypeError("The type of metric in metrics must be xxxx, not {}.".format( + type(), type(metric) + )) + elif isinstance(metrics, None): + _metrics = [metrics] + else: + raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format( + type(metrics) + )) + + return _metrics + From ad0a8c177554ee1a5c4656ea2c8a06aa369f0ca5 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 1 Dec 2018 18:27:07 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0metric?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 23 +++++++ fastNLP/core/metrics.py | 129 +++++++++++++++++++++++++++++++++++++++- fastNLP/core/tester.py | 56 ++++++++++++----- fastNLP/core/trainer.py | 71 ++++++++++------------ fastNLP/core/utils.py | 53 +++++++---------- 5 files changed, 245 insertions(+), 87 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 1e5a4914..d818c613 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -17,6 +17,29 @@ class Loss(LossBase): pass +class LossInForward(LossBase): + def __init__(self, loss_key='loss'): + super().__init__() + + self.loss_key = loss_key + + def get_loss(self, *args, **kwargs): + pass + + def __call__(self, output_dict, predict_dict): + pass + + +def _prepare_losser(losser): + if losser is None: + losser = LossInForward() + return losser + elif isinstance(losser, LossBase): + return losser + else: + raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}") + + def squash(predict, truth, **kwargs): '''To reshape tensors in order to fit Loss functions in pytorch diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 94893324..d4d81212 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1,8 +1,136 @@ + import warnings +import inspect import numpy as np import torch +from fastNLP.core.utils import get_func_signature +from fastNLP.core.utils import _check_arg_dict_list +from fastNLP.core.utils import _build_args + +class MetricBase(object): + def __init__(self): + self.param_map = {} # key is param in function, value is input param. + self._checked = False + + def evaluate(self, *args, **kwargs): + raise NotImplementedError + + def _init_param_map(self, key_map, **kwargs): + self.param_map = {} + for key, value in key_map.items(): + if isinstance(key, str): + raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") + if isinstance(value, str): + raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") + self.param_map[key] = value + for key, value in kwargs.items(): + if isinstance(value, str): + raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") + self.param_map[key] = value + + def __call__(self, output_dict, target_dict, force_check=False): + """ + :param output_dict: + :param target_dict: + :return: + """ + if not callable(self.evaluate): + raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") + + if not self._checked: + # 1. check consistence between signature and param_map + func_spect = inspect.getfullargspec(self.evaluate) + func_args = func_spect.args + for func_param, input_param in self.param_map.items(): + if func_param not in func_args: + raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.") + # 2. only part of the param_map are passed, left are not + for arg in func_args: + if arg not in self.param_map: + self.param_map[arg] = arg #This param does not need mapping. + self._evaluate_args = func_args + + # need to wrap inputs in dict. + mapped_output_dict = {} + mapped_target_dict = {} + for func_arg in self._evaluate_args: + input_arg = self.param_map[func_arg] + if input_arg in output_dict: + mapped_output_dict[func_arg] = output_dict[input_arg] + if input_arg in target_dict: + mapped_target_dict[func_arg] = target_dict[input_arg] + + # check duplicated, unused, missing + if force_check or not self._checked: + check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) + self._reverse_param_map = {value:key for key, value in check_res.items()} + for key, value in check_res.items(): + new_value = value.copy() + for idx, func_param in enumerate(value): + if func_param in self._reverse_param_map: + new_value[idx] = self._reverse_param_map[func_param] + if check_res.missing or check_res.duplicated: + raise CheckError(check_res=check_res) + refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) + + metrics = self.evaluate(**refined_args) + + if not isinstance(metrics, dict): + raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, " + f"got {type(metrics)}.") + self._checked = True + + return metrics + + + + + +class CheckError(Exception): + def __init__(self, check_res): + + err = '' + if check_res.missing: + err += f'Missing: {check_res.missing}\n' + if check_res.duplicated: + err += f'Duplicated: {check_res.duplicated}\n' + self.check_res = check_res + + def __str__(self): + pass + + +class Metric(MetricBase): + def __init__(self, func, key_map, **kwargs): + super().__init__() + pass + +def _prepare_metrics(metrics): + """ + + Prepare list of Metric based on input + :param metrics: + :return: List[fastNLP.MetricBase] + """ + _metrics = [] + if metrics: + if isinstance(metrics, list): + for metric in metrics: + if isinstance(metric, type): + metric = metric() + if isinstance(metric, MetricBase): + _metrics.append(metric) + else: + raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") + elif isinstance(metrics, MetricBase): + _metrics = [metrics] + else: + raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." + .format(type(metrics))) + return _metrics + class Evaluator(object): def __init__(self): @@ -17,7 +145,6 @@ class Evaluator(object): """ raise NotImplementedError - class ClassifyEvaluator(Evaluator): def __init__(self): super(ClassifyEvaluator, self).__init__() diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 5d264b80..a66ce234 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -2,32 +2,49 @@ import itertools from collections import defaultdict import torch +from torch import nn from fastNLP.core.batch import Batch from fastNLP.core.sampler import RandomSampler +from fastNLP.core.dataset import DataSet from fastNLP.core.utils import _build_args from fastNLP.core.utils import get_func_signature +from fastNLP.core.utils import _move_dict_value_to_device +from fastNLP.core.metrics import _prepare_metrics class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): super(Tester, self).__init__() - self.use_cuda = use_cuda + + if not isinstance(data, DataSet): + raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") + if not isinstance(model, nn.Module): + raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") + + self.metrics = _prepare_metrics(metrics) + + # check predict + if hasattr(self._model, 'predict'): + self._predict_func = self._model.predict + if not callable(self._predict_func): + _model_name = model.__class__.__name__ + raise TypeError(f"`{_model_name}.predict` must be callable to be used " + f"for evaluation, not `{type(self._predict_func)}`.") + else: + self._predict_func = self._model + self.data = data - self.batch_size = batch_size - self.verbose = verbose if torch.cuda.is_available() and self.use_cuda: self._model = model.cuda() else: self._model = model - if hasattr(self._model, 'predict'): - if not callable(self._model.predict): - raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " - f"for evaluation.") - self._predict_func = self._model.predict - else: - self._predict_func = self._model + self.use_cuda = use_cuda + self.batch_size = batch_size + self.verbose = verbose + + self._model_device = model.parameters().__next__().device def test(self): @@ -39,6 +56,7 @@ class Tester(object): with torch.no_grad(): for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(self._model_device, batch_x, batch_y) prediction = self.data_forward(network, batch_x) assert isinstance(prediction, dict) for k, v in prediction.items(): @@ -49,10 +67,13 @@ class Tester(object): output[k] = itertools.chain(*v) for k, v in truths.items(): truths[k] = itertools.chain(*v) - # args = _build_args(self._evaluator, **output, **truths) - eval_results = self._evaluator(**args) + eval_results = {} + for metric in self.metrics: + eval_result = metric(output, truths) + metric_name = metric.__class__.__name__ + eval_results[metric_name] = eval_result if self.verbose >= 0: - print("[tester] {}".format(self.print_eval_results(eval_results))) + print("[tester] \n{}".format(self.format_eval_results(eval_results))) self.mode(network, is_test=False) return eval_results @@ -74,10 +95,15 @@ class Tester(object): y = self._predict_func(**x) return y - def print_eval_results(self, results): + def format_eval_results(self, results): """Override this method to support more print formats. :param results: dict, (str: float) is (metrics name: value) """ - return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) + _str = '' + for metric_name, metric_result in results.items(): + _str += metric_name + '\n\t' + _str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) + _str += '\n' + return _str diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 4febdfce..97b420c5 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -17,10 +17,15 @@ from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list -from fastNLP.core.utils import _syn_model_data +from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature from fastNLP.core.dataset import DataSet +from fastNLP.core.losses import LossBase +from fastNLP.core.metrics import MetricBase +from fastNLP.core.losses import _prepare_losser +from fastNLP.core.metrics import _prepare_metrics + class Trainer(object): """Main Training Loop @@ -32,6 +37,25 @@ class Trainer(object): **kwargs): super(Trainer, self).__init__() + if not isinstance(train_data, DataSet): + raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") + if not isinstance(model, nn.Module): + raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") + + # check metrics and dev_data + if (not metrics) and dev_data is not None: + raise ValueError("No metric for dev_data evaluation.") + if metrics and (dev_data is None): + raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") + + # prepare evaluate + metrics = _prepare_metrics(metrics) + # prepare loss + losser = _prepare_losser(losser) + + if need_check_code: + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) + self.train_data = train_data self.dev_data = dev_data # If None, No validation. self.model = model @@ -45,10 +69,7 @@ class Trainer(object): self.validate_every = int(validate_every) self._best_accuracy = 0 - - # TODO check loss与metrics的类型 - - + self._model_device = model.parameters().__next__().device # TODO self._best_accuracy不能表现出当前的metric多种的情况 @@ -72,38 +93,6 @@ class Trainer(object): # print(self.__dict__) - def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1, - validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", - optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, - **kwargs): - if not isinstance(train_data, DataSet): - raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\ - format(type(train_data))) - if not isinstance(model, nn.Module): - raise TypeError("The type of model must be torch.nn.Module, got {}.".\ - format(type(model))) - if losser is not None: - # TODO change - if not isinstance(losser, None): - raise TypeError("The type of losser must be xxx, got {}.".\ - format(type(losser))) - - # check metrics and dev_data - if (not metrics) and dev_data is not None: - raise ValueError("No metric for dev_data evaluation.") - if metrics and (dev_data is None): - raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") - - # check loss - if isinstance(losser, type): - self.losser = losser() - if not isinstance(self.losser, None): - raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.') - - if need_check_code: - _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) - - def train(self): """Start Training. @@ -153,8 +142,9 @@ class Trainer(object): - epoch: int, """ for batch_x, batch_y in data_iterator: + # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 + _move_dict_value_to_device(self._model_device, batch_x, batch_y) prediction = self.data_forward(model, batch_x) - loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) self.update() @@ -205,7 +195,6 @@ class Trainer(object): x = _build_args(network.forward, **x) y = network(**x) if not isinstance(y, dict): - raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") return y @@ -299,7 +288,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): - raise ValueError("The return value of {} should be torch.Tensor, but got {}.". + raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". format(type(losser), type(loss))) if len(loss.size())!=0: raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( @@ -314,7 +303,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ outputs, truths = defaultdict(list), defaultdict(list) dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) # TODO 这里修改为使用tester - + tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) with torch.no_grad(): for batch_count, (batch_x, batch_y) in enumerate(dev_batch): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 8ffcc7bb..97ed83d9 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -3,11 +3,9 @@ import inspect import os from collections import Counter from collections import namedtuple +from collections import defaultdict import torch -CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) - - def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -89,11 +87,15 @@ def _check_arg_dict_list(func, args): input_args = set(input_arg_count.keys()) missing = list(require_args - input_args) unused = list(input_args - all_args) - return CheckRes(missing=missing, - unused=unused, - duplicated=duplicated, - required=list(require_args), - all_needed=list(all_args)) + + check_res = {} + check_res['missing'] = missing + check_res['unused'] = unused + check_res['duplicated'] = duplicated + check_res['required'] = list(require_args) + check_res['all_needed'] = list(all_args) + + return check_res def get_func_signature(func): """ @@ -150,31 +152,22 @@ def _syn_model_data(model, *args): else: raise TypeError("Only support `dict` type right now.") -def _prepare_metrics(metrics): +def _move_dict_value_to_device(device, *args): """ - Prepare list of Metric based on input - :param metrics: + move data to model's device, element in *args should be dict. This is a inplace change. + :param device: torch.device + :param args: :return: """ - _metrics = [] - if metrics: - if isinstance(metrics, list): - for metric in metrics: - if isinstance(metric, type): - metric = metric() - if isinstance(metric, None): - _metrics.append(metric) - else: - raise TypeError("The type of metric in metrics must be xxxx, not {}.".format( - type(), type(metric) - )) - elif isinstance(metrics, None): - _metrics = [metrics] - else: - raise TypeError("The type of metrics should be `list[xxx]` or `xxx`, got {}.".format( - type(metrics) - )) + if not isinstance(device, torch.device): + raise TypeError(f"device must be `torch.device`, got `{type(device)}`") - return _metrics + for arg in args: + if isinstance(arg, dict): + for key, value in arg.items(): + if isinstance(value, torch.Tensor): + arg[key] = value.to(device) + else: + raise TypeError("Only support `dict` type right now.")