1. Change structbert dev revision to master revision
2. Fix bug: Sample code failed because the updating of model configuration
3. Fix bug: Continue training regression failed
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10519992
master
| @@ -2,13 +2,19 @@ | |||
| from modelscope.utils.config import ConfigDict | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||
| from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
| MODELS = Registry('models') | |||
| BACKBONES = Registry('backbones') | |||
| BACKBONES._modules = MODELS._modules | |||
| BACKBONES = MODELS | |||
| HEADS = Registry('heads') | |||
| modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||
| for module_index in list(modules.keys()): | |||
| if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||
| modules[(MODELS.name.upper(), module_index[1], | |||
| module_index[2])] = modules[module_index] | |||
| def build_model(cfg: ConfigDict, | |||
| task_name: str = None, | |||
| @@ -19,6 +19,7 @@ if TYPE_CHECKING: | |||
| SbertForSequenceClassification, | |||
| SbertForTokenClassification, | |||
| SbertTokenizer, | |||
| SbertModel, | |||
| SbertTokenizerFast, | |||
| ) | |||
| from .bert import ( | |||
| @@ -61,6 +62,7 @@ else: | |||
| 'SbertForTokenClassification', | |||
| 'SbertTokenizer', | |||
| 'SbertTokenizerFast', | |||
| 'SbertModel', | |||
| ], | |||
| 'veco': [ | |||
| 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | |||
| @@ -98,7 +98,7 @@ def compute_adv_loss(embedding, | |||
| if is_nan: | |||
| logger.warning('Nan occured when calculating adv loss.') | |||
| return ori_loss | |||
| emb_grad = emb_grad / emb_grad_norm | |||
| emb_grad = emb_grad / (emb_grad_norm + 1e-6) | |||
| embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||
| embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||
| embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||
| @@ -69,7 +69,7 @@ class CheckpointHook(Hook): | |||
| self.rng_state = meta.get('rng_state') | |||
| self.need_load_rng_state = True | |||
| def before_train_epoch(self, trainer): | |||
| def before_train_iter(self, trainer): | |||
| if self.need_load_rng_state: | |||
| if self.rng_state is not None: | |||
| random.setstate(self.rng_state['random']) | |||
| @@ -84,13 +84,6 @@ class CheckpointHook(Hook): | |||
| 'this may cause a random data order or model initialization.' | |||
| ) | |||
| self.rng_state = { | |||
| 'random': random.getstate(), | |||
| 'numpy': np.random.get_state(), | |||
| 'cpu': torch.random.get_rng_state(), | |||
| 'cuda': torch.cuda.get_rng_state_all(), | |||
| } | |||
| def after_train_epoch(self, trainer): | |||
| if not self.by_epoch: | |||
| return | |||
| @@ -142,6 +135,12 @@ class CheckpointHook(Hook): | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
| self.rng_state = { | |||
| 'random': random.getstate(), | |||
| 'numpy': np.random.get_state(), | |||
| 'cpu': torch.random.get_rng_state(), | |||
| 'cuda': torch.cuda.get_rng_state_all(), | |||
| } | |||
| meta = { | |||
| 'epoch': trainer.epoch, | |||
| 'iter': trainer.iter + 1, | |||
| @@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
| task_dataset.trainer = self | |||
| return task_dataset | |||
| else: | |||
| if task_data_config is None: | |||
| # adapt to some special models | |||
| task_data_config = {} | |||
| # avoid add no str value datasets, preprocessors in cfg | |||
| task_data_build_config = ConfigDict( | |||
| type=self.cfg.model.type, | |||
| @@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer): | |||
| return metrics | |||
| def set_checkpoint_file_to_hook(self, checkpoint_path): | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| checkpoint_hooks = list( | |||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||
| self.hooks)) | |||
| for hook in checkpoint_hooks: | |||
| hook.checkpoint_file = checkpoint_path | |||
| if checkpoint_path is not None: | |||
| if os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| checkpoint_hooks = list( | |||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||
| self.hooks)) | |||
| for hook in checkpoint_hooks: | |||
| hook.checkpoint_file = checkpoint_path | |||
| else: | |||
| self.logger.error( | |||
| f'No {checkpoint_path} found in local file system.') | |||
| def train(self, checkpoint_path=None, *args, **kwargs): | |||
| self._mode = ModeKeys.TRAIN | |||
| @@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| # bert | |||
| language = 'zh' | |||
| model_dir = snapshot_download(self.model_id_bert, revision='beta') | |||
| model_dir = snapshot_download(self.model_id_bert) | |||
| preprocessor = NLPPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = Model.from_pretrained(model_dir) | |||
| @@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| # Bert | |||
| language = 'zh' | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, | |||
| model=self.model_id_bert, | |||
| model_revision='beta') | |||
| pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) | |||
| print( | |||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||
| @@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id, revision='beta') | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SequenceClassificationModel.from_pretrained( | |||
| self.model_id, num_labels=2, revision='beta') | |||
| self.model_id, num_labels=2) | |||
| pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.text_classification, model=model, preprocessor=tokenizer) | |||
| @@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id, revision='beta') | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_classification, | |||
| @@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_classification, | |||
| model=self.model_id, | |||
| model_revision='beta') | |||
| task=Tasks.text_classification, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| self.assertTrue( | |||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_classification, model_revision='beta') | |||
| pipeline_ins = pipeline(task=Tasks.text_classification) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| self.assertTrue( | |||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
| @@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| kwargs = dict( | |||
| model=model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir, | |||
| model_revision='beta') | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| @@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| model=model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir, | |||
| model_revision='beta') | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| @@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_user_defined_config(self): | |||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
| cfg = read_config(model_id, revision='beta') | |||
| cfg = read_config(model_id) | |||
| cfg.train.max_epochs = 20 | |||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
| @@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| model=model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| cfg_file=cfg_file, | |||
| model_revision='beta') | |||
| cfg_file=cfg_file) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| @@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| os.makedirs(tmp_dir) | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| cache_path = snapshot_download(model_id, revision='beta') | |||
| cache_path = snapshot_download(model_id) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||