@@ -14,7 +14,7 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], | def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], | ||||
model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | |||||
model: "paddle.nn.Layer", **kwargs) -> PaddleDriver: | |||||
r""" | r""" | ||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | 1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | ||||
@@ -11,8 +11,8 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | ||||
def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.device, int, List[int]]], | |||||
model: torch.nn.Module, **kwargs) -> TorchDriver: | |||||
def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]], | |||||
model: "torch.nn.Module", **kwargs) -> TorchDriver: | |||||
r""" | r""" | ||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | 注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | ||||
@@ -11,9 +11,8 @@ _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | |||||
if _IS_ALLENNLP_AVAILABLE: | if _IS_ALLENNLP_AVAILABLE: | ||||
from allennlp.training.metrics import Metric as allennlp_Metric | from allennlp.training.metrics import Metric as allennlp_Metric | ||||
if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.metric import Metric as paddle_Metric | from paddle.metric import Metric as paddle_Metric | ||||
@@ -16,7 +16,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | ||||
from fastNLP.core import Evaluator | from fastNLP.core import Evaluator | ||||
from fastNLP.core.utils.utils import safe_rm | |||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.core.drivers.torch_driver import TorchSingleDriver | from fastNLP.core.drivers.torch_driver import TorchSingleDriver | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
@@ -112,7 +112,7 @@ def test_load_best_model_callback( | |||||
results = evaluator.run() | results = evaluator.run() | ||||
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | ||||
if save_folder: | if save_folder: | ||||
safe_rm(save_folder) | |||||
rank_zero_rm(save_folder) | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -4,7 +4,6 @@ | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py | ||||
""" | """ | ||||
import os | import os | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | import sys | ||||
sys.path.append("../../../") | sys.path.append("../../../") | ||||
@@ -4,7 +4,6 @@ | |||||
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py | python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py | ||||
""" | """ | ||||
import os | import os | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import sys | import sys | ||||
sys.path.append("../../../") | sys.path.append("../../../") | ||||
@@ -1,6 +1,4 @@ | |||||
import pytest | import pytest | ||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
@@ -25,7 +23,7 @@ class TrainPaddleConfig: | |||||
shuffle: bool = True | shuffle: bool = True | ||||
evaluate_every = 2 | evaluate_every = 2 | ||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | |||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | ||||
RichCallback(5)]]) | RichCallback(5)]]) | ||||
@@ -3,7 +3,6 @@ import sys | |||||
import signal | import signal | ||||
import pytest | import pytest | ||||
import traceback | import traceback | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import numpy as np | import numpy as np | ||||
@@ -1,8 +1,6 @@ | |||||
import pytest | import pytest | ||||
import os | |||||
from pathlib import Path | from pathlib import Path | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomSampler, | RandomSampler, | ||||
@@ -1,8 +1,5 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | ||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
from fastNLP.envs import get_gpu_count | from fastNLP.envs import get_gpu_count | ||||
@@ -1,6 +1,3 @@ | |||||
import os | |||||
from re import S | |||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
import pytest | import pytest | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -1,6 +1,4 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
from fastNLP.core.drivers.paddle_driver.utils import ( | from fastNLP.core.drivers.paddle_driver.utils import ( | ||||
get_device_from_visible, | get_device_from_visible, | ||||
@@ -0,0 +1,31 @@ | |||||
import sys | |||||
sys.path.append("../../../../") | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
import torch | |||||
device = [0, 1] | |||||
torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=False | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() | |||||
print("-----------first--------------") | |||||
device = [0, 2] | |||||
torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=False | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() |
@@ -1,8 +1,6 @@ | |||||
import pytest | import pytest | ||||
import os | |||||
from pathlib import Path | from pathlib import Path | ||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomSampler, | RandomSampler, | ||||
@@ -1,8 +1,5 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | ||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
from fastNLP.envs import get_gpu_count | from fastNLP.envs import get_gpu_count | ||||
@@ -1,5 +1,3 @@ | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
import pytest | import pytest | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -1,6 +1,4 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.utils import ( | from fastNLP.core.drivers.torch_driver.utils import ( | ||||
replace_batch_sampler, | replace_batch_sampler, | ||||
@@ -9,153 +9,153 @@ from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
class TestReproducibleBatchSampler: | |||||
# TODO 拆分测试,在这里只测试一个东西 | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
"sampler_type": "RandomBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass | |||||
# | |||||
# class TestReproducibleBatchSampler: | |||||
# # TODO 拆分测试,在这里只测试一个东西 | |||||
# def test_torch_dataloader_1(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# next(iter_dataloader) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
# "sampler_type": "RandomBatchSampler"} | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 改变 batch_size; | |||||
# after_batch_size = 3 | |||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# # 先把断点重训所在的那一个 epoch 跑完; | |||||
# begin_idx = 27 | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# # 开始新的一轮; | |||||
# begin_idx = 0 | |||||
# iter_dataloader = iter(dataloader) | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_torch_dataloader_2(self): | |||||
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
# all_supposed_data = [] | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 先把这一轮的数据过完; | |||||
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
# while True: | |||||
# try: | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# except StopIteration: | |||||
# break | |||||
# assert all_supposed_data == list(pre_index_list) | |||||
# | |||||
# # 重新开启新的一轮; | |||||
# for _ in range(3): | |||||
# iter_dataloader = iter(dataloader) | |||||
# res = [] | |||||
# while True: | |||||
# try: | |||||
# res.append(next(iter_dataloader)) | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_3(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# | |||||
# for idx, data in enumerate(dataloader): | |||||
# if idx > 3: | |||||
# break | |||||
# | |||||
# iterator = iter(dataloader) | |||||
# for each in iterator: | |||||
# pass | |||||
class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
@@ -28,12 +28,12 @@ class TestUnrepeatedSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
def test_multi(self, num_replica, num_of_data, shuffle): | |||||
def test_multi(self, num_replicas, num_of_data, shuffle): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | ||||
sampler.set_distributed(num_replica, rank=i) | |||||
sampler.set_distributed(num_replicas, rank=i) | |||||
samplers.append(sampler) | samplers.append(sampler) | ||||
indexes = list(chain(*samplers)) | indexes = list(chain(*samplers)) | ||||
@@ -52,12 +52,12 @@ class TestUnrepeatedSortedSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | |||||
def test_multi(self, num_replicas, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) | sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) | ||||
sampler.set_distributed(num_replica, rank=i) | |||||
sampler.set_distributed(num_replicas, rank=i) | |||||
samplers.append(sampler) | samplers.append(sampler) | ||||
# 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||
@@ -83,12 +83,12 @@ class TestUnrepeatedSequentialSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | |||||
def test_multi(self, num_replicas, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | ||||
sampler.set_distributed(num_replica, rank=i) | |||||
sampler.set_distributed(num_replicas, rank=i) | |||||
samplers.append(sampler) | samplers.append(sampler) | ||||
# 保证顺序是没乱的 | # 保证顺序是没乱的 | ||||