|
|
@@ -8,8 +8,7 @@ from fastNLP.core.utils import CheckError |
|
|
|
from fastNLP.core.utils import CheckRes |
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
from fastNLP.core.utils import _check_function_or_method |
|
|
|
from fastNLP.core.utils import _get_arg_list |
|
|
|
from fastNLP.core.utils import _map_args |
|
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
|
|
|
|
|
@@ -62,8 +61,7 @@ class LossBase(object): |
|
|
|
if func_param not in func_args: |
|
|
|
raise NameError( |
|
|
|
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " |
|
|
|
f"initialization parameters, or change the signature of" |
|
|
|
f" {get_func_signature(self.get_loss)}.") |
|
|
|
f"initialization parameters, or change its signature.") |
|
|
|
|
|
|
|
# evaluate should not have varargs. |
|
|
|
if func_spect.varargs: |
|
|
@@ -87,71 +85,68 @@ class LossBase(object): |
|
|
|
loss = self.get_loss(*fast_param) |
|
|
|
return loss |
|
|
|
|
|
|
|
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) |
|
|
|
if varargs is not None: |
|
|
|
raise RuntimeError( |
|
|
|
f"The function {get_func_signature(self.get_loss)} should not use Positional Argument." |
|
|
|
) |
|
|
|
|
|
|
|
param_map = self.param_map |
|
|
|
if args is None: |
|
|
|
raise RuntimeError( |
|
|
|
f"There is not any param in function{get_func_signature(self.get_loss)}" |
|
|
|
) |
|
|
|
|
|
|
|
self._checked = self._checked and not check |
|
|
|
if not self._checked: |
|
|
|
for keys in args: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
if defaults is not None: |
|
|
|
for keys in defaults: |
|
|
|
if keys not in param_map: |
|
|
|
param_map.update({keys: keys}) |
|
|
|
self.param_map = param_map |
|
|
|
# param map: key= name in get_loss function, value= name in param dict |
|
|
|
reversed_param_map = {val: key for key, val in param_map.items()} |
|
|
|
# reversed param map: key= name in param dict, value= name in get_loss function |
|
|
|
|
|
|
|
# 1. check consistence between signature and param_map |
|
|
|
func_spect = inspect.getfullargspec(self.get_loss) |
|
|
|
func_args = set([arg for arg in func_spect.args if arg != 'self']) |
|
|
|
for func_arg, input_arg in self.param_map.items(): |
|
|
|
if func_arg not in func_args: |
|
|
|
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") |
|
|
|
|
|
|
|
# 2. only part of the param_map are passed, left are not |
|
|
|
for arg in func_args: |
|
|
|
if arg not in self.param_map: |
|
|
|
self.param_map[arg] = arg # This param does not need mapping. |
|
|
|
self._evaluate_args = func_args |
|
|
|
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} |
|
|
|
|
|
|
|
# need to wrap inputs in dict. |
|
|
|
mapped_pred_dict = {} |
|
|
|
mapped_target_dict = {} |
|
|
|
duplicated = [] |
|
|
|
missing = [] |
|
|
|
if not self._checked: |
|
|
|
for keys, val in pred_dict.items(): |
|
|
|
if keys in target_dict.keys(): |
|
|
|
duplicated.append(param_map[keys]) |
|
|
|
|
|
|
|
param_val_dict = {} |
|
|
|
for keys, val in pred_dict.items(): |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
for keys, val in target_dict.items(): |
|
|
|
param_val_dict.update({keys: val}) |
|
|
|
|
|
|
|
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): |
|
|
|
not_duplicate_flag = 0 |
|
|
|
if input_arg in self._reverse_param_map: |
|
|
|
mapped_arg = self._reverse_param_map[input_arg] |
|
|
|
not_duplicate_flag += 1 |
|
|
|
else: |
|
|
|
mapped_arg = input_arg |
|
|
|
if input_arg in pred_dict: |
|
|
|
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] |
|
|
|
not_duplicate_flag += 1 |
|
|
|
if input_arg in target_dict: |
|
|
|
mapped_target_dict[mapped_arg] = target_dict[input_arg] |
|
|
|
not_duplicate_flag += 1 |
|
|
|
if not_duplicate_flag == 3: |
|
|
|
duplicated.append(input_arg) |
|
|
|
|
|
|
|
# missing |
|
|
|
if not self._checked: |
|
|
|
for keys in args: |
|
|
|
if param_map[keys] not in param_val_dict.keys(): |
|
|
|
missing.append(param_map[keys]) |
|
|
|
|
|
|
|
if len(duplicated) > 0 or len(missing) > 0: |
|
|
|
raise CheckError( |
|
|
|
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[], |
|
|
|
varargs=varargs), |
|
|
|
func_signature=get_func_signature(self.get_loss) |
|
|
|
) |
|
|
|
|
|
|
|
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) |
|
|
|
# only check missing. |
|
|
|
missing = check_res.missing |
|
|
|
replaced_missing = list(missing) |
|
|
|
for idx, func_arg in enumerate(missing): |
|
|
|
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ |
|
|
|
f"in `{self.__class__.__name__}`)" |
|
|
|
|
|
|
|
check_res = CheckRes(missing=replaced_missing, |
|
|
|
unused=check_res.unused, |
|
|
|
duplicated=duplicated, |
|
|
|
required=check_res.required, |
|
|
|
all_needed=check_res.all_needed, |
|
|
|
varargs=check_res.varargs) |
|
|
|
|
|
|
|
if check_res.missing or check_res.duplicated or check_res.varargs: |
|
|
|
raise CheckError(check_res=check_res, |
|
|
|
func_signature=get_func_signature(self.get_loss)) |
|
|
|
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) |
|
|
|
|
|
|
|
loss = self.get_loss(**refined_args) |
|
|
|
self._checked = True |
|
|
|
|
|
|
|
param_map_val = _map_args(reversed_param_map, **param_val_dict) |
|
|
|
param_value = _build_args(self.get_loss, **param_map_val) |
|
|
|
loss = self.get_loss(**param_value) |
|
|
|
|
|
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
|
if not isinstance(loss, torch.Tensor): |
|
|
|
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") |
|
|
|
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}") |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
class LossFunc(LossBase): |
|
|
|
def __init__(self, func, key_map=None, **kwargs): |
|
|
|
super(LossFunc, self).__init__() |
|
|
@@ -168,34 +163,42 @@ class LossFunc(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(LossBase): |
|
|
|
def __init__(self, pred=None, target=None): |
|
|
|
def __init__(self, pred=None, target=None, padding_idx=-100): |
|
|
|
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 |
|
|
|
# TODO (16, 4) |
|
|
|
super(CrossEntropyLoss, self).__init__() |
|
|
|
self.get_loss = F.cross_entropy |
|
|
|
self._init_param_map(input=pred, target=target) |
|
|
|
self._init_param_map(pred=pred, target=target) |
|
|
|
self.padding_idx = padding_idx |
|
|
|
|
|
|
|
def get_loss(self, pred, target): |
|
|
|
return F.cross_entropy(input=pred, target=target, |
|
|
|
ignore_index=self.padding_idx) |
|
|
|
|
|
|
|
class L1Loss(LossBase): |
|
|
|
def __init__(self, pred=None, target=None): |
|
|
|
super(L1Loss, self).__init__() |
|
|
|
self.get_loss = F.l1_loss |
|
|
|
self._init_param_map(input=pred, target=target) |
|
|
|
|
|
|
|
def get_loss(self, pred, target): |
|
|
|
return F.l1_loss(input=pred, target=target) |
|
|
|
|
|
|
|
|
|
|
|
class BCELoss(LossBase): |
|
|
|
def __init__(self, pred=None, target=None): |
|
|
|
super(BCELoss, self).__init__() |
|
|
|
self.get_loss = F.binary_cross_entropy |
|
|
|
self._init_param_map(input=pred, target=target) |
|
|
|
|
|
|
|
def get_loss(self, pred, target): |
|
|
|
return F.binary_cross_entropy(input=pred, target=target) |
|
|
|
|
|
|
|
class NLLLoss(LossBase): |
|
|
|
def __init__(self, pred=None, target=None): |
|
|
|
super(NLLLoss, self).__init__() |
|
|
|
self.get_loss = F.nll_loss |
|
|
|
self._init_param_map(input=pred, target=target) |
|
|
|
|
|
|
|
def get_loss(self, pred, target): |
|
|
|
return F.nll_loss(input=pred, target=target) |
|
|
|
|
|
|
|
|
|
|
|
class LossInForward(LossBase): |
|
|
|
def __init__(self, loss_key='loss'): |
|
|
|