| @@ -14,7 +14,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle | import paddle | ||||
| def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], | def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], | ||||
| model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | |||||
| model: "paddle.nn.Layer", **kwargs) -> PaddleDriver: | |||||
| r""" | r""" | ||||
| 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
| 1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | 1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | ||||
| @@ -11,8 +11,8 @@ from fastNLP.core.log import logger | |||||
| from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | ||||
| def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.device, int, List[int]]], | |||||
| model: torch.nn.Module, **kwargs) -> TorchDriver: | |||||
| def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]], | |||||
| model: "torch.nn.Module", **kwargs) -> TorchDriver: | |||||
| r""" | r""" | ||||
| 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
| 注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | 注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | ||||
| @@ -11,9 +11,8 @@ _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | |||||
| if _IS_ALLENNLP_AVAILABLE: | if _IS_ALLENNLP_AVAILABLE: | ||||
| from allennlp.training.metrics import Metric as allennlp_Metric | from allennlp.training.metrics import Metric as allennlp_Metric | ||||
| if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: | |||||
| if _IS_TORCHMETRICS_AVAILABLE: | |||||
| from torchmetrics import Metric as torchmetrics_Metric | |||||
| if _IS_TORCHMETRICS_AVAILABLE: | |||||
| from torchmetrics import Metric as torchmetrics_Metric | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| from paddle.metric import Metric as paddle_Metric | from paddle.metric import Metric as paddle_Metric | ||||
| @@ -16,7 +16,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
| from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | ||||
| from fastNLP.core import Evaluator | from fastNLP.core import Evaluator | ||||
| from fastNLP.core.utils.utils import safe_rm | |||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.core.drivers.torch_driver import TorchSingleDriver | from fastNLP.core.drivers.torch_driver import TorchSingleDriver | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
| @@ -112,7 +112,7 @@ def test_load_best_model_callback( | |||||
| results = evaluator.run() | results = evaluator.run() | ||||
| assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | ||||
| if save_folder: | if save_folder: | ||||
| safe_rm(save_folder) | |||||
| rank_zero_rm(save_folder) | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -4,7 +4,6 @@ | |||||
| python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | ||||
| """ | """ | ||||
| import os | import os | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| import sys | import sys | ||||
| sys.path.append("../../../") | sys.path.append("../../../") | ||||
| @@ -4,7 +4,6 @@ | |||||
| python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py | python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py | ||||
| """ | """ | ||||
| import os | import os | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| import sys | import sys | ||||
| sys.path.append("../../../") | sys.path.append("../../../") | ||||
| @@ -1,6 +1,4 @@ | |||||
| import pytest | import pytest | ||||
| import os | |||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| @@ -25,7 +23,7 @@ class TrainPaddleConfig: | |||||
| shuffle: bool = True | shuffle: bool = True | ||||
| evaluate_every = 2 | evaluate_every = 2 | ||||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | |||||
| # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
| @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | ||||
| RichCallback(5)]]) | RichCallback(5)]]) | ||||
| @@ -3,7 +3,6 @@ import sys | |||||
| import signal | import signal | ||||
| import pytest | import pytest | ||||
| import traceback | import traceback | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| import numpy as np | import numpy as np | ||||
| @@ -1,8 +1,6 @@ | |||||
| import pytest | import pytest | ||||
| import os | |||||
| from pathlib import Path | from pathlib import Path | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | ||||
| from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
| RandomSampler, | RandomSampler, | ||||
| @@ -1,8 +1,5 @@ | |||||
| import os | |||||
| import pytest | import pytest | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | ||||
| from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
| from fastNLP.envs import get_gpu_count | from fastNLP.envs import get_gpu_count | ||||
| @@ -1,6 +1,3 @@ | |||||
| import os | |||||
| from re import S | |||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| import pytest | import pytest | ||||
| from pathlib import Path | from pathlib import Path | ||||
| @@ -1,6 +1,4 @@ | |||||
| import os | |||||
| import pytest | import pytest | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from fastNLP.core.drivers.paddle_driver.utils import ( | from fastNLP.core.drivers.paddle_driver.utils import ( | ||||
| get_device_from_visible, | get_device_from_visible, | ||||
| @@ -0,0 +1,31 @@ | |||||
| import sys | |||||
| sys.path.append("../../../../") | |||||
| from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| import torch | |||||
| device = [0, 1] | |||||
| torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
| torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
| device = [torch.device(i) for i in device] | |||||
| driver = TorchDDPDriver( | |||||
| model=torch_model, | |||||
| parallel_device=device, | |||||
| fp16=False | |||||
| ) | |||||
| driver.set_optimizers(torch_opt) | |||||
| driver.setup() | |||||
| print("-----------first--------------") | |||||
| device = [0, 2] | |||||
| torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
| torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
| device = [torch.device(i) for i in device] | |||||
| driver = TorchDDPDriver( | |||||
| model=torch_model, | |||||
| parallel_device=device, | |||||
| fp16=False | |||||
| ) | |||||
| driver.set_optimizers(torch_opt) | |||||
| driver.setup() | |||||
| @@ -1,8 +1,6 @@ | |||||
| import pytest | import pytest | ||||
| import os | |||||
| from pathlib import Path | from pathlib import Path | ||||
| os.environ["FASTNLP_BACKEND"] = "torch" | |||||
| from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | ||||
| from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
| RandomSampler, | RandomSampler, | ||||
| @@ -1,8 +1,5 @@ | |||||
| import os | |||||
| import pytest | import pytest | ||||
| os.environ["FASTNLP_BACKEND"] = "torch" | |||||
| from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | ||||
| from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
| from fastNLP.envs import get_gpu_count | from fastNLP.envs import get_gpu_count | ||||
| @@ -1,5 +1,3 @@ | |||||
| import os | |||||
| os.environ["FASTNLP_BACKEND"] = "torch" | |||||
| import pytest | import pytest | ||||
| from pathlib import Path | from pathlib import Path | ||||
| @@ -1,6 +1,4 @@ | |||||
| import os | |||||
| import pytest | import pytest | ||||
| os.environ["FASTNLP_BACKEND"] = "torch" | |||||
| from fastNLP.core.drivers.torch_driver.utils import ( | from fastNLP.core.drivers.torch_driver.utils import ( | ||||
| replace_batch_sampler, | replace_batch_sampler, | ||||
| @@ -9,153 +9,153 @@ from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| class TestReproducibleBatchSampler: | |||||
| # TODO 拆分测试,在这里只测试一个东西 | |||||
| def test_torch_dataloader_1(self): | |||||
| import torch | |||||
| from torch.utils.data import DataLoader | |||||
| # no shuffle | |||||
| before_batch_size = 7 | |||||
| dataset = TorchNormalDataset(num_of_data=100) | |||||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| forward_steps = 3 | |||||
| iter_dataloader = iter(dataloader) | |||||
| for _ in range(forward_steps): | |||||
| next(iter_dataloader) | |||||
| # 1. 保存状态 | |||||
| _get_re_batchsampler = dataloader.batch_sampler | |||||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
| state = _get_re_batchsampler.state_dict() | |||||
| assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
| "sampler_type": "RandomBatchSampler"} | |||||
| # 2. 断点重训,重新生成一个 dataloader; | |||||
| # 不改变 batch_size; | |||||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| re_batchsampler.load_state_dict(state) | |||||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| real_res = [] | |||||
| supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
| forward_steps = 2 | |||||
| iter_dataloader = iter(dataloader) | |||||
| for _ in range(forward_steps): | |||||
| real_res.append(next(iter_dataloader)) | |||||
| for i in range(forward_steps): | |||||
| assert all(real_res[i] == supposed_res[i]) | |||||
| # 改变 batch_size; | |||||
| after_batch_size = 3 | |||||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| re_batchsampler.load_state_dict(state) | |||||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| real_res = [] | |||||
| supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
| forward_steps = 2 | |||||
| iter_dataloader = iter(dataloader) | |||||
| for _ in range(forward_steps): | |||||
| real_res.append(next(iter_dataloader)) | |||||
| for i in range(forward_steps): | |||||
| assert all(real_res[i] == supposed_res[i]) | |||||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||||
| # 先把断点重训所在的那一个 epoch 跑完; | |||||
| begin_idx = 27 | |||||
| while True: | |||||
| try: | |||||
| data = next(iter_dataloader) | |||||
| _batch_size = len(data) | |||||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
| begin_idx += _batch_size | |||||
| except StopIteration: | |||||
| break | |||||
| # 开始新的一轮; | |||||
| begin_idx = 0 | |||||
| iter_dataloader = iter(dataloader) | |||||
| while True: | |||||
| try: | |||||
| data = next(iter_dataloader) | |||||
| _batch_size = len(data) | |||||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
| begin_idx += _batch_size | |||||
| except StopIteration: | |||||
| break | |||||
| def test_torch_dataloader_2(self): | |||||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
| from torch.utils.data import DataLoader | |||||
| # no shuffle | |||||
| before_batch_size = 7 | |||||
| dataset = TorchNormalDataset(num_of_data=100) | |||||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
| all_supposed_data = [] | |||||
| forward_steps = 3 | |||||
| iter_dataloader = iter(dataloader) | |||||
| for _ in range(forward_steps): | |||||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
| # 1. 保存状态 | |||||
| _get_re_batchsampler = dataloader.batch_sampler | |||||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
| state = _get_re_batchsampler.state_dict() | |||||
| # 2. 断点重训,重新生成一个 dataloader; | |||||
| # 不改变 batch_size; | |||||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| re_batchsampler.load_state_dict(state) | |||||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # 先把这一轮的数据过完; | |||||
| pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
| while True: | |||||
| try: | |||||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
| except StopIteration: | |||||
| break | |||||
| assert all_supposed_data == list(pre_index_list) | |||||
| # 重新开启新的一轮; | |||||
| for _ in range(3): | |||||
| iter_dataloader = iter(dataloader) | |||||
| res = [] | |||||
| while True: | |||||
| try: | |||||
| res.append(next(iter_dataloader)) | |||||
| except StopIteration: | |||||
| break | |||||
| def test_3(self): | |||||
| import torch | |||||
| from torch.utils.data import DataLoader | |||||
| before_batch_size = 7 | |||||
| dataset = TorchNormalDataset(num_of_data=100) | |||||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| for idx, data in enumerate(dataloader): | |||||
| if idx > 3: | |||||
| break | |||||
| iterator = iter(dataloader) | |||||
| for each in iterator: | |||||
| pass | |||||
| # | |||||
| # class TestReproducibleBatchSampler: | |||||
| # # TODO 拆分测试,在这里只测试一个东西 | |||||
| # def test_torch_dataloader_1(self): | |||||
| # import torch | |||||
| # from torch.utils.data import DataLoader | |||||
| # # no shuffle | |||||
| # before_batch_size = 7 | |||||
| # dataset = TorchNormalDataset(num_of_data=100) | |||||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # | |||||
| # forward_steps = 3 | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # for _ in range(forward_steps): | |||||
| # next(iter_dataloader) | |||||
| # | |||||
| # # 1. 保存状态 | |||||
| # _get_re_batchsampler = dataloader.batch_sampler | |||||
| # assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
| # state = _get_re_batchsampler.state_dict() | |||||
| # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
| # "sampler_type": "RandomBatchSampler"} | |||||
| # | |||||
| # # 2. 断点重训,重新生成一个 dataloader; | |||||
| # # 不改变 batch_size; | |||||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| # re_batchsampler.load_state_dict(state) | |||||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # | |||||
| # real_res = [] | |||||
| # supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
| # forward_steps = 2 | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # for _ in range(forward_steps): | |||||
| # real_res.append(next(iter_dataloader)) | |||||
| # | |||||
| # for i in range(forward_steps): | |||||
| # assert all(real_res[i] == supposed_res[i]) | |||||
| # | |||||
| # # 改变 batch_size; | |||||
| # after_batch_size = 3 | |||||
| # dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| # re_batchsampler.load_state_dict(state) | |||||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # | |||||
| # real_res = [] | |||||
| # supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
| # forward_steps = 2 | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # for _ in range(forward_steps): | |||||
| # real_res.append(next(iter_dataloader)) | |||||
| # | |||||
| # for i in range(forward_steps): | |||||
| # assert all(real_res[i] == supposed_res[i]) | |||||
| # | |||||
| # # 断点重训的第二轮是否是一个完整的 dataloader; | |||||
| # # 先把断点重训所在的那一个 epoch 跑完; | |||||
| # begin_idx = 27 | |||||
| # while True: | |||||
| # try: | |||||
| # data = next(iter_dataloader) | |||||
| # _batch_size = len(data) | |||||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
| # begin_idx += _batch_size | |||||
| # except StopIteration: | |||||
| # break | |||||
| # | |||||
| # # 开始新的一轮; | |||||
| # begin_idx = 0 | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # while True: | |||||
| # try: | |||||
| # data = next(iter_dataloader) | |||||
| # _batch_size = len(data) | |||||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
| # begin_idx += _batch_size | |||||
| # except StopIteration: | |||||
| # break | |||||
| # | |||||
| # def test_torch_dataloader_2(self): | |||||
| # # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
| # from torch.utils.data import DataLoader | |||||
| # # no shuffle | |||||
| # before_batch_size = 7 | |||||
| # dataset = TorchNormalDataset(num_of_data=100) | |||||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # | |||||
| # # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
| # all_supposed_data = [] | |||||
| # forward_steps = 3 | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # for _ in range(forward_steps): | |||||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
| # | |||||
| # # 1. 保存状态 | |||||
| # _get_re_batchsampler = dataloader.batch_sampler | |||||
| # assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
| # state = _get_re_batchsampler.state_dict() | |||||
| # | |||||
| # # 2. 断点重训,重新生成一个 dataloader; | |||||
| # # 不改变 batch_size; | |||||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
| # re_batchsampler.load_state_dict(state) | |||||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
| # | |||||
| # # 先把这一轮的数据过完; | |||||
| # pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
| # while True: | |||||
| # try: | |||||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
| # except StopIteration: | |||||
| # break | |||||
| # assert all_supposed_data == list(pre_index_list) | |||||
| # | |||||
| # # 重新开启新的一轮; | |||||
| # for _ in range(3): | |||||
| # iter_dataloader = iter(dataloader) | |||||
| # res = [] | |||||
| # while True: | |||||
| # try: | |||||
| # res.append(next(iter_dataloader)) | |||||
| # except StopIteration: | |||||
| # break | |||||
| # | |||||
| # def test_3(self): | |||||
| # import torch | |||||
| # from torch.utils.data import DataLoader | |||||
| # before_batch_size = 7 | |||||
| # dataset = TorchNormalDataset(num_of_data=100) | |||||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
| # | |||||
| # for idx, data in enumerate(dataloader): | |||||
| # if idx > 3: | |||||
| # break | |||||
| # | |||||
| # iterator = iter(dataloader) | |||||
| # for each in iterator: | |||||
| # pass | |||||
| class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
| @@ -28,12 +28,12 @@ class TestUnrepeatedSampler: | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| @pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
| def test_multi(self, num_replica, num_of_data, shuffle): | |||||
| def test_multi(self, num_replicas, num_of_data, shuffle): | |||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| samplers = [] | samplers = [] | ||||
| for i in range(num_replica): | |||||
| for i in range(num_replicas): | |||||
| sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | ||||
| sampler.set_distributed(num_replica, rank=i) | |||||
| sampler.set_distributed(num_replicas, rank=i) | |||||
| samplers.append(sampler) | samplers.append(sampler) | ||||
| indexes = list(chain(*samplers)) | indexes = list(chain(*samplers)) | ||||
| @@ -52,12 +52,12 @@ class TestUnrepeatedSortedSampler: | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, num_replica, num_of_data): | |||||
| def test_multi(self, num_replicas, num_of_data): | |||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| samplers = [] | samplers = [] | ||||
| for i in range(num_replica): | |||||
| for i in range(num_replicas): | |||||
| sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) | sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) | ||||
| sampler.set_distributed(num_replica, rank=i) | |||||
| sampler.set_distributed(num_replicas, rank=i) | |||||
| samplers.append(sampler) | samplers.append(sampler) | ||||
| # 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||
| @@ -83,12 +83,12 @@ class TestUnrepeatedSequentialSampler: | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, num_replica, num_of_data): | |||||
| def test_multi(self, num_replicas, num_of_data): | |||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| samplers = [] | samplers = [] | ||||
| for i in range(num_replica): | |||||
| for i in range(num_replicas): | |||||
| sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | ||||
| sampler.set_distributed(num_replica, rank=i) | |||||
| sampler.set_distributed(num_replicas, rank=i) | |||||
| samplers.append(sampler) | samplers.append(sampler) | ||||
| # 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||