|
@@ -11,7 +11,7 @@ from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.utils import seq_lens_to_masks |
|
|
from fastNLP.core.utils import seq_lens_to_masks |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.utils import CheckRes |
|
|
|
|
|
|
|
|
class MetricBase(object): |
|
|
class MetricBase(object): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
@@ -72,6 +72,17 @@ class MetricBase(object): |
|
|
def get_metric(self, reset=True): |
|
|
def get_metric(self, reset=True): |
|
|
raise NotImplemented |
|
|
raise NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
def _fast_call_evaluate(self, pred_dict, target_dict): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. |
|
|
|
|
|
such as pred_dict has one element, target_dict has one element |
|
|
|
|
|
:param pred_dict: |
|
|
|
|
|
:param target_dict: |
|
|
|
|
|
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. |
|
|
|
|
|
""" |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def __call__(self, pred_dict, target_dict, check=False): |
|
|
def __call__(self, pred_dict, target_dict, check=False): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -79,7 +90,7 @@ class MetricBase(object): |
|
|
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict |
|
|
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict |
|
|
(1) whether self.evaluate has varargs, which is not supported. |
|
|
(1) whether self.evaluate has varargs, which is not supported. |
|
|
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. |
|
|
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. |
|
|
(3) whether params needed by self.evaluate duplicate in output_dict, target_dict |
|
|
|
|
|
|
|
|
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict |
|
|
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) |
|
|
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) |
|
|
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and |
|
|
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and |
|
|
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering |
|
|
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering |
|
@@ -92,6 +103,10 @@ class MetricBase(object): |
|
|
if not callable(self.evaluate): |
|
|
if not callable(self.evaluate): |
|
|
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") |
|
|
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") |
|
|
|
|
|
|
|
|
|
|
|
if not check: |
|
|
|
|
|
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
if not self._checked: |
|
|
if not self._checked: |
|
|
# 1. check consistence between signature and param_map |
|
|
# 1. check consistence between signature and param_map |
|
|
func_spect = inspect.getfullargspec(self.evaluate) |
|
|
func_spect = inspect.getfullargspec(self.evaluate) |
|
@@ -110,28 +125,40 @@ class MetricBase(object): |
|
|
# need to wrap inputs in dict. |
|
|
# need to wrap inputs in dict. |
|
|
mapped_pred_dict = {} |
|
|
mapped_pred_dict = {} |
|
|
mapped_target_dict = {} |
|
|
mapped_target_dict = {} |
|
|
|
|
|
duplicated = [] |
|
|
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): |
|
|
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): |
|
|
|
|
|
not_duplicate_flag = 0 |
|
|
if input_arg in self._reverse_param_map: |
|
|
if input_arg in self._reverse_param_map: |
|
|
mapped_arg = self._reverse_param_map[input_arg] |
|
|
mapped_arg = self._reverse_param_map[input_arg] |
|
|
|
|
|
not_duplicate_flag += 1 |
|
|
else: |
|
|
else: |
|
|
mapped_arg = input_arg |
|
|
mapped_arg = input_arg |
|
|
if input_arg in pred_dict: |
|
|
if input_arg in pred_dict: |
|
|
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] |
|
|
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] |
|
|
|
|
|
not_duplicate_flag += 1 |
|
|
if input_arg in target_dict: |
|
|
if input_arg in target_dict: |
|
|
mapped_target_dict[mapped_arg] = target_dict[input_arg] |
|
|
mapped_target_dict[mapped_arg] = target_dict[input_arg] |
|
|
|
|
|
not_duplicate_flag += 1 |
|
|
|
|
|
if not_duplicate_flag == 3: |
|
|
|
|
|
duplicated.append(input_arg) |
|
|
|
|
|
|
|
|
# check duplicated, unused, missing |
|
|
|
|
|
|
|
|
# missing |
|
|
if check or not self._checked: |
|
|
if check or not self._checked: |
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) |
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) |
|
|
for key in check_res._fields: |
|
|
|
|
|
value = getattr(check_res, key) |
|
|
|
|
|
new_value = list(value) |
|
|
|
|
|
# TODO 这里报错的逻辑应该是怎样的? |
|
|
|
|
|
for idx, func_arg in enumerate(value): |
|
|
|
|
|
if func_arg in self.param_map: |
|
|
|
|
|
new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' |
|
|
|
|
|
else: |
|
|
|
|
|
new_value[idx] = func_arg |
|
|
|
|
|
|
|
|
# only check missing. |
|
|
|
|
|
missing = check_res.missing |
|
|
|
|
|
replaced_missing = list(missing) |
|
|
|
|
|
for idx, func_arg in enumerate(missing): |
|
|
|
|
|
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \ |
|
|
|
|
|
f"in `{get_func_signature(self.evaluate)}`)" |
|
|
|
|
|
|
|
|
|
|
|
check_res = CheckRes(missing=replaced_missing, |
|
|
|
|
|
unused=check_res.unused, |
|
|
|
|
|
duplicated=duplicated, |
|
|
|
|
|
required=check_res.required, |
|
|
|
|
|
all_needed=check_res.all_needed, |
|
|
|
|
|
varargs=check_res.varargs) |
|
|
|
|
|
|
|
|
if check_res.missing or check_res.duplicated or check_res.varargs: |
|
|
if check_res.missing or check_res.duplicated or check_res.varargs: |
|
|
raise CheckError(check_res=check_res, |
|
|
raise CheckError(check_res=check_res, |
|
|
func_signature=get_func_signature(self.evaluate)) |
|
|
func_signature=get_func_signature(self.evaluate)) |
|
@@ -140,6 +167,7 @@ class MetricBase(object): |
|
|
self.evaluate(**refined_args) |
|
|
self.evaluate(**refined_args) |
|
|
self._checked = True |
|
|
self._checked = True |
|
|
|
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
class AccuracyMetric(MetricBase): |
|
|
class AccuracyMetric(MetricBase): |
|
|
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): |
|
|
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): |
|
@@ -151,6 +179,22 @@ class AccuracyMetric(MetricBase): |
|
|
self.total = 0 |
|
|
self.total = 0 |
|
|
self.acc_count = 0 |
|
|
self.acc_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
def _fast_call_evaluate(self, pred_dict, target_dict): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. |
|
|
|
|
|
such as pred_dict has one element, target_dict has one element |
|
|
|
|
|
:param pred_dict: |
|
|
|
|
|
:param target_dict: |
|
|
|
|
|
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. |
|
|
|
|
|
""" |
|
|
|
|
|
if len(pred_dict)==1 and len(target_dict)==1: |
|
|
|
|
|
pred = list(pred_dict.values())[0] |
|
|
|
|
|
target = list(target_dict.values())[0] |
|
|
|
|
|
self.evaluate(pred=pred, target=target) |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def evaluate(self, pred, target, masks=None, seq_lens=None): |
|
|
def evaluate(self, pred, target, masks=None, seq_lens=None): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -164,6 +208,7 @@ class AccuracyMetric(MetricBase): |
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
:return: dict({'acc': float}) |
|
|
:return: dict({'acc': float}) |
|
|
""" |
|
|
""" |
|
|
|
|
|
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value |
|
|
if not isinstance(pred, torch.Tensor): |
|
|
if not isinstance(pred, torch.Tensor): |
|
|
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
f"got {type(pred)}.") |
|
|
f"got {type(pred)}.") |
|
|