You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from abc import ABC, abstractmethod
  4. from threading import Lock
  5. from typing import Any, Dict, Generator, List, Mapping, Union
  6. import numpy as np
  7. from modelscope.models.base import Model
  8. from modelscope.msdatasets import MsDataset
  9. from modelscope.outputs import TASK_OUTPUTS
  10. from modelscope.preprocessors import Preprocessor
  11. from modelscope.utils.config import Config
  12. from modelscope.utils.constant import Frameworks, ModelFile
  13. from modelscope.utils.device import (create_device, device_placement,
  14. verify_device)
  15. from modelscope.utils.import_utils import is_tf_available, is_torch_available
  16. from modelscope.utils.logger import get_logger
  17. from .util import is_model, is_official_hub_path
  18. if is_torch_available():
  19. import torch
  20. if is_tf_available():
  21. import tensorflow as tf
  22. Tensor = Union['torch.Tensor', 'tf.Tensor']
  23. Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray']
  24. InputModel = Union[str, Model]
  25. logger = get_logger()
  26. class Pipeline(ABC):
  27. def initiate_single_model(self, model):
  28. if isinstance(model, str):
  29. logger.info(f'initiate model from {model}')
  30. if isinstance(model, str) and is_official_hub_path(model):
  31. logger.info(f'initiate model from location {model}.')
  32. # expecting model has been prefetched to local cache beforehand
  33. return Model.from_pretrained(
  34. model, model_prefetched=True,
  35. device=self.device_name) if is_model(model) else model
  36. elif isinstance(model, Model):
  37. return model
  38. else:
  39. if model and not isinstance(model, str):
  40. raise ValueError(
  41. f'model type for single model is either str or Model, but got type {type(model)}'
  42. )
  43. return model
  44. def initiate_multiple_models(self, input_models: List[InputModel]):
  45. models = []
  46. for model in input_models:
  47. models.append(self.initiate_single_model(model))
  48. return models
  49. def __init__(self,
  50. config_file: str = None,
  51. model: Union[InputModel, List[InputModel]] = None,
  52. preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
  53. device: str = 'gpu',
  54. auto_collate=True,
  55. **kwargs):
  56. """ Base class for pipeline.
  57. If config_file is provided, model and preprocessor will be
  58. instantiated from corresponding config. Otherwise, model
  59. and preprocessor will be constructed separately.
  60. Args:
  61. config_file(str, optional): Filepath to configuration file.
  62. model: (list of) Model name or model object
  63. preprocessor: (list of) Preprocessor object
  64. device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X
  65. auto_collate (bool): automatically to convert data to tensor or not.
  66. """
  67. if config_file is not None:
  68. self.cfg = Config.from_file(config_file)
  69. verify_device(device)
  70. self.device_name = device
  71. if not isinstance(model, List):
  72. self.model = self.initiate_single_model(model)
  73. self.models = [self.model]
  74. else:
  75. self.model = None
  76. self.models = self.initiate_multiple_models(model)
  77. self.has_multiple_models = len(self.models) > 1
  78. self.preprocessor = preprocessor
  79. if self.model or (self.has_multiple_models and self.models[0]):
  80. self.framework = self._get_framework()
  81. else:
  82. self.framework = None
  83. if self.framework == Frameworks.torch:
  84. self.device = create_device(self.device_name)
  85. self._model_prepare = False
  86. self._model_prepare_lock = Lock()
  87. self._auto_collate = auto_collate
  88. def prepare_model(self):
  89. """ Place model on certain device for pytorch models before first inference
  90. """
  91. self._model_prepare_lock.acquire(timeout=600)
  92. def _prepare_single(model):
  93. if isinstance(model, torch.nn.Module):
  94. model.to(self.device)
  95. model.eval()
  96. elif hasattr(model, 'model') and isinstance(
  97. model.model, torch.nn.Module):
  98. model.model.to(self.device)
  99. model.model.eval()
  100. if not self._model_prepare:
  101. # prepare model for pytorch
  102. if self.framework == Frameworks.torch:
  103. if self.has_multiple_models:
  104. for m in self.models:
  105. _prepare_single(m)
  106. else:
  107. _prepare_single(self.model)
  108. self._model_prepare = True
  109. self._model_prepare_lock.release()
  110. def _get_framework(self) -> str:
  111. frameworks = []
  112. for m in self.models:
  113. if isinstance(m, Model):
  114. model_dir = m.model_dir
  115. else:
  116. assert isinstance(m,
  117. str), 'model should be either str or Model.'
  118. model_dir = m
  119. cfg_file = osp.join(model_dir, ModelFile.CONFIGURATION)
  120. cfg = Config.from_file(cfg_file)
  121. frameworks.append(cfg.framework)
  122. if not all(x == frameworks[0] for x in frameworks):
  123. raise ValueError(
  124. f'got multiple models, but they are in different frameworks {frameworks}'
  125. )
  126. return frameworks[0]
  127. def __call__(self, input: Union[Input, List[Input]], *args,
  128. **kwargs) -> Union[Dict[str, Any], Generator]:
  129. # model provider should leave it as it is
  130. # modelscope library developer will handle this function
  131. # place model to cpu or gpu
  132. if (self.model or (self.has_multiple_models and self.models[0])):
  133. if not self._model_prepare:
  134. self.prepare_model()
  135. # simple showcase, need to support iterator type for both tensorflow and pytorch
  136. # input_dict = self._handle_input(input)
  137. # sanitize the parameters
  138. preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
  139. **kwargs)
  140. kwargs['preprocess_params'] = preprocess_params
  141. kwargs['forward_params'] = forward_params
  142. kwargs['postprocess_params'] = postprocess_params
  143. if isinstance(input, list):
  144. output = []
  145. for ele in input:
  146. output.append(self._process_single(ele, *args, **kwargs))
  147. elif isinstance(input, MsDataset):
  148. return self._process_iterator(input, *args, **kwargs)
  149. else:
  150. output = self._process_single(input, *args, **kwargs)
  151. return output
  152. def _sanitize_parameters(self, **pipeline_parameters):
  153. """
  154. this method should sanitize the keyword args to preprocessor params,
  155. forward params and postprocess params on '__call__' or '_process_single' method
  156. considered to be a normal classmethod with default implementation / output
  157. Default Returns:
  158. Dict[str, str]: preprocess_params = {}
  159. Dict[str, str]: forward_params = {}
  160. Dict[str, str]: postprocess_params = pipeline_parameters
  161. """
  162. return {}, {}, pipeline_parameters
  163. def _process_iterator(self, input: Input, *args, **kwargs):
  164. for ele in input:
  165. yield self._process_single(ele, *args, **kwargs)
  166. def _collate_fn(self, data):
  167. """Prepare the input just before the forward function.
  168. This method will move the tensors to the right device.
  169. Usually this method does not need to be overridden.
  170. Args:
  171. data: The data out of the dataloader.
  172. Returns: The processed data.
  173. """
  174. from torch.utils.data.dataloader import default_collate
  175. from modelscope.preprocessors import InputFeatures
  176. if isinstance(data, dict) or isinstance(data, Mapping):
  177. return type(data)(
  178. {k: self._collate_fn(v)
  179. for k, v in data.items()})
  180. elif isinstance(data, (tuple, list)):
  181. if isinstance(data[0], (int, float)):
  182. return default_collate(data).to(self.device)
  183. else:
  184. return type(data)(self._collate_fn(v) for v in data)
  185. elif isinstance(data, np.ndarray):
  186. if data.dtype.type is np.str_:
  187. return data
  188. else:
  189. return self._collate_fn(torch.from_numpy(data))
  190. elif isinstance(data, torch.Tensor):
  191. return data.to(self.device)
  192. elif isinstance(data, (bytes, str, int, float, bool, type(None))):
  193. return data
  194. elif isinstance(data, InputFeatures):
  195. return data
  196. else:
  197. import mmcv
  198. if isinstance(data, mmcv.parallel.data_container.DataContainer):
  199. return data
  200. else:
  201. raise ValueError(f'Unsupported data type {type(data)}')
  202. def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
  203. preprocess_params = kwargs.get('preprocess_params', {})
  204. forward_params = kwargs.get('forward_params', {})
  205. postprocess_params = kwargs.get('postprocess_params', {})
  206. out = self.preprocess(input, **preprocess_params)
  207. with device_placement(self.framework, self.device_name):
  208. if self.framework == Frameworks.torch:
  209. with torch.no_grad():
  210. if self._auto_collate:
  211. out = self._collate_fn(out)
  212. out = self.forward(out, **forward_params)
  213. else:
  214. out = self.forward(out, **forward_params)
  215. out = self.postprocess(out, **postprocess_params)
  216. self._check_output(out)
  217. return out
  218. def _check_output(self, input):
  219. # this attribute is dynamically attached by registry
  220. # when cls is registered in registry using task name
  221. task_name = self.group_key
  222. if task_name not in TASK_OUTPUTS:
  223. logger.warning(f'task {task_name} output keys are missing')
  224. return
  225. output_keys = TASK_OUTPUTS[task_name]
  226. missing_keys = []
  227. for k in output_keys:
  228. if k not in input:
  229. missing_keys.append(k)
  230. if len(missing_keys) > 0:
  231. raise ValueError(f'expected output keys are {output_keys}, '
  232. f'those {missing_keys} are missing')
  233. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  234. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  235. """
  236. assert self.preprocessor is not None, 'preprocess method should be implemented'
  237. assert not isinstance(self.preprocessor, List),\
  238. 'default implementation does not support using multiple preprocessors.'
  239. return self.preprocessor(inputs, **preprocess_params)
  240. def forward(self, inputs: Dict[str, Any],
  241. **forward_params) -> Dict[str, Any]:
  242. """ Provide default implementation using self.model and user can reimplement it
  243. """
  244. assert self.model is not None, 'forward method should be implemented'
  245. assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
  246. return self.model(inputs, **forward_params)
  247. @abstractmethod
  248. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  249. """ If current pipeline support model reuse, common postprocess
  250. code should be write here.
  251. Args:
  252. inputs: input data
  253. Return:
  254. dict of results: a dict containing outputs of model, each
  255. output should have the standard output name.
  256. """
  257. raise NotImplementedError('postprocess')