|
|
@@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device |
|
|
|
from fastNLP.envs import rank_zero_call |
|
|
|
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.core.samplers import ReproducibleBatchSampler |
|
|
|
|
|
|
|
|
|
|
|
class TorchDriver(Driver): |
|
|
@@ -178,8 +179,28 @@ class TorchDriver(Driver): |
|
|
|
model.load_state_dict(res.state_dict()) |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): |
|
|
|
# 1. 保存模型的状态; |
|
|
|
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): |
|
|
|
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 |
|
|
|
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; |
|
|
|
|
|
|
|
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; |
|
|
|
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 |
|
|
|
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
sampler = dataloader_args.batch_sampler |
|
|
|
elif dataloader_args.sampler: |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
else: |
|
|
|
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") |
|
|
|
|
|
|
|
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): |
|
|
|
states['sampler_states'] = sampler.state_dict() |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') |
|
|
|
|
|
|
|
# 2. 保存模型的状态; |
|
|
|
if should_save_model: |
|
|
|
model = self.unwrap_model() |
|
|
|
if only_state_dict: |
|
|
@@ -191,7 +212,7 @@ class TorchDriver(Driver): |
|
|
|
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) |
|
|
|
logger.debug("Save model") |
|
|
|
|
|
|
|
# 2. 保存 optimizers 的状态; |
|
|
|
# 3. 保存 optimizers 的状态; |
|
|
|
optimizers_state_dict = {} |
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
optimizer: torch.optim.Optimizer = self.optimizers[i] |
|
|
@@ -203,7 +224,7 @@ class TorchDriver(Driver): |
|
|
|
states["optimizers_state_dict"] = optimizers_state_dict |
|
|
|
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
|
|
|
|
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: |
|
|
|
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: |
|
|
|
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
|
|
|
|
# 1. 加载 optimizers 的状态; |
|
|
@@ -224,6 +245,39 @@ class TorchDriver(Driver): |
|
|
|
model.load_state_dict(res.state_dict()) |
|
|
|
logger.debug("Load model.") |
|
|
|
|
|
|
|
# 3. 恢复 sampler 的状态; |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
sampler = dataloader_args.sampler |
|
|
|
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): |
|
|
|
# 说明这里需要使用 ReproduceSampler 来弄一下了 |
|
|
|
if self.is_distributed(): |
|
|
|
raise RuntimeError( |
|
|
|
"It is not allowed to use single device checkpoint retraining before but ddp now.") |
|
|
|
sampler = ReproducibleBatchSampler( |
|
|
|
batch_sampler=sampler, |
|
|
|
batch_size=dataloader_args.batch_size, |
|
|
|
drop_last=dataloader_args.drop_last |
|
|
|
) |
|
|
|
sampler.load_state_dict(states['sampler_states']) |
|
|
|
|
|
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
|
|
|
|
|
# 4. 修改 trainer_state.batch_idx_in_epoch |
|
|
|
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; |
|
|
|
if not isinstance(sampler, ReproducibleBatchSampler): |
|
|
|
if dataloader_args.drop_last: |
|
|
|
batch_idx_in_epoch = len( |
|
|
|
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ |
|
|
|
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size |
|
|
|
# sampler 是 batch_sampler; |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = sampler.batch_idx_in_epoch |
|
|
|
|
|
|
|
states["batch_idx_in_epoch"] = batch_idx_in_epoch |
|
|
|
|
|
|
|
return states |
|
|
|
|
|
|
|
def get_evaluate_context(self): |
|
|
|