|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import glob
- import os
- import os.path as osp
- from typing import Any
-
- from easycv.utils.ms_utils import EasyCVMeta
-
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.pipelines.util import is_official_hub_path
- from modelscope.utils.config import Config
- from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
-
-
- class EasyCVPipeline(object):
- """Base pipeline for EasyCV.
- Loading configuration file of modelscope style by default,
- but it is actually use the predictor api of easycv to predict.
- So here we do some adaptation work for configuration and predict api.
- """
-
- def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs):
- """
- model (str): model id on modelscope hub or local model path.
- model_file_pattern (str): model file pattern.
-
- """
- self.model_file_pattern = model_file_pattern
-
- assert isinstance(model, str)
- if osp.exists(model):
- model_dir = model
- else:
- assert is_official_hub_path(
- model), 'Only support local model path and official hub path!'
- model_dir = snapshot_download(
- model_id=model, revision=DEFAULT_MODEL_REVISION)
-
- assert osp.isdir(model_dir)
- model_files = glob.glob(
- os.path.join(model_dir, self.model_file_pattern))
- assert len(
- model_files
- ) == 1, f'Need one model file, but find {len(model_files)}: {model_files}'
-
- model_path = model_files[0]
- self.model_path = model_path
-
- # get configuration file from source model dir
- self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
- assert os.path.exists(
- self.config_file
- ), f'Not find "{ModelFile.CONFIGURATION}" in model directory!'
-
- self.cfg = Config.from_file(self.config_file)
- self.predict_op = self._build_predict_op()
-
- def _build_predict_op(self):
- """Build EasyCV predictor."""
- from easycv.predictors.builder import build_predictor
-
- easycv_config = self._to_easycv_config()
- pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, {
- 'model_path': self.model_path,
- 'config_file': easycv_config
- })
- return pipeline_op
-
- def _to_easycv_config(self):
- """Adapt to EasyCV predictor."""
- # TODO: refine config compatibility problems
-
- easycv_arch = self.cfg.model.pop(EasyCVMeta.ARCH, None)
- model_cfg = self.cfg.model
- # Revert to the configuration of easycv
- if easycv_arch is not None:
- model_cfg.update(easycv_arch)
-
- easycv_config = Config(dict(model=model_cfg))
-
- reserved_keys = []
- if hasattr(self.cfg, EasyCVMeta.META):
- easycv_meta_cfg = getattr(self.cfg, EasyCVMeta.META)
- reserved_keys = easycv_meta_cfg.get(EasyCVMeta.RESERVED_KEYS, [])
- for key in reserved_keys:
- easycv_config.merge_from_dict({key: getattr(self.cfg, key)})
- if 'test_pipeline' not in reserved_keys:
- easycv_config.merge_from_dict(
- {'test_pipeline': self.cfg.dataset.val.get('pipeline', [])})
-
- return easycv_config
-
- def __call__(self, inputs) -> Any:
- # TODO: support image url
- return self.predict_op(inputs)
|