@@ -1,17 +1,23 @@ | |||
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | |||
pip install -r requirements/tests.txt | |||
git config --global --add safe.directory /Maas-lib | |||
# run linter test first | |||
git config --global user.email tmp | |||
git config --global user.name tmp.com | |||
# linter test | |||
# use internal project for pre-commit due to the network problem | |||
pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
if [ $? -ne 0 ]; then | |||
echo "linter test failed" | |||
echo "From the repository folder" | |||
echo "Run 'pip install -r requirements/tests.txt' install test dependencies." | |||
echo "Run 'pre-commit install' install pre-commit hooks." | |||
echo "Finally run linter with command: 'pre-commit run --all-files' to check." | |||
echo "Ensure there is no failure!!!!!!!!" | |||
exit -1 | |||
if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||
pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
if [ $? -ne 0 ]; then | |||
echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
echo "From the repository folder" | |||
echo "Run 'pip install -r requirements/tests.txt' install test dependencies." | |||
echo "Run 'pre-commit install' install pre-commit hooks." | |||
echo "Finally run linter with command: 'pre-commit run --all-files' to check." | |||
echo "Ensure there is no failure!!!!!!!!" | |||
exit -1 | |||
fi | |||
fi | |||
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
@@ -0,0 +1,64 @@ | |||
name: citest | |||
on: | |||
push: | |||
branches: | |||
- master | |||
- "release/**" | |||
paths-ignore: | |||
- "setup.*" | |||
- "requirements.txt" | |||
- "requirements/**" | |||
- "docs/**" | |||
- "tools/**" | |||
- ".dev_scripts/**" | |||
- "README.md" | |||
- "README_zh-CN.md" | |||
- "NOTICE" | |||
- ".github/workflows/lint.yaml" | |||
- ".github/workflows/publish.yaml" | |||
pull_request: | |||
paths-ignore: | |||
- "setup.*" | |||
- "requirements.txt" | |||
- "requirements/**" | |||
- "docs/**" | |||
- "tools/**" | |||
- ".dev_scripts/**" | |||
- "README.md" | |||
- "README_zh-CN.md" | |||
- "NOTICE" | |||
- ".github/workflows/lint.yaml" | |||
- ".github/workflows/publish.yaml" | |||
concurrency: | |||
group: ${{ github.workflow }}-${{ github.ref }} | |||
cancel-in-progress: true | |||
jobs: | |||
unittest: | |||
# The type of runner that the job will run on | |||
runs-on: [modelscope-self-hosted] | |||
steps: | |||
- name: ResetFileMode | |||
shell: bash | |||
run: | | |||
# reset filemode to allow action runner to delete files | |||
# generated by root in docker | |||
set -e | |||
source ~/.bashrc | |||
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR | |||
- name: Checkout | |||
uses: actions/checkout@v2 | |||
with: | |||
lfs: 'true' | |||
- name: Checkout LFS objects | |||
run: git lfs checkout | |||
- name: Run unittest | |||
shell: bash | |||
run: | | |||
set -e | |||
source /mnt/modelscope/ci_env.sh | |||
bash .dev_scripts/dockerci.sh |
@@ -0,0 +1,22 @@ | |||
name: Lint test | |||
on: [push, pull_request] | |||
concurrency: | |||
group: ${{ github.workflow }}-${{ github.ref }} | |||
cancel-in-progress: true | |||
jobs: | |||
lint: | |||
runs-on: ubuntu-latest | |||
steps: | |||
- uses: actions/checkout@v2 | |||
- name: Set up Python 3.7 | |||
uses: actions/setup-python@v2 | |||
with: | |||
python-version: 3.7 | |||
- name: Install pre-commit hook | |||
run: | | |||
pip install pre-commit | |||
- name: Linting | |||
run: pre-commit run --all-files |
@@ -44,7 +44,7 @@ There are mainly three test levels: | |||
* level 2: scenario tests for all the implemented modules such as model, pipeline in different algorithm filed. | |||
Default test level is 0, which will only run those cases of level 0, you can set test level | |||
via environment variable `TEST_LEVEL`. For more details, you can refer to [test-doc](https://alidocs.dingtalk.com/i/nodes/mdvQnONayjBJKLXy1Bp38PY2MeXzp5o0?dontjump=true&nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA) | |||
via environment variable `TEST_LEVEL`. | |||
```bash | |||
@@ -159,9 +159,7 @@ git pull origin branch_name | |||
git push --set-upstream origin dev/my-dev-branch | |||
``` | |||
Note that you may push multiple times to the same branch with 'git push' commands later. | |||
5. Open the remote url `https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/new` to create a new merge request that merges your development branch (aka, the "dev/my-dev-branch in this example) into master branch. Please follow the instruction on aone page to submit the merge request a code review. | |||
5. Create a pull request on github to merge your code into master. | |||
## Build pip package | |||
```bash | |||
@@ -74,7 +74,7 @@ pip install "modelscope[multi-modal]" -f https://modelscope.oss-cn-beijing.aliyu | |||
ModelScope的源码可以直接clone到本地: | |||
```shell | |||
git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git modelscope | |||
git clone git@github.com:modelscope/modelscope.git | |||
cd modelscope | |||
git fetch origin master | |||
git checkout master | |||
@@ -86,6 +86,7 @@ class Models(object): | |||
ponet = 'ponet' | |||
T5 = 'T5' | |||
mglm = 'mglm' | |||
codegeex = 'codegeex' | |||
bloom = 'bloom' | |||
# audio models | |||
@@ -94,6 +95,7 @@ class Models(object): | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
kws_kwsbp = 'kws-kwsbp' | |||
generic_asr = 'generic-asr' | |||
wenet_asr = 'wenet-asr' | |||
# multi-modal models | |||
ofa = 'ofa' | |||
@@ -261,6 +263,8 @@ class Pipelines(object): | |||
extractive_summarization = 'extractive-summarization' | |||
feature_extraction = 'feature-extraction' | |||
mglm_text_summarization = 'mglm-text-summarization' | |||
codegeex_code_translation = 'codegeex-code-translation' | |||
codegeex_code_generation = 'codegeex-code-generation' | |||
translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
@@ -273,6 +277,7 @@ class Pipelines(object): | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
kws_kwsbp = 'kws-kwsbp' | |||
asr_inference = 'asr-inference' | |||
asr_wenet_inference = 'asr-wenet-inference' | |||
# multi-modal tasks | |||
image_captioning = 'image-captioning' | |||
@@ -0,0 +1,38 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict | |||
import json | |||
import wenetruntime as wenet | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Model | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['WeNetAutomaticSpeechRecognition'] | |||
@MODELS.register_module( | |||
Tasks.auto_speech_recognition, module_name=Models.wenet_asr) | |||
class WeNetAutomaticSpeechRecognition(Model): | |||
def __init__(self, model_dir: str, am_model_name: str, | |||
model_config: Dict[str, Any], *args, **kwargs): | |||
"""initialize the info of model. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, am_model_name, model_config, *args, | |||
**kwargs) | |||
self.decoder = wenet.Decoder(model_dir, lang='chs') | |||
def forward(self, inputs: Dict[str, Any]) -> str: | |||
if inputs['audio_format'] == 'wav': | |||
rst = self.decoder.decode_wav(inputs['audio']) | |||
else: | |||
rst = self.decoder.decode(inputs['audio']) | |||
text = json.loads(rst)['nbest'][0]['sentence'] | |||
return {'text': text} |
@@ -36,6 +36,7 @@ if TYPE_CHECKING: | |||
) | |||
from .T5 import T5ForConditionalGeneration | |||
from .mglm import MGLMForTextSummarization | |||
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration | |||
from .task_models import ( | |||
FeatureExtractionModel, | |||
InformationExtractionModel, | |||
@@ -110,6 +111,8 @@ else: | |||
'sentence_embedding': ['SentenceEmbedding'], | |||
'T5': ['T5ForConditionalGeneration'], | |||
'mglm': ['MGLMForTextSummarization'], | |||
'codegeex': | |||
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||
'gpt_neo': ['GPTNeoModel'], | |||
'bloom': ['BloomModel'], | |||
} | |||
@@ -0,0 +1,24 @@ | |||
# Modified by Zhipu.AI | |||
# Original Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING, Union | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .codegeex_for_code_translation import CodeGeeXForCodeTranslation | |||
from .codegeex_for_code_generation import CodeGeeXForCodeGeneration | |||
else: | |||
_import_structure = { | |||
'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'], | |||
'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,110 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
import copy | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .codegeex import CodeGeeXModel | |||
from .inference import get_token_stream | |||
from .tokenizer import CodeGeeXTokenizer | |||
def model_provider(): | |||
"""Build the model.""" | |||
hidden_size = 5120 | |||
num_attention_heads = 40 | |||
num_layers = 39 | |||
padded_vocab_size = 52224 | |||
max_position_embeddings = 2048 | |||
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||
padded_vocab_size, max_position_embeddings) | |||
return model | |||
@MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex) | |||
class CodeGeeXForCodeGeneration(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the fast poem model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
logger = get_logger() | |||
# loading tokenizer | |||
logger.info('Loading tokenizer ...') | |||
self.tokenizer = CodeGeeXTokenizer( | |||
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||
# loading model | |||
state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt' | |||
logger.info('Loading state dict ...') | |||
state_dict = torch.load(state_dict_path, map_location='cpu') | |||
state_dict = state_dict['module'] | |||
logger.info('Building CodeGeeX model ...') | |||
self.model = model_provider() | |||
self.model.load_state_dict(state_dict) | |||
self.model.eval() | |||
self.model.half() | |||
self.model.cuda() | |||
def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||
micro_batch_size = 1 | |||
seq_length = 2048 | |||
out_seq_length = 256 | |||
bad_ids = None | |||
lang = input['language'] | |||
prompt = input['prompt'] | |||
prompt = f'# language: {lang}\n{prompt}' | |||
logger = get_logger() | |||
tokenizer = self.tokenizer | |||
model = self.model | |||
for prompt in [prompt]: | |||
tokens = tokenizer.encode_code(prompt) | |||
n_token_prompt = len(tokens) | |||
token_stream = get_token_stream( | |||
model, | |||
tokenizer, | |||
seq_length, | |||
out_seq_length, | |||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||
micro_batch_size=micro_batch_size, | |||
bad_ids=bad_ids, | |||
topk=1, | |||
topp=0.9, | |||
temperature=0.9, | |||
greedy=True) | |||
is_finished = [False for _ in range(micro_batch_size)] | |||
for i, generated in enumerate(token_stream): | |||
generated_tokens = generated[0] | |||
for j in range(micro_batch_size): | |||
if is_finished[j]: | |||
continue | |||
if generated_tokens[j].cpu().numpy( | |||
)[-1] == tokenizer.eos_token_id or len( | |||
generated_tokens[j]) >= out_seq_length: | |||
is_finished[j] = True | |||
generated_tokens_ = generated_tokens[j].cpu().numpy( | |||
).tolist() | |||
generated_code = tokenizer.decode_code( | |||
generated_tokens_[n_token_prompt:]) | |||
generated_code = ''.join(generated_code) | |||
logger.info( | |||
'================================= Generated code:' | |||
) | |||
logger.info(generated_code) | |||
if all(is_finished): | |||
break | |||
logger.info('Generation finished.') | |||
return {OutputKeys.TEXT: generated_code} |
@@ -0,0 +1,109 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
import copy | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .codegeex import CodeGeeXModel | |||
from .inference import get_token_stream | |||
from .tokenizer import CodeGeeXTokenizer | |||
def model_provider(): | |||
"""Build the model.""" | |||
hidden_size = 5120 | |||
num_attention_heads = 40 | |||
num_layers = 39 | |||
padded_vocab_size = 52224 | |||
max_position_embeddings = 2048 | |||
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||
padded_vocab_size, max_position_embeddings) | |||
return model | |||
@MODELS.register_module(Tasks.code_translation, module_name=Models.codegeex) | |||
class CodeGeeXForCodeTranslation(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the fast poem model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
logger = get_logger() | |||
# loading tokenizer | |||
logger.info('Loading tokenizer ...') | |||
self.tokenizer = CodeGeeXTokenizer( | |||
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||
# loading model | |||
state_dict_path = model_dir + '/ckpt_ms_translation_0817.pt' | |||
logger.info('Loading state dict ...') | |||
state_dict = torch.load(state_dict_path, map_location='cpu') | |||
state_dict = state_dict['module'] | |||
logger.info('Building CodeGeeX model ...') | |||
self.model = model_provider() | |||
self.model.load_state_dict(state_dict) | |||
self.model.eval() | |||
self.model.half() | |||
self.model.cuda() | |||
def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||
micro_batch_size = 1 | |||
seq_length = 2048 | |||
out_seq_length = 256 | |||
bad_ids = None | |||
src_lang = input['source language'] | |||
dst_lang = input['target language'] | |||
prompt = input['prompt'] | |||
prompt = f'code translation\n{src_lang}:\n{prompt}\n{dst_lang}:\n' | |||
logger = get_logger() | |||
tokenizer = self.tokenizer | |||
model = self.model | |||
for prompt in [prompt]: | |||
tokens = tokenizer.encode_code(prompt) | |||
n_token_prompt = len(tokens) | |||
token_stream = get_token_stream( | |||
model, | |||
tokenizer, | |||
seq_length, | |||
out_seq_length, | |||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||
micro_batch_size=micro_batch_size, | |||
bad_ids=bad_ids, | |||
greedy=True, | |||
) | |||
is_finished = [False for _ in range(micro_batch_size)] | |||
for i, generated in enumerate(token_stream): | |||
generated_tokens = generated[0] | |||
for j in range(micro_batch_size): | |||
if is_finished[j]: | |||
continue | |||
if generated_tokens[j].cpu().numpy( | |||
)[-1] == tokenizer.eos_token_id or len( | |||
generated_tokens[j]) >= out_seq_length: | |||
is_finished[j] = True | |||
generated_tokens_ = generated_tokens[j].cpu().numpy( | |||
).tolist() | |||
generated_code = tokenizer.decode_code( | |||
generated_tokens_[n_token_prompt:]) | |||
generated_code = ''.join(generated_code) | |||
logger.info( | |||
'================================= Generated code:' | |||
) | |||
logger.info(generated_code) | |||
if all(is_finished): | |||
break | |||
logger.info('Generation finished.') | |||
return {OutputKeys.TEXT: generated_code} |
@@ -0,0 +1,301 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
from typing import List | |||
import torch | |||
import torch.nn.functional as F | |||
def get_ltor_masks_and_position_ids( | |||
data, | |||
eod_token, | |||
reset_position_ids, | |||
reset_attention_mask, | |||
): | |||
"""Build masks and position id for left to right model.""" | |||
# Extract batch size and sequence length. | |||
micro_batch_size, seq_length = data.size() | |||
# Attention mask (lower triangular). | |||
if reset_attention_mask: | |||
att_mask_batch = micro_batch_size | |||
else: | |||
att_mask_batch = 1 | |||
attention_mask = torch.tril( | |||
torch.ones((att_mask_batch, seq_length, seq_length), | |||
device=data.device)).view(att_mask_batch, 1, seq_length, | |||
seq_length) | |||
# Position ids. | |||
position_ids = torch.arange( | |||
seq_length, dtype=torch.long, device=data.device) | |||
position_ids = position_ids.unsqueeze(0).expand_as(data) | |||
# We need to clone as the ids will be modifed based on batch index. | |||
if reset_position_ids: | |||
position_ids = position_ids.clone() | |||
if reset_position_ids or reset_attention_mask: | |||
# Loop through the batches: | |||
for b in range(micro_batch_size): | |||
# Find indecies where EOD token is. | |||
eod_index = position_ids[b, data[b] == eod_token] | |||
# Detach indecies from positions if going to modify positions. | |||
if reset_position_ids: | |||
eod_index = eod_index.clone() | |||
# Loop through EOD indecies: | |||
prev_index = 0 | |||
for j in range(eod_index.size()[0]): | |||
i = eod_index[j] | |||
# Mask attention loss. | |||
if reset_attention_mask: | |||
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 | |||
# Reset positions. | |||
if reset_position_ids: | |||
position_ids[b, (i + 1):] -= i + 1 - prev_index | |||
prev_index = i + 1 | |||
# Convert attention mask to binary: | |||
attention_mask = attention_mask < 0.5 | |||
return attention_mask, position_ids | |||
def get_batch( | |||
context_tokens, | |||
micro_batch_size, | |||
eod_token, | |||
reset_position_ids=False, | |||
reset_attention_mask=False, | |||
): | |||
"""Generate batch from context tokens.""" | |||
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda() | |||
# Get the attention mask and postition ids. | |||
attention_mask, position_ids = get_ltor_masks_and_position_ids( | |||
tokens, | |||
eod_token, | |||
reset_position_ids, | |||
reset_attention_mask, | |||
) | |||
return tokens, attention_mask, position_ids | |||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||
"""This function has been mostly taken from huggingface conversational | |||
ai code at | |||
https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||
conversational-ai-with-transfer-learning-2d818ac26313""" | |||
if top_k > 0: | |||
# Remove all tokens with a probability less than the | |||
# last token of the top-k | |||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||
None] | |||
logits[indices_to_remove] = filter_value | |||
if top_p > 0.0: | |||
# Cconvert to 1D | |||
sorted_logits, sorted_indices = torch.sort( | |||
logits, descending=True, dim=-1) | |||
cumulative_probs = torch.cumsum( | |||
F.softmax(sorted_logits, dim=-1), dim=-1) | |||
# Remove tokens with cumulative probability above the threshold | |||
sorted_indices_to_remove = cumulative_probs > top_p | |||
# Shift the indices to the right to keep also the first token | |||
# above the threshold | |||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||
..., :-1].clone() | |||
sorted_indices_to_remove[..., 0] = 0 | |||
for i in range(sorted_indices.size(0)): | |||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | |||
logits[i][indices_to_remove] = filter_value | |||
return logits | |||
def pad_batch(batch, pad_id, seq_length): | |||
context_lengths = [] | |||
for tokens in batch: | |||
context_length = len(tokens) | |||
if context_length < seq_length: | |||
tokens.extend([pad_id] * (seq_length - context_length)) | |||
context_lengths.append(context_length) | |||
return batch, context_lengths | |||
def get_token_stream( | |||
model, | |||
tokenizer, | |||
seq_length, | |||
out_seq_length, | |||
context_tokens, | |||
return_scores: bool = False, | |||
prompt_length: int = None, | |||
micro_batch_size: int = None, | |||
bad_ids: List = None, | |||
temperature: float = 1.0, | |||
topp: float = 1.0, | |||
topk: int = 0.0, | |||
greedy: bool = False, | |||
): | |||
context_tokens, context_lengths = pad_batch(context_tokens, | |||
tokenizer.eos_token_id, | |||
seq_length) | |||
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) | |||
context_length_tensor = torch.cuda.LongTensor(context_lengths) | |||
context_length = context_length_tensor.min().item() | |||
tokens, attention_mask, position_ids = get_batch( | |||
context_tokens_tensor, | |||
micro_batch_size, | |||
tokenizer.eos_token_id, | |||
) | |||
batch_token_iterator = sample_sequence_batch( | |||
model, | |||
tokenizer, | |||
context_tokens_tensor, | |||
context_length_tensor, | |||
attention_mask, | |||
position_ids, | |||
seq_length=seq_length, | |||
out_seq_length=out_seq_length, | |||
return_scores=return_scores, | |||
prompt_length=prompt_length, | |||
bad_ids=bad_ids, | |||
temperature=temperature, | |||
topp=topp, | |||
topk=topk, | |||
greedy=greedy, | |||
) | |||
for tokens, lengths in batch_token_iterator: | |||
context_length += 1 | |||
if tokens is not None: | |||
yield tokens[:, :context_length], lengths | |||
else: | |||
yield None, None | |||
def switch(val1, val2, boolean): | |||
boolean = boolean.type_as(val1) | |||
return (1 - boolean) * val1 + boolean * val2 | |||
def sample_sequence_batch( | |||
model, | |||
tokenizer, | |||
context_tokens, | |||
context_lengths, | |||
attention_mask, | |||
position_ids, | |||
seq_length, | |||
out_seq_length, | |||
maxlen=None, | |||
return_scores: bool = False, | |||
prompt_length: int = None, | |||
bad_ids: List = None, | |||
temperature: float = 1.0, | |||
topp: float = 1.0, | |||
topk: int = 0.0, | |||
recompute: bool = False, | |||
greedy: bool = False, | |||
): | |||
model.eval() | |||
with torch.no_grad(): | |||
context_length = context_lengths.min().item() | |||
eos_id = tokenizer.eos_token_id | |||
counter = 0 | |||
org_context_length = context_length | |||
layer_past = None | |||
batch_size = context_tokens.size(0) | |||
is_done = torch.zeros([batch_size]).byte().cuda() | |||
tokens = context_tokens | |||
if maxlen is None: | |||
maxlen = seq_length - 1 | |||
if maxlen > (org_context_length + out_seq_length): | |||
maxlen = org_context_length + out_seq_length | |||
lengths = torch.ones([batch_size]).long().cuda() * maxlen | |||
if return_scores: | |||
scores = torch.zeros([batch_size]).float().cuda() | |||
while context_length <= (maxlen): | |||
if recompute: | |||
logits = model( | |||
tokens, | |||
position_ids, | |||
attention_mask, | |||
prompt_length=prompt_length, | |||
context_length=context_length, | |||
) | |||
logits = logits[:, context_length - 1, :] | |||
else: | |||
if counter == 0: | |||
tokens2use = tokens[:, :context_length] | |||
positions2use = position_ids[:, :context_length] | |||
else: | |||
tokens2use = tokens[:, context_length - 1].view( | |||
batch_size, -1) | |||
positions2use = position_ids[:, context_length - 1].view( | |||
batch_size, -1) | |||
logits, layer_past = model( | |||
tokens2use, | |||
positions2use, | |||
attention_mask, | |||
layer_past=layer_past, | |||
get_key_value=True, | |||
prompt_length=prompt_length, | |||
context_length=context_length, | |||
) | |||
logits = logits[:, -1].view(batch_size, -1).contiguous() | |||
if bad_ids is not None: | |||
for bad_id in bad_ids: | |||
logits[:, bad_id] = -10000 | |||
if greedy: | |||
prev = torch.argmax(logits, dim=-1).view(-1) | |||
else: | |||
logits = logits.float() | |||
if return_scores: | |||
orig_log_probs = torch.log_softmax(logits, dim=-1) | |||
logits /= temperature | |||
logits = top_k_logits(logits, top_k=topk, top_p=topp) | |||
log_probs = F.softmax(logits, dim=-1) | |||
prev = torch.multinomial(log_probs, num_samples=1).view(-1) | |||
started = context_lengths <= context_length | |||
new_tokens = switch(tokens[:, context_length].view(-1), prev, | |||
started) | |||
if not greedy and return_scores: | |||
indices = prev.view(-1, 1) | |||
new_scores = orig_log_probs.gather(1, indices).view(-1) | |||
new_scores = new_scores * started | |||
new_scores = new_scores * is_done.bool().logical_not() | |||
scores += new_scores | |||
tokens[:, context_length] = new_tokens | |||
done_token = (prev == eos_id).byte() & started.byte() | |||
just_finished = (done_token & ~is_done).bool() | |||
lengths[just_finished.view(-1)] = context_length | |||
is_done = is_done | done_token | |||
done = torch.all(is_done) | |||
if return_scores: | |||
yield tokens, (lengths, scores) | |||
else: | |||
yield tokens, lengths | |||
context_length += 1 | |||
counter += 1 | |||
if done: | |||
break |
@@ -0,0 +1,187 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
from typing import List, Union | |||
import torch | |||
from transformers import AutoTokenizer | |||
from transformers.models.gpt2 import GPT2TokenizerFast | |||
def encode_whitespaces(text, start_extra_id: int, max_len: int): | |||
""" Encode whitespaces to extra tokens in GPT-J. | |||
>>> encode_whitespaces('a\\n b\\n c', 10, 10) | |||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' | |||
""" | |||
def push_acc_space(acc_len: int, text: str): | |||
if acc_len == 0: | |||
return text | |||
if acc_len == 1: | |||
return text + ' ' | |||
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}' | |||
extra_id = start_extra_id - 2 + acc_len | |||
extra_token = f'<|extratoken_{extra_id}|>' | |||
return text + extra_token | |||
acc_len = 0 | |||
res = '' | |||
for ch in text: | |||
if ch == ' ': | |||
acc_len += 1 | |||
if acc_len == max_len: | |||
res = push_acc_space(acc_len, res) | |||
acc_len = 0 | |||
else: | |||
res = push_acc_space(acc_len, res) | |||
acc_len = 0 | |||
res = res + ch | |||
res = push_acc_space(acc_len, res) | |||
return res | |||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int): | |||
""" Decode the whitespace-encoded strings produced by encode_whitespace. | |||
>>> text = 'a\\n b\\n c' | |||
>>> s, l = 10, 10 | |||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l) | |||
True | |||
""" | |||
for l in range(2, max_len + 1): # noqa | |||
token_id = start_extra_id - 2 + l | |||
token = f'<|extratoken_{token_id}|>' | |||
text = text.replace(token, ' ' * l) | |||
return text | |||
class Code13BDictionary(object): | |||
def __init__( | |||
self, | |||
dict_file: str, | |||
extra_token_ids: List[str] = None, | |||
pad_to_vocab_size: int = -1, | |||
): | |||
self._idx = dict() | |||
self._count = dict() | |||
self._num_symbols = 0 | |||
self._symbols = [] | |||
self._add_symbol('<s>', 0) | |||
self._add_symbol('<pad>', 0) | |||
self._add_symbol('</s>', 0) | |||
self._add_symbol('<unk>', 0) | |||
self._load_dict(dict_file) | |||
if extra_token_ids is None: | |||
extra_token_ids = [str(x) for x in range(50257, 50400) | |||
] # follows GPT-J settings | |||
for token_id in extra_token_ids: | |||
self._add_symbol(token_id, 0) | |||
if pad_to_vocab_size > 0: | |||
self._pad_to_vocab_size(pad_to_vocab_size) | |||
def _pad_to_vocab_size(self, vocab_size: int): | |||
num_pad = vocab_size - len(self) | |||
if num_pad <= 0: | |||
return | |||
for i in range(1, num_pad + 1): | |||
self._add_symbol('vocab_pad_token{}'.format(i), 0) | |||
def _load_dict(self, dict_file: str): | |||
with open(dict_file, 'r') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line == '' or line.startswith('#'): | |||
continue | |||
sym, count = line.split() | |||
self._add_symbol(sym, int(count)) | |||
def _add_symbol(self, sym: str, count: int): | |||
self._idx[sym] = self._num_symbols | |||
self._count[sym] = count | |||
self._symbols.append(sym) | |||
self._num_symbols += 1 | |||
def __len__(self): | |||
return self._num_symbols | |||
def index(self, sym: str): | |||
return self._idx[sym] | |||
def string(self, idx: int): | |||
return self._symbols[idx] | |||
def map_token(self, token: Union[int, str]): | |||
if isinstance(token, int): | |||
token = str(token) | |||
return self.index(token) | |||
def map_tokens(self, tokens): | |||
return [self.map_token(token) for token in tokens] | |||
def decode_tokens(self, tokens): | |||
decoded = [ | |||
'50256' if token == 50256 else self.string(token) | |||
for token in tokens | |||
] | |||
return [int(x) for x in decoded if not x.startswith('vocab_pad_token')] | |||
class CodeGeeXTokenizer(object): | |||
def __init__( | |||
self, | |||
tokenizer: GPT2TokenizerFast = None, | |||
tokenizer_path: str = 'EleutherAI/gpt-j-6B', | |||
start_extra_id: int = 10, | |||
max_len: int = 10, | |||
mode='codegeex-13b', | |||
dict_file: str = None, | |||
): | |||
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( | |||
tokenizer_path) | |||
if mode not in ['codegeex-13b', 'codegeex-python-13b']: | |||
raise ValueError( | |||
f"Invalid mode {mode}, choose from ['codegeex-13b', 'codegeex-python-13b']" | |||
) | |||
self.start_extra_id = start_extra_id | |||
self.max_len = max_len | |||
self.mode = mode | |||
if dict_file is not None: | |||
self.code_dict = Code13BDictionary( | |||
dict_file, pad_to_vocab_size=51200 | |||
) if self.mode == 'codegeex-python-13b' else None | |||
else: | |||
self.code_dict = None | |||
self.eos_token_id = self.tokenizer.eos_token_id | |||
def encode_code(self, code: str): | |||
if self.mode == 'codegeex-13b': | |||
code = encode_whitespaces(code, self.start_extra_id, self.max_len) | |||
input_ids = self.tokenizer( | |||
code, is_split_into_words=False).input_ids | |||
elif self.mode == 'codegeex-python-13b': | |||
code = encode_whitespaces(code, self.start_extra_id, self.max_len) | |||
input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code)) | |||
input_ids = torch.LongTensor(input_ids).reshape(1, -1) | |||
return input_ids | |||
def decode_code(self, input_ids): | |||
if self.mode == 'codegeex-13b': | |||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False) | |||
output_code = decode_whitespaces(text, self.start_extra_id, | |||
self.max_len) | |||
elif self.mode == 'codegeex-python-13b': | |||
input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])] | |||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False) | |||
output_code = decode_whitespaces(text, self.start_extra_id, | |||
self.max_len) | |||
return output_code |
@@ -0,0 +1,87 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Union | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import WavToScp | |||
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||
load_bytes_from_url) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['WeNetAutomaticSpeechRecognitionPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference) | |||
class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): | |||
"""ASR Inference Pipeline | |||
""" | |||
def __init__(self, | |||
model: Union[Model, str] = None, | |||
preprocessor: WavToScp = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create an asr pipeline for prediction | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def __call__(self, | |||
audio_in: Union[str, bytes], | |||
audio_fs: int = None, | |||
recog_type: str = None, | |||
audio_format: str = None) -> Dict[str, Any]: | |||
from easyasr.common import asr_utils | |||
self.recog_type = recog_type | |||
self.audio_format = audio_format | |||
self.audio_fs = audio_fs | |||
if isinstance(audio_in, str): | |||
# load pcm data from url if audio_in is url str | |||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||
elif isinstance(audio_in, bytes): | |||
# load pcm data from wav data if audio_in is wave format | |||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||
else: | |||
self.audio_in = audio_in | |||
# set the sample_rate of audio_in if checking_audio_fs is valid | |||
if checking_audio_fs is not None: | |||
self.audio_fs = checking_audio_fs | |||
if recog_type is None or audio_format is None: | |||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||
audio_in=self.audio_in, | |||
recog_type=recog_type, | |||
audio_format=audio_format) | |||
if hasattr(asr_utils, 'sample_rate_checking'): | |||
checking_audio_fs = asr_utils.sample_rate_checking( | |||
self.audio_in, self.audio_format) | |||
if checking_audio_fs is not None: | |||
self.audio_fs = checking_audio_fs | |||
inputs = { | |||
'audio': self.audio_in, | |||
'audio_format': self.audio_format, | |||
'audio_fs': self.audio_fs | |||
} | |||
output = self.forward(inputs) | |||
rst = self.postprocess(output['asr_result']) | |||
return rst | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""Decoding | |||
""" | |||
inputs['asr_result'] = self.model(inputs) | |||
return inputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the asr results | |||
""" | |||
return inputs |
@@ -33,6 +33,8 @@ if TYPE_CHECKING: | |||
from .word_segmentation_pipeline import WordSegmentationPipeline, WordSegmentationThaiPipeline | |||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | |||
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | |||
else: | |||
_import_structure = { | |||
@@ -75,6 +77,10 @@ else: | |||
'zero_shot_classification_pipeline': | |||
['ZeroShotClassificationPipeline'], | |||
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | |||
'codegeex_code_translation_pipeline': | |||
['CodeGeeXCodeTranslationPipeline'], | |||
'codegeex_code_generation_pipeline': | |||
['CodeGeeXCodeGenerationPipeline'], | |||
} | |||
import sys | |||
@@ -0,0 +1,55 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
from typing import Any, Dict, Union | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.nlp import CodeGeeXForCodeGeneration | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import Preprocessor | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
group_key=Tasks.code_generation, | |||
module_name=Pipelines.codegeex_code_generation) | |||
class CodeGeeXCodeGenerationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[CodeGeeXForCodeGeneration, str], | |||
preprocessor: [Preprocessor] = None, | |||
*args, | |||
**kwargs): | |||
model = CodeGeeXForCodeGeneration(model) if isinstance(model, | |||
str) else model | |||
self.model = model | |||
self.model.eval() | |||
self.model.half() | |||
self.model.cuda() | |||
super().__init__(model=model, **kwargs) | |||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||
return inputs | |||
# define the forward pass | |||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||
# check input format | |||
for para in ['prompt', 'language']: | |||
if para not in inputs: | |||
raise Exception('Please check your input format.') | |||
if inputs['language'] not in [ | |||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||
'Pascal', 'R', 'Fortran', 'Lean' | |||
]: # noqa | |||
raise Exception( | |||
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||
) # noqa | |||
return self.model(inputs) | |||
# format the outputs from pipeline | |||
def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||
return input |
@@ -0,0 +1,65 @@ | |||
# Copyright (c) 2022 Zhipu.AI | |||
from typing import Any, Dict, Union | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.nlp import CodeGeeXForCodeTranslation | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import Preprocessor | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
group_key=Tasks.code_translation, | |||
module_name=Pipelines.codegeex_code_translation) | |||
class CodeGeeXCodeTranslationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[CodeGeeXForCodeTranslation, str], | |||
preprocessor: [Preprocessor] = None, | |||
*args, | |||
**kwargs): | |||
model = CodeGeeXForCodeTranslation(model) if isinstance(model, | |||
str) else model | |||
self.model = model | |||
self.model.eval() | |||
self.model.half() | |||
self.model.cuda() | |||
super().__init__(model=model, **kwargs) | |||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||
return inputs | |||
# define the forward pass | |||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||
# check input format | |||
for para in ['prompt', 'source language', 'target language']: | |||
if para not in inputs: | |||
raise Exception('please check your input format.') | |||
if inputs['source language'] not in [ | |||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||
'Pascal', 'R', 'Fortran', 'Lean' | |||
]: | |||
raise Exception( | |||
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||
) # noqa | |||
if inputs['target language'] not in [ | |||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||
'Pascal', 'R', 'Fortran', 'Lean' | |||
]: | |||
raise Exception( | |||
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||
) # noqa | |||
return self.model(inputs) | |||
# format the outputs from pipeline | |||
def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||
return input |
@@ -23,7 +23,7 @@ if TYPE_CHECKING: | |||
SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, | |||
TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, | |||
TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, | |||
WordSegmentationBlankSetToLabelPreprocessor, | |||
WordSegmentationBlankSetToLabelPreprocessor, CodeGeeXPreprocessor, | |||
MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, | |||
TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, | |||
DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | |||
@@ -57,7 +57,7 @@ else: | |||
'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | |||
'Tokenize', 'Text2TextGenerationPreprocessor', | |||
'WordSegmentationBlankSetToLabelPreprocessor', | |||
'MGLMSummarizationPreprocessor', | |||
'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor', | |||
'ZeroShotClassificationPreprocessor', | |||
'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | |||
'NERPreprocessorViet', 'NERPreprocessorThai', | |||
@@ -121,6 +121,8 @@ class NLPTasks(object): | |||
fill_mask = 'fill-mask' | |||
text_summarization = 'text-summarization' | |||
question_answering = 'question-answering' | |||
code_translation = 'code-translation' | |||
code_generation = 'code-generation' | |||
zero_shot_classification = 'zero-shot-classification' | |||
backbone = 'backbone' | |||
text_error_correction = 'text-error-correction' | |||
@@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """ | |||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | |||
""" | |||
WENETRUNTIME_IMPORT_ERROR = """ | |||
{0} requires the wenetruntime library but it was not found in your environment. You can install it with pip: | |||
`pip install wenetruntime==TORCH_VER` | |||
""" | |||
# docstyle-ignore | |||
SCIPY_IMPORT_ERROR = """ | |||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: | |||
@@ -245,6 +245,10 @@ def is_torch_cuda_available(): | |||
return False | |||
def is_wenetruntime_available(): | |||
return importlib.util.find_spec('wenetruntime') is not None | |||
def is_tf_available(): | |||
return _tf_available | |||
@@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([ | |||
('timm', (is_timm_available, TIMM_IMPORT_ERROR)), | |||
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), | |||
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), | |||
('wenetruntime', | |||
(is_wenetruntime_available, | |||
WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))), | |||
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), | |||
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), | |||
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), | |||
@@ -0,0 +1,131 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import unittest | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import soundfile | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import ColorCodes, Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import download_and_untar, test_level | |||
logger = get_logger() | |||
WAV_FILE = 'data/test/audios/asr_example.wav' | |||
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' | |||
class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, | |||
DemoCompatibilityCheck): | |||
action_info = { | |||
'test_run_with_pcm': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_url': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_wav': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'wav_example': { | |||
'text': '每一天都要快乐喔' | |||
} | |||
} | |||
def setUp(self) -> None: | |||
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' | |||
# this temporary workspace dir will store waveform files | |||
self.workspace = os.path.join(os.getcwd(), '.tmp') | |||
self.task = Tasks.auto_speech_recognition | |||
if not os.path.exists(self.workspace): | |||
os.mkdir(self.workspace) | |||
def tearDown(self) -> None: | |||
# remove workspace dir (.tmp) | |||
shutil.rmtree(self.workspace, ignore_errors=True) | |||
def run_pipeline(self, | |||
model_id: str, | |||
audio_in: Union[str, bytes], | |||
sr: int = None) -> Dict[str, Any]: | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=model_id) | |||
rec_result = inference_16k_pipline(audio_in, audio_fs=sr) | |||
return rec_result | |||
def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||
+ ColorCodes.END) | |||
logger.error( | |||
ColorCodes.MAGENTA + functions + ' correct result example:' | |||
+ ColorCodes.YELLOW | |||
+ str(self.action_info[self.action_info[functions]['example']]) | |||
+ ColorCodes.END) | |||
raise ValueError('asr result is mismatched') | |||
def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||
if result.__contains__(self.action_info[functions]['checking_item']): | |||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||
+ ColorCodes.END) | |||
logger.info( | |||
ColorCodes.YELLOW | |||
+ str(result[self.action_info[functions]['checking_item']]) | |||
+ ColorCodes.END) | |||
else: | |||
self.log_error(functions, result) | |||
def wav2bytes(self, wav_file): | |||
audio, fs = soundfile.read(wav_file) | |||
# float32 -> int16 | |||
audio = np.asarray(audio) | |||
dtype = np.dtype('int16') | |||
i = np.iinfo(dtype) | |||
abs_max = 2**(i.bits - 1) | |||
offset = i.min + abs_max | |||
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||
# int16(PCM_16) -> byte | |||
audio = audio.tobytes() | |||
return audio, fs | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_pcm(self): | |||
"""run with wav data | |||
""" | |||
logger.info('Run ASR test with wav data (wenet)...') | |||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_model_id, audio_in=audio, sr=sr) | |||
self.check_result('test_run_with_pcm', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav(self): | |||
"""run with single waveform file | |||
""" | |||
logger.info('Run ASR test with waveform file (wenet)...') | |||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_url(self): | |||
"""run with single url file | |||
""" | |||
logger.info('Run ASR test with url file (wenet)...') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_model_id, audio_in=URL_FILE) | |||
self.check_result('test_run_with_url', rec_result) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -10,6 +10,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run | |||
- test_easycv_trainer.py | |||
- test_segformer.py | |||
- test_segmentation_pipeline.py | |||
- test_movie_scene_segmentation.py | |||
- test_image_inpainting.py | |||
- test_mglm_text_summarization.py | |||
- test_team_transfer_trainer.py | |||