diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 38da83da..0aca6055 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,4 +1,5 @@ import torch +import numpy as np class Batch(object): @@ -45,7 +46,7 @@ class Batch(object): if field.is_target or field.is_input: batch = field.get(indices) if not self.as_numpy: - batch = torch.from_numpy(batch) + batch = to_tensor(batch, field.dtype) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -54,3 +55,10 @@ class Batch(object): self.curidx = endidx return batch_x, batch_y + +def to_tensor(batch, dtype): + if dtype in (np.int8, np.int16, np.int32, np.int64): + batch = torch.LongTensor(batch) + if dtype in (np.float32, np.float64): + batch = torch.FloatTensor(batch) + return batch \ No newline at end of file diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index f93fbf2e..714fa169 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -39,7 +39,7 @@ class FieldArray(object): @staticmethod def _map_to_np_type(basic_type): - type_mapping = {int: np.int64, float: np.double, str: np.str} + type_mapping = {int: np.int64, float: np.float64, str: np.str} return type_mapping[basic_type] def __repr__(self): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 85b16e64..564eb7ce 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -126,15 +126,30 @@ class NLLLoss(LossBase): class LossInForward(LossBase): def __init__(self, loss_key='loss'): super().__init__() + if not isinstance(loss_key, str): + raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") self.loss_key = loss_key def get_loss(self, **kwargs): if self.loss_key not in kwargs: - pass + check_res = CheckRes(missing=[self.loss_key], + unused=[], + duplicated=[], + required=[], + all_needed=[], + varargs=[]) + raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) - def __call__(self, output_dict, predict_dict): + def __call__(self, output_dict, predict_dict, force_check=False): - return self.get_loss(**output_dict) + loss = self.get_loss(**output_dict) + + if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): + if not isinstance(loss, torch.Tensor): + raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") + raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") + + return loss def _prepare_losser(losser): diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 69bb540d..f8fc1d49 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -10,7 +10,7 @@ from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _build_args from fastNLP.core.utils import CheckError - +from fastNLP.core.utils import _check_function_or_method class MetricBase(object): def __init__(self): @@ -20,19 +20,32 @@ class MetricBase(object): def evaluate(self, *args, **kwargs): raise NotImplementedError - def _init_param_map(self, key_map, **kwargs): - self.param_map = {} - value_counter = defaultdict(0) - for key, value in key_map.items(): - if isinstance(key, str): - raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") - if isinstance(value, str): - raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") - self.param_map[key] = value + def _init_param_map(self, key_map=None, **kwargs): + value_counter = defaultdict(set) + if key_map is not None: + if not isinstance(key_map, dict): + raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) + for key, value in key_map.items(): + if value is None: + self.param_map[key] = key + continue + if isinstance(key, str): + raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") + if isinstance(value, str): + raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") + self.param_map[key] = value + value_counter[value].add(key) for key, value in kwargs.items(): + if value is None: + self.param_map[key] = key + continue if isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") self.param_map[key] = value + value_counter[value].add(key) + for value, key_set in value_counter.items(): + if len(key_set)>1: + raise ValueError(f"Several params:{key_set} are provided with one output {value}.") def __call__(self, output_dict, target_dict, check=False): """ @@ -45,8 +58,6 @@ class MetricBase(object): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") if not self._checked: - # 0. check param_map does not have same value - # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = func_spect.args @@ -58,26 +69,32 @@ class MetricBase(object): if arg not in self.param_map: self.param_map[arg] = arg #This param does not need mapping. self._evaluate_args = func_args + self._reverse_param_map = {value: key for key, value in self.param_map.items()} # need to wrap inputs in dict. mapped_output_dict = {} mapped_target_dict = {} for func_arg in self._evaluate_args: input_arg = self.param_map[func_arg] + if input_arg in self._reverse_param_map: + mapped_arg = func_arg + else: + mapped_arg = input_arg if input_arg in output_dict: - mapped_output_dict[func_arg] = output_dict[input_arg] + mapped_output_dict[mapped_arg] = output_dict[input_arg] if input_arg in target_dict: - mapped_target_dict[func_arg] = target_dict[input_arg] + mapped_target_dict[mapped_arg] = target_dict[input_arg] # check duplicated, unused, missing if check or not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) - self._reverse_param_map = {value:key for key, value in check_res.items()} for key, value in check_res.items(): new_value = list(value) for idx, func_param in enumerate(value): if func_param in self._reverse_param_map: - new_value[idx] = self._reverse_param_map[func_param] + new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})' + else: + new_value[idx] = func_param if check_res.missing or check_res.duplicated or check_res.varargs: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.evaluate)) @@ -93,11 +110,55 @@ class MetricBase(object): return metrics -class Metric(MetricBase): +class FuncMetric(MetricBase): def __init__(self, func, key_map, **kwargs): super().__init__() + + _check_function_or_method(func=func) + self._init_param_map(key_map=key_map, **kwargs) + + self.evaluate = func + + +class AccuracyMetric(MetricBase): + def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None): + super().__init__() + + self._init_param_map(predictions=predictions, targets=targets, + masks=masks, seq_lens=seq_lens) + + def evaluate(self, predictions, targets, masks=None, seq_lens=None): + """ + + :param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: + torch.Size([]), torch.Size([n_classes,]), torch.Size([max_len,]), torch.Size([max_len, n_classes]) + :param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: + torch.Size([]), torch.Size([]), torch.Size([max_len,]), torch.Size([max_len, ]) + :param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: + None, None, torch.Size([max_len,], torch.Size([max_len, ]) + :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: + None, None, torch.Size([1], torch.Size([1]) + :return: dict({'acc': float}) + """ pass + def _check_evaluate_param(self, predictions, targets, masks=None, seq_lens=None): + # check the validity of self.evaluate param + prediction = predictions[0] + target = targets[0] + + if len(np.shape(prediction))==len(target): + pass + + if masks is not None: + mask = masks[0] + if seq_lens is not None: + seq_len = seq_lens[0] + + + + + def _prepare_metrics(metrics): """ diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 39efb454..e809cd06 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -7,11 +7,11 @@ from torch import nn from fastNLP.core.batch import Batch from fastNLP.core.sampler import SequentialSampler from fastNLP.core.dataset import DataSet +from fastNLP.core.utils import CheckError from fastNLP.core.utils import _build_args from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.metrics import _prepare_metrics -from fastNLP.core.utils import CheckError from fastNLP.core.utils import _check_loss_evaluate class Tester(object): @@ -57,7 +57,7 @@ class Tester(object): with torch.no_grad(): for batch_x, batch_y in data_iterator: - _move_dict_value_to_device(self._model_device, batch_x, batch_y) + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(self._predict_func, batch_x) assert isinstance(prediction, dict) for k, v in prediction.items(): @@ -77,7 +77,7 @@ class Tester(object): except CheckError as e: prev_func_signature = get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, - check_res=e.check_res, output=output, batch_y=truths) + check_res=e.check_res, output=output, batch_y=truths, check_level=0) if self.verbose >= 0: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 39d76521..6d31e390 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,6 +1,5 @@ import os import time -import warnings from datetime import datetime from datetime import timedelta @@ -9,24 +8,19 @@ from tensorboardX import SummaryWriter from torch import nn from fastNLP.core.batch import Batch -from fastNLP.core.dataset import DataSet -from fastNLP.core.losses import _prepare_losser -from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Adam from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester -from fastNLP.core.utils import CheckError -from fastNLP.core.utils import _build_args -from fastNLP.core.utils import _check_arg_dict_list -from fastNLP.core.utils import _move_dict_value_to_device -from fastNLP.core.utils import get_func_signature from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.utils import CheckError from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_forward_error +from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _move_dict_value_to_device +from fastNLP.core.utils import get_func_signature class Trainer(object): """Main Training Loop @@ -52,6 +46,9 @@ class Trainer(object): if metrics and (dev_data is None): raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") + # check save_path + if not (save_path is None or isinstance(save_path, str)): + raise ValueError("save_path can only be None or `str`.") # prepare evaluate metrics = _prepare_metrics(metrics) @@ -156,7 +153,7 @@ class Trainer(object): """ for batch_x, batch_y in data_iterator: # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 - _move_dict_value_to_device(self._model_device, batch_x, batch_y) + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(model, batch_x) loss = self._compute_loss(prediction, batch_y) self._grad_backward(loss) @@ -232,11 +229,12 @@ class Trainer(object): return self.losser(predict, truth) def _save_model(self, model, model_name, only_param=False): - model_name = os.path.join(self.save_path, model_name) - if only_param: - torch.save(model.state_dict(), model_name) - else: - torch.save(model, model_name) + if self.save_path is not None: + model_name = os.path.join(self.save_path, model_name) + if only_param: + torch.save(model.state_dict(), model_name) + else: + torch.save(model, model_name) def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. @@ -297,7 +295,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): - _move_dict_value_to_device(model_devcie, batch_x, batch_y) + _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) # forward check if batch_count==0: _check_forward_error(forward_func=model.forward, check_level=check_level, @@ -335,6 +333,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ if dev_data is not None: tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) - tester.test() + evaluate_results = tester.test() + # TODO 这里需要检查是否返回来的值是否是合理的 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d237c190..cfc77f46 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -122,13 +122,13 @@ def _check_arg_dict_list(func, args): input_args = set(input_arg_count.keys()) missing = list(require_args - input_args) unused = list(input_args - all_args) - + varargs = [] if spect.varargs else [arg for arg in spect.varargs] return CheckRes(missing=missing, unused=unused, duplicated=duplicated, required=list(require_args), all_needed=list(all_args), - varargs=[arg for arg in spect.varargs]) + varargs=varargs) def get_func_signature(func): """ @@ -165,6 +165,7 @@ def get_func_signature(func): signature_str = func.__name__ + signature_str return signature_str + def _is_function_or_method(func): """ @@ -179,26 +180,8 @@ def _check_function_or_method(func): if not _is_function_or_method(func): raise TypeError(f"{type(func)} is not a method or function.") -def _syn_model_data(model, *args): - """ - - move data to model's device, element in *args should be dict. This is a inplace change. - :param model: - :param args: - :return: - """ - if len(model.state_dict())==0: - raise ValueError("model has no parameter.") - device = model.parameters().__next__().device - for arg in args: - if isinstance(arg, dict): - for key, value in arg.items(): - if isinstance(value, torch.Tensor): - arg[key] = value.to(device) - else: - raise TypeError("Only support `dict` type right now.") -def _move_dict_value_to_device(device, *args): +def _move_dict_value_to_device(*args, device:torch.device): """ move data to model's device, element in *args should be dict. This is a inplace change. @@ -240,6 +223,7 @@ class CheckError(Exception): self.check_res = check_res self.func_signature = func_signature + IGNORE_CHECK_LEVEL = 0 WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 @@ -252,8 +236,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " f"please delete it.)") if check_res.missing: - errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" - f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") + errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`" + f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).") if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") @@ -281,7 +265,7 @@ def _check_forward_error(forward_func, batch_x, check_level): if check_res.varargs: errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") if check_res.missing: - errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") + errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.") if check_res.unused: _unused = [f"\tunused param: {check_res.unused}"] if check_level == STRICT_CHECK_LEVEL: