|
|
@@ -70,6 +70,12 @@ class LossBase(object): |
|
|
|
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " |
|
|
|
f"positional argument.).") |
|
|
|
|
|
|
|
def _fast_param_map(self, pred_dict, target_dict): |
|
|
|
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: |
|
|
|
return pred_dict.values[0], target_dict.values[0] |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, pred_dict, target_dict, check=False): |
|
|
|
""" |
|
|
|
:param pred_dict: A dict from forward function of the network. |
|
|
@@ -77,6 +83,11 @@ class LossBase(object): |
|
|
|
:param check: Boolean. Force to check the mapping functions when it is running. |
|
|
|
:return: |
|
|
|
""" |
|
|
|
fast_param = self._fast_param_map(pred_dict, target_dict) |
|
|
|
if fast_param is not None: |
|
|
|
loss = self.get_loss(*fast_param) |
|
|
|
return loss |
|
|
|
|
|
|
|
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) |
|
|
|
if varargs is not None: |
|
|
|
raise RuntimeError( |
|
|
@@ -132,7 +143,6 @@ class LossBase(object): |
|
|
|
|
|
|
|
param_map_val = _map_args(reversed_param_map, **param_val_dict) |
|
|
|
param_value = _build_args(self.get_loss, **param_map_val) |
|
|
|
|
|
|
|
loss = self.get_loss(**param_value) |
|
|
|
|
|
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
|