wenmeng.zwm 3 years ago
parent
commit
76bc51eadd
4 changed files with 153 additions and 28 deletions
  1. +8
    -2
      modelscope/pipelines/base.py
  2. +43
    -8
      modelscope/pipelines/builder.py
  3. +34
    -18
      modelscope/pipelines/util.py
  4. +68
    -0
      tests/pipelines/test_builder.py

+ 8
- 2
modelscope/pipelines/base.py View File

@@ -11,6 +11,7 @@ from modelscope.preprocessors import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.logger import get_logger
from .util import is_model_name

Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -20,11 +21,15 @@ InputModel = Union[str, Model]
output_keys = [
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key

logger = get_logger()


class Pipeline(ABC):

def initiate_single_model(self, model):
if isinstance(model, str):
logger.info(f'initiate model from {model}')
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
if isinstance(model, str) and model.startswith('damo/'):
if not osp.exists(model):
cache_path = get_model_cache_dir(model)
model = cache_path if osp.exists(
@@ -34,10 +39,11 @@ class Pipeline(ABC):
elif isinstance(model, Model):
return model
else:
if model:
if model and not isinstance(model, str):
raise ValueError(
f'model type for single model is either str or Model, but got type {type(model)}'
)
return model

def initiate_multiple_models(self, input_models: List[InputModel]):
models = []


+ 43
- 8
modelscope/pipelines/builder.py View File

@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from typing import Union
from typing import List, Union

import json
from maas_hub.file_download import model_file_download
@@ -44,7 +44,7 @@ def build_pipeline(cfg: ConfigDict,


def pipeline(task: str = None,
model: Union[str, Model] = None,
model: Union[str, List[str], Model, List[Model]] = None,
preprocessor=None,
config_file: str = None,
pipeline_name: str = None,
@@ -56,7 +56,7 @@ def pipeline(task: str = None,

Args:
task (str): Task name defining which pipeline will be returned.
model (str or obj:`Model`): model name or model object.
model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object.
preprocessor: preprocessor object.
config_file (str, optional): path to config file.
pipeline_name (str, optional): pipeline class name or alias name.
@@ -68,23 +68,42 @@ def pipeline(task: str = None,

Examples:
```python
>>> # Using default model for a task
>>> p = pipeline('image-classification')
>>> p = pipeline('text-classification', model='distilbert-base-uncased')
>>> # Using model object
>>> # Using pipeline with a model name
>>> p = pipeline('text-classification', model='damo/distilbert-base-uncased')
>>> # Using pipeline with a model object
>>> resnet = Model.from_pretrained('Resnet')
>>> p = pipeline('image-classification', model=resnet)
>>> # Using pipeline with a list of model names
>>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')

if pipeline_name is None:
# get default pipeline for this task
pipeline_name, default_model_repo = get_default_pipeline_info(task)
if isinstance(model, str) \
or (isinstance(model, list) and isinstance(model[0], str)):

# if is_model_name(model):
if (isinstance(model, str) and model.startswith('damo/')) \
or (isinstance(model, list) and model[0].startswith('damo/')) \
or (isinstance(model, str) and osp.exists(model)):
# TODO @wenmeng.zwm add support when model is a str of modelhub address
# read pipeline info from modelhub configuration file.
pipeline_name, default_model_repo = get_default_pipeline_info(
task)
else:
pipeline_name = get_pipeline_by_model_name(task, model)
else:
pipeline_name, default_model_repo = get_default_pipeline_info(task)

if model is None:
model = default_model_repo

assert isinstance(model, (type(None), str, Model)), \
f'model should be either None, str or Model, but got {type(model)}'
assert isinstance(model, (type(None), str, Model, list)), \
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'

cfg = ConfigDict(type=pipeline_name, model=model)

@@ -134,3 +153,19 @@ def get_default_pipeline_info(task):
else:
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model


def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]):
""" Get pipeline name by task name and model name

Args:
task (str): task name.
model (str| list[str]): model names
"""
if isinstance(model, str):
model_key = model
else:
model_key = '_'.join(model)
assert model_key in PIPELINES.modules[task], \
f'pipeline for task {task} model {model_key} not found.'
return model_key

+ 34
- 18
modelscope/pipelines/util.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
from typing import List, Union

import json
from maas_hub.file_download import model_file_download
@@ -8,23 +9,38 @@ from maas_hub.file_download import model_file_download
from modelscope.utils.constant import CONFIGFILE


def is_model_name(model):
if osp.exists(model):
if osp.exists(osp.join(model, CONFIGFILE)):
return True
def is_model_name(model: Union[str, List]):
""" whether model is a valid modelhub path
"""

def is_model_name_impl(model):
if osp.exists(model):
if osp.exists(osp.join(model, CONFIGFILE)):
return True
else:
return False
else:
return False
# try:
# cfg_file = model_file_download(model, CONFIGFILE)
# except Exception:
# cfg_file = None
# TODO @wenmeng.zwm use exception instead of
# following tricky logic
cfg_file = model_file_download(model, CONFIGFILE)
with open(cfg_file, 'r') as infile:
cfg = json.load(infile)
if 'Code' in cfg:
return False
else:
return True

if isinstance(model, str):
return is_model_name_impl(model)
else:
# try:
# cfg_file = model_file_download(model, CONFIGFILE)
# except Exception:
# cfg_file = None
# TODO @wenmeng.zwm use exception instead of
# following tricky logic
cfg_file = model_file_download(model, CONFIGFILE)
with open(cfg_file, 'r') as infile:
cfg = json.load(infile)
if 'Code' in cfg:
return False
else:
return True
results = [is_model_name_impl(m) for m in model]
all_true = all(results)
any_true = any(results)
if any_true and not all_true:
raise ValueError('some model are hub address, some are not')

return all_true

+ 68
- 0
tests/pipelines/test_builder.py View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest
from asyncio import Task
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import PIL

from modelscope.models.base import Model
from modelscope.pipelines import Pipeline, pipeline
from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group

logger = get_logger()


@PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='custom_single_model')
class CustomSingleModelPipeline(Pipeline):

def __init__(self,
config_file: str = None,
model: List[Union[str, Model]] = None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)
assert isinstance(model, str), 'model is not str'
print(model)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return super().postprocess(inputs)


@PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='model1_model2')
class CustomMultiModelPipeline(Pipeline):

def __init__(self,
config_file: str = None,
model: List[Union[str, Model]] = None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)
assert isinstance(model, list), 'model is not list'
for m in model:
assert isinstance(m, str), 'submodel is not str'
print(m)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return super().postprocess(inputs)


class PipelineInterfaceTest(unittest.TestCase):

def test_single_model(self):
pipe = pipeline(Tasks.image_tagging, model='custom_single_model')
assert isinstance(pipe, CustomSingleModelPipeline)

def test_multi_model(self):
pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2'])
assert isinstance(pipe, CustomMultiModelPipeline)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save