diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 64ad8e23..c3459964 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -70,6 +70,12 @@ class LossBase(object): raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " f"positional argument.).") + def _fast_param_map(self, pred_dict, target_dict): + if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + return pred_dict.values[0], target_dict.values[0] + return None + + def __call__(self, pred_dict, target_dict, check=False): """ :param pred_dict: A dict from forward function of the network. @@ -77,6 +83,11 @@ class LossBase(object): :param check: Boolean. Force to check the mapping functions when it is running. :return: """ + fast_param = self._fast_param_map(pred_dict, target_dict) + if fast_param is not None: + loss = self.get_loss(*fast_param) + return loss + args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) if varargs is not None: raise RuntimeError( @@ -132,7 +143,6 @@ class LossBase(object): param_map_val = _map_args(reversed_param_map, **param_val_dict) param_value = _build_args(self.get_loss, **param_map_val) - loss = self.get_loss(**param_value) if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index b1fc110b..6216b16d 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -71,7 +71,7 @@ class MetricBase(object): def get_metric(self, reset=True): raise NotImplemented - def _fast_call_evaluate(self, pred_dict, target_dict): + def _fast_param_map(self, pred_dict, target_dict): """ Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. @@ -80,7 +80,9 @@ class MetricBase(object): :param target_dict: :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. """ - return False + if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + return pred_dict.values[0] and target_dict.values[0] + return None def __call__(self, pred_dict, target_dict, check=False): """ @@ -103,7 +105,9 @@ class MetricBase(object): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") if not check: - if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): + fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) + if fast_param is not None: + self.evaluate(*fast_param) return if not self._checked: