Browse Source

merge feat nlp

master
ly119399 3 years ago
parent
commit
f3146996e8
3 changed files with 92 additions and 4 deletions
  1. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  2. +88
    -0
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  3. +3
    -4
      modelscope/preprocessors/__init__.py

+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -8,3 +8,4 @@ from .sentiment_classification_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import * # noqa F403

+ 88
- 0
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -0,0 +1,88 @@
import os
import uuid
from typing import Any, Dict, Union

import json
import numpy as np
import torch
from scipy.special import softmax

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SbertForZeroShotClassification
from ...preprocessors import ZeroShotClassificationPreprocessor
from ...utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES

__all__ = ['ZeroShotClassificationPipeline']


@PIPELINES.register_module(
Tasks.zero_shot_classification,
module_name=Pipelines.zero_shot_classification)
class ZeroShotClassificationPipeline(Pipeline):

def __init__(self,
model: Union[SbertForZeroShotClassification, str],
preprocessor: ZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \
'model must be a single str or SbertForZeroShotClassification'
model = model if isinstance(
model,
SbertForZeroShotClassification) else Model.from_pretrained(model)
self.entailment_id = 0
self.contradiction_id = 2
if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}
if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self,
inputs: Dict[str, Any],
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results
Args:
inputs (Dict[str, Any]): _description_
Returns:
Dict[str, Any]: the prediction results
"""
logits = inputs['logits']
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1]
else:
logits = logits[..., self.entailment_id]
scores = softmax(logits, axis=-1)
reversed_index = list(reversed(scores.argsort()))
result = {
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index]
}
return result

+ 3
- 4
modelscope/preprocessors/__init__.py View File

@@ -1,8 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

# from .audio import LinearAECAndFbank
from .audio import LinearAECAndFbank
from .base import Preprocessor
# from .builder import PREPROCESSORS, build_preprocessor
from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .kws import WavToLists
@@ -11,5 +11,4 @@ from .nlp import * # noqa F403
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
from .space.dialog_modeling_preprocessor import * # noqa F403
from .space.dialog_state_tracking_preprocessor import * # noqa F403

# from .text_to_speech import * # noqa F403
from .text_to_speech import * # noqa F403

Loading…
Cancel
Save