From 0d4720b1d91648fa61683d9dde13d9e183b9c003 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 1 Dec 2018 20:14:43 +0800 Subject: [PATCH] CheckError add function --- fastNLP/core/metrics.py | 28 ++++--------- fastNLP/core/tester.py | 30 ++++++++------ fastNLP/core/trainer.py | 87 +++++++++++------------------------------ fastNLP/core/utils.py | 17 ++++---- 4 files changed, 57 insertions(+), 105 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index d4d81212..60e0d82f 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -8,6 +8,8 @@ import torch from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _build_args +from fastNLP.core.utils import CheckError + class MetricBase(object): def __init__(self): @@ -29,7 +31,7 @@ class MetricBase(object): if isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") self.param_map[key] = value - + def __call__(self, output_dict, target_dict, force_check=False): """ :param output_dict: @@ -67,7 +69,7 @@ class MetricBase(object): check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) self._reverse_param_map = {value:key for key, value in check_res.items()} for key, value in check_res.items(): - new_value = value.copy() + new_value = list(value) for idx, func_param in enumerate(value): if func_param in self._reverse_param_map: new_value[idx] = self._reverse_param_map[func_param] @@ -85,28 +87,12 @@ class MetricBase(object): return metrics - - - -class CheckError(Exception): - def __init__(self, check_res): - - err = '' - if check_res.missing: - err += f'Missing: {check_res.missing}\n' - if check_res.duplicated: - err += f'Duplicated: {check_res.duplicated}\n' - self.check_res = check_res - - def __str__(self): - pass - - class Metric(MetricBase): def __init__(self, func, key_map, **kwargs): super().__init__() pass + def _prepare_metrics(metrics): """ @@ -127,8 +113,8 @@ def _prepare_metrics(metrics): elif isinstance(metrics, MetricBase): _metrics = [metrics] else: - raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}." - .format(type(metrics))) + raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, " + f"got {type(metrics)}.") return _metrics diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index a66ce234..33d8cc81 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -5,12 +5,13 @@ import torch from torch import nn from fastNLP.core.batch import Batch -from fastNLP.core.sampler import RandomSampler +from fastNLP.core.sampler import SequentialSampler from fastNLP.core.dataset import DataSet from fastNLP.core.utils import _build_args from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.metrics import _prepare_metrics +from fastNLP.core.utils import CheckError class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ @@ -33,7 +34,7 @@ class Tester(object): raise TypeError(f"`{_model_name}.predict` must be callable to be used " f"for evaluation, not `{type(self._predict_func)}`.") else: - self._predict_func = self._model + self._predict_func = self._model.forward self.data = data if torch.cuda.is_available() and self.use_cuda: @@ -50,14 +51,14 @@ class Tester(object): def test(self): # turn on the testing mode; clean up the history network = self._model - self.mode(network, is_test=True) + self._mode(network, is_test=True) output, truths = defaultdict(list), defaultdict(list) - data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) + data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) with torch.no_grad(): for batch_x, batch_y in data_iterator: _move_dict_value_to_device(self._model_device, batch_x, batch_y) - prediction = self.data_forward(network, batch_x) + prediction = self._data_forward(self._predict_func, batch_x) assert isinstance(prediction, dict) for k, v in prediction.items(): output[k].append(v) @@ -68,16 +69,21 @@ class Tester(object): for k, v in truths.items(): truths[k] = itertools.chain(*v) eval_results = {} + try: for metric in self.metrics: eval_result = metric(output, truths) metric_name = metric.__class__.__name__ eval_results[metric_name] = eval_result + except CheckError as e: + pass + + if self.verbose >= 0: - print("[tester] \n{}".format(self.format_eval_results(eval_results))) - self.mode(network, is_test=False) + print("[tester] \n{}".format(self._format_eval_results(eval_results))) + self._mode(network, is_test=False) return eval_results - def mode(self, model, is_test=False): + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model @@ -89,13 +95,13 @@ class Tester(object): else: model.train() - def data_forward(self, network, x): + def _data_forward(self, func, x): """A forward pass of the model. """ - x = _build_args(network.forward, **x) - y = self._predict_func(**x) + x = _build_args(func, **x) + y = func(**x) return y - def format_eval_results(self, results): + def _format_eval_results(self, results): """Override this method to support more print formats. :param results: dict, (str: float) is (metrics name: value) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 97b420c5..da8e54f9 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -25,7 +25,7 @@ from fastNLP.core.losses import LossBase from fastNLP.core.metrics import MetricBase from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics - +from fastNLP.core.utils import CheckError class Trainer(object): """Main Training Loop @@ -211,13 +211,11 @@ class Trainer(object): def get_loss(self, predict, truth): """Compute loss given prediction and ground truth. - :param predict: prediction label vector - :param truth: ground truth label vector + :param predict: prediction dict, produced by model.forward + :param truth: ground truth dict, produced by batch_y :return: a scalar """ - assert isinstance(predict, dict) and isinstance(truth, dict) - args = _build_args(self.loss_func, **predict, **truth) - return self.loss_func(**args) + return self.losser(predict, truth) def save_model(self, model, model_name, only_param=False): model_name = os.path.join(self.save_path, model_name) @@ -260,11 +258,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ dev_data=None, check_level=WARNING_CHECK_LEVEL): # check get_loss 方法 - model_name = model.__class__.__name__ + model_devcie = model.parameters().__next__().device batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): - _syn_model_data(model, batch_x, batch_y) + _move_dict_value_to_device(model_devcie, batch_x, batch_y) # forward check if batch_count==0: _check_forward_error(model_func=model.forward, check_level=check_level, @@ -277,68 +275,29 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") # loss check - if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss - # 需要保证output与batch_y是无歧义的? - # (1) output和batch_y长度为1 - # (2) output和batch_y的key是和losser接受的完全一致 - pass - - loss = losser(output, batch_y) - + try: + loss = losser(output, batch_y) + except CheckError as e: + _check_loss_evaluate(prev_func=model.forward, func=e.func_signature, + check_res=e.check_res, output=output, batch_y=batch_y, + check_level=check_level) # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): - raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.". - format(type(losser), type(loss))) + raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " + f"but got `{type(loss)}`.") if len(loss.size())!=0: - raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( - type(losser), loss.size() - )) + raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " + f"should be torch.size([])") loss.backward() model.zero_grad() if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: break if dev_data is not None: - outputs, truths = defaultdict(list), defaultdict(list) - dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) - # TODO 这里修改为使用tester - tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, ) - - with torch.no_grad(): - for batch_count, (batch_x, batch_y) in enumerate(dev_batch): - _syn_model_data(model, batch_x, batch_y) - - if hasattr(model, 'predict'): - if not callable(model.predict): - raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " - f"for evaluation.") - refined_batch_x = _build_args(model.predict, **batch_x) - prev_func = model.predict - output = prev_func(**refined_batch_x) - else: - refined_batch_x = _build_args(model.forward, **batch_x) - prev_func = model.forward - output = prev_func(**refined_batch_x) - func_signature = get_func_signature(prev_func) - if not isinstance(output, dict): - raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") - for k, v in output.items(): - outputs[k].append(v) - for k, v in batch_y.items(): - truths[k].append(v) - if batch_count+1>DEFAULT_CHECK_NUM_BATCH: - break - for k, v in outputs.items(): - outputs[k] = tuple(itertools.chain(*v)) - for k, v in truths.items(): - truths[k] = tuple(itertools.chain(*v)) - #TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug - - - - - + tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, + batch_size=batch_size, verbose=-1) + tester.test() def _check_forward_error(model_func, check_level, batch_x): @@ -346,11 +305,11 @@ def _check_forward_error(model_func, check_level, batch_x): _missing = '' _unused = '' func_signature = get_func_signature(model_func) - if len(check_res.missing)!=0: + if len(check_res['missing'])!=0: _missing = "Function {} misses {}, only provided with {}, " \ ".\n".format(func_signature, check_res.missing, list(batch_x.keys())) - if len(check_res.unused)!=0: + if len(check_res['unused'])!=0: if len(check_res.unused) > 1: _unused = "{} are not used ".format(check_res.unused) else: @@ -370,9 +329,7 @@ def _check_forward_error(model_func, check_level, batch_x): elif check_level == WARNING_CHECK_LEVEL: warnings.warn(message=_unused) -def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): - - check_res = _check_arg_dict_list(func, [output, batch_y]) +def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): _missing = '' _unused = '' _duplicated = '' diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index efc2ef7e..61c5bc5c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -220,13 +220,16 @@ class CheckError(Exception): CheckError. Used in losses.LossBase, metrics.MetricBase. """ - def __init__(self, check_res): + def __init__(self, check_res:CheckRes, func_signature:str): err = '' - if check_res['missing']: - err += f"Missing: {check_res['missing']}\n" - if check_res['duplicated']: - err += f"Duplicated: {check_res['duplicated']}\n" - if check_res['unused']: - err += f"Unused: {check_res['unused']}\n" + if check_res.missing: + err += f"Missing: {check_res.missing}\n" + if check_res.duplicated: + err += f"Duplicated: {check_res.duplicated}\n" + if check_res.unused: + err += f"Unused: {check_res.unused}\n" + Exception.__init__(self, err) + self.check_res = check_res + self.func_signature = func_signature