Browse Source

Merge branch feat/fix_lazy_import_bug_in_text_cls_pipeline into master

Title: [to #42322933] Fix lazy importing problem in text classification pipeline 

1.  Fix lazy importing problem in text classification pipeline
2.  Ignore some regressing discrepancies caused by variable versions of transformers
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10614286
master
yingda.chen 2 years ago
parent
commit
69d41ee413
8 changed files with 63 additions and 29 deletions
  1. +10
    -8
      modelscope/pipelines/nlp/text_classification_pipeline.py
  2. +9
    -8
      modelscope/preprocessors/base.py
  3. +18
    -0
      modelscope/utils/regress_test_utils.py
  4. +7
    -3
      tests/pipelines/test_fill_mask.py
  5. +4
    -3
      tests/pipelines/test_nli.py
  6. +4
    -2
      tests/pipelines/test_sentence_similarity.py
  7. +7
    -3
      tests/pipelines/test_word_segmentation.py
  8. +4
    -2
      tests/pipelines/test_zero_shot_classification.py

+ 10
- 8
modelscope/pipelines/nlp/text_classification_pipeline.py View File

@@ -3,14 +3,13 @@ from typing import Any, Dict, Union

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.base import Model
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Fields, Tasks


@PIPELINES.register_module(
@@ -58,8 +57,11 @@ class TextClassificationPipeline(Pipeline):
str) else model

if preprocessor is None:
if isinstance(model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=model.model_dir)
if model.__class__.__name__ == 'OfaForAllTasks':
preprocessor = Preprocessor.from_pretrained(
model_name_or_path=model.model_dir,
type=Preprocessors.ofa_tasks_preprocessor,
field=Fields.multi_modal)
else:
first_sequence = kwargs.pop('first_sequence', 'first_sequence')
second_sequence = kwargs.pop('second_sequence', None)
@@ -76,7 +78,7 @@ class TextClassificationPipeline(Pipeline):

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if isinstance(self.model, OfaForAllTasks):
if self.model.__class__.__name__ == 'OfaForAllTasks':
return super().forward(inputs, **forward_params)
return self.model(**inputs, **forward_params)

@@ -95,7 +97,7 @@ class TextClassificationPipeline(Pipeline):
labels: The real labels.
Label at index 0 is the smallest probability.
"""
if isinstance(self.model, OfaForAllTasks):
if self.model.__class__.__name__ == 'OfaForAllTasks':
return inputs
else:
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \


+ 9
- 8
modelscope/preprocessors/base.py View File

@@ -205,10 +205,12 @@ class Preprocessor(ABC):
if 'task' in kwargs:
task = kwargs.pop('task')
field_name = Tasks.find_field_by_task(task)
if 'field' in kwargs:
field_name = kwargs.pop('field')
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'

if not hasattr(cfg, 'preprocessor'):
logger.error('No preprocessor field found in cfg.')
if not hasattr(cfg, 'preprocessor') or len(cfg.preprocessor) == 0:
logger.warn('No preprocessor field found in cfg.')
preprocessor_cfg = ConfigDict()
else:
preprocessor_cfg = cfg.preprocessor
@@ -217,9 +219,8 @@ class Preprocessor(ABC):
if sub_key in preprocessor_cfg:
sub_cfg = getattr(preprocessor_cfg, sub_key)
else:
logger.error(
f'No {sub_key} key and type key found in '
f'preprocessor domain of configuration.json file.')
logger.warn(f'No {sub_key} key and type key found in '
f'preprocessor domain of configuration.json file.')
sub_cfg = preprocessor_cfg
else:
sub_cfg = preprocessor_cfg
@@ -235,7 +236,7 @@ class Preprocessor(ABC):

preprocessor = build_preprocessor(sub_cfg, field_name)
else:
logger.error(
logger.warn(
f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, '
f'current config: {sub_cfg}. trying to build by task and model information.'
)
@@ -243,13 +244,13 @@ class Preprocessor(ABC):
model_type = model_cfg.type if hasattr(
model_cfg, 'type') else getattr(model_cfg, 'model_type', None)
if task is None or model_type is None:
logger.error(
logger.warn(
f'Find task: {task}, model type: {model_type}. '
f'Insufficient information to build preprocessor, skip building preprocessor'
)
return None
if (model_type, task) not in PREPROCESSOR_MAP:
logger.error(
logger.warn(
f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
f'skip building preprocessor.')
return None


+ 18
- 0
modelscope/utils/regress_test_utils.py View File

@@ -5,6 +5,7 @@ import hashlib
import os
import pickle
import random
import re
import shutil
import tempfile
from collections import OrderedDict
@@ -759,3 +760,20 @@ def compare_cfg_and_optimizers(baseline_json,
state2, **kwargs) and match

return match


class IgnoreKeyFn:

def __init__(self, keys):
if isinstance(keys, str):
keys = [keys]
self.keys = keys if isinstance(keys, list) else []

def __call__(self, v1output, v2output, key, type):
if key == 'encoder.encoder.layer.0.intermediate.intermediate_act_fn':
print()
for _key in self.keys:
pattern = re.compile(_key)
if key is not None and pattern.fullmatch(key):
return True
return None

+ 7
- 3
tests/pipelines/test_fill_mask.py View File

@@ -11,7 +11,7 @@ from modelscope.pipelines.nlp import FillMaskPipeline
from modelscope.preprocessors import NLPPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -109,7 +109,9 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_sbert_{language}'):
pipeline_ins.model,
f'fill_mask_sbert_{language}',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')
@@ -124,7 +126,9 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_veco_{language}'):
pipeline_ins.model,
f'fill_mask_veco_{language}',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


+ 4
- 3
tests/pipelines/test_nli.py View File

@@ -3,13 +3,12 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -48,7 +47,9 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_nli'):
pipeline_ins.model,
'sbert_nli',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 4
- 2
tests/pipelines/test_sentence_similarity.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import TextClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -54,7 +54,9 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.sentence_similarity, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_sen_sim'):
pipeline_ins.model,
'sbert_sen_sim',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 7
- 3
tests/pipelines/test_word_segmentation.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import WordSegmentationPipeline
from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -48,10 +48,14 @@ class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_zh'):
pipeline_ins.model,
'sbert_ws_zh',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=self.sentence))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_en'):
pipeline_ins.model,
'sbert_ws_en',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(pipeline_ins(input=self.sentence_eng))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


+ 4
- 2
tests/pipelines/test_zero_shot_classification.py View File

@@ -9,7 +9,7 @@ from modelscope.pipelines.nlp import ZeroShotClassificationPipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level


@@ -65,7 +65,9 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, model=self.model_id)
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_zero_shot'):
pipeline_ins.model,
'sbert_zero_shot',
compare_fn=IgnoreKeyFn('.*intermediate_act_fn')):
print(
pipeline_ins(
input=self.sentence, candidate_labels=self.labels))


Loading…
Cancel
Save