|
@@ -1,20 +1,76 @@ |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.utils import _get_arg_list |
|
|
|
|
|
from fastNLP.core.utils import _map_args |
|
|
|
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LossBase(object): |
|
|
class LossBase(object): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
|
|
|
# key: name in target function; value: name in output function |
|
|
self.param_map = {} |
|
|
self.param_map = {} |
|
|
|
|
|
|
|
|
def get_loss(self, *args, **kwargs): |
|
|
def get_loss(self, *args, **kwargs): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def __call__(self, output_dict, predict_dict): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def __call__(self, output_dict, target_dict): |
|
|
|
|
|
""" |
|
|
|
|
|
:param output_dict: A dict from forward function of the network. |
|
|
|
|
|
:param target_dict: A dict from DataSet.batch_y. |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) |
|
|
|
|
|
if varargs is not None: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"The function {get_func_signature(self.get_loss)} should not use Positional Argument." |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
param_map = self.param_map |
|
|
|
|
|
for keys in args: |
|
|
|
|
|
if keys not in param_map: |
|
|
|
|
|
param_map.update({keys: keys}) |
|
|
|
|
|
for keys in defaults: |
|
|
|
|
|
if keys not in param_map: |
|
|
|
|
|
param_map.update({keys: keys}) |
|
|
|
|
|
# param map: key= name in get_loss function, value= name in param dict |
|
|
|
|
|
reversed_param_map = {val: key for key, val in param_map} |
|
|
|
|
|
# reversed param map: key= name in param dict, value= name in get_loss function |
|
|
|
|
|
|
|
|
|
|
|
param_val_dict = {} |
|
|
|
|
|
for keys, val in output_dict.items(): |
|
|
|
|
|
if keys not in target_dict.keys(): |
|
|
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) |
|
|
|
|
|
for keys, val in target_dict.items(): |
|
|
|
|
|
if keys not in output_dict.keys(): |
|
|
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) |
|
|
|
|
|
|
|
|
|
|
|
for keys in args: |
|
|
|
|
|
if param_map[keys] not in param_val_dict.keys(): |
|
|
|
|
|
raise RuntimeError("missing param {} in function {}".format(keys, self.get_loss)) |
|
|
|
|
|
|
|
|
class Loss(LossBase): |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
param_map_val = _map_args(reversed_param_map, **param_val_dict) |
|
|
|
|
|
param_value = _build_args(**param_map_val) |
|
|
|
|
|
|
|
|
|
|
|
loss = self.get_loss(**param_value) |
|
|
|
|
|
|
|
|
|
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
|
|
|
if not isinstance(loss, torch.Tensor): |
|
|
|
|
|
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss))) |
|
|
|
|
|
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size()))) |
|
|
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NewLoss(LossBase): |
|
|
|
|
|
def __init__(self, func, key_map=None, **kwargs): |
|
|
|
|
|
super(NewLoss).__init__() |
|
|
|
|
|
if not callable(func): |
|
|
|
|
|
raise RuntimeError("") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def squash(predict, truth, **kwargs): |
|
|
def squash(predict, truth, **kwargs): |
|
|