You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, List, Tuple, Union
  5. from maas_hub.file_download import model_file_download
  6. from maas_hub.snapshot_download import snapshot_download
  7. from maas_lib.models.builder import build_model
  8. from maas_lib.pipelines import util
  9. from maas_lib.utils.config import Config
  10. from maas_lib.utils.constant import CONFIGFILE
  11. Tensor = Union['torch.Tensor', 'tf.Tensor']
  12. class Model(ABC):
  13. def __init__(self, model_dir, *args, **kwargs):
  14. self.model_dir = model_dir
  15. def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  16. return self.post_process(self.forward(input))
  17. @abstractmethod
  18. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  19. pass
  20. def post_process(self, input: Dict[str, Tensor],
  21. **kwargs) -> Dict[str, Tensor]:
  22. # model specific postprocess, implementation is optional
  23. # will be called in Pipeline and evaluation loop(in the future)
  24. return input
  25. @classmethod
  26. def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
  27. """ Instantiate a model from local directory or remote model repo
  28. """
  29. if osp.exists(model_name_or_path):
  30. local_model_dir = model_name_or_path
  31. else:
  32. cache_path = util.get_model_cache_dir(model_name_or_path)
  33. local_model_dir = cache_path if osp.exists(
  34. cache_path) else snapshot_download(model_name_or_path)
  35. # else:
  36. # raise ValueError(
  37. # 'Remote model repo {model_name_or_path} does not exists')
  38. cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE))
  39. task_name = cfg.task
  40. model_cfg = cfg.model
  41. # TODO @wenmeng.zwm may should mannually initialize model after model building
  42. if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
  43. model_cfg.type = model_cfg.model_type
  44. model_cfg.model_dir = local_model_dir
  45. return build_model(model_cfg, task_name)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展