Browse Source

[to #42322933] Fix gpt3 loading checkpoint after finetuning.

1. 修复GPT-3模型无法加载finetune保存的checkpoint的问题
2. 为GPT-3诗词生成模型添加 ut
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10537209
master
hemu.zp yingda.chen 2 years ago
parent
commit
d0f8547e7e
2 changed files with 14 additions and 0 deletions
  1. +2
    -0
      modelscope/models/nlp/gpt3/backbone.py
  2. +12
    -0
      tests/pipelines/test_text_generation.py

+ 2
- 0
modelscope/models/nlp/gpt3/backbone.py View File

@@ -342,6 +342,8 @@ class GPT3Model(PreTrainedModel):
state_dict_file = os.path.join(pretrained_model_name_or_path,
ModelFile.TORCH_MODEL_BIN_FILE)
state_dict = torch.load(state_dict_file)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
state_dict = {
k.replace('model.language_model', 'language_model'): v
for k, v in state_dict.items()


+ 12
- 0
tests/pipelines/test_text_generation.py View File

@@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base'
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large'
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,'
self.gpt3_poetry_input = '天生我材必有用,'

def run_pipeline_with_model_instance(self, model_id, input):
model = Model.from_pretrained(model_id)
@@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_instance(self.palm_model_id_en,
self.palm_input_en)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_poetry_large_with_model_name(self):
self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id,
self.gpt3_poetry_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_base_model_id,
@@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_instance(self.gpt3_large_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_poetry_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id,
self.gpt3_poetry_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh_base,


Loading…
Cancel
Save