Browse Source

[to #42322933] Fix bug for bloom and gpt_neo

1. 修复 bloom 和 gpt_neo 模型更新 transformers 4.23 后后处理报错的问题
2. 统一使用 ModelOutput 作为模型输出
3. gpt_neo checkpoint 已上线,修改 ut 为 level2
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10553103
master
hemu.zp yingda.chen 2 years ago
parent
commit
fa415d8720
5 changed files with 71 additions and 14 deletions
  1. +3
    -5
      modelscope/models/nlp/heads/text_generation_head.py
  2. +15
    -7
      modelscope/models/nlp/task_models/text_generation.py
  3. +47
    -0
      modelscope/outputs/nlp/model_outputs.py
  4. +5
    -1
      modelscope/pipelines/nlp/text_generation_pipeline.py
  5. +1
    -1
      tests/pipelines/test_text_generation.py

+ 3
- 5
modelscope/models/nlp/heads/text_generation_head.py View File

@@ -8,7 +8,6 @@ from torch import nn
from modelscope.metainfo import Heads
from modelscope.models.base import TorchHead
from modelscope.models.builder import HEADS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks


@@ -27,9 +26,8 @@ class TextGenerationHead(TorchHead):

def forward(self, inputs=None):
logits = self.linear(inputs)
return {OutputKeys.LOGITS: logits}
return logits

def compute_loss(self, outputs: Dict[str, torch.Tensor],
def compute_loss(self, logits: torch.Tensor,
labels) -> Dict[str, torch.Tensor]:
logits = outputs[OutputKeys.LOGITS]
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)}
return F.cross_entropy(logits, labels)

+ 15
- 7
modelscope/models/nlp/task_models/text_generation.py View File

@@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import addict
import numpy as np
from transformers.modeling_utils import PreTrainedModel

@@ -9,7 +8,8 @@ from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.outputs import (OutputKeys, TextGenerationModelOutput,
TokenGeneratorOutput)
from modelscope.utils.constant import Tasks

__all__ = ['TaskModelForTextGeneration']
@@ -43,12 +43,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):
backbone_outputs = super().forward(input)
hidden_states = backbone_outputs[0]

outputs = self.head.forward(hidden_states)
logits = self.head.forward(hidden_states)
loss = None
if labels is not None:
input[OutputKeys.LABELS] = labels
loss = self.compute_loss(outputs, labels)
outputs.update(loss)
return addict.Dict(outputs)
loss = self.compute_loss(logits, labels)
return TextGenerationModelOutput(logits=logits, loss=loss)

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
@@ -76,4 +76,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):

def generate(self, inputs, *args, **kwargs):
input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs
return super().generate(input_ids, *args, **kwargs)
generate_output = super().generate(input_ids, *args, **kwargs)
if isinstance(generate_output, Dict):
return TokenGeneratorOutput(
sequences=generate_output.sequences,
scores=generate_output.scores,
attentions=generate_output.attentions,
hidden_states=generate_output.hidden_states)
else:
return TokenGeneratorOutput(sequences=generate_output)

+ 47
- 0
modelscope/outputs/nlp/model_outputs.py View File

@@ -541,3 +541,50 @@ class Seq2SeqLMOutput(ModelOutputBase):
encoder_last_hidden_state: Optional[Tensor] = None
encoder_hidden_states: Optional[Tuple[Tensor]] = None
encoder_attentions: Optional[Tuple[Tensor]] = None


@dataclass
class TextGenerationModelOutput(ModelOutputBase):
"""The output class for text generation models.

Args:
logits (`Tensor`): The logits output of the model. loss (`Tensor`,
*optional*) The loss of the model, available when training.
hidden_states (`Tensor`, *optional*) Hidden-states of the model at the
output of each layer plus the optional initial embedding outputs.
"""

logits: Tensor = None
loss: Tensor = None


@dataclass
class TokenGeneratorOutput(ModelOutputBase):
"""
The output class for generate method of text generation models.


Args:
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`
is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`
is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`
is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
"""

sequences: Tensor = None
scores: Optional[Tuple[Tensor]] = None
attentions: Optional[Tuple[Tuple[Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[Tensor]]] = None

+ 5
- 1
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -104,6 +104,10 @@ class TextGenerationPipeline(Pipeline):
tokenizer = self.preprocessor.tokenizer
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True)

def sentence_piece(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
return tokenizer.decode(inputs.tolist())

def roberta(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
decoded = tokenizer.decode(inputs.tolist())
@@ -121,7 +125,7 @@ class TextGenerationPipeline(Pipeline):
Dict[str, str]: the prediction results
"""
inputs = inputs['sequences']
if isinstance(inputs, list):
if isinstance(inputs, list) or len(inputs.shape) > 1:
inputs = inputs[0]
decoded = getattr(self, self.postprocessor)(inputs)
text = self._remove_space_between_chinese_chars(decoded)


+ 1
- 1
tests/pipelines/test_text_generation.py View File

@@ -183,7 +183,7 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.text_generation, model='langboat/bloom-1b4-zh')
print(pipe('中国的首都是'))

@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub")
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_neo(self):
pipe = pipeline(
task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base')


Loading…
Cancel
Save