@@ -138,26 +138,12 @@ class JittorDriver(Driver): | |||||
num_consumed_batches = states.pop('num_consumed_batches') | num_consumed_batches = states.pop('num_consumed_batches') | ||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
else: | else: | ||||
@@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver): | |||||
if args.sampler is None: | if args.sampler is None: | ||||
sampler = RandomSampler(args.dataset, args.shuffle) | sampler = RandomSampler(args.dataset, args.shuffle) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, JittorRandomSampler): | |||||
elif type(args.sampler) is JittorRandomSampler: | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | if getattr(args.sampler, '_num_samples', None) is None \ | ||||
and getattr(args.sampler, 'rep', False) is False: | and getattr(args.sampler, 'rep', False) is False: | ||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | ||||
sampler = RandomSampler(args.sampler.dataset, shuffle=True) | sampler = RandomSampler(args.sampler.dataset, shuffle=True) | ||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, JittorSequentialSampler): | |||||
elif type(args.sampler) is JittorSequentialSampler: | |||||
# 需要替换为不要 shuffle 的。 | # 需要替换为不要 shuffle 的。 | ||||
sampler = RandomSampler(args.sampler.dataset, shuffle=False) | sampler = RandomSampler(args.sampler.dataset, shuffle=False) | ||||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | ||||
@@ -73,6 +73,7 @@ from .utils import ( | |||||
_FleetWrappingModel, | _FleetWrappingModel, | ||||
replace_sampler, | replace_sampler, | ||||
replace_batch_sampler, | replace_batch_sampler, | ||||
_check_dataloader_args_for_distributed | |||||
) | ) | ||||
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | ||||
@@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
_check_dataloader_args_for_distributed(args, controller='Trainer') | |||||
sampler = RandomSampler( | sampler = RandomSampler( | ||||
dataset=args.dataset, | dataset=args.dataset, | ||||
shuffle=args.shuffle, | shuffle=args.shuffle, | ||||
@@ -222,26 +222,12 @@ class PaddleDriver(Driver): | |||||
num_consumed_batches = states.pop("num_consumed_batches") | num_consumed_batches = states.pop("num_consumed_batches") | ||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | ||||
@@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import DataParallel | from paddle import DataParallel | ||||
from paddle.fluid.reader import _DatasetKind | from paddle.fluid.reader import _DatasetKind | ||||
from paddle.io import ( | |||||
RandomSampler as PaddleRandomSampler, | |||||
SequenceSampler as PaddleSequenialSampler, | |||||
BatchSampler as PaddleBatchSampler, | |||||
) | |||||
__all__ = [ | __all__ = [ | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
@@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, paddle.io.RandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
if type(args.batch_sampler) is PaddleBatchSampler: | |||||
if type(args.sampler) is PaddleRandomSampler: | |||||
if isinstance(args.sampler, PaddleRandomSampler): | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif type(args.sampler) is PaddleSequenialSampler: | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | batch_sampler = ReproduceBatchSampler( | ||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
@@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE: | |||||
import paddle | import paddle | ||||
from paddle import nn | from paddle import nn | ||||
from paddle.nn import Layer | from paddle.nn import Layer | ||||
from paddle.io import DataLoader, BatchSampler | |||||
from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler | |||||
from paddle.amp import auto_cast, GradScaler | from paddle.amp import auto_cast, GradScaler | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | from fastNLP.core.utils.dummy_class import DummyClass as Layer | ||||
@@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device): | |||||
else: | else: | ||||
new_state[name] = param | new_state[name] = param | ||||
return new_state | return new_state | ||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||||
SequenceSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
f"``{substitution}``. The customized sampler should set for distributed running " | |||||
f"before initializing ``{controller}`` , and then set the " | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -11,11 +11,12 @@ from fastNLP.core.samplers import ( | |||||
) | ) | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | ||||
from tests.helpers.utils import magic_argv_env_context | |||||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP import prepare_paddle_dataloader | from fastNLP import prepare_paddle_dataloader | ||||
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP import logger | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
@@ -532,7 +533,6 @@ class TestSetDistReproDataloader: | |||||
num_samples = 200 | num_samples = 200 | ||||
dataset = PaddleNormalXYDataset(num_samples) | dataset = PaddleNormalXYDataset(num_samples) | ||||
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | ||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
self.driver.setup() | self.driver.setup() | ||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | ||||
@@ -581,8 +581,6 @@ class TestSetDistReproDataloader: | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | ||||
shuffle=shuffle, num_batch_per_bucket=2) | shuffle=shuffle, num_batch_per_bucket=2) | ||||
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | ||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
device = [0, 1] | |||||
self.driver.setup() | self.driver.setup() | ||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | ||||
@@ -619,6 +617,95 @@ class TestSetDistReproDataloader: | |||||
finally: | finally: | ||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
@recover_logger | |||||
@pytest.mark.parametrize("inherit", ([True, False])) | |||||
def test_customized_batch_sampler_dataloader(self, inherit): | |||||
try: | |||||
logger.set_stdout('raw', level='info') | |||||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
num_samples = 10 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
if inherit: | |||||
class BatchSampler(paddle.io.BatchSampler): | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
for i in range(len(self)): | |||||
start = i * self.batch_size | |||||
end = (i + 1) * self.batch_size | |||||
return indices[start:end] | |||||
def __len__(self): | |||||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
else: | |||||
class BatchSampler: | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
for i in range(len(self)): | |||||
start = i * self.batch_size | |||||
end = (i + 1) * self.batch_size | |||||
return indices[start:end] | |||||
def __len__(self): | |||||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||||
self.driver.setup() | |||||
with pytest.raises(TypeError): | |||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
finally: | |||||
pass | |||||
@magic_argv_env_context | |||||
@recover_logger | |||||
@pytest.mark.parametrize("inherit", ([True, False])) | |||||
def test_customized_sampler_dataloader(self, inherit): | |||||
try: | |||||
logger.set_stdout('raw', level='info') | |||||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
num_samples = 10 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
if inherit: | |||||
class Sampler(paddle.io.RandomSampler): | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
return iter(indices) | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
else: | |||||
class Sampler: | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
return iter(indices) | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||||
self.driver.setup() | |||||
# TODO 这里需要raise | |||||
with pytest.raises(TypeError): | |||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
finally: | |||||
pass | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||