@@ -26,6 +26,7 @@ class Models(object): | |||
space = 'space' | |||
tcrf = 'transformer-crf' | |||
bart = 'bart' | |||
gpt3 = 'gpt3' | |||
# audio models | |||
sambert_hifigan = 'sambert-hifigan' | |||
@@ -160,7 +161,7 @@ class Preprocessors(object): | |||
# nlp preprocessor | |||
sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | |||
text_gen_tokenizer = 'text-gen-tokenizer' | |||
token_cls_tokenizer = 'token-cls-tokenizer' | |||
ner_tokenizer = 'ner-tokenizer' | |||
nli_tokenizer = 'nli-tokenizer' | |||
@@ -4,7 +4,8 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase) | |||
from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase, | |||
GPT3Model) | |||
from .heads import SequenceClassificationHead | |||
from .bert_for_sequence_classification import BertForSequenceClassification | |||
from .csanmt_for_translation import CsanmtForTranslation | |||
@@ -23,10 +24,12 @@ if TYPE_CHECKING: | |||
from .space_for_dialog_state_tracking import SpaceForDialogStateTracking | |||
from .task_model import SingleBackboneTaskModelBase | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .gpt3_for_text_generation import GPT3ForTextGeneration | |||
else: | |||
_import_structure = { | |||
'backbones': ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'], | |||
'backbones': | |||
['SbertModel', 'SpaceGenerator', 'SpaceModelBase', 'GPT3Model'], | |||
'heads': ['SequenceClassificationHead'], | |||
'csanmt_for_translation': ['CsanmtForTranslation'], | |||
'bert_for_sequence_classification': ['BertForSequenceClassification'], | |||
@@ -48,6 +51,7 @@ else: | |||
'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], | |||
'task_model': ['SingleBackboneTaskModelBase'], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'gpt3_for_text_generation': ['GPT3ForTextGeneration'], | |||
} | |||
import sys | |||
@@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .space import SpaceGenerator, SpaceModelBase | |||
from .structbert import SbertModel | |||
from .gpt3 import GPT3Model | |||
else: | |||
_import_structure = { | |||
'space': ['SpaceGenerator', 'SpaceModelBase'], | |||
'structbert': ['SbertModel'] | |||
'structbert': ['SbertModel'], | |||
'gpt3': ['GPT3Model'] | |||
} | |||
import sys | |||
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_gpt3 import GPT3Config | |||
from .modeling_gpt3 import GPT3Model | |||
else: | |||
_import_structure = { | |||
'configuration_gpt3': ['GPT3Config'], | |||
'modeling_gpt3': ['GPT3Model'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,51 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from transformers.configuration_utils import PretrainedConfig | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
class GPT3Config(PretrainedConfig): | |||
model_type = 'gpt' | |||
def __init__(self, | |||
vocab_size=25600, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act='gelu', | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=2048, | |||
type_vocab_size=2, | |||
layernorm_epsilon=1e-12, | |||
**kwargs): | |||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.layernorm_epsilon = layernorm_epsilon |
@@ -0,0 +1,337 @@ | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import math | |||
import os | |||
from typing import Optional, Union | |||
import torch | |||
from addict import Dict | |||
from torch.nn import Dropout, Embedding, LayerNorm, Linear, Module, Softmax | |||
from torch.nn import functional as F | |||
from transformers.modeling_utils import PreTrainedModel | |||
from modelscope.utils.constant import ModelFile | |||
from .configuration_gpt3 import GPT3Config | |||
class GPT3SelfAttention(Module): | |||
"""Parallel self-attention layer abstract class. | |||
Self-attention layer takes input with size [s, b, h] | |||
and returns output of the same size. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.hidden_size = config.hidden_size | |||
self.num_attention_heads = config.num_attention_heads | |||
# Per attention head | |||
self.hidden_size_per_attention_head = \ | |||
self.hidden_size // self.num_attention_heads | |||
self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size) | |||
self.softmax = Softmax(dim=-1) | |||
self.attention_dropout = Dropout(config.attention_probs_dropout_prob) | |||
# Output. | |||
self.dense = Linear(self.hidden_size, self.hidden_size) | |||
self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |||
def _transpose_for_scores(self, tensor): | |||
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with | |||
size [b, np, s, hn]. | |||
""" | |||
new_tensor_shape = tensor.size()[:-1] + ( | |||
self.num_attention_heads, self.hidden_size_per_attention_head) | |||
tensor = tensor.view(*new_tensor_shape) | |||
return tensor.permute(0, 2, 1, 3) | |||
def _split_tensor_along_last_dim(self, | |||
tensor, | |||
num_partitions, | |||
contiguous_split_chunks=False): | |||
# Get the size and dimension. | |||
last_dim = tensor.dim() - 1 | |||
last_dim_size = tensor.size()[last_dim] // num_partitions | |||
# Split. | |||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||
# Note: torch.split does not create contiguous tensors by default. | |||
if contiguous_split_chunks: | |||
return tuple(chunk.contiguous() for chunk in tensor_list) | |||
return tensor_list | |||
def forward(self, hidden_states, ltor_mask, is_infer=False): | |||
# hidden_states: [b, s, h] | |||
# ltor_mask: [1, 1, s, s] | |||
# Attention heads. [b, s, hp] | |||
tgt_len = hidden_states.size(1) | |||
ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) | |||
mixed_x_layer = self.query_key_value(hidden_states) | |||
(mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ | |||
self._split_tensor_along_last_dim(mixed_x_layer, 3) | |||
# Reshape and transpose [b, np, s, hn] | |||
query_layer = self._transpose_for_scores(mixed_query_layer) | |||
key_layer = self._transpose_for_scores(mixed_key_layer) | |||
value_layer = self._transpose_for_scores(mixed_value_layer) | |||
previous_type = value_layer.type() | |||
# Raw attention scores. [b, np, s, s] | |||
attention_scores = torch.matmul(query_layer, | |||
key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt( | |||
self.hidden_size_per_attention_head) | |||
# Apply the left to right attention mask. | |||
if is_infer: | |||
src_len = key_layer.size(2) | |||
ltor_mask = torch.tril( | |||
torch.ones((1, tgt_len, src_len), | |||
device=hidden_states.device)).view( | |||
1, 1, tgt_len, src_len).type(previous_type) | |||
converted_mask = 10000.0 * (1.0 - ltor_mask) | |||
attention_scores = (torch.mul(attention_scores, ltor_mask) | |||
- converted_mask).type(previous_type) | |||
# Attention probabilities. [b, np, s, s] | |||
attention_probs = self.softmax(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.attention_dropout(attention_probs) | |||
# Context layer. | |||
# [b, np, s, hn] | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
# [b, s, np, hn] | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||
self.hidden_size, ) | |||
# [b, s, hp] | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
# Output. [b, s, h] | |||
output = self.dense(context_layer) | |||
output = self.output_dropout(output) | |||
return output | |||
class GPT3MLP(Module): | |||
"""MLP. | |||
MLP will take the input with h hidden state, project it to 4*h | |||
hidden dimension, perform nonlinear transformation, and project the | |||
state back into h hidden dimension. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
hidden_size = config.hidden_size | |||
# Project to 4h. | |||
self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size) | |||
self.activation_func = F.gelu | |||
# Project back to h. | |||
self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size) | |||
self.dropout = Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states): | |||
# [s, b, 4hp] | |||
intermediate_parallel = self.dense_h_to_4h(hidden_states) | |||
intermediate_parallel = self.activation_func(intermediate_parallel) | |||
# [s, b, h] | |||
output = self.dense_4h_to_h(intermediate_parallel) | |||
output = self.dropout(output) | |||
return output | |||
class GPT3TransformerLayer(Module): | |||
"""A single transformer layer. | |||
Transformer layer takes input with size [s, b, h] and returns an | |||
output of the same size. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
# Layernorm on the input data. | |||
self.input_layernorm = LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
# Self attention. | |||
self.attention = GPT3SelfAttention(config) | |||
# Layernorm on the attention output | |||
self.post_attention_layernorm = LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
# MLP | |||
self.mlp = GPT3MLP(config) | |||
def forward(self, hidden_states, ltor_mask): | |||
# hidden_states: [b, s, h] | |||
# ltor_mask: [1, 1, s, s] | |||
# Layer norm at the begining of the transformer layer. | |||
layernorm_output = self.input_layernorm(hidden_states) | |||
# Self attention. | |||
attention_output = self.attention(layernorm_output, ltor_mask) | |||
# Residual connection. | |||
layernorm_input = hidden_states + attention_output | |||
# Layer norm post the self attention. | |||
layernorm_output = self.post_attention_layernorm(layernorm_input) | |||
# MLP. | |||
mlp_output = self.mlp(layernorm_output) | |||
# Second residual connection. | |||
output = layernorm_input + mlp_output | |||
return output | |||
class GPT3Transformer(Module): | |||
"""Transformer class.""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.input_tensor = None | |||
# Number of layers. | |||
self.num_layers = config.num_hidden_layers | |||
self.layers = torch.nn.ModuleList( | |||
[GPT3TransformerLayer(config) for _ in range(self.num_layers)]) | |||
# Final layer norm before output. | |||
self.final_layernorm = LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
def _get_layer(self, layer_number): | |||
return self.layers[layer_number] | |||
def forward(self, hidden_states, attention_mask): | |||
# hidden_states: [s, b, h] | |||
for index in range(self.num_layers): | |||
layer = self._get_layer(index) | |||
hidden_states = layer(hidden_states, attention_mask) | |||
# Final layer norm. | |||
hidden_states = self.final_layernorm(hidden_states) | |||
return hidden_states | |||
class GPT3TransformerLanguageModel(Module): | |||
"""Transformer language model. | |||
Arguments: | |||
transformer_hparams: transformer hyperparameters | |||
vocab_size: vocabulary size | |||
max_sequence_length: maximum size of sequence. This | |||
is used for positional embedding | |||
embedding_dropout_prob: dropout probability for embeddings | |||
num_tokentypes: size of the token-type embeddings. 0 value | |||
will ignore this embedding | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
# Embeddings. | |||
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) | |||
self.position_embeddings = Embedding(config.max_position_embeddings, | |||
config.hidden_size) | |||
self.embedding_dropout = Dropout(config.hidden_dropout_prob) | |||
# Transformer. | |||
self.transformer = GPT3Transformer(config) | |||
def forward(self, input_ids, attention_mask, position_ids): | |||
words_embeddings = self.word_embeddings(input_ids) | |||
position_embeddings = self.position_embeddings(position_ids) | |||
embeddings = words_embeddings + position_embeddings | |||
transformer_input = self.embedding_dropout(embeddings) | |||
transformer_output = self.transformer(transformer_input, | |||
attention_mask) | |||
logits = F.linear(transformer_output, self.word_embeddings.weight) | |||
return logits | |||
class GPT3Model(PreTrainedModel): | |||
config_class = GPT3Config | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, Embedding): | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.language_model = GPT3TransformerLanguageModel(config) | |||
def forward(self, | |||
input_ids, | |||
attention_mask=None, | |||
position_ids=None, | |||
**kwargs): | |||
seq_length = input_ids.size(1) | |||
if attention_mask is None: | |||
attention_mask = torch.tril( | |||
torch.ones((1, seq_length, seq_length), | |||
dtype=torch.long, | |||
device=input_ids.device)) | |||
if position_ids is None: | |||
position_ids = torch.arange( | |||
seq_length, dtype=torch.long, device=input_ids.device) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
logits = self.language_model(input_ids, attention_mask, position_ids) | |||
return Dict(logits=logits) | |||
@classmethod | |||
def from_pretrained( | |||
cls, pretrained_model_name_or_path: Optional[Union[str, | |||
os.PathLike]]): | |||
config = cls.config_class.from_pretrained( | |||
pretrained_model_name_or_path) | |||
model = cls(config) | |||
state_dict_file = os.path.join(pretrained_model_name_or_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
state_dict = torch.load(state_dict_file) | |||
model.load_state_dict(state_dict) | |||
return model |
@@ -1 +1,19 @@ | |||
from .modeling_sbert import SbertModel | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .modeling_sbert import SbertModel | |||
else: | |||
_import_structure = {'modeling_sbert': ['SbertModel']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,56 @@ | |||
from typing import Dict | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['GPT3ForTextGeneration'] | |||
@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3) | |||
class GPT3ForTextGeneration(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the text generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from modelscope.models.nlp import GPT3Model | |||
from transformers import BertTokenizer | |||
self.model = GPT3Model.from_pretrained(model_dir) | |||
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'logits': Tensor([[0.54, 0.32...])]), # logits | |||
} | |||
""" | |||
return self.model(**input) | |||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||
assert 'input_ids' in input, "generate function must accept 'input_ids' key" | |||
gen_params = dict() | |||
gen_params['inputs'] = input['input_ids'] | |||
gen_params['do_sample'] = input.pop('do_sample', True) | |||
gen_params['max_length'] = input.pop('max_length', 128) | |||
gen_params['top_k'] = input.pop('top_k', 10) | |||
gen_params['top_p'] = input.pop('top_p', None) | |||
sample_output = self.model.generate(**gen_params) | |||
return { | |||
OutputKeys.TEXT: | |||
self.tokenizer.decode(sample_output[0], skip_special_tokens=True) | |||
} |
@@ -1,8 +1,9 @@ | |||
from typing import Dict | |||
from typing import Dict, List | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['PalmForTextGeneration'] | |||
@@ -27,8 +28,7 @@ class PalmForTextGeneration(TorchModel): | |||
self.tokenizer = self.model.tokenizer | |||
self.generator = Translator(self.model) | |||
def _evaluate_postprocess(self, src: Tensor, tgt: Tensor, | |||
mask_src: Tensor) -> Dict[str, str]: | |||
def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
@@ -36,29 +36,14 @@ class PalmForTextGeneration(TorchModel): | |||
''), | |||
('<s>', ''), ('</s>', ''), ('<unk>', ' ')) | |||
inputs = self.generator(src, mask_src) | |||
pred_list = inputs['predictions'] | |||
pred_id_list = [ | |||
pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list | |||
] | |||
tgt_id_list = tgt.cpu().numpy().tolist() | |||
pred_strings = [ | |||
self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list | |||
] | |||
tgt_strings = [ | |||
self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list | |||
] | |||
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | |||
for _old, _new in replace_tokens_bert: | |||
pred_strings = [s.replace(_old, _new) for s in pred_strings] | |||
tgt_strings = [s.replace(_old, _new) for s in tgt_strings] | |||
strings = [s.replace(_old, _new) for s in strings] | |||
for _old, _new in replace_tokens_roberta: | |||
pred_strings = [s.replace(_old, _new) for s in pred_strings] | |||
tgt_strings = [s.replace(_old, _new) for s in tgt_strings] | |||
for s in pred_strings: | |||
strings = [s.replace(_old, _new) for s in strings] | |||
for s in strings: | |||
s.strip() | |||
for s in tgt_strings: | |||
s.strip() | |||
return {'preds': pred_strings, 'tgts': tgt_strings} | |||
return strings | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
@@ -70,12 +55,30 @@ class PalmForTextGeneration(TorchModel): | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer | |||
'loss': Tensor([12.34]), # loss for backward | |||
} | |||
or | |||
{ | |||
'preds': List["hello word"...] # the predicted strings | |||
'tgts': List["hello world"...] # target strings | |||
} | |||
""" | |||
if self.training: | |||
return {'loss': self.model(**input)} | |||
elif 'tgt' in input: | |||
return self._evaluate_postprocess(**input) | |||
else: | |||
return self.generator(**input) | |||
outputs = self.generator(input['src'], input['mask_src']) | |||
preds = outputs['predictions'] | |||
pred_ids_list = [ | |||
pred_batch[0].cpu().numpy().tolist() for pred_batch in preds | |||
] | |||
tgt_ids_list = input['tgt'].cpu().numpy().tolist() | |||
return { | |||
'preds': self._evaluate_postprocess(pred_ids_list), | |||
'tgts': self._evaluate_postprocess(tgt_ids_list) | |||
} | |||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||
outputs = self.generator(**input) | |||
preds = outputs['predictions'] | |||
pred_ids_list = [preds[0][0].cpu().numpy().tolist()] | |||
return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]} |
@@ -3,9 +3,7 @@ from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import PalmForTextGeneration | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.models.base import TorchModel | |||
from modelscope.pipelines.base import Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import TextGenerationPreprocessor | |||
@@ -19,7 +17,7 @@ __all__ = ['TextGenerationPipeline'] | |||
class TextGenerationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[PalmForTextGeneration, str], | |||
model: Union[TorchModel, str], | |||
preprocessor: Optional[TextGenerationPreprocessor] = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction | |||
@@ -29,21 +27,19 @@ class TextGenerationPipeline(Pipeline): | |||
preprocessor (TextGenerationPreprocessor): a preprocessor instance | |||
""" | |||
model = model if isinstance( | |||
model, PalmForTextGeneration) else Model.from_pretrained(model) | |||
model, TorchModel) else TorchModel.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = TextGenerationPreprocessor( | |||
model.model_dir, | |||
model.tokenizer, | |||
first_sequence='sentence', | |||
second_sequence=None) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = model.tokenizer | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
return self.model.generate(inputs) | |||
def postprocess(self, inputs: Dict[str, Tensor], | |||
**postprocess_params) -> Dict[str, str]: | |||
@@ -55,20 +51,4 @@ class TextGenerationPipeline(Pipeline): | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>', | |||
''), | |||
('<s>', ''), ('</s>', ''), ('<unk>', ' ')) | |||
pred_list = inputs['predictions'] | |||
pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||
pred_string = self.tokenizer.decode(pred_ids) | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string.strip() | |||
for _old, _new in replace_tokens_roberta: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string.strip() | |||
return {OutputKeys.TEXT: pred_string} | |||
return inputs |
@@ -2,7 +2,7 @@ | |||
import os.path as osp | |||
import uuid | |||
from typing import Any, Dict, Union | |||
from typing import Any, Dict, Optional, Union | |||
from transformers import AutoTokenizer | |||
@@ -211,36 +211,34 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) | |||
Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) | |||
class TextGenerationPreprocessor(NLPPreprocessorBase): | |||
def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs): | |||
self.tokenizer = self.build_tokenizer( | |||
model_dir) if tokenizer is None else tokenizer | |||
kwargs['truncation'] = True | |||
kwargs['padding'] = 'max_length' | |||
kwargs['padding'] = True | |||
kwargs['return_tensors'] = 'pt' | |||
kwargs['return_token_type_ids'] = False | |||
kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
super().__init__(model_dir, *args, **kwargs) | |||
def build_tokenizer(self, model_dir: str): | |||
@staticmethod | |||
def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: | |||
import os | |||
from sofa.models.palm_v2 import PalmConfig | |||
for name in os.listdir(model_dir): | |||
full_name = os.path.join(model_dir, name) | |||
if 'roberta' in name and os.path.isdir(full_name): | |||
return full_name | |||
config_file = os.path.join(model_dir, 'config.json') | |||
config = PalmConfig.from_json_file(config_file) if os.path.isfile( | |||
config_file) else PalmConfig() | |||
config.encoder_pth = os.path.join(model_dir, config.encoder_pth) | |||
if config.encoder == 'roberta': | |||
def build_tokenizer(self, model_dir: str): | |||
roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir) | |||
if roberta_tokenizer_dir: | |||
from transformers import RobertaTokenizer | |||
tokenizer = RobertaTokenizer.from_pretrained( | |||
config.encoder_pth, do_lower_case=False) | |||
elif config.encoder == 'bert' or config.encoder == 'zh_bert': | |||
from transformers import BertTokenizer | |||
tokenizer = BertTokenizer.from_pretrained( | |||
config.encoder_pth, do_lower_case=True) | |||
return tokenizer | |||
return RobertaTokenizer.from_pretrained( | |||
roberta_tokenizer_dir, do_lower_case=False) | |||
return super().build_tokenizer(model_dir) | |||
@PREPROCESSORS.register_module( | |||
@@ -3,7 +3,7 @@ import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import PalmForTextGeneration | |||
from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.nlp import TextGenerationPipeline | |||
from modelscope.preprocessors import TextGenerationPreprocessor | |||
@@ -12,26 +12,32 @@ from modelscope.utils.test_utils import test_level | |||
class TextGenerationTest(unittest.TestCase): | |||
model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' | |||
model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' | |||
input_zh = """ | |||
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: | |||
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 | |||
""" | |||
input_en = """ | |||
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started | |||
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , | |||
54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges | |||
of paedophilia against nine children because he has dementia . Today , newly-released documents | |||
revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . | |||
And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her | |||
pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . | |||
""" | |||
def setUp(self) -> None: | |||
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' | |||
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' | |||
self.palm_input_zh = """ | |||
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: | |||
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 | |||
""" | |||
self.palm_input_en = """ | |||
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started | |||
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , | |||
54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges | |||
of paedophilia against nine children because he has dementia . Today , newly-released documents | |||
revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . | |||
And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her | |||
pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . | |||
""" | |||
self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' | |||
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' | |||
self.gpt3_input = '我很好奇' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run(self): | |||
for model_id, input in ((self.model_id_zh, self.input_zh), | |||
(self.model_id_en, self.input_en)): | |||
def test_run_palm(self): | |||
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), | |||
(self.palm_model_id_en, self.palm_input_en)): | |||
cache_path = snapshot_download(model_id) | |||
model = PalmForTextGeneration(cache_path) | |||
preprocessor = TextGenerationPreprocessor( | |||
@@ -46,10 +52,28 @@ class TextGenerationTest(unittest.TestCase): | |||
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}' | |||
) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_gpt3(self): | |||
cache_path = snapshot_download(self.gpt3_base_model_id) | |||
model = GPT3ForTextGeneration(cache_path) | |||
preprocessor = TextGenerationPreprocessor( | |||
cache_path, | |||
model.tokenizer, | |||
first_sequence='sentence', | |||
second_sequence=None) | |||
pipeline1 = TextGenerationPipeline(model, preprocessor) | |||
pipeline2 = pipeline( | |||
Tasks.text_generation, model=model, preprocessor=preprocessor) | |||
print( | |||
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' | |||
) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
for model_id, input in ((self.model_id_zh, self.input_zh), | |||
(self.model_id_en, self.input_en)): | |||
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), | |||
(self.palm_model_id_en, self.palm_input_en), | |||
(self.gpt3_base_model_id, self.gpt3_input), | |||
(self.gpt3_large_model_id, self.gpt3_input)): | |||
model = Model.from_pretrained(model_id) | |||
preprocessor = TextGenerationPreprocessor( | |||
model.model_dir, | |||
@@ -62,17 +86,19 @@ class TextGenerationTest(unittest.TestCase): | |||
preprocessor=preprocessor) | |||
print(pipeline_ins(input)) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
for model_id, input in ((self.model_id_zh, self.input_zh), | |||
(self.model_id_en, self.input_en)): | |||
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), | |||
(self.palm_model_id_en, self.palm_input_en), | |||
(self.gpt3_base_model_id, self.gpt3_input), | |||
(self.gpt3_large_model_id, self.gpt3_input)): | |||
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) | |||
print(pipeline_ins(input)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.text_generation) | |||
print(pipeline_ins(self.input_zh)) | |||
print(pipeline_ins(self.palm_input_zh)) | |||
if __name__ == '__main__': | |||