You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from abc import ABC, abstractmethod
  3. from typing import Any, Dict, List, Tuple, Union
  4. from maas_lib.models import Model
  5. from maas_lib.preprocessors import Preprocessor
  6. Tensor = Union['torch.Tensor', 'tf.Tensor']
  7. Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray']
  8. output_keys = [
  9. ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key
  10. class Pipeline(ABC):
  11. def __init__(self,
  12. config_file: str = None,
  13. model: Model = None,
  14. preprocessor: Preprocessor = None,
  15. **kwargs):
  16. self.model = model
  17. self.preprocessor = preprocessor
  18. def __call__(self, input: Union[Input, List[Input]], *args,
  19. **post_kwargs) -> Dict[str, Any]:
  20. # moodel provider should leave it as it is
  21. # maas library developer will handle this function
  22. # simple show case, need to support iterator type for both tensorflow and pytorch
  23. # input_dict = self._handle_input(input)
  24. if isinstance(input, list):
  25. output = []
  26. for ele in input:
  27. output.append(self._process_single(ele, *args, **post_kwargs))
  28. else:
  29. output = self._process_single(input, *args, **post_kwargs)
  30. return output
  31. def _process_single(self, input: Input, *args,
  32. **post_kwargs) -> Dict[str, Any]:
  33. out = self.preprocess(input)
  34. out = self.forward(out)
  35. out = self.postprocess(out, **post_kwargs)
  36. return out
  37. def preprocess(self, inputs: Input) -> Dict[str, Any]:
  38. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  39. """
  40. assert self.preprocessor is not None, 'preprocess method should be implemented'
  41. return self.preprocessor(inputs)
  42. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  43. """ Provide default implementation using self.model and user can reimplement it
  44. """
  45. assert self.model is not None, 'forward method should be implemented'
  46. return self.model(inputs)
  47. @abstractmethod
  48. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  49. raise NotImplementedError('postprocess')

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展