From 76bc51eadd7d3e257f2e439c836c2c48155271d8 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Fri, 10 Jun 2022 16:22:28 +0800 Subject: [PATCH] [to #42362853] support multimodel in pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8997599 --- modelscope/pipelines/base.py | 10 ++++- modelscope/pipelines/builder.py | 51 +++++++++++++++++++++---- modelscope/pipelines/util.py | 52 ++++++++++++++++--------- tests/pipelines/test_builder.py | 68 +++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 28 deletions(-) create mode 100644 tests/pipelines/test_builder.py diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 2e88801a..f4d4d1b7 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -11,6 +11,7 @@ from modelscope.preprocessors import Preprocessor from modelscope.pydatasets import PyDataset from modelscope.utils.config import Config from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.logger import get_logger from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -20,11 +21,15 @@ InputModel = Union[str, Model] output_keys = [ ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key +logger = get_logger() + class Pipeline(ABC): def initiate_single_model(self, model): - if isinstance(model, str): + logger.info(f'initiate model from {model}') + # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model + if isinstance(model, str) and model.startswith('damo/'): if not osp.exists(model): cache_path = get_model_cache_dir(model) model = cache_path if osp.exists( @@ -34,10 +39,11 @@ class Pipeline(ABC): elif isinstance(model, Model): return model else: - if model: + if model and not isinstance(model, str): raise ValueError( f'model type for single model is either str or Model, but got type {type(model)}' ) + return model def initiate_multiple_models(self, input_models: List[InputModel]): models = [] diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6d0ec729..6495a5db 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import Union +from typing import List, Union import json from maas_hub.file_download import model_file_download @@ -44,7 +44,7 @@ def build_pipeline(cfg: ConfigDict, def pipeline(task: str = None, - model: Union[str, Model] = None, + model: Union[str, List[str], Model, List[Model]] = None, preprocessor=None, config_file: str = None, pipeline_name: str = None, @@ -56,7 +56,7 @@ def pipeline(task: str = None, Args: task (str): Task name defining which pipeline will be returned. - model (str or obj:`Model`): model name or model object. + model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object. preprocessor: preprocessor object. config_file (str, optional): path to config file. pipeline_name (str, optional): pipeline class name or alias name. @@ -68,23 +68,42 @@ def pipeline(task: str = None, Examples: ```python + >>> # Using default model for a task >>> p = pipeline('image-classification') - >>> p = pipeline('text-classification', model='distilbert-base-uncased') - >>> # Using model object + >>> # Using pipeline with a model name + >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased') + >>> # Using pipeline with a model object >>> resnet = Model.from_pretrained('Resnet') >>> p = pipeline('image-classification', model=resnet) + >>> # Using pipeline with a list of model names + >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2']) """ if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') if pipeline_name is None: # get default pipeline for this task - pipeline_name, default_model_repo = get_default_pipeline_info(task) + if isinstance(model, str) \ + or (isinstance(model, list) and isinstance(model[0], str)): + + # if is_model_name(model): + if (isinstance(model, str) and model.startswith('damo/')) \ + or (isinstance(model, list) and model[0].startswith('damo/')) \ + or (isinstance(model, str) and osp.exists(model)): + # TODO @wenmeng.zwm add support when model is a str of modelhub address + # read pipeline info from modelhub configuration file. + pipeline_name, default_model_repo = get_default_pipeline_info( + task) + else: + pipeline_name = get_pipeline_by_model_name(task, model) + else: + pipeline_name, default_model_repo = get_default_pipeline_info(task) + if model is None: model = default_model_repo - assert isinstance(model, (type(None), str, Model)), \ - f'model should be either None, str or Model, but got {type(model)}' + assert isinstance(model, (type(None), str, Model, list)), \ + f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' cfg = ConfigDict(type=pipeline_name, model=model) @@ -134,3 +153,19 @@ def get_default_pipeline_info(task): else: pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] return pipeline_name, default_model + + +def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): + """ Get pipeline name by task name and model name + + Args: + task (str): task name. + model (str| list[str]): model names + """ + if isinstance(model, str): + model_key = model + else: + model_key = '_'.join(model) + assert model_key in PIPELINES.modules[task], \ + f'pipeline for task {task} model {model_key} not found.' + return model_key diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py index 92ad6af4..caef6b22 100644 --- a/modelscope/pipelines/util.py +++ b/modelscope/pipelines/util.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import os.path as osp +from typing import List, Union import json from maas_hub.file_download import model_file_download @@ -8,23 +9,38 @@ from maas_hub.file_download import model_file_download from modelscope.utils.constant import CONFIGFILE -def is_model_name(model): - if osp.exists(model): - if osp.exists(osp.join(model, CONFIGFILE)): - return True +def is_model_name(model: Union[str, List]): + """ whether model is a valid modelhub path + """ + + def is_model_name_impl(model): + if osp.exists(model): + if osp.exists(osp.join(model, CONFIGFILE)): + return True + else: + return False else: - return False + # try: + # cfg_file = model_file_download(model, CONFIGFILE) + # except Exception: + # cfg_file = None + # TODO @wenmeng.zwm use exception instead of + # following tricky logic + cfg_file = model_file_download(model, CONFIGFILE) + with open(cfg_file, 'r') as infile: + cfg = json.load(infile) + if 'Code' in cfg: + return False + else: + return True + + if isinstance(model, str): + return is_model_name_impl(model) else: - # try: - # cfg_file = model_file_download(model, CONFIGFILE) - # except Exception: - # cfg_file = None - # TODO @wenmeng.zwm use exception instead of - # following tricky logic - cfg_file = model_file_download(model, CONFIGFILE) - with open(cfg_file, 'r') as infile: - cfg = json.load(infile) - if 'Code' in cfg: - return False - else: - return True + results = [is_model_name_impl(m) for m in model] + all_true = all(results) + any_true = any(results) + if any_true and not all_true: + raise ValueError('some model are hub address, some are not') + + return all_true diff --git a/tests/pipelines/test_builder.py b/tests/pipelines/test_builder.py new file mode 100644 index 00000000..a0b15a32 --- /dev/null +++ b/tests/pipelines/test_builder.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from asyncio import Task +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import PIL + +from modelscope.models.base import Model +from modelscope.pipelines import Pipeline, pipeline +from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.registry import default_group + +logger = get_logger() + + +@PIPELINES.register_module( + group_key=Tasks.image_tagging, module_name='custom_single_model') +class CustomSingleModelPipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model: List[Union[str, Model]] = None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + assert isinstance(model, str), 'model is not str' + print(model) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return super().postprocess(inputs) + + +@PIPELINES.register_module( + group_key=Tasks.image_tagging, module_name='model1_model2') +class CustomMultiModelPipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model: List[Union[str, Model]] = None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + assert isinstance(model, list), 'model is not list' + for m in model: + assert isinstance(m, str), 'submodel is not str' + print(m) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return super().postprocess(inputs) + + +class PipelineInterfaceTest(unittest.TestCase): + + def test_single_model(self): + pipe = pipeline(Tasks.image_tagging, model='custom_single_model') + assert isinstance(pipe, CustomSingleModelPipeline) + + def test_multi_model(self): + pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2']) + assert isinstance(pipe, CustomMultiModelPipeline) + + +if __name__ == '__main__': + unittest.main()