From f3146996e8264efb51ce07a1c5c1c8a05e4d5268 Mon Sep 17 00:00:00 2001 From: ly119399 Date: Thu, 30 Jun 2022 18:58:54 +0800 Subject: [PATCH] merge feat nlp --- modelscope/pipelines/nlp/__init__.py | 1 + .../nlp/zero_shot_classification_pipeline.py | 88 +++++++++++++++++++ modelscope/preprocessors/__init__.py | 7 +- 3 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 modelscope/pipelines/nlp/zero_shot_classification_pipeline.py diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 08a6f825..f600dec0 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -8,3 +8,4 @@ from .sentiment_classification_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 +from .zero_shot_classification_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py new file mode 100644 index 00000000..375e9093 --- /dev/null +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -0,0 +1,88 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np +import torch +from scipy.special import softmax + +from ...metainfo import Pipelines +from ...models import Model +from ...models.nlp import SbertForZeroShotClassification +from ...preprocessors import ZeroShotClassificationPreprocessor +from ...utils.constant import Tasks +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['ZeroShotClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.zero_shot_classification, + module_name=Pipelines.zero_shot_classification) +class ZeroShotClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[SbertForZeroShotClassification, str], + preprocessor: ZeroShotClassificationPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + Args: + model (SbertForSentimentClassification): a model instance + preprocessor (SentimentClassificationPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ + 'model must be a single str or SbertForZeroShotClassification' + model = model if isinstance( + model, + SbertForZeroShotClassification) else Model.from_pretrained(model) + self.entailment_id = 0 + self.contradiction_id = 2 + if preprocessor is None: + preprocessor = ZeroShotClassificationPreprocessor(model.model_dir) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **kwargs): + preprocess_params = {} + postprocess_params = {} + if 'candidate_labels' in kwargs: + candidate_labels = kwargs.pop('candidate_labels') + preprocess_params['candidate_labels'] = candidate_labels + postprocess_params['candidate_labels'] = candidate_labels + else: + raise ValueError('You must include at least one label.') + preprocess_params['hypothesis_template'] = kwargs.pop( + 'hypothesis_template', '{}') + postprocess_params['multi_label'] = kwargs.pop('multi_label', False) + return preprocess_params, {}, postprocess_params + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, + inputs: Dict[str, Any], + candidate_labels, + multi_label=False) -> Dict[str, Any]: + """process the prediction results + Args: + inputs (Dict[str, Any]): _description_ + Returns: + Dict[str, Any]: the prediction results + """ + logits = inputs['logits'] + if multi_label or len(candidate_labels) == 1: + logits = logits[..., [self.contradiction_id, self.entailment_id]] + scores = softmax(logits, axis=-1)[..., 1] + else: + logits = logits[..., self.entailment_id] + scores = softmax(logits, axis=-1) + reversed_index = list(reversed(scores.argsort())) + result = { + 'labels': [candidate_labels[i] for i in reversed_index], + 'scores': [scores[i].item() for i in reversed_index] + } + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 4b4932c5..742a6152 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# from .audio import LinearAECAndFbank +from .audio import LinearAECAndFbank from .base import Preprocessor -# from .builder import PREPROCESSORS, build_preprocessor +from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .kws import WavToLists @@ -11,5 +11,4 @@ from .nlp import * # noqa F403 from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403 from .space.dialog_state_tracking_preprocessor import * # noqa F403 - -# from .text_to_speech import * # noqa F403 +from .text_to_speech import * # noqa F403