From a376eea776eed523ff0a9118a2cc97d0a3134db8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 10 Apr 2022 12:56:49 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=E7=9A=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E4=B8=BB=E8=A6=81=E4=BF=AE=E6=94=B9=E4=BA=86=20traine?= =?UTF-8?q?r.save/load=20=E5=92=8C=20driver.save=20=E5=92=8C=20load=20?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../callbacks/load_best_model_callback.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 62 +++++++++++++++++-- .../test_checkpoint_callback_torch.py | 4 +- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index bd6b8e66..b4ef4e62 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): real_monitor=self._real_monitor, res=results) if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): + (monitor_value > self.monitor_value and self.larger_better): self.monitor_value = monitor_value if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 0e1a45e0..96d11761 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler class TorchDriver(Driver): @@ -178,8 +179,28 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) @rank_zero_call - def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): - # 1. 保存模型的状态; + def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 + # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + states['sampler_states'] = sampler.state_dict() + else: + raise RuntimeError( + 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') + + # 2. 保存模型的状态; if should_save_model: model = self.unwrap_model() if only_state_dict: @@ -191,7 +212,7 @@ class TorchDriver(Driver): torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) logger.debug("Save model") - # 2. 保存 optimizers 的状态; + # 3. 保存 optimizers 的状态; optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: torch.optim.Optimizer = self.optimizers[i] @@ -203,7 +224,7 @@ class TorchDriver(Driver): states["optimizers_state_dict"] = optimizers_state_dict torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) - def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; @@ -224,6 +245,39 @@ class TorchDriver(Driver): model.load_state_dict(res.state_dict()) logger.debug("Load model.") + # 3. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + + sampler = dataloader_args.sampler + if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): + # 说明这里需要使用 ReproduceSampler 来弄一下了 + if self.is_distributed(): + raise RuntimeError( + "It is not allowed to use single device checkpoint retraining before but ddp now.") + sampler = ReproducibleBatchSampler( + batch_sampler=sampler, + batch_size=dataloader_args.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(states['sampler_states']) + + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch + return states def get_evaluate_context(self): diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 759135f0..f7cc6e5f 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( dist.destroy_process_group() -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @pytest.mark.parametrize("only_state_dict", [True, False]) @magic_argv_env_context @@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; -@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context def test_trainer_checkpoint_callback_2(