Browse Source

[to #42322933] remove dev model inference and fix some bugs

1. Change structbert dev revision to master revision
2. Fix bug:  Sample code failed because the updating of model configuration
3. Fix bug: Continue training regression failed
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10519992
master
yuze.zyz yingda.chen 2 years ago
parent
commit
c2da44b371
8 changed files with 45 additions and 40 deletions
  1. +8
    -2
      modelscope/models/builder.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +1
    -1
      modelscope/models/nlp/structbert/adv_utils.py
  4. +7
    -8
      modelscope/trainers/hooks/checkpoint_hook.py
  5. +14
    -7
      modelscope/trainers/trainer.py
  6. +2
    -5
      tests/pipelines/test_fill_mask.py
  7. +5
    -8
      tests/pipelines/test_sentiment_classification.py
  8. +6
    -9
      tests/trainers/test_trainer_with_nlp.py

+ 8
- 2
modelscope/models/builder.py View File

@@ -2,13 +2,19 @@

from modelscope.utils.config import ConfigDict
from modelscope.utils.constant import Tasks
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg

MODELS = Registry('models')
BACKBONES = Registry('backbones')
BACKBONES._modules = MODELS._modules
BACKBONES = MODELS
HEADS = Registry('heads')

modules = LazyImportModule.AST_INDEX[INDEX_KEY]
for module_index in list(modules.keys()):
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
modules[(MODELS.name.upper(), module_index[1],
module_index[2])] = modules[module_index]


def build_model(cfg: ConfigDict,
task_name: str = None,


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
SbertForSequenceClassification,
SbertForTokenClassification,
SbertTokenizer,
SbertModel,
SbertTokenizerFast,
)
from .bert import (
@@ -61,6 +62,7 @@ else:
'SbertForTokenClassification',
'SbertTokenizer',
'SbertTokenizerFast',
'SbertModel',
],
'veco': [
'VecoModel', 'VecoConfig', 'VecoForTokenClassification',


+ 1
- 1
modelscope/models/nlp/structbert/adv_utils.py View File

@@ -98,7 +98,7 @@ def compute_adv_loss(embedding,
if is_nan:
logger.warning('Nan occured when calculating adv loss.')
return ori_loss
emb_grad = emb_grad / emb_grad_norm
emb_grad = emb_grad / (emb_grad_norm + 1e-6)
embedding_2 = embedding_1 + adv_grad_factor * emb_grad
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)


+ 7
- 8
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -69,7 +69,7 @@ class CheckpointHook(Hook):
self.rng_state = meta.get('rng_state')
self.need_load_rng_state = True

def before_train_epoch(self, trainer):
def before_train_iter(self, trainer):
if self.need_load_rng_state:
if self.rng_state is not None:
random.setstate(self.rng_state['random'])
@@ -84,13 +84,6 @@ class CheckpointHook(Hook):
'this may cause a random data order or model initialization.'
)

self.rng_state = {
'random': random.getstate(),
'numpy': np.random.get_state(),
'cpu': torch.random.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all(),
}

def after_train_epoch(self, trainer):
if not self.by_epoch:
return
@@ -142,6 +135,12 @@ class CheckpointHook(Hook):
cur_save_name = os.path.join(
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')

self.rng_state = {
'random': random.getstate(),
'numpy': np.random.get_state(),
'cpu': torch.random.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all(),
}
meta = {
'epoch': trainer.epoch,
'iter': trainer.iter + 1,


+ 14
- 7
modelscope/trainers/trainer.py View File

@@ -354,6 +354,9 @@ class EpochBasedTrainer(BaseTrainer):
task_dataset.trainer = self
return task_dataset
else:
if task_data_config is None:
# adapt to some special models
task_data_config = {}
# avoid add no str value datasets, preprocessors in cfg
task_data_build_config = ConfigDict(
type=self.cfg.model.type,
@@ -419,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer):
return metrics

def set_checkpoint_file_to_hook(self, checkpoint_path):
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
checkpoint_hooks = list(
filter(lambda hook: isinstance(hook, CheckpointHook),
self.hooks))
for hook in checkpoint_hooks:
hook.checkpoint_file = checkpoint_path
if checkpoint_path is not None:
if os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
checkpoint_hooks = list(
filter(lambda hook: isinstance(hook, CheckpointHook),
self.hooks))
for hook in checkpoint_hooks:
hook.checkpoint_file = checkpoint_path
else:
self.logger.error(
f'No {checkpoint_path} found in local file system.')

def train(self, checkpoint_path=None, *args, **kwargs):
self._mode = ModeKeys.TRAIN


+ 2
- 5
tests/pipelines/test_fill_mask.py View File

@@ -83,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):

# bert
language = 'zh'
model_dir = snapshot_download(self.model_id_bert, revision='beta')
model_dir = snapshot_download(self.model_id_bert)
preprocessor = NLPPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = Model.from_pretrained(model_dir)
@@ -149,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):

# Bert
language = 'zh'
pipeline_ins = pipeline(
task=Tasks.fill_mask,
model=self.model_id_bert,
model_revision='beta')
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert)
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')


+ 5
- 8
tests/pipelines/test_sentiment_classification.py View File

@@ -24,10 +24,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id, revision='beta')
cache_path = snapshot_download(self.model_id)
tokenizer = SequenceClassificationPreprocessor(cache_path)
model = SequenceClassificationModel.from_pretrained(
self.model_id, num_labels=2, revision='beta')
self.model_id, num_labels=2)
pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.text_classification, model=model, preprocessor=tokenizer)
@@ -38,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id, revision='beta')
model = Model.from_pretrained(self.model_id)
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.text_classification,
@@ -51,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.text_classification,
model=self.model_id,
model_revision='beta')
task=Tasks.text_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence1))
self.assertTrue(
isinstance(pipeline_ins.model, SequenceClassificationModel))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(
task=Tasks.text_classification, model_revision='beta')
pipeline_ins = pipeline(task=Tasks.text_classification)
print(pipeline_ins(input=self.sentence1))
self.assertTrue(
isinstance(pipeline_ins.model, SequenceClassificationModel))


+ 6
- 9
tests/trainers/test_trainer_with_nlp.py View File

@@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
kwargs = dict(
model=model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir,
model_revision='beta')
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
@@ -80,8 +79,7 @@ class TestTrainerWithNlp(unittest.TestCase):
model=model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir,
model_revision='beta')
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
@@ -97,7 +95,7 @@ class TestTrainerWithNlp(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_user_defined_config(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
cfg = read_config(model_id, revision='beta')
cfg = read_config(model_id)
cfg.train.max_epochs = 20
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
@@ -108,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase):
model=model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
cfg_file=cfg_file,
model_revision='beta')
cfg_file=cfg_file)

trainer = build_trainer(default_args=kwargs)
trainer.train()
@@ -233,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase):
os.makedirs(tmp_dir)

model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
cache_path = snapshot_download(model_id, revision='beta')
cache_path = snapshot_download(model_id)
model = SbertForSequenceClassification.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),


Loading…
Cancel
Save