diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 1e5a4914..39ba4012 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -1,20 +1,76 @@ import torch +from fastNLP.core.utils import _get_arg_list +from fastNLP.core.utils import _map_args +from fastNLP.core.utils import get_func_signature +from fastNLP.core.utils import _build_args + class LossBase(object): def __init__(self): + # key: name in target function; value: name in output function self.param_map = {} def get_loss(self, *args, **kwargs): raise NotImplementedError - def __call__(self, output_dict, predict_dict): - pass + def __call__(self, output_dict, target_dict): + """ + :param output_dict: A dict from forward function of the network. + :param target_dict: A dict from DataSet.batch_y. + :return: + """ + args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) + if varargs is not None: + raise RuntimeError( + f"The function {get_func_signature(self.get_loss)} should not use Positional Argument." + ) + + param_map = self.param_map + for keys in args: + if keys not in param_map: + param_map.update({keys: keys}) + for keys in defaults: + if keys not in param_map: + param_map.update({keys: keys}) + # param map: key= name in get_loss function, value= name in param dict + reversed_param_map = {val: key for key, val in param_map} + # reversed param map: key= name in param dict, value= name in get_loss function + + param_val_dict = {} + for keys, val in output_dict.items(): + if keys not in target_dict.keys(): + param_val_dict.update({keys: val}) + else: + raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) + for keys, val in target_dict.items(): + if keys not in output_dict.keys(): + param_val_dict.update({keys: val}) + else: + raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) + for keys in args: + if param_map[keys] not in param_val_dict.keys(): + raise RuntimeError("missing param {} in function {}".format(keys, self.get_loss)) -class Loss(LossBase): - def __init__(self): - pass + param_map_val = _map_args(reversed_param_map, **param_val_dict) + param_value = _build_args(**param_map_val) + + loss = self.get_loss(**param_value) + + if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): + if not isinstance(loss, torch.Tensor): + raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss))) + raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size()))) + + return loss + + +class NewLoss(LossBase): + def __init__(self, func, key_map=None, **kwargs): + super(NewLoss).__init__() + if not callable(func): + raise RuntimeError("") def squash(predict, truth, **kwargs): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 84faaece..13982e27 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -64,6 +64,39 @@ def _build_args(func, **kwargs): return output +def _map_args(maps: dict, **kwargs): + # maps: key=old name, value= new name + output = {} + for name, val in kwargs.items(): + if name in maps: + assert isinstance(maps[name], str) + output.update({maps[name]: val}) + else: + output.update({name: val}) + for keys in maps.keys(): + if keys not in output.keys(): + # TODO: add UNUSED warning. + pass + return output + + +def _get_arg_list(func): + assert callable(func) + spect = inspect.getfullargspec(func) + if spect.defaults is not None: + args = spect.args[: -len(spect.defaults)] + defaults = spect.args[-len(spect.defaults):] + defaults_val = spect.defaults + else: + args = spect.args + defaults = None + defaults_val = None + varargs = spect.varargs + kwargs = spect.varkw + return args, defaults, defaults_val, varargs, kwargs + + + # check args def _check_arg_dict_list(func, args): if isinstance(args, dict):