|
|
|
@@ -54,14 +54,32 @@ class MetricBase(object): |
|
|
|
if len(key_set)>1: |
|
|
|
raise ValueError(f"Several params:{key_set} are provided with one output {value}.") |
|
|
|
|
|
|
|
# check consistence between signature and param_map |
|
|
|
func_spect = inspect.getfullargspec(self.evaluate) |
|
|
|
func_args = func_spect.args |
|
|
|
for func_param, input_param in self.param_map.items(): |
|
|
|
if func_param not in func_args: |
|
|
|
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the " |
|
|
|
f"initialization params, or change {get_func_signature(self.evaluate)} signature.") |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
def __call__(self, output_dict, target_dict, check=False): |
|
|
|
""" |
|
|
|
:param output_dict: |
|
|
|
:param target_dict: |
|
|
|
:param check: boolean, |
|
|
|
|
|
|
|
This method will call self.evaluate method. |
|
|
|
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict |
|
|
|
(1) whether self.evaluate has varargs, which is not supported. |
|
|
|
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. |
|
|
|
(3) whether params needed by self.evaluate duplicate in output_dict, target_dict |
|
|
|
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) |
|
|
|
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and |
|
|
|
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering |
|
|
|
will be conducted) |
|
|
|
:param output_dict: usually the output of forward or prediction function |
|
|
|
:param target_dict: usually features set as target.. |
|
|
|
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not callable(self.evaluate): |
|
|
|
@@ -73,7 +91,7 @@ class MetricBase(object): |
|
|
|
func_args = func_spect.args |
|
|
|
for func_param, input_param in self.param_map.items(): |
|
|
|
if func_param not in func_args: |
|
|
|
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.") |
|
|
|
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.") |
|
|
|
# 2. only part of the param_map are passed, left are not |
|
|
|
for arg in func_args: |
|
|
|
if arg not in self.param_map: |
|
|
|
@@ -97,8 +115,9 @@ class MetricBase(object): |
|
|
|
|
|
|
|
# check duplicated, unused, missing |
|
|
|
if check or not self._checked: |
|
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) |
|
|
|
for key, value in check_res.items(): |
|
|
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict]) |
|
|
|
for key in check_res._fields: |
|
|
|
value = getattr(check_res, key) |
|
|
|
new_value = list(value) |
|
|
|
for idx, func_param in enumerate(value): |
|
|
|
if func_param in self._reverse_param_map: |
|
|
|
@@ -115,21 +134,21 @@ class MetricBase(object): |
|
|
|
|
|
|
|
|
|
|
|
class AccuracyMetric(MetricBase): |
|
|
|
def __init__(self, input=None, targets=None, masks=None, seq_lens=None): |
|
|
|
def __init__(self, input=None, target=None, masks=None, seq_lens=None): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._init_param_map(input=input, targets=targets, |
|
|
|
self._init_param_map(input=input, target=target, |
|
|
|
masks=masks, seq_lens=seq_lens) |
|
|
|
|
|
|
|
self.total = 0 |
|
|
|
self.acc_count = 0 |
|
|
|
|
|
|
|
def evaluate(self, input, targets, masks=None, seq_lens=None): |
|
|
|
def evaluate(self, input, target, masks=None, seq_lens=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) |
|
|
|
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) |
|
|
|
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) |
|
|
|
@@ -140,9 +159,9 @@ class AccuracyMetric(MetricBase): |
|
|
|
if not isinstance(input, torch.Tensor): |
|
|
|
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
f"got {type(input)}.") |
|
|
|
if not isinstance(targets, torch.Tensor): |
|
|
|
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
f"got {type(targets)}.") |
|
|
|
if not isinstance(target, torch.Tensor): |
|
|
|
raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if masks is not None and not isinstance(masks, torch.Tensor): |
|
|
|
raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
@@ -154,20 +173,23 @@ class AccuracyMetric(MetricBase): |
|
|
|
if masks is None and seq_lens is not None: |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) |
|
|
|
|
|
|
|
if input.size()==targets.size(): |
|
|
|
if input.size()==target.size(): |
|
|
|
pass |
|
|
|
elif len(input.size())==len(targets.size())+1: |
|
|
|
elif len(input.size())==len(target.size())+1: |
|
|
|
input = input.argmax(dim=-1) |
|
|
|
else: |
|
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " |
|
|
|
f"size:{input.size()}, targets should with size: {input.size()} or " |
|
|
|
f"{input.size()[:-1]}, got {targets.size()}.") |
|
|
|
f"size:{input.size()}, target should with size: {input.size()} or " |
|
|
|
f"{input.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
input = input.float() |
|
|
|
target = target.float() |
|
|
|
|
|
|
|
if masks is not None: |
|
|
|
self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item() |
|
|
|
self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item() |
|
|
|
self.total += torch.sum(masks.float()).item() |
|
|
|
else: |
|
|
|
self.acc_count += torch.sum(torch.eq(input, targets).float()).item() |
|
|
|
self.acc_count += torch.sum(torch.eq(input, target).float()).item() |
|
|
|
self.total += np.prod(list(input.size())) |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
|