diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 58847c31..3bbbf9e2 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -8,8 +8,7 @@ from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckRes from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_function_or_method -from fastNLP.core.utils import _get_arg_list -from fastNLP.core.utils import _map_args +from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import get_func_signature @@ -62,8 +61,7 @@ class LossBase(object): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " - f"initialization parameters, or change the signature of" - f" {get_func_signature(self.get_loss)}.") + f"initialization parameters, or change its signature.") # evaluate should not have varargs. if func_spect.varargs: @@ -87,71 +85,68 @@ class LossBase(object): loss = self.get_loss(*fast_param) return loss - args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) - if varargs is not None: - raise RuntimeError( - f"The function {get_func_signature(self.get_loss)} should not use Positional Argument." - ) - - param_map = self.param_map - if args is None: - raise RuntimeError( - f"There is not any param in function{get_func_signature(self.get_loss)}" - ) - - self._checked = self._checked and not check if not self._checked: - for keys in args: - if keys not in param_map: - param_map.update({keys: keys}) - if defaults is not None: - for keys in defaults: - if keys not in param_map: - param_map.update({keys: keys}) - self.param_map = param_map - # param map: key= name in get_loss function, value= name in param dict - reversed_param_map = {val: key for key, val in param_map.items()} - # reversed param map: key= name in param dict, value= name in get_loss function - + # 1. check consistence between signature and param_map + func_spect = inspect.getfullargspec(self.get_loss) + func_args = set([arg for arg in func_spect.args if arg != 'self']) + for func_arg, input_arg in self.param_map.items(): + if func_arg not in func_args: + raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") + + # 2. only part of the param_map are passed, left are not + for arg in func_args: + if arg not in self.param_map: + self.param_map[arg] = arg # This param does not need mapping. + self._evaluate_args = func_args + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} + + # need to wrap inputs in dict. + mapped_pred_dict = {} + mapped_target_dict = {} duplicated = [] - missing = [] - if not self._checked: - for keys, val in pred_dict.items(): - if keys in target_dict.keys(): - duplicated.append(param_map[keys]) - - param_val_dict = {} - for keys, val in pred_dict.items(): - param_val_dict.update({keys: val}) - for keys, val in target_dict.items(): - param_val_dict.update({keys: val}) - + for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): + not_duplicate_flag = 0 + if input_arg in self._reverse_param_map: + mapped_arg = self._reverse_param_map[input_arg] + not_duplicate_flag += 1 + else: + mapped_arg = input_arg + if input_arg in pred_dict: + mapped_pred_dict[mapped_arg] = pred_dict[input_arg] + not_duplicate_flag += 1 + if input_arg in target_dict: + mapped_target_dict[mapped_arg] = target_dict[input_arg] + not_duplicate_flag += 1 + if not_duplicate_flag == 3: + duplicated.append(input_arg) + + # missing if not self._checked: - for keys in args: - if param_map[keys] not in param_val_dict.keys(): - missing.append(param_map[keys]) - - if len(duplicated) > 0 or len(missing) > 0: - raise CheckError( - CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[], - varargs=varargs), - func_signature=get_func_signature(self.get_loss) - ) - + check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) + # only check missing. + missing = check_res.missing + replaced_missing = list(missing) + for idx, func_arg in enumerate(missing): + replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + f"in `{self.__class__.__name__}`)" + + check_res = CheckRes(missing=replaced_missing, + unused=check_res.unused, + duplicated=duplicated, + required=check_res.required, + all_needed=check_res.all_needed, + varargs=check_res.varargs) + + if check_res.missing or check_res.duplicated or check_res.varargs: + raise CheckError(check_res=check_res, + func_signature=get_func_signature(self.get_loss)) + refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) + + loss = self.get_loss(**refined_args) self._checked = True - param_map_val = _map_args(reversed_param_map, **param_val_dict) - param_value = _build_args(self.get_loss, **param_map_val) - loss = self.get_loss(**param_value) - - if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): - if not isinstance(loss, torch.Tensor): - raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") - raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}") - return loss - class LossFunc(LossBase): def __init__(self, func, key_map=None, **kwargs): super(LossFunc, self).__init__() @@ -168,34 +163,42 @@ class LossFunc(LossBase): class CrossEntropyLoss(LossBase): - def __init__(self, pred=None, target=None): + def __init__(self, pred=None, target=None, padding_idx=-100): # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 # TODO (16, 4) super(CrossEntropyLoss, self).__init__() - self.get_loss = F.cross_entropy - self._init_param_map(input=pred, target=target) + self._init_param_map(pred=pred, target=target) + self.padding_idx = padding_idx + def get_loss(self, pred, target): + return F.cross_entropy(input=pred, target=target, + ignore_index=self.padding_idx) class L1Loss(LossBase): def __init__(self, pred=None, target=None): super(L1Loss, self).__init__() - self.get_loss = F.l1_loss self._init_param_map(input=pred, target=target) + def get_loss(self, pred, target): + return F.l1_loss(input=pred, target=target) + class BCELoss(LossBase): def __init__(self, pred=None, target=None): super(BCELoss, self).__init__() - self.get_loss = F.binary_cross_entropy self._init_param_map(input=pred, target=target) + def get_loss(self, pred, target): + return F.binary_cross_entropy(input=pred, target=target) class NLLLoss(LossBase): def __init__(self, pred=None, target=None): super(NLLLoss, self).__init__() - self.get_loss = F.nll_loss self._init_param_map(input=pred, target=target) + def get_loss(self, pred, target): + return F.nll_loss(input=pred, target=target) + class LossInForward(LossBase): def __init__(self, loss_key='loss'): diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 270b4d3b..22f11234 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -322,7 +322,7 @@ class TestLosserError(unittest.TestCase): def test_losser3(self): # (2) with corrupted size pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} - target_dict = {'target': torch.zeros(16, 3).long()} + target_dict = {'target': torch.zeros(16).long()} los = loss.CrossEntropyLoss() print(los(pred_dict=pred_dict, target_dict=target_dict)) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 1b578eae..e74ec4b5 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -8,7 +8,7 @@ from fastNLP.core.utils import CheckError from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss -from fastNLP.core.losses import LossInForward +from fastNLP.core.losses import CrossEntropyLoss from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.optimizer import SGD from fastNLP.core.trainer import Trainer @@ -222,7 +222,7 @@ class TrainerTestGround(unittest.TestCase): x1 = self.fc(x1) x2 = self.fc(x2) x = x1 + x2 - loss = F.cross_entropy(x, y) + # loss = F.cross_entropy(x, y) return {'pred': x} model = Model() @@ -231,10 +231,10 @@ class TrainerTestGround(unittest.TestCase): train_data=dataset, model=model, dev_data=dataset, + losser=CrossEntropyLoss(), metrics=AccuracyMetric(), use_tqdm=False, - print_every=2 - ) + print_every=2) def test_case2(self): # check metrics Wrong