@@ -609,9 +609,15 @@ class TorchDDPDriver(TorchDriver): | |||
# evaluator | |||
elif dist == "unrepeatdist": | |||
args = self.get_dataloader_args(dataloader) | |||
if type(args.batch_sampler) != BatchSampler: | |||
# TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 | |||
logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ | |||
"train dataloader while testing ``overfit_batches``, which may cause that" \ | |||
"the data for distributed evaluation is not unrepeated.") | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
# TODO 避开 batch_sampler 的情况 | |||
_check_dataloader_args_for_distributed(args, controller='Evaluator') | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
@@ -622,6 +628,7 @@ class TorchDDPDriver(TorchDriver): | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
) | |||
# TODO 这里暂时统一替换为 BatchSampler | |||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
@@ -14,7 +14,7 @@ from fastNLP.envs import ( | |||
FASTNLP_BACKEND_LAUNCH, | |||
FASTNLP_GLOBAL_SEED, | |||
) | |||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler | |||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.utils import auto_param_call, apply_to_collection | |||
from fastNLP.core.log import logger | |||
@@ -308,15 +308,26 @@ def optimizer_state_to_device(state, device): | |||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||
TorchSequentialSampler}): | |||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
""" | |||
检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 | |||
在分布式训练中出现错误会报错。 | |||
""" | |||
error_flag = (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler}) | |||
if controller == 'Trainer': | |||
mode = 'training' | |||
substitution = 'fastNLP.RandomSampler' | |||
error_flag = (type(args.batch_sampler) != TorchBatchSampler) or error_flag | |||
else: # Evaluator | |||
mode = 'evaluation' | |||
substitution = 'fastNLP.UnrepeatedSequentialSampler' | |||
if error_flag: | |||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||
f"``{substitution}``. The customized sampler should set for distributed running " | |||
f"before initializing ``{controller}`` , and then set the " | |||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." | |||
f"\n Current batch_sampler: {type(args.batch_sampler)}" | |||
f"\n Current sampler: {type(args.sampler)}") | |||
def _create_default_config( | |||
zero_optimization: bool = True, | |||
@@ -7,6 +7,7 @@ from dataclasses import dataclass | |||
from typing import Any | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.core.samplers import BucketedBatchSampler, RandomSampler | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||
@@ -378,4 +379,96 @@ def test_trainer_w_evaluator_overfit_torch( | |||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("train_sampler", ["batch_sampler", "sampler"]) | |||
@pytest.mark.parametrize("eval_sampler", ["batch_sampler", "sampler"]) | |||
@pytest.mark.parametrize("overfit_batches", [-1, 0]) | |||
@magic_argv_env_context | |||
def test_trainer_w_evaluator_w_samplers( | |||
driver, | |||
device, | |||
train_sampler, | |||
eval_sampler, | |||
overfit_batches, | |||
): | |||
""" | |||
测试使用 dataloader 时使用了定制 batch_sampler 或 sampler 且合法的情况 | |||
""" | |||
model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
optimizers = SGD(model.parameters(), lr=0.001) | |||
metrics = {"acc": Accuracy()} | |||
dataset = TorchArgMaxDataset( | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||
data_num=ArgMaxDatasetConfig.data_num, | |||
seed=ArgMaxDatasetConfig.seed | |||
) | |||
if train_sampler == "batch_sampler": | |||
train_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_sampler=BucketedBatchSampler( | |||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size | |||
) | |||
) | |||
elif train_sampler == "sampler": | |||
train_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
sampler=RandomSampler(dataset) | |||
) | |||
else: | |||
train_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
shuffle=True, | |||
) | |||
if eval_sampler == "batch_sampler": | |||
eval_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_sampler=BucketedBatchSampler( | |||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size | |||
) | |||
) | |||
elif eval_sampler == "sampler": | |||
eval_dataloader = DataLoader( | |||
dataset=dataset, | |||
sampler=RandomSampler(dataset) | |||
) | |||
else: | |||
DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
shuffle=True, | |||
) | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
device=device, | |||
overfit_batches=overfit_batches, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
evaluate_dataloaders={"dl": eval_dataloader}, | |||
metrics=metrics, | |||
n_epochs=2, | |||
output_from_new_proc="all", | |||
evaluate_every=-1, | |||
torch_kwargs={ | |||
"non_blocking": False, | |||
"set_grad_to_none": True | |||
} | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||