Browse Source

init

master^2
shuaigezhu 2 years ago
parent
commit
db0f25a594
14 changed files with 1819 additions and 2 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +22
    -0
      modelscope/models/nlp/codegeex/__init__.py
  4. +1030
    -0
      modelscope/models/nlp/codegeex/codegeex.py
  5. +126
    -0
      modelscope/models/nlp/codegeex/codegeex_for_code_translation.py
  6. +335
    -0
      modelscope/models/nlp/codegeex/inference.py
  7. +186
    -0
      modelscope/models/nlp/codegeex/tokenizer.py
  8. +3
    -0
      modelscope/pipelines/nlp/__init__.py
  9. +44
    -0
      modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py
  10. +2
    -2
      modelscope/preprocessors/__init__.py
  11. +2
    -0
      modelscope/preprocessors/nlp/__init__.py
  12. +25
    -0
      modelscope/preprocessors/nlp/codegeex_preprocessor.py
  13. +1
    -0
      modelscope/utils/constant.py
  14. +38
    -0
      tests/pipelines/test_CodeGeeX_code_translation.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -84,6 +84,7 @@ class Models(object):
ponet = 'ponet'
T5 = 'T5'
mglm = 'mglm'
codegeex = 'codegeex'
bloom = 'bloom'

# audio models
@@ -255,6 +256,7 @@ class Pipelines(object):
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
mglm_text_summarization = 'mglm-text-summarization'
codegeex_code_translation = 'codegeex-code-translation'
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
@@ -382,6 +384,7 @@ class Preprocessors(object):
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
mglm_summarization = 'mglm-summarization'
codegeex = 'codegeex'
sentence_piece = 'sentence-piece'

# audio preprocessor


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -36,6 +36,7 @@ if TYPE_CHECKING:
)
from .T5 import T5ForConditionalGeneration
from .mglm import MGLMForTextSummarization
from .codegeex import CodeGeeXForCodeTranslation
from .task_models import (
FeatureExtractionModel,
InformationExtractionModel,
@@ -108,6 +109,7 @@ else:
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],
'mglm': ['MGLMForTextSummarization'],
'codegeex': ['CodeGeeXForCodeTranslation'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
}


+ 22
- 0
modelscope/models/nlp/codegeex/__init__.py View File

@@ -0,0 +1,22 @@
# Modified by Zhipu.AI
# Original Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .codegeex_for_code_translation import CodeGeeXForCodeTranslation
else:
_import_structure = {
'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 1030
- 0
modelscope/models/nlp/codegeex/codegeex.py
File diff suppressed because it is too large
View File


+ 126
- 0
modelscope/models/nlp/codegeex/codegeex_for_code_translation.py View File

@@ -0,0 +1,126 @@
# Copyright (c) 2022 Zhipu.AI

import copy
import os
import random
import time
from typing import Dict

import numpy as np
import torch
from IPython import embed

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from .codegeex import CodeGeeXModel
from .inference import get_token_stream
from .tokenizer import CodeGeeXTokenizer


def model_provider():
"""Build the model."""

hidden_size = 5120
num_attention_heads = 40
num_layers = 39
padded_vocab_size = 52224
max_position_embeddings = 2048

model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads,
padded_vocab_size, max_position_embeddings)

return model


@MODELS.register_module(Tasks.code_translation, module_name=Models.codegeex)
class CodeGeeXForCodeTranslation(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the fast poem model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

# loading tokenizer
print('Loading tokenizer ...')
self.tokenizer = CodeGeeXTokenizer(
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
# loading model
state_dict_path = model_dir + '/ckpt_ms_translation_0817.pt'
print('Loading state dict ...')
state_dict = torch.load(state_dict_path, map_location='cpu')
state_dict = state_dict['module']

print('Building CodeGeeX model ...')
self.model = model_provider()
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.half()
self.model.cuda()

def forward(self, input: Dict[str, str]) -> Dict[str, str]:
micro_batch_size = 1
seq_length = 2048
out_seq_length = 256
bad_ids = None
print('Generating ...')
src_lang = input['source language']
dst_lang = input['target language']
prompt = input['prompt']
prompt = f'code translation\n{src_lang}:\n{prompt}\n{dst_lang}:\n'
t0 = time.perf_counter()
tokenizer = self.tokenizer
model = self.model
for prompt in [prompt]:
tokens = tokenizer.encode_code(prompt)
print(tokens)
print('Current prompt:')
print(prompt)
n_token_prompt = len(tokens)
print('N_token_prompt:', n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
bad_ids=bad_ids,
greedy=True,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy(
)[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy(
).tolist()
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:])
generated_code = ''.join(generated_code)
t1 = time.perf_counter()
print('Total generation time:', t1 - t0, '# Tokens:',
len(generated_tokens_) - n_token_prompt)
print(
f'{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token'
)
print(
'================================= Generated code:'
)
print(generated_code)
t0 = time.perf_counter()
if all(is_finished):
break

print('Generation finished.')
return {OutputKeys.TEXT: generated_code}

+ 335
- 0
modelscope/models/nlp/codegeex/inference.py View File

@@ -0,0 +1,335 @@
import copy
import os
import time
import typing
from dataclasses import dataclass

import json
import torch
import torch.nn.functional as F


def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
):
"""Build masks and position id for left to right model."""

# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()

# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length),
device=data.device)).view(att_mask_batch, 1, seq_length,
seq_length)

# Position ids.
position_ids = torch.arange(
seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()

# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= i + 1 - prev_index
prev_index = i + 1

# Convert attention mask to binary:
attention_mask = attention_mask < 0.5

return attention_mask, position_ids


def get_batch(
context_tokens,
micro_batch_size,
eod_token,
reset_position_ids=False,
reset_attention_mask=False,
):
"""Generate batch from context tokens."""
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_token,
reset_position_ids,
reset_attention_mask,
)

return tokens, attention_mask, position_ids


def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""

if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value

return logits


def pad_batch(batch, pad_id, seq_length):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths


def forward_step(
model,
tokens,
seq_length,
position_ids,
attention_mask,
layer_past=None,
get_key_value=None,
prompt_length=None,
context_length=None,
):
# Forward pass through the model.
output_tensor = model(
tokens,
position_ids,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length,
)

if get_key_value:
output_tensor, layer_past = output_tensor

if get_key_value:
return output_tensor, layer_past

return output_tensor


def get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
context_tokens,
return_scores: bool = False,
prompt_length: int = None,
micro_batch_size: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
greedy: bool = False,
):
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eos_token_id,
seq_length)

context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(
context_tokens_tensor,
micro_batch_size,
tokenizer.eos_token_id,
)

batch_token_iterator = sample_sequence_batch(
model,
tokenizer,
context_tokens_tensor,
context_length_tensor,
attention_mask,
position_ids,
seq_length=seq_length,
out_seq_length=out_seq_length,
return_scores=return_scores,
prompt_length=prompt_length,
bad_ids=bad_ids,
temperature=temperature,
topp=topp,
topk=topk,
greedy=greedy,
)

for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None


def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2


def sample_sequence_batch(
model,
tokenizer,
context_tokens,
context_lengths,
attention_mask,
position_ids,
seq_length,
out_seq_length,
maxlen=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
recompute: bool = False,
greedy: bool = False,
):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.eos_token_id

counter = 0
org_context_length = context_length

layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = seq_length - 1
if maxlen > (org_context_length + out_seq_length):
maxlen = org_context_length + out_seq_length

lengths = torch.ones([batch_size]).long().cuda() * maxlen
if return_scores:
scores = torch.zeros([batch_size]).float().cuda()

while context_length <= (maxlen):

if recompute:
logits = model(
tokens,
position_ids,
attention_mask,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(
tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
prompt_length=prompt_length,
context_length=context_length,
)
logits = logits[:, -1].view(batch_size, -1).contiguous()

if bad_ids is not None:
for bad_id in bad_ids:
logits[:, bad_id] = -10000
if greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
if return_scores:
orig_log_probs = torch.log_softmax(logits, dim=-1)
logits /= temperature
logits = top_k_logits(logits, top_k=topk, top_p=topp)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)

started = context_lengths <= context_length

new_tokens = switch(tokens[:, context_length].view(-1), prev,
started)

if not greedy and return_scores:
indices = prev.view(-1, 1)
new_scores = orig_log_probs.gather(1, indices).view(-1)
new_scores = new_scores * started
new_scores = new_scores * is_done.bool().logical_not()
scores += new_scores

tokens[:, context_length] = new_tokens
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)

if return_scores:
yield tokens, (lengths, scores)
else:
yield tokens, lengths

context_length += 1
counter += 1
if done:
break

+ 186
- 0
modelscope/models/nlp/codegeex/tokenizer.py View File

@@ -0,0 +1,186 @@
import typing

import torch
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast


def encode_whitespaces(text, start_extra_id: int, max_len: int):
""" Encode whitespaces to extra tokens in GPT-J.

>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""

def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + ' '
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
extra_id = start_extra_id - 2 + acc_len
extra_token = f'<|extratoken_{extra_id}|>'
return text + extra_token

acc_len = 0
res = ''
for ch in text:
if ch == ' ':
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch

res = push_acc_space(acc_len, res)

return res


def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
""" Decode the whitespace-encoded strings produced by encode_whitespace.

>>> text = 'a\\n b\\n c'
>>> s, l = 10, 10
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
True
"""
for l in range(2, max_len + 1): # noqa
token_id = start_extra_id - 2 + l
token = f'<|extratoken_{token_id}|>'
text = text.replace(token, ' ' * l)
return text


class Code13BDictionary(object):

def __init__(
self,
dict_file: str,
extra_token_ids: List[str] = None,
pad_to_vocab_size: int = -1,
):
self._idx = dict()
self._count = dict()
self._num_symbols = 0
self._symbols = []

self._add_symbol('<s>', 0)
self._add_symbol('<pad>', 0)
self._add_symbol('</s>', 0)
self._add_symbol('<unk>', 0)
self._load_dict(dict_file)

if extra_token_ids is None:
extra_token_ids = [str(x) for x in range(50257, 50400)
] # follows GPT-J settings

for token_id in extra_token_ids:
self._add_symbol(token_id, 0)

if pad_to_vocab_size > 0:
self._pad_to_vocab_size(pad_to_vocab_size)

def _pad_to_vocab_size(self, vocab_size: int):
num_pad = vocab_size - len(self)
if num_pad <= 0:
return
for i in range(1, num_pad + 1):
self._add_symbol('vocab_pad_token{}'.format(i), 0)

def _load_dict(self, dict_file: str):
with open(dict_file, 'r') as f:
for line in f:
line = line.strip()
if line == '' or line.startswith('#'):
continue
sym, count = line.split()
self._add_symbol(sym, int(count))

def _add_symbol(self, sym: str, count: int):
self._idx[sym] = self._num_symbols
self._count[sym] = count
self._symbols.append(sym)
self._num_symbols += 1

def __len__(self):
return self._num_symbols

def index(self, sym: str):
return self._idx[sym]

def string(self, idx: int):
return self._symbols[idx]

def map_token(self, token: Union[int, str]):
if isinstance(token, int):
token = str(token)
return self.index(token)

def map_tokens(self, tokens):
return [self.map_token(token) for token in tokens]

def decode_tokens(self, tokens):
decoded = [
'50256' if token == 50256 else self.string(token)
for token in tokens
]
return [int(x) for x in decoded if not x.startswith('vocab_pad_token')]


class CodeGeeXTokenizer(object):

def __init__(
self,
tokenizer: GPT2TokenizerFast = None,
tokenizer_path: str = 'EleutherAI/gpt-j-6B',
start_extra_id: int = 10,
max_len: int = 10,
mode='codegeex-13b',
dict_file: str = None,
):
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
tokenizer_path)
if mode not in ['codegeex-13b', 'codegeex-python-13b']:
raise ValueError(
f"Invalid mode {mode}, choose from ['codegeex-13b', 'codegeex-python-13b']"
)
self.start_extra_id = start_extra_id
self.max_len = max_len
self.mode = mode
if dict_file is not None:
self.code_dict = Code13BDictionary(
dict_file, pad_to_vocab_size=51200
) if self.mode == 'codegeex-python-13b' else None
else:
self.code_dict = None
self.eos_token_id = self.tokenizer.eos_token_id

def encode_code(self, code: str):
if self.mode == 'codegeex-13b':
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
input_ids = self.tokenizer(
code, is_split_into_words=False).input_ids

elif self.mode == 'codegeex-python-13b':
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code))
input_ids = torch.LongTensor(input_ids).reshape(1, -1)

return input_ids

def decode_code(self, input_ids):
if self.mode == 'codegeex-13b':
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
output_code = decode_whitespaces(text, self.start_extra_id,
self.max_len)
elif self.mode == 'codegeex-python-13b':
input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])]
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
output_code = decode_whitespaces(text, self.start_extra_id,
self.max_len)

return output_code

+ 3
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \
WordSegmentationThaiPipeline

@@ -73,6 +74,8 @@ else:
'zero_shot_classification_pipeline':
['ZeroShotClassificationPipeline'],
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'],
'codegeex_code_translation_pipeline':
['CodeGeeXCodeTranslationPipeline'],
'multilingual_word_segmentation_pipeline': [
'MultilingualWordSegmentationPipeline',
'WordSegmentationThaiPipeline'


+ 44
- 0
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py View File

@@ -0,0 +1,44 @@
# Copyright (c) 2022 Zhipu.AI

from typing import Any, Dict, Optional, Union

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.models.nlp import CodeGeeXForCodeTranslation
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import CodeGeeXPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
group_key=Tasks.code_translation,
module_name=Pipelines.codegeex_code_translation)
class CodeGeeXCodeTranslationPipeline(Pipeline):

def __init__(self,
model: Union[CodeGeeXForCodeTranslation, str],
preprocessor: [Preprocessor] = None,
*args,
**kwargs):
model = CodeGeeXForCodeTranslation(model) if isinstance(model,
str) else model
self.model = model
self.model.eval()
self.model.half()
self.model.cuda()
if preprocessor is None:
preprocessor = CodeGeeXPreprocessor()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

# define the forward pass
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
# check input format
for para in ['prompt', 'source language', 'target language']:
if para not in inputs:
return ('please check your input format.')
return self.model(inputs)

# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input

+ 2
- 2
modelscope/preprocessors/__init__.py View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor,
TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor,
TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
WordSegmentationBlankSetToLabelPreprocessor, CodeGeeXPreprocessor,
MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor,
TextGenerationJiebaPreprocessor, SentencePiecePreprocessor,
DialogIntentPredictionPreprocessor, DialogModelingPreprocessor,
@@ -57,7 +57,7 @@ else:
'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor',
'Tokenize', 'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'MGLMSummarizationPreprocessor',
'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor',
'ZeroShotClassificationPreprocessor',
'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor',
'NERPreprocessorViet', 'NERPreprocessorThai',


+ 2
- 0
modelscope/preprocessors/nlp/__init__.py View File

@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from .space_T_en import ConversationalTextToSqlPreprocessor
from .space_T_cn import TableQuestionAnsweringPreprocessor
from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor
from .codegeex_preprocessor import CodeGeeXPreprocessor
else:
_import_structure = {
'nlp_base': [
@@ -64,6 +65,7 @@ else:
'TextErrorCorrectionPreprocessor',
],
'mglm_summarization_preprocessor': ['MGLMSummarizationPreprocessor'],
'codegeex_preprocessor': ['CodeGeeXPreprocessor'],
'token_classification_thai_preprocessor': [
'NERPreprocessorThai',
'WordSegmentationPreprocessorThai',


+ 25
- 0
modelscope/preprocessors/nlp/codegeex_preprocessor.py View File

@@ -0,0 +1,25 @@
# Copyright (c) 2022 Zhipu.AI

import re
from typing import Any, Dict, Iterable, Optional, Tuple, Union

from modelscope.metainfo import Models, Preprocessors
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile
from modelscope.utils.type_assert import type_assert


@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.codegeex)
class CodeGeeXPreprocessor(Preprocessor):

def __init__(self, *args, **kwargs):
"""preprocess the data
Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
return data

+ 1
- 0
modelscope/utils/constant.py View File

@@ -120,6 +120,7 @@ class NLPTasks(object):
fill_mask = 'fill-mask'
text_summarization = 'text-summarization'
question_answering = 'question-answering'
code_translation = 'code-translation'
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'


+ 38
- 0
tests/pipelines/test_CodeGeeX_code_translation.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest

from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.preprocessors import CodeGeeXPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.output_dir = 'unittest_output'
os.makedirs(self.output_dir, exist_ok=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_CodeGeeX_with_name(self):
model = 'ZhipuAI/CodeGeeX-Code-Translation-13B'
preprocessor = CodeGeeXPreprocessor()
pipe = pipeline(
task=Tasks.code_translation,
model=model,
preprocessor=preprocessor,
)
inputs = {
'prompt': 'for i in range(10):\n\tprint(i)\n',
'source language': 'Python',
'target language': 'C++'
}
result = pipe(inputs)
print(result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save