diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 9b67629e..507073a4 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -91,6 +91,7 @@ class JittorDataLoader: self.dataset.dataset.set_attrs(batch_size=1) # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 # self._collate_fn = _collate_fn + self.cur_batch_indices = None def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 12356074..d008d4ad 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,7 +3,7 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator @@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle=False super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, @@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + def get_batch_indices(self) -> List[int]: + """ + 获取当前 batch 的 idx + + :return: + """ + return self.cur_batch_indices + def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index 2856b656..42200cbb 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -1,5 +1,5 @@ import functools class DummyClass: - def __call__(self, *args, **kwargs): - return + def __init__(self, *args, **kwargs): + pass diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py index 8a38670a..219ccafd 100644 --- a/tests/core/callbacks/test_callback_event.py +++ b/tests/core/callbacks/test_callback_event.py @@ -162,7 +162,7 @@ class TestCallbackEvents: def test_every(self): # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; - event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; + event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -174,7 +174,7 @@ class TestCallbackEvents: _res.append(cu_res) assert _res == list(range(100)) - event_state = Events.on_train_begin(every=10) + event_state = Event.on_train_begin(every=10) @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) def _fn(data): return data @@ -187,7 +187,7 @@ class TestCallbackEvents: assert _res == [w - 1 for w in range(10, 101, 10)] def test_once(self): - event_state = Events.on_train_begin(once=10) + event_state = Event.on_train_begin(once=10) @Filter(once=event_state.once) def _fn(data): @@ -220,7 +220,7 @@ def test_callback_events_torch(): return True return False - event_state = Events.on_train_begin(filter_fn=filter_fn) + event_state = Event.on_train_begin(filter_fn=filter_fn) @Filter(filter_fn=event_state.filter_fn) def _fn(trainer, data): diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index d4b49b89..2de21825 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -2,9 +2,6 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from pathlib import Path import re import time @@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: @@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( topk=None, last=False, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] elif version == 1: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( topk=1, last=True, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 2b59ccd5..08c6f8e2 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -12,9 +12,7 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist + from pathlib import Path import re @@ -29,7 +27,11 @@ from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.log import logger from fastNLP.core.callbacks import MoreEvaluateCallback - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py index 4aa3d4de..3725243e 100644 --- a/tests/core/collators/padders/test_get_padder.py +++ b/tests/core/collators/padders/test_get_padder.py @@ -17,6 +17,7 @@ def test_get_element_shape_dtype(): @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) @pytest.mark.torch @pytest.mark.paddle +@pytest.mark.jittor def test_get_padder_run(backend): if not _NEED_IMPORT_TORCH and backend == 'torch': pytest.skip("No torch") diff --git a/tests/core/collators/padders/test_paddle_padder.py b/tests/core/collators/padders/test_paddle_padder.py index 80abf30a..bea10de0 100644 --- a/tests/core/collators/padders/test_paddle_padder.py +++ b/tests/core/collators/padders/test_paddle_padder.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder +from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from fastNLP.core.collators.padders.exceptions import DtypeError from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE: @pytest.mark.paddle -class TestpaddleNumberPadder: +class TestPaddleNumberPadder: def test_run(self): - padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) a = [1, 2, 3] t_a = padder(a) assert isinstance(t_a, paddle.Tensor) @@ -20,9 +20,9 @@ class TestpaddleNumberPadder: @pytest.mark.paddle -class TestpaddleSequencePadder: +class TestPaddleSequencePadder: def test_run(self): - padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) a = [[1, 2, 3], [3]] a = padder(a) shape = a.shape @@ -32,20 +32,20 @@ class TestpaddleSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) a = padder([[1], [2, 322]]) # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 - padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) @pytest.mark.paddle -class TestpaddleTensorPadder: +class TestPaddleTensorPadder: def test_run(self): - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) a = [paddle.zeros((3,)), paddle.zeros((2,))] a = padder(a) shape = a.shape @@ -74,7 +74,7 @@ class TestpaddleTensorPadder: [[0, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = padder(a) shape = a.shape @@ -85,7 +85,7 @@ class TestpaddleTensorPadder: ]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = padder(a) shape = a.shape @@ -96,11 +96,11 @@ class TestpaddleTensorPadder: assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): - padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) def test_v1(self): print(paddle.zeros((3, )).dtype) diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py index 9742bc9a..9cb38766 100644 --- a/tests/core/collators/padders/test_raw_padder.py +++ b/tests/core/collators/padders/test_raw_padder.py @@ -23,7 +23,6 @@ class TestRawSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - with pytest.raises(DtypeError): - padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) \ No newline at end of file diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 87762c16..ba1e7e08 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -4,7 +4,7 @@ import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR -from fastNLP.core.collators.new_collator import Collator +from fastNLP.core.collators.collator import Collator def _assert_equal(d1, d2): diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index 403d75c2..73eb0d6d 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -1,10 +1,7 @@ import pytest from typing import Any from dataclasses import dataclass -from torch.optim import SGD -from torch.utils.data import DataLoader -from torchmetrics import Accuracy -import torch.distributed as dist + from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.callbacks.callback_event import Event @@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.utils import magic_argv_env_context, Capturing +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + from torchmetrics import Accuracy + import torch.distributed as dist @dataclass @@ -96,10 +99,10 @@ def test_trainer_event_trigger_1( if dist.is_initialized(): dist.destroy_process_group() - Event_attrs = Event.__dict__ - for k, v in Event_attrs.items(): - if isinstance(v, staticmethod): - assert k in output[0] + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @@ -211,7 +214,101 @@ def test_trainer_event_trigger_2( ) trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + Event_attrs = Event.__dict__ for k, v in Event_attrs.items(): if isinstance(v, staticmethod): assert k in output[0] + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + import re + + once_message_1 = "This message should be typed 1 times." + once_message_2 = "test_filter_fn" + once_message_3 = "once message 3" + twice_message = "twice message hei hei" + + @Trainer.on(Event.on_train_epoch_begin(every=2)) + def train_epoch_begin_1(trainer): + print(once_message_1) + + @Trainer.on(Event.on_train_epoch_begin()) + def train_epoch_begin_2(trainer): + print(twice_message) + + @Trainer.on(Event.on_train_epoch_begin(once=2)) + def train_epoch_begin_3(trainer): + print(once_message_3) + + def filter_fn(filter, trainer): + if trainer.cur_epoch_idx == 1: + return True + else: + return False + + @Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn)) + def test_filter_fn(trainer): + print(once_message_2) + + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + + once_pattern_1 = re.compile(once_message_1) + once_pattern_2 = re.compile(once_message_2) + once_pattern_3 = re.compile(once_message_3) + twice_pattern = re.compile(twice_message) + + once_res_1 = once_pattern_1.findall(output[0]) + assert len(once_res_1) == 1 + once_res_2 = once_pattern_2.findall(output[0]) + assert len(once_res_2) == 1 + once_res_3 = once_pattern_3.findall(output[0]) + assert len(once_res_3) == 1 + twice_res = twice_pattern.findall(output[0]) + assert len(twice_res) == 2 + + + + + + + + + + + + + + + + diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py index 9cdec2dd..3d9a5037 100644 --- a/tests/core/controllers/test_trainer_other_things.py +++ b/tests/core/controllers/test_trainer_other_things.py @@ -1,22 +1,22 @@ import pytest from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks import Events +from fastNLP.core.callbacks import Event from tests.helpers.utils import magic_argv_env_context @magic_argv_env_context def test_trainer_torch_without_evaluator(): - @Trainer.on(Events.on_train_epoch_begin(every=10)) + @Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") def fn1(trainer): pass - @Trainer.on(Events.on_train_batch_begin(every=10)) + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn2(trainer, batch, indices): pass - with pytest.raises(AssertionError): - @Trainer.on(Events.on_train_batch_begin(every=10)) + with pytest.raises(BaseException): + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn3(trainer, batch): pass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 891626b5..f44bd735 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -2,9 +2,7 @@ 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; """ import pytest -from torch.optim import SGD -from torch.utils.data import DataLoader -import torch.distributed as dist + from dataclasses import dataclass from typing import Any from torchmetrics import Accuracy @@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + import torch.distributed as dist @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 624f80fb..102ab310 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -2,9 +2,7 @@ import os.path import subprocess import sys import pytest -import torch.distributed as dist -from torch.optim import SGD -from torch.utils.data import DataLoader + from dataclasses import dataclass from typing import Any from pathlib import Path @@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch.distributed as dist + from torch.optim import SGD + from torch.utils.data import DataLoader @dataclass @@ -286,6 +289,7 @@ def test_trainer_on_exception( dist.destroy_process_group() +@pytest.mark.torch @pytest.mark.parametrize("version", [0, 1, 2, 3]) @magic_argv_env_context def test_torch_distributed_launch_1(version): diff --git a/tests/core/controllers/utils/test_utils.py b/tests/core/controllers/utils/test_utils.py index 860d84d5..39c1987a 100644 --- a/tests/core/controllers/utils/test_utils.py +++ b/tests/core/controllers/utils/test_utils.py @@ -11,7 +11,7 @@ class Test_WrapDataLoader: for sanity_batches in all_sanity_batches: data = NormalSampler(num_of_data=1000) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) - dataloader = iter(wrapper(dataloader=data)) + dataloader = iter(wrapper) mark = 0 while True: try: @@ -32,8 +32,7 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - dataloader = iter(dataloader) + dataloader = iter(wrapper) all_supposed_running_data_num = 0 while True: try: @@ -55,6 +54,5 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - length.append(len(dataloader)) + length.append(len(wrapper)) assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) \ No newline at end of file diff --git a/tests/core/drivers/jittor_driver/test_single_device.py b/tests/core/drivers/jittor_driver/test_single_device.py index 8bbceed9..2e220974 100644 --- a/tests/core/drivers/jittor_driver/test_single_device.py +++ b/tests/core/drivers/jittor_driver/test_single_device.py @@ -15,7 +15,7 @@ else: -class Model (Module): +class Model(Module): def __init__ (self): super (Model, self).__init__() self.conv1 = nn.Conv (3, 32, 3, 1) # no padding @@ -45,6 +45,7 @@ class Model (Module): return x @pytest.mark.jittor +@pytest.mark.skip("Skip jittor tests now.") class TestSingleDevice: def test_on_gpu_without_fp16(self): diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 48299bf4..d6f0ee77 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.utils.data import DataLoader, BatchSampler -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, BatchSampler - -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] @@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= # ############################################################################ +@pytest.mark.torch +@magic_argv_env_context +def test_multi_drivers(): + """ + 测试使用了多个 TorchDDPDriver 的情况。 + """ + generate_driver(10, 10) + generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + generate_driver(20, 3, device=[0,1,2]) + assert False + dist.barrier() + + if dist.is_initialized(): + dist.destroy_process_group() + @pytest.mark.torch class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ - @classmethod - def setup_class(cls): - cls.driver = generate_driver(10, 10) - @magic_argv_env_context - def test_multi_drivers(self): + def test_simple_functions(self): """ - 测试使用了多个 TorchDDPDriver 的情况。 + 简单测试多个函数 """ - - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,1,2]) - assert False - dist.barrier() + driver = generate_driver(10, 10) - @magic_argv_env_context - def test_move_data_to_device(self): """ - 这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 - 就不重复测试了 + 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 + tests/core/utils/test_torch_utils.py中,就不重复测试了 """ - self.driver.move_data_to_device(torch.rand((32, 64))) - + driver.move_data_to_device(torch.rand((32, 64))) dist.barrier() - @magic_argv_env_context - def test_is_distributed(self): """ 测试 is_distributed 函数 """ - assert self.driver.is_distributed() == True + assert driver.is_distributed() == True dist.barrier() - @magic_argv_env_context - def test_get_no_sync_context(self): """ 测试 get_no_sync_context 函数 """ - res = self.driver.get_model_no_sync_context() + res = driver.get_model_no_sync_context() dist.barrier() - @magic_argv_env_context - def test_is_global_zero(self): """ 测试 is_global_zero 函数 """ - self.driver.is_global_zero() + driver.is_global_zero() dist.barrier() - @magic_argv_env_context - def test_unwrap_model(self): """ 测试 unwrap_model 函数 """ - self.driver.unwrap_model() + driver.unwrap_model() dist.barrier() - @magic_argv_env_context - def test_get_local_rank(self): """ 测试 get_local_rank 函数 """ - self.driver.get_local_rank() + driver.get_local_rank() dist.barrier() - @magic_argv_env_context - def test_all_gather(self): """ 测试 all_gather 函数 详细的测试在 test_dist_utils.py 中完成 """ obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } - obj_list = self.driver.all_gather(obj, group=None) + obj_list = driver.all_gather(obj, group=None) for i, res in enumerate(obj_list): assert res["rank"] == i - @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1])) - def test_broadcast_object(self, src_rank): """ 测试 broadcast_object 函数 详细的函数在 test_dist_utils.py 中完成 """ - if self.driver.global_rank == src_rank: + if driver.global_rank == 0: obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } else: obj = None - res = self.driver.broadcast_object(obj, src=src_rank) - assert res["rank"] == src_rank + res = driver.broadcast_object(obj, src=0) + assert res["rank"] == 0 + + if dist.is_initialized(): + dist.destroy_process_group() ############################################################################ # @@ -187,7 +180,6 @@ class TestSetDistReproDataloader: @classmethod def setup_class(cls): cls.device = [0, 1] - cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): self.dataset = TorchNormalDataset(40) @@ -204,17 +196,20 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -223,9 +218,10 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -234,9 +230,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` @@ -251,15 +249,17 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): # 应当抛出 RuntimeError - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context - # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ @@ -268,21 +268,24 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -292,12 +295,13 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -307,9 +311,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -318,11 +324,14 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 此时直接返回原来的 dataloader,不做任何处理。 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -337,12 +346,13 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader( dataset=self.dataset, batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) @@ -351,6 +361,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -361,8 +373,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) @@ -372,6 +385,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -381,8 +396,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -392,6 +408,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -407,8 +425,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -418,6 +437,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -427,8 +448,9 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -439,6 +461,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -448,8 +472,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -459,6 +484,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() def check_distributed_sampler(self, sampler): """ @@ -469,7 +496,7 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): """ 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -501,8 +528,8 @@ class TestSetDistReproDataloader: drop_last=False, ) new_loader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) @@ -512,8 +539,8 @@ class TestSetDistReproDataloader: # 重新构造 dataloader new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) for idx, batch in enumerate(new_loader): @@ -534,11 +561,6 @@ class TestSaveLoad: 测试多卡情况下 save 和 load 相关函数的表现 """ - @classmethod - def setup_class(cls): - # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) - def setup_method(self): self.dataset = TorchArgMaxDataset(10, 20) @@ -552,26 +574,26 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) - self.driver1.save_model(path, only_state_dict) + driver1.save_model(path, only_state_dict) # 同步 dist.barrier() - self.driver2.load_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) for idx, batch in enumerate(dataloader): - batch = self.driver1.move_data_to_device(batch) - res1 = self.driver1.model( + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -580,6 +602,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -593,7 +618,7 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, @@ -603,8 +628,8 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -623,7 +648,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -634,11 +659,11 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -652,7 +677,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -664,16 +689,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -686,6 +711,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -700,13 +728,13 @@ class TestSaveLoad: num_replicas = len(device) - self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) - self.driver2 = generate_driver(10, 10, device=device, fp16=False) + driver1 = generate_driver(10, 10, device=device, fp16=fp16) + driver2 = generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -726,18 +754,18 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -753,7 +781,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -765,16 +793,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -786,4 +814,7 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - rank_zero_rm(path) \ No newline at end of file + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index f62ccd0c..8ec70de1 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -2,12 +2,14 @@ import pytest from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver -from fastNLP.envs import get_gpu_count from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.utils import magic_argv_env_context - -import torch - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import device as torchdevice +else: + from fastNLP.core.utils.dummy_class import DummyClass as torchdevice @pytest.mark.torch def test_incorrect_driver(): @@ -20,7 +22,7 @@ def test_incorrect_driver(): @pytest.mark.torch @pytest.mark.parametrize( "device", - ["cpu", "cuda:0", 0, torch.device("cuda:0")] + ["cpu", "cuda:0", 0, torchdevice("cuda:0")] ) @pytest.mark.parametrize( "driver", @@ -83,7 +85,6 @@ def test_get_ddp(driver, device): ("driver", "device"), [("torch_ddp", "cpu")] ) -@magic_argv_env_context def test_get_ddp_cpu(driver, device): """ 测试试图在 cpu 上初始化分布式训练的情况 @@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device): @pytest.mark.torch @pytest.mark.parametrize( "device", - [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] + [-2, [0, 20, 3], [-2], 20] ) @pytest.mark.parametrize( "driver", ["torch", "torch_ddp"] ) -@magic_argv_env_context def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index b89d15db..cadf4e0e 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -7,15 +7,20 @@ import copy import socket import pytest import numpy as np -import torch -import torch.distributed -from torch.multiprocessing import Pool, set_start_method + from sklearn.metrics import accuracy_score as sklearn_accuracy from fastNLP.core.dataset import DataSet from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.metric import Metric from .utils import find_free_network_port, setup_ddp, _assert_allclose +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -26,7 +31,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class: Type[Metric], metric_kwargs: Dict[str, Any], diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py index bc006cb1..75203a3e 100644 --- a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -2,18 +2,23 @@ from functools import partial import copy import pytest -import torch + import numpy as np -from torch.multiprocessing import Pool, set_start_method from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) -def _test(local_rank: int, world_size: int, device: torch.device, +def _test(local_rank: int, world_size: int, device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, metric_result): metric = metric_class(**metric_kwargs) # dataset 也类似(每个进程有自己的一个) diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index 72db05fc..0ebb9bdd 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -5,16 +5,21 @@ import os, sys import copy from functools import partial -import torch -import torch.distributed import numpy as np import socket -from torch.multiprocessing import Pool, set_start_method + # from multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -44,7 +49,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, diff --git a/tests/core/metrics/utils.py b/tests/core/metrics/utils.py index 10157438..4126dc97 100644 --- a/tests/core/metrics/utils.py +++ b/tests/core/metrics/utils.py @@ -2,9 +2,11 @@ import os, sys import socket from typing import Union -import torch -from torch import distributed import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed def setup_ddp(rank: int, world_size: int, master_port: int) -> None: diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 5657ae81..77c618bb 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -3,6 +3,7 @@ import pytest import subprocess from io import StringIO import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results from fastNLP.core import rank_zero_rm diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 395c854d..170110ce 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,4 +1,5 @@ import os +import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing @@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): filepath = None try: with Capturing() as output: - dump_fastnlp_backend() + dump_fastnlp_backend(backend="torch") filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') assert filepath in output[0] assert os.path.exists(filepath) diff --git a/tests/helpers/callbacks/helper_callbacks_torch.py b/tests/helpers/callbacks/helper_callbacks_torch.py index a197bb33..4b9730da 100644 --- a/tests/helpers/callbacks/helper_callbacks_torch.py +++ b/tests/helpers/callbacks/helper_callbacks_torch.py @@ -1,7 +1,9 @@ -import torch from copy import deepcopy from fastNLP.core.callbacks.callback import Callback +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch class RecordAccumulationStepsCallback_Torch(Callback): diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 9a0af019..7c9056cd 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,7 +1,11 @@ import torch from functools import reduce -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from torch.utils.data.sampler import SequentialSampler, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import Dataset, DataLoader, DistributedSampler + from torch.utils.data.sampler import SequentialSampler, BatchSampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class TorchNormalDataset(Dataset): diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 236ffda5..afb441ce 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -1,9 +1,14 @@ -import torch -import torch.nn as nn +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch.nn import Module + import torch.nn as nn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module # 1. 最为基础的分类模型 -class TorchNormalModel_Classification_1(nn.Module): +class TorchNormalModel_Classification_1(Module): """ 单独实现 train_step 和 evaluate_step; """ @@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): return {"preds": x, "target": y} -class TorchNormalModel_Classification_2(nn.Module): +class TorchNormalModel_Classification_2(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; """ @@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} -class TorchNormalModel_Classification_3(nn.Module): +class TorchNormalModel_Classification_3(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 关闭 auto_param_call,forward 只有一个 batch 参数; diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..d6a33a94 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + torch + paddle + jittor + torchpaddle \ No newline at end of file