1. Change structbert dev revision to master revision 2. Fix bug: Sample code failed because the updating of model configuration 3. Fix bug: Continue training regression failed Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10519992master
@@ -2,13 +2,19 @@ | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
MODELS = Registry('models') | |||
BACKBONES = Registry('backbones') | |||
BACKBONES._modules = MODELS._modules | |||
BACKBONES = MODELS | |||
HEADS = Registry('heads') | |||
modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||
for module_index in list(modules.keys()): | |||
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||
modules[(MODELS.name.upper(), module_index[1], | |||
module_index[2])] = modules[module_index] | |||
def build_model(cfg: ConfigDict, | |||
task_name: str = None, | |||
@@ -19,6 +19,7 @@ if TYPE_CHECKING: | |||
SbertForSequenceClassification, | |||
SbertForTokenClassification, | |||
SbertTokenizer, | |||
SbertModel, | |||
SbertTokenizerFast, | |||
) | |||
from .bert import ( | |||
@@ -61,6 +62,7 @@ else: | |||
'SbertForTokenClassification', | |||
'SbertTokenizer', | |||
'SbertTokenizerFast', | |||
'SbertModel', | |||
], | |||
'veco': [ | |||
'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | |||
@@ -98,7 +98,7 @@ def compute_adv_loss(embedding, | |||
if is_nan: | |||
logger.warning('Nan occured when calculating adv loss.') | |||
return ori_loss | |||
emb_grad = emb_grad / emb_grad_norm | |||
emb_grad = emb_grad / (emb_grad_norm + 1e-6) | |||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||
@@ -69,7 +69,7 @@ class CheckpointHook(Hook): | |||
self.rng_state = meta.get('rng_state') | |||
self.need_load_rng_state = True | |||
def before_train_epoch(self, trainer): | |||
def before_train_iter(self, trainer): | |||
if self.need_load_rng_state: | |||
if self.rng_state is not None: | |||
random.setstate(self.rng_state['random']) | |||
@@ -84,13 +84,6 @@ class CheckpointHook(Hook): | |||
'this may cause a random data order or model initialization.' | |||
) | |||
self.rng_state = { | |||
'random': random.getstate(), | |||
'numpy': np.random.get_state(), | |||
'cpu': torch.random.get_rng_state(), | |||
'cuda': torch.cuda.get_rng_state_all(), | |||
} | |||
def after_train_epoch(self, trainer): | |||
if not self.by_epoch: | |||
return | |||
@@ -142,6 +135,12 @@ class CheckpointHook(Hook): | |||
cur_save_name = os.path.join( | |||
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
self.rng_state = { | |||
'random': random.getstate(), | |||
'numpy': np.random.get_state(), | |||
'cpu': torch.random.get_rng_state(), | |||
'cuda': torch.cuda.get_rng_state_all(), | |||
} | |||
meta = { | |||
'epoch': trainer.epoch, | |||
'iter': trainer.iter + 1, | |||
@@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
task_dataset.trainer = self | |||
return task_dataset | |||
else: | |||
if task_data_config is None: | |||
# adapt to some special models | |||
task_data_config = {} | |||
# avoid add no str value datasets, preprocessors in cfg | |||
task_data_build_config = ConfigDict( | |||
type=self.cfg.model.type, | |||
@@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer): | |||
return metrics | |||
def set_checkpoint_file_to_hook(self, checkpoint_path): | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
checkpoint_hooks = list( | |||
filter(lambda hook: isinstance(hook, CheckpointHook), | |||
self.hooks)) | |||
for hook in checkpoint_hooks: | |||
hook.checkpoint_file = checkpoint_path | |||
if checkpoint_path is not None: | |||
if os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
checkpoint_hooks = list( | |||
filter(lambda hook: isinstance(hook, CheckpointHook), | |||
self.hooks)) | |||
for hook in checkpoint_hooks: | |||
hook.checkpoint_file = checkpoint_path | |||
else: | |||
self.logger.error( | |||
f'No {checkpoint_path} found in local file system.') | |||
def train(self, checkpoint_path=None, *args, **kwargs): | |||
self._mode = ModeKeys.TRAIN | |||
@@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
# bert | |||
language = 'zh' | |||
model_dir = snapshot_download(self.model_id_bert, revision='beta') | |||
model_dir = snapshot_download(self.model_id_bert) | |||
preprocessor = NLPPreprocessor( | |||
model_dir, first_sequence='sentence', second_sequence=None) | |||
model = Model.from_pretrained(model_dir) | |||
@@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
# Bert | |||
language = 'zh' | |||
pipeline_ins = pipeline( | |||
task=Tasks.fill_mask, | |||
model=self.model_id_bert, | |||
model_revision='beta') | |||
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) | |||
print( | |||
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||
f'{pipeline_ins(self.test_inputs[language])}\n') | |||
@@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_direct_file_download(self): | |||
cache_path = snapshot_download(self.model_id, revision='beta') | |||
cache_path = snapshot_download(self.model_id) | |||
tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
model = SequenceClassificationModel.from_pretrained( | |||
self.model_id, num_labels=2, revision='beta') | |||
self.model_id, num_labels=2) | |||
pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) | |||
pipeline2 = pipeline( | |||
Tasks.text_classification, model=model, preprocessor=tokenizer) | |||
@@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id, revision='beta') | |||
model = Model.from_pretrained(self.model_id) | |||
tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.text_classification, | |||
@@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.text_classification, | |||
model=self.model_id, | |||
model_revision='beta') | |||
task=Tasks.text_classification, model=self.model_id) | |||
print(pipeline_ins(input=self.sentence1)) | |||
self.assertTrue( | |||
isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.text_classification, model_revision='beta') | |||
pipeline_ins = pipeline(task=Tasks.text_classification) | |||
print(pipeline_ins(input=self.sentence1)) | |||
self.assertTrue( | |||
isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
@@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer(self): | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
kwargs = dict( | |||
model=model_id, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
work_dir=self.tmp_dir, | |||
model_revision='beta') | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
@@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
model=model_id, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
work_dir=self.tmp_dir, | |||
model_revision='beta') | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
@@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_user_defined_config(self): | |||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
cfg = read_config(model_id, revision='beta') | |||
cfg = read_config(model_id) | |||
cfg.train.max_epochs = 20 | |||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
@@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
model=model_id, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
cfg_file=cfg_file, | |||
model_revision='beta') | |||
cfg_file=cfg_file) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
@@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
os.makedirs(tmp_dir) | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cache_path = snapshot_download(model_id, revision='beta') | |||
cache_path = snapshot_download(model_id) | |||
model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||