You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from abc import ABC, abstractmethod
  4. from typing import Any, Dict, Generator, List, Tuple, Union
  5. from ali_maas_datasets import PyDataset
  6. from maas_hub.snapshot_download import snapshot_download
  7. from maas_lib.models import Model
  8. from maas_lib.pipelines import util
  9. from maas_lib.preprocessors import Preprocessor
  10. from maas_lib.utils.config import Config
  11. from .util import is_model_name
  12. Tensor = Union['torch.Tensor', 'tf.Tensor']
  13. Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
  14. output_keys = [
  15. ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key
  16. class Pipeline(ABC):
  17. def __init__(self,
  18. config_file: str = None,
  19. model: Union[Model, str] = None,
  20. preprocessor: Preprocessor = None,
  21. **kwargs):
  22. """ Base class for pipeline.
  23. If config_file is provided, model and preprocessor will be
  24. instantiated from corresponding config. Otherwise, model
  25. and preprocessor will be constructed separately.
  26. Args:
  27. config_file(str, optional): Filepath to configuration file.
  28. model: Model name or model object
  29. preprocessor: Preprocessor object
  30. """
  31. if config_file is not None:
  32. self.cfg = Config.from_file(config_file)
  33. if isinstance(model, str):
  34. if not osp.exists(model):
  35. cache_path = util.get_model_cache_dir(model)
  36. if osp.exists(cache_path):
  37. model = cache_path
  38. else:
  39. model = snapshot_download(model)
  40. if is_model_name(model):
  41. self.model = Model.from_pretrained(model)
  42. else:
  43. self.model = model
  44. elif isinstance(model, Model):
  45. self.model = model
  46. else:
  47. if model:
  48. raise ValueError(
  49. f'model type is either str or Model, but got type {type(model)}'
  50. )
  51. self.preprocessor = preprocessor
  52. def __call__(self, input: Union[Input, List[Input]], *args,
  53. **post_kwargs) -> Union[Dict[str, Any], Generator]:
  54. # model provider should leave it as it is
  55. # maas library developer will handle this function
  56. # simple showcase, need to support iterator type for both tensorflow and pytorch
  57. # input_dict = self._handle_input(input)
  58. if isinstance(input, list):
  59. output = []
  60. for ele in input:
  61. output.append(self._process_single(ele, *args, **post_kwargs))
  62. elif isinstance(input, PyDataset):
  63. return self._process_iterator(input, *args, **post_kwargs)
  64. else:
  65. output = self._process_single(input, *args, **post_kwargs)
  66. return output
  67. def _process_iterator(self, input: Input, *args, **post_kwargs):
  68. for ele in input:
  69. yield self._process_single(ele, *args, **post_kwargs)
  70. def _process_single(self, input: Input, *args,
  71. **post_kwargs) -> Dict[str, Any]:
  72. out = self.preprocess(input)
  73. out = self.forward(out)
  74. out = self.postprocess(out, **post_kwargs)
  75. return out
  76. def preprocess(self, inputs: Input) -> Dict[str, Any]:
  77. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  78. """
  79. assert self.preprocessor is not None, 'preprocess method should be implemented'
  80. return self.preprocessor(inputs)
  81. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  82. """ Provide default implementation using self.model and user can reimplement it
  83. """
  84. assert self.model is not None, 'forward method should be implemented'
  85. return self.model(inputs)
  86. @abstractmethod
  87. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  88. raise NotImplementedError('postprocess')

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展