You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, List, Tuple, Union
  5. from maas_hub.file_download import model_file_download
  6. from maas_hub.snapshot_download import snapshot_download
  7. from maas_lib.models.builder import build_model
  8. from maas_lib.utils.config import Config
  9. from maas_lib.utils.constant import CONFIGFILE
  10. Tensor = Union['torch.Tensor', 'tf.Tensor']
  11. class Model(ABC):
  12. def __init__(self, model_dir, *args, **kwargs):
  13. self.model_dir = model_dir
  14. def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  15. return self.post_process(self.forward(input))
  16. @abstractmethod
  17. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  18. pass
  19. def post_process(self, input: Dict[str, Tensor],
  20. **kwargs) -> Dict[str, Tensor]:
  21. # model specific postprocess, implementation is optional
  22. # will be called in Pipeline and evaluation loop(in the future)
  23. return input
  24. @classmethod
  25. def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
  26. """ Instantiate a model from local directory or remote model repo
  27. """
  28. if osp.exists(model_name_or_path):
  29. local_model_dir = model_name_or_path
  30. else:
  31. local_model_dir = snapshot_download(model_name_or_path)
  32. # else:
  33. # raise ValueError(
  34. # 'Remote model repo {model_name_or_path} does not exists')
  35. cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE))
  36. task_name = cfg.task
  37. model_cfg = cfg.model
  38. # TODO @wenmeng.zwm may should mannually initialize model after model building
  39. if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
  40. model_cfg.type = model_cfg.model_type
  41. model_cfg.model_dir = local_model_dir
  42. return build_model(model_cfg, task_name)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展