You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Union
  3. from maas_lib.models.base import Model
  4. from maas_lib.utils.config import ConfigDict
  5. from maas_lib.utils.constant import Tasks
  6. from maas_lib.utils.registry import Registry, build_from_cfg
  7. from .base import Pipeline
  8. PIPELINES = Registry('pipelines')
  9. def build_pipeline(cfg: ConfigDict,
  10. task_name: str = None,
  11. default_args: dict = None):
  12. """ build pipeline given model config dict.
  13. Args:
  14. cfg (:obj:`ConfigDict`): config dict for model object.
  15. task_name (str, optional): task name, refer to
  16. :obj:`Tasks` for more details.
  17. default_args (dict, optional): Default initialization arguments.
  18. """
  19. return build_from_cfg(
  20. cfg, PIPELINES, group_key=task_name, default_args=default_args)
  21. def pipeline(task: str = None,
  22. model: Union[str, Model] = None,
  23. config_file: str = None,
  24. pipeline_name: str = None,
  25. framework: str = None,
  26. device: int = -1,
  27. **kwargs) -> Pipeline:
  28. """ Factory method to build a obj:`Pipeline`.
  29. Args:
  30. task (str): Task name defining which pipeline will be returned.
  31. model (str or obj:`Model`): model name or model object.
  32. config_file (str, optional): path to config file.
  33. pipeline_name (str, optional): pipeline class name or alias name.
  34. framework (str, optional): framework type.
  35. device (int, optional): which device is used to do inference.
  36. Return:
  37. pipeline (obj:`Pipeline`): pipeline object for certain task.
  38. Examples:
  39. ```python
  40. >>> p = pipeline('image-classification')
  41. >>> p = pipeline('text-classification', model='distilbert-base-uncased')
  42. >>> # Using model object
  43. >>> resnet = Model.from_pretrained('Resnet')
  44. >>> p = pipeline('image-classification', model=resnet)
  45. """
  46. if task is not None and model is None and pipeline_name is None:
  47. # get default pipeline for this task
  48. assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}'
  49. pipeline_name = list(PIPELINES.modules[task].keys())[0]
  50. if pipeline_name is not None:
  51. cfg = dict(type=pipeline_name, **kwargs)
  52. return build_pipeline(cfg, task_name=task)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展