diff --git a/modelscope/models/nlp/gpt3/backbone.py b/modelscope/models/nlp/gpt3/backbone.py index 587c7a9d..4647428e 100644 --- a/modelscope/models/nlp/gpt3/backbone.py +++ b/modelscope/models/nlp/gpt3/backbone.py @@ -342,6 +342,8 @@ class GPT3Model(PreTrainedModel): state_dict_file = os.path.join(pretrained_model_name_or_path, ModelFile.TORCH_MODEL_BIN_FILE) state_dict = torch.load(state_dict_file) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] state_dict = { k.replace('model.language_model', 'language_model'): v for k, v in state_dict.items() diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index ffb30090..c97f347d 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' + self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large' self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' + self.gpt3_poetry_input = '天生我材必有用,' def run_pipeline_with_model_instance(self, model_id, input): model = Model.from_pretrained(model_id) @@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_instance(self.palm_model_id_en, self.palm_input_en) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_gpt_base_with_model_instance(self): self.run_pipeline_with_model_instance(self.gpt3_base_model_id, @@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_instance(self.gpt3_large_model_id, self.gpt3_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): for model_id, input in ((self.palm_model_id_zh_base,