From d0f8547e7ebbcd8108ee1fe83aa85230459b12de Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Wed, 26 Oct 2022 20:58:00 +0800 Subject: [PATCH] [to #42322933] Fix gpt3 loading checkpoint after finetuning. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复GPT-3模型无法加载finetune保存的checkpoint的问题 2. 为GPT-3诗词生成模型添加 ut Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10537209 --- modelscope/models/nlp/gpt3/backbone.py | 2 ++ tests/pipelines/test_text_generation.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/modelscope/models/nlp/gpt3/backbone.py b/modelscope/models/nlp/gpt3/backbone.py index 587c7a9d..4647428e 100644 --- a/modelscope/models/nlp/gpt3/backbone.py +++ b/modelscope/models/nlp/gpt3/backbone.py @@ -342,6 +342,8 @@ class GPT3Model(PreTrainedModel): state_dict_file = os.path.join(pretrained_model_name_or_path, ModelFile.TORCH_MODEL_BIN_FILE) state_dict = torch.load(state_dict_file) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] state_dict = { k.replace('model.language_model', 'language_model'): v for k, v in state_dict.items() diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index ffb30090..c97f347d 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' + self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large' self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' + self.gpt3_poetry_input = '天生我材必有用,' def run_pipeline_with_model_instance(self, model_id, input): model = Model.from_pretrained(model_id) @@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_instance(self.palm_model_id_en, self.palm_input_en) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_gpt_base_with_model_instance(self): self.run_pipeline_with_model_instance(self.gpt3_base_model_id, @@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_instance(self.gpt3_large_model_id, self.gpt3_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): for model_id, input in ((self.palm_model_id_zh_base,