|
|
@@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
|
|
|
|
self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' |
|
|
|
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' |
|
|
|
self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large' |
|
|
|
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' |
|
|
|
self.gpt3_poetry_input = '天生我材必有用,' |
|
|
|
|
|
|
|
def run_pipeline_with_model_instance(self, model_id, input): |
|
|
|
model = Model.from_pretrained(model_id) |
|
|
@@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_en, |
|
|
|
self.palm_input_en) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_gpt_poetry_large_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id, |
|
|
|
self.gpt3_poetry_input) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_gpt_base_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.gpt3_base_model_id, |
|
|
@@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
self.run_pipeline_with_model_instance(self.gpt3_large_model_id, |
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_gpt_poetry_large_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id, |
|
|
|
self.gpt3_poetry_input) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_run_palm(self): |
|
|
|
for model_id, input in ((self.palm_model_id_zh_base, |
|
|
|