From fcfd8c267ed7f82eb4e1ee8826e8c7fe7567abbb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:03:19 +0000 Subject: [PATCH 01/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9callback=5Fmanager?= =?UTF-8?q?=E7=9A=84=E6=8A=A5=E9=94=99=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 9 ++++++--- tests/helpers/callbacks/helper_callbacks.py | 4 +--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f63c6088..2b8fff60 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -11,6 +11,7 @@ from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.envs import rank_zero_call +from fastNLP.core.utils.utils import _get_fun_msg def _transfer(func): @@ -21,10 +22,12 @@ def _transfer(func): def wrapper(manager, *arg, **kwargs): manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; - returns = [] for callback_fn in manager.callback_fns[func.__name__]: - returns.append(callback_fn(*arg, **kwargs)) - return returns + try: + callback_fn(*arg, **kwargs) + except BaseException as e: + logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") + raise e return wrapper diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 4fd5b654..1e0d0e11 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -36,12 +36,10 @@ class RecordMetricCallback(Callback): self.larger_better = larger_better self.metric = None self.metric_threshold = metric_threshold - self.metric_begin_value = None + self.metric_begin_value = float('-inf') if larger_better else float('inf') def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] - if self.metric_begin_value is None: - self.metric_begin_value = self.metric def on_train_end(self, trainer): if self.larger_better: From 28b3f10a4670729ac1f1afb8ad7c74234b1df051 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:04:03 +0000 Subject: [PATCH 02/16] =?UTF-8?q?=E5=88=A0=E9=99=A4=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=90=8E=E7=9A=84barrier=EF=BC=8C=E9=81=BF=E5=85=8D=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E5=90=8E=E5=8D=A1=E4=BD=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5223c9d8..1e9f907c 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger): raise e finally: self.on_train_end() - self.driver.barrier() def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: From 51a3e901f005333b04ec2b0aad9f5e2c2e9e0a0f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:05:29 +0000 Subject: [PATCH 03/16] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/collators/padders/test_raw_padder.py | 3 +-- tests/core/utils/test_cache_results.py | 1 + tests/envs/test_set_backend.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py index 9742bc9a..9cb38766 100644 --- a/tests/core/collators/padders/test_raw_padder.py +++ b/tests/core/collators/padders/test_raw_padder.py @@ -23,7 +23,6 @@ class TestRawSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - with pytest.raises(DtypeError): - padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) \ No newline at end of file diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 5657ae81..77c618bb 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -3,6 +3,7 @@ import pytest import subprocess from io import StringIO import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results from fastNLP.core import rank_zero_rm diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 395c854d..170110ce 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,4 +1,5 @@ import os +import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing @@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): filepath = None try: with Capturing() as output: - dump_fastnlp_backend() + dump_fastnlp_backend(backend="torch") filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') assert filepath in output[0] assert os.path.exists(filepath) From 96de608ef9dfca82bcb456b9c7a3439cb36c4227 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:21:14 +0000 Subject: [PATCH 04/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9Test=5Ftrainer=5Fother?= =?UTF-8?q?=5Fthings=E4=B8=ADTrainer.on=E7=9A=84=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/controllers/test_trainer_other_things.py | 8 ++++---- tests/core/controllers/test_trainer_wo_evaluator_torch.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py index 9cdec2dd..b010058b 100644 --- a/tests/core/controllers/test_trainer_other_things.py +++ b/tests/core/controllers/test_trainer_other_things.py @@ -7,16 +7,16 @@ from tests.helpers.utils import magic_argv_env_context @magic_argv_env_context def test_trainer_torch_without_evaluator(): - @Trainer.on(Events.on_train_epoch_begin(every=10)) + @Trainer.on(Events.on_train_epoch_begin(every=10), marker="test_trainer_other_things") def fn1(trainer): pass - @Trainer.on(Events.on_train_batch_begin(every=10)) + @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn2(trainer, batch, indices): pass - with pytest.raises(AssertionError): - @Trainer.on(Events.on_train_batch_begin(every=10)) + with pytest.raises(BaseException): + @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn3(trainer, batch): pass diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 825bd425..aa86ef92 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -286,6 +286,7 @@ def test_trainer_on_exception( dist.destroy_process_group() +@pytest.mark.torch @pytest.mark.parametrize("version", [0, 1, 2, 3]) @magic_argv_env_context def test_torch_distributed_launch_1(version): From f422f34ccaa8cd92b02745ab2475e06332786d42 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:21:44 +0000 Subject: [PATCH 05/16] small --- tests/core/controllers/utils/test_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/core/controllers/utils/test_utils.py b/tests/core/controllers/utils/test_utils.py index 0cf7a252..9c9b763b 100644 --- a/tests/core/controllers/utils/test_utils.py +++ b/tests/core/controllers/utils/test_utils.py @@ -11,7 +11,7 @@ class Test_WrapDataLoader: for sanity_batches in all_sanity_batches: data = NormalIterator(num_of_data=1000) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) - dataloader = iter(wrapper(dataloader=data)) + dataloader = iter(wrapper) mark = 0 while True: try: @@ -32,8 +32,7 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - dataloader = iter(dataloader) + dataloader = iter(wrapper) all_supposed_running_data_num = 0 while True: try: @@ -55,6 +54,5 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - length.append(len(dataloader)) + length.append(len(wrapper)) assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) \ No newline at end of file From c1c8f102459d9b015f6c954e9fb30579e54c1410 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 17:13:45 +0000 Subject: [PATCH 06/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9padder=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=9A=84import=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../collators/padders/test_paddle_padder.py | 36 +++++++++---------- tests/core/collators/test_new_collator.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/core/collators/padders/test_paddle_padder.py b/tests/core/collators/padders/test_paddle_padder.py index 80abf30a..bea10de0 100644 --- a/tests/core/collators/padders/test_paddle_padder.py +++ b/tests/core/collators/padders/test_paddle_padder.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder +from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from fastNLP.core.collators.padders.exceptions import DtypeError from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE: @pytest.mark.paddle -class TestpaddleNumberPadder: +class TestPaddleNumberPadder: def test_run(self): - padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) a = [1, 2, 3] t_a = padder(a) assert isinstance(t_a, paddle.Tensor) @@ -20,9 +20,9 @@ class TestpaddleNumberPadder: @pytest.mark.paddle -class TestpaddleSequencePadder: +class TestPaddleSequencePadder: def test_run(self): - padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) a = [[1, 2, 3], [3]] a = padder(a) shape = a.shape @@ -32,20 +32,20 @@ class TestpaddleSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) a = padder([[1], [2, 322]]) # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 - padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) @pytest.mark.paddle -class TestpaddleTensorPadder: +class TestPaddleTensorPadder: def test_run(self): - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) a = [paddle.zeros((3,)), paddle.zeros((2,))] a = padder(a) shape = a.shape @@ -74,7 +74,7 @@ class TestpaddleTensorPadder: [[0, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = padder(a) shape = a.shape @@ -85,7 +85,7 @@ class TestpaddleTensorPadder: ]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = padder(a) shape = a.shape @@ -96,11 +96,11 @@ class TestpaddleTensorPadder: assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): - padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) def test_v1(self): print(paddle.zeros((3, )).dtype) diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py index 87762c16..ba1e7e08 100644 --- a/tests/core/collators/test_new_collator.py +++ b/tests/core/collators/test_new_collator.py @@ -4,7 +4,7 @@ import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR -from fastNLP.core.collators.new_collator import Collator +from fastNLP.core.collators.collator import Collator def _assert_equal(d1, d2): From 5520f597eaedee089a40428fa2b480c7418a3903 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 07:13:58 +0000 Subject: [PATCH 07/16] =?UTF-8?q?pytest=E7=9A=84=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/pytest.ini | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 tests/pytest.ini diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..d6a33a94 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + torch + paddle + jittor + torchpaddle \ No newline at end of file From ebfd0e966c7c96b057b5a8f562c5bf6477c73492 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 07:14:36 +0000 Subject: [PATCH 08/16] =?UTF-8?q?=E4=B8=BAtorch=20driver=E7=9A=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BE=8B=E6=B7=BB=E5=8A=A0=E9=94=80=E6=AF=81=E9=80=9A?= =?UTF-8?q?=E4=BF=A1=E8=BF=9B=E7=A8=8B=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/torch_driver/test_ddp.py | 269 ++++++++++-------- .../test_initialize_torch_driver.py | 16 +- 2 files changed, 161 insertions(+), 124 deletions(-) diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 48299bf4..11799515 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.utils.data import DataLoader, BatchSampler -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, BatchSampler - -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] @@ -73,107 +74,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= ############################################################################ @pytest.mark.torch +@magic_argv_env_context +def test_multi_drivers(): + """ + 测试使用了多个 TorchDDPDriver 的情况。 + """ + generate_driver(10, 10) + generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + generate_driver(20, 3, device=[0,1,2]) + assert False + dist.barrier() + + if dist.is_initialized(): + dist.destroy_process_group() + +@pytest.mark.torch +@pytest.mark.torchtemp class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ - @classmethod - def setup_class(cls): - cls.driver = generate_driver(10, 10) - @magic_argv_env_context - def test_multi_drivers(self): + def test_simple_functions(self): """ - 测试使用了多个 TorchDDPDriver 的情况。 + 简单测试多个函数 """ - - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,1,2]) - assert False - dist.barrier() + driver = generate_driver(10, 10) - @magic_argv_env_context - def test_move_data_to_device(self): """ - 这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 - 就不重复测试了 + 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 + tests/core/utils/test_torch_utils.py中,就不重复测试了 """ - self.driver.move_data_to_device(torch.rand((32, 64))) - + driver.move_data_to_device(torch.rand((32, 64))) dist.barrier() - @magic_argv_env_context - def test_is_distributed(self): """ 测试 is_distributed 函数 """ - assert self.driver.is_distributed() == True + assert driver.is_distributed() == True dist.barrier() - @magic_argv_env_context - def test_get_no_sync_context(self): """ 测试 get_no_sync_context 函数 """ - res = self.driver.get_model_no_sync_context() + res = driver.get_model_no_sync_context() dist.barrier() - @magic_argv_env_context - def test_is_global_zero(self): """ 测试 is_global_zero 函数 """ - self.driver.is_global_zero() + driver.is_global_zero() dist.barrier() - @magic_argv_env_context - def test_unwrap_model(self): """ 测试 unwrap_model 函数 """ - self.driver.unwrap_model() + driver.unwrap_model() dist.barrier() - @magic_argv_env_context - def test_get_local_rank(self): """ 测试 get_local_rank 函数 """ - self.driver.get_local_rank() + driver.get_local_rank() dist.barrier() - @magic_argv_env_context - def test_all_gather(self): """ 测试 all_gather 函数 详细的测试在 test_dist_utils.py 中完成 """ obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } - obj_list = self.driver.all_gather(obj, group=None) + obj_list = driver.all_gather(obj, group=None) for i, res in enumerate(obj_list): assert res["rank"] == i - @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1])) - def test_broadcast_object(self, src_rank): """ 测试 broadcast_object 函数 详细的函数在 test_dist_utils.py 中完成 """ - if self.driver.global_rank == src_rank: + if driver.global_rank == 0: obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } else: obj = None - res = self.driver.broadcast_object(obj, src=src_rank) - assert res["rank"] == src_rank + res = driver.broadcast_object(obj, src=0) + assert res["rank"] == 0 + + if dist.is_initialized(): + dist.destroy_process_group() ############################################################################ # @@ -182,12 +176,12 @@ class TestDDPDriverFunction: ############################################################################ @pytest.mark.torch +@pytest.mark.torchtemp class TestSetDistReproDataloader: @classmethod def setup_class(cls): cls.device = [0, 1] - cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): self.dataset = TorchNormalDataset(40) @@ -204,17 +198,20 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -223,9 +220,10 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -234,9 +232,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` @@ -251,15 +251,17 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): # 应当抛出 RuntimeError - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context - # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ @@ -268,21 +270,24 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -292,12 +297,13 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -307,9 +313,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -318,11 +326,14 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 此时直接返回原来的 dataloader,不做任何处理。 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -337,12 +348,13 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader( dataset=self.dataset, batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) @@ -351,6 +363,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -361,8 +375,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) @@ -372,6 +387,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -381,8 +398,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -392,6 +410,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -407,8 +427,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -418,6 +439,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -427,8 +450,9 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -439,6 +463,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -448,8 +474,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -459,6 +486,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() def check_distributed_sampler(self, sampler): """ @@ -469,7 +498,7 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): """ 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -501,8 +530,8 @@ class TestSetDistReproDataloader: drop_last=False, ) new_loader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) @@ -512,8 +541,8 @@ class TestSetDistReproDataloader: # 重新构造 dataloader new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) for idx, batch in enumerate(new_loader): @@ -534,11 +563,6 @@ class TestSaveLoad: 测试多卡情况下 save 和 load 相关函数的表现 """ - @classmethod - def setup_class(cls): - # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) - def setup_method(self): self.dataset = TorchArgMaxDataset(10, 20) @@ -552,26 +576,26 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) - self.driver1.save_model(path, only_state_dict) + driver1.save_model(path, only_state_dict) # 同步 dist.barrier() - self.driver2.load_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) for idx, batch in enumerate(dataloader): - batch = self.driver1.move_data_to_device(batch) - res1 = self.driver1.model( + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -580,6 +604,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -593,7 +620,7 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, @@ -603,8 +630,8 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -623,7 +650,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -634,11 +661,11 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -652,7 +679,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -664,16 +691,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -686,6 +713,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -700,13 +730,13 @@ class TestSaveLoad: num_replicas = len(device) - self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) - self.driver2 = generate_driver(10, 10, device=device, fp16=False) + driver1 = generate_driver(10, 10, device=device, fp16=fp16) + driver2 = generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -726,18 +756,18 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -753,7 +783,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -765,16 +795,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -786,4 +816,7 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - rank_zero_rm(path) \ No newline at end of file + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index f62ccd0c..8992867e 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -2,12 +2,12 @@ import pytest from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver -from fastNLP.envs import get_gpu_count from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.utils import magic_argv_env_context - -import torch - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist @pytest.mark.torch def test_incorrect_driver(): @@ -55,6 +55,9 @@ def test_get_ddp_2(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) + dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @pytest.mark.torch @@ -76,6 +79,9 @@ def test_get_ddp(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) + dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @pytest.mark.torch @@ -83,7 +89,6 @@ def test_get_ddp(driver, device): ("driver", "device"), [("torch_ddp", "cpu")] ) -@magic_argv_env_context def test_get_ddp_cpu(driver, device): """ 测试试图在 cpu 上初始化分布式训练的情况 @@ -102,7 +107,6 @@ def test_get_ddp_cpu(driver, device): "driver", ["torch", "torch_ddp"] ) -@magic_argv_env_context def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 From f79ee049566bc5914950d82838f200eb57dfe015 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 07:57:32 +0000 Subject: [PATCH 09/16] =?UTF-8?q?=E4=B8=BAtorch=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=5FNEED=5FIMPORT=5FTORCH?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils/dummy_class.py | 4 ++-- .../test_checkpoint_callback_torch.py | 18 +++++++++++------- .../callbacks/test_more_evaluate_callback.py | 10 ++++++---- .../controllers/test_trainer_event_trigger.py | 11 +++++++---- .../test_trainer_w_evaluator_torch.py | 10 ++++++---- .../test_trainer_wo_evaluator_torch.py | 9 ++++++--- .../jittor_driver/test_single_device.py | 3 ++- tests/core/drivers/torch_driver/test_ddp.py | 2 -- .../test_initialize_torch_driver.py | 7 +++++-- tests/core/metrics/test_accuracy_torch.py | 13 +++++++++---- .../test_classify_f1_pre_rec_metric_torch.py | 11 ++++++++--- .../core/metrics/test_span_f1_rec_acc_torch.py | 13 +++++++++---- tests/core/metrics/utils.py | 6 ++++-- .../callbacks/helper_callbacks_torch.py | 4 +++- tests/helpers/datasets/torch_data.py | 8 ++++++-- tests/helpers/models/torch_model.py | 15 ++++++++++----- 16 files changed, 94 insertions(+), 50 deletions(-) diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index 2856b656..42200cbb 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -1,5 +1,5 @@ import functools class DummyClass: - def __call__(self, *args, **kwargs): - return + def __init__(self, *args, **kwargs): + pass diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 976b68ba..a8ce451b 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -2,9 +2,6 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from pathlib import Path import re import time @@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: @@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( topk=None, last=False, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] elif version == 1: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( topk=1, last=True, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 2b59ccd5..08c6f8e2 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -12,9 +12,7 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist + from pathlib import Path import re @@ -29,7 +27,11 @@ from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.log import logger from fastNLP.core.callbacks import MoreEvaluateCallback - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index bcd89614..0d484ac7 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -1,10 +1,7 @@ import pytest from typing import Any from dataclasses import dataclass -from torch.optim import SGD -from torch.utils.data import DataLoader -from torchmetrics import Accuracy -import torch.distributed as dist + from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.callbacks.callback_events import Events @@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.utils import magic_argv_env_context, Capturing +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + from torchmetrics import Accuracy + import torch.distributed as dist @dataclass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 891626b5..f44bd735 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -2,9 +2,7 @@ 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; """ import pytest -from torch.optim import SGD -from torch.utils.data import DataLoader -import torch.distributed as dist + from dataclasses import dataclass from typing import Any from torchmetrics import Accuracy @@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + import torch.distributed as dist @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index aa86ef92..74e5058e 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -2,9 +2,7 @@ import os.path import subprocess import sys import pytest -import torch.distributed as dist -from torch.optim import SGD -from torch.utils.data import DataLoader + from dataclasses import dataclass from typing import Any from pathlib import Path @@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch.distributed as dist + from torch.optim import SGD + from torch.utils.data import DataLoader @dataclass diff --git a/tests/core/drivers/jittor_driver/test_single_device.py b/tests/core/drivers/jittor_driver/test_single_device.py index 8bbceed9..2e220974 100644 --- a/tests/core/drivers/jittor_driver/test_single_device.py +++ b/tests/core/drivers/jittor_driver/test_single_device.py @@ -15,7 +15,7 @@ else: -class Model (Module): +class Model(Module): def __init__ (self): super (Model, self).__init__() self.conv1 = nn.Conv (3, 32, 3, 1) # no padding @@ -45,6 +45,7 @@ class Model (Module): return x @pytest.mark.jittor +@pytest.mark.skip("Skip jittor tests now.") class TestSingleDevice: def test_on_gpu_without_fp16(self): diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 11799515..d6f0ee77 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -92,7 +92,6 @@ def test_multi_drivers(): dist.destroy_process_group() @pytest.mark.torch -@pytest.mark.torchtemp class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 @@ -176,7 +175,6 @@ class TestDDPDriverFunction: ############################################################################ @pytest.mark.torch -@pytest.mark.torchtemp class TestSetDistReproDataloader: @classmethod diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 8992867e..9c3bd8f9 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -8,6 +8,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist + from torch import device as torchdevice +else: + from fastNLP.core.utils.dummy_class import DummyClass as torchdevice @pytest.mark.torch def test_incorrect_driver(): @@ -20,7 +23,7 @@ def test_incorrect_driver(): @pytest.mark.torch @pytest.mark.parametrize( "device", - ["cpu", "cuda:0", 0, torch.device("cuda:0")] + ["cpu", "cuda:0", 0, torchdevice("cuda:0")] ) @pytest.mark.parametrize( "driver", @@ -101,7 +104,7 @@ def test_get_ddp_cpu(driver, device): @pytest.mark.torch @pytest.mark.parametrize( "device", - [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] + [-2, [0, 20, 3], [-2], 20] ) @pytest.mark.parametrize( "driver", diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index b89d15db..cadf4e0e 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -7,15 +7,20 @@ import copy import socket import pytest import numpy as np -import torch -import torch.distributed -from torch.multiprocessing import Pool, set_start_method + from sklearn.metrics import accuracy_score as sklearn_accuracy from fastNLP.core.dataset import DataSet from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.metric import Metric from .utils import find_free_network_port, setup_ddp, _assert_allclose +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -26,7 +31,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class: Type[Metric], metric_kwargs: Dict[str, Any], diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py index bc006cb1..75203a3e 100644 --- a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -2,18 +2,23 @@ from functools import partial import copy import pytest -import torch + import numpy as np -from torch.multiprocessing import Pool, set_start_method from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) -def _test(local_rank: int, world_size: int, device: torch.device, +def _test(local_rank: int, world_size: int, device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, metric_result): metric = metric_class(**metric_kwargs) # dataset 也类似(每个进程有自己的一个) diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index 72db05fc..0ebb9bdd 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -5,16 +5,21 @@ import os, sys import copy from functools import partial -import torch -import torch.distributed import numpy as np import socket -from torch.multiprocessing import Pool, set_start_method + # from multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -44,7 +49,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, diff --git a/tests/core/metrics/utils.py b/tests/core/metrics/utils.py index 10157438..4126dc97 100644 --- a/tests/core/metrics/utils.py +++ b/tests/core/metrics/utils.py @@ -2,9 +2,11 @@ import os, sys import socket from typing import Union -import torch -from torch import distributed import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed def setup_ddp(rank: int, world_size: int, master_port: int) -> None: diff --git a/tests/helpers/callbacks/helper_callbacks_torch.py b/tests/helpers/callbacks/helper_callbacks_torch.py index a197bb33..4b9730da 100644 --- a/tests/helpers/callbacks/helper_callbacks_torch.py +++ b/tests/helpers/callbacks/helper_callbacks_torch.py @@ -1,7 +1,9 @@ -import torch from copy import deepcopy from fastNLP.core.callbacks.callback import Callback +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch class RecordAccumulationStepsCallback_Torch(Callback): diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 9a0af019..7c9056cd 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,7 +1,11 @@ import torch from functools import reduce -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from torch.utils.data.sampler import SequentialSampler, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import Dataset, DataLoader, DistributedSampler + from torch.utils.data.sampler import SequentialSampler, BatchSampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class TorchNormalDataset(Dataset): diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 236ffda5..afb441ce 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -1,9 +1,14 @@ -import torch -import torch.nn as nn +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch.nn import Module + import torch.nn as nn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module # 1. 最为基础的分类模型 -class TorchNormalModel_Classification_1(nn.Module): +class TorchNormalModel_Classification_1(Module): """ 单独实现 train_step 和 evaluate_step; """ @@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): return {"preds": x, "target": y} -class TorchNormalModel_Classification_2(nn.Module): +class TorchNormalModel_Classification_2(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; """ @@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} -class TorchNormalModel_Classification_3(nn.Module): +class TorchNormalModel_Classification_3(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 关闭 auto_param_call,forward 只有一个 batch 参数; From 175ced39059a5b70fcf5588138ed136880ed7306 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 08:25:59 +0000 Subject: [PATCH 10/16] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=20initialize=5Ftorch?= =?UTF-8?q?=5Fdriver=20=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/torch_driver/test_initialize_torch_driver.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 9c3bd8f9..8ec70de1 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -7,7 +7,6 @@ from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch - import torch.distributed as dist from torch import device as torchdevice else: from fastNLP.core.utils.dummy_class import DummyClass as torchdevice @@ -58,9 +57,6 @@ def test_get_ddp_2(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) - dist.barrier() - if dist.is_initialized(): - dist.destroy_process_group() @pytest.mark.torch @@ -82,9 +78,6 @@ def test_get_ddp(driver, device): driver = initialize_torch_driver(driver, device, model) assert isinstance(driver, TorchDDPDriver) - dist.barrier() - if dist.is_initialized(): - dist.destroy_process_group() @pytest.mark.torch From aff84e5955de927ab40ac25cac5eb5d656455468 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 08:41:53 +0000 Subject: [PATCH 11/16] =?UTF-8?q?1=E3=80=81=E4=B8=BATorchDataLoader?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0get=5Fbatch=5Findices=E5=87=BD=E6=95=B0=202?= =?UTF-8?q?=E3=80=81=E5=9C=A8=E8=AE=BE=E7=BD=AEsampler=E5=90=8E=E5=B0=86sh?= =?UTF-8?q?uffle=E8=AE=BE=E7=BD=AE=E4=B8=BAFalse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 12356074..d008d4ad 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,7 +3,7 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator @@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle=False super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, @@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + def get_batch_indices(self) -> List[int]: + """ + 获取当前 batch 的 idx + + :return: + """ + return self.cur_batch_indices + def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, From dc90c0cd2689eb9467bb9794edd08482cc106606 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 08:44:25 +0000 Subject: [PATCH 12/16] =?UTF-8?q?=E4=BD=BF=E7=94=A8=E6=96=B0=E7=9A=84colla?= =?UTF-8?q?tor=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/collators/test_collator.py | 362 +++++++++++++++++----- tests/core/collators/test_new_collator.py | 293 ----------------- 2 files changed, 287 insertions(+), 368 deletions(-) delete mode 100644 tests/core/collators/test_new_collator.py diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 2b56624a..ba1e7e08 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -1,81 +1,293 @@ + +import numpy as np import pytest -from fastNLP.core.collators import AutoCollator -from fastNLP.core.collators.collator import _MultiCollator -from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) class TestCollator: - @pytest.mark.parametrize('as_numpy', [True, False]) - def test_auto_collator(self, as_numpy): - """ - 测试auto_collator的auto_pad功能 - - :param as_numpy: - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=as_numpy) - collator.set_input('x', 'y') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - results = [] - for bucket in bucket_data: - res = collator(bucket) - assert res['x'].shape == (40, 5) - assert res['y'].shape == (40,) - results.append(res) - - def test_auto_collator_v1(self): - """ - 测试auto_collator的set_pad_val和set_pad_val功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - collator.set_input('x') - collator.set_pad_val('x', val=-1) - collator.set_as_numpy(True) - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = collator(bucket) - print(res) - - def test_multicollator(self): - """ - 测试multicollator功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - multi_collator = _MultiCollator(collator) - multi_collator.set_as_numpy(as_numpy=True) - multi_collator.set_pad_val('x', val=-1) - multi_collator.set_input('x') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = multi_collator(bucket) - print(res) + @pytest.mark.torch + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + def test_nest_ignore(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} + } + ] + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': {'int':[1, 1]}}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=None) + collator.set_ignore('str', 'int', 'lst_int') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [{'int':1}, {'int':1}]}} + pad_batch = collator(dict_batch) + findDictDiff(raw_pad_batch, pad_batch) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int') + collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) + pad_batch = collator(dict_batch) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [1, 1]}} + findDictDiff(raw_pad_batch, pad_batch) + + + + + + diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py deleted file mode 100644 index ba1e7e08..00000000 --- a/tests/core/collators/test_new_collator.py +++ /dev/null @@ -1,293 +0,0 @@ - -import numpy as np -import pytest - -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR - -from fastNLP.core.collators.collator import Collator - - -def _assert_equal(d1, d2): - try: - if 'torch' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) - assert (d1 == d2).all().item() - else: - assert all(d1 == d2) - except TypeError: - assert d1 == d2 - except ValueError: - assert (d1 == d2).all() - - -def findDictDiff(d1, d2, path=""): - for k in d1: - if k in d2: - if isinstance(d1[k], dict): - findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) - else: - _assert_equal(d1[k], d2[k]) - else: - raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) - - -def findListDiff(d1, d2): - assert len(d1)==len(d2) - for _d1, _d2 in zip(d1, d2): - if isinstance(_d1, list): - findListDiff(_d1, _d2) - else: - _assert_equal(_d1, _d2) - - -class TestCollator: - - @pytest.mark.torch - def test_run(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - collator = Collator(backend='raw') - assert raw_pad_batch == collator(dict_batch) - collator = Collator(backend='raw') - raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='numpy') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), - 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), - 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), - 'b': np.array([[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='numpy') - numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), - np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), - np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(numpy_pad_lst, collator(list_batch)) - - if _NEED_IMPORT_TORCH: - import torch - collator = Collator(backend='torch') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), - 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), - 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - 'float': torch.FloatTensor([1.1, 2.1]), - 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), - 'numpy': torch.FloatTensor([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), - 'b': torch.LongTensor( - [[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='torch') - torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), - torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), - torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(torch_pad_lst, collator(list_batch)) - - def test_pad(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 测试 set_pad - collator = Collator(backend='raw') - collator.set_pad('str', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - # 测试设置 pad 值 - collator = Collator(backend='raw') - collator.set_pad('nest_lst_int', pad_val=100) - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 设置 backend 和 type - collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - - # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - # [{'1'}, {'2'}]] - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_4', pad_val=None) - raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='raw') - collator.set_pad('_0', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_2', backend='numpy') - collator.set_pad('_4', backend='numpy', pad_val=100) - raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - # _single - collator = Collator() - collator.set_pad('_single') - findListDiff(list_batch, collator(list_batch)) - - def test_nest_ignore(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} - } - ] - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': {'int':[1, 1]}}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=None) - collator.set_ignore('str', 'int', 'lst_int') - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [{'int':1}, {'int':1}]}} - pad_batch = collator(dict_batch) - findDictDiff(raw_pad_batch, pad_batch) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int') - collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) - pad_batch = collator(dict_batch) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [1, 1]}} - findDictDiff(raw_pad_batch, pad_batch) - - - - - - From dd83e0b2c7908bdf466ea9c05ae9e013b26aa803 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 09:35:57 +0000 Subject: [PATCH 13/16] =?UTF-8?q?test=5Fget=5Fpadder=E6=B7=BB=E5=8A=A0jitt?= =?UTF-8?q?or=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/collators/padders/test_get_padder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py index 4aa3d4de..3725243e 100644 --- a/tests/core/collators/padders/test_get_padder.py +++ b/tests/core/collators/padders/test_get_padder.py @@ -17,6 +17,7 @@ def test_get_element_shape_dtype(): @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) @pytest.mark.torch @pytest.mark.paddle +@pytest.mark.jittor def test_get_padder_run(backend): if not _NEED_IMPORT_TORCH and backend == 'torch': pytest.skip("No torch") From 5329614018a52d851b2e8cfced7a18df88a5dab7 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 09:36:17 +0000 Subject: [PATCH 14/16] small --- fastNLP/core/dataloaders/jittor_dataloader/fdl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 9b67629e..507073a4 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -91,6 +91,7 @@ class JittorDataLoader: self.dataset.dataset.set_attrs(batch_size=1) # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 # self._collate_fn = _collate_fn + self.cur_batch_indices = None def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 From 3d4c318f0e38add9ee8ec79a1e136fb20aa3eaf0 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 3 May 2022 17:37:28 +0800 Subject: [PATCH 15/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20test=5Ftrai?= =?UTF-8?q?ner=5Fevent=5Ftrigger=5F3=20=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/test_trainer_event_trigger.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index 1a90a96d..01f28a0b 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -219,6 +219,85 @@ def test_trainer_event_trigger_2( assert member.value in output[0] +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + import re + + once_message_1 = "This message should be typed 1 times." + once_message_2 = "test_filter_fn" + once_message_3 = "once message 3" + twice_message = "twice message hei hei" + + @Trainer.on(Events.on_train_epoch_begin(every=2)) + def train_epoch_begin_1(trainer): + print(once_message_1) + + @Trainer.on(Events.on_train_epoch_begin()) + def train_epoch_begin_2(trainer): + print(twice_message) + + @Trainer.on(Events.on_train_epoch_begin(once=2)) + def train_epoch_begin_3(trainer): + print(once_message_3) + + def filter_fn(filter, trainer): + if trainer.cur_epoch_idx == 1: + return True + else: + return False + + @Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn)) + def test_filter_fn(trainer): + print(once_message_2) + + + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + + once_pattern_1 = re.compile(once_message_1) + once_pattern_2 = re.compile(once_message_2) + once_pattern_3 = re.compile(once_message_3) + twice_pattern = re.compile(twice_message) + + once_res_1 = once_pattern_1.findall(output[0]) + assert len(once_res_1) == 1 + once_res_2 = once_pattern_2.findall(output[0]) + assert len(once_res_2) == 1 + once_res_3 = once_pattern_3.findall(output[0]) + assert len(once_res_3) == 1 + twice_res = twice_pattern.findall(output[0]) + assert len(twice_res) == 2 + + + + + + From c8e8ff4a8cd80672422ce6463ae575b2aa56d17d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 09:47:08 +0000 Subject: [PATCH 16/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E4=B8=AD=E7=9A=84Events=E4=B8=BAEvent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/callbacks/test_callback_event.py | 8 ++++---- tests/core/controllers/test_trainer_other_things.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py index 8a38670a..219ccafd 100644 --- a/tests/core/callbacks/test_callback_event.py +++ b/tests/core/callbacks/test_callback_event.py @@ -162,7 +162,7 @@ class TestCallbackEvents: def test_every(self): # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; - event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; + event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -174,7 +174,7 @@ class TestCallbackEvents: _res.append(cu_res) assert _res == list(range(100)) - event_state = Events.on_train_begin(every=10) + event_state = Event.on_train_begin(every=10) @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -187,7 +187,7 @@ class TestCallbackEvents: assert _res == [w - 1 for w in range(10, 101, 10)] def test_once(self): - event_state = Events.on_train_begin(once=10) + event_state = Event.on_train_begin(once=10) @Filter(once=event_state.once) def _fn(data): @@ -220,7 +220,7 @@ def test_callback_events_torch(): return True return False - event_state = Events.on_train_begin(filter_fn=filter_fn) + event_state = Event.on_train_begin(filter_fn=filter_fn) @Filter(filter_fn=event_state.filter_fn) def _fn(trainer, data): diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py index b010058b..3d9a5037 100644 --- a/tests/core/controllers/test_trainer_other_things.py +++ b/tests/core/controllers/test_trainer_other_things.py @@ -1,22 +1,22 @@ import pytest from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks import Events +from fastNLP.core.callbacks import Event from tests.helpers.utils import magic_argv_env_context @magic_argv_env_context def test_trainer_torch_without_evaluator(): - @Trainer.on(Events.on_train_epoch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") def fn1(trainer): pass - @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn2(trainer, batch, indices): pass with pytest.raises(BaseException): - @Trainer.on(Events.on_train_batch_begin(every=10), marker="test_trainer_other_things") + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn3(trainer, batch): pass