- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- from abc import ABC, abstractmethod
- from typing import Dict, List, Tuple, Union
-
- Tensor = Union['torch.Tensor', 'tf.Tensor']
-
-
- class Model(ABC):
-
- def __init__(self, *args, **kwargs):
- pass
-
- def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- return self.post_process(self.forward(input))
-
- @abstractmethod
- def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- pass
-
- def post_process(self, input: Dict[str, Tensor],
- **kwargs) -> Dict[str, Tensor]:
- # model specific postprocess, implementation is optional
- # will be called in Pipeline and evaluation loop(in the future)
- return input
-
- @classmethod
- def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
- raise NotImplementedError('from_pretrained has not been implemented')
|