1. Fix a bug in trainer's progress bar 2. Fix a bug that trainer does not support dataset in config file 3. Add feature: support go on training via checkpoint file 4. Add feature: support fixed filename when saving best checkpoint 5. Fix a bug that no id2label in config file after finetune of nlp models 6. Fix some other bugs Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10138906master
@@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.seq_cls_metric) | |||
class SequenceClassificationMetric(Metric): | |||
"""The metric computation class for sequence classification classes. | |||
"""The metric computation class for sequence classification tasks. | |||
This metric class calculates accuracy for the whole input batches. | |||
This metric class calculates accuracy of the whole input batches. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
@@ -1,14 +1,16 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import random | |||
import json | |||
import numpy as np | |||
import torch | |||
from modelscope import __version__ | |||
from modelscope.metainfo import Hooks | |||
from modelscope.utils.checkpoint import save_checkpoint | |||
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint | |||
from modelscope.utils.constant import LogKeys, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.torch_utils import is_master | |||
from modelscope.utils.torch_utils import get_dist_info, is_master | |||
from .builder import HOOKS | |||
from .hook import Hook | |||
from .priority import Priority | |||
@@ -25,6 +27,7 @@ class CheckpointHook(Hook): | |||
save_optimizer (bool): Whether to save optimizer state dict. Default: True. | |||
save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` | |||
save_last (bool): Whether to save the last checkpoint. Default: True. | |||
checkpoint_file (str): The checkpoint file to be loaded. | |||
""" | |||
PRIORITY = Priority.LOW | |||
@@ -34,12 +37,16 @@ class CheckpointHook(Hook): | |||
by_epoch=True, | |||
save_optimizer=True, | |||
save_dir=None, | |||
save_last=True): | |||
save_last=True, | |||
checkpoint_file=None): | |||
self.interval = interval | |||
self.by_epoch = by_epoch | |||
self.save_optimizer = save_optimizer | |||
self.save_dir = save_dir | |||
self.checkpoint_file = checkpoint_file | |||
self.save_last = save_last | |||
self.rng_state = None | |||
self.need_load_rng_state = False | |||
def before_run(self, trainer): | |||
if not self.save_dir: | |||
@@ -56,6 +63,34 @@ class CheckpointHook(Hook): | |||
if is_master(): | |||
self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||
if self.checkpoint_file is not None and os.path.isfile( | |||
self.checkpoint_file): | |||
meta = self.load_checkpoint(self.checkpoint_file, trainer) | |||
self.rng_state = meta.get('rng_state') | |||
self.need_load_rng_state = True | |||
def before_train_epoch(self, trainer): | |||
if self.need_load_rng_state: | |||
if self.rng_state is not None: | |||
random.setstate(self.rng_state['random']) | |||
np.random.set_state(self.rng_state['numpy']) | |||
torch.random.set_rng_state(self.rng_state['cpu']) | |||
if torch.cuda.is_available(): | |||
torch.cuda.random.set_rng_state_all(self.rng_state['cuda']) | |||
self.need_load_rng_state = False | |||
else: | |||
self.logger.warn( | |||
'Random state cannot be found in checkpoint file, ' | |||
'this may cause a random data order or model initialization.' | |||
) | |||
self.rng_state = { | |||
'random': random.getstate(), | |||
'numpy': np.random.get_state(), | |||
'cpu': torch.random.get_rng_state(), | |||
'cuda': torch.cuda.get_rng_state_all(), | |||
} | |||
def after_train_epoch(self, trainer): | |||
if not self.by_epoch: | |||
return | |||
@@ -66,6 +101,39 @@ class CheckpointHook(Hook): | |||
f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||
self._save_checkpoint(trainer) | |||
@classmethod | |||
def load_checkpoint(cls, filename, trainer): | |||
from modelscope.trainers.parallel.utils import is_parallel | |||
if is_parallel(trainer.model): | |||
model = trainer.model.module | |||
else: | |||
model = trainer.model | |||
meta = load_checkpoint(filename, model, trainer.optimizer, | |||
trainer.lr_scheduler) | |||
trainer._epoch = meta.get('epoch', trainer._epoch) | |||
trainer._iter = meta.get('iter', trainer._iter) | |||
trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||
for i, hook in enumerate(trainer.hooks): | |||
# hook: Hook | |||
key = f'{hook.__class__}-{i}' | |||
if key in meta: | |||
hook.load_state_dict(meta[key]) | |||
else: | |||
trainer.logger( | |||
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||
) | |||
version = meta.get('modelscope') | |||
if version != __version__: | |||
trainer.logger( | |||
f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||
f'The saved version: {version}, runtime version: {__version__}' | |||
) | |||
trainer.logger( | |||
f'Checkpoint {filename} saving time: {meta.get("time")}') | |||
return meta | |||
def _save_checkpoint(self, trainer): | |||
if self.by_epoch: | |||
cur_save_name = os.path.join( | |||
@@ -74,7 +142,21 @@ class CheckpointHook(Hook): | |||
cur_save_name = os.path.join( | |||
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
meta = { | |||
'epoch': trainer.epoch, | |||
'iter': trainer.iter + 1, | |||
'inner_iter': trainer.inner_iter + 1, | |||
'rng_state': self.rng_state, | |||
} | |||
for i, hook in enumerate(trainer.hooks): | |||
meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||
save_checkpoint( | |||
trainer.model, | |||
cur_save_name, | |||
trainer.optimizer, | |||
trainer.lr_scheduler, | |||
meta=meta) | |||
if (self.is_last_epoch(trainer) | |||
and self.by_epoch) or (self.is_last_iter(trainer) | |||
and not self.by_epoch): | |||
@@ -144,6 +226,7 @@ class BestCkptSaverHook(CheckpointHook): | |||
by_epoch=True, | |||
save_optimizer=True, | |||
save_dir=None, | |||
save_file_name=None, | |||
interval=0): | |||
assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | |||
super().__init__( | |||
@@ -179,16 +262,44 @@ class BestCkptSaverHook(CheckpointHook): | |||
return False | |||
def _save_checkpoint(self, trainer): | |||
if self.by_epoch: | |||
cur_save_name = os.path.join( | |||
self.save_dir, | |||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||
) | |||
else: | |||
cur_save_name = os.path.join( | |||
self.save_dir, | |||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||
) | |||
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
cur_save_name = self.save_file_name | |||
if cur_save_name is None: | |||
if self.by_epoch: | |||
cur_save_name = os.path.join( | |||
self.save_dir, | |||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||
) | |||
else: | |||
cur_save_name = os.path.join( | |||
self.save_dir, | |||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||
) | |||
meta = { | |||
'epoch': trainer.epoch, | |||
'iter': trainer.iter + 1, | |||
'inner_iter': trainer.inner_iter + 1, | |||
'rng_state': self.rng_state, | |||
} | |||
for i, hook in enumerate(trainer.hooks): | |||
meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||
if os.path.isfile(cur_save_name): | |||
os.remove(cur_save_name) | |||
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer, | |||
trainer.lr_scheduler, meta) | |||
self._best_ckpt_file = cur_save_name | |||
self._save_pretrained(trainer) | |||
def state_dict(self): | |||
return { | |||
'best_metric': self._best_metric, | |||
} | |||
def load_state_dict(self, state_dict): | |||
if state_dict is not None and len(state_dict) > 0: | |||
self._best_metric = state_dict.get('best_metric') | |||
else: | |||
self.logger.warn( | |||
'The state_dict is not available, the best metric value will be affected.' | |||
) |
@@ -215,3 +215,9 @@ class Hook: | |||
trigger_stages.add(stage) | |||
return [stage for stage in Hook.stages if stage in trigger_stages] | |||
def state_dict(self): | |||
return {} | |||
def load_state_dict(self, state_dict): | |||
pass |
@@ -4,6 +4,7 @@ import logging | |||
from torch.nn.utils import clip_grad | |||
from modelscope.metainfo import Hooks | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.trainers.hooks.builder import HOOKS | |||
from modelscope.trainers.hooks.hook import Hook | |||
from modelscope.trainers.hooks.priority import Priority | |||
@@ -27,7 +28,7 @@ class OptimizerHook(Hook): | |||
def __init__(self, | |||
cumulative_iters=1, | |||
grad_clip=None, | |||
loss_keys='loss') -> None: | |||
loss_keys=OutputKeys.LOSS) -> None: | |||
if isinstance(loss_keys, str): | |||
loss_keys = [loss_keys] | |||
assert isinstance(loss_keys, (tuple, list)) | |||
@@ -28,10 +28,10 @@ class BaseWarmup(_LRScheduler): | |||
return self.base_scheduler.get_lr() | |||
def state_dict(self): | |||
self.base_scheduler.state_dict() | |||
return self.base_scheduler.state_dict() | |||
def load_state_dict(self, state_dict): | |||
self.base_scheduler.load_state_dict(state_dict) | |||
return self.base_scheduler.load_state_dict(state_dict) | |||
def scale(self): | |||
"""Scale the learning rates. | |||
@@ -1,6 +1,7 @@ | |||
import os | |||
from typing import Callable, Dict, Optional, Tuple, Union | |||
from typing import Callable, Optional, Tuple, Union | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from torch.utils.data import Dataset | |||
@@ -11,9 +12,10 @@ from modelscope.metrics.builder import build_metric | |||
from modelscope.models.base import Model, TorchModel | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
from modelscope.utils.config import Config, ConfigDict | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, | |||
ModelFile, Tasks) | |||
from modelscope.utils.hub import parse_label_mapping | |||
from .base import TRAINERS | |||
from .trainer import EpochBasedTrainer | |||
@@ -81,19 +83,32 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | |||
model_dir = os.path.dirname(cfg_file) | |||
self.label2id = None | |||
self.id2label = None | |||
self.num_labels = None | |||
self.cfg_modify_fn = cfg_modify_fn | |||
self.cfg = self.rebuild_config(Config.from_file(cfg_file)) | |||
try: | |||
labels = self.cfg.dataset.train.labels | |||
except AttributeError: | |||
labels = None | |||
self.label2id = None | |||
self.num_labels = None | |||
if labels is not None and len(labels) > 0: | |||
self.label2id = {label: idx for idx, label in enumerate(labels)} | |||
self.id2label = {idx: label for idx, label in enumerate(labels)} | |||
self.num_labels = len(labels) | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None: | |||
self.label2id = label2id | |||
self.id2label = {id: label for label, id in label2id.items()} | |||
self.num_labels = len(label2id) | |||
else: | |||
try: | |||
labels = self.cfg.dataset.train.labels | |||
if labels is not None and len(labels) > 0: | |||
self.label2id = { | |||
label: idx | |||
for idx, label in enumerate(labels) | |||
} | |||
self.id2label = { | |||
idx: label | |||
for idx, label in enumerate(labels) | |||
} | |||
self.num_labels = len(labels) | |||
except AttributeError: | |||
pass | |||
def build_dataset_keys(cfg): | |||
if cfg is not None: | |||
@@ -130,7 +145,13 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
def rebuild_config(self, cfg: Config): | |||
if self.cfg_modify_fn is not None: | |||
return self.cfg_modify_fn(cfg) | |||
cfg = self.cfg_modify_fn(cfg) | |||
if not hasattr(cfg.model, 'label2id') and not hasattr( | |||
cfg.model, 'id2label'): | |||
if self.id2label is not None: | |||
cfg.model['id2label'] = self.id2label | |||
if self.label2id is not None: | |||
cfg.model['label2id'] = self.label2id | |||
return cfg | |||
def build_model(self) -> Union[nn.Module, TorchModel]: | |||
@@ -203,6 +224,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
""" | |||
from modelscope.msdatasets.task_datasets import VecoDataset | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
self.model.eval() | |||
self._mode = ModeKeys.EVAL | |||
metric_values = {} | |||
@@ -223,12 +247,10 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | |||
self.data_loader = self.eval_dataloader | |||
metric_classes = [ | |||
build_metric(metric, default_args={'trainer': self}) | |||
for metric in self.metrics | |||
] | |||
self.evaluation_loop(self.eval_dataloader, checkpoint_path, | |||
metric_classes) | |||
metric_classes = [build_metric(metric) for metric in self.metrics] | |||
for m in metric_classes: | |||
m.trainer = self | |||
self.evaluation_loop(self.eval_dataloader, metric_classes) | |||
for m_idx, metric_cls in enumerate(metric_classes): | |||
if f'eval_dataset[{idx}]' not in metric_values: | |||
@@ -242,4 +264,8 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
else: | |||
break | |||
for metric_name in self.metrics: | |||
metric_values[metric_name] = np.average( | |||
[m[metric_name] for m in metric_values.values()]) | |||
return metric_values |
@@ -1,6 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import random | |||
import time | |||
from collections.abc import Mapping | |||
from distutils.version import LooseVersion | |||
@@ -8,7 +7,6 @@ from functools import partial | |||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
import json | |||
import numpy as np | |||
import torch | |||
from torch import distributed as dist | |||
from torch import nn | |||
@@ -425,8 +423,16 @@ class EpochBasedTrainer(BaseTrainer): | |||
metrics = [metrics] | |||
return metrics | |||
def train(self, *args, **kwargs): | |||
self.model.train() | |||
def set_checkpoint_file_to_hook(self, checkpoint_path): | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
checkpoint_hooks = list( | |||
filter(lambda hook: isinstance(hook, CheckpointHook), | |||
self.hooks)) | |||
for hook in checkpoint_hooks: | |||
hook.checkpoint_file = checkpoint_path | |||
def train(self, checkpoint_path=None, *args, **kwargs): | |||
self._mode = ModeKeys.TRAIN | |||
if self.train_dataset is None: | |||
@@ -442,13 +448,17 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.register_optimizers_hook() | |||
self.register_hook_from_cfg(self.cfg.train.hooks) | |||
self.set_checkpoint_file_to_hook(checkpoint_path) | |||
self.model.train() | |||
self.train_loop(self.train_dataloader) | |||
def evaluate(self, checkpoint_path=None): | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
self.model.eval() | |||
self._mode = ModeKeys.EVAL | |||
if self.eval_dataset is None: | |||
self.eval_dataloader = self.get_eval_data_loader() | |||
else: | |||
@@ -462,8 +472,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
metric_classes = [build_metric(metric) for metric in self.metrics] | |||
for m in metric_classes: | |||
m.trainer = self | |||
metric_values = self.evaluation_loop(self.eval_dataloader, | |||
checkpoint_path, metric_classes) | |||
metric_classes) | |||
self._metric_values = metric_values | |||
return metric_values | |||
@@ -631,18 +642,13 @@ class EpochBasedTrainer(BaseTrainer): | |||
if hasattr(data_cfg, 'name'): | |||
dataset = MsDataset.load( | |||
dataset_name=data_cfg.name, | |||
split=data_cfg.split, | |||
subset_name=data_cfg.subset_name if hasattr( | |||
data_cfg, 'subset_name') else None, | |||
hub=data_cfg.hub | |||
if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||
**data_cfg, | |||
) | |||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||
torch_dataset = dataset.to_torch_dataset( | |||
task_data_config=cfg, | |||
task_name=self.cfg.task, | |||
preprocessors=self.preprocessor) | |||
preprocessors=preprocessor) | |||
else: | |||
torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | |||
dataset = self.to_task_dataset(torch_dataset, mode) | |||
@@ -802,19 +808,22 @@ class EpochBasedTrainer(BaseTrainer): | |||
""" Training loop used by `EpochBasedTrainer.train()` | |||
""" | |||
self.invoke_hook(TrainerStages.before_run) | |||
self._epoch = 0 | |||
kwargs = {} | |||
self.model.train() | |||
for _ in range(self._epoch, self._max_epochs): | |||
self.invoke_hook(TrainerStages.before_train_epoch) | |||
time.sleep(2) # Prevent possible deadlock during epoch transition | |||
for i, data_batch in enumerate(data_loader): | |||
if i < self.inner_iter: | |||
# inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. | |||
continue | |||
data_batch = to_device(data_batch, self.device) | |||
self.data_batch = data_batch | |||
self._inner_iter = i | |||
self.invoke_hook(TrainerStages.before_train_iter) | |||
self.train_step(self.model, data_batch, **kwargs) | |||
self.invoke_hook(TrainerStages.after_train_iter) | |||
# Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||
del self.data_batch | |||
self._iter += 1 | |||
self._mode = ModeKeys.TRAIN | |||
@@ -823,12 +832,14 @@ class EpochBasedTrainer(BaseTrainer): | |||
break | |||
self.invoke_hook(TrainerStages.after_train_epoch) | |||
# Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||
self._inner_iter = 0 | |||
self._epoch += 1 | |||
time.sleep(1) # wait for some hooks like loggers to finish | |||
self.invoke_hook(TrainerStages.after_run) | |||
def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): | |||
def evaluation_loop(self, data_loader, metric_classes): | |||
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`. | |||
""" | |||
@@ -841,7 +852,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
tmpdir=None, | |||
gpu_collect=False, | |||
metric_classes=metric_classes, | |||
data_loader_iters_per_gpu=self.iters_per_epoch) | |||
data_loader_iters_per_gpu=self._eval_iters_per_epoch) | |||
else: | |||
from modelscope.trainers.utils.inference import single_gpu_test | |||
metric_values = single_gpu_test( | |||
@@ -849,7 +860,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
data_loader, | |||
device=self.device, | |||
metric_classes=metric_classes, | |||
data_loader_iters=self.iters_per_epoch) | |||
data_loader_iters=self._eval_iters_per_epoch) | |||
self._inner_iter = self.iters_per_epoch - 1 # start from index 0 | |||
@@ -8,14 +8,17 @@ from shutil import copytree, ignore_patterns, rmtree | |||
from typing import Callable, List, Optional, Union | |||
import json | |||
import numpy as np | |||
import torch | |||
from torch.optim import Optimizer | |||
from torch.optim.lr_scheduler import _LRScheduler | |||
from modelscope import __version__ | |||
from modelscope.fileio import File, LocalStorage | |||
from modelscope.utils.config import JSONIteratorEncoder | |||
from modelscope.utils.constant import ConfigFields, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger(__name__) | |||
storage = LocalStorage() | |||
@@ -40,24 +43,27 @@ def weights_to_cpu(state_dict): | |||
def save_checkpoint(model: torch.nn.Module, | |||
filename: str, | |||
optimizer: Optional[Optimizer] = None, | |||
lr_scheduler: Optional[_LRScheduler] = None, | |||
meta: Optional[dict] = None, | |||
with_meta: bool = True) -> None: | |||
"""Save checkpoint to file. | |||
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | |||
``optimizer``. By default ``meta`` will contain version and time info. | |||
``optimizer``. By default, ``meta`` will contain version and time info. | |||
Args: | |||
model (Module): Module whose params are to be saved. | |||
filename (str): Checkpoint filename. | |||
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | |||
lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved. | |||
meta (dict, optional): Metadata to be saved in checkpoint. | |||
with_meta (bool, optional): | |||
""" | |||
if meta is None: | |||
meta = {} | |||
elif not isinstance(meta, dict): | |||
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | |||
meta.update(modescope=__version__, time=time.asctime()) | |||
meta.update(modelscope=__version__, time=time.asctime()) | |||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
model = model.module | |||
@@ -71,22 +77,69 @@ def save_checkpoint(model: torch.nn.Module, | |||
'meta': meta, | |||
'state_dict': weights_to_cpu(model.state_dict()) | |||
} | |||
# save optimizer state dict in the checkpoint | |||
if isinstance(optimizer, Optimizer): | |||
checkpoint['optimizer'] = optimizer.state_dict() | |||
elif isinstance(optimizer, dict): | |||
checkpoint['optimizer'] = {} | |||
for name, optim in optimizer.items(): | |||
checkpoint['optimizer'][name] = optim.state_dict() | |||
# save lr_scheduler state dict in the checkpoint | |||
assert isinstance(lr_scheduler, _LRScheduler), \ | |||
f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}' | |||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict() | |||
else: | |||
checkpoint = weights_to_cpu(model.state_dict()) | |||
# save optimizer state dict in the checkpoint | |||
if isinstance(optimizer, Optimizer): | |||
checkpoint['optimizer'] = optimizer.state_dict() | |||
elif isinstance(optimizer, dict): | |||
checkpoint['optimizer'] = {} | |||
for name, optim in optimizer.items(): | |||
checkpoint['optimizer'][name] = optim.state_dict() | |||
with io.BytesIO() as f: | |||
torch.save(checkpoint, f) | |||
File.write(f.getvalue(), filename) | |||
def load_checkpoint(filename, | |||
model, | |||
optimizer: Optimizer = None, | |||
lr_scheduler: _LRScheduler = None): | |||
if not os.path.exists(filename): | |||
raise ValueError(f'Checkpoint file {filename} does not exist!') | |||
checkpoint = torch.load(filename, map_location='cpu') | |||
if optimizer is not None: | |||
if 'optimizer' in checkpoint: | |||
if isinstance(optimizer, Optimizer): | |||
optimizer.load_state_dict(checkpoint['optimizer']) | |||
elif isinstance(optimizer, dict): | |||
optimizer_dict = checkpoint['optimizer'] | |||
for key, optimizer_ins in optimizer.items(): | |||
if key in optimizer_dict: | |||
optimizer_ins.load_state_dict(optimizer_dict[key]) | |||
else: | |||
logger.warn( | |||
f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}' | |||
) | |||
else: | |||
logger.warn( | |||
f'The state dict of optimizer cannot be found in checkpoint file: {filename}' | |||
) | |||
if lr_scheduler is not None: | |||
if 'lr_scheduler' in checkpoint: | |||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||
else: | |||
logger.warn( | |||
f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}' | |||
) | |||
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||
'state_dict'] | |||
model.load_state_dict(state_dict) | |||
if 'meta' in checkpoint: | |||
return checkpoint.get('meta', {}) | |||
def save_pretrained(model, | |||
target_folder: Union[str, os.PathLike], | |||
save_checkpoint_name: str = None, | |||
@@ -299,19 +299,23 @@ class MsRegressTool(RegressTool): | |||
file_name, | |||
level='config', | |||
compare_fn=None, | |||
ignore_keys=None): | |||
ignore_keys=None, | |||
compare_random=True, | |||
lazy_stop_callback=None): | |||
def lazy_stop_callback(): | |||
if lazy_stop_callback is None: | |||
from modelscope.trainers.hooks.hook import Hook, Priority | |||
def lazy_stop_callback(): | |||
class EarlyStopHook(Hook): | |||
PRIORITY = Priority.VERY_LOW | |||
from modelscope.trainers.hooks.hook import Hook, Priority | |||
def after_iter(self, trainer): | |||
raise MsRegressTool.EarlyStopError('Test finished.') | |||
class EarlyStopHook(Hook): | |||
PRIORITY = Priority.VERY_LOW | |||
trainer.register_hook(EarlyStopHook()) | |||
def after_iter(self, trainer): | |||
raise MsRegressTool.EarlyStopError('Test finished.') | |||
trainer.register_hook(EarlyStopHook()) | |||
def _train_loop(trainer, *args, **kwargs): | |||
with self.monitor_module_train( | |||
@@ -320,6 +324,7 @@ class MsRegressTool(RegressTool): | |||
level, | |||
compare_fn=compare_fn, | |||
ignore_keys=ignore_keys, | |||
compare_random=compare_random, | |||
lazy_stop_callback=lazy_stop_callback): | |||
try: | |||
return trainer.train_loop_origin(*args, **kwargs) | |||
@@ -1,8 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from huggingface/transformers. | |||
from collections.abc import Mapping | |||
import numpy as np | |||
def torch_nested_numpify(tensors): | |||
@@ -1,3 +0,0 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21 | |||
size 151830 |
@@ -11,7 +11,8 @@ from modelscope.models.nlp.sequence_classification import \ | |||
SbertForSequenceClassification | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.pipelines import pipeline | |||
from modelscope.trainers import build_trainer | |||
from modelscope.trainers import EpochBasedTrainer, build_trainer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.hub import read_config | |||
from modelscope.utils.test_utils import test_level | |||
@@ -119,6 +120,90 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
self.assertTrue(Metrics.accuracy in eval_results) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_configured_datasets(self): | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cfg: Config = read_config(model_id) | |||
cfg.train.max_epochs = 20 | |||
cfg.train.work_dir = self.tmp_dir | |||
cfg.dataset = { | |||
'train': { | |||
'name': 'afqmc_small', | |||
'split': 'train', | |||
'namespace': 'userxiaoming' | |||
}, | |||
'val': { | |||
'name': 'afqmc_small', | |||
'split': 'train', | |||
'namespace': 'userxiaoming' | |||
}, | |||
} | |||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
cfg.dump(cfg_file) | |||
kwargs = dict(model=model_id, cfg_file=cfg_file) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(cfg.train.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
eval_results = trainer.evaluate( | |||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
self.assertTrue(Metrics.accuracy in eval_results) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_continue_train(self): | |||
from modelscope.utils.regress_test_utils import MsRegressTool | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cfg: Config = read_config(model_id) | |||
cfg.train.max_epochs = 3 | |||
cfg.train.work_dir = self.tmp_dir | |||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
cfg.dump(cfg_file) | |||
dataset = MsDataset.load('clue', subset_name='afqmc', split='train') | |||
dataset = dataset.to_hf_dataset().select(range(128)) | |||
kwargs = dict( | |||
model=model_id, | |||
train_dataset=dataset, | |||
eval_dataset=dataset, | |||
cfg_file=cfg_file) | |||
regress_tool = MsRegressTool(baseline=True) | |||
trainer: EpochBasedTrainer = build_trainer(default_args=kwargs) | |||
def lazy_stop_callback(): | |||
from modelscope.trainers.hooks.hook import Hook, Priority | |||
class EarlyStopHook(Hook): | |||
PRIORITY = Priority.VERY_LOW | |||
def after_iter(self, trainer): | |||
if trainer.iter == 12: | |||
raise MsRegressTool.EarlyStopError('Test finished.') | |||
if 'EarlyStopHook' not in [ | |||
hook.__class__.__name__ for hook in trainer.hooks | |||
]: | |||
trainer.register_hook(EarlyStopHook()) | |||
with regress_tool.monitor_ms_train( | |||
trainer, | |||
'trainer_continue_train', | |||
level='strict', | |||
lazy_stop_callback=lazy_stop_callback): | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
trainer = build_trainer(default_args=kwargs) | |||
regress_tool = MsRegressTool(baseline=False) | |||
with regress_tool.monitor_ms_train( | |||
trainer, 'trainer_continue_train', level='strict'): | |||
trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||