Browse Source

跟进paddle jittor 关于 set_dist_repro_dataloader函数中的修改

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
cd0957fb5b
7 changed files with 136 additions and 57 deletions
  1. +5
    -19
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +2
    -2
      fastNLP/core/drivers/jittor_driver/single_device.py
  3. +2
    -0
      fastNLP/core/drivers/paddle_driver/fleet.py
  4. +5
    -19
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  5. +19
    -12
      fastNLP/core/drivers/paddle_driver/single_device.py
  6. +12
    -1
      fastNLP/core/drivers/paddle_driver/utils.py
  7. +91
    -4
      tests/core/drivers/paddle_driver/test_fleet.py

+ 5
- 19
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -138,26 +138,12 @@ class JittorDriver(Driver):
num_consumed_batches = states.pop('num_consumed_batches') num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict() sampler_states = sampler.state_dict()
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else: else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")


states['sampler_states'] = sampler_states states['sampler_states'] = sampler_states
else: else:


+ 2
- 2
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver):
if args.sampler is None: if args.sampler is None:
sampler = RandomSampler(args.dataset, args.shuffle) sampler = RandomSampler(args.dataset, args.shuffle)
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorRandomSampler):
elif type(args.sampler) is JittorRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \ if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'rep', False) is False: and getattr(args.sampler, 'rep', False) is False:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 # 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.dataset, shuffle=True) sampler = RandomSampler(args.sampler.dataset, shuffle=True)
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorSequentialSampler):
elif type(args.sampler) is JittorSequentialSampler:
# 需要替换为不要 shuffle 的。 # 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.dataset, shuffle=False) sampler = RandomSampler(args.sampler.dataset, shuffle=False)
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.")


+ 2
- 0
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -73,6 +73,7 @@ from .utils import (
_FleetWrappingModel, _FleetWrappingModel,
replace_sampler, replace_sampler,
replace_batch_sampler, replace_batch_sampler,
_check_dataloader_args_for_distributed
) )
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object


@@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver):
) )
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else: else:
_check_dataloader_args_for_distributed(args, controller='Trainer')
sampler = RandomSampler( sampler = RandomSampler(
dataset=args.dataset, dataset=args.dataset,
shuffle=args.shuffle, shuffle=args.shuffle,


+ 5
- 19
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -222,26 +222,12 @@ class PaddleDriver(Driver):
num_consumed_batches = states.pop("num_consumed_batches") num_consumed_batches = states.pop("num_consumed_batches")
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict() sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else: else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else: else:
raise RuntimeError( raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")


+ 19
- 12
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE:
import paddle import paddle
from paddle import DataParallel from paddle import DataParallel
from paddle.fluid.reader import _DatasetKind from paddle.fluid.reader import _DatasetKind
from paddle.io import (
RandomSampler as PaddleRandomSampler,
SequenceSampler as PaddleSequenialSampler,
BatchSampler as PaddleBatchSampler,
)


__all__ = [ __all__ = [
"PaddleSingleDriver", "PaddleSingleDriver",
@@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)


if reproducible: if reproducible:
if isinstance(args.sampler, paddle.io.RandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
if type(args.batch_sampler) is PaddleBatchSampler:
if type(args.sampler) is PaddleRandomSampler:
if isinstance(args.sampler, PaddleRandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is PaddleSequenialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, paddle.io.SequenceSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler( batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler, batch_sampler=args.batch_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,


+ 12
- 1
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE:
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import Layer from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler
from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler
from paddle.amp import auto_cast, GradScaler from paddle.amp import auto_cast, GradScaler
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device):
else: else:
new_state[name] = param new_state[name] = param
return new_state return new_state

def _check_dataloader_args_for_distributed(args, controller='Trainer'):
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler,
SequenceSampler}):
mode = 'training' if controller == 'Trainer' else 'evaluation'
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
f"``{substitution}``. The customized sampler should set for distributed running "
f"before initializing ``{controller}`` , and then set the "
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")

+ 91
- 4
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -11,11 +11,12 @@ from fastNLP.core.samplers import (
) )
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from tests.helpers.utils import magic_argv_env_context, recover_logger
from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.distributed import rank_zero_rm
from fastNLP import prepare_paddle_dataloader from fastNLP import prepare_paddle_dataloader
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP import logger
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
@@ -532,7 +533,6 @@ class TestSetDistReproDataloader:
num_samples = 200 num_samples = 200
dataset = PaddleNormalXYDataset(num_samples) dataset = PaddleNormalXYDataset(num_samples)
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver.setup() self.driver.setup()
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible)


@@ -581,8 +581,6 @@ class TestSetDistReproDataloader:
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last,
shuffle=shuffle, num_batch_per_bucket=2) shuffle=shuffle, num_batch_per_bucket=2)
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler)
model = PaddleNormalModel_Classification_1(10, 32)
device = [0, 1]
self.driver.setup() self.driver.setup()
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible)


@@ -619,6 +617,95 @@ class TestSetDistReproDataloader:
finally: finally:
dist.barrier() dist.barrier()


@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_batch_sampler_dataloader(self, inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = PaddleNormalXYDataset(num_samples)
if inherit:
class BatchSampler(paddle.io.BatchSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size
else:
class BatchSampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size

dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4))
self.driver.setup()
with pytest.raises(TypeError):
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
pass

@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_sampler_dataloader(self, inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = PaddleNormalXYDataset(num_samples)
if inherit:
class Sampler(paddle.io.RandomSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)
else:
class Sampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)

dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4))
self.driver.setup()
# TODO 这里需要raise
with pytest.raises(TypeError):
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
pass



############################################################################ ############################################################################
# #


Loading…
Cancel
Save