Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8997599master
@@ -11,6 +11,7 @@ from modelscope.preprocessors import Preprocessor | |||
from modelscope.pydatasets import PyDataset | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.logger import get_logger | |||
from .util import is_model_name | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
@@ -20,11 +21,15 @@ InputModel = Union[str, Model] | |||
output_keys = [ | |||
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||
logger = get_logger() | |||
class Pipeline(ABC): | |||
def initiate_single_model(self, model): | |||
if isinstance(model, str): | |||
logger.info(f'initiate model from {model}') | |||
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||
if isinstance(model, str) and model.startswith('damo/'): | |||
if not osp.exists(model): | |||
cache_path = get_model_cache_dir(model) | |||
model = cache_path if osp.exists( | |||
@@ -34,10 +39,11 @@ class Pipeline(ABC): | |||
elif isinstance(model, Model): | |||
return model | |||
else: | |||
if model: | |||
if model and not isinstance(model, str): | |||
raise ValueError( | |||
f'model type for single model is either str or Model, but got type {type(model)}' | |||
) | |||
return model | |||
def initiate_multiple_models(self, input_models: List[InputModel]): | |||
models = [] | |||
@@ -1,7 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from typing import Union | |||
from typing import List, Union | |||
import json | |||
from maas_hub.file_download import model_file_download | |||
@@ -44,7 +44,7 @@ def build_pipeline(cfg: ConfigDict, | |||
def pipeline(task: str = None, | |||
model: Union[str, Model] = None, | |||
model: Union[str, List[str], Model, List[Model]] = None, | |||
preprocessor=None, | |||
config_file: str = None, | |||
pipeline_name: str = None, | |||
@@ -56,7 +56,7 @@ def pipeline(task: str = None, | |||
Args: | |||
task (str): Task name defining which pipeline will be returned. | |||
model (str or obj:`Model`): model name or model object. | |||
model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object. | |||
preprocessor: preprocessor object. | |||
config_file (str, optional): path to config file. | |||
pipeline_name (str, optional): pipeline class name or alias name. | |||
@@ -68,23 +68,42 @@ def pipeline(task: str = None, | |||
Examples: | |||
```python | |||
>>> # Using default model for a task | |||
>>> p = pipeline('image-classification') | |||
>>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||
>>> # Using model object | |||
>>> # Using pipeline with a model name | |||
>>> p = pipeline('text-classification', model='damo/distilbert-base-uncased') | |||
>>> # Using pipeline with a model object | |||
>>> resnet = Model.from_pretrained('Resnet') | |||
>>> p = pipeline('image-classification', model=resnet) | |||
>>> # Using pipeline with a list of model names | |||
>>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2']) | |||
""" | |||
if task is None and pipeline_name is None: | |||
raise ValueError('task or pipeline_name is required') | |||
if pipeline_name is None: | |||
# get default pipeline for this task | |||
pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
if isinstance(model, str) \ | |||
or (isinstance(model, list) and isinstance(model[0], str)): | |||
# if is_model_name(model): | |||
if (isinstance(model, str) and model.startswith('damo/')) \ | |||
or (isinstance(model, list) and model[0].startswith('damo/')) \ | |||
or (isinstance(model, str) and osp.exists(model)): | |||
# TODO @wenmeng.zwm add support when model is a str of modelhub address | |||
# read pipeline info from modelhub configuration file. | |||
pipeline_name, default_model_repo = get_default_pipeline_info( | |||
task) | |||
else: | |||
pipeline_name = get_pipeline_by_model_name(task, model) | |||
else: | |||
pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
if model is None: | |||
model = default_model_repo | |||
assert isinstance(model, (type(None), str, Model)), \ | |||
f'model should be either None, str or Model, but got {type(model)}' | |||
assert isinstance(model, (type(None), str, Model, list)), \ | |||
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
cfg = ConfigDict(type=pipeline_name, model=model) | |||
@@ -134,3 +153,19 @@ def get_default_pipeline_info(task): | |||
else: | |||
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | |||
return pipeline_name, default_model | |||
def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): | |||
""" Get pipeline name by task name and model name | |||
Args: | |||
task (str): task name. | |||
model (str| list[str]): model names | |||
""" | |||
if isinstance(model, str): | |||
model_key = model | |||
else: | |||
model_key = '_'.join(model) | |||
assert model_key in PIPELINES.modules[task], \ | |||
f'pipeline for task {task} model {model_key} not found.' | |||
return model_key |
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
from typing import List, Union | |||
import json | |||
from maas_hub.file_download import model_file_download | |||
@@ -8,23 +9,38 @@ from maas_hub.file_download import model_file_download | |||
from modelscope.utils.constant import CONFIGFILE | |||
def is_model_name(model): | |||
if osp.exists(model): | |||
if osp.exists(osp.join(model, CONFIGFILE)): | |||
return True | |||
def is_model_name(model: Union[str, List]): | |||
""" whether model is a valid modelhub path | |||
""" | |||
def is_model_name_impl(model): | |||
if osp.exists(model): | |||
if osp.exists(osp.join(model, CONFIGFILE)): | |||
return True | |||
else: | |||
return False | |||
else: | |||
return False | |||
# try: | |||
# cfg_file = model_file_download(model, CONFIGFILE) | |||
# except Exception: | |||
# cfg_file = None | |||
# TODO @wenmeng.zwm use exception instead of | |||
# following tricky logic | |||
cfg_file = model_file_download(model, CONFIGFILE) | |||
with open(cfg_file, 'r') as infile: | |||
cfg = json.load(infile) | |||
if 'Code' in cfg: | |||
return False | |||
else: | |||
return True | |||
if isinstance(model, str): | |||
return is_model_name_impl(model) | |||
else: | |||
# try: | |||
# cfg_file = model_file_download(model, CONFIGFILE) | |||
# except Exception: | |||
# cfg_file = None | |||
# TODO @wenmeng.zwm use exception instead of | |||
# following tricky logic | |||
cfg_file = model_file_download(model, CONFIGFILE) | |||
with open(cfg_file, 'r') as infile: | |||
cfg = json.load(infile) | |||
if 'Code' in cfg: | |||
return False | |||
else: | |||
return True | |||
results = [is_model_name_impl(m) for m in model] | |||
all_true = all(results) | |||
any_true = any(results) | |||
if any_true and not all_true: | |||
raise ValueError('some model are hub address, some are not') | |||
return all_true |
@@ -0,0 +1,68 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from asyncio import Task | |||
from typing import Any, Dict, List, Tuple, Union | |||
import numpy as np | |||
import PIL | |||
from modelscope.models.base import Model | |||
from modelscope.pipelines import Pipeline, pipeline | |||
from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.registry import default_group | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
group_key=Tasks.image_tagging, module_name='custom_single_model') | |||
class CustomSingleModelPipeline(Pipeline): | |||
def __init__(self, | |||
config_file: str = None, | |||
model: List[Union[str, Model]] = None, | |||
preprocessor=None, | |||
**kwargs): | |||
super().__init__(config_file, model, preprocessor, **kwargs) | |||
assert isinstance(model, str), 'model is not str' | |||
print(model) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return super().postprocess(inputs) | |||
@PIPELINES.register_module( | |||
group_key=Tasks.image_tagging, module_name='model1_model2') | |||
class CustomMultiModelPipeline(Pipeline): | |||
def __init__(self, | |||
config_file: str = None, | |||
model: List[Union[str, Model]] = None, | |||
preprocessor=None, | |||
**kwargs): | |||
super().__init__(config_file, model, preprocessor, **kwargs) | |||
assert isinstance(model, list), 'model is not list' | |||
for m in model: | |||
assert isinstance(m, str), 'submodel is not str' | |||
print(m) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return super().postprocess(inputs) | |||
class PipelineInterfaceTest(unittest.TestCase): | |||
def test_single_model(self): | |||
pipe = pipeline(Tasks.image_tagging, model='custom_single_model') | |||
assert isinstance(pipe, CustomSingleModelPipeline) | |||
def test_multi_model(self): | |||
pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2']) | |||
assert isinstance(pipe, CustomMultiModelPipeline) | |||
if __name__ == '__main__': | |||
unittest.main() |