From a6103f634253458b909f3e1d8113f94e2f34921c Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 30 Apr 2022 10:51:55 +0000 Subject: [PATCH 1/4] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=E4=B8=8D?= =?UTF-8?q?=E9=9C=80=E8=A6=81pytest=E7=9A=84=E6=B5=8B=E8=AF=95=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/{test_trainer_fleet.py => _test_trainer_fleet.py} | 0 ...st_trainer_fleet_outside.py => _test_trainer_fleet_outside.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/core/controllers/{test_trainer_fleet.py => _test_trainer_fleet.py} (100%) rename tests/core/controllers/{test_trainer_fleet_outside.py => _test_trainer_fleet_outside.py} (100%) diff --git a/tests/core/controllers/test_trainer_fleet.py b/tests/core/controllers/_test_trainer_fleet.py similarity index 100% rename from tests/core/controllers/test_trainer_fleet.py rename to tests/core/controllers/_test_trainer_fleet.py diff --git a/tests/core/controllers/test_trainer_fleet_outside.py b/tests/core/controllers/_test_trainer_fleet_outside.py similarity index 100% rename from tests/core/controllers/test_trainer_fleet_outside.py rename to tests/core/controllers/_test_trainer_fleet_outside.py From b3c9819fb84c93b674af71bee60f50aed3179fab Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 30 Apr 2022 12:55:57 +0000 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20=5F=5Finit=5F=5F.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/jittor_dataloader/__init__.py | 0 tests/core/dataloaders/paddle_dataloader/__init__.py | 0 tests/core/dataloaders/torch_dataloader/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/core/dataloaders/jittor_dataloader/__init__.py create mode 100644 tests/core/dataloaders/paddle_dataloader/__init__.py create mode 100644 tests/core/dataloaders/torch_dataloader/__init__.py diff --git a/tests/core/dataloaders/jittor_dataloader/__init__.py b/tests/core/dataloaders/jittor_dataloader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/dataloaders/paddle_dataloader/__init__.py b/tests/core/dataloaders/paddle_dataloader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/dataloaders/torch_dataloader/__init__.py b/tests/core/dataloaders/torch_dataloader/__init__.py new file mode 100644 index 00000000..e69de29b From cf2ef2ecd79a43f9ecf4054f067231fc421e0dd9 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 30 Apr 2022 13:04:55 +0000 Subject: [PATCH 3/4] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E7=9A=84backend=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paddle_driver/initialize_paddle_driver.py | 2 +- .../torch_driver/initialize_torch_driver.py | 4 +-- fastNLP/core/metrics/utils.py | 5 ++- .../test_load_best_model_callback_torch.py | 4 +-- tests/core/controllers/_test_trainer_fleet.py | 1 - .../_test_trainer_fleet_outside.py | 1 - tests/core/controllers/test_trainer_paddle.py | 4 +-- .../drivers/paddle_driver/test_dist_utils.py | 1 - .../core/drivers/paddle_driver/test_fleet.py | 2 -- .../test_initialize_paddle_driver.py | 3 -- .../paddle_driver/test_single_device.py | 3 -- .../core/drivers/paddle_driver/test_utils.py | 2 -- tests/core/drivers/torch_driver/test.py | 31 +++++++++++++++++++ tests/core/drivers/torch_driver/test_ddp.py | 2 -- .../test_initialize_torch_driver.py | 3 -- .../torch_driver/test_single_device.py | 2 -- tests/core/drivers/torch_driver/test_utils.py | 2 -- .../core/samplers/test_unrepeated_sampler.py | 18 +++++------ 18 files changed, 48 insertions(+), 42 deletions(-) create mode 100644 tests/core/drivers/torch_driver/test.py diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index 9a9d4198..c0489e6e 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -14,7 +14,7 @@ if _NEED_IMPORT_PADDLE: import paddle def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], - model: paddle.nn.Layer, **kwargs) -> PaddleDriver: + model: "paddle.nn.Layer", **kwargs) -> PaddleDriver: r""" 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; 1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 5ee946c4..7cef7316 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -11,8 +11,8 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_BACKEND_LAUNCH -def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.device, int, List[int]]], - model: torch.nn.Module, **kwargs) -> TorchDriver: +def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]], + model: "torch.nn.Module", **kwargs) -> TorchDriver: r""" 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; 注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; diff --git a/fastNLP/core/metrics/utils.py b/fastNLP/core/metrics/utils.py index ce6f618b..6d3fd74a 100644 --- a/fastNLP/core/metrics/utils.py +++ b/fastNLP/core/metrics/utils.py @@ -11,9 +11,8 @@ _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') if _IS_ALLENNLP_AVAILABLE: from allennlp.training.metrics import Metric as allennlp_Metric -if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: - if _IS_TORCHMETRICS_AVAILABLE: - from torchmetrics import Metric as torchmetrics_Metric +if _IS_TORCHMETRICS_AVAILABLE: + from torchmetrics import Metric as torchmetrics_Metric if _NEED_IMPORT_PADDLE: from paddle.metric import Metric as paddle_Metric diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 0bc63bd5..b042ae0f 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -16,7 +16,7 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback from fastNLP.core import Evaluator -from fastNLP.core.utils.utils import safe_rm +from fastNLP.core import rank_zero_rm from fastNLP.core.drivers.torch_driver import TorchSingleDriver from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset @@ -112,7 +112,7 @@ def test_load_best_model_callback( results = evaluator.run() assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) if save_folder: - safe_rm(save_folder) + rank_zero_rm(save_folder) if dist.is_initialized(): dist.destroy_process_group() diff --git a/tests/core/controllers/_test_trainer_fleet.py b/tests/core/controllers/_test_trainer_fleet.py index 46201c67..f438b6de 100644 --- a/tests/core/controllers/_test_trainer_fleet.py +++ b/tests/core/controllers/_test_trainer_fleet.py @@ -4,7 +4,6 @@ python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py """ import os -os.environ["FASTNLP_BACKEND"] = "paddle" import sys sys.path.append("../../../") diff --git a/tests/core/controllers/_test_trainer_fleet_outside.py b/tests/core/controllers/_test_trainer_fleet_outside.py index a48434fa..e8c9a244 100644 --- a/tests/core/controllers/_test_trainer_fleet_outside.py +++ b/tests/core/controllers/_test_trainer_fleet_outside.py @@ -4,7 +4,6 @@ python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py """ import os -os.environ["FASTNLP_BACKEND"] = "paddle" import sys sys.path.append("../../../") diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 8a3ab2ce..aaf20105 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -1,6 +1,4 @@ import pytest -import os -os.environ["FASTNLP_BACKEND"] = "paddle" from dataclasses import dataclass from fastNLP.core.controllers.trainer import Trainer @@ -25,7 +23,7 @@ class TrainPaddleConfig: shuffle: bool = True evaluate_every = 2 -@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) +@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), RichCallback(5)]]) diff --git a/tests/core/drivers/paddle_driver/test_dist_utils.py b/tests/core/drivers/paddle_driver/test_dist_utils.py index 9b81c38d..bd43378e 100644 --- a/tests/core/drivers/paddle_driver/test_dist_utils.py +++ b/tests/core/drivers/paddle_driver/test_dist_utils.py @@ -3,7 +3,6 @@ import sys import signal import pytest import traceback -os.environ["FASTNLP_BACKEND"] = "paddle" import numpy as np diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 34c80888..6190dd8c 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,8 +1,6 @@ import pytest -import os from pathlib import Path -os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver from fastNLP.core.samplers import ( RandomSampler, diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index df96d746..c8b5bfff 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -1,8 +1,5 @@ -import os import pytest -os.environ["FASTNLP_BACKEND"] = "paddle" - from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver from fastNLP.envs import get_gpu_count diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 2aa4e0e6..ec40e9f3 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,6 +1,3 @@ -import os -from re import S -os.environ["FASTNLP_BACKEND"] = "paddle" import pytest from pathlib import Path diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 690d0fb8..69be8055 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -1,6 +1,4 @@ -import os import pytest -os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.utils import ( get_device_from_visible, diff --git a/tests/core/drivers/torch_driver/test.py b/tests/core/drivers/torch_driver/test.py new file mode 100644 index 00000000..3a1a280d --- /dev/null +++ b/tests/core/drivers/torch_driver/test.py @@ -0,0 +1,31 @@ +import sys +sys.path.append("../../../../") +from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + +import torch + +device = [0, 1] +torch_model = TorchNormalModel_Classification_1(10, 10) +torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) +device = [torch.device(i) for i in device] +driver = TorchDDPDriver( + model=torch_model, + parallel_device=device, + fp16=False +) +driver.set_optimizers(torch_opt) +driver.setup() +print("-----------first--------------") + +device = [0, 2] +torch_model = TorchNormalModel_Classification_1(10, 10) +torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) +device = [torch.device(i) for i in device] +driver = TorchDDPDriver( + model=torch_model, + parallel_device=device, + fp16=False +) +driver.set_optimizers(torch_opt) +driver.setup() \ No newline at end of file diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 0e91fe77..87787fbc 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -1,8 +1,6 @@ import pytest -import os from pathlib import Path -os.environ["FASTNLP_BACKEND"] = "torch" from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver from fastNLP.core.samplers import ( RandomSampler, diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 6c47e30e..3e612964 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -1,8 +1,5 @@ -import os import pytest -os.environ["FASTNLP_BACKEND"] = "torch" - from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver from fastNLP.envs import get_gpu_count diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index b8a8def9..f46f69c0 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -1,5 +1,3 @@ -import os -os.environ["FASTNLP_BACKEND"] = "torch" import pytest from pathlib import Path diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index 8f0172e0..4df767b5 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -1,6 +1,4 @@ -import os import pytest -os.environ["FASTNLP_BACKEND"] = "torch" from fastNLP.core.drivers.torch_driver.utils import ( replace_batch_sampler, diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py index 4a271f41..39d4e34f 100644 --- a/tests/core/samplers/test_unrepeated_sampler.py +++ b/tests/core/samplers/test_unrepeated_sampler.py @@ -28,12 +28,12 @@ class TestUnrepeatedSampler: @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('shuffle', [False, True]) - def test_multi(self, num_replica, num_of_data, shuffle): + def test_multi(self, num_replicas, num_of_data, shuffle): data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) - sampler.set_distributed(num_replica, rank=i) + sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) indexes = list(chain(*samplers)) @@ -52,12 +52,12 @@ class TestUnrepeatedSortedSampler: @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) - def test_multi(self, num_replica, num_of_data): + def test_multi(self, num_replicas, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) - sampler.set_distributed(num_replica, rank=i) + sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) # 保证顺序是没乱的 @@ -83,12 +83,12 @@ class TestUnrepeatedSequentialSampler: @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) - def test_multi(self, num_replica, num_of_data): + def test_multi(self, num_replicas, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) - sampler.set_distributed(num_replica, rank=i) + sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) # 保证顺序是没乱的 From 35f05932687ddf93229d5d26987e9030b744acd9 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 30 Apr 2022 21:39:20 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=BA=9B=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6=E7=9A=84=E5=90=8D?= =?UTF-8?q?=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../{test_logger.py => test_logger_torch.py} | 0 .../test_reproducible_batch_sampler.py | 294 +++++++++--------- 2 files changed, 147 insertions(+), 147 deletions(-) rename tests/core/log/{test_logger.py => test_logger_torch.py} (100%) diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger_torch.py similarity index 100% rename from tests/core/log/test_logger.py rename to tests/core/log/test_logger_torch.py diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 3514c331..6cf4b7d4 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -9,153 +9,153 @@ from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset - -class TestReproducibleBatchSampler: - # TODO 拆分测试,在这里只测试一个东西 - def test_torch_dataloader_1(self): - import torch - from torch.utils.data import DataLoader - # no shuffle - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - forward_steps = 3 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - next(iter_dataloader) - - # 1. 保存状态 - _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, RandomBatchSampler) - state = _get_re_batchsampler.state_dict() - assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, - "sampler_type": "RandomBatchSampler"} - - # 2. 断点重训,重新生成一个 dataloader; - # 不改变 batch_size; - dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - real_res = [] - supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) - forward_steps = 2 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - real_res.append(next(iter_dataloader)) - - for i in range(forward_steps): - assert all(real_res[i] == supposed_res[i]) - - # 改变 batch_size; - after_batch_size = 3 - dataloader = DataLoader(dataset, batch_size=after_batch_size) - re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - real_res = [] - supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) - forward_steps = 2 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - real_res.append(next(iter_dataloader)) - - for i in range(forward_steps): - assert all(real_res[i] == supposed_res[i]) - - # 断点重训的第二轮是否是一个完整的 dataloader; - # 先把断点重训所在的那一个 epoch 跑完; - begin_idx = 27 - while True: - try: - data = next(iter_dataloader) - _batch_size = len(data) - assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) - begin_idx += _batch_size - except StopIteration: - break - - # 开始新的一轮; - begin_idx = 0 - iter_dataloader = iter(dataloader) - while True: - try: - data = next(iter_dataloader) - _batch_size = len(data) - assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) - begin_idx += _batch_size - except StopIteration: - break - - def test_torch_dataloader_2(self): - # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; - from torch.utils.data import DataLoader - # no shuffle - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; - dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - # 将一轮的所有数据保存下来,看是否恢复的是正确的; - all_supposed_data = [] - forward_steps = 3 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - all_supposed_data.extend(next(iter_dataloader).tolist()) - - # 1. 保存状态 - _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, RandomBatchSampler) - state = _get_re_batchsampler.state_dict() - - # 2. 断点重训,重新生成一个 dataloader; - # 不改变 batch_size; - dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - # 先把这一轮的数据过完; - pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] - while True: - try: - all_supposed_data.extend(next(iter_dataloader).tolist()) - except StopIteration: - break - assert all_supposed_data == list(pre_index_list) - - # 重新开启新的一轮; - for _ in range(3): - iter_dataloader = iter(dataloader) - res = [] - while True: - try: - res.append(next(iter_dataloader)) - except StopIteration: - break - - def test_3(self): - import torch - from torch.utils.data import DataLoader - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; - dataloader = DataLoader(dataset, batch_size=before_batch_size) - - for idx, data in enumerate(dataloader): - if idx > 3: - break - - iterator = iter(dataloader) - for each in iterator: - pass +# +# class TestReproducibleBatchSampler: +# # TODO 拆分测试,在这里只测试一个东西 +# def test_torch_dataloader_1(self): +# import torch +# from torch.utils.data import DataLoader +# # no shuffle +# before_batch_size = 7 +# dataset = TorchNormalDataset(num_of_data=100) +# dataloader = DataLoader(dataset, batch_size=before_batch_size) +# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# dataloader = replace_batch_sampler(dataloader, re_batchsampler) +# +# forward_steps = 3 +# iter_dataloader = iter(dataloader) +# for _ in range(forward_steps): +# next(iter_dataloader) +# +# # 1. 保存状态 +# _get_re_batchsampler = dataloader.batch_sampler +# assert isinstance(_get_re_batchsampler, RandomBatchSampler) +# state = _get_re_batchsampler.state_dict() +# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, +# "sampler_type": "RandomBatchSampler"} +# +# # 2. 断点重训,重新生成一个 dataloader; +# # 不改变 batch_size; +# dataloader = DataLoader(dataset, batch_size=before_batch_size) +# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler.load_state_dict(state) +# dataloader = replace_batch_sampler(dataloader, re_batchsampler) +# +# real_res = [] +# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) +# forward_steps = 2 +# iter_dataloader = iter(dataloader) +# for _ in range(forward_steps): +# real_res.append(next(iter_dataloader)) +# +# for i in range(forward_steps): +# assert all(real_res[i] == supposed_res[i]) +# +# # 改变 batch_size; +# after_batch_size = 3 +# dataloader = DataLoader(dataset, batch_size=after_batch_size) +# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler.load_state_dict(state) +# dataloader = replace_batch_sampler(dataloader, re_batchsampler) +# +# real_res = [] +# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) +# forward_steps = 2 +# iter_dataloader = iter(dataloader) +# for _ in range(forward_steps): +# real_res.append(next(iter_dataloader)) +# +# for i in range(forward_steps): +# assert all(real_res[i] == supposed_res[i]) +# +# # 断点重训的第二轮是否是一个完整的 dataloader; +# # 先把断点重训所在的那一个 epoch 跑完; +# begin_idx = 27 +# while True: +# try: +# data = next(iter_dataloader) +# _batch_size = len(data) +# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) +# begin_idx += _batch_size +# except StopIteration: +# break +# +# # 开始新的一轮; +# begin_idx = 0 +# iter_dataloader = iter(dataloader) +# while True: +# try: +# data = next(iter_dataloader) +# _batch_size = len(data) +# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) +# begin_idx += _batch_size +# except StopIteration: +# break +# +# def test_torch_dataloader_2(self): +# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; +# from torch.utils.data import DataLoader +# # no shuffle +# before_batch_size = 7 +# dataset = TorchNormalDataset(num_of_data=100) +# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; +# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) +# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# dataloader = replace_batch_sampler(dataloader, re_batchsampler) +# +# # 将一轮的所有数据保存下来,看是否恢复的是正确的; +# all_supposed_data = [] +# forward_steps = 3 +# iter_dataloader = iter(dataloader) +# for _ in range(forward_steps): +# all_supposed_data.extend(next(iter_dataloader).tolist()) +# +# # 1. 保存状态 +# _get_re_batchsampler = dataloader.batch_sampler +# assert isinstance(_get_re_batchsampler, RandomBatchSampler) +# state = _get_re_batchsampler.state_dict() +# +# # 2. 断点重训,重新生成一个 dataloader; +# # 不改变 batch_size; +# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) +# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler.load_state_dict(state) +# dataloader = replace_batch_sampler(dataloader, re_batchsampler) +# +# # 先把这一轮的数据过完; +# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] +# while True: +# try: +# all_supposed_data.extend(next(iter_dataloader).tolist()) +# except StopIteration: +# break +# assert all_supposed_data == list(pre_index_list) +# +# # 重新开启新的一轮; +# for _ in range(3): +# iter_dataloader = iter(dataloader) +# res = [] +# while True: +# try: +# res.append(next(iter_dataloader)) +# except StopIteration: +# break +# +# def test_3(self): +# import torch +# from torch.utils.data import DataLoader +# before_batch_size = 7 +# dataset = TorchNormalDataset(num_of_data=100) +# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; +# dataloader = DataLoader(dataset, batch_size=before_batch_size) +# +# for idx, data in enumerate(dataloader): +# if idx > 3: +# break +# +# iterator = iter(dataloader) +# for each in iterator: +# pass class DatasetWithVaryLength: