ofa增加asr任务infer Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10761019master^2
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:46dbc998c9d1d48111267c40741dd3200f2e5bcf4075f8c4c97f4451160dce50 | |||
size 134570 |
@@ -284,6 +284,7 @@ class Pipelines(object): | |||
video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
image_text_retrieval = 'image-text-retrieval' | |||
ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
ofa_asr = 'ofa-asr' | |||
# science tasks | |||
protein_structure = 'unifold-protein-structure' | |||
@@ -1,5 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .modeling_mmspeech import MMSpeechModel | |||
from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast |
@@ -0,0 +1,260 @@ | |||
# Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" MMSpeech model configuration""" | |||
import warnings | |||
from transformers import PretrainedConfig | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
class MMSpeechConfig(PretrainedConfig): | |||
r""" | |||
This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||
defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||
architecture. | |||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
documentation from [`PretrainedConfig`] for more information. | |||
Args: | |||
vocab_size (`int`, *optional*, defaults to 50265): | |||
Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||
`inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||
d_model (`int`, *optional*, defaults to 1024): | |||
Dimension of the layers and the pooler layer. | |||
encoder_layers (`int`, *optional*, defaults to 12): | |||
Number of encoder layers. | |||
decoder_layers (`int`, *optional*, defaults to 12): | |||
Number of decoder layers. | |||
encoder_attention_heads (`int`, *optional*, defaults to 16): | |||
Number of attention heads for each attention layer in the Transformer encoder. | |||
decoder_attention_heads (`int`, *optional*, defaults to 16): | |||
Number of attention heads for each attention layer in the Transformer decoder. | |||
decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||
`"relu"`, `"silu"` and `"gelu_new"` are supported. | |||
dropout (`float`, *optional*, defaults to 0.1): | |||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
attention_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for the attention probabilities. | |||
activation_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for activations inside the fully connected layer. | |||
classifier_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for classifier. | |||
max_position_embeddings (`int`, *optional*, defaults to 1024): | |||
The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
just in case (e.g., 512 or 1024 or 2048). | |||
init_std (`float`, *optional*, defaults to 0.02): | |||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
for more details. | |||
decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
for more details. | |||
use_cache (`bool`, *optional*, defaults to `True`): | |||
Whether or not the model should return the last key/values attentions (not used by all models). | |||
""" | |||
model_type = 'ofa' | |||
keys_to_ignore_at_inference = ['past_key_values'] | |||
attribute_map = { | |||
'num_attention_heads': 'encoder_attention_heads', | |||
'hidden_size': 'd_model' | |||
} | |||
def __init__(self, | |||
vocab_size=59457, | |||
max_position_embeddings=1024, | |||
encoder_layers=4, | |||
encoder_ffn_dim=512 * 4, | |||
encoder_attention_heads=8, | |||
decoder_layers=4, | |||
decoder_ffn_dim=512 * 4, | |||
decoder_attention_heads=8, | |||
encoder_layerdrop=0.0, | |||
decoder_layerdrop=0.0, | |||
use_cache=True, | |||
is_encoder_decoder=True, | |||
activation_function='gelu', | |||
d_model=512, | |||
dropout=0.1, | |||
attention_dropout=0.0, | |||
activation_dropout=0.0, | |||
init_std=0.02, | |||
classifier_dropout=0.0, | |||
scale_embedding=False, | |||
pad_token_id=1, | |||
bos_token_id=0, | |||
decoder_start_token_id=0, | |||
eos_token_id=2, | |||
forced_eos_token_id=2, | |||
encoder_normalize_before=True, | |||
decoder_normalize_before=True, | |||
normformer=True, | |||
encoder_drop_path_rate=0.0, | |||
decoder_drop_path_rate=0.0, | |||
layernorm_embedding=True, | |||
patch_layernorm_embedding=True, | |||
resnet_type='resnet101', | |||
resnet_model_path=None, | |||
resnet_drop_path_rate=0.0, | |||
token_bucket_size=256, | |||
image_bucket_size=42, | |||
add_type_embedding=True, | |||
share_decoder_input_output_embed=True, | |||
attn_scale_factor=2., | |||
code_layernorm_embedding=False, | |||
code_image_size=128, | |||
entangle_position_embedding=False, | |||
interpolate_position=False, | |||
orig_patch_image_size=224, | |||
share_attn_bias=False, | |||
use_image_feature=True, | |||
disable_entangle=False, | |||
use_ofasys=False, | |||
vit_type='vit_base', | |||
vit_drop_path_rate=0.0, | |||
required_seq_len_multiple=2, | |||
encoder_pos_conv_depth=5, | |||
encoder_conv_pos=95, | |||
encoder_conv_pos_groups=16, | |||
encoder_max_positions=100000, | |||
phone_vocab_size=141, | |||
audio_mask_prob=0.65, | |||
audio_mask_selection='static', | |||
audio_mask_other=0, | |||
audio_mask_length=10, | |||
audio_no_mask_overlap=False, | |||
audio_mask_min_space=1, | |||
audio_mask_channel_prob=0.0, | |||
audio_mask_channel_before=False, | |||
audio_mask_channel_selection='static', | |||
audio_mask_channel_other=0, | |||
audio_mask_channel_length=10, | |||
audio_no_mask_channel_overlap=False, | |||
audio_mask_channel_min_space=1, | |||
encoder_dropout_input=0.0, | |||
encoder_dropout_features=0.0, | |||
phone_dict_size=124, | |||
**kwargs): | |||
self.vocab_size = vocab_size | |||
self.max_position_embeddings = max_position_embeddings | |||
self.d_model = d_model | |||
self.encoder_ffn_dim = encoder_ffn_dim | |||
self.encoder_layers = encoder_layers | |||
self.encoder_attention_heads = encoder_attention_heads | |||
self.decoder_ffn_dim = decoder_ffn_dim | |||
self.decoder_layers = decoder_layers | |||
self.decoder_attention_heads = decoder_attention_heads | |||
self.dropout = dropout | |||
self.attention_dropout = attention_dropout | |||
self.activation_dropout = activation_dropout | |||
self.activation_function = activation_function | |||
self.init_std = init_std | |||
self.encoder_layerdrop = encoder_layerdrop | |||
self.decoder_layerdrop = decoder_layerdrop | |||
self.classifier_dropout = classifier_dropout | |||
self.use_cache = use_cache | |||
self.num_hidden_layers = encoder_layers | |||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||
self.encoder_normalize_before = encoder_normalize_before | |||
self.decoder_normalize_before = decoder_normalize_before | |||
self.normformer = normformer | |||
self.encoder_drop_path_rate = encoder_drop_path_rate | |||
self.decoder_drop_path_rate = decoder_drop_path_rate | |||
self.layernorm_embedding = layernorm_embedding | |||
self.patch_layernorm_embedding = patch_layernorm_embedding | |||
self.resnet_type = resnet_type | |||
self.resnet_model_path = resnet_model_path | |||
self.resnet_drop_path_rate = resnet_drop_path_rate | |||
self.token_bucket_size = token_bucket_size | |||
self.image_bucket_size = image_bucket_size | |||
self.add_type_embedding = add_type_embedding | |||
self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||
self.attn_scale_factor = attn_scale_factor | |||
self.code_layernorm_embedding = code_layernorm_embedding | |||
self.code_image_size = code_image_size | |||
self.entangle_position_embedding = entangle_position_embedding | |||
self.interpolate_position = interpolate_position | |||
self.orig_patch_image_size = orig_patch_image_size | |||
self.share_attn_bias = share_attn_bias | |||
self.use_image_feature = use_image_feature | |||
self.disable_entangle = disable_entangle | |||
self.use_ofasys = use_ofasys | |||
self.vit_type = vit_type | |||
self.vit_drop_path_rate = vit_drop_path_rate | |||
# FP16 optimization | |||
self.required_seq_len_multiple = required_seq_len_multiple | |||
# encoder_pos_conv | |||
self.encoder_pos_conv_depth = encoder_pos_conv_depth | |||
self.encoder_conv_pos = encoder_conv_pos | |||
self.encoder_conv_pos_groups = encoder_conv_pos_groups | |||
self.encoder_max_positions = encoder_max_positions | |||
# phone | |||
self.phone_vocab_size = phone_vocab_size | |||
# audio_mask | |||
self.audio_mask_prob = audio_mask_prob | |||
self.audio_mask_selection = audio_mask_selection | |||
self.audio_mask_other = audio_mask_other | |||
self.audio_mask_length = audio_mask_length | |||
self.audio_no_mask_overlap = audio_no_mask_overlap | |||
self.audio_mask_min_space = audio_mask_min_space | |||
self.audio_mask_channel_prob = audio_mask_channel_prob | |||
self.audio_mask_channel_before = audio_mask_channel_before | |||
self.audio_mask_channel_selection = audio_mask_channel_selection | |||
self.audio_mask_channel_other = audio_mask_channel_other | |||
self.audio_mask_channel_length = audio_mask_channel_length | |||
self.audio_no_mask_channel_overlap = audio_no_mask_channel_overlap | |||
self.audio_mask_channel_min_space = audio_mask_channel_min_space | |||
# audio encoder | |||
self.encoder_dropout_input = encoder_dropout_input | |||
self.encoder_dropout_features = encoder_dropout_features | |||
self.phone_dict_size = phone_dict_size | |||
super().__init__( | |||
pad_token_id=pad_token_id, | |||
bos_token_id=bos_token_id, | |||
eos_token_id=eos_token_id, | |||
is_encoder_decoder=is_encoder_decoder, | |||
decoder_start_token_id=decoder_start_token_id, | |||
forced_eos_token_id=forced_eos_token_id, | |||
**kwargs, | |||
) | |||
# ensure backward compatibility for BART CNN models | |||
if self.forced_bos_token_id is None and kwargs.get( | |||
'force_bos_token_to_be_generated', False): | |||
self.forced_bos_token_id = self.bos_token_id | |||
warnings.warn( | |||
f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||
'The config can simply be saved and uploaded again to be fixed.' | |||
) |
@@ -227,6 +227,9 @@ class SequenceGenerator(nn.Module): | |||
- net_input['padding_mask'].sum(-1) | |||
if net_input['padding_mask'] is not None else torch.tensor( | |||
src_tokens.size(-1)).to(src_tokens)) | |||
elif 'fbank' in net_input: | |||
src_tokens = net_input['fbank'] | |||
src_lengths = net_input['fbank_length'] | |||
else: | |||
raise Exception( | |||
'expected src_tokens or source in net input. input keys: ' | |||
@@ -11,4 +11,5 @@ OFA_TASK_KEY_MAPPING = { | |||
Tasks.text_classification: OutputKeys.LABELS, | |||
Tasks.image_classification: OutputKeys.LABELS, | |||
Tasks.visual_entailment: OutputKeys.LABELS, | |||
Tasks.auto_speech_recognition: OutputKeys.TEXT | |||
} |
@@ -19,7 +19,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.trie import Trie | |||
from .ofa import OFAModel, OFATokenizer, OFATokenizerZH | |||
from .ofa import MMSpeechModel, OFAModel, OFATokenizer, OFATokenizerZH | |||
from .ofa.generate import sequence_generator as sg | |||
from .ofa.generate.utils import move_to_device | |||
from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | |||
@@ -37,13 +37,20 @@ __all__ = ['OfaForAllTasks'] | |||
@MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | |||
@MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | |||
@MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | |||
@MODELS.register_module(Tasks.auto_speech_recognition, module_name=Models.ofa) | |||
class OfaForAllTasks(TorchModel): | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super().__init__(model_dir=model_dir, *args, **kwargs) | |||
model = OFAModel.from_pretrained(model_dir) | |||
self.cfg = Config.from_file( | |||
osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
multimodal_type = self.cfg.model.get('multimodal_type', 'default') | |||
if multimodal_type == 'default': | |||
model = OFAModel.from_pretrained(model_dir) | |||
elif multimodal_type == 'mmspeech': | |||
model = MMSpeechModel.from_pretrained(model_dir) | |||
else: | |||
raise NotImplementedError | |||
self.model = model.module if hasattr(model, 'module') else model | |||
self.language = self.cfg.model.get('language', 'en') | |||
if self.language == 'en': | |||
@@ -54,12 +61,20 @@ class OfaForAllTasks(TorchModel): | |||
raise NotImplementedError | |||
# there is some diff between here and our ofa code, | |||
# there will be no need to use param: use_bpe | |||
if not model.use_ofasys: | |||
self.tokenizer.add_tokens( | |||
['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens( | |||
['<bin_{}>'.format(i) for i in range(1000)]) | |||
self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
if multimodal_type == 'default': | |||
self.tokenizer.add_tokens( | |||
['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens( | |||
['<bin_{}>'.format(i) for i in range(1000)]) | |||
self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
elif multimodal_type == 'mmspeech': | |||
self.tokenizer.add_tokens('<blank>') | |||
self.tokenizer.add_tokens( | |||
['<audio_{}>'.format(i) for i in range(30000)]) | |||
self.cfg.update({'num_bins': 0, 'num_codes': 30000}) | |||
self.batch_size = self.cfg.model.get('batch_size', 1) | |||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
self.max_image_size = self.cfg.model.get('max_image_size', 512) | |||
@@ -110,6 +125,7 @@ class OfaForAllTasks(TorchModel): | |||
Tasks.visual_question_answering: inference_d[self.gen_type], | |||
Tasks.text_classification: inference_d[self.gen_type], | |||
Tasks.image_classification: inference_d[self.gen_type], | |||
Tasks.auto_speech_recognition: self._text_gen_inference, | |||
} | |||
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | |||
self.pattern = re.compile(pattern_str) | |||
@@ -186,7 +186,10 @@ TASK_INPUTS = { | |||
# ============ audio tasks =================== | |||
Tasks.auto_speech_recognition: | |||
InputType.AUDIO, | |||
[InputType.AUDIO, { | |||
'wav': InputType.AUDIO, | |||
'text': InputType.TEXT | |||
}], | |||
Tasks.speech_signal_process: | |||
InputType.AUDIO, | |||
Tasks.acoustic_echo_cancellation: { | |||
@@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||
from .video_multi_modal_embedding_pipeline import \ | |||
VideoMultiModalEmbeddingPipeline | |||
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||
from .asr_pipeline import AutomaticSpeechRecognitionPipeline | |||
else: | |||
_import_structure = { | |||
@@ -26,7 +27,8 @@ else: | |||
'video_multi_modal_embedding_pipeline': | |||
['VideoMultiModalEmbeddingPipeline'], | |||
'generative_multi_modal_embedding_pipeline': | |||
['GEMMMultiModalEmbeddingPipeline'] | |||
['GEMMMultiModalEmbeddingPipeline'], | |||
'asr_pipeline': ['AutomaticSpeechRecognitionPipeline'], | |||
} | |||
import sys | |||
@@ -0,0 +1,54 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, | |||
Preprocessor) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.auto_speech_recognition, module_name=Pipelines.ofa_asr) | |||
class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
""" | |||
use `model` and `preprocessor` to create an automatic speech recognition pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
assert isinstance(model, str) or isinstance(model, Model), \ | |||
'model must be a single str or OfaForAllTasks' | |||
if isinstance(model, str): | |||
pipe_model = Model.from_pretrained(model) | |||
elif isinstance(model, Model): | |||
pipe_model = model | |||
else: | |||
raise NotImplementedError | |||
pipe_model.model.eval() | |||
if preprocessor is None: | |||
if isinstance(pipe_model, OfaForAllTasks): | |||
preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||
elif isinstance(pipe_model, MPlugForAllTasks): | |||
preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -53,7 +53,8 @@ class OfaPreprocessor(Preprocessor): | |||
Tasks.image_classification: OfaImageClassificationPreprocessor, | |||
Tasks.text_classification: OfaTextClassificationPreprocessor, | |||
Tasks.text_summarization: OfaSummarizationPreprocessor, | |||
Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||
Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor, | |||
Tasks.auto_speech_recognition: OfaASRPreprocessor | |||
} | |||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
model_dir) | |||
@@ -1,4 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .asr import OfaASRPreprocessor | |||
from .image_captioning import OfaImageCaptioningPreprocessor | |||
from .image_classification import OfaImageClassificationPreprocessor | |||
from .ocr_recognition import OfaOcrRecognitionPreprocessor | |||
@@ -0,0 +1,121 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import random | |||
from pathlib import Path | |||
from typing import Any, Dict | |||
import soundfile as sf | |||
import torch | |||
from fairseq.data.audio.feature_transforms import \ | |||
CompositeAudioFeatureTransform | |||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig | |||
from modelscope.utils.chinese_utils import pre_chinese | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
from .utils.text2phone import Text2Phone | |||
class OfaASRPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args, | |||
**kwargs) | |||
# Initialize transform | |||
self.data_cfg = S2TDataConfig( | |||
Path(os.path.join(model_dir, 'fbank_config.yaml'))) | |||
self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||
self.data_cfg.get_feature_transforms('train', True)) | |||
self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||
self.data_cfg.get_feature_transforms('test', False)) | |||
self.text2phone_tokenizer = Text2Phone( | |||
os.path.join(model_dir, 'text2phone_dict.txt')) | |||
self.phone_to_id, self.id_to_phone = self.build_phone_dict( | |||
os.path.join(model_dir, 'phone_dict.txt')) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
if self.mode == ModeKeys.TRAIN: | |||
return self._build_train_sample(data) | |||
else: | |||
return self._build_infer_sample(data) | |||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
speed = random.choice([0.9, 1.0, 1.1]) | |||
wav, sr = sf.read(self.column_map['wav']) | |||
fbank = self.prepare_fbank( | |||
torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True) | |||
fbank_mask = torch.tensor([True]) | |||
sample = { | |||
'fbank': fbank, | |||
'fbank_mask': fbank_mask, | |||
'label': data[self.column_map['text']] | |||
} | |||
target = sample['label'] | |||
if self.language == 'zh': | |||
target = pre_chinese(target, self.max_tgt_length) | |||
sample['target'] = self.tokenize_text(target, add_bos=False) | |||
else: | |||
target = target.translate(self.transtab).strip() | |||
target_token_list = target.strip().split() | |||
target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
sample['target'] = self.tokenize_text(target, add_bos=False) | |||
phone_item = self.to_phone(target) - 3 | |||
phone_mask = torch.tensor([False]) | |||
sample['phone_item'] = phone_item | |||
sample['phone_mask'] = phone_mask | |||
sample['prev_output_tokens'] = torch.cat( | |||
[self.bos_item, sample['target'][:-1]]) | |||
return sample | |||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
speed = 1.0 | |||
wav, sr = sf.read(data[self.column_map['wav']]) | |||
fbank = self.prepare_fbank( | |||
torch.tensor([wav], dtype=torch.float32), | |||
sr, | |||
speed, | |||
is_train=False) | |||
fbank_mask = torch.tensor([True]) | |||
sample = {'fbank': fbank, 'fbank_mask': fbank_mask} | |||
if 'text' in self.column_map and self.column_map['text'] in data: | |||
sample['label'] = data[self.column_map['text']] | |||
# mock | |||
sample['phone_item'] = torch.tensor([6, 6, 6]) | |||
sample['phone_mask'] = torch.tensor([False]) | |||
return sample | |||
def to_phone(self, text): | |||
phones = self.text2phone_tokenizer.trans(text) | |||
ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')]) | |||
return ids | |||
def build_phone_dict(self, phone_dict_path): | |||
phone_to_id = dict() | |||
id_to_phone = dict() | |||
with open(phone_dict_path, 'r') as phone_dict_file: | |||
for i, line in enumerate(phone_dict_file): | |||
phone = line.strip().split(' ')[0] | |||
phone_to_id[phone] = i | |||
id_to_phone[i] = phone_to_id | |||
return phone_to_id, id_to_phone |
@@ -6,11 +6,14 @@ from os import path as osp | |||
import json | |||
import numpy as np | |||
import torch | |||
import torchaudio | |||
from PIL import Image | |||
from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.trie import Trie | |||
from .utils.audio_helper import (_get_kaldi_fbank, _get_torchaudio_fbank, | |||
convert_waveform) | |||
from .utils.constant import OFA_TASK_KEY_MAPPING | |||
from .utils.random_help import set_torch_seed | |||
@@ -88,6 +91,9 @@ class OfaBasePreprocessor: | |||
+ answer_item.tolist() | |||
+ [tokenizer.eos_token_id]) | |||
self.train_audio_feature_transforms = None | |||
self.test_audio_feature_transforms = None | |||
def tokenize_text(self, text, add_bos=True, add_eos=True): | |||
if text is None: | |||
return None | |||
@@ -163,3 +169,36 @@ class OfaBasePreprocessor: | |||
image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | |||
else load_image(path_or_url_or_pil) | |||
return image | |||
def prepare_fbank(self, waveform, sample_rate, speed, is_train): | |||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor( | |||
waveform, sample_rate, | |||
[['speed', str(speed)], ['rate', str(sample_rate)]]) | |||
_waveform, _ = convert_waveform( | |||
waveform, sample_rate, to_mono=True, normalize_volume=True) | |||
# Kaldi compliance: 16-bit signed integers | |||
_waveform = _waveform * (2**15) | |||
_waveform = _waveform.numpy() | |||
fbank = _get_kaldi_fbank(_waveform, sample_rate, 80) | |||
if fbank is None: | |||
fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80) | |||
if fbank is None: | |||
raise ImportError( | |||
'Please install pyKaldi or torchaudio to enable fbank feature extraction' | |||
) | |||
if is_train and self.train_audio_feature_transforms is not None: | |||
fbank = self.train_audio_feature_transforms(fbank) | |||
elif ~is_train and self.test_audio_feature_transforms( | |||
fbank) is not None: | |||
fbank = self.test_audio_feature_transforms(fbank) | |||
fbank = torch.from_numpy(fbank).float() | |||
fbank = self.pack_frames(fbank) | |||
return fbank | |||
def pack_frames(self, feature: torch.Tensor): | |||
if self.cfg.n_frames_per_step == 1: | |||
return feature | |||
n_packed_frames = feature.shape[0] // self.cfg.n_frames_per_step | |||
feature = feature[:self.cfg.n_frames_per_step * n_packed_frames] | |||
return feature.reshape(n_packed_frames, -1) |
@@ -0,0 +1,91 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Optional, Tuple, Union | |||
import numpy as np | |||
import torch | |||
def convert_waveform( | |||
waveform: Union[np.ndarray, torch.Tensor], | |||
sample_rate: int, | |||
normalize_volume: bool = False, | |||
to_mono: bool = False, | |||
to_sample_rate: Optional[int] = None, | |||
) -> Tuple[Union[np.ndarray, torch.Tensor], int]: | |||
"""convert a waveform: | |||
- to a target sample rate | |||
- from multi-channel to mono channel | |||
- volume normalization | |||
Args: | |||
waveform (numpy.ndarray or torch.Tensor): 2D original waveform | |||
(channels x length) | |||
sample_rate (int): original sample rate | |||
normalize_volume (bool): perform volume normalization | |||
to_mono (bool): convert to mono channel if having multiple channels | |||
to_sample_rate (Optional[int]): target sample rate | |||
Returns: | |||
waveform (numpy.ndarray): converted 2D waveform (channels x length) | |||
sample_rate (float): target sample rate | |||
""" | |||
try: | |||
import torchaudio.sox_effects as ta_sox | |||
except ImportError: | |||
raise ImportError('Please install torchaudio: pip install torchaudio') | |||
effects = [] | |||
if normalize_volume: | |||
effects.append(['gain', '-n']) | |||
if to_sample_rate is not None and to_sample_rate != sample_rate: | |||
effects.append(['rate', f'{to_sample_rate}']) | |||
if to_mono and waveform.shape[0] > 1: | |||
effects.append(['channels', '1']) | |||
if len(effects) > 0: | |||
is_np_input = isinstance(waveform, np.ndarray) | |||
_waveform = torch.from_numpy(waveform) if is_np_input else waveform | |||
converted, converted_sample_rate = ta_sox.apply_effects_tensor( | |||
_waveform, sample_rate, effects) | |||
if is_np_input: | |||
converted = converted.numpy() | |||
return converted, converted_sample_rate | |||
return waveform, sample_rate | |||
def _get_kaldi_fbank(waveform: np.ndarray, | |||
sample_rate: int, | |||
n_bins=80) -> Optional[np.ndarray]: | |||
"""Get mel-filter bank features via PyKaldi.""" | |||
try: | |||
from kaldi.feat.fbank import Fbank, FbankOptions | |||
from kaldi.feat.mel import MelBanksOptions | |||
from kaldi.feat.window import FrameExtractionOptions | |||
from kaldi.matrix import Vector | |||
mel_opts = MelBanksOptions() | |||
mel_opts.num_bins = n_bins | |||
frame_opts = FrameExtractionOptions() | |||
frame_opts.samp_freq = sample_rate | |||
opts = FbankOptions() | |||
opts.mel_opts = mel_opts | |||
opts.frame_opts = frame_opts | |||
fbank = Fbank(opts=opts) | |||
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() | |||
return features | |||
except ImportError: | |||
return None | |||
def _get_torchaudio_fbank(waveform: np.ndarray, | |||
sample_rate, | |||
n_bins=80) -> Optional[np.ndarray]: | |||
"""Get mel-filter bank features via TorchAudio.""" | |||
try: | |||
import torchaudio.compliance.kaldi as ta_kaldi | |||
waveform = torch.from_numpy(waveform) | |||
features = ta_kaldi.fbank( | |||
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate) | |||
return features.numpy() | |||
except ImportError: | |||
return None |
@@ -1,5 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import List | |||
import numpy as np | |||
import torch | |||
@@ -13,14 +15,12 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
pad_idx, | |||
eos_idx=eos_idx) | |||
src_tokens = merge('source') | |||
batch = { | |||
'nsentences': len(samples), | |||
'net_input': { | |||
'input_ids': src_tokens, | |||
}, | |||
'net_input': {}, | |||
} | |||
if samples[0].get('source', None) is not None: | |||
batch['net_input']['input_ids'] = merge('source') | |||
if samples[0].get('id', None) is not None: | |||
batch['id'] = np.array([s.get['id'] for s in samples]) | |||
if samples[0].get('target', None) is not None: | |||
@@ -70,6 +70,20 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
[s['region_coord'] for s in samples], dim=0) | |||
if samples[0].get('sample', None) is not None: | |||
batch['samples'] = [s['sample'] for s in samples] | |||
# For asr | |||
if samples[0].get('fbank', None) is not None: | |||
batch['net_input']['fbank'] = _collate_frames( | |||
[s['fbank'] for s in samples]) | |||
batch['net_input']['fbank_length'] = torch.tensor( | |||
[s['fbank'].size(0) for s in samples], dtype=torch.long) | |||
if samples[0].get('fbank_mask', None) is not None: | |||
batch['net_input']['fbank_masks'] = torch.cat( | |||
[s['fbank_mask'] for s in samples]) | |||
if samples[0].get('phone_item', None) is not None: | |||
batch['net_input']['phone_items'] = merge('phone_item') | |||
batch['net_input']['phone_masks'] = torch.cat( | |||
[s['phone_mask'] for s in samples]) | |||
return batch | |||
@@ -113,3 +127,19 @@ def collate_tokens( | |||
for i, v in enumerate(values): | |||
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |||
return res | |||
def _collate_frames(frames: List[torch.Tensor]): | |||
""" | |||
Convert a list of 2D frames into a padded 3D tensor | |||
Args: | |||
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is | |||
length of i-th frame and f_dim is static dimension of features | |||
Returns: | |||
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |||
""" | |||
max_len = max(frame.size(0) for frame in frames) | |||
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) | |||
for i, v in enumerate(frames): | |||
out[i, :v.size(0)] = v | |||
return out |
@@ -9,5 +9,6 @@ OFA_TASK_KEY_MAPPING = { | |||
Tasks.visual_grounding: ['image', 'text'], | |||
Tasks.visual_question_answering: ['image', 'text'], | |||
Tasks.visual_entailment: ['image', 'text', 'text2'], | |||
Tasks.text_to_image_synthesis: ['text'] | |||
Tasks.text_to_image_synthesis: ['text'], | |||
Tasks.auto_speech_recognition: ['wav', 'text'], | |||
} |
@@ -0,0 +1,192 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.chinese_utils import normalize_chinese_number | |||
class TrieNode(object): | |||
def __init__(self): | |||
""" | |||
Initialize your data structure here. | |||
""" | |||
self.data = {} | |||
self.is_word = False | |||
class Trie(object): | |||
""" | |||
trie-tree | |||
""" | |||
def __init__(self): | |||
""" | |||
Initialize your data structure here. | |||
""" | |||
self.root = TrieNode() | |||
def insert(self, word): | |||
""" | |||
Inserts a word into the trie. | |||
:type word: str | |||
:rtype: void | |||
""" | |||
node = self.root | |||
for chars in word: | |||
child = node.data.get(chars) | |||
if not child: | |||
node.data[chars] = TrieNode() | |||
node = node.data[chars] | |||
node.is_word = True | |||
def search(self, word): | |||
""" | |||
Returns if the word is in the trie. | |||
:type word: str | |||
:rtype: bool | |||
""" | |||
node = self.root | |||
for chars in word: | |||
node = node.data.get(chars) | |||
if not node: | |||
return False | |||
return node.is_word | |||
def startsWith(self, prefix): | |||
""" | |||
Returns if there is any word in the trie that starts with the given prefix. | |||
:type prefix: str | |||
:rtype: bool | |||
""" | |||
node = self.root | |||
for chars in prefix: | |||
node = node.data.get(chars) | |||
if not node: | |||
return False | |||
return True | |||
def get_start(self, prefix): | |||
""" | |||
Returns words started with prefix | |||
:param prefix: | |||
:return: words (list) | |||
""" | |||
def get_key(pre, pre_node): | |||
word_list = [] | |||
if pre_node.is_word: | |||
word_list.append(pre) | |||
for x in pre_node.data.keys(): | |||
word_list.extend(get_key(pre + str(x), pre_node.data.get(x))) | |||
return word_list | |||
words = [] | |||
if not self.startsWith(prefix): | |||
return words | |||
if self.search(prefix): | |||
words.append(prefix) | |||
return words | |||
node = self.root | |||
for chars in prefix: | |||
node = node.data.get(chars) | |||
return get_key(prefix, node) | |||
class TrieTokenizer(Trie): | |||
""" | |||
word_split based on trie-tree | |||
""" | |||
def __init__(self, dict_path): | |||
super(TrieTokenizer, self).__init__() | |||
self.dict_path = dict_path | |||
self.create_trie_tree() | |||
def load_dict(self): | |||
words = [] | |||
with open(self.dict_path, mode='r', encoding='utf-8') as file: | |||
for line in file: | |||
words.append(line.strip().split('\t')[0].encode( | |||
'utf-8').decode('utf-8-sig')) | |||
return words | |||
def create_trie_tree(self): | |||
words = self.load_dict() | |||
for word in words: | |||
self.insert(word) | |||
def mine_tree(self, tree, sentence, trace_index): | |||
if trace_index <= (len(sentence) - 1): | |||
if sentence[trace_index] in tree.data: | |||
trace_index = trace_index + 1 | |||
trace_index = self.mine_tree( | |||
tree.data[sentence[trace_index - 1]], sentence, | |||
trace_index) | |||
return trace_index | |||
def tokenize(self, sentence): | |||
tokens = [] | |||
sentence_len = len(sentence) | |||
while sentence_len != 0: | |||
trace_index = 0 | |||
trace_index = self.mine_tree(self.root, sentence, trace_index) | |||
if trace_index == 0: | |||
tokens.append(sentence[0:1]) | |||
sentence = sentence[1:len(sentence)] | |||
sentence_len = len(sentence) | |||
else: | |||
tokens.append(sentence[0:trace_index]) | |||
sentence = sentence[trace_index:len(sentence)] | |||
sentence_len = len(sentence) | |||
return tokens | |||
def combine(self, token_list): | |||
flag = 0 | |||
output = [] | |||
temp = [] | |||
for i in token_list: | |||
if len(i) != 1: | |||
if flag == 0: | |||
output.append(i[::]) | |||
else: | |||
output.append(''.join(temp)) | |||
output.append(i[::]) | |||
temp = [] | |||
flag = 0 | |||
else: | |||
if flag == 0: | |||
temp.append(i) | |||
flag = 1 | |||
else: | |||
temp.append(i) | |||
return output | |||
class Text2Phone: | |||
def __init__(self, phone_dict_path): | |||
self.trie_cws = TrieTokenizer(phone_dict_path) | |||
self.phone_map = self.get_phone_map(phone_dict_path) | |||
def get_phone_map(self, phone_dict_path): | |||
phone_map = dict() | |||
with open(phone_dict_path, 'r') as phone_map_file_reader: | |||
for line in phone_map_file_reader: | |||
key, phone_series = line.strip().split('\t') | |||
if key not in phone_map: | |||
phone_map[key] = phone_series | |||
return phone_map | |||
def trans(self, text): | |||
text = normalize_chinese_number(text) | |||
tokens = self.trie_cws.tokenize(text) | |||
phones = [] | |||
for word in tokens: | |||
if word in self.phone_map: | |||
phones.append(self.phone_map[word]) | |||
elif len(word) > 1: | |||
for char in word: | |||
if char in self.phone_map: | |||
phones.append(self.phone_map[char]) | |||
return ' '.join(phones) |
@@ -113,6 +113,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
self.use_rdrop = args.get('use_rdrop', False) | |||
self.reg_alpha = args.get('reg_alpha', 1.0) | |||
self.sample_patch_num = args.get('sample_patch_num', 196) | |||
self.ctc_weight = args.get('ctc_weight', 0.0) | |||
self.constraint_start = None | |||
self.constraint_end = None | |||
@@ -141,6 +142,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
output = model.model(**sample['net_input']) | |||
loss, nll_loss, ntokens = self.compute_loss( | |||
output.logits, sample, update_num, reduce=reduce) | |||
if self.ctc_weight > 0: | |||
ctc_loss = self.compute_ctc_loss(model, output, sample) | |||
loss = nll_loss + ctc_loss | |||
sample_size = ( | |||
sample['target'].size(0) if self.sentence_avg else ntokens) | |||
logging_output = { | |||
@@ -206,6 +210,32 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
constraint_end=self.constraint_end) | |||
return loss, nll_loss, ntokens | |||
def compute_ctc_loss(self, model, output, sample): | |||
lprobs = model.get_encoder_normalized_probs( | |||
output, log_probs=True).contiguous() # (T, B, C) from the encoder | |||
non_padding_mask = ~output.encoder_padding_mask | |||
input_lengths = non_padding_mask.long().sum(-1) | |||
target_lengths = sample['ctc_output_lengths'] | |||
pad_mask = torch.arange(target_lengths.max()).expand([ | |||
target_lengths.shape[0], -1 | |||
]).to(target_lengths) < target_lengths.unsqueeze(1) | |||
targets_flat = sample['ctc_outputs'].masked_select(pad_mask) | |||
with torch.backends.cudnn.flags(enabled=False): | |||
loss = F.ctc_loss( | |||
lprobs, | |||
targets_flat, | |||
input_lengths, | |||
target_lengths, | |||
blank=self.blank_idx, | |||
reduction='sum', | |||
zero_infinity=True, | |||
) | |||
return loss | |||
def get_schedule(scheduler): | |||
@@ -1,5 +1,13 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import re | |||
import string | |||
from zhconv import convert | |||
CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。' | |||
ENGLISH_PUNCTUATION = string.punctuation | |||
def is_chinese_char(word: str): | |||
chinese_punctuations = { | |||
@@ -33,3 +41,28 @@ def rebuild_chinese_str(string: str): | |||
return ' '.join(''.join([ | |||
f' {char} ' if is_chinese_char(char) else char for char in string | |||
]).split()) | |||
def normalize_chinese_number(text): | |||
chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九'] | |||
new_text = '' | |||
for x in text: | |||
if x in '0123456789': | |||
x = chinese_number[0] | |||
new_text += x | |||
new_text = convert(new_text, 'zh-hans') | |||
return new_text | |||
def pre_chinese(text, max_words): | |||
text = text.lower().replace(CHINESE_PUNCTUATION, | |||
' ').replace(ENGLISH_PUNCTUATION, ' ') | |||
text = re.sub( | |||
r'\s{2,}', | |||
' ', | |||
text, | |||
) | |||
text = text.rstrip('\n') | |||
text = text.strip(' ')[:max_words] | |||
return text |
@@ -8,6 +8,7 @@ pytorch_lightning<=1.7.7 | |||
# which introduced compatability issues that are being investigated | |||
rouge_score<=0.0.4 | |||
sacrebleu | |||
soundfile | |||
taming-transformers-rom1504 | |||
timm | |||
tokenizers | |||
@@ -273,6 +273,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
print(f'Output written to {osp.abspath("result.png")}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_asr_with_name(self): | |||
model = 'damo/ofa_asr_pretrain_base_zh' | |||
ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) | |||
example = {'wav': 'data/test/audios/asr_example_ofa.wav'} | |||
result = ofa_pipe(example) | |||
print(result[OutputKeys.TEXT]) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||