|
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import os.path as osp
- from abc import ABC, abstractmethod
- from typing import Any, Dict, Generator, List, Tuple, Union
-
- from ali_maas_datasets import PyDataset
- from maas_hub.snapshot_download import snapshot_download
-
- from maas_lib.models import Model
- from maas_lib.pipelines import util
- from maas_lib.preprocessors import Preprocessor
- from maas_lib.utils.config import Config
- from .util import is_model_name
-
- Tensor = Union['torch.Tensor', 'tf.Tensor']
- Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
-
- output_keys = [
- ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key
-
-
- class Pipeline(ABC):
-
- def __init__(self,
- config_file: str = None,
- model: Union[Model, str] = None,
- preprocessor: Preprocessor = None,
- **kwargs):
- """ Base class for pipeline.
-
- If config_file is provided, model and preprocessor will be
- instantiated from corresponding config. Otherwise, model
- and preprocessor will be constructed separately.
-
- Args:
- config_file(str, optional): Filepath to configuration file.
- model: Model name or model object
- preprocessor: Preprocessor object
- """
- if config_file is not None:
- self.cfg = Config.from_file(config_file)
-
- if isinstance(model, str):
- if not osp.exists(model):
- cache_path = util.get_model_cache_dir(model)
- if osp.exists(cache_path):
- model = cache_path
- else:
- model = snapshot_download(model)
-
- if is_model_name(model):
- self.model = Model.from_pretrained(model)
- else:
- self.model = model
- elif isinstance(model, Model):
- self.model = model
- else:
- if model:
- raise ValueError(
- f'model type is either str or Model, but got type {type(model)}'
- )
- self.preprocessor = preprocessor
-
- def __call__(self, input: Union[Input, List[Input]], *args,
- **post_kwargs) -> Union[Dict[str, Any], Generator]:
- # model provider should leave it as it is
- # maas library developer will handle this function
-
- # simple showcase, need to support iterator type for both tensorflow and pytorch
- # input_dict = self._handle_input(input)
- if isinstance(input, list):
- output = []
- for ele in input:
- output.append(self._process_single(ele, *args, **post_kwargs))
-
- elif isinstance(input, PyDataset):
- return self._process_iterator(input, *args, **post_kwargs)
-
- else:
- output = self._process_single(input, *args, **post_kwargs)
- return output
-
- def _process_iterator(self, input: Input, *args, **post_kwargs):
- for ele in input:
- yield self._process_single(ele, *args, **post_kwargs)
-
- def _process_single(self, input: Input, *args,
- **post_kwargs) -> Dict[str, Any]:
- out = self.preprocess(input)
- out = self.forward(out)
- out = self.postprocess(out, **post_kwargs)
- return out
-
- def preprocess(self, inputs: Input) -> Dict[str, Any]:
- """ Provide default implementation based on preprocess_cfg and user can reimplement it
-
- """
- assert self.preprocessor is not None, 'preprocess method should be implemented'
- return self.preprocessor(inputs)
-
- def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- """ Provide default implementation using self.model and user can reimplement it
- """
- assert self.model is not None, 'forward method should be implemented'
- return self.model(inputs)
-
- @abstractmethod
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- raise NotImplementedError('postprocess')
|