diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 22f28743..391bc22a 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE: from paddle.optimizer import Optimizer _reduces = { - 'max': paddle.max, - 'min': paddle.min, - 'mean': paddle.mean, - 'sum': paddle.sum + "max": paddle.max, + "min": paddle.min, + "mean": paddle.mean, + "sum": paddle.sum } class PaddleDriver(Driver): @@ -254,24 +254,24 @@ class PaddleDriver(Driver): else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - num_consumed_batches = states.pop('num_consumed_batches') - if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + num_consumed_batches = states.pop("num_consumed_batches") + if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - try: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." - + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + try: + sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + pass + assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." else: raise RuntimeError( - 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") + states["sampler_states"] = sampler_states # 2. 保存模型的状态; if should_save_model: @@ -326,7 +326,7 @@ class PaddleDriver(Driver): batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) - sampler.load_state_dict(states['sampler_states']) + sampler.load_state_dict(states["sampler_states"]) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch @@ -355,7 +355,7 @@ class PaddleDriver(Driver): return paddle.no_grad @staticmethod - def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']): + def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): r""" 用来将模型转移到指定的 device 上; 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 @@ -363,7 +363,7 @@ class PaddleDriver(Driver): if device is not None: model.to(device) - def move_data_to_device(self, batch: 'paddle.Tensor'): + def move_data_to_device(self, batch: "paddle.Tensor"): r""" 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 @@ -404,7 +404,7 @@ class PaddleDriver(Driver): if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) - def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx): + def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): r""" 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index b9681121..51597210 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -224,7 +224,6 @@ class TestSetDistReproDataloder: """ def setup_method(self): self.dataset = PaddleNormalDataset(20) - self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleSingleDriver(model, device="cpu") @@ -233,55 +232,59 @@ class TestSetDistReproDataloder: 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 当dist为字符串时,此时应该返回原来的 dataloader """ - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) - assert replaced_loader is self.dataloader + assert replaced_loader is dataloader def test_set_dist_repro_dataloader_with_reproducible_true(self): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler """ - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) - assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size - assert replaced_loader.drop_last == self.dataloader.drop_last + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler is dist - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_sampler(self): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) dist = RandomSampler(self.dataset, shuffle=True) - replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) - assert not (replaced_loader is self.dataloader) + assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) - assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.sampler is dist - assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): """ @@ -295,11 +298,12 @@ class TestSetDistReproDataloder: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): """ @@ -316,34 +320,52 @@ class TestSetDistReproDataloder: assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == True - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): """ 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ # 迭代两个 batch - # 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 + num_consumed_batches = 2 already_seen_idx = set() - for idx, batch in replaced_loader: - already_seen_idx.update(batch) - if idx >= 1: + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: break + already_seen_idx.update(batch) if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() - print(sampler_states["data_idx"]) + + # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + + import time + time.sleep(5) # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size replaced_loader.batch_sampler.load_state_dict(sampler_states) else: + batch_size = replaced_loader.batch_sampler.batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) + replaced_loader.batch_sampler.sampler.set_epoch(0) for idx, batch in enumerate(replaced_loader): left_idxes.update(batch)