@@ -244,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
""" | """ | ||||
if self.local_rank == 0: | if self.local_rank == 0: | ||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
print("in launcher") | |||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
# 设置参数和初始化分布式环境 | # 设置参数和初始化分布式环境 | ||||
@@ -326,7 +325,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | ||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist.set_distributed( | dist.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -346,15 +344,16 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
@@ -66,8 +66,8 @@ class PaddleDriver(Driver): | |||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | :param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | ||||
""" | """ | ||||
# if set_to_none: | |||||
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||||
if set_to_none: | |||||
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||||
for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
optimizer.clear_grad() | optimizer.clear_grad() | ||||
@@ -254,8 +254,21 @@ class PaddleDriver(Driver): | |||||
else: | else: | ||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
states['sampler_states'] = sampler.state_dict() | |||||
sampler_states = sampler.state_dict() | |||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
try: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
@@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
@@ -151,7 +151,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | ||||