Browse Source

[to #42322933] fix checkpoint format

1. 修复 palm,gpt3,mplug 模型存在的 finetune 后保存 checkpoint 与原有 checkpoint key 字段存在区别无法使用 from_pretrained 导入的问题
2. 调整 test_finetune_mplug.py 为只保存训练结束时的 checkpoint,减少 ci 耗时
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10016517
master
hemu.zp yingda.chen 3 years ago
parent
commit
3d3f9b4537
4 changed files with 36 additions and 27 deletions
  1. +6
    -4
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  2. +4
    -0
      modelscope/models/nlp/gpt3/modeling_gpt3.py
  3. +8
    -8
      modelscope/models/nlp/palm_v2/modeling_palm.py
  4. +18
    -15
      tests/trainers/test_finetune_mplug.py

+ 6
- 4
modelscope/models/multi_modal/mplug/modeling_mplug.py View File

@@ -1867,11 +1867,13 @@ class MPlug(PreTrainedModel):
ModelFile.TORCH_MODEL_BIN_FILE)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint['module']
checkpoint = checkpoint['model']
checkpoint = {
k.replace('model.', ''): v
for k, v in checkpoint.items()
}

msg = model.load_state_dict(state_dict, strict=False)
msg = model.load_state_dict(checkpoint, strict=False)
print('load checkpoint from %s' % checkpoint_path)
print(msg)
return model


+ 4
- 0
modelscope/models/nlp/gpt3/modeling_gpt3.py View File

@@ -339,5 +339,9 @@ class GPT3Model(PreTrainedModel):
state_dict_file = os.path.join(pretrained_model_name_or_path,
ModelFile.TORCH_MODEL_BIN_FILE)
state_dict = torch.load(state_dict_file)
state_dict = {
k.replace('model.language_model', 'language_model'): v
for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
return model

+ 8
- 8
modelscope/models/nlp/palm_v2/modeling_palm.py View File

@@ -592,11 +592,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
self.generator.dense.weight = self.decoder.embeddings.weight

if checkpoint is not None:
for key in list(checkpoint['model'].keys()):
checkpoint['model'][key.replace('module.',
'')] = checkpoint['model'][key]
msg = self.load_state_dict(checkpoint['model'], strict=False)
print(msg)
if 'model' in checkpoint:
checkpoint = checkpoint['model']
for key in list(checkpoint.keys()):
checkpoint[key.replace('model.palm.', '')] = checkpoint[key]
self.load_state_dict(checkpoint, strict=False)
else:
for module in self.decoder.modules():
if isinstance(module, (nn.Linear, nn.Embedding)):
@@ -734,7 +734,7 @@ class PalmForConditionalGeneration(PalmPreTrainedModel):
return addict.Dict(loss=loss)


class Translator(nn.Module):
class Translator(object):
"""
Uses a model to translate a batch of sentences.
"""
@@ -1298,8 +1298,8 @@ class Translator(nn.Module):

return results

def forward(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
def __call__(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
batch = self.Batch(
batch_size=input_ids.size()[0],
src=input_ids,


+ 18
- 15
tests/trainers/test_finetune_mplug.py View File

@@ -41,6 +41,18 @@ class TestFinetuneMPlug(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()

def _cfg_modify_fn(self, cfg):
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': self.max_epochs
}, {
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}]
return cfg

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_caption(self):
kwargs = dict(
@@ -48,15 +60,12 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_caption_with_model_and_args(self):
@@ -86,15 +95,12 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_vqa_with_model_and_args(self):
@@ -124,15 +130,12 @@ class TestFinetuneMPlug(unittest.TestCase):
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
cfg_modify_fn=self._cfg_modify_fn)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_retrieval_with_model_and_args(self):


Loading…
Cancel
Save