From c7b078704968c4e4ebd621ea48acabcea69e8411 Mon Sep 17 00:00:00 2001 From: "menrui.mr" Date: Thu, 27 Oct 2022 23:29:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E8=BF=87=E7=A8=8B=E5=8F=82=E6=95=B0=E6=9C=AA=E7=94=9F=E6=95=88?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 此前文生图模型没有加载configuration.json中的参数 影响默认配置 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10558026 --- .../multi_modal/ofa_for_text_to_image_synthesis_model.py | 8 +++++++- tests/pipelines/test_ofa_tasks.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py index 8110a0f7..655d36d2 100644 --- a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from os import path as osp from typing import Any, Dict import json @@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg from modelscope.models.multi_modal.ofa.generate.search import Sampling from modelscope.models.multi_modal.ofa.generate.utils import move_to_device -from modelscope.utils.constant import Tasks +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks try: from torchvision.transforms import InterpolationMode @@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model): super().__init__(model_dir=model_dir, *args, **kwargs) # Initialize ofa model = OFAModel.from_pretrained(model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) self.model = model.module if hasattr(model, 'module') else model self.tokenizer = OFATokenizer.from_pretrained(model_dir) self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) @@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model): 'gen_code': True, 'constraint_range': '50265,58457' } + if hasattr(self.cfg.model, 'beam_search'): + sg_args.update(self.cfg.model.beam_search) self.generator = sg.SequenceGenerator(**sg_args) def clip_tokenize(self, texts, context_length=77, truncate=False): diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 57dcb0c3..6be70468 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -243,6 +243,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_with_text_to_image_synthesis_with_name(self): model = 'damo/ofa_text-to-image-synthesis_coco_large_en' ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 example = {'text': 'a bear in the water.'} result = ofa_pipe(example) result[OutputKeys.OUTPUT_IMG].save('result.png') @@ -253,6 +254,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained( 'damo/ofa_text-to-image-synthesis_coco_large_en') ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 example = {'text': 'a bear in the water.'} result = ofa_pipe(example) result[OutputKeys.OUTPUT_IMG].save('result.png')