From be432c3b393195b99b2e846fb5758b18dc380a3c Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Wed, 23 Oct 2019 17:36:53 +0800 Subject: [PATCH] fix a bug in ClassifyFPRMetric --- fastNLP/core/metrics.py | 11 ++++++----- test/core/test_metrics.py | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index e06c5650..f1f97b17 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -465,6 +465,7 @@ class ClassifyFPreRecMetric(MetricBase): masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = torch.ones_like(target).long().to(target.device) + masks = masks.eq(0) if pred.dim() == target.dim(): pass @@ -477,12 +478,12 @@ class ClassifyFPreRecMetric(MetricBase): f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") - target_list = target.tolist() + target_idxes = set(target.reshape(-1).tolist()) target = target.to(pred) - for target_num in target_list: - self._tp[target_num] += torch.sum((pred == target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() - self._fp[target_num] += torch.sum((pred != target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() - self._fn[target_num] += torch.sum((pred == target_num).long().masked_fill(target == target_num, 0).masked_fill(masks.eq(0), 0)).item() + for target_idx in target_idxes: + self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() + self._fp[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() + self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() def get_metric(self, reset=True): """ diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index a11bd90b..d45eac79 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -498,7 +498,7 @@ class TestClassfiyFPreRecMetric(unittest.TestCase): metric = ClassifyFPreRecMetric(f_type='micro') metric.evaluate(pred, target) result_dict = metric.get_metric(reset=True) - ground_truth = {'f': 0.85022, 'pre': 0.853982, 'rec': 0.846491} + ground_truth = {'f': 0.84375, 'pre': 0.84375, 'rec': 0.84375} for keys in ['f', 'pre', 'rec']: self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) @@ -507,8 +507,8 @@ class TestClassfiyFPreRecMetric(unittest.TestCase): result_dict = metric.get_metric(reset=True) ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, 'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, - 'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.85022, - 'pre': 0.853982, 'rec': 0.846491} + 'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.84375, + 'pre': 0.84375, 'rec': 0.84375} for keys in ground_truth.keys(): self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001)