|
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- from abc import ABC, abstractmethod
- from typing import Callable, Dict, List, Optional, Tuple, Union
-
- from maas_lib.trainers.builder import TRAINERS
- from maas_lib.utils.config import Config
-
-
- class BaseTrainer(ABC):
- """ Base class for trainer which can not be instantiated.
-
- BaseTrainer defines necessary interface
- and provide default implementation for basic initialization
- such as parsing config file and parsing commandline args.
- """
-
- def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None):
- """ Trainer basic init, should be called in derived class
-
- Args:
- cfg_file: Path to configuration file.
- arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`.
- """
- self.cfg = Config.from_file(cfg_file)
- if arg_parse_fn:
- self.args = self.cfg.to_args(arg_parse_fn)
- else:
- self.args = None
-
- @abstractmethod
- def train(self, *args, **kwargs):
- """ Train (and evaluate) process
-
- Train process should be implemented for specific task or
- model, releated paramters have been intialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- pass
-
- @abstractmethod
- def evaluate(self, checkpoint_path: str, *args,
- **kwargs) -> Dict[str, float]:
- """ Evaluation process
-
- Evaluation process should be implemented for specific task or
- model, releated paramters have been intialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- pass
-
-
- @TRAINERS.register_module(module_name='dummy')
- class DummyTrainer(BaseTrainer):
-
- def __init__(self, cfg_file: str, *args, **kwargs):
- """ Dummy Trainer.
-
- Args:
- cfg_file: Path to configuration file.
- """
- super().__init__(cfg_file)
-
- def train(self, *args, **kwargs):
- """ Train (and evaluate) process
-
- Train process should be implemented for specific task or
- model, releated paramters have been intialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- cfg = self.cfg.train
- print(f'train cfg {cfg}')
-
- def evaluate(self,
- checkpoint_path: str = None,
- *args,
- **kwargs) -> Dict[str, float]:
- """ Evaluation process
-
- Evaluation process should be implemented for specific task or
- model, releated paramters have been intialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- cfg = self.cfg.evaluation
- print(f'eval cfg {cfg}')
- print(f'checkpoint_path {checkpoint_path}')
|