Browse Source

[to #42322933] refine some comments

Refine some comments.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9674650
master
yuze.zyz yingda.chen 3 years ago
parent
commit
1640df567b
25 changed files with 385 additions and 46 deletions
  1. +3
    -1
      modelscope/metrics/sequence_classification_metric.py
  2. +2
    -0
      modelscope/metrics/text_generation_metric.py
  3. +4
    -2
      modelscope/metrics/token_classification_metric.py
  4. +0
    -2
      modelscope/models/nlp/backbones/structbert.py
  5. +12
    -0
      modelscope/models/nlp/masked_language.py
  6. +6
    -0
      modelscope/models/nlp/nncrf_for_named_entity_recognition.py
  7. +32
    -0
      modelscope/models/nlp/sequence_classification.py
  8. +24
    -1
      modelscope/models/nlp/token_classification.py
  9. +4
    -3
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  10. +3
    -2
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  11. +21
    -3
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  12. +18
    -0
      modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
  13. +22
    -3
      modelscope/pipelines/nlp/pair_sentence_classification_pipeline.py
  14. +3
    -3
      modelscope/pipelines/nlp/sequence_classification_pipeline_base.py
  15. +21
    -3
      modelscope/pipelines/nlp/single_sentence_classification_pipeline.py
  16. +5
    -3
      modelscope/pipelines/nlp/summarization_pipeline.py
  17. +4
    -3
      modelscope/pipelines/nlp/task_oriented_conversation_pipeline.py
  18. +13
    -3
      modelscope/pipelines/nlp/text_error_correction_pipeline.py
  19. +24
    -4
      modelscope/pipelines/nlp/text_generation_pipeline.py
  20. +6
    -1
      modelscope/pipelines/nlp/translation_pipeline.py
  21. +19
    -4
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  22. +29
    -3
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  23. +64
    -2
      modelscope/preprocessors/nlp.py
  24. +28
    -0
      modelscope/trainers/nlp_trainer.py
  25. +18
    -0
      modelscope/utils/hub.py

+ 3
- 1
modelscope/metrics/sequence_classification_metric.py View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Union
from typing import Dict

import numpy as np

@@ -15,6 +15,8 @@ from .builder import METRICS, MetricKeys
group_key=default_group, module_name=Metrics.seq_cls_metric)
class SequenceClassificationMetric(Metric):
"""The metric computation class for sequence classification classes.

This metric class calculates accuracy for the whole input batches.
"""

def __init__(self, *args, **kwargs):


+ 2
- 0
modelscope/metrics/text_generation_metric.py View File

@@ -10,6 +10,8 @@ from .builder import METRICS, MetricKeys
group_key=default_group, module_name=Metrics.text_gen_metric)
class TextGenerationMetric(Metric):
"""The metric computation class for text generation classes.

This metric class calculates F1 of the rouge scores for the whole evaluation dataset.
"""

def __init__(self):


+ 4
- 2
modelscope/metrics/token_classification_metric.py View File

@@ -14,8 +14,10 @@ from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group, module_name=Metrics.token_cls_metric)
class TokenClassificationMetric(Metric):
"""
The metric computation class for token-classification task.
"""The metric computation class for token-classification task.

This metric class uses seqeval to calculate the scores.

Args:
return_entity_level_metrics (bool, *optional*):
Whether to return every label's detail metrics, default False.


+ 0
- 2
modelscope/models/nlp/backbones/structbert.py View File

@@ -1,5 +1,3 @@
from transformers import PreTrainedModel

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import BACKBONES


+ 12
- 0
modelscope/models/nlp/masked_language.py View File

@@ -17,6 +17,10 @@ __all__ = ['BertForMaskedLM', 'StructBertForMaskedLM', 'VecoForMaskedLM']

@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
class StructBertForMaskedLM(TorchModel, SbertForMaskedLM):
"""Structbert for MLM model.

Inherited from structbert.SbertForMaskedLM and TorchModel, so this class can be registered into Model sets.
"""

def __init__(self, config, model_dir):
super(TorchModel, self).__init__(model_dir)
@@ -49,6 +53,10 @@ class StructBertForMaskedLM(TorchModel, SbertForMaskedLM):

@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert)
class BertForMaskedLM(TorchModel, BertForMaskedLMTransformer):
"""Bert for MLM model.

Inherited from transformers.BertForMaskedLM and TorchModel, so this class can be registered into Model sets.
"""

def __init__(self, config, model_dir):
super(TorchModel, self).__init__(model_dir)
@@ -83,6 +91,10 @@ class BertForMaskedLM(TorchModel, BertForMaskedLMTransformer):

@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
class VecoForMaskedLM(TorchModel, VecoForMaskedLMTransformer):
"""Veco for MLM model.

Inherited from veco.VecoForMaskedLM and TorchModel, so this class can be registered into Model sets.
"""

def __init__(self, config, model_dir):
super(TorchModel, self).__init__(model_dir)


+ 6
- 0
modelscope/models/nlp/nncrf_for_named_entity_recognition.py View File

@@ -16,6 +16,8 @@ __all__ = ['TransformerCRFForNamedEntityRecognition']
@MODELS.register_module(
Tasks.named_entity_recognition, module_name=Models.tcrf)
class TransformerCRFForNamedEntityRecognition(TorchModel):
"""This model wraps the TransformerCRF model to register into model sets.
"""

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
@@ -63,6 +65,10 @@ class TransformerCRFForNamedEntityRecognition(TorchModel):


class TransformerCRF(nn.Module):
"""A transformer based model to NER tasks.

This model will use transformers' backbones as its backbone.
"""

def __init__(self, model_dir, num_labels, **kwargs):
super(TransformerCRF, self).__init__()


+ 32
- 0
modelscope/models/nlp/sequence_classification.py View File

@@ -18,6 +18,8 @@ __all__ = ['SbertForSequenceClassification', 'VecoForSequenceClassification']


class SequenceClassificationBase(TorchModel):
"""A sequence classification base class for all the fitted sequence classification models.
"""
base_model_prefix: str = 'bert'

def __init__(self, config, model_dir):
@@ -81,6 +83,10 @@ class SequenceClassificationBase(TorchModel):
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForSequenceClassification(SequenceClassificationBase,
SbertPreTrainedModel):
"""Sbert sequence classification model.

Inherited from SequenceClassificationBase.
"""
base_model_prefix: str = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']
@@ -108,6 +114,16 @@ class SbertForSequenceClassification(SequenceClassificationBase,

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""

model_dir = kwargs.get('model_dir')
num_labels = kwargs.get('num_labels')
if num_labels is None:
@@ -129,6 +145,12 @@ class SbertForSequenceClassification(SequenceClassificationBase,
@MODELS.register_module(Tasks.nli, module_name=Models.veco)
class VecoForSequenceClassification(TorchModel,
VecoForSequenceClassificationTransform):
"""Veco sequence classification model.

Inherited from VecoForSequenceClassification and TorchModel, so this class can be registered into the model set.
This model cannot be inherited from SequenceClassificationBase, because Veco/XlmRoberta's classification structure
is different.
"""

def __init__(self, config, model_dir):
super().__init__(model_dir)
@@ -159,6 +181,16 @@ class VecoForSequenceClassification(TorchModel,

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
@return: The loaded model, which is initialized by veco.VecoForSequenceClassification.from_pretrained
"""

model_dir = kwargs.get('model_dir')
num_labels = kwargs.get('num_labels')
if num_labels is None:


+ 24
- 1
modelscope/models/nlp/token_classification.py View File

@@ -19,6 +19,8 @@ __all__ = ['SbertForTokenClassification']


class TokenClassification(TorchModel):
"""A token classification base class for all the fitted token classification models.
"""

base_model_prefix: str = 'bert'

@@ -56,7 +58,7 @@ class TokenClassification(TorchModel):
labels: The labels
**kwargs: Other input params.

Returns: Loss.
Returns: The loss.

"""
pass
@@ -92,6 +94,10 @@ class TokenClassification(TorchModel):
@MODELS.register_module(
Tasks.token_classification, module_name=Models.structbert)
class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):
"""Sbert token classification model.

Inherited from TokenClassification.
"""

supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r'pooler']
@@ -118,6 +124,14 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):
labels=labels)

def compute_loss(self, logits, labels, attention_mask=None, **kwargs):
"""Compute the loss with an attention mask.

@param logits: The logits output from the classifier.
@param labels: The labels.
@param attention_mask: The attention_mask.
@param kwargs: Unused input args.
@return: The loss
"""
loss_fct = nn.CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
@@ -132,6 +146,15 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.get('model_dir')
num_labels = kwargs.get('num_labels')
if num_labels is None:


+ 4
- 3
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -23,11 +23,12 @@ class DialogIntentPredictionPipeline(Pipeline):
model: Union[SpaceForDialogIntent, str],
preprocessor: DialogIntentPredictionPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog intent prediction pipeline
"""Use `model` and `preprocessor` to create a dialog intent prediction pipeline

Args:
model (SpaceForDialogIntent): a model instance
preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance
model (str or SpaceForDialogIntent): Supply either a local model dir or a model id from the model hub,
or a SpaceForDialogIntent instance.
preprocessor (DialogIntentPredictionPreprocessor): An optional preprocessor instance.
"""
model = model if isinstance(
model, SpaceForDialogIntent) else Model.from_pretrained(model)


+ 3
- 2
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -24,8 +24,9 @@ class DialogStateTrackingPipeline(Pipeline):
observation of dialog states tracking after many turns of open domain dialogue

Args:
model (SpaceForDialogStateTracking): a model instance
preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance
model (str or SpaceForDialogStateTracking): Supply either a local model dir or a model id
from the model hub, or a SpaceForDialogStateTracking instance.
preprocessor (DialogStateTrackingPreprocessor): An optional preprocessor instance.
"""

model = model if isinstance(


+ 21
- 3
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -24,11 +24,29 @@ class FillMaskPipeline(Pipeline):
preprocessor: Optional[Preprocessor] = None,
first_sequence='sentence',
**kwargs):
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction
"""Use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction

Args:
model (Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported mlm task, or a
mlm model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the sentence in.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline('fill-mask', model='damo/nlp_structbert_fill-mask_english-large')
>>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].'
>>> print(pipeline_ins(input))

NOTE2: Please pay attention to the model's special tokens.
If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'.
If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'.
To view other examples plese check the tests/pipelines/test_fill_mask.py.
"""
fill_mask_model = model if isinstance(
model, Model) else Model.from_pretrained(model)


+ 18
- 0
modelscope/pipelines/nlp/named_entity_recognition_pipeline.py View File

@@ -22,6 +22,24 @@ class NamedEntityRecognitionPipeline(Pipeline):
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""Use `model` and `preprocessor` to create a nlp NER pipeline for prediction

Args:
model (str or Model): Supply either a local model dir which supported NER task, or a
model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='named-entity-recognition',
>>> model='damo/nlp_raner_named-entity-recognition_chinese-base-news')
>>> input = '这与温岭市新河镇的一个神秘的传说有关。'
>>> print(pipeline_ins(input))

To view other examples plese check the tests/pipelines/test_named_entity_recognition.py.
"""

model = model if isinstance(model,
Model) else Model.from_pretrained(model)


+ 22
- 3
modelscope/pipelines/nlp/pair_sentence_classification_pipeline.py View File

@@ -23,11 +23,30 @@ class PairSentenceClassificationPipeline(SequenceClassificationPipelineBase):
first_sequence='first_sequence',
second_sequence='second_sequence',
**kwargs):
"""use `model` and `preprocessor` to create a nlp pair sentence classification pipeline for prediction
"""Use `model` and `preprocessor` to create a nlp pair sequence classification pipeline for prediction.

Args:
model (Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported the sequence classification task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the first sentence in.
second_sequence: The key to read the second sentence in.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

NOTE: Inputs of type 'tuple' or 'list' are also supported. In this scenario, the 'first_sequence' and
'second_sequence' param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='nli', model='damo/nlp_structbert_nli_chinese-base')
>>> sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
>>> sentence2 = '四川商务职业学院商务管理在哪个校区?'
>>> print(pipeline_ins((sentence1, sentence2)))
>>> # Or use the dict input:
>>> print(pipeline_ins({'first_sequence': sentence1, 'second_sequence': sentence2}))

To view other examples plese check the tests/pipelines/test_nli.py.
"""
if preprocessor is None:
preprocessor = PairSentenceClassificationPreprocessor(


+ 3
- 3
modelscope/pipelines/nlp/sequence_classification_pipeline_base.py View File

@@ -13,11 +13,11 @@ class SequenceClassificationPipelineBase(Pipeline):

def __init__(self, model: Union[Model, str], preprocessor: Preprocessor,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
"""This is the base class for all the sequence classification sub-tasks.

Args:
model (str or Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): A model instance or a model local dir or a model id in the model hub.
preprocessor (Preprocessor): a preprocessor instance, must not be None.
"""
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'


+ 21
- 3
modelscope/pipelines/nlp/single_sentence_classification_pipeline.py View File

@@ -22,11 +22,29 @@ class SingleSentenceClassificationPipeline(SequenceClassificationPipelineBase):
preprocessor: Preprocessor = None,
first_sequence='first_sequence',
**kwargs):
"""use `model` and `preprocessor` to create a nlp single sentence classification pipeline for prediction
"""Use `model` and `preprocessor` to create a nlp single sequence classification pipeline for prediction.

Args:
model (Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported the sequence classification task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the first sentence in.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='sentiment-classification',
>>> model='damo/nlp_structbert_sentiment-classification_chinese-base')
>>> sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音'
>>> print(pipeline_ins(sentence1))
>>> # Or use the dict input:
>>> print(pipeline_ins({'first_sequence': sentence1}))

To view other examples plese check the tests/pipelines/test_sentiment-classification.py.
"""
if preprocessor is None:
preprocessor = SingleSentenceClassificationPreprocessor(


+ 5
- 3
modelscope/pipelines/nlp/summarization_pipeline.py View File

@@ -20,10 +20,12 @@ class SummarizationPipeline(Pipeline):
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
"""Use `model` and `preprocessor` to create a Summarization pipeline for prediction.
Args:
model: model id on modelscope hub.
model (str or Model): Supply either a local model dir which supported the summarization task,
or a model id from the model hub, or a model instance.
preprocessor (Preprocessor): An optional preprocessor instance.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \


+ 4
- 3
modelscope/pipelines/nlp/task_oriented_conversation_pipeline.py View File

@@ -23,11 +23,12 @@ class TaskOrientedConversationPipeline(Pipeline):
model: Union[SpaceForDialogModeling, str],
preprocessor: DialogModelingPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation
"""Use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation

Args:
model (SpaceForDialogModeling): a model instance
preprocessor (DialogModelingPreprocessor): a preprocessor instance
model (str or SpaceForDialogModeling): Supply either a local model dir or a model id from the model hub,
or a SpaceForDialogModeling instance.
preprocessor (DialogModelingPreprocessor): An optional preprocessor instance.
"""
model = model if isinstance(
model, SpaceForDialogModeling) else Model.from_pretrained(model)


+ 13
- 3
modelscope/pipelines/nlp/text_error_correction_pipeline.py View File

@@ -23,12 +23,22 @@ class TextErrorCorrectionPipeline(Pipeline):
model: Union[BartForTextErrorCorrection, str],
preprocessor: Optional[TextErrorCorrectionPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction
"""use `model` and `preprocessor` to create a nlp text correction pipeline.

Args:
model (BartForTextErrorCorrection): a model instance
preprocessor (TextErrorCorrectionPreprocessor): a preprocessor instance
model (BartForTextErrorCorrection): A model instance, or a model local dir, or a model id in the model hub.
preprocessor (TextErrorCorrectionPreprocessor): An optional preprocessor instance.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(
>>> task='text-error-correction', model='damo/nlp_bart_text-error-correction_chinese')
>>> sentence1 = '随着中国经济突飞猛近,建造工业与日俱增'
>>> print(pipeline_ins(sentence1))

To view other examples plese check the tests/pipelines/test_text_error_correction.py.
"""

model = model if isinstance(
model,
BartForTextErrorCorrection) else Model.from_pretrained(model)


+ 24
- 4
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -19,19 +19,39 @@ class TextGenerationPipeline(Pipeline):
def __init__(self,
model: Union[Model, str],
preprocessor: Optional[TextGenerationPreprocessor] = None,
first_sequence='sentence',
**kwargs):
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.

Args:
model (PalmForTextGeneration): a model instance
preprocessor (TextGenerationPreprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported the text generation task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the first sentence in.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='text-generation',
>>> model='damo/nlp_palm2.0_text-generation_chinese-base')
>>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:'
>>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代'
>>> print(pipeline_ins(sentence1))
>>> # Or use the dict input:
>>> print(pipeline_ins({'sentence': sentence1}))

To view other examples plese check the tests/pipelines/test_text_generation.py.
"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
model.model_dir,
first_sequence='sentence',
first_sequence=first_sequence,
second_sequence=None,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()


+ 6
- 1
modelscope/pipelines/nlp/translation_pipeline.py View File

@@ -5,6 +5,7 @@ import numpy as np
import tensorflow as tf

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
@@ -25,7 +26,11 @@ __all__ = ['TranslationPipeline']
Tasks.translation, module_name=Pipelines.csanmt_translation)
class TranslationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
def __init__(self, model: Model, **kwargs):
"""Build a translation pipeline with a model dir or a model id in the model hub.

@param model: A Model instance.
"""
super().__init__(model=model)
model = self.model.model_dir
tf.reset_default_graph()


+ 19
- 4
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -5,7 +5,7 @@ import torch
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import (Preprocessor,
TokenClassificationPreprocessor)
@@ -22,11 +22,26 @@ class WordSegmentationPipeline(Pipeline):
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction
"""Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction.

Args:
model (Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported the WS task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.

NOTE: The preprocessor will first split the sentence into single characters,
then feed them into the tokenizer with the parameter is_split_into_words=True.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='word-segmentation',
>>> model='damo/nlp_structbert_word-segmentation_chinese-base')
>>> sentence1 = '今天天气不错,适合出去游玩'
>>> print(pipeline_ins(sentence1))

To view other examples plese check the tests/pipelines/test_word_segmentation.py.
"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)


+ 29
- 3
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -24,10 +24,36 @@ class ZeroShotClassificationPipeline(Pipeline):
model: Union[Model, str],
preprocessor: Preprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp zero-shot text classification pipeline for prediction
"""Use `model` and `preprocessor` to create a nlp zero shot classifiction for prediction.

A zero-shot classification task is used to classify texts by prompts.
In a normal classification task, model may produce a positive label by the input text
like 'The ice cream is made of the high quality milk, it is so delicious'
In a zero-shot task, the sentence is converted to:
['The ice cream is made of the high quality milk, it is so delicious', 'This means it is good']
And:
['The ice cream is made of the high quality milk, it is so delicious', 'This means it is bad']
Then feed these sentences into the model and turn the task to a NLI task(entailment, contradiction),
and compare the output logits to give the original classification label.


Args:
model (Model): a model instance
preprocessor (Preprocessor): a preprocessor instance
model (str or Model): Supply either a local model dir which supported the task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='zero-shot-classification',
>>> model='damo/nlp_structbert_zero-shot-classification_chinese-base')
>>> sentence1 = '全新突破 解放军运20版空中加油机曝光'
>>> labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
>>> template = '这篇文章的标题是{}'
>>> print(pipeline_ins(sentence1, candidate_labels=labels, hypothesis_template=template))

To view other examples plese check the tests/pipelines/test_zero_shot_classification.py.
"""
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'


+ 64
- 2
modelscope/preprocessors/nlp.py View File

@@ -107,10 +107,20 @@ class SequenceClassificationPreprocessor(Preprocessor):
class NLPTokenizerPreprocessorBase(Preprocessor):

def __init__(self, model_dir: str, pair: bool, mode: str, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
"""The NLP tokenizer preprocessor base class.

Any nlp preprocessor which uses the hf tokenizer can inherit from this class.

Args:
model_dir (str): model path
model_dir (str): The local model path
first_sequence: The key for the first sequence
second_sequence: The key for the second sequence
label: The label key
label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping
if this mapping is not supplied.
pair (bool): Pair sentence input or single sentence input.
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode
kwargs: These kwargs will be directly fed into the tokenizer.
"""

super().__init__(**kwargs)
@@ -132,11 +142,24 @@ class NLPTokenizerPreprocessorBase(Preprocessor):

@property
def id2label(self):
"""Return the id2label mapping according to the label2id mapping.

@return: The id2label mapping if exists.
"""
if self.label2id is not None:
return {id: label for label, id in self.label2id.items()}
return None

def build_tokenizer(self, model_dir):
"""Build a tokenizer by the model type.

NOTE: This default implementation only returns slow tokenizer, because the fast tokenizers have a
multi-thread problem.

@param model_dir: The local model dir.
@return: The initialized tokenizer.
"""

model_type = get_model_type(model_dir)
if model_type in (Models.structbert, Models.gpt3, Models.palm):
from modelscope.models.nlp.structbert import SbertTokenizer
@@ -172,6 +195,15 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
return output

def parse_text_and_label(self, data):
"""Parse the input and return the sentences and labels.

When input type is tuple or list and its size is 2:
If the pair param is False, data will be parsed as the first_sentence and the label,
else it will be parsed as the first_sentence and the second_sentence.

@param data: The input data.
@return: The sentences and labels tuple.
"""
text_a, text_b, labels = None, None, None
if isinstance(data, str):
text_a = data
@@ -191,6 +223,16 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
return text_a, text_b, labels

def labels_to_id(self, labels, output):
"""Turn the labels to id with the type int or float.

If the original label's type is str or int, the label2id mapping will try to convert it to the final label.
If the original label's type is float, or the label2id mapping does not exist,
the original label will be returned.

@param labels: The input labels.
@param output: The label id.
@return: The final labels.
"""

def label_can_be_mapped(label):
return isinstance(label, str) or isinstance(label, int)
@@ -212,6 +254,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer)
class PairSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in pair sentence classification.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
@@ -224,6 +268,8 @@ class PairSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer)
class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in single sentence classification.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
@@ -236,6 +282,8 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer)
class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in zero shot classification.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -277,6 +325,8 @@ class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.text_gen_tokenizer)
class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in text generation.
"""

def __init__(self,
model_dir: str,
@@ -325,6 +375,8 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):

@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask)
class FillMaskPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in MLM task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
@@ -339,6 +391,8 @@ class FillMaskPreprocessor(NLPTokenizerPreprocessorBase):
Fields.nlp,
module_name=Preprocessors.word_segment_text_to_label_preprocessor)
class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
"""The preprocessor used to turn a single sentence to a labeled token-classification dict.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -373,6 +427,8 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.token_cls_tokenizer)
class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in normal token classification task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
@@ -455,6 +511,10 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.ner_tokenizer)
class NERPreprocessor(Preprocessor):
"""The tokenizer preprocessor used in normal NER task.

NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition.
"""

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -544,6 +604,8 @@ class NERPreprocessor(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.text_error_correction)
class TextErrorCorrectionPreprocessor(Preprocessor):
"""The preprocessor used in text correction task.
"""

def __init__(self, model_dir: str, *args, **kwargs):
from fairseq.data import Dictionary


+ 28
- 0
modelscope/trainers/nlp_trainer.py View File

@@ -38,8 +38,36 @@ class NlpEpochBasedTrainer(EpochBasedTrainer):
**kwargs):
"""Add code to adapt with nlp models.

This trainer will accept the information of labels&text keys in the cfg, and then initialize
the nlp models/preprocessors with this information.

Labels&text key information may be carried in the cfg like this:

>>> cfg = {
>>> ...
>>> "dataset": {
>>> "train": {
>>> "first_sequence": "text1",
>>> "second_sequence": "text2",
>>> "label": "label",
>>> "labels": [1, 2, 3, 4]
>>> }
>>> }
>>> }


Args:
cfg_modify_fn: An input fn which is used to modify the cfg read out of the file.

Example:
>>> def cfg_modify_fn(cfg):
>>> cfg.preprocessor.first_sequence= 'text1'
>>> cfg.preprocessor.second_sequence='text2'
>>> return cfg

To view some actual finetune examples, please check the test files listed below:
tests/trainers/test_finetune_sequence_classification.py
tests/trainers/test_finetune_token_classification.py
"""

if isinstance(model, str):


+ 18
- 0
modelscope/utils/hub.py View File

@@ -74,6 +74,14 @@ def auto_load(model: Union[str, List[str]]):


def get_model_type(model_dir):
"""Get the model type from the configuration.

This method will try to get the 'model.type' or 'model.model_type' field from the configuration.json file.
If this file does not exist, the method will try to get the 'model_type' field from the config.json.

@param model_dir: The local model dir to use.
@return: The model type string, returns None if nothing is found.
"""
try:
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
config_file = osp.join(model_dir, 'config.json')
@@ -89,6 +97,16 @@ def get_model_type(model_dir):


def parse_label_mapping(model_dir):
"""Get the label mapping from the model dir.

This method will do:
1. Try to read label-id mapping from the label_mapping.json
2. Try to read label-id mapping from the configuration.json
3. Try to read label-id mapping from the config.json

@param model_dir: The local model dir to use.
@return: The label2id mapping if found.
"""
import json
import os
label2id = None


Loading…
Cancel
Save