From 0cab4dc526b8cbcb980d1d0a595f273417b34a8f Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Sun, 10 Apr 2022 23:10:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9F1RecPreMetric=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../metrics/backend/torch_backend/backend.py | 1 + fastNLP/core/metrics/element.py | 1 + .../core/metrics/span_f1_pre_rec_metric.py | 34 +++-- .../dataloaders/torch_dataloader/test_fdl.py | 38 ++--- tests/core/metrics/test_accuracy_torch.py | 1 - tests/core/metrics/test_f1_rec_acc_torch.py | 134 ++++++++++-------- 6 files changed, 118 insertions(+), 91 deletions(-) diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py index 06304a98..f1db0151 100644 --- a/fastNLP/core/metrics/backend/torch_backend/backend.py +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -34,6 +34,7 @@ class TorchBackend(Backend): if method is None: raise AggregateMethodError(should_have_aggregate_method=True) tensor = self._gather_all(tensor) + # tensor = self.all_gather_object(tensor) if isinstance(tensor[0], torch.Tensor): tensor = torch.stack(tensor) # 第一步, aggregate结果 diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index b3a496bf..483e9a49 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -34,6 +34,7 @@ class Element: 自动aggregate对应的元素 """ + self._check_value_initialized() try: self._value = self.backend.aggregate(self._value, self.aggregate_method) except AggregateMethodError as e: diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index 45b412c8..b92c24dc 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -216,9 +216,9 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): class SpanFPreRecMetric(Metric): - def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, - encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', - beta=1, aggregate_when_get_metric: bool = True,) -> None: + def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, + only_gross: bool = True, f_type='micro', + beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) @@ -249,9 +249,18 @@ class SpanFPreRecMetric(Metric): self.only_gross = only_gross self.tag_vocab = tag_vocab - self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) - self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) - self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) + self._true_positives = {} + self._false_positives = {} + self._false_negatives = {} + for word, _ in tag_vocab: + word = word.lower() + if word != 'o': + word = word.split('-')[1] + if word in self._true_positives: + continue + self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) + self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) + self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) def get_metric(self) -> dict: evaluate_result = {} @@ -284,10 +293,17 @@ class SpanFPreRecMetric(Metric): evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': + tp, fn, fp = [], [], [] + for val in self._true_positives.values(): + tp.append(val.get_scalar()) + for val in self._false_negatives.values(): + fn.append(val.get_scalar()) + for val in self._false_positives.values(): + fp.append(val.get_scalar()) f, pre, rec = _compute_f_pre_rec(self.beta_square, - sum(val.get_scalar() for val in self._true_positives.values()), - sum(val.get_scalar() for val in self._false_negatives.values()), - sum(val.get_scalar() for val in self._false_positives.values())) + sum(tp), + sum(fn), + sum(fp)) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 0cd17ddd..2b1dd8a9 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -1,6 +1,6 @@ import unittest -from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader +from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle @@ -9,17 +9,17 @@ class TestFdl(unittest.TestCase): def test_init_v1(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) + fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) # for batch in fdl: # print(batch) - fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) + fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) # for batch in fdl1: # print(batch) def test_set_padding(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds.set_pad_val("x", val=-1) - fdl = FDataLoader(ds, batch_size=3) + fdl = TorchDataLoader(ds, batch_size=3) fdl.set_input("x", "y") for batch in fdl: print(batch) @@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase): _dict["Y"].append(ins['y']) return _dict - fdl = FDataLoader(ds, batch_size=3, as_numpy=True) + fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) fdl.set_input("x", "y") fdl.add_collator(collate_fn) for batch in fdl: @@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase): def test_get_batch_indices(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - fdl = FDataLoader(ds, batch_size=3, shuffle=True) + fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) fdl.set_input("y", "x") for batch in fdl: print(fdl.get_batch_indices()) @@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase): return object.__getattribute__(self, item) dataset = _DataSet() - dl = FDataLoader(dataset, batch_size=2, shuffle=True) + dl = TorchDataLoader(dataset, batch_size=2, shuffle=True) # dl.set_inputs('data', 'labels') # dl.set_pad_val('labels', val=None) for batch in dl: print(batch) print(dl.get_batch_indices()) - def test_prepare_dataloader(self): + def test_prepare_torch_dataloader(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) - assert isinstance(dl, FDataLoader) + dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, TorchDataLoader) ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) dbl = DataBundle(datasets={'train': ds, 'val': ds1}) - dl_bundle = prepare_dataloader(dbl) - assert isinstance(dl_bundle['train'], FDataLoader) - assert isinstance(dl_bundle['val'], FDataLoader) + dl_bundle = prepare_torch_dataloader(dbl) + assert isinstance(dl_bundle['train'], TorchDataLoader) + assert isinstance(dl_bundle['val'], TorchDataLoader) ds_dict = {'train_1': ds, 'val': ds1} - dl_dict = prepare_dataloader(ds_dict) - assert isinstance(dl_dict['train_1'], FDataLoader) - assert isinstance(dl_dict['val'], FDataLoader) + dl_dict = prepare_torch_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], TorchDataLoader) + assert isinstance(dl_dict['val'], TorchDataLoader) sequence = [ds, ds1] - seq_ds = prepare_dataloader(sequence) - assert isinstance(seq_ds[0], FDataLoader) - assert isinstance(seq_ds[1], FDataLoader) + seq_ds = prepare_torch_dataloader(sequence) + assert isinstance(seq_ds[0], TorchDataLoader) + assert isinstance(seq_ds[1], TorchDataLoader) diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index 33fc791a..b62200db 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -118,7 +118,6 @@ class TestAccuracy: def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'], metric_kwargs: Dict[str, Any]) -> None: global pool - print(pool) if is_ddp: if sys.platform == "win32": pytest.skip("DDP not supported on windows") diff --git a/tests/core/metrics/test_f1_rec_acc_torch.py b/tests/core/metrics/test_f1_rec_acc_torch.py index 121f9530..34067080 100644 --- a/tests/core/metrics/test_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_f1_rec_acc_torch.py @@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet + set_start_method("spawn", force=True) @@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(master_port) - print(torch.cuda.device_count()) if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -64,15 +64,15 @@ def find_free_network_port() -> int: return port -@pytest.fixture(scope='class', autouse=True) -def pre_process(): - global pool - pool = Pool(processes=NUM_PROCESSES) - master_port = find_free_network_port() - pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) - yield - pool.close() - pool.join() +# @pytest.fixture(scope='class', autouse=True) +# def pre_process(): +# global pool +# pool = Pool(processes=NUM_PROCESSES) +# master_port = find_free_network_port() +# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) +# yield +# pool.close() +# pool.join() def _test(local_rank: int, @@ -87,18 +87,19 @@ def _test(local_rank: int, # dataset 也类似(每个进程有自己的一个) dataset = copy.deepcopy(dataset) metric.to(device) - print(os.environ.get("MASTER_PORT", "xx")) # 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) for i in range(local_rank, len(dataset), world_size): pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] + print(tg, seq_len) metric.update(pred, tg, seq_len) my_result = metric.get_metric() + print(my_result) + print(sklearn_metric) assert my_result == sklearn_metric class SpanFPreRecMetricTest(unittest.TestCase): - global pool def test_case1(self): from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans @@ -147,26 +148,26 @@ class SpanFPreRecMetricTest(unittest.TestCase): -1.3508, -0.9513], [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, -0.0842, -0.4294]], - - [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, - -1.4138, -0.8853], - [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, - -1.0726, 0.0364], - [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, - -0.8836, -0.9320], - [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, - -1.6857, 1.1571], - [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, - -0.5837, 1.0184], - [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, - -0.9025, 0.0864]]]) - bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], - [4, 1, 7, 0, 4, 7]]) + [ + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]] + ] + ]) + bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} - assert expect_bio_res == fastnlp_bio_metric.get_metric() # print(fastnlp_bio_metric.get_metric()) @@ -325,44 +326,52 @@ class SpanFPreRecMetricTest(unittest.TestCase): metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') def test_case5(self): - global pool - # pool = Pool(NUM_PROCESSES) - # master_port = find_free_network_port() - # pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) + # global pool + pool = Pool(NUM_PROCESSES) + master_port = find_free_network_port() + pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) number_labels = 4 # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - dataset = DataSet({'pred': [torch.FloatTensor( - [[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, - -0.3782, 0.8240], - [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, - -0.3562, -1.4116], - [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, - 2.0023, 0.7075], - [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, - 0.3832, -0.1540], - [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, - -1.3508, -0.9513], - [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, - -0.0842, -0.4294]], - - [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, - -1.4138, -0.8853], - [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, - -1.0726, 0.0364], - [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, - -0.8836, -0.9320], - [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, - -1.6857, 1.1571], - [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, - -0.5837, 1.0184], - [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, - -0.9025, 0.0864]]])] * 100, - 'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4], - [4, 1, 7, 0, 4, 7]])] * 100, - 'seq_len': [[6, 6]] * 100}) + dataset = DataSet({'pred': [ + torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, + -0.3562, -1.4116], + [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + 2.0023, 0.7075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + 0.3832, -0.1540], + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + -1.3508, -0.9513], + [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]] + + ]), + torch.FloatTensor([ + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]] + ]) + ], + 'tg': [ + torch.LongTensor([[3, 6, 0, 8, 2, 4]]), + torch.LongTensor([[4, 1, 7, 0, 4, 7]]) + ], + 'seq_len': [ + [6], [6] + ]}) metric_kwargs = { 'tag_vocab': fastnlp_bio_vocab, 'only_gross': False, @@ -372,7 +381,6 @@ class SpanFPreRecMetricTest(unittest.TestCase): 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} processes = NUM_PROCESSES - print(torch.cuda.device_count()) pool.starmap( partial( @@ -384,3 +392,5 @@ class SpanFPreRecMetricTest(unittest.TestCase): ), [(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] ) + pool.close() + pool.join()