|
|
@@ -6,16 +6,26 @@ from typing import Union |
|
|
|
import json |
|
|
|
from maas_hub.file_download import model_file_download |
|
|
|
|
|
|
|
from maas_lib.models.base import Model |
|
|
|
from maas_lib.utils.config import Config, ConfigDict |
|
|
|
from maas_lib.utils.constant import CONFIGFILE, Tasks |
|
|
|
from maas_lib.utils.registry import Registry, build_from_cfg |
|
|
|
from modelscope.models.base import Model |
|
|
|
from modelscope.utils.config import Config, ConfigDict |
|
|
|
from modelscope.utils.constant import CONFIGFILE, Tasks |
|
|
|
from modelscope.utils.registry import Registry, build_from_cfg |
|
|
|
from .base import Pipeline |
|
|
|
from .default import DEFAULT_MODEL_FOR_PIPELINE, get_default_pipeline_info |
|
|
|
from .util import is_model_name |
|
|
|
|
|
|
|
PIPELINES = Registry('pipelines') |
|
|
|
|
|
|
|
DEFAULT_MODEL_FOR_PIPELINE = { |
|
|
|
# TaskName: (pipeline_module_name, model_repo) |
|
|
|
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), |
|
|
|
Tasks.text_classification: |
|
|
|
('bert-sentiment-analysis', 'damo/bert-base-sst2'), |
|
|
|
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), |
|
|
|
Tasks.image_captioning: ('ofa', None), |
|
|
|
Tasks.image_generation: |
|
|
|
('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def build_pipeline(cfg: ConfigDict, |
|
|
|
task_name: str = None, |
|
|
@@ -84,3 +94,42 @@ def pipeline(task: str = None, |
|
|
|
cfg.preprocessor = preprocessor |
|
|
|
|
|
|
|
return build_pipeline(cfg, task_name=task) |
|
|
|
|
|
|
|
|
|
|
|
def add_default_pipeline_info(task: str, |
|
|
|
model_name: str, |
|
|
|
modelhub_name: str = None, |
|
|
|
overwrite: bool = False): |
|
|
|
""" Add default model for a task. |
|
|
|
|
|
|
|
Args: |
|
|
|
task (str): task name. |
|
|
|
model_name (str): model_name. |
|
|
|
modelhub_name (str): name for default modelhub. |
|
|
|
overwrite (bool): overwrite default info. |
|
|
|
""" |
|
|
|
if not overwrite: |
|
|
|
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \ |
|
|
|
f'task {task} already has default model.' |
|
|
|
|
|
|
|
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name) |
|
|
|
|
|
|
|
|
|
|
|
def get_default_pipeline_info(task): |
|
|
|
""" Get default info for certain task. |
|
|
|
|
|
|
|
Args: |
|
|
|
task (str): task name. |
|
|
|
|
|
|
|
Return: |
|
|
|
A tuple: first element is pipeline name(model_name), second element |
|
|
|
is modelhub name. |
|
|
|
""" |
|
|
|
|
|
|
|
if task not in DEFAULT_MODEL_FOR_PIPELINE: |
|
|
|
# support pipeline which does not register default model |
|
|
|
pipeline_name = list(PIPELINES.modules[task].keys())[0] |
|
|
|
default_model = None |
|
|
|
else: |
|
|
|
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] |
|
|
|
return pipeline_name, default_model |