Browse Source

[to #43850241] adapt to torch IterableDataset

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9761129

    * adapt to torch IterableDataset
master
jiangnana.jnn 3 years ago
parent
commit
4819fd569d
3 changed files with 38 additions and 3 deletions
  1. +1
    -1
      modelscope/trainers/hooks/hook.py
  2. +1
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  3. +36
    -1
      modelscope/trainers/trainer.py

+ 1
- 1
modelscope/trainers/hooks/hook.py View File

@@ -192,7 +192,7 @@ class Hook:
Whether to reach the end of every epoch
Returns: bool
"""
return trainer.inner_iter + 1 == len(trainer.data_loader)
return trainer.inner_iter + 1 == trainer.iters_per_epoch

def is_last_epoch(self, trainer):
"""


+ 1
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -93,7 +93,7 @@ class TextLoggerHook(LoggerHook):
lr_str = f'{lr_key}: {log_dict[lr_key]:.3e}'

if self.by_epoch:
log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{len(trainer.data_loader)}]\t'
log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{trainer.iters_per_epoch}]\t'
else:
log_str = f'{iter_key} [{log_dict[iter_key]}/{trainer.max_iters}]\t'
log_str += f'{lr_str}, '


+ 36
- 1
modelscope/trainers/trainer.py View File

@@ -180,6 +180,16 @@ class EpochBasedTrainer(BaseTrainer):
else:
self._max_epochs = kwargs['max_epochs']

self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None)
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None)
if self._train_iters_per_epoch is None and hasattr(
self.cfg.train, 'train_iters_per_epoch'):
self._train_iters_per_epoch = self.cfg.train.train_iters_per_epoch
if self._eval_iters_per_epoch is None and hasattr(
self.cfg, 'evaluation') and hasattr(self.cfg.evaluation,
'val_iters_per_epoch'):
self._eval_iters_per_epoch = self.cfg.evaluation.val_iters_per_epoch

self.use_fp16 = kwargs.get('use_fp16', False)

# TODO @wenmeng.zwm add seed init fn
@@ -236,7 +246,32 @@ class EpochBasedTrainer(BaseTrainer):
@property
def max_iters(self):
"""int: Maximum training iterations."""
return self._max_epochs * len(self.data_loader)
return self._max_epochs * self.iters_per_epoch

@property
def iters_per_epoch(self):
"""int: Total iterations of one epoch"""

def _get_data_len(data_loader):
try:
return len(self.data_loader)
except Exception as e:
self.logger.error(e)
raise ValueError(
'Please implement ``__len__`` method for your dataset, '
'or add `train_iters_per_epoch` and `train_iters_per_epoch` '
'to your configuration file or kwargs')

if self.mode == ModeKeys.TRAIN:
if self._train_iters_per_epoch is not None:
return self._train_iters_per_epoch
else:
return _get_data_len(self.data_loader)
elif self.mode == ModeKeys.EVAL:
if self._eval_iters_per_epoch is not None:
return self._eval_iters_per_epoch
else:
return _get_data_len(self.data_loader)

def to_task_dataset(self,
datasets: Union[Dataset, List[Dataset]],


Loading…
Cancel
Save