From 0e4766f41d79a852e573c0251f71a7276380e77d Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 1 Dec 2022 21:16:55 +0800 Subject: [PATCH] Fix bugs in testlevel1 & 2 1. Fix: ws regression failed. 2. Fix: label2id missing in text_classification_pipeline when preprocessor is passed in through args. 3. Fix: remove obsolete imports 4. Fix: incomplete modification Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10936431 --- data/test/regression/sbert_ws_zh.bin | 4 ++-- .../pipelines/nlp/text_classification_pipeline.py | 14 +++++++------- modelscope/preprocessors/nlp/__init__.py | 4 ---- .../nlp/faq_question_answering_preprocessor.py | 2 +- tests/run.py | 2 +- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/data/test/regression/sbert_ws_zh.bin b/data/test/regression/sbert_ws_zh.bin index ed753e50..469a13f9 100644 --- a/data/test/regression/sbert_ws_zh.bin +++ b/data/test/regression/sbert_ws_zh.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3b38bfb5a851d35d5fba4d59eda926557666dbd62c70e3e3b24c22605e7d9c4a -size 40771 +oid sha256:dc16ad72e753f751360dab82878ec0a31190fb5125632d8f4698f6537fae79cb +size 40819 diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py index 24c07d69..845e8315 100644 --- a/modelscope/pipelines/nlp/text_classification_pipeline.py +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -79,12 +79,9 @@ class TextClassificationPipeline(Pipeline): 'sequence_length': sequence_length, **kwargs }) - assert hasattr(self.preprocessor, 'id2label') - self.id2label = self.preprocessor.id2label - if self.id2label is None: - logger.warn( - 'The id2label mapping is None, will return original ids.' - ) + + if hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: @@ -111,6 +108,9 @@ class TextClassificationPipeline(Pipeline): if self.model.__class__.__name__ == 'OfaForAllTasks': return inputs else: + if getattr(self, 'id2label', None) is None: + logger.warn( + 'The id2label mapping is None, will return original ids.') logits = inputs[OutputKeys.LOGITS].cpu().numpy() if logits.shape[0] == 1: logits = logits[0] @@ -126,7 +126,7 @@ class TextClassificationPipeline(Pipeline): probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() def map_to_label(id): - if self.id2label is not None: + if getattr(self, 'id2label', None) is not None: if id in self.id2label: return self.id2label[id] elif str(id) in self.id2label: diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index 5f23fb27..8ee9a80c 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -30,10 +30,6 @@ if TYPE_CHECKING: from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor else: _import_structure = { - 'nlp_base': [ - 'NLPTokenizerPreprocessorBase', - 'NLPBasePreprocessor', - ], 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], 'bert_seq_cls_tokenizer': ['Tokenize'], 'document_segmentation_preprocessor': diff --git a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py index bfff3885..bdf8b30f 100644 --- a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py +++ b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py @@ -119,6 +119,6 @@ class FaqQuestionAnsweringTransformersPreprocessor(Preprocessor): def batch_encode(self, sentence_list: list, max_length=None): if not max_length: - max_length = self.MAX_LEN + max_length = self.max_len return self.tokenizer.batch_encode_plus( sentence_list, padding=True, max_length=max_length) diff --git a/tests/run.py b/tests/run.py index e7fae5a2..1b252756 100644 --- a/tests/run.py +++ b/tests/run.py @@ -555,7 +555,7 @@ if __name__ == '__main__': nargs='*', help='Run specified test suites(test suite files list split by space)') args = parser.parse_args() - set_test_level(2) + set_test_level(args.level) os.environ['REGRESSION_BASELINE'] = '1' logger.info(f'TEST LEVEL: {test_level()}') if not args.disable_profile: