|
|
@@ -180,6 +180,16 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
else: |
|
|
|
self._max_epochs = kwargs['max_epochs'] |
|
|
|
|
|
|
|
self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) |
|
|
|
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) |
|
|
|
if self._train_iters_per_epoch is None and hasattr( |
|
|
|
self.cfg.train, 'train_iters_per_epoch'): |
|
|
|
self._train_iters_per_epoch = self.cfg.train.train_iters_per_epoch |
|
|
|
if self._eval_iters_per_epoch is None and hasattr( |
|
|
|
self.cfg, 'evaluation') and hasattr(self.cfg.evaluation, |
|
|
|
'val_iters_per_epoch'): |
|
|
|
self._eval_iters_per_epoch = self.cfg.evaluation.val_iters_per_epoch |
|
|
|
|
|
|
|
self.use_fp16 = kwargs.get('use_fp16', False) |
|
|
|
|
|
|
|
# TODO @wenmeng.zwm add seed init fn |
|
|
@@ -236,7 +246,32 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
@property |
|
|
|
def max_iters(self): |
|
|
|
"""int: Maximum training iterations.""" |
|
|
|
return self._max_epochs * len(self.data_loader) |
|
|
|
return self._max_epochs * self.iters_per_epoch |
|
|
|
|
|
|
|
@property |
|
|
|
def iters_per_epoch(self): |
|
|
|
"""int: Total iterations of one epoch""" |
|
|
|
|
|
|
|
def _get_data_len(data_loader): |
|
|
|
try: |
|
|
|
return len(self.data_loader) |
|
|
|
except Exception as e: |
|
|
|
self.logger.error(e) |
|
|
|
raise ValueError( |
|
|
|
'Please implement ``__len__`` method for your dataset, ' |
|
|
|
'or add `train_iters_per_epoch` and `train_iters_per_epoch` ' |
|
|
|
'to your configuration file or kwargs') |
|
|
|
|
|
|
|
if self.mode == ModeKeys.TRAIN: |
|
|
|
if self._train_iters_per_epoch is not None: |
|
|
|
return self._train_iters_per_epoch |
|
|
|
else: |
|
|
|
return _get_data_len(self.data_loader) |
|
|
|
elif self.mode == ModeKeys.EVAL: |
|
|
|
if self._eval_iters_per_epoch is not None: |
|
|
|
return self._eval_iters_per_epoch |
|
|
|
else: |
|
|
|
return _get_data_len(self.data_loader) |
|
|
|
|
|
|
|
def to_task_dataset(self, |
|
|
|
datasets: Union[Dataset, List[Dataset]], |
|
|
|