Browse Source

TorchDriver的sampler加载和保存拆分为单独的函数

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
9903d2eec1
1 changed files with 65 additions and 54 deletions
  1. +65
    -54
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 65
- 54
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -190,7 +190,30 @@ class TorchDriver(Driver):
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;

# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 1. sampler 的状态;
num_consumed_batches = states.pop('num_consumed_batches')
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)

# 2. 保存模型的状态;
if should_save_model:
if not os.path.exists(folder):
os.mkdir(folder)
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
self.save_model(model_path, only_state_dict=only_state_dict)

# 3. 保存 optimizers 的状态;
states["optimizers_state_dict"] = self.get_optimizer_state()
logger.debug("Save optimizer state dict.")

# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))

def get_sampler_state(self, dataloader, num_consumed_batches):
# 因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
dataloader_args = self.get_dataloader_args(dataloader)
@@ -200,7 +223,7 @@ class TorchDriver(Driver):
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
if dataloader_args.batch_size is not None:
@@ -209,30 +232,49 @@ class TorchDriver(Driver):
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's "
"`num_consumed_samples`, it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states
else:
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
'state.')

# 2. 保存模型的状态;
if should_save_model:
if not os.path.exists(folder):
os.mkdir(folder)
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
self.save_model(model_path, only_state_dict=only_state_dict)
return sampler_states

# 3. 保存 optimizers 的状态;
optimizers_state_dict = self.get_optimizer_state()
def load_sampler_state(self, dataloader, sampler_states):
states = {}
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif isinstance(dataloader_args.sampler, TorchRandomSampler):
sampler = RandomSampler(dataloader_args.sampler.data_source)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
"`ReproducibleSampler`.")
else:
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(sampler_states)
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict
# 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler;
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch

logger.debug("Save optimizer state dict")
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
states["batch_idx_in_epoch"] = batch_idx_in_epoch
return states

def get_optimizer_state(self):
optimizers_state_dict = {}
@@ -262,7 +304,7 @@ class TorchDriver(Driver):
if should_load_model:
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)

# 3. 加载fp16的状态
# 3. 加载 fp16 的状态
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if not isinstance(self.grad_scaler, DummyGradScaler):
@@ -273,40 +315,9 @@ class TorchDriver(Driver):
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif isinstance(dataloader_args.sampler, TorchRandomSampler):
sampler = RandomSampler(dataloader_args.sampler.data_source)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
"`ReproducibleSampler`.")
else:
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states.pop('sampler_states'))
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler;
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch

states["batch_idx_in_epoch"] = batch_idx_in_epoch
sampler_states = states.pop('sampler_states')
states_ret = self.load_sampler_state(dataloader, sampler_states)
states.update(states_ret)

return states



Loading…
Cancel
Save