|
|
|
@@ -1,23 +1,29 @@ |
|
|
|
import torch |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from fastNLP.core.utils import CheckError |
|
|
|
from fastNLP.core.utils import CheckRes |
|
|
|
from fastNLP.core.utils import _get_arg_list |
|
|
|
from fastNLP.core.utils import _map_args |
|
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
from fastNLP.core.utils import _check_function_or_method |
|
|
|
|
|
|
|
|
|
|
|
class LossBase(object): |
|
|
|
def __init__(self): |
|
|
|
# key: name in target function; value: name in output function |
|
|
|
self.param_map = {} |
|
|
|
self._checked = False |
|
|
|
|
|
|
|
def get_loss(self, *args, **kwargs): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def __call__(self, output_dict, target_dict): |
|
|
|
def __call__(self, output_dict, target_dict, force_check=False): |
|
|
|
""" |
|
|
|
:param output_dict: A dict from forward function of the network. |
|
|
|
:param target_dict: A dict from DataSet.batch_y. |
|
|
|
:param force_check: Boolean. Force to check the mapping functions when it is running. |
|
|
|
:return: |
|
|
|
""" |
|
|
|
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) |
|
|
|
@@ -27,50 +33,94 @@ class LossBase(object): |
|
|
|
) |
|
|
|
|
|
|
|
param_map = self.param_map |
|
|
|
for keys in args: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
for keys in defaults: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
if args is None: |
|
|
|
raise RuntimeError( |
|
|
|
f"There is not any param in function{get_func_signature(self.get_loss)}" |
|
|
|
) |
|
|
|
self._checked = self._checked and not force_check |
|
|
|
if not self._checked: |
|
|
|
for keys in args: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
if defaults is not None: |
|
|
|
for keys in defaults: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
self.param_map = param_map |
|
|
|
# param map: key= name in get_loss function, value= name in param dict |
|
|
|
reversed_param_map = {val: key for key, val in param_map} |
|
|
|
reversed_param_map = {val: key for key, val in param_map.items()} |
|
|
|
# reversed param map: key= name in param dict, value= name in get_loss function |
|
|
|
|
|
|
|
duplicated = [] |
|
|
|
missing = [] |
|
|
|
if not self._checked: |
|
|
|
for keys, val in output_dict.items(): |
|
|
|
if keys in target_dict.keys(): |
|
|
|
duplicated.append(keys) |
|
|
|
|
|
|
|
param_val_dict = {} |
|
|
|
for keys, val in output_dict.items(): |
|
|
|
if keys not in target_dict.keys(): |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
else: |
|
|
|
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
for keys, val in target_dict.items(): |
|
|
|
if keys not in output_dict.keys(): |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
else: |
|
|
|
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
|
|
|
|
for keys in args: |
|
|
|
if param_map[keys] not in param_val_dict.keys(): |
|
|
|
raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}") |
|
|
|
if not self._checked: |
|
|
|
for keys in args: |
|
|
|
if param_map[keys] not in param_val_dict.keys(): |
|
|
|
missing.append(keys) |
|
|
|
|
|
|
|
if len(duplicated) > 0 or len(missing) > 0: |
|
|
|
raise CheckError( |
|
|
|
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]), |
|
|
|
func_signature=get_func_signature(self.get_loss) |
|
|
|
) |
|
|
|
|
|
|
|
self._checked = True |
|
|
|
|
|
|
|
param_map_val = _map_args(reversed_param_map, **param_val_dict) |
|
|
|
param_value = _build_args(**param_map_val) |
|
|
|
param_value = _build_args(self.get_loss, **param_map_val) |
|
|
|
|
|
|
|
loss = self.get_loss(**param_value) |
|
|
|
|
|
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
|
if not isinstance(loss, torch.Tensor): |
|
|
|
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss))) |
|
|
|
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size()))) |
|
|
|
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") |
|
|
|
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
class NewLoss(LossBase): |
|
|
|
def __init__(self, func, key_map=None, **kwargs): |
|
|
|
super(NewLoss).__init__() |
|
|
|
if not callable(func): |
|
|
|
raise RuntimeError("") |
|
|
|
super(NewLoss, self).__init__() |
|
|
|
_check_function_or_method(func) |
|
|
|
if key_map is not None: |
|
|
|
if not isinstance(key_map, dict): |
|
|
|
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") |
|
|
|
self.param_map = key_map |
|
|
|
if len(kwargs) > 0: |
|
|
|
for key, val in kwargs.items(): |
|
|
|
self.param_map.update({key: val}) |
|
|
|
|
|
|
|
self.get_loss = func |
|
|
|
|
|
|
|
|
|
|
|
class L1Loss(LossBase): |
|
|
|
def __init__(self): |
|
|
|
super(L1Loss, self).__init__() |
|
|
|
self.get_loss = F.l1_loss |
|
|
|
|
|
|
|
|
|
|
|
class BCELoss(LossBase): |
|
|
|
def __init__(self): |
|
|
|
super(BCELoss, self).__init__() |
|
|
|
self.get_loss = F.binary_cross_entropy |
|
|
|
|
|
|
|
|
|
|
|
class NLLLoss(LossBase): |
|
|
|
def __init__(self): |
|
|
|
super(NLLLoss, self).__init__() |
|
|
|
self.get_loss = F.nll_loss |
|
|
|
|
|
|
|
|
|
|
|
class LossInForward(LossBase): |
|
|
|
|