Browse Source

[to #42322933] Add gpt_neo model

1. 添加 gpt_neo 模型,因 checkpoint 归属于 Langboat 还未上传到模型库,已线下完成测试
2. 添加 text-generation task models 与 head,后续会将 gpt3,palm 等已上线文本生成模型统一为 backbone + head 结构的 task models
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10404249
master
hemu.zp yingda.chen 3 years ago
parent
commit
271e2a2a99
12 changed files with 207 additions and 11 deletions
  1. +5
    -0
      modelscope/metainfo.py
  2. +3
    -1
      modelscope/models/nlp/__init__.py
  3. +15
    -0
      modelscope/models/nlp/backbones/gpt_neo.py
  4. +35
    -0
      modelscope/models/nlp/heads/text_generation_head.py
  5. +2
    -0
      modelscope/models/nlp/task_models/__init__.py
  6. +79
    -0
      modelscope/models/nlp/task_models/text_generation.py
  7. +29
    -9
      modelscope/pipelines/nlp/text_generation_pipeline.py
  8. +2
    -0
      modelscope/preprocessors/__init__.py
  9. +2
    -0
      modelscope/preprocessors/nlp/__init__.py
  10. +21
    -0
      modelscope/preprocessors/nlp/nlp_base.py
  11. +13
    -0
      tests/pipelines/test_text_generation.py
  12. +1
    -1
      tests/utils/test_ast.py

+ 5
- 0
modelscope/metainfo.py View File

@@ -71,6 +71,7 @@ class Models(object):
gcnncrf = 'gcnn-crf'
bart = 'bart'
gpt3 = 'gpt3'
gpt_neo = 'gpt-neo'
plug = 'plug'
bert_for_ds = 'bert-for-document-segmentation'
ponet = 'ponet'
@@ -101,6 +102,7 @@ class TaskModels(object):
information_extraction = 'information-extraction'
fill_mask = 'fill-mask'
feature_extraction = 'feature-extraction'
text_generation = 'text-generation'


class Heads(object):
@@ -116,6 +118,8 @@ class Heads(object):
token_classification = 'token-classification'
# extraction
information_extraction = 'information-extraction'
# text gen
text_generation = 'text-generation'


class Pipelines(object):
@@ -341,6 +345,7 @@ class Preprocessors(object):
re_tokenizer = 're-tokenizer'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
sentence_piece = 'sentence-piece'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 3
- 1
modelscope/models/nlp/__init__.py View File

@@ -30,7 +30,8 @@ if TYPE_CHECKING:
InformationExtractionModel,
SequenceClassificationModel,
SingleBackboneTaskModelBase,
TokenClassificationModel)
TokenClassificationModel,
TaskModelForTextGeneration)
from .token_classification import SbertForTokenClassification
from .sentence_embedding import SentenceEmbedding
from .passage_ranking import PassageRanking
@@ -69,6 +70,7 @@ else:
'SequenceClassificationModel',
'SingleBackboneTaskModelBase',
'TokenClassificationModel',
'TaskModelForTextGeneration',
],
'token_classification': ['SbertForTokenClassification'],
'table_question_answering': ['TableQuestionAnswering'],


+ 15
- 0
modelscope/models/nlp/backbones/gpt_neo.py View File

@@ -0,0 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from transformers import GPTNeoConfig
from transformers import GPTNeoModel as GPTNeoModelTransform

from modelscope.metainfo import Models
from modelscope.models.builder import BACKBONES
from modelscope.utils.constant import Fields


@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.gpt_neo)
class GPTNeoModel(GPTNeoModelTransform):

def __init__(self, **kwargs):
config = GPTNeoConfig(**kwargs)
super().__init__(config)

+ 35
- 0
modelscope/models/nlp/heads/text_generation_head.py View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict

import torch
import torch.nn.functional as F
from torch import nn

from modelscope.metainfo import Heads
from modelscope.models.base import TorchHead
from modelscope.models.builder import HEADS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks


@HEADS.register_module(
Tasks.text_generation, module_name=Heads.text_generation)
class TextGenerationHead(TorchHead):

def __init__(self, **kwargs):
super().__init__(**kwargs)
config = self.config
self.linear = nn.Linear(
config['hidden_size'], config['vocab_size'], bias=False)

def get_output_embeddings(self):
return self.linear

def forward(self, inputs=None):
logits = self.linear(inputs)
return {OutputKeys.LOGITS: logits}

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
logits = outputs[OutputKeys.LOGITS]
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)}

+ 2
- 0
modelscope/models/nlp/task_models/__init__.py View File

@@ -10,6 +10,7 @@ if TYPE_CHECKING:
from .sequence_classification import SequenceClassificationModel
from .task_model import SingleBackboneTaskModelBase
from .token_classification import TokenClassificationModel
from .text_generation import TaskModelForTextGeneration

else:
_import_structure = {
@@ -19,6 +20,7 @@ else:
'sequence_classification': ['SequenceClassificationModel'],
'task_model': ['SingleBackboneTaskModelBase'],
'token_classification': ['TokenClassificationModel'],
'text_generation': ['TaskModelForTextGeneration'],
}

import sys


+ 79
- 0
modelscope/models/nlp/task_models/text_generation.py View File

@@ -0,0 +1,79 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import addict
import numpy as np
from transformers.modeling_utils import PreTrainedModel

from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks

__all__ = ['TaskModelForTextGeneration']


@MODELS.register_module(
Tasks.text_generation, module_name=TaskModels.text_generation)
class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

self.build_backbone(self.backbone_cfg)
self.build_head(self.head_cfg)
if self.config.get('shared_embedding', False):
input_embeddings = self.backbone.get_input_embeddings()
output_embeddings = self.head.get_output_embeddings()
output_embeddings.weight = input_embeddings.weight

def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:
# backbone do not need labels, only head need for loss compute
labels = input.pop(OutputKeys.LABELS, None)

backbone_outputs = super().forward(input)
hidden_states = backbone_outputs[0]

outputs = self.head.forward(hidden_states)
if labels is not None:
input[OutputKeys.LABELS] = labels
loss = self.compute_loss(outputs, labels)
outputs.update(loss)
return addict.Dict(outputs)

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get('token_type_ids', None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get('attention_mask', None)
position_ids = kwargs.get('position_ids', None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
'input_ids': input_ids,
'past_key_values': past,
'use_cache': kwargs.get('use_cache'),
'position_ids': position_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
}

+ 29
- 9
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -6,10 +6,12 @@ import torch

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.preprocessors import Preprocessor, build_preprocessor
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.hub import read_config

__all__ = ['TextGenerationPipeline']

@@ -20,7 +22,7 @@ class TextGenerationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[TextGenerationPreprocessor] = None,
preprocessor: Optional[Preprocessor] = None,
first_sequence='sentence',
**kwargs):
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.
@@ -50,19 +52,34 @@ class TextGenerationPipeline(Pipeline):
"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
cfg = read_config(model.model_dir)
self.postprocessor = cfg.pop('postprocessor', None)
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
preprocessor_cfg = cfg.preprocessor
preprocessor_cfg.update({
'model_dir':
model.model_dir,
first_sequence=first_sequence,
second_sequence=None,
sequence_length=kwargs.pop('sequence_length', 128))
'first_sequence':
first_sequence,
'second_sequence':
None,
'sequence_length':
kwargs.pop('sequence_length', 128)
})
preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **pipeline_parameters):
return {}, pipeline_parameters, {}

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model.generate(inputs)
return self.model.generate(inputs, **forward_params)

def sentence_piece(self, inputs) -> Dict[str, Tensor]:
return self.preprocessor.tokenizer.decode(inputs.tolist())[0]

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
@@ -74,4 +91,7 @@ class TextGenerationPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
return inputs
return inputs if self.postprocessor is None else {
OutputKeys.TEXT:
getattr(self, self.postprocessor.replace('-', '_'))(inputs)
}

+ 2
- 0
modelscope/preprocessors/__init__.py View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
SentencePiecePreprocessor,
)
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
@@ -71,6 +72,7 @@ else:
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'ZeroShotClassificationPreprocessor',
'SentencePiecePreprocessor',
],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',


+ 2
- 0
modelscope/preprocessors/nlp/__init__.py View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
SentencePiecePreprocessor,
)

else:
@@ -41,6 +42,7 @@ else:
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'ZeroShotClassificationPreprocessor',
'SentencePiecePreprocessor',
],
'text_error_correction': [
'TextErrorCorrectionPreprocessor',


+ 21
- 0
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -5,6 +5,7 @@ import re
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
import sentencepiece as spm
import torch
from transformers import AutoTokenizer

@@ -1160,3 +1161,23 @@ class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase):

self.labels_to_id(labels, output)
return output


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sentence_piece)
class SentencePiecePreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
import os

super().__init__(*args, **kwargs)
self.tokenizer = None
for file_name in os.listdir(model_dir):
if file_name.endswith('.model'):
m_file = osp.join(model_dir, file_name)
self.tokenizer = spm.SentencePieceProcessor(model_file=m_file)
break
assert self.tokenizer is not None, 'Can not find .model file'

def __call__(self, data: str) -> Dict[str, Any]:
return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long)

+ 13
- 0
tests/pipelines/test_text_generation.py View File

@@ -133,6 +133,19 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
def test_demo_compatibility(self):
self.compatibility_check()

@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub")
def test_gpt_neo(self):
pipe = pipeline(
task=Tasks.text_generation, model='Langboat/mengzi-gpt-neo-base')
print(
pipe(
'我是',
do_sample=True,
top_k=5,
top_p=1,
max_length=20,
repetition_penalty=0.5))


if __name__ == '__main__':
unittest.main()

+ 1
- 1
tests/utils/test_ast.py View File

@@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase):
self.assertIsInstance(from_imports, dict)
self.assertIsInstance(decorators, list)
self.assertListEqual(list(set(imports.keys()) - set(['torch'])), [])
self.assertEqual(len(from_imports.keys()), 7)
self.assertEqual(len(from_imports.keys()), 9)
self.assertTrue(from_imports['modelscope.metainfo'] is not None)
self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines'])
self.assertEqual(decorators,


Loading…
Cancel
Save