1. Fix: ws regression failed. 2. Fix: label2id missing in text_classification_pipeline when preprocessor is passed in through args. 3. Fix: remove obsolete imports 4. Fix: incomplete modification Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10936431master^2
@@ -1,3 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
oid sha256:3b38bfb5a851d35d5fba4d59eda926557666dbd62c70e3e3b24c22605e7d9c4a | |||||
size 40771 | |||||
oid sha256:dc16ad72e753f751360dab82878ec0a31190fb5125632d8f4698f6537fae79cb | |||||
size 40819 |
@@ -79,12 +79,9 @@ class TextClassificationPipeline(Pipeline): | |||||
'sequence_length': sequence_length, | 'sequence_length': sequence_length, | ||||
**kwargs | **kwargs | ||||
}) | }) | ||||
assert hasattr(self.preprocessor, 'id2label') | |||||
self.id2label = self.preprocessor.id2label | |||||
if self.id2label is None: | |||||
logger.warn( | |||||
'The id2label mapping is None, will return original ids.' | |||||
) | |||||
if hasattr(self.preprocessor, 'id2label'): | |||||
self.id2label = self.preprocessor.id2label | |||||
def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
**forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
@@ -111,6 +108,9 @@ class TextClassificationPipeline(Pipeline): | |||||
if self.model.__class__.__name__ == 'OfaForAllTasks': | if self.model.__class__.__name__ == 'OfaForAllTasks': | ||||
return inputs | return inputs | ||||
else: | else: | ||||
if getattr(self, 'id2label', None) is None: | |||||
logger.warn( | |||||
'The id2label mapping is None, will return original ids.') | |||||
logits = inputs[OutputKeys.LOGITS].cpu().numpy() | logits = inputs[OutputKeys.LOGITS].cpu().numpy() | ||||
if logits.shape[0] == 1: | if logits.shape[0] == 1: | ||||
logits = logits[0] | logits = logits[0] | ||||
@@ -126,7 +126,7 @@ class TextClassificationPipeline(Pipeline): | |||||
probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() | probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() | ||||
def map_to_label(id): | def map_to_label(id): | ||||
if self.id2label is not None: | |||||
if getattr(self, 'id2label', None) is not None: | |||||
if id in self.id2label: | if id in self.id2label: | ||||
return self.id2label[id] | return self.id2label[id] | ||||
elif str(id) in self.id2label: | elif str(id) in self.id2label: | ||||
@@ -30,10 +30,6 @@ if TYPE_CHECKING: | |||||
from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | ||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'nlp_base': [ | |||||
'NLPTokenizerPreprocessorBase', | |||||
'NLPBasePreprocessor', | |||||
], | |||||
'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | ||||
'bert_seq_cls_tokenizer': ['Tokenize'], | 'bert_seq_cls_tokenizer': ['Tokenize'], | ||||
'document_segmentation_preprocessor': | 'document_segmentation_preprocessor': | ||||
@@ -119,6 +119,6 @@ class FaqQuestionAnsweringTransformersPreprocessor(Preprocessor): | |||||
def batch_encode(self, sentence_list: list, max_length=None): | def batch_encode(self, sentence_list: list, max_length=None): | ||||
if not max_length: | if not max_length: | ||||
max_length = self.MAX_LEN | |||||
max_length = self.max_len | |||||
return self.tokenizer.batch_encode_plus( | return self.tokenizer.batch_encode_plus( | ||||
sentence_list, padding=True, max_length=max_length) | sentence_list, padding=True, max_length=max_length) |
@@ -555,7 +555,7 @@ if __name__ == '__main__': | |||||
nargs='*', | nargs='*', | ||||
help='Run specified test suites(test suite files list split by space)') | help='Run specified test suites(test suite files list split by space)') | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
set_test_level(2) | |||||
set_test_level(args.level) | |||||
os.environ['REGRESSION_BASELINE'] = '1' | os.environ['REGRESSION_BASELINE'] = '1' | ||||
logger.info(f'TEST LEVEL: {test_level()}') | logger.info(f'TEST LEVEL: {test_level()}') | ||||
if not args.disable_profile: | if not args.disable_profile: | ||||