Browse Source

重新修改了断点重训的逻辑,主要修改了 trainer.save/load 和 driver.save 和 load 函数

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
a376eea776
3 changed files with 61 additions and 7 deletions
  1. +1
    -1
      fastNLP/core/callbacks/load_best_model_callback.py
  2. +58
    -4
      fastNLP/core/drivers/torch_driver/torch_driver.py
  3. +2
    -2
      tests/core/callbacks/test_checkpoint_callback_torch.py

+ 1
- 1
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback):
real_monitor=self._real_monitor,
res=results)
if (monitor_value < self.monitor_value and self.larger_better is False) or \
(monitor_value > self.monitor_value and self.larger_better):
(monitor_value > self.monitor_value and self.larger_better):
self.monitor_value = monitor_value
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,


+ 58
- 4
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler


class TorchDriver(Driver):
@@ -178,8 +179,28 @@ class TorchDriver(Driver):
model.load_state_dict(res.state_dict())

@rank_zero_call
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# 1. 保存模型的状态;
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;

# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")

if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')

# 2. 保存模型的状态;
if should_save_model:
model = self.unwrap_model()
if only_state_dict:
@@ -191,7 +212,7 @@ class TorchDriver(Driver):
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME))
logger.debug("Save model")

# 2. 保存 optimizers 的状态;
# 3. 保存 optimizers 的状态;
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: torch.optim.Optimizer = self.optimizers[i]
@@ -203,7 +224,7 @@ class TorchDriver(Driver):
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))

def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))

# 1. 加载 optimizers 的状态;
@@ -224,6 +245,39 @@ class TorchDriver(Driver):
model.load_state_dict(res.state_dict())
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)

sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])

states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler;
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch

states["batch_idx_in_epoch"] = batch_idx_in_epoch

return states

def get_evaluate_context(self):


+ 2
- 2
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2(
dist.destroy_process_group()


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
@@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1(


# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_trainer_checkpoint_callback_2(


Loading…
Cancel
Save