|
|
|
@@ -1,6 +1,7 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import os |
|
|
|
from os import path as osp |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import json |
|
|
|
@@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer |
|
|
|
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg |
|
|
|
from modelscope.models.multi_modal.ofa.generate.search import Sampling |
|
|
|
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
from modelscope.utils.config import Config |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
|
|
|
|
try: |
|
|
|
from torchvision.transforms import InterpolationMode |
|
|
|
@@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model): |
|
|
|
super().__init__(model_dir=model_dir, *args, **kwargs) |
|
|
|
# Initialize ofa |
|
|
|
model = OFAModel.from_pretrained(model_dir) |
|
|
|
self.cfg = Config.from_file( |
|
|
|
osp.join(model_dir, ModelFile.CONFIGURATION)) |
|
|
|
self.model = model.module if hasattr(model, 'module') else model |
|
|
|
self.tokenizer = OFATokenizer.from_pretrained(model_dir) |
|
|
|
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) |
|
|
|
@@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model): |
|
|
|
'gen_code': True, |
|
|
|
'constraint_range': '50265,58457' |
|
|
|
} |
|
|
|
if hasattr(self.cfg.model, 'beam_search'): |
|
|
|
sg_args.update(self.cfg.model.beam_search) |
|
|
|
self.generator = sg.SequenceGenerator(**sg_args) |
|
|
|
|
|
|
|
def clip_tokenize(self, texts, context_length=77, truncate=False): |
|
|
|
|