Browse Source

[to #42322933] Fix random seed for trainer

1. Fix random seed for trainer and init it at the first line of init
2. Add a regress test for fixed training
3. Change the dataset 'dureader_robust_qg' to 'DuReader_robust-QG'
4. Change some datasets from loading hf.datasets to loading msdataset.load
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10029509
master
yuze.zyz 3 years ago
parent
commit
01e768503c
6 changed files with 95 additions and 16 deletions
  1. +4
    -3
      modelscope/trainers/trainer.py
  2. +14
    -0
      modelscope/utils/regress_test_utils.py
  3. +1
    -0
      modelscope/utils/torch_utils.py
  4. +3
    -0
      tests/trainers/data/test/regression/sbert-base-tnews.bin
  5. +72
    -12
      tests/trainers/test_finetune_sequence_classification.py
  6. +1
    -1
      tests/trainers/test_finetune_text_generation.py

+ 4
- 3
modelscope/trainers/trainer.py View File

@@ -75,6 +75,7 @@ class EpochBasedTrainer(BaseTrainer):
this preprocessing action will be executed every time the dataset's __getitem__ is called.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple
containing the optimizer and the scheduler to use.
seed (int): The optional random seed for torch, cuda, numpy and random.
max_epochs: (int, optional): Total training epochs.
"""

@@ -93,8 +94,11 @@ class EpochBasedTrainer(BaseTrainer):
torch.optim.lr_scheduler._LRScheduler] = (None,
None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
seed: int = 42,
**kwargs):

self._seed = seed
set_random_seed(self._seed)
if isinstance(model, str):
if os.path.exists(model):
self.model_dir = model if os.path.isdir(
@@ -213,9 +217,6 @@ class EpochBasedTrainer(BaseTrainer):

self.use_fp16 = kwargs.get('use_fp16', False)

# TODO @wenmeng.zwm add seed init fn
self._seed = 0

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])



+ 14
- 0
modelscope/utils/regress_test_utils.py View File

@@ -133,6 +133,7 @@ class RegressTool:
compare_fn=None,
ignore_keys=None,
compare_random=True,
reset_dropout=True,
lazy_stop_callback=None):
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.

@@ -151,6 +152,7 @@ class RegressTool:
@param compare_fn: A custom fn used to compare the results manually.
@param ignore_keys: The keys to ignore of the named_parameters.
@param compare_random: If to compare random setttings, default True.
@param reset_dropout: Reset all dropout modules to 0.0.
@param lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.

>>> def compare_fn(v1, v2, key, type):
@@ -202,6 +204,18 @@ class RegressTool:
trainer,
'_seed') else trainer.seed if hasattr(trainer, 'seed') else None

if reset_dropout:
with torch.no_grad():

def reinit_dropout(_module):
for name, submodule in _module.named_children():
if isinstance(submodule, torch.nn.Dropout):
setattr(_module, name, torch.nn.Dropout(0.))
else:
reinit_dropout(submodule)

reinit_dropout(module)

if level == 'strict':
hack_forward(module, file_name, io_json)
intercept_module(module, io_json)


+ 1
- 0
modelscope/utils/torch_utils.py View File

@@ -186,6 +186,7 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
else:
raise ValueError(
f'Random seed should be positive, current seed is {seed}')


+ 3
- 0
tests/trainers/data/test/regression/sbert-base-tnews.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21
size 151830

+ 72
- 12
tests/trainers/test_finetune_sequence_classification.py View File

@@ -10,11 +10,14 @@ from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.trainers.hooks import Hook
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
from modelscope.trainers.nlp_trainer import (EpochBasedTrainer,
NlpEpochBasedTrainer)
from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
calculate_fisher
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.data_utils import to_device
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level


class TestFinetuneSequenceClassification(unittest.TestCase):
@@ -28,11 +31,76 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.regress_tool = MsRegressTool(baseline=False)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_repeatable(self):
import torch # noqa

def cfg_modify_fn(cfg):
cfg.task = 'nli'
cfg['preprocessor'] = {'type': 'nli-tokenizer'}
cfg.train.optimizer.lr = 2e-5
cfg['dataset'] = {
'train': {
'labels': [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
'11', '12', '13', '14'
],
'first_sequence':
'sentence',
'label':
'label',
}
}
cfg.train.max_epochs = 5
cfg.train.lr_scheduler = {
'type': 'LinearLR',
'start_factor': 1.0,
'end_factor': 0.0,
'total_iters':
int(len(dataset['train']) / 32) * cfg.train.max_epochs,
'options': {
'by_epoch': False
}
}
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': 1
}, {
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}, {
'type': 'EvaluationHook',
'by_epoch': False,
'interval': 100
}]
return cfg

dataset = MsDataset.load('clue', subset_name='tnews')

kwargs = dict(
model='damo/nlp_structbert_backbone_base_std',
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
work_dir=self.tmp_dir,
seed=42,
cfg_modify_fn=cfg_modify_fn)

os.environ['LOCAL_RANK'] = '0'
trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)

with self.regress_tool.monitor_ms_train(
trainer, 'sbert-base-tnews', level='strict'):
trainer.train()

def finetune(self,
model_id,
train_dataset,
@@ -54,7 +122,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.epoch_num):
self.assertIn(f'epoch_{i+1}.pth', results_files)
self.assertIn(f'epoch_{i + 1}.pth', results_files)

output_files = os.listdir(
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
@@ -118,11 +186,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
}]
return cfg

from datasets import load_dataset
from datasets import DownloadConfig
dc = DownloadConfig()
dc.local_files_only = True
dataset = load_dataset('clue', 'afqmc', download_config=dc)
dataset = MsDataset.load('clue', subset_name='afqmc')
self.finetune(
model_id='damo/nlp_structbert_backbone_base_std',
train_dataset=dataset['train'],
@@ -182,11 +246,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
}]
return cfg

from datasets import load_dataset
from datasets import DownloadConfig
dc = DownloadConfig()
dc.local_files_only = True
dataset = load_dataset('clue', 'tnews', download_config=dc)
dataset = MsDataset.load('clue', subset_name='tnews')

self.finetune(
model_id='damo/nlp_structbert_backbone_base_std',


+ 1
- 1
tests/trainers/test_finetune_text_generation.py View File

@@ -129,7 +129,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
@unittest.skip
def test_finetune_cnndm(self):
from modelscope.msdatasets import MsDataset
dataset_dict = MsDataset.load('dureader_robust_qg')
dataset_dict = MsDataset.load('DuReader_robust-QG')
train_dataset = dataset_dict['train'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
eval_dataset = dataset_dict['validation'].to_hf_dataset() \


Loading…
Cancel
Save