From cd0957fb5b7f2646ed1a2e930ea76346cf6e4356 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 14 Jun 2022 13:19:56 +0000 Subject: [PATCH] =?UTF-8?q?=E8=B7=9F=E8=BF=9Bpaddle=20jittor=20=E5=85=B3?= =?UTF-8?q?=E4=BA=8E=20set=5Fdist=5Frepro=5Fdataloader=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/jittor_driver/jittor_driver.py | 24 +---- .../drivers/jittor_driver/single_device.py | 4 +- fastNLP/core/drivers/paddle_driver/fleet.py | 2 + .../drivers/paddle_driver/paddle_driver.py | 24 +---- .../drivers/paddle_driver/single_device.py | 31 +++--- fastNLP/core/drivers/paddle_driver/utils.py | 13 ++- .../core/drivers/paddle_driver/test_fleet.py | 95 ++++++++++++++++++- 7 files changed, 136 insertions(+), 57 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 63ac6ec4..c2e338bb 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -138,26 +138,12 @@ class JittorDriver(Driver): num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。因为 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") states['sampler_states'] = sampler_states else: diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 4e9b3447..386f8694 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver): if args.sampler is None: sampler = RandomSampler(args.dataset, args.shuffle) return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, JittorRandomSampler): + elif type(args.sampler) is JittorRandomSampler: if getattr(args.sampler, '_num_samples', None) is None \ and getattr(args.sampler, 'rep', False) is False: # 如果本来就是随机的,并且没有定制,直接替换掉吧。 sampler = RandomSampler(args.sampler.dataset, shuffle=True) logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, JittorSequentialSampler): + elif type(args.sampler) is JittorSequentialSampler: # 需要替换为不要 shuffle 的。 sampler = RandomSampler(args.sampler.dataset, shuffle=False) logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index d19da9fe..9344f515 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -73,6 +73,7 @@ from .utils import ( _FleetWrappingModel, replace_sampler, replace_batch_sampler, + _check_dataloader_args_for_distributed ) from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object @@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver): ) return replace_sampler(dataloader, sampler) else: + _check_dataloader_args_for_distributed(args, controller='Trainer') sampler = RandomSampler( dataset=args.dataset, shuffle=args.shuffle, diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 4527f1ed..6ef0aaae 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -222,26 +222,12 @@ class PaddleDriver(Driver): num_consumed_batches = states.pop("num_consumed_batches") if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index ba404814..4105bf20 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import DataParallel from paddle.fluid.reader import _DatasetKind + from paddle.io import ( + RandomSampler as PaddleRandomSampler, + SequenceSampler as PaddleSequenialSampler, + BatchSampler as PaddleBatchSampler, + ) __all__ = [ "PaddleSingleDriver", @@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(args.sampler, paddle.io.RandomSampler): - if getattr(args.sampler, '_num_samples', None) is None \ - and getattr(args.sampler, 'replacements', False) is False \ - and getattr(args.sampler, 'generator', None) is None: - # 如果本来就是随机的,并且没有定制,直接替换掉。 - sampler = RandomSampler(args.sampler.data_source, shuffle=True) - logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + if type(args.batch_sampler) is PaddleBatchSampler: + if type(args.sampler) is PaddleRandomSampler: + if isinstance(args.sampler, PaddleRandomSampler): + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif type(args.sampler) is PaddleSequenialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, paddle.io.SequenceSampler): - # 需要替换为不要 shuffle 的。 - sampler = RandomSampler(args.sampler.data_source, shuffle=False) - logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") - return replace_sampler(dataloader, sampler) batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 9f35cf2a..1191b60c 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler + from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device): else: new_state[name] = param return new_state + +def _check_dataloader_args_for_distributed(args, controller='Trainer'): + if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, + SequenceSampler}): + mode = 'training' if controller == 'Trainer' else 'evaluation' + substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " + f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " + f"``{substitution}``. The customized sampler should set for distributed running " + f"before initializing ``{controller}`` , and then set the " + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index b303249c..80d494da 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -11,11 +11,12 @@ from fastNLP.core.samplers import ( ) from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset -from tests.helpers.utils import magic_argv_env_context +from tests.helpers.utils import magic_argv_env_context, recover_logger from fastNLP.envs.distributed import rank_zero_rm from fastNLP import prepare_paddle_dataloader from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP import logger if _NEED_IMPORT_PADDLE: import paddle import paddle.distributed as dist @@ -532,7 +533,6 @@ class TestSetDistReproDataloader: num_samples = 200 dataset = PaddleNormalXYDataset(num_samples) dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) - model = PaddleNormalModel_Classification_1(10, 32) self.driver.setup() dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) @@ -581,8 +581,6 @@ class TestSetDistReproDataloader: sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, num_batch_per_bucket=2) dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) - model = PaddleNormalModel_Classification_1(10, 32) - device = [0, 1] self.driver.setup() dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) @@ -619,6 +617,95 @@ class TestSetDistReproDataloader: finally: dist.barrier() + @magic_argv_env_context + @recover_logger + @pytest.mark.parametrize("inherit", ([True, False])) + def test_customized_batch_sampler_dataloader(self, inherit): + try: + logger.set_stdout('raw', level='info') + # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 + num_samples = 10 + dataset = PaddleNormalXYDataset(num_samples) + if inherit: + class BatchSampler(paddle.io.BatchSampler): + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + def __iter__(self): + indices = list(range(len(dataset))) + for i in range(len(self)): + start = i * self.batch_size + end = (i + 1) * self.batch_size + return indices[start:end] + + def __len__(self): + return (len(self.dataset)+self.batch_size-1)//self.batch_size + else: + class BatchSampler: + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + def __iter__(self): + indices = list(range(len(dataset))) + for i in range(len(self)): + start = i * self.batch_size + end = (i + 1) * self.batch_size + return indices[start:end] + + def __len__(self): + return (len(self.dataset)+self.batch_size-1)//self.batch_size + + dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) + self.driver.setup() + with pytest.raises(TypeError): + dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) + finally: + pass + + @magic_argv_env_context + @recover_logger + @pytest.mark.parametrize("inherit", ([True, False])) + def test_customized_sampler_dataloader(self, inherit): + try: + logger.set_stdout('raw', level='info') + # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 + num_samples = 10 + dataset = PaddleNormalXYDataset(num_samples) + if inherit: + class Sampler(paddle.io.RandomSampler): + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + def __iter__(self): + indices = list(range(len(dataset))) + return iter(indices) + + def __len__(self): + return len(self.dataset) + else: + class Sampler: + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + def __iter__(self): + indices = list(range(len(dataset))) + return iter(indices) + + def __len__(self): + return len(self.dataset) + + dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) + self.driver.setup() + # TODO 这里需要raise + with pytest.raises(TypeError): + dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) + finally: + pass + ############################################################################ #