You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from abc import ABC, abstractmethod
  3. from typing import Callable, Dict, List, Optional, Tuple, Union
  4. from maas_lib.trainers.builder import TRAINERS
  5. from maas_lib.utils.config import Config
  6. class BaseTrainer(ABC):
  7. """ Base class for trainer which can not be instantiated.
  8. BaseTrainer defines necessary interface
  9. and provide default implementation for basic initialization
  10. such as parsing config file and parsing commandline args.
  11. """
  12. def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None):
  13. """ Trainer basic init, should be called in derived class
  14. Args:
  15. cfg_file: Path to configuration file.
  16. arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`.
  17. """
  18. self.cfg = Config.from_file(cfg_file)
  19. if arg_parse_fn:
  20. self.args = self.cfg.to_args(arg_parse_fn)
  21. else:
  22. self.args = None
  23. @abstractmethod
  24. def train(self, *args, **kwargs):
  25. """ Train (and evaluate) process
  26. Train process should be implemented for specific task or
  27. model, releated paramters have been intialized in
  28. ``BaseTrainer.__init__`` and should be used in this function
  29. """
  30. pass
  31. @abstractmethod
  32. def evaluate(self, checkpoint_path: str, *args,
  33. **kwargs) -> Dict[str, float]:
  34. """ Evaluation process
  35. Evaluation process should be implemented for specific task or
  36. model, releated paramters have been intialized in
  37. ``BaseTrainer.__init__`` and should be used in this function
  38. """
  39. pass
  40. @TRAINERS.register_module(module_name='dummy')
  41. class DummyTrainer(BaseTrainer):
  42. def __init__(self, cfg_file: str, *args, **kwargs):
  43. """ Dummy Trainer.
  44. Args:
  45. cfg_file: Path to configuration file.
  46. """
  47. super().__init__(cfg_file)
  48. def train(self, *args, **kwargs):
  49. """ Train (and evaluate) process
  50. Train process should be implemented for specific task or
  51. model, releated paramters have been intialized in
  52. ``BaseTrainer.__init__`` and should be used in this function
  53. """
  54. cfg = self.cfg.train
  55. print(f'train cfg {cfg}')
  56. def evaluate(self,
  57. checkpoint_path: str = None,
  58. *args,
  59. **kwargs) -> Dict[str, float]:
  60. """ Evaluation process
  61. Evaluation process should be implemented for specific task or
  62. model, releated paramters have been intialized in
  63. ``BaseTrainer.__init__`` and should be used in this function
  64. """
  65. cfg = self.cfg.evaluation
  66. print(f'eval cfg {cfg}')
  67. print(f'checkpoint_path {checkpoint_path}')

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展