CodeGeex code translation and generation ut failed due to a known run.py environment setup issue that is being fixed. nothing to do with the change itself.master^2
| @@ -84,6 +84,7 @@ class Models(object): | |||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| T5 = 'T5' | T5 = 'T5' | ||||
| mglm = 'mglm' | mglm = 'mglm' | ||||
| codegeex = 'codegeex' | |||||
| bloom = 'bloom' | bloom = 'bloom' | ||||
| # audio models | # audio models | ||||
| @@ -256,6 +257,8 @@ class Pipelines(object): | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_text_summarization = 'mglm-text-summarization' | mglm_text_summarization = 'mglm-text-summarization' | ||||
| codegeex_code_translation = 'codegeex-code-translation' | |||||
| codegeex_code_generation = 'codegeex-code-generation' | |||||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | translation_en_to_de = 'translation_en_to_de' # keep it underscore | ||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | ||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | ||||
| @@ -36,6 +36,7 @@ if TYPE_CHECKING: | |||||
| ) | ) | ||||
| from .T5 import T5ForConditionalGeneration | from .T5 import T5ForConditionalGeneration | ||||
| from .mglm import MGLMForTextSummarization | from .mglm import MGLMForTextSummarization | ||||
| from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration | |||||
| from .task_models import ( | from .task_models import ( | ||||
| FeatureExtractionModel, | FeatureExtractionModel, | ||||
| InformationExtractionModel, | InformationExtractionModel, | ||||
| @@ -108,6 +109,8 @@ else: | |||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| 'mglm': ['MGLMForTextSummarization'], | 'mglm': ['MGLMForTextSummarization'], | ||||
| 'codegeex': | |||||
| ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | 'bloom': ['BloomModel'], | ||||
| } | } | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Modified by Zhipu.AI | |||||
| # Original Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING, Union | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .codegeex_for_code_translation import CodeGeeXForCodeTranslation | |||||
| from .codegeex_for_code_generation import CodeGeeXForCodeGeneration | |||||
| else: | |||||
| _import_structure = { | |||||
| 'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'], | |||||
| 'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,110 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import copy | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .codegeex import CodeGeeXModel | |||||
| from .inference import get_token_stream | |||||
| from .tokenizer import CodeGeeXTokenizer | |||||
| def model_provider(): | |||||
| """Build the model.""" | |||||
| hidden_size = 5120 | |||||
| num_attention_heads = 40 | |||||
| num_layers = 39 | |||||
| padded_vocab_size = 52224 | |||||
| max_position_embeddings = 2048 | |||||
| model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||||
| padded_vocab_size, max_position_embeddings) | |||||
| return model | |||||
| @MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex) | |||||
| class CodeGeeXForCodeGeneration(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the fast poem model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| logger = get_logger() | |||||
| # loading tokenizer | |||||
| logger.info('Loading tokenizer ...') | |||||
| self.tokenizer = CodeGeeXTokenizer( | |||||
| tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||||
| # loading model | |||||
| state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt' | |||||
| logger.info('Loading state dict ...') | |||||
| state_dict = torch.load(state_dict_path, map_location='cpu') | |||||
| state_dict = state_dict['module'] | |||||
| logger.info('Building CodeGeeX model ...') | |||||
| self.model = model_provider() | |||||
| self.model.load_state_dict(state_dict) | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||||
| micro_batch_size = 1 | |||||
| seq_length = 2048 | |||||
| out_seq_length = 256 | |||||
| bad_ids = None | |||||
| lang = input['language'] | |||||
| prompt = input['prompt'] | |||||
| prompt = f'# language: {lang}\n{prompt}' | |||||
| logger = get_logger() | |||||
| tokenizer = self.tokenizer | |||||
| model = self.model | |||||
| for prompt in [prompt]: | |||||
| tokens = tokenizer.encode_code(prompt) | |||||
| n_token_prompt = len(tokens) | |||||
| token_stream = get_token_stream( | |||||
| model, | |||||
| tokenizer, | |||||
| seq_length, | |||||
| out_seq_length, | |||||
| [copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||||
| micro_batch_size=micro_batch_size, | |||||
| bad_ids=bad_ids, | |||||
| topk=1, | |||||
| topp=0.9, | |||||
| temperature=0.9, | |||||
| greedy=True) | |||||
| is_finished = [False for _ in range(micro_batch_size)] | |||||
| for i, generated in enumerate(token_stream): | |||||
| generated_tokens = generated[0] | |||||
| for j in range(micro_batch_size): | |||||
| if is_finished[j]: | |||||
| continue | |||||
| if generated_tokens[j].cpu().numpy( | |||||
| )[-1] == tokenizer.eos_token_id or len( | |||||
| generated_tokens[j]) >= out_seq_length: | |||||
| is_finished[j] = True | |||||
| generated_tokens_ = generated_tokens[j].cpu().numpy( | |||||
| ).tolist() | |||||
| generated_code = tokenizer.decode_code( | |||||
| generated_tokens_[n_token_prompt:]) | |||||
| generated_code = ''.join(generated_code) | |||||
| logger.info( | |||||
| '================================= Generated code:' | |||||
| ) | |||||
| logger.info(generated_code) | |||||
| if all(is_finished): | |||||
| break | |||||
| logger.info('Generation finished.') | |||||
| return {OutputKeys.TEXT: generated_code} | |||||
| @@ -0,0 +1,109 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import copy | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .codegeex import CodeGeeXModel | |||||
| from .inference import get_token_stream | |||||
| from .tokenizer import CodeGeeXTokenizer | |||||
| def model_provider(): | |||||
| """Build the model.""" | |||||
| hidden_size = 5120 | |||||
| num_attention_heads = 40 | |||||
| num_layers = 39 | |||||
| padded_vocab_size = 52224 | |||||
| max_position_embeddings = 2048 | |||||
| model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||||
| padded_vocab_size, max_position_embeddings) | |||||
| return model | |||||
| @MODELS.register_module(Tasks.code_translation, module_name=Models.codegeex) | |||||
| class CodeGeeXForCodeTranslation(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the fast poem model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| logger = get_logger() | |||||
| # loading tokenizer | |||||
| logger.info('Loading tokenizer ...') | |||||
| self.tokenizer = CodeGeeXTokenizer( | |||||
| tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||||
| # loading model | |||||
| state_dict_path = model_dir + '/ckpt_ms_translation_0817.pt' | |||||
| logger.info('Loading state dict ...') | |||||
| state_dict = torch.load(state_dict_path, map_location='cpu') | |||||
| state_dict = state_dict['module'] | |||||
| logger.info('Building CodeGeeX model ...') | |||||
| self.model = model_provider() | |||||
| self.model.load_state_dict(state_dict) | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||||
| micro_batch_size = 1 | |||||
| seq_length = 2048 | |||||
| out_seq_length = 256 | |||||
| bad_ids = None | |||||
| src_lang = input['source language'] | |||||
| dst_lang = input['target language'] | |||||
| prompt = input['prompt'] | |||||
| prompt = f'code translation\n{src_lang}:\n{prompt}\n{dst_lang}:\n' | |||||
| logger = get_logger() | |||||
| tokenizer = self.tokenizer | |||||
| model = self.model | |||||
| for prompt in [prompt]: | |||||
| tokens = tokenizer.encode_code(prompt) | |||||
| n_token_prompt = len(tokens) | |||||
| token_stream = get_token_stream( | |||||
| model, | |||||
| tokenizer, | |||||
| seq_length, | |||||
| out_seq_length, | |||||
| [copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||||
| micro_batch_size=micro_batch_size, | |||||
| bad_ids=bad_ids, | |||||
| greedy=True, | |||||
| ) | |||||
| is_finished = [False for _ in range(micro_batch_size)] | |||||
| for i, generated in enumerate(token_stream): | |||||
| generated_tokens = generated[0] | |||||
| for j in range(micro_batch_size): | |||||
| if is_finished[j]: | |||||
| continue | |||||
| if generated_tokens[j].cpu().numpy( | |||||
| )[-1] == tokenizer.eos_token_id or len( | |||||
| generated_tokens[j]) >= out_seq_length: | |||||
| is_finished[j] = True | |||||
| generated_tokens_ = generated_tokens[j].cpu().numpy( | |||||
| ).tolist() | |||||
| generated_code = tokenizer.decode_code( | |||||
| generated_tokens_[n_token_prompt:]) | |||||
| generated_code = ''.join(generated_code) | |||||
| logger.info( | |||||
| '================================= Generated code:' | |||||
| ) | |||||
| logger.info(generated_code) | |||||
| if all(is_finished): | |||||
| break | |||||
| logger.info('Generation finished.') | |||||
| return {OutputKeys.TEXT: generated_code} | |||||
| @@ -0,0 +1,301 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import List | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| def get_ltor_masks_and_position_ids( | |||||
| data, | |||||
| eod_token, | |||||
| reset_position_ids, | |||||
| reset_attention_mask, | |||||
| ): | |||||
| """Build masks and position id for left to right model.""" | |||||
| # Extract batch size and sequence length. | |||||
| micro_batch_size, seq_length = data.size() | |||||
| # Attention mask (lower triangular). | |||||
| if reset_attention_mask: | |||||
| att_mask_batch = micro_batch_size | |||||
| else: | |||||
| att_mask_batch = 1 | |||||
| attention_mask = torch.tril( | |||||
| torch.ones((att_mask_batch, seq_length, seq_length), | |||||
| device=data.device)).view(att_mask_batch, 1, seq_length, | |||||
| seq_length) | |||||
| # Position ids. | |||||
| position_ids = torch.arange( | |||||
| seq_length, dtype=torch.long, device=data.device) | |||||
| position_ids = position_ids.unsqueeze(0).expand_as(data) | |||||
| # We need to clone as the ids will be modifed based on batch index. | |||||
| if reset_position_ids: | |||||
| position_ids = position_ids.clone() | |||||
| if reset_position_ids or reset_attention_mask: | |||||
| # Loop through the batches: | |||||
| for b in range(micro_batch_size): | |||||
| # Find indecies where EOD token is. | |||||
| eod_index = position_ids[b, data[b] == eod_token] | |||||
| # Detach indecies from positions if going to modify positions. | |||||
| if reset_position_ids: | |||||
| eod_index = eod_index.clone() | |||||
| # Loop through EOD indecies: | |||||
| prev_index = 0 | |||||
| for j in range(eod_index.size()[0]): | |||||
| i = eod_index[j] | |||||
| # Mask attention loss. | |||||
| if reset_attention_mask: | |||||
| attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 | |||||
| # Reset positions. | |||||
| if reset_position_ids: | |||||
| position_ids[b, (i + 1):] -= i + 1 - prev_index | |||||
| prev_index = i + 1 | |||||
| # Convert attention mask to binary: | |||||
| attention_mask = attention_mask < 0.5 | |||||
| return attention_mask, position_ids | |||||
| def get_batch( | |||||
| context_tokens, | |||||
| micro_batch_size, | |||||
| eod_token, | |||||
| reset_position_ids=False, | |||||
| reset_attention_mask=False, | |||||
| ): | |||||
| """Generate batch from context tokens.""" | |||||
| tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda() | |||||
| # Get the attention mask and postition ids. | |||||
| attention_mask, position_ids = get_ltor_masks_and_position_ids( | |||||
| tokens, | |||||
| eod_token, | |||||
| reset_position_ids, | |||||
| reset_attention_mask, | |||||
| ) | |||||
| return tokens, attention_mask, position_ids | |||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||||
| """This function has been mostly taken from huggingface conversational | |||||
| ai code at | |||||
| https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||||
| conversational-ai-with-transfer-learning-2d818ac26313""" | |||||
| if top_k > 0: | |||||
| # Remove all tokens with a probability less than the | |||||
| # last token of the top-k | |||||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||||
| None] | |||||
| logits[indices_to_remove] = filter_value | |||||
| if top_p > 0.0: | |||||
| # Cconvert to 1D | |||||
| sorted_logits, sorted_indices = torch.sort( | |||||
| logits, descending=True, dim=-1) | |||||
| cumulative_probs = torch.cumsum( | |||||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
| # Remove tokens with cumulative probability above the threshold | |||||
| sorted_indices_to_remove = cumulative_probs > top_p | |||||
| # Shift the indices to the right to keep also the first token | |||||
| # above the threshold | |||||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||||
| ..., :-1].clone() | |||||
| sorted_indices_to_remove[..., 0] = 0 | |||||
| for i in range(sorted_indices.size(0)): | |||||
| indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | |||||
| logits[i][indices_to_remove] = filter_value | |||||
| return logits | |||||
| def pad_batch(batch, pad_id, seq_length): | |||||
| context_lengths = [] | |||||
| for tokens in batch: | |||||
| context_length = len(tokens) | |||||
| if context_length < seq_length: | |||||
| tokens.extend([pad_id] * (seq_length - context_length)) | |||||
| context_lengths.append(context_length) | |||||
| return batch, context_lengths | |||||
| def get_token_stream( | |||||
| model, | |||||
| tokenizer, | |||||
| seq_length, | |||||
| out_seq_length, | |||||
| context_tokens, | |||||
| return_scores: bool = False, | |||||
| prompt_length: int = None, | |||||
| micro_batch_size: int = None, | |||||
| bad_ids: List = None, | |||||
| temperature: float = 1.0, | |||||
| topp: float = 1.0, | |||||
| topk: int = 0.0, | |||||
| greedy: bool = False, | |||||
| ): | |||||
| context_tokens, context_lengths = pad_batch(context_tokens, | |||||
| tokenizer.eos_token_id, | |||||
| seq_length) | |||||
| context_tokens_tensor = torch.cuda.LongTensor(context_tokens) | |||||
| context_length_tensor = torch.cuda.LongTensor(context_lengths) | |||||
| context_length = context_length_tensor.min().item() | |||||
| tokens, attention_mask, position_ids = get_batch( | |||||
| context_tokens_tensor, | |||||
| micro_batch_size, | |||||
| tokenizer.eos_token_id, | |||||
| ) | |||||
| batch_token_iterator = sample_sequence_batch( | |||||
| model, | |||||
| tokenizer, | |||||
| context_tokens_tensor, | |||||
| context_length_tensor, | |||||
| attention_mask, | |||||
| position_ids, | |||||
| seq_length=seq_length, | |||||
| out_seq_length=out_seq_length, | |||||
| return_scores=return_scores, | |||||
| prompt_length=prompt_length, | |||||
| bad_ids=bad_ids, | |||||
| temperature=temperature, | |||||
| topp=topp, | |||||
| topk=topk, | |||||
| greedy=greedy, | |||||
| ) | |||||
| for tokens, lengths in batch_token_iterator: | |||||
| context_length += 1 | |||||
| if tokens is not None: | |||||
| yield tokens[:, :context_length], lengths | |||||
| else: | |||||
| yield None, None | |||||
| def switch(val1, val2, boolean): | |||||
| boolean = boolean.type_as(val1) | |||||
| return (1 - boolean) * val1 + boolean * val2 | |||||
| def sample_sequence_batch( | |||||
| model, | |||||
| tokenizer, | |||||
| context_tokens, | |||||
| context_lengths, | |||||
| attention_mask, | |||||
| position_ids, | |||||
| seq_length, | |||||
| out_seq_length, | |||||
| maxlen=None, | |||||
| return_scores: bool = False, | |||||
| prompt_length: int = None, | |||||
| bad_ids: List = None, | |||||
| temperature: float = 1.0, | |||||
| topp: float = 1.0, | |||||
| topk: int = 0.0, | |||||
| recompute: bool = False, | |||||
| greedy: bool = False, | |||||
| ): | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| context_length = context_lengths.min().item() | |||||
| eos_id = tokenizer.eos_token_id | |||||
| counter = 0 | |||||
| org_context_length = context_length | |||||
| layer_past = None | |||||
| batch_size = context_tokens.size(0) | |||||
| is_done = torch.zeros([batch_size]).byte().cuda() | |||||
| tokens = context_tokens | |||||
| if maxlen is None: | |||||
| maxlen = seq_length - 1 | |||||
| if maxlen > (org_context_length + out_seq_length): | |||||
| maxlen = org_context_length + out_seq_length | |||||
| lengths = torch.ones([batch_size]).long().cuda() * maxlen | |||||
| if return_scores: | |||||
| scores = torch.zeros([batch_size]).float().cuda() | |||||
| while context_length <= (maxlen): | |||||
| if recompute: | |||||
| logits = model( | |||||
| tokens, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| prompt_length=prompt_length, | |||||
| context_length=context_length, | |||||
| ) | |||||
| logits = logits[:, context_length - 1, :] | |||||
| else: | |||||
| if counter == 0: | |||||
| tokens2use = tokens[:, :context_length] | |||||
| positions2use = position_ids[:, :context_length] | |||||
| else: | |||||
| tokens2use = tokens[:, context_length - 1].view( | |||||
| batch_size, -1) | |||||
| positions2use = position_ids[:, context_length - 1].view( | |||||
| batch_size, -1) | |||||
| logits, layer_past = model( | |||||
| tokens2use, | |||||
| positions2use, | |||||
| attention_mask, | |||||
| layer_past=layer_past, | |||||
| get_key_value=True, | |||||
| prompt_length=prompt_length, | |||||
| context_length=context_length, | |||||
| ) | |||||
| logits = logits[:, -1].view(batch_size, -1).contiguous() | |||||
| if bad_ids is not None: | |||||
| for bad_id in bad_ids: | |||||
| logits[:, bad_id] = -10000 | |||||
| if greedy: | |||||
| prev = torch.argmax(logits, dim=-1).view(-1) | |||||
| else: | |||||
| logits = logits.float() | |||||
| if return_scores: | |||||
| orig_log_probs = torch.log_softmax(logits, dim=-1) | |||||
| logits /= temperature | |||||
| logits = top_k_logits(logits, top_k=topk, top_p=topp) | |||||
| log_probs = F.softmax(logits, dim=-1) | |||||
| prev = torch.multinomial(log_probs, num_samples=1).view(-1) | |||||
| started = context_lengths <= context_length | |||||
| new_tokens = switch(tokens[:, context_length].view(-1), prev, | |||||
| started) | |||||
| if not greedy and return_scores: | |||||
| indices = prev.view(-1, 1) | |||||
| new_scores = orig_log_probs.gather(1, indices).view(-1) | |||||
| new_scores = new_scores * started | |||||
| new_scores = new_scores * is_done.bool().logical_not() | |||||
| scores += new_scores | |||||
| tokens[:, context_length] = new_tokens | |||||
| done_token = (prev == eos_id).byte() & started.byte() | |||||
| just_finished = (done_token & ~is_done).bool() | |||||
| lengths[just_finished.view(-1)] = context_length | |||||
| is_done = is_done | done_token | |||||
| done = torch.all(is_done) | |||||
| if return_scores: | |||||
| yield tokens, (lengths, scores) | |||||
| else: | |||||
| yield tokens, lengths | |||||
| context_length += 1 | |||||
| counter += 1 | |||||
| if done: | |||||
| break | |||||
| @@ -0,0 +1,187 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import List, Union | |||||
| import torch | |||||
| from transformers import AutoTokenizer | |||||
| from transformers.models.gpt2 import GPT2TokenizerFast | |||||
| def encode_whitespaces(text, start_extra_id: int, max_len: int): | |||||
| """ Encode whitespaces to extra tokens in GPT-J. | |||||
| >>> encode_whitespaces('a\\n b\\n c', 10, 10) | |||||
| 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' | |||||
| """ | |||||
| def push_acc_space(acc_len: int, text: str): | |||||
| if acc_len == 0: | |||||
| return text | |||||
| if acc_len == 1: | |||||
| return text + ' ' | |||||
| assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}' | |||||
| extra_id = start_extra_id - 2 + acc_len | |||||
| extra_token = f'<|extratoken_{extra_id}|>' | |||||
| return text + extra_token | |||||
| acc_len = 0 | |||||
| res = '' | |||||
| for ch in text: | |||||
| if ch == ' ': | |||||
| acc_len += 1 | |||||
| if acc_len == max_len: | |||||
| res = push_acc_space(acc_len, res) | |||||
| acc_len = 0 | |||||
| else: | |||||
| res = push_acc_space(acc_len, res) | |||||
| acc_len = 0 | |||||
| res = res + ch | |||||
| res = push_acc_space(acc_len, res) | |||||
| return res | |||||
| def decode_whitespaces(text: str, start_extra_id: int, max_len: int): | |||||
| """ Decode the whitespace-encoded strings produced by encode_whitespace. | |||||
| >>> text = 'a\\n b\\n c' | |||||
| >>> s, l = 10, 10 | |||||
| >>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l) | |||||
| True | |||||
| """ | |||||
| for l in range(2, max_len + 1): # noqa | |||||
| token_id = start_extra_id - 2 + l | |||||
| token = f'<|extratoken_{token_id}|>' | |||||
| text = text.replace(token, ' ' * l) | |||||
| return text | |||||
| class Code13BDictionary(object): | |||||
| def __init__( | |||||
| self, | |||||
| dict_file: str, | |||||
| extra_token_ids: List[str] = None, | |||||
| pad_to_vocab_size: int = -1, | |||||
| ): | |||||
| self._idx = dict() | |||||
| self._count = dict() | |||||
| self._num_symbols = 0 | |||||
| self._symbols = [] | |||||
| self._add_symbol('<s>', 0) | |||||
| self._add_symbol('<pad>', 0) | |||||
| self._add_symbol('</s>', 0) | |||||
| self._add_symbol('<unk>', 0) | |||||
| self._load_dict(dict_file) | |||||
| if extra_token_ids is None: | |||||
| extra_token_ids = [str(x) for x in range(50257, 50400) | |||||
| ] # follows GPT-J settings | |||||
| for token_id in extra_token_ids: | |||||
| self._add_symbol(token_id, 0) | |||||
| if pad_to_vocab_size > 0: | |||||
| self._pad_to_vocab_size(pad_to_vocab_size) | |||||
| def _pad_to_vocab_size(self, vocab_size: int): | |||||
| num_pad = vocab_size - len(self) | |||||
| if num_pad <= 0: | |||||
| return | |||||
| for i in range(1, num_pad + 1): | |||||
| self._add_symbol('vocab_pad_token{}'.format(i), 0) | |||||
| def _load_dict(self, dict_file: str): | |||||
| with open(dict_file, 'r') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line == '' or line.startswith('#'): | |||||
| continue | |||||
| sym, count = line.split() | |||||
| self._add_symbol(sym, int(count)) | |||||
| def _add_symbol(self, sym: str, count: int): | |||||
| self._idx[sym] = self._num_symbols | |||||
| self._count[sym] = count | |||||
| self._symbols.append(sym) | |||||
| self._num_symbols += 1 | |||||
| def __len__(self): | |||||
| return self._num_symbols | |||||
| def index(self, sym: str): | |||||
| return self._idx[sym] | |||||
| def string(self, idx: int): | |||||
| return self._symbols[idx] | |||||
| def map_token(self, token: Union[int, str]): | |||||
| if isinstance(token, int): | |||||
| token = str(token) | |||||
| return self.index(token) | |||||
| def map_tokens(self, tokens): | |||||
| return [self.map_token(token) for token in tokens] | |||||
| def decode_tokens(self, tokens): | |||||
| decoded = [ | |||||
| '50256' if token == 50256 else self.string(token) | |||||
| for token in tokens | |||||
| ] | |||||
| return [int(x) for x in decoded if not x.startswith('vocab_pad_token')] | |||||
| class CodeGeeXTokenizer(object): | |||||
| def __init__( | |||||
| self, | |||||
| tokenizer: GPT2TokenizerFast = None, | |||||
| tokenizer_path: str = 'EleutherAI/gpt-j-6B', | |||||
| start_extra_id: int = 10, | |||||
| max_len: int = 10, | |||||
| mode='codegeex-13b', | |||||
| dict_file: str = None, | |||||
| ): | |||||
| self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( | |||||
| tokenizer_path) | |||||
| if mode not in ['codegeex-13b', 'codegeex-python-13b']: | |||||
| raise ValueError( | |||||
| f"Invalid mode {mode}, choose from ['codegeex-13b', 'codegeex-python-13b']" | |||||
| ) | |||||
| self.start_extra_id = start_extra_id | |||||
| self.max_len = max_len | |||||
| self.mode = mode | |||||
| if dict_file is not None: | |||||
| self.code_dict = Code13BDictionary( | |||||
| dict_file, pad_to_vocab_size=51200 | |||||
| ) if self.mode == 'codegeex-python-13b' else None | |||||
| else: | |||||
| self.code_dict = None | |||||
| self.eos_token_id = self.tokenizer.eos_token_id | |||||
| def encode_code(self, code: str): | |||||
| if self.mode == 'codegeex-13b': | |||||
| code = encode_whitespaces(code, self.start_extra_id, self.max_len) | |||||
| input_ids = self.tokenizer( | |||||
| code, is_split_into_words=False).input_ids | |||||
| elif self.mode == 'codegeex-python-13b': | |||||
| code = encode_whitespaces(code, self.start_extra_id, self.max_len) | |||||
| input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code)) | |||||
| input_ids = torch.LongTensor(input_ids).reshape(1, -1) | |||||
| return input_ids | |||||
| def decode_code(self, input_ids): | |||||
| if self.mode == 'codegeex-13b': | |||||
| text = self.tokenizer.decode(input_ids, skip_special_tokens=False) | |||||
| output_code = decode_whitespaces(text, self.start_extra_id, | |||||
| self.max_len) | |||||
| elif self.mode == 'codegeex-python-13b': | |||||
| input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])] | |||||
| text = self.tokenizer.decode(input_ids, skip_special_tokens=False) | |||||
| output_code = decode_whitespaces(text, self.start_extra_id, | |||||
| self.max_len) | |||||
| return output_code | |||||
| @@ -32,6 +32,8 @@ if TYPE_CHECKING: | |||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | ||||
| from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | |||||
| from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | |||||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | ||||
| WordSegmentationThaiPipeline | WordSegmentationThaiPipeline | ||||
| @@ -73,6 +75,10 @@ else: | |||||
| 'zero_shot_classification_pipeline': | 'zero_shot_classification_pipeline': | ||||
| ['ZeroShotClassificationPipeline'], | ['ZeroShotClassificationPipeline'], | ||||
| 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | ||||
| 'codegeex_code_translation_pipeline': | |||||
| ['CodeGeeXCodeTranslationPipeline'], | |||||
| 'codegeex_code_generation_pipeline': | |||||
| ['CodeGeeXCodeGenerationPipeline'], | |||||
| 'multilingual_word_segmentation_pipeline': [ | 'multilingual_word_segmentation_pipeline': [ | ||||
| 'MultilingualWordSegmentationPipeline', | 'MultilingualWordSegmentationPipeline', | ||||
| 'WordSegmentationThaiPipeline' | 'WordSegmentationThaiPipeline' | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import CodeGeeXForCodeGeneration | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.code_generation, | |||||
| module_name=Pipelines.codegeex_code_generation) | |||||
| class CodeGeeXCodeGenerationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[CodeGeeXForCodeGeneration, str], | |||||
| preprocessor: [Preprocessor] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| model = CodeGeeXForCodeGeneration(model) if isinstance(model, | |||||
| str) else model | |||||
| self.model = model | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| super().__init__(model=model, **kwargs) | |||||
| def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||||
| return inputs | |||||
| # define the forward pass | |||||
| def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||||
| # check input format | |||||
| for para in ['prompt', 'language']: | |||||
| if para not in inputs: | |||||
| raise Exception('Please check your input format.') | |||||
| if inputs['language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: # noqa | |||||
| raise Exception( | |||||
| 'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| return self.model(inputs) | |||||
| # format the outputs from pipeline | |||||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import CodeGeeXForCodeTranslation | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.code_translation, | |||||
| module_name=Pipelines.codegeex_code_translation) | |||||
| class CodeGeeXCodeTranslationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[CodeGeeXForCodeTranslation, str], | |||||
| preprocessor: [Preprocessor] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| model = CodeGeeXForCodeTranslation(model) if isinstance(model, | |||||
| str) else model | |||||
| self.model = model | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| super().__init__(model=model, **kwargs) | |||||
| def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||||
| return inputs | |||||
| # define the forward pass | |||||
| def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||||
| # check input format | |||||
| for para in ['prompt', 'source language', 'target language']: | |||||
| if para not in inputs: | |||||
| raise Exception('please check your input format.') | |||||
| if inputs['source language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: | |||||
| raise Exception( | |||||
| 'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| if inputs['target language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: | |||||
| raise Exception( | |||||
| 'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| return self.model(inputs) | |||||
| # format the outputs from pipeline | |||||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -23,7 +23,7 @@ if TYPE_CHECKING: | |||||
| SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, | SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, | ||||
| TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, | TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, | ||||
| TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, | TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, | ||||
| WordSegmentationBlankSetToLabelPreprocessor, | |||||
| WordSegmentationBlankSetToLabelPreprocessor, CodeGeeXPreprocessor, | |||||
| MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, | MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, | ||||
| TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, | TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, | ||||
| DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | ||||
| @@ -57,7 +57,7 @@ else: | |||||
| 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | ||||
| 'Tokenize', 'Text2TextGenerationPreprocessor', | 'Tokenize', 'Text2TextGenerationPreprocessor', | ||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'MGLMSummarizationPreprocessor', | |||||
| 'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor', | |||||
| 'ZeroShotClassificationPreprocessor', | 'ZeroShotClassificationPreprocessor', | ||||
| 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | ||||
| 'NERPreprocessorViet', 'NERPreprocessorThai', | 'NERPreprocessorViet', 'NERPreprocessorThai', | ||||
| @@ -120,6 +120,8 @@ class NLPTasks(object): | |||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| text_summarization = 'text-summarization' | text_summarization = 'text-summarization' | ||||
| question_answering = 'question-answering' | question_answering = 'question-answering' | ||||
| code_translation = 'code-translation' | |||||
| code_generation = 'code-generation' | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||