Browse Source

修复ClassifyFPreRecMetric代码中的bug

tags/v1.0.0alpha
yh 3 years ago
parent
commit
4d3a93964f
3 changed files with 35 additions and 22 deletions
  1. +7
    -9
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  2. +27
    -12
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  3. +1
    -1
      tests/core/metrics/test_span_f1_rec_acc_torch.py

+ 7
- 9
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -5,6 +5,7 @@ __all__ = [
from typing import Union, List from typing import Union, List
from collections import Counter from collections import Counter
import warnings import warnings
import numpy as np


from .metric import Metric from .metric import Metric
from .backend import Backend from .backend import Backend
@@ -132,10 +133,10 @@ class ClassifyFPreRecMetric(Metric):
seq_len = self.tensor2numpy(seq_len) seq_len = self.tensor2numpy(seq_len)


if seq_len is not None and target.ndim > 1: if seq_len is not None and target.ndim > 1:
max_len = target.ndim[-1]
max_len = target.shape[-1]
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
else: else:
masks = None
masks = np.ones_like(target)


if pred.ndim == target.ndim: if pred.ndim == target.ndim:
if len(pred.flatten()) != len(target.flatten()): if len(pred.flatten()) != len(target.flatten()):
@@ -143,7 +144,6 @@ class ClassifyFPreRecMetric(Metric):
f" while target have element numbers:{len(pred.flatten())}, " f" while target have element numbers:{len(pred.flatten())}, "
f"pred have element numbers: {len(target.flatten())}") f"pred have element numbers: {len(target.flatten())}")


pass
elif pred.ndim == target.ndim + 1: elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1) pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1: if seq_len is None and target.ndim > 1:
@@ -152,11 +152,9 @@ class ClassifyFPreRecMetric(Metric):
raise RuntimeError(f"when pred have " raise RuntimeError(f"when pred have "
f"size:{pred.shape}, target should have size: {pred.shape} or " f"size:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.") f"{pred.shape[:-1]}, got {target.shape}.")
if masks is not None:
target = target * masks
pred = pred * masks

target_idxes = set(target.reshape(-1).tolist()) target_idxes = set(target.reshape(-1).tolist())
for target_idx in target_idxes: for target_idx in target_idxes:
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()
self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item()
self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item()
self._fn[target_idx] += ((pred != target_idx) * (target == target_idx) * masks).sum().item()

+ 27
- 12
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -31,7 +31,7 @@ def _test(local_rank: int, world_size: int, device: "torch.device",


my_result = metric.get_metric() my_result = metric.get_metric()
for keys in ['f', 'pre', 'rec']: for keys in ['f', 'pre', 'rec']:
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
assert np.allclose(my_result[keys], metric_result[keys], atol=0.000001)




@pytest.mark.torch @pytest.mark.torch
@@ -69,7 +69,6 @@ class TestClassfiyFPreRecMetric:
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], [-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], [-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) [0.8692, 0.7718, -0.6734, 0.6515, 0.0641]])
arg_max_pred = torch.argmax(pred, dim=-1)
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
0, 3, 0, 0, 0, 1, 3, 1]) 0, 3, 0, 0, 0, 1, 3, 1])


@@ -79,10 +78,9 @@ class TestClassfiyFPreRecMetric:
f1_score = 0.1882051282051282 f1_score = 0.1882051282051282
recall = 0.1619047619047619 recall = 0.1619047619047619
pre = 0.23928571428571427 pre = 0.23928571428571427

ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
for keys in ['f', 'pre', 'rec']: for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)


metric = ClassifyFPreRecMetric(f_type='micro') metric = ClassifyFPreRecMetric(f_type='micro')
metric.update(pred, target) metric.update(pred, target)
@@ -93,7 +91,7 @@ class TestClassfiyFPreRecMetric:


ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
for keys in ['f', 'pre', 'rec']: for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)


metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
metric.update(pred, target) metric.update(pred, target)
@@ -103,19 +101,35 @@ class TestClassfiyFPreRecMetric:
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, '1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5},
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, '2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2},
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, '3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9},
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9},
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427,
'recall': 0.1619047619047619, 'support': 32},
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32},
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875,
'support': 32}}
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}}
for keys in result_dict.keys(): for keys in result_dict.keys():
if keys == "f" or "pre" or "rec": if keys == "f" or "pre" or "rec":
continue continue
gl = str(keys[-1]) gl = str(keys[-1])
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"}
gk = tmp_d[keys[0]] gk = tmp_d[keys[0]]
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
assert np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)

def test_seq_len(self):
pred = torch.tensor([[[0.3, 0.7, 0.1], [0.4, 0.1, 0.1], [0.3, 0.1, 0.7]],
[[0.7, 0.1, 0.1], [0.5, 0.9, 0.1], [0.3, 0.1, 0.7]]])
seq_len = torch.LongTensor([3, 2])
target = torch.LongTensor([[1, 0, 2], [0, 1, 0]])

# 不考虑长度
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
metric.update(pred, target)
result_dict = metric.get_metric()
for keys in ['f', 'pre', 'rec']:
assert result_dict[keys] != 1

# 考虑长度
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
metric.update(pred, target, seq_len=seq_len)
result_dict = metric.get_metric()
for keys in ['f', 'pre', 'rec']:
assert result_dict[keys] == 1



@pytest.mark.parametrize("f_type, f1_score,recall,pre", @pytest.mark.parametrize("f_type, f1_score,recall,pre",
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427), [('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427),
@@ -180,3 +194,4 @@ class TestClassfiyFPreRecMetric:
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)]) [(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)])
pool.close() pool.close()
pool.join() pool.join()


+ 1
- 1
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -226,7 +226,7 @@ class TestSpanFPreRecMetric:
# print(expected_metric) # print(expected_metric)
metric_value = metric.get_metric() metric_value = metric.get_metric()
for key, value in expected_metric.items(): for key, value in expected_metric.items():
np.allclose(value, metric_value[key])
assert np.allclose(value, metric_value[key])


def test_auto_encoding_type_infer(self): def test_auto_encoding_type_infer(self):
# 检查是否可以自动check encode的类型 # 检查是否可以自动check encode的类型


Loading…
Cancel
Save