# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Union import json from maas_hub.file_download import model_file_download from maas_lib.models.base import Model from maas_lib.utils.config import Config, ConfigDict from maas_lib.utils.constant import CONFIGFILE, Tasks from maas_lib.utils.registry import Registry, build_from_cfg from .base import Pipeline from .util import is_model_name PIPELINES = Registry('pipelines') def build_pipeline(cfg: ConfigDict, task_name: str = None, default_args: dict = None): """ build pipeline given model config dict. Args: cfg (:obj:`ConfigDict`): config dict for model object. task_name (str, optional): task name, refer to :obj:`Tasks` for more details. default_args (dict, optional): Default initialization arguments. """ return build_from_cfg( cfg, PIPELINES, group_key=task_name, default_args=default_args) def pipeline(task: str = None, model: Union[str, Model] = None, preprocessor=None, config_file: str = None, pipeline_name: str = None, framework: str = None, device: int = -1, **kwargs) -> Pipeline: """ Factory method to build a obj:`Pipeline`. Args: task (str): Task name defining which pipeline will be returned. model (str or obj:`Model`): model name or model object. preprocessor: preprocessor object. config_file (str, optional): path to config file. pipeline_name (str, optional): pipeline class name or alias name. framework (str, optional): framework type. device (int, optional): which device is used to do inference. Return: pipeline (obj:`Pipeline`): pipeline object for certain task. Examples: ```python >>> p = pipeline('image-classification') >>> p = pipeline('text-classification', model='distilbert-base-uncased') >>> # Using model object >>> resnet = Model.from_pretrained('Resnet') >>> p = pipeline('image-classification', model=resnet) """ if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') if pipeline_name is None: # get default pipeline for this task assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' pipeline_name = get_default_pipeline(task) cfg = ConfigDict(type=pipeline_name) if model: assert isinstance(model, (str, Model)), \ f'model should be either str or Model, but got {type(model)}' cfg.model = model if preprocessor is not None: cfg.preprocessor = preprocessor return build_pipeline(cfg, task_name=task) def get_default_pipeline(task): return list(PIPELINES.modules[task].keys())[0]