diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 84e4aa70..1594a903 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -190,7 +190,30 @@ class TorchDriver(Driver): # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; - # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # 1. sampler 的状态; + num_consumed_batches = states.pop('num_consumed_batches') + states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) + + # 2. 保存模型的状态; + if should_save_model: + if not os.path.exists(folder): + os.mkdir(folder) + model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) + self.save_model(model_path, only_state_dict=only_state_dict) + + # 3. 保存 optimizers 的状态; + states["optimizers_state_dict"] = self.get_optimizer_state() + logger.debug("Save optimizer state dict.") + + # 4. 保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + + torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + def get_sampler_state(self, dataloader, num_consumed_batches): + # 因为我们支持 resume training,即精确恢复到具体的一个 batch; # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) @@ -200,7 +223,7 @@ class TorchDriver(Driver): sampler = dataloader_args.sampler else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - num_consumed_batches = states.pop('num_consumed_batches') + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() if dataloader_args.batch_size is not None: @@ -209,30 +232,49 @@ class TorchDriver(Driver): else: logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " "`num_consumed_samples`, it may cause missing some samples when reload.") - - states['sampler_states'] = sampler_states else: raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' 'state.') - # 2. 保存模型的状态; - if should_save_model: - if not os.path.exists(folder): - os.mkdir(folder) - model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) - self.save_model(model_path, only_state_dict=only_state_dict) + return sampler_states - # 3. 保存 optimizers 的状态; - optimizers_state_dict = self.get_optimizer_state() + def load_sampler_state(self, dataloader, sampler_states): + states = {} + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, TorchRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.data_source) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" + "`ReproducibleSampler`.") + else: + sampler = ReproduceBatchSampler( + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(sampler_states) + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - # 4. 保存fp16的状态 - if not isinstance(self.grad_scaler, DummyGradScaler): - grad_scaler_state_dict = self.grad_scaler.state_dict() - states['grad_scaler_state_dict'] = grad_scaler_state_dict + # 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch - logger.debug("Save optimizer state dict") - states["optimizers_state_dict"] = optimizers_state_dict - torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states def get_optimizer_state(self): optimizers_state_dict = {} @@ -262,7 +304,7 @@ class TorchDriver(Driver): if should_load_model: self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) - # 3. 加载fp16的状态 + # 3. 加载 fp16 的状态 if "grad_scaler_state_dict" in states: grad_scaler_state_dict = states.pop("grad_scaler_state_dict") if not isinstance(self.grad_scaler, DummyGradScaler): @@ -273,40 +315,9 @@ class TorchDriver(Driver): f"the training process may be unstable.") # 4. 恢复 sampler 的状态; - dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): - sampler = dataloader_args.batch_sampler - elif isinstance(dataloader_args.sampler, ReproducibleSampler): - sampler = dataloader_args.sampler - elif isinstance(dataloader_args.sampler, TorchRandomSampler): - sampler = RandomSampler(dataloader_args.sampler.data_source) - logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") - elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" - "`ReproducibleSampler`.") - else: - sampler = ReproduceBatchSampler( - batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, - batch_size=dataloader_args.batch_size, - drop_last=dataloader_args.drop_last - ) - sampler.load_state_dict(states.pop('sampler_states')) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - - # 4. 修改 trainer_state.batch_idx_in_epoch - # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, ReproducibleBatchSampler): - if dataloader_args.drop_last: - batch_idx_in_epoch = len( - sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size - else: - batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ - (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size - # sampler 是 batch_sampler; - else: - batch_idx_in_epoch = sampler.batch_idx_in_epoch - - states["batch_idx_in_epoch"] = batch_idx_in_epoch + sampler_states = states.pop('sampler_states') + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) return states