@@ -609,9 +609,15 @@ class TorchDDPDriver(TorchDriver): | |||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
if type(args.batch_sampler) != BatchSampler: | |||||
# TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 | |||||
logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ | |||||
"train dataloader while testing ``overfit_batches``, which may cause that" \ | |||||
"the data for distributed evaluation is not unrepeated.") | |||||
if isinstance(args.sampler, ReproducibleSampler): | if isinstance(args.sampler, ReproducibleSampler): | ||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | ||||
elif not isinstance(args.sampler, UnrepeatedSampler): | elif not isinstance(args.sampler, UnrepeatedSampler): | ||||
# TODO 避开 batch_sampler 的情况 | |||||
_check_dataloader_args_for_distributed(args, controller='Evaluator') | _check_dataloader_args_for_distributed(args, controller='Evaluator') | ||||
sampler = UnrepeatedSequentialSampler( | sampler = UnrepeatedSequentialSampler( | ||||
dataset=args.dataset | dataset=args.dataset | ||||
@@ -622,6 +628,7 @@ class TorchDDPDriver(TorchDriver): | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
# TODO 这里暂时统一替换为 BatchSampler | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
else: | else: | ||||
@@ -14,7 +14,7 @@ from fastNLP.envs import ( | |||||
FASTNLP_BACKEND_LAUNCH, | FASTNLP_BACKEND_LAUNCH, | ||||
FASTNLP_GLOBAL_SEED, | FASTNLP_GLOBAL_SEED, | ||||
) | ) | ||||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.utils import auto_param_call, apply_to_collection | from fastNLP.core.utils import auto_param_call, apply_to_collection | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -308,15 +308,26 @@ def optimizer_state_to_device(state, device): | |||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | def _check_dataloader_args_for_distributed(args, controller='Trainer'): | ||||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||||
TorchSequentialSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
""" | |||||
检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 | |||||
在分布式训练中出现错误会报错。 | |||||
""" | |||||
error_flag = (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler}) | |||||
if controller == 'Trainer': | |||||
mode = 'training' | |||||
substitution = 'fastNLP.RandomSampler' | |||||
error_flag = (type(args.batch_sampler) != TorchBatchSampler) or error_flag | |||||
else: # Evaluator | |||||
mode = 'evaluation' | |||||
substitution = 'fastNLP.UnrepeatedSequentialSampler' | |||||
if error_flag: | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | ||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | ||||
f"``{substitution}``. The customized sampler should set for distributed running " | f"``{substitution}``. The customized sampler should set for distributed running " | ||||
f"before initializing ``{controller}`` , and then set the " | f"before initializing ``{controller}`` , and then set the " | ||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." | |||||
f"\n Current batch_sampler: {type(args.batch_sampler)}" | |||||
f"\n Current sampler: {type(args.sampler)}") | |||||
def _create_default_config( | def _create_default_config( | ||||
zero_optimization: bool = True, | zero_optimization: bool = True, | ||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass | |||||
from typing import Any | from typing import Any | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.samplers import BucketedBatchSampler, RandomSampler | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
@@ -378,4 +379,96 @@ def test_trainer_w_evaluator_overfit_torch( | |||||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | ||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("train_sampler", ["batch_sampler", "sampler"]) | |||||
@pytest.mark.parametrize("eval_sampler", ["batch_sampler", "sampler"]) | |||||
@pytest.mark.parametrize("overfit_batches", [-1, 0]) | |||||
@magic_argv_env_context | |||||
def test_trainer_w_evaluator_w_samplers( | |||||
driver, | |||||
device, | |||||
train_sampler, | |||||
eval_sampler, | |||||
overfit_batches, | |||||
): | |||||
""" | |||||
测试使用 dataloader 时使用了定制 batch_sampler 或 sampler 且合法的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
optimizers = SGD(model.parameters(), lr=0.001) | |||||
metrics = {"acc": Accuracy()} | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
if train_sampler == "batch_sampler": | |||||
train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=BucketedBatchSampler( | |||||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size | |||||
) | |||||
) | |||||
elif train_sampler == "sampler": | |||||
train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
sampler=RandomSampler(dataset) | |||||
) | |||||
else: | |||||
train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True, | |||||
) | |||||
if eval_sampler == "batch_sampler": | |||||
eval_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=BucketedBatchSampler( | |||||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size | |||||
) | |||||
) | |||||
elif eval_sampler == "sampler": | |||||
eval_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
sampler=RandomSampler(dataset) | |||||
) | |||||
else: | |||||
DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True, | |||||
) | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
overfit_batches=overfit_batches, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders={"dl": eval_dataloader}, | |||||
metrics=metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all", | |||||
evaluate_every=-1, | |||||
torch_kwargs={ | |||||
"non_blocking": False, | |||||
"set_grad_to_none": True | |||||
} | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||