@@ -138,26 +138,12 @@ class JittorDriver(Driver): | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
sampler_states = sampler.state_dict() | |||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。因为 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
states['sampler_states'] = sampler_states | |||
else: | |||
@@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver): | |||
if args.sampler is None: | |||
sampler = RandomSampler(args.dataset, args.shuffle) | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, JittorRandomSampler): | |||
elif type(args.sampler) is JittorRandomSampler: | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'rep', False) is False: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, JittorSequentialSampler): | |||
elif type(args.sampler) is JittorSequentialSampler: | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||
@@ -73,6 +73,7 @@ from .utils import ( | |||
_FleetWrappingModel, | |||
replace_sampler, | |||
replace_batch_sampler, | |||
_check_dataloader_args_for_distributed | |||
) | |||
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | |||
@@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
_check_dataloader_args_for_distributed(args, controller='Trainer') | |||
sampler = RandomSampler( | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
@@ -222,26 +222,12 @@ class PaddleDriver(Driver): | |||
num_consumed_batches = states.pop("num_consumed_batches") | |||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||
sampler_states = sampler.state_dict() | |||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
else: | |||
raise RuntimeError( | |||
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | |||
@@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle import DataParallel | |||
from paddle.fluid.reader import _DatasetKind | |||
from paddle.io import ( | |||
RandomSampler as PaddleRandomSampler, | |||
SequenceSampler as PaddleSequenialSampler, | |||
BatchSampler as PaddleBatchSampler, | |||
) | |||
__all__ = [ | |||
"PaddleSingleDriver", | |||
@@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
if isinstance(args.sampler, paddle.io.RandomSampler): | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'replacements', False) is False \ | |||
and getattr(args.sampler, 'generator', None) is None: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
if type(args.batch_sampler) is PaddleBatchSampler: | |||
if type(args.sampler) is PaddleRandomSampler: | |||
if isinstance(args.sampler, PaddleRandomSampler): | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'replacements', False) is False \ | |||
and getattr(args.sampler, 'generator', None) is None: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif type(args.sampler) is PaddleSequenialSampler: | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
batch_sampler = ReproduceBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
@@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle import nn | |||
from paddle.nn import Layer | |||
from paddle.io import DataLoader, BatchSampler | |||
from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler | |||
from paddle.amp import auto_cast, GradScaler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | |||
@@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device): | |||
else: | |||
new_state[name] = param | |||
return new_state | |||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||
SequenceSampler}): | |||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||
f"``{substitution}``. The customized sampler should set for distributed running " | |||
f"before initializing ``{controller}`` , and then set the " | |||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -11,11 +11,12 @@ from fastNLP.core.samplers import ( | |||
) | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
from tests.helpers.utils import magic_argv_env_context | |||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP import prepare_paddle_dataloader | |||
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
import paddle.distributed as dist | |||
@@ -532,7 +533,6 @@ class TestSetDistReproDataloader: | |||
num_samples = 200 | |||
dataset = PaddleNormalXYDataset(num_samples) | |||
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
self.driver.setup() | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||
@@ -581,8 +581,6 @@ class TestSetDistReproDataloader: | |||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||
shuffle=shuffle, num_batch_per_bucket=2) | |||
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
device = [0, 1] | |||
self.driver.setup() | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||
@@ -619,6 +617,95 @@ class TestSetDistReproDataloader: | |||
finally: | |||
dist.barrier() | |||
@magic_argv_env_context | |||
@recover_logger | |||
@pytest.mark.parametrize("inherit", ([True, False])) | |||
def test_customized_batch_sampler_dataloader(self, inherit): | |||
try: | |||
logger.set_stdout('raw', level='info') | |||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
num_samples = 10 | |||
dataset = PaddleNormalXYDataset(num_samples) | |||
if inherit: | |||
class BatchSampler(paddle.io.BatchSampler): | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
for i in range(len(self)): | |||
start = i * self.batch_size | |||
end = (i + 1) * self.batch_size | |||
return indices[start:end] | |||
def __len__(self): | |||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
else: | |||
class BatchSampler: | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
for i in range(len(self)): | |||
start = i * self.batch_size | |||
end = (i + 1) * self.batch_size | |||
return indices[start:end] | |||
def __len__(self): | |||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||
self.driver.setup() | |||
with pytest.raises(TypeError): | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
finally: | |||
pass | |||
@magic_argv_env_context | |||
@recover_logger | |||
@pytest.mark.parametrize("inherit", ([True, False])) | |||
def test_customized_sampler_dataloader(self, inherit): | |||
try: | |||
logger.set_stdout('raw', level='info') | |||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
num_samples = 10 | |||
dataset = PaddleNormalXYDataset(num_samples) | |||
if inherit: | |||
class Sampler(paddle.io.RandomSampler): | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
return iter(indices) | |||
def __len__(self): | |||
return len(self.dataset) | |||
else: | |||
class Sampler: | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
return iter(indices) | |||
def __len__(self): | |||
return len(self.dataset) | |||
dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||
self.driver.setup() | |||
# TODO 这里需要raise | |||
with pytest.raises(TypeError): | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
finally: | |||
pass | |||
############################################################################ | |||
# | |||