@@ -465,6 +465,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | ||||
else: | else: | ||||
masks = torch.ones_like(target).long().to(target.device) | masks = torch.ones_like(target).long().to(target.device) | ||||
masks = masks.eq(0) | |||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
pass | pass | ||||
@@ -477,12 +478,12 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target_list = target.tolist() | |||||
target_idxes = set(target.reshape(-1).tolist()) | |||||
target = target.to(pred) | target = target.to(pred) | ||||
for target_num in target_list: | |||||
self._tp[target_num] += torch.sum((pred == target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() | |||||
self._fp[target_num] += torch.sum((pred != target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() | |||||
self._fn[target_num] += torch.sum((pred == target_num).long().masked_fill(target == target_num, 0).masked_fill(masks.eq(0), 0)).item() | |||||
for target_idx in target_idxes: | |||||
self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() | |||||
self._fp[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() | |||||
self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | """ | ||||
@@ -498,7 +498,7 @@ class TestClassfiyFPreRecMetric(unittest.TestCase): | |||||
metric = ClassifyFPreRecMetric(f_type='micro') | metric = ClassifyFPreRecMetric(f_type='micro') | ||||
metric.evaluate(pred, target) | metric.evaluate(pred, target) | ||||
result_dict = metric.get_metric(reset=True) | result_dict = metric.get_metric(reset=True) | ||||
ground_truth = {'f': 0.85022, 'pre': 0.853982, 'rec': 0.846491} | |||||
ground_truth = {'f': 0.84375, 'pre': 0.84375, 'rec': 0.84375} | |||||
for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | ||||
@@ -507,8 +507,8 @@ class TestClassfiyFPreRecMetric(unittest.TestCase): | |||||
result_dict = metric.get_metric(reset=True) | result_dict = metric.get_metric(reset=True) | ||||
ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, | ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, | ||||
'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, | 'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, | ||||
'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.85022, | |||||
'pre': 0.853982, 'rec': 0.846491} | |||||
'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.84375, | |||||
'pre': 0.84375, 'rec': 0.84375} | |||||
for keys in ground_truth.keys(): | for keys in ground_truth.keys(): | ||||
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | ||||