You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 920 B

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from abc import ABC, abstractmethod
  3. from typing import Dict, List, Tuple, Union
  4. Tensor = Union['torch.Tensor', 'tf.Tensor']
  5. class Model(ABC):
  6. def __init__(self, *args, **kwargs):
  7. pass
  8. def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  9. return self.post_process(self.forward(input))
  10. @abstractmethod
  11. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  12. pass
  13. def post_process(self, input: Dict[str, Tensor],
  14. **kwargs) -> Dict[str, Tensor]:
  15. # model specific postprocess, implementation is optional
  16. # will be called in Pipeline and evaluation loop(in the future)
  17. return input
  18. @classmethod
  19. def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
  20. raise NotImplementedError('from_pretrained has not been implemented')

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展