Browse Source

add _fast_param_map

tags/v0.2.0^2
yunfan 6 years ago
parent
commit
131e1ccd3b
2 changed files with 18 additions and 4 deletions
  1. +11
    -1
      fastNLP/core/losses.py
  2. +7
    -3
      fastNLP/core/metrics.py

+ 11
- 1
fastNLP/core/losses.py View File

@@ -70,6 +70,12 @@ class LossBase(object):
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).")

def _fast_param_map(self, pred_dict, target_dict):
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0], target_dict.values[0]
return None


def __call__(self, pred_dict, target_dict, check=False):
"""
:param pred_dict: A dict from forward function of the network.
@@ -77,6 +83,11 @@ class LossBase(object):
:param check: Boolean. Force to check the mapping functions when it is running.
:return:
"""
fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param is not None:
loss = self.get_loss(*fast_param)
return loss

args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
if varargs is not None:
raise RuntimeError(
@@ -132,7 +143,6 @@ class LossBase(object):

param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(self.get_loss, **param_map_val)

loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):


+ 7
- 3
fastNLP/core/metrics.py View File

@@ -71,7 +71,7 @@ class MetricBase(object):
def get_metric(self, reset=True):
raise NotImplemented

def _fast_call_evaluate(self, pred_dict, target_dict):
def _fast_param_map(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
@@ -80,7 +80,9 @@ class MetricBase(object):
:param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
"""
return False
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0] and target_dict.values[0]
return None

def __call__(self, pred_dict, target_dict, check=False):
"""
@@ -103,7 +105,9 @@ class MetricBase(object):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not check:
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict):
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param is not None:
self.evaluate(*fast_param)
return

if not self._checked:


Loading…
Cancel
Save