Browse Source

跟进paddle jittor 关于 set_dist_repro_dataloader函数中的修改

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
cd0957fb5b
7 changed files with 136 additions and 57 deletions
  1. +5
    -19
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +2
    -2
      fastNLP/core/drivers/jittor_driver/single_device.py
  3. +2
    -0
      fastNLP/core/drivers/paddle_driver/fleet.py
  4. +5
    -19
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  5. +19
    -12
      fastNLP/core/drivers/paddle_driver/single_device.py
  6. +12
    -1
      fastNLP/core/drivers/paddle_driver/utils.py
  7. +91
    -4
      tests/core/drivers/paddle_driver/test_fleet.py

+ 5
- 19
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -138,26 +138,12 @@ class JittorDriver(Driver):
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states
else:


+ 2
- 2
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver):
if args.sampler is None:
sampler = RandomSampler(args.dataset, args.shuffle)
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorRandomSampler):
elif type(args.sampler) is JittorRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'rep', False) is False:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.dataset, shuffle=True)
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorSequentialSampler):
elif type(args.sampler) is JittorSequentialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.dataset, shuffle=False)
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.")


+ 2
- 0
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -73,6 +73,7 @@ from .utils import (
_FleetWrappingModel,
replace_sampler,
replace_batch_sampler,
_check_dataloader_args_for_distributed
)
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object

@@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver):
)
return replace_sampler(dataloader, sampler)
else:
_check_dataloader_args_for_distributed(args, controller='Trainer')
sampler = RandomSampler(
dataset=args.dataset,
shuffle=args.shuffle,


+ 5
- 19
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -222,26 +222,12 @@ class PaddleDriver(Driver):
num_consumed_batches = states.pop("num_consumed_batches")
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")


+ 19
- 12
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
from paddle.fluid.reader import _DatasetKind
from paddle.io import (
RandomSampler as PaddleRandomSampler,
SequenceSampler as PaddleSequenialSampler,
BatchSampler as PaddleBatchSampler,
)

__all__ = [
"PaddleSingleDriver",
@@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
if isinstance(args.sampler, paddle.io.RandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
if type(args.batch_sampler) is PaddleBatchSampler:
if type(args.sampler) is PaddleRandomSampler:
if isinstance(args.sampler, PaddleRandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is PaddleSequenialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, paddle.io.SequenceSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,


+ 12
- 1
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import nn
from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler
from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler
from paddle.amp import auto_cast, GradScaler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device):
else:
new_state[name] = param
return new_state

def _check_dataloader_args_for_distributed(args, controller='Trainer'):
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler,
SequenceSampler}):
mode = 'training' if controller == 'Trainer' else 'evaluation'
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
f"``{substitution}``. The customized sampler should set for distributed running "
f"before initializing ``{controller}`` , and then set the "
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")

+ 91
- 4
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -11,11 +11,12 @@ from fastNLP.core.samplers import (
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from tests.helpers.utils import magic_argv_env_context, recover_logger
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP import prepare_paddle_dataloader
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP import logger
if _NEED_IMPORT_PADDLE:
import paddle
import paddle.distributed as dist
@@ -532,7 +533,6 @@ class TestSetDistReproDataloader:
num_samples = 200
dataset = PaddleNormalXYDataset(num_samples)
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver.setup()
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible)

@@ -581,8 +581,6 @@ class TestSetDistReproDataloader:
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last,
shuffle=shuffle, num_batch_per_bucket=2)
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler)
model = PaddleNormalModel_Classification_1(10, 32)
device = [0, 1]
self.driver.setup()
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible)

@@ -619,6 +617,95 @@ class TestSetDistReproDataloader:
finally:
dist.barrier()

@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_batch_sampler_dataloader(self, inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = PaddleNormalXYDataset(num_samples)
if inherit:
class BatchSampler(paddle.io.BatchSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size
else:
class BatchSampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size

dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4))
self.driver.setup()
with pytest.raises(TypeError):
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
pass

@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_sampler_dataloader(self, inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = PaddleNormalXYDataset(num_samples)
if inherit:
class Sampler(paddle.io.RandomSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)
else:
class Sampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)

dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4))
self.driver.setup()
# TODO 这里需要raise
with pytest.raises(TypeError):
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
pass


############################################################################
#


Loading…
Cancel
Save