|
|
@@ -190,7 +190,30 @@ class TorchDriver(Driver): |
|
|
|
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 |
|
|
|
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; |
|
|
|
|
|
|
|
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; |
|
|
|
# 1. sampler 的状态; |
|
|
|
num_consumed_batches = states.pop('num_consumed_batches') |
|
|
|
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) |
|
|
|
|
|
|
|
# 2. 保存模型的状态; |
|
|
|
if should_save_model: |
|
|
|
if not os.path.exists(folder): |
|
|
|
os.mkdir(folder) |
|
|
|
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) |
|
|
|
self.save_model(model_path, only_state_dict=only_state_dict) |
|
|
|
|
|
|
|
# 3. 保存 optimizers 的状态; |
|
|
|
states["optimizers_state_dict"] = self.get_optimizer_state() |
|
|
|
logger.debug("Save optimizer state dict.") |
|
|
|
|
|
|
|
# 4. 保存fp16的状态 |
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
grad_scaler_state_dict = self.grad_scaler.state_dict() |
|
|
|
states['grad_scaler_state_dict'] = grad_scaler_state_dict |
|
|
|
|
|
|
|
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
|
|
|
|
def get_sampler_state(self, dataloader, num_consumed_batches): |
|
|
|
# 因为我们支持 resume training,即精确恢复到具体的一个 batch; |
|
|
|
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 |
|
|
|
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
@@ -200,7 +223,7 @@ class TorchDriver(Driver): |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
else: |
|
|
|
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") |
|
|
|
num_consumed_batches = states.pop('num_consumed_batches') |
|
|
|
|
|
|
|
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): |
|
|
|
sampler_states = sampler.state_dict() |
|
|
|
if dataloader_args.batch_size is not None: |
|
|
@@ -209,30 +232,49 @@ class TorchDriver(Driver): |
|
|
|
else: |
|
|
|
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " |
|
|
|
"`num_consumed_samples`, it may cause missing some samples when reload.") |
|
|
|
|
|
|
|
states['sampler_states'] = sampler_states |
|
|
|
else: |
|
|
|
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' |
|
|
|
'state.') |
|
|
|
|
|
|
|
# 2. 保存模型的状态; |
|
|
|
if should_save_model: |
|
|
|
if not os.path.exists(folder): |
|
|
|
os.mkdir(folder) |
|
|
|
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) |
|
|
|
self.save_model(model_path, only_state_dict=only_state_dict) |
|
|
|
return sampler_states |
|
|
|
|
|
|
|
# 3. 保存 optimizers 的状态; |
|
|
|
optimizers_state_dict = self.get_optimizer_state() |
|
|
|
def load_sampler_state(self, dataloader, sampler_states): |
|
|
|
states = {} |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
sampler = dataloader_args.batch_sampler |
|
|
|
elif isinstance(dataloader_args.sampler, ReproducibleSampler): |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
elif isinstance(dataloader_args.sampler, TorchRandomSampler): |
|
|
|
sampler = RandomSampler(dataloader_args.sampler.data_source) |
|
|
|
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") |
|
|
|
elif self.is_distributed(): |
|
|
|
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" |
|
|
|
"`ReproducibleSampler`.") |
|
|
|
else: |
|
|
|
sampler = ReproduceBatchSampler( |
|
|
|
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, |
|
|
|
batch_size=dataloader_args.batch_size, |
|
|
|
drop_last=dataloader_args.drop_last |
|
|
|
) |
|
|
|
sampler.load_state_dict(sampler_states) |
|
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
|
|
|
|
|
# 4. 保存fp16的状态 |
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
grad_scaler_state_dict = self.grad_scaler.state_dict() |
|
|
|
states['grad_scaler_state_dict'] = grad_scaler_state_dict |
|
|
|
# 修改 trainer_state.batch_idx_in_epoch |
|
|
|
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; |
|
|
|
if not isinstance(sampler, ReproducibleBatchSampler): |
|
|
|
if dataloader_args.drop_last: |
|
|
|
batch_idx_in_epoch = len( |
|
|
|
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ |
|
|
|
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size |
|
|
|
# sampler 是 batch_sampler; |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = sampler.batch_idx_in_epoch |
|
|
|
|
|
|
|
logger.debug("Save optimizer state dict") |
|
|
|
states["optimizers_state_dict"] = optimizers_state_dict |
|
|
|
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
states["batch_idx_in_epoch"] = batch_idx_in_epoch |
|
|
|
return states |
|
|
|
|
|
|
|
def get_optimizer_state(self): |
|
|
|
optimizers_state_dict = {} |
|
|
@@ -262,7 +304,7 @@ class TorchDriver(Driver): |
|
|
|
if should_load_model: |
|
|
|
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) |
|
|
|
|
|
|
|
# 3. 加载fp16的状态 |
|
|
|
# 3. 加载 fp16 的状态 |
|
|
|
if "grad_scaler_state_dict" in states: |
|
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") |
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
@@ -273,40 +315,9 @@ class TorchDriver(Driver): |
|
|
|
f"the training process may be unstable.") |
|
|
|
|
|
|
|
# 4. 恢复 sampler 的状态; |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
sampler = dataloader_args.batch_sampler |
|
|
|
elif isinstance(dataloader_args.sampler, ReproducibleSampler): |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
elif isinstance(dataloader_args.sampler, TorchRandomSampler): |
|
|
|
sampler = RandomSampler(dataloader_args.sampler.data_source) |
|
|
|
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") |
|
|
|
elif self.is_distributed(): |
|
|
|
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" |
|
|
|
"`ReproducibleSampler`.") |
|
|
|
else: |
|
|
|
sampler = ReproduceBatchSampler( |
|
|
|
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, |
|
|
|
batch_size=dataloader_args.batch_size, |
|
|
|
drop_last=dataloader_args.drop_last |
|
|
|
) |
|
|
|
sampler.load_state_dict(states.pop('sampler_states')) |
|
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
|
|
|
|
|
# 4. 修改 trainer_state.batch_idx_in_epoch |
|
|
|
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; |
|
|
|
if not isinstance(sampler, ReproducibleBatchSampler): |
|
|
|
if dataloader_args.drop_last: |
|
|
|
batch_idx_in_epoch = len( |
|
|
|
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ |
|
|
|
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size |
|
|
|
# sampler 是 batch_sampler; |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = sampler.batch_idx_in_epoch |
|
|
|
|
|
|
|
states["batch_idx_in_epoch"] = batch_idx_in_epoch |
|
|
|
sampler_states = states.pop('sampler_states') |
|
|
|
states_ret = self.load_sampler_state(dataloader, sampler_states) |
|
|
|
states.update(states_ret) |
|
|
|
|
|
|
|
return states |
|
|
|
|
|
|
|