Browse Source

fix a bug in ClassifyFPRMetric

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
be432c3b39
2 changed files with 9 additions and 8 deletions
  1. +6
    -5
      fastNLP/core/metrics.py
  2. +3
    -3
      test/core/test_metrics.py

+ 6
- 5
fastNLP/core/metrics.py View File

@@ -465,6 +465,7 @@ class ClassifyFPreRecMetric(MetricBase):
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
else:
masks = torch.ones_like(target).long().to(target.device)
masks = masks.eq(0)

if pred.dim() == target.dim():
pass
@@ -477,12 +478,12 @@ class ClassifyFPreRecMetric(MetricBase):
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

target_list = target.tolist()
target_idxes = set(target.reshape(-1).tolist())
target = target.to(pred)
for target_num in target_list:
self._tp[target_num] += torch.sum((pred == target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item()
self._fp[target_num] += torch.sum((pred != target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item()
self._fn[target_num] += torch.sum((pred == target_num).long().masked_fill(target == target_num, 0).masked_fill(masks.eq(0), 0)).item()
for target_idx in target_idxes:
self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item()
self._fp[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item()
self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item()

def get_metric(self, reset=True):
"""


+ 3
- 3
test/core/test_metrics.py View File

@@ -498,7 +498,7 @@ class TestClassfiyFPreRecMetric(unittest.TestCase):
metric = ClassifyFPreRecMetric(f_type='micro')
metric.evaluate(pred, target)
result_dict = metric.get_metric(reset=True)
ground_truth = {'f': 0.85022, 'pre': 0.853982, 'rec': 0.846491}
ground_truth = {'f': 0.84375, 'pre': 0.84375, 'rec': 0.84375}
for keys in ['f', 'pre', 'rec']:
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001)

@@ -507,8 +507,8 @@ class TestClassfiyFPreRecMetric(unittest.TestCase):
result_dict = metric.get_metric(reset=True)
ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0,
'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143,
'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.85022,
'pre': 0.853982, 'rec': 0.846491}
'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.84375,
'pre': 0.84375, 'rec': 0.84375}
for keys in ground_truth.keys():
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001)


Loading…
Cancel
Save