1. 修复 palm,gpt3,mplug 模型存在的 finetune 后保存 checkpoint 与原有 checkpoint key 字段存在区别无法使用 from_pretrained 导入的问题 2. 调整 test_finetune_mplug.py 为只保存训练结束时的 checkpoint,减少 ci 耗时 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10016517master
@@ -1867,11 +1867,13 @@ class MPlug(PreTrainedModel): | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |||
if 'model' in checkpoint: | |||
state_dict = checkpoint['model'] | |||
else: | |||
state_dict = checkpoint['module'] | |||
checkpoint = checkpoint['model'] | |||
checkpoint = { | |||
k.replace('model.', ''): v | |||
for k, v in checkpoint.items() | |||
} | |||
msg = model.load_state_dict(state_dict, strict=False) | |||
msg = model.load_state_dict(checkpoint, strict=False) | |||
print('load checkpoint from %s' % checkpoint_path) | |||
print(msg) | |||
return model | |||
@@ -339,5 +339,9 @@ class GPT3Model(PreTrainedModel): | |||
state_dict_file = os.path.join(pretrained_model_name_or_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
state_dict = torch.load(state_dict_file) | |||
state_dict = { | |||
k.replace('model.language_model', 'language_model'): v | |||
for k, v in state_dict.items() | |||
} | |||
model.load_state_dict(state_dict) | |||
return model |
@@ -592,11 +592,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model | |||
self.generator.dense.weight = self.decoder.embeddings.weight | |||
if checkpoint is not None: | |||
for key in list(checkpoint['model'].keys()): | |||
checkpoint['model'][key.replace('module.', | |||
'')] = checkpoint['model'][key] | |||
msg = self.load_state_dict(checkpoint['model'], strict=False) | |||
print(msg) | |||
if 'model' in checkpoint: | |||
checkpoint = checkpoint['model'] | |||
for key in list(checkpoint.keys()): | |||
checkpoint[key.replace('model.palm.', '')] = checkpoint[key] | |||
self.load_state_dict(checkpoint, strict=False) | |||
else: | |||
for module in self.decoder.modules(): | |||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||
@@ -734,7 +734,7 @@ class PalmForConditionalGeneration(PalmPreTrainedModel): | |||
return addict.Dict(loss=loss) | |||
class Translator(nn.Module): | |||
class Translator(object): | |||
""" | |||
Uses a model to translate a batch of sentences. | |||
""" | |||
@@ -1298,8 +1298,8 @@ class Translator(nn.Module): | |||
return results | |||
def forward(self, input_ids: torch.Tensor, | |||
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
def __call__(self, input_ids: torch.Tensor, | |||
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
batch = self.Batch( | |||
batch_size=input_ids.size()[0], | |||
src=input_ids, | |||
@@ -41,6 +41,18 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
def _cfg_modify_fn(self, cfg): | |||
cfg.train.hooks = [{ | |||
'type': 'CheckpointHook', | |||
'interval': self.max_epochs | |||
}, { | |||
'type': 'TextLoggerHook', | |||
'interval': 1 | |||
}, { | |||
'type': 'IterTimerHook' | |||
}] | |||
return cfg | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_with_caption(self): | |||
kwargs = dict( | |||
@@ -48,15 +60,12 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
max_epochs=self.max_epochs, | |||
work_dir=self.tmp_dir) | |||
work_dir=self.tmp_dir, | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_caption_with_model_and_args(self): | |||
@@ -86,15 +95,12 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
max_epochs=self.max_epochs, | |||
work_dir=self.tmp_dir) | |||
work_dir=self.tmp_dir, | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_vqa_with_model_and_args(self): | |||
@@ -124,15 +130,12 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
max_epochs=self.max_epochs, | |||
work_dir=self.tmp_dir) | |||
work_dir=self.tmp_dir, | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_retrieval_with_model_and_args(self): | |||