1. Change structbert dev revision to master revision
2. Fix bug: Sample code failed because the updating of model configuration
3. Fix bug: Continue training regression failed
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10519992
master
| @@ -2,13 +2,19 @@ | |||||
| from modelscope.utils.config import ConfigDict | from modelscope.utils.config import ConfigDict | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||||
| from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | ||||
| MODELS = Registry('models') | MODELS = Registry('models') | ||||
| BACKBONES = Registry('backbones') | |||||
| BACKBONES._modules = MODELS._modules | |||||
| BACKBONES = MODELS | |||||
| HEADS = Registry('heads') | HEADS = Registry('heads') | ||||
| modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||||
| for module_index in list(modules.keys()): | |||||
| if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||||
| modules[(MODELS.name.upper(), module_index[1], | |||||
| module_index[2])] = modules[module_index] | |||||
| def build_model(cfg: ConfigDict, | def build_model(cfg: ConfigDict, | ||||
| task_name: str = None, | task_name: str = None, | ||||
| @@ -19,6 +19,7 @@ if TYPE_CHECKING: | |||||
| SbertForSequenceClassification, | SbertForSequenceClassification, | ||||
| SbertForTokenClassification, | SbertForTokenClassification, | ||||
| SbertTokenizer, | SbertTokenizer, | ||||
| SbertModel, | |||||
| SbertTokenizerFast, | SbertTokenizerFast, | ||||
| ) | ) | ||||
| from .bert import ( | from .bert import ( | ||||
| @@ -61,6 +62,7 @@ else: | |||||
| 'SbertForTokenClassification', | 'SbertForTokenClassification', | ||||
| 'SbertTokenizer', | 'SbertTokenizer', | ||||
| 'SbertTokenizerFast', | 'SbertTokenizerFast', | ||||
| 'SbertModel', | |||||
| ], | ], | ||||
| 'veco': [ | 'veco': [ | ||||
| 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | ||||
| @@ -98,7 +98,7 @@ def compute_adv_loss(embedding, | |||||
| if is_nan: | if is_nan: | ||||
| logger.warning('Nan occured when calculating adv loss.') | logger.warning('Nan occured when calculating adv loss.') | ||||
| return ori_loss | return ori_loss | ||||
| emb_grad = emb_grad / emb_grad_norm | |||||
| emb_grad = emb_grad / (emb_grad_norm + 1e-6) | |||||
| embedding_2 = embedding_1 + adv_grad_factor * emb_grad | embedding_2 = embedding_1 + adv_grad_factor * emb_grad | ||||
| embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | ||||
| embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | ||||
| @@ -69,7 +69,7 @@ class CheckpointHook(Hook): | |||||
| self.rng_state = meta.get('rng_state') | self.rng_state = meta.get('rng_state') | ||||
| self.need_load_rng_state = True | self.need_load_rng_state = True | ||||
| def before_train_epoch(self, trainer): | |||||
| def before_train_iter(self, trainer): | |||||
| if self.need_load_rng_state: | if self.need_load_rng_state: | ||||
| if self.rng_state is not None: | if self.rng_state is not None: | ||||
| random.setstate(self.rng_state['random']) | random.setstate(self.rng_state['random']) | ||||
| @@ -84,13 +84,6 @@ class CheckpointHook(Hook): | |||||
| 'this may cause a random data order or model initialization.' | 'this may cause a random data order or model initialization.' | ||||
| ) | ) | ||||
| self.rng_state = { | |||||
| 'random': random.getstate(), | |||||
| 'numpy': np.random.get_state(), | |||||
| 'cpu': torch.random.get_rng_state(), | |||||
| 'cuda': torch.cuda.get_rng_state_all(), | |||||
| } | |||||
| def after_train_epoch(self, trainer): | def after_train_epoch(self, trainer): | ||||
| if not self.by_epoch: | if not self.by_epoch: | ||||
| return | return | ||||
| @@ -142,6 +135,12 @@ class CheckpointHook(Hook): | |||||
| cur_save_name = os.path.join( | cur_save_name = os.path.join( | ||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | ||||
| self.rng_state = { | |||||
| 'random': random.getstate(), | |||||
| 'numpy': np.random.get_state(), | |||||
| 'cpu': torch.random.get_rng_state(), | |||||
| 'cuda': torch.cuda.get_rng_state_all(), | |||||
| } | |||||
| meta = { | meta = { | ||||
| 'epoch': trainer.epoch, | 'epoch': trainer.epoch, | ||||
| 'iter': trainer.iter + 1, | 'iter': trainer.iter + 1, | ||||
| @@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| task_dataset.trainer = self | task_dataset.trainer = self | ||||
| return task_dataset | return task_dataset | ||||
| else: | else: | ||||
| if task_data_config is None: | |||||
| # adapt to some special models | |||||
| task_data_config = {} | |||||
| # avoid add no str value datasets, preprocessors in cfg | # avoid add no str value datasets, preprocessors in cfg | ||||
| task_data_build_config = ConfigDict( | task_data_build_config = ConfigDict( | ||||
| type=self.cfg.model.type, | type=self.cfg.model.type, | ||||
| @@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| return metrics | return metrics | ||||
| def set_checkpoint_file_to_hook(self, checkpoint_path): | def set_checkpoint_file_to_hook(self, checkpoint_path): | ||||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| checkpoint_hooks = list( | |||||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||||
| self.hooks)) | |||||
| for hook in checkpoint_hooks: | |||||
| hook.checkpoint_file = checkpoint_path | |||||
| if checkpoint_path is not None: | |||||
| if os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| checkpoint_hooks = list( | |||||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||||
| self.hooks)) | |||||
| for hook in checkpoint_hooks: | |||||
| hook.checkpoint_file = checkpoint_path | |||||
| else: | |||||
| self.logger.error( | |||||
| f'No {checkpoint_path} found in local file system.') | |||||
| def train(self, checkpoint_path=None, *args, **kwargs): | def train(self, checkpoint_path=None, *args, **kwargs): | ||||
| self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
| @@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| # bert | # bert | ||||
| language = 'zh' | language = 'zh' | ||||
| model_dir = snapshot_download(self.model_id_bert, revision='beta') | |||||
| model_dir = snapshot_download(self.model_id_bert) | |||||
| preprocessor = NLPPreprocessor( | preprocessor = NLPPreprocessor( | ||||
| model_dir, first_sequence='sentence', second_sequence=None) | model_dir, first_sequence='sentence', second_sequence=None) | ||||
| model = Model.from_pretrained(model_dir) | model = Model.from_pretrained(model_dir) | ||||
| @@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| # Bert | # Bert | ||||
| language = 'zh' | language = 'zh' | ||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.fill_mask, | |||||
| model=self.model_id_bert, | |||||
| model_revision='beta') | |||||
| pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) | |||||
| print( | print( | ||||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | ||||
| f'{pipeline_ins(self.test_inputs[language])}\n') | f'{pipeline_ins(self.test_inputs[language])}\n') | ||||
| @@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_direct_file_download(self): | def test_run_with_direct_file_download(self): | ||||
| cache_path = snapshot_download(self.model_id, revision='beta') | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | tokenizer = SequenceClassificationPreprocessor(cache_path) | ||||
| model = SequenceClassificationModel.from_pretrained( | model = SequenceClassificationModel.from_pretrained( | ||||
| self.model_id, num_labels=2, revision='beta') | |||||
| self.model_id, num_labels=2) | |||||
| pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) | pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) | ||||
| pipeline2 = pipeline( | pipeline2 = pipeline( | ||||
| Tasks.text_classification, model=model, preprocessor=tokenizer) | Tasks.text_classification, model=model, preprocessor=tokenizer) | ||||
| @@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained(self.model_id, revision='beta') | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | tokenizer = SequenceClassificationPreprocessor(model.model_dir) | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.text_classification, | task=Tasks.text_classification, | ||||
| @@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.text_classification, | |||||
| model=self.model_id, | |||||
| model_revision='beta') | |||||
| task=Tasks.text_classification, model=self.model_id) | |||||
| print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
| self.assertTrue( | self.assertTrue( | ||||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | isinstance(pipeline_ins.model, SequenceClassificationModel)) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text_classification, model_revision='beta') | |||||
| pipeline_ins = pipeline(task=Tasks.text_classification) | |||||
| print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
| self.assertTrue( | self.assertTrue( | ||||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | isinstance(pipeline_ins.model, SequenceClassificationModel)) | ||||
| @@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_trainer(self): | def test_trainer(self): | ||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||||
| kwargs = dict( | kwargs = dict( | ||||
| model=model_id, | model=model_id, | ||||
| train_dataset=self.dataset, | train_dataset=self.dataset, | ||||
| eval_dataset=self.dataset, | eval_dataset=self.dataset, | ||||
| work_dir=self.tmp_dir, | |||||
| model_revision='beta') | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | trainer = build_trainer(default_args=kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| @@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| model=model_id, | model=model_id, | ||||
| train_dataset=self.dataset, | train_dataset=self.dataset, | ||||
| eval_dataset=self.dataset, | eval_dataset=self.dataset, | ||||
| work_dir=self.tmp_dir, | |||||
| model_revision='beta') | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | trainer = build_trainer(default_args=kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| @@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_trainer_with_user_defined_config(self): | def test_trainer_with_user_defined_config(self): | ||||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
| cfg = read_config(model_id, revision='beta') | |||||
| cfg = read_config(model_id) | |||||
| cfg.train.max_epochs = 20 | cfg.train.max_epochs = 20 | ||||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | ||||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | ||||
| @@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| model=model_id, | model=model_id, | ||||
| train_dataset=self.dataset, | train_dataset=self.dataset, | ||||
| eval_dataset=self.dataset, | eval_dataset=self.dataset, | ||||
| cfg_file=cfg_file, | |||||
| model_revision='beta') | |||||
| cfg_file=cfg_file) | |||||
| trainer = build_trainer(default_args=kwargs) | trainer = build_trainer(default_args=kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| @@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| os.makedirs(tmp_dir) | os.makedirs(tmp_dir) | ||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
| cache_path = snapshot_download(model_id, revision='beta') | |||||
| cache_path = snapshot_download(model_id) | |||||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | model = SbertForSequenceClassification.from_pretrained(cache_path) | ||||
| kwargs = dict( | kwargs = dict( | ||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | ||||