diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 924595c2..5dd8125a 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -609,9 +609,15 @@ class TorchDDPDriver(TorchDriver): # evaluator elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) + if type(args.batch_sampler) != BatchSampler: + # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 + logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ + "train dataloader while testing ``overfit_batches``, which may cause that" \ + "the data for distributed evaluation is not unrepeated.") if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): + # TODO 避开 batch_sampler 的情况 _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset @@ -622,6 +628,7 @@ class TorchDDPDriver(TorchDriver): num_replicas=self.world_size, rank=self.global_rank ) + # TODO 这里暂时统一替换为 BatchSampler batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) return replace_batch_sampler(dataloader, batch_sampler) else: diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index f2e34d23..8306d33c 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -14,7 +14,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) -from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler +from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.utils import auto_param_call, apply_to_collection from fastNLP.core.log import logger @@ -308,15 +308,26 @@ def optimizer_state_to_device(state, device): def _check_dataloader_args_for_distributed(args, controller='Trainer'): - if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, - TorchSequentialSampler}): - mode = 'training' if controller == 'Trainer' else 'evaluation' - substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + """ + 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 + 在分布式训练中出现错误会报错。 + """ + error_flag = (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler}) + if controller == 'Trainer': + mode = 'training' + substitution = 'fastNLP.RandomSampler' + error_flag = (type(args.batch_sampler) != TorchBatchSampler) or error_flag + else: # Evaluator + mode = 'evaluation' + substitution = 'fastNLP.UnrepeatedSequentialSampler' + if error_flag: raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " f"``{substitution}``. The customized sampler should set for distributed running " f"before initializing ``{controller}`` , and then set the " - f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." + f"\n Current batch_sampler: {type(args.batch_sampler)}" + f"\n Current sampler: {type(args.sampler)}") def _create_default_config( zero_optimization: bool = True, diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 87ccb613..31e1f5ab 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.samplers import BucketedBatchSampler, RandomSampler from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback @@ -378,4 +379,96 @@ def test_trainer_w_evaluator_overfit_torch( trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) if dist.is_initialized(): - dist.destroy_process_group() \ No newline at end of file + dist.destroy_process_group() + +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("train_sampler", ["batch_sampler", "sampler"]) +@pytest.mark.parametrize("eval_sampler", ["batch_sampler", "sampler"]) +@pytest.mark.parametrize("overfit_batches", [-1, 0]) +@magic_argv_env_context +def test_trainer_w_evaluator_w_samplers( + driver, + device, + train_sampler, + eval_sampler, + overfit_batches, +): + """ + 测试使用 dataloader 时使用了定制 batch_sampler 或 sampler 且合法的情况 + """ + model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + optimizers = SGD(model.parameters(), lr=0.001) + metrics = {"acc": Accuracy()} + + dataset = TorchArgMaxDataset( + feature_dimension=ArgMaxDatasetConfig.feature_dimension, + data_num=ArgMaxDatasetConfig.data_num, + seed=ArgMaxDatasetConfig.seed + ) + if train_sampler == "batch_sampler": + train_dataloader = DataLoader( + dataset=dataset, + batch_sampler=BucketedBatchSampler( + dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size + ) + ) + elif train_sampler == "sampler": + train_dataloader = DataLoader( + dataset=dataset, + batch_size=ArgMaxDatasetConfig.batch_size, + sampler=RandomSampler(dataset) + ) + else: + train_dataloader = DataLoader( + dataset=dataset, + batch_size=ArgMaxDatasetConfig.batch_size, + shuffle=True, + ) + if eval_sampler == "batch_sampler": + eval_dataloader = DataLoader( + dataset=dataset, + batch_sampler=BucketedBatchSampler( + dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size + ) + ) + elif eval_sampler == "sampler": + eval_dataloader = DataLoader( + dataset=dataset, + sampler=RandomSampler(dataset) + ) + else: + DataLoader( + dataset=dataset, + batch_size=ArgMaxDatasetConfig.batch_size, + shuffle=True, + ) + + trainer = Trainer( + model=model, + driver=driver, + device=device, + overfit_batches=overfit_batches, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders={"dl": eval_dataloader}, + metrics=metrics, + n_epochs=2, + output_from_new_proc="all", + evaluate_every=-1, + + torch_kwargs={ + "non_blocking": False, + "set_grad_to_none": True + } + + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() +