|
|
@@ -52,15 +52,16 @@ class MetricBase(object): |
|
|
|
value_counter[value].add(key) |
|
|
|
for value, key_set in value_counter.items(): |
|
|
|
if len(key_set)>1: |
|
|
|
raise ValueError(f"Several params:{key_set} are provided with one output {value}.") |
|
|
|
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") |
|
|
|
|
|
|
|
# check consistence between signature and param_map |
|
|
|
func_spect = inspect.getfullargspec(self.evaluate) |
|
|
|
func_args = func_spect.args |
|
|
|
for func_param, input_param in self.param_map.items(): |
|
|
|
if func_param not in func_args: |
|
|
|
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the " |
|
|
|
f"initialization params, or change {get_func_signature(self.evaluate)} signature.") |
|
|
|
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " |
|
|
|
f"initialization parameters, or change the signature of" |
|
|
|
f" {get_func_signature(self.evaluate)}.") |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
raise NotImplemented |
|
|
@@ -134,19 +135,19 @@ class MetricBase(object): |
|
|
|
|
|
|
|
|
|
|
|
class AccuracyMetric(MetricBase): |
|
|
|
def __init__(self, input=None, target=None, masks=None, seq_lens=None): |
|
|
|
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._init_param_map(input=input, target=target, |
|
|
|
self._init_param_map(pred=pred, target=target, |
|
|
|
masks=masks, seq_lens=seq_lens) |
|
|
|
|
|
|
|
self.total = 0 |
|
|
|
self.acc_count = 0 |
|
|
|
|
|
|
|
def evaluate(self, input, target, masks=None, seq_lens=None): |
|
|
|
def evaluate(self, pred, target, masks=None, seq_lens=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
|
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) |
|
|
|
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) |
|
|
@@ -156,41 +157,41 @@ class AccuracyMetric(MetricBase): |
|
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
|
:return: dict({'acc': float}) |
|
|
|
""" |
|
|
|
if not isinstance(input, torch.Tensor): |
|
|
|
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
f"got {type(input)}.") |
|
|
|
if not isinstance(pred, torch.Tensor): |
|
|
|
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(pred)}.") |
|
|
|
if not isinstance(target, torch.Tensor): |
|
|
|
raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if masks is not None and not isinstance(masks, torch.Tensor): |
|
|
|
raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(masks)}.") |
|
|
|
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): |
|
|
|
raise NameError(f"`seq_lens` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(seq_lens)}.") |
|
|
|
|
|
|
|
if masks is None and seq_lens is not None: |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) |
|
|
|
|
|
|
|
if input.size()==target.size(): |
|
|
|
if pred.size()==target.size(): |
|
|
|
pass |
|
|
|
elif len(input.size())==len(target.size())+1: |
|
|
|
input = input.argmax(dim=-1) |
|
|
|
elif len(pred.size())==len(target.size())+1: |
|
|
|
pred = pred.argmax(dim=-1) |
|
|
|
else: |
|
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " |
|
|
|
f"size:{input.size()}, target should with size: {input.size()} or " |
|
|
|
f"{input.size()[:-1]}, got {target.size()}.") |
|
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " |
|
|
|
f"size:{pred.size()}, target should have size: {pred.size()} or " |
|
|
|
f"{pred.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
input = input.float() |
|
|
|
pred = pred.float() |
|
|
|
target = target.float() |
|
|
|
|
|
|
|
if masks is not None: |
|
|
|
self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item() |
|
|
|
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() |
|
|
|
self.total += torch.sum(masks.float()).item() |
|
|
|
else: |
|
|
|
self.acc_count += torch.sum(torch.eq(input, target).float()).item() |
|
|
|
self.total += np.prod(list(input.size())) |
|
|
|
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() |
|
|
|
self.total += np.prod(list(pred.size())) |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
evaluate_result = {'acc': self.acc_count/self.total} |
|
|
|