From 7661470350f529556f2b63f383af4e204476df56 Mon Sep 17 00:00:00 2001 From: "shiyi.zxh" Date: Fri, 25 Nov 2022 12:16:33 +0800 Subject: [PATCH] =?UTF-8?q?ofa=E5=A2=9E=E5=8A=A0asr=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ofa增加asr任务infer Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10761019 --- data/test/audios/asr_example_ofa.wav | 3 + modelscope/metainfo.py | 1 + modelscope/models/multi_modal/ofa/__init__.py | 1 + .../multi_modal/ofa/configuration_mmspeech.py | 260 ++++ .../ofa/generate/sequence_generator.py | 3 + .../multi_modal/ofa/modeling_mmspeech.py | 1075 +++++++++++++++++ .../models/multi_modal/ofa/utils/constant.py | 1 + .../models/multi_modal/ofa_for_all_tasks.py | 30 +- modelscope/pipeline_inputs.py | 5 +- modelscope/pipelines/multi_modal/__init__.py | 4 +- .../pipelines/multi_modal/asr_pipeline.py | 54 + modelscope/preprocessors/multi_modal.py | 3 +- modelscope/preprocessors/ofa/__init__.py | 1 + modelscope/preprocessors/ofa/asr.py | 121 ++ modelscope/preprocessors/ofa/base.py | 39 + .../preprocessors/ofa/utils/audio_helper.py | 91 ++ modelscope/preprocessors/ofa/utils/collate.py | 40 +- .../preprocessors/ofa/utils/constant.py | 3 +- .../preprocessors/ofa/utils/text2phone.py | 192 +++ .../multi_modal/ofa/ofa_trainer_utils.py | 30 + modelscope/utils/chinese_utils.py | 33 + requirements/multi-modal.txt | 1 + tests/pipelines/test_ofa_tasks.py | 8 + 23 files changed, 1983 insertions(+), 16 deletions(-) create mode 100644 data/test/audios/asr_example_ofa.wav create mode 100644 modelscope/models/multi_modal/ofa/configuration_mmspeech.py create mode 100644 modelscope/models/multi_modal/ofa/modeling_mmspeech.py create mode 100644 modelscope/pipelines/multi_modal/asr_pipeline.py create mode 100644 modelscope/preprocessors/ofa/asr.py create mode 100644 modelscope/preprocessors/ofa/utils/audio_helper.py create mode 100644 modelscope/preprocessors/ofa/utils/text2phone.py diff --git a/data/test/audios/asr_example_ofa.wav b/data/test/audios/asr_example_ofa.wav new file mode 100644 index 00000000..4e35a2c9 --- /dev/null +++ b/data/test/audios/asr_example_ofa.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46dbc998c9d1d48111267c40741dd3200f2e5bcf4075f8c4c97f4451160dce50 +size 134570 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 5b56e09a..a5cafdb7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -284,6 +284,7 @@ class Pipelines(object): video_multi_modal_embedding = 'video-multi-modal-embedding' image_text_retrieval = 'image-text-retrieval' ofa_ocr_recognition = 'ofa-ocr-recognition' + ofa_asr = 'ofa-asr' # science tasks protein_structure = 'unifold-protein-structure' diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py index 3e8e59f4..da2d09fb 100644 --- a/modelscope/models/multi_modal/ofa/__init__.py +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .modeling_mmspeech import MMSpeechModel from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel from .tokenization_ofa import OFATokenizer, OFATokenizerZH from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast diff --git a/modelscope/models/multi_modal/ofa/configuration_mmspeech.py b/modelscope/models/multi_modal/ofa/configuration_mmspeech.py new file mode 100644 index 00000000..37be12e9 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/configuration_mmspeech.py @@ -0,0 +1,260 @@ +# Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MMSpeech model configuration""" +import warnings + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class MMSpeechConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = 'ofa' + keys_to_ignore_at_inference = ['past_key_values'] + + attribute_map = { + 'num_attention_heads': 'encoder_attention_heads', + 'hidden_size': 'd_model' + } + + def __init__(self, + vocab_size=59457, + max_position_embeddings=1024, + encoder_layers=4, + encoder_ffn_dim=512 * 4, + encoder_attention_heads=8, + decoder_layers=4, + decoder_ffn_dim=512 * 4, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function='gelu', + d_model=512, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + pad_token_id=1, + bos_token_id=0, + decoder_start_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + encoder_normalize_before=True, + decoder_normalize_before=True, + normformer=True, + encoder_drop_path_rate=0.0, + decoder_drop_path_rate=0.0, + layernorm_embedding=True, + patch_layernorm_embedding=True, + resnet_type='resnet101', + resnet_model_path=None, + resnet_drop_path_rate=0.0, + token_bucket_size=256, + image_bucket_size=42, + add_type_embedding=True, + share_decoder_input_output_embed=True, + attn_scale_factor=2., + code_layernorm_embedding=False, + code_image_size=128, + entangle_position_embedding=False, + interpolate_position=False, + orig_patch_image_size=224, + share_attn_bias=False, + use_image_feature=True, + disable_entangle=False, + use_ofasys=False, + vit_type='vit_base', + vit_drop_path_rate=0.0, + required_seq_len_multiple=2, + encoder_pos_conv_depth=5, + encoder_conv_pos=95, + encoder_conv_pos_groups=16, + encoder_max_positions=100000, + phone_vocab_size=141, + audio_mask_prob=0.65, + audio_mask_selection='static', + audio_mask_other=0, + audio_mask_length=10, + audio_no_mask_overlap=False, + audio_mask_min_space=1, + audio_mask_channel_prob=0.0, + audio_mask_channel_before=False, + audio_mask_channel_selection='static', + audio_mask_channel_other=0, + audio_mask_channel_length=10, + audio_no_mask_channel_overlap=False, + audio_mask_channel_min_space=1, + encoder_dropout_input=0.0, + encoder_dropout_features=0.0, + phone_dict_size=124, + **kwargs): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.encoder_normalize_before = encoder_normalize_before + self.decoder_normalize_before = decoder_normalize_before + self.normformer = normformer + self.encoder_drop_path_rate = encoder_drop_path_rate + self.decoder_drop_path_rate = decoder_drop_path_rate + self.layernorm_embedding = layernorm_embedding + self.patch_layernorm_embedding = patch_layernorm_embedding + self.resnet_type = resnet_type + self.resnet_model_path = resnet_model_path + self.resnet_drop_path_rate = resnet_drop_path_rate + self.token_bucket_size = token_bucket_size + self.image_bucket_size = image_bucket_size + self.add_type_embedding = add_type_embedding + self.share_decoder_input_output_embed = share_decoder_input_output_embed + self.attn_scale_factor = attn_scale_factor + self.code_layernorm_embedding = code_layernorm_embedding + self.code_image_size = code_image_size + self.entangle_position_embedding = entangle_position_embedding + self.interpolate_position = interpolate_position + self.orig_patch_image_size = orig_patch_image_size + + self.share_attn_bias = share_attn_bias + self.use_image_feature = use_image_feature + self.disable_entangle = disable_entangle + self.use_ofasys = use_ofasys + self.vit_type = vit_type + self.vit_drop_path_rate = vit_drop_path_rate + + # FP16 optimization + self.required_seq_len_multiple = required_seq_len_multiple + + # encoder_pos_conv + self.encoder_pos_conv_depth = encoder_pos_conv_depth + self.encoder_conv_pos = encoder_conv_pos + self.encoder_conv_pos_groups = encoder_conv_pos_groups + self.encoder_max_positions = encoder_max_positions + + # phone + self.phone_vocab_size = phone_vocab_size + + # audio_mask + self.audio_mask_prob = audio_mask_prob + self.audio_mask_selection = audio_mask_selection + self.audio_mask_other = audio_mask_other + self.audio_mask_length = audio_mask_length + self.audio_no_mask_overlap = audio_no_mask_overlap + self.audio_mask_min_space = audio_mask_min_space + + self.audio_mask_channel_prob = audio_mask_channel_prob + self.audio_mask_channel_before = audio_mask_channel_before + self.audio_mask_channel_selection = audio_mask_channel_selection + self.audio_mask_channel_other = audio_mask_channel_other + self.audio_mask_channel_length = audio_mask_channel_length + self.audio_no_mask_channel_overlap = audio_no_mask_channel_overlap + self.audio_mask_channel_min_space = audio_mask_channel_min_space + + # audio encoder + self.encoder_dropout_input = encoder_dropout_input + self.encoder_dropout_features = encoder_dropout_features + + self.phone_dict_size = phone_dict_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + # ensure backward compatibility for BART CNN models + if self.forced_bos_token_id is None and kwargs.get( + 'force_bos_token_to_be_generated', False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' + 'The config can simply be saved and uploaded again to be fixed.' + ) diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py index e42d3c8e..c86f171e 100644 --- a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -227,6 +227,9 @@ class SequenceGenerator(nn.Module): - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor( src_tokens.size(-1)).to(src_tokens)) + elif 'fbank' in net_input: + src_tokens = net_input['fbank'] + src_lengths = net_input['fbank_length'] else: raise Exception( 'expected src_tokens or source in net input. input keys: ' diff --git a/modelscope/models/multi_modal/ofa/modeling_mmspeech.py b/modelscope/models/multi_modal/ofa/modeling_mmspeech.py new file mode 100644 index 00000000..07d5b7e8 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/modeling_mmspeech.py @@ -0,0 +1,1075 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OFA model.""" + +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer +from fairseq.modules import LayerNorm, SamePad, TransposeLast +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import index_put +from packaging import version +from torch import Tensor, nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, + Seq2SeqModelOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_mmspeech import MMSpeechConfig +from .generate import utils +from .modeling_ofa import (Embedding, OFADecoder, OFAModel, OFAPreTrainedModel, + _expand_mask, shift_tokens_right) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'mmspeech-base' +_CONFIG_FOR_DOC = 'MMSpeechConfig' +_TOKENIZER_FOR_DOC = 'OFATokenizer' +TORCH_VERSION = version.parse(torch.__version__) +TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +OFA_PRETRAINED_MODEL_ARCHIVE_LIST = ['mmspeech-base', 'mmspeech-large'] + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +class MMSpeechPreTrainedModel(OFAPreTrainedModel): + r""" + Base class OFA + """ + + config_class = MMSpeechConfig + + def _set_gradient_checkpointing(self, module, value=False): + r""" + Turn on the switch of gradient checkpointing. + """ + if isinstance(module, (OFADecoder, MMSpeechEncoder)): + module.gradient_checkpointing = value + + +@dataclass +class MMSpeechEncoderOutput(ModelOutput): + r""" + Base class for OFA's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + Sequence of hidden-states at the output of the last layer of the model. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed + or when `config.output_hidden_states=True`): + + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(bsz, seq_len, hidden)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed + or when `config.output_attentions=True`): + + Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + postional embeddings of the inputs. + """ + + phone_distribution: torch.Tensor = None + last_hidden_state: torch.Tensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + kl_loss: Optional[torch.Tensor] = None + + +@dataclass +class MMSpeechModelOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, + returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights of the decoder, after the attention softmax, + used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights of the decoder's cross-attention layer, + after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, + *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed + or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_padding_mask: Optional[torch.Tensor] = None + phone_distribution: Optional[torch.Tensor] = None + kl_loss: Optional[torch.Tensor] = None + + +MMSPEECH_START_DOCSTRING = r""" + This model inherits from [`OFAModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`~MMSpeechConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MMSPEECH_GENERATION_EXAMPLE = r""" + Image captioning example: + + ```python + >>> import soundfile as sf + >>> import torchaudio + >>> import torchaudio.compliance.kaldi as ta_kaldi + >>> wav, sr = sf.read(data[self.column_map['wav']]) + >>> wav = torchaudio.sox_effects.apply_effects_tensor( + >>> wav, sr, + >>> [['speed', '1.0'], ['rate', '16000'], ['gain', '-n'], ['channels', '1']])) + >>> wav = wav * (2**15) + >>> wav = torch.from_numpy(wav.numpy()) + >>> fbank = ta_kaldi.fbank( + waveform, num_mel_bins=n_bins, sample_frequency=sample_rate) + >>> fbank_mask = torch.tensor([True]) + >>> model = MMSpeechModel.from_pretrained(ckpt_dir) + >>> tokenizer = OFATokenizerZH.from_pretrained(ckpt_dir) + + >>> gen = model.generate(fbank=fbank, fbank_mask=fbank_mask, num_beams=4) + >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +MMSPEECH_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + fbank (`torch.Tensor`): fbank feature of audio. + fbank_length (`torch.Tensor`): fbank length of audio. + fbank_masks (`torch.BoolTensor`): whether to have fbank feature. + phone_items (`torch.Tensor`): phoneme sequence. + phone_masks (`torch.BoolTensor`): whether to have phoneme feature. + features_only (`torch.BoolTensor`): whether to return encoder features only. + mask (`torch.BoolTensor`): whether to mask fbank feature. + mask_prob (`torch.Tensor`): the prob of mask fbank feature. + layer (`int`): the number of layer to cache hidden state. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. +""" + + +class Conv2dSubsampling4(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in range(2): + out = ((out.float() - 1) // 2 + 1).floor().long() + return out + + def forward(self, x: torch.Tensor, + x_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + return x, self.get_out_seq_lens_tensor(x_length) + + +class TransformerEncoder(nn.Module): + + def build_encoder_layer(self, args: MMSpeechConfig): + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_function, + layer_norm_first=args.encoder_normalize_before, + ) + return layer + + def __init__(self, args: MMSpeechConfig): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.d_model + self.required_seq_len_multiple = args.required_seq_len_multiple + + pos_conv_depth = args.encoder_pos_conv_depth + if pos_conv_depth > 1: + num_layers = args.encoder_pos_conv_depth + k = max(3, args.encoder_conv_pos // num_layers) + + def make_conv_block(e, k, g, la): + return nn.Sequential(*[ + nn.Sequential( + nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) for _ in range(la) + ]) + + self.pos_conv = make_conv_block(self.embedding_dim, k, + args.encoder_conv_pos_groups, + num_layers) + self.phone_pos_conv = make_conv_block(self.embedding_dim, k, + args.encoder_conv_pos_groups, + num_layers) + + else: + + def make_conv_pos(e, k, g): + pos_conv = nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + nn.init.normal_(pos_conv.weight, mean=0, std=std) + nn.init.constant_(pos_conv.bias, 0) + + pos_conv = nn.utils.weight_norm(pos_conv, name='weight', dim=2) + pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) + + return pos_conv + + self.pos_conv = make_conv_pos( + self.embedding_dim, + args.encoder_conv_pos, + args.encoder_conv_pos_groups, + ) + self.phone_pos_conv = make_conv_pos( + self.embedding_dim, + args.encoder_conv_pos, + args.encoder_conv_pos_groups, + ) + + self.layers = nn.ModuleList([ + self.build_encoder_layer(args) for _ in range(args.encoder_layers) + ]) + self.layer_norm_first = args.encoder_normalize_before + + self.layer_norm = LayerNorm(self.embedding_dim) + self.phone_layer_norm = LayerNorm(self.embedding_dim) + + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, + x, + padding_mask=None, + phone_x=None, + phone_padding_mask=None, + layer=None, + context_layer=None): + x, layer_results, x_conv, pre_padding_mask = self.extract_features( + x, + padding_mask, + phone_x, + phone_padding_mask, + layer, + context_layer=context_layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results, x_conv, pre_padding_mask + + def extract_features( + self, + x, + padding_mask=None, + phone_x=None, + phone_padding_mask=None, + tgt_layer=None, + min_layer=0, + context_layer=None, + ): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + if phone_x is not None: + if phone_padding_mask is not None: + phone_x = index_put(phone_x, phone_padding_mask, 0) + + phone_x_conv = self.phone_pos_conv(phone_x.transpose(1, 2)) + phone_x_conv = phone_x_conv.transpose(1, 2) + phone_x = phone_x + phone_x_conv + + if not self.layer_norm_first: + # to fix + phone_x = self.layer_norm(phone_x) + + pre_padding_mask = padding_mask.clone() + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + + if i < context_layer and (~padding_mask).any() is False: + continue + + if i == context_layer and phone_x is not None and phone_x_conv is not None: + x = x.transpose(0, 1) + x = torch.cat([x, phone_x], dim=1) + padding_mask = torch.cat([padding_mask, phone_padding_mask], + dim=1) + pre_padding_mask = padding_mask.clone() + x_conv = torch.cat([x_conv, phone_x_conv], dim=1) + x = x.transpose(0, 1) + + dropout_probability = np.random.random( + ) if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results, x_conv, pre_padding_mask + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.encoder_max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class MMSpeechEncoder(MMSpeechPreTrainedModel): + + def __init__(self, + cfg: MMSpeechConfig, + embed_tokens: Optional[nn.Embedding] = None): + + super().__init__(cfg) + + self.cfg = cfg + + self.embed = cfg.d_model + + # fbank encoder + self.subsample = Conv2dSubsampling4(80 * 1, cfg.d_model) + self.post_subsample_proj = nn.Linear(cfg.d_model, cfg.d_model) + + # phone and text encoder + self.padding_idx = embed_tokens.padding_idx + self.phone_padding_idx = self.padding_idx + self.phone_item_embedding = Embedding(cfg.phone_vocab_size, self.embed, + self.phone_padding_idx) + + # mask + self.mask_prob = cfg.audio_mask_prob + self.mask_selection = cfg.audio_mask_selection + self.mask_other = cfg.audio_mask_other + self.mask_length = cfg.audio_mask_length + self.no_mask_overlap = cfg.audio_no_mask_overlap + self.mask_min_space = cfg.audio_mask_min_space + + self.mask_channel_prob = cfg.audio_mask_channel_prob + self.mask_channel_before = cfg.audio_mask_channel_before + self.mask_channel_selection = cfg.audio_mask_channel_selection + self.mask_channel_other = cfg.audio_mask_channel_other + self.mask_channel_length = cfg.audio_mask_channel_length + self.no_mask_channel_overlap = cfg.audio_no_mask_channel_overlap + self.mask_channel_min_space = cfg.audio_mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.encoder_dropout_input) + self.dropout_features = nn.Dropout(cfg.encoder_dropout_features) + + self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.d_model).uniform_()) + + self.encoder = TransformerEncoder(cfg) + + self.final_proj = nn.Linear(self.embed, self.embed) + + self.num_updates = 0 + + def get_input_embeddings(self): + r""" + Get the embedding weight. + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + r""" + Set the weight of embedding with the given tensor. + """ + self.embed_tokens = value + + def apply_mask(self, + x, + padding_mask, + mask_indices=None, + mask_channel_indices=None, + mask_prob=None): + B, T, C = x.shape + + if self.mask_channel_prob > 0 and self.mask_channel_before: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices).to( + x.device).unsqueeze(1).expand(-1, T, -1)) + x[mask_channel_indices] = 0 + + if self.mask_prob > 0 or mask_prob is not None: + if mask_indices is None: + if mask_prob is None: + mask_prob = self.mask_prob + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=1, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + require_same_masks=self.cfg.require_same_masks, + mask_dropout=self.cfg.mask_dropout, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + else: + mask_indices = None + + if self.mask_channel_prob > 0 and not self.mask_channel_before: + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices).to( + x.device).unsqueeze(1).expand(-1, T, -1)) + x = index_put(x, mask_channel_indices, 0) + + return x, mask_indices + + def _get_feat_extract_output_lengths(self, + input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length(input_lengths, + conv_cfg_list[i][1], + conv_cfg_list[i][2]) + + return input_lengths.to(torch.long) + + def forward(self, + fbank: Optional[torch.Tensor] = None, + fbank_length: Optional[torch.Tensor] = None, + fbank_masks: Optional[torch.Tensor] = None, + phone_items: Optional[torch.Tensor] = None, + phone_masks: Optional[torch.Tensor] = None, + features_only: Optional[torch.Tensor] = True, + mask: Optional[torch.Tensor] = False, + mask_prob: Optional[torch.Tensor] = None, + layer=None, + output_hidden_states=False): + + features, fbank_feature_length = self.subsample(fbank, fbank_length) + + if self.post_subsample_proj is not None: + features = self.post_subsample_proj(features) + + padding_mask = ( + torch.BoolTensor(features.shape[:2]).fill_(False) + # if self.pad_audio else None + ).to(features.device) + for i, l in enumerate(fbank_feature_length): + diff = l - padding_mask.shape[-1] + if diff < 0: + padding_mask[i, diff:] = True + + pre_encoder_features = features.clone() + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask, mask_prob=mask_prob) + else: + x = features + mask_indices = None + + padding_mask[~fbank_masks] = True + + phone_x = None + phone_padding_mask = None + if phone_items is not None: + phone_x = self.phone_item_embedding(phone_items) + phone_padding_mask = phone_items.eq(self.phone_padding_idx) + phone_padding_mask[~phone_masks] = True + if mask_indices is not None: + phone_mask_indices = phone_padding_mask.new_zeros( + phone_padding_mask.size()).bool() + mask_indices = torch.cat([mask_indices, phone_mask_indices], + dim=1) + + pre_padding_mask = padding_mask.clone() + x, layer_results, pos_embed, padding_mask = self.encoder( + x, + padding_mask=padding_mask, + phone_x=phone_x, + phone_padding_mask=phone_padding_mask, + layer=layer, + context_layer=6) + + emb_weight = self.phone_item_embedding.weight[ + 3:self.cfg.phone_dict_size, :] + if features_only is False: # no gradient for embedding here + emb_weight = emb_weight.detach() + + phone_distribution = F.linear(x, emb_weight, None) + + if features_only: + return MMSpeechEncoderOutput( + phone_distribution=phone_distribution.transpose(0, 1), + last_hidden_state=x, + padding_mask=padding_mask, + position_embedding=pos_embed) + + result = { + 'losses': {}, + } + + with torch.no_grad(): + self.encoder.eval() + y, y_layer_results, _, _ = self.encoder.extract_features( + pre_encoder_features, + padding_mask=pre_padding_mask, + phone_x=phone_x, + phone_padding_mask=phone_padding_mask, + min_layer= + 0, # self.cfg.encoder_layers - self.average_top_k_layers, + context_layer=6) + y = { + 'x': y, + 'padding_mask': padding_mask, + 'layer_results': y_layer_results, + } + + emb_weight = self.phone_item_embedding.weight[ + 3:self.cfg.phone_dict_size, :] + + y = F.linear(y['x'], emb_weight, None) + y = y[mask_indices] + self.encoder.train() + + y_student = phone_distribution[mask_indices] + + def _kl_loss(p, q): + loss = F.kl_div( + utils.log_softmax(p, dim=-1), + utils.softmax(q, dim=-1), + reduction='sum') + return loss + + y = y + kl_loss = _kl_loss(y_student.float(), y.float()) + + with torch.no_grad(): + result['target_var'] = self.compute_var(y) + result['pred_var'] = self.compute_var(y_student.float()) + + if self.num_updates > 5000 and result[ + 'target_var'] < self.cfg.min_target_var: + logger.error( + f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" + ) + raise Exception( + f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" + ) + if self.num_updates > 5000 and result[ + 'pred_var'] < self.cfg.min_pred_var: + logger.error( + f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" + ) + raise Exception( + f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" + ) + + return MMSpeechEncoderOutput( + phone_distribution=phone_distribution.transpose(0, 1), + last_hidden_state=x, + padding_mask=padding_mask, + position_embedding=pos_embed, + kl_loss=kl_loss) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + # if encoder_out["last_hidden_state"] is None: + if 'last_hidden_state' not in encoder_out: + new_encoder_out = None + else: + new_encoder_out = encoder_out['last_hidden_state'].index_select( + 0, new_order) + # if encoder_out["padding_mask"] is None: + if 'padding_mask' not in encoder_out: + new_encoder_padding_mask = None + else: + new_encoder_padding_mask = encoder_out[ + 'padding_mask'].index_select(0, new_order) + + # if encoder_out["position_embedding"] is None: + if 'position_embedding' not in encoder_out: + new_position_embeddings = None + else: + new_position_embeddings = encoder_out[ + 'position_embedding'].index_select(0, new_order) + + if 'hidden_states' not in encoder_out: + new_encoer_states = None + else: + encoder_states = encoder_out['hidden_states'] + new_encoer_states = () + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + new_encoer_states += (state.index_select(0, new_order), ) + + if 'attentions' not in encoder_out: + attentions = None + else: + attentions = encoder_out['attentions'] + + new_kl_loss = None + if 'kl_loss' in encoder_out: + new_kl_loss = encoder_out['kl_loss'] + + if len(encoder_out['phone_distribution']) == 0: + new_phone_distribution = None + else: + new_phone_distribution = encoder_out[ + 'phone_distribution'].index_select(1, new_order) + + return MMSpeechEncoderOutput( + phone_distribution=new_phone_distribution, + last_hidden_state=new_encoder_out, # B x T x C + padding_mask=new_encoder_padding_mask, # B x T + hidden_states=new_encoer_states, # List[T x B x C] + attentions=attentions, + position_embedding=new_position_embeddings, # B x T x C + kl_loss=new_kl_loss) + + @staticmethod + def compute_var(y): + y = y.view(-1, y.size(-1)) + if dist.is_initialized(): + zc = torch.tensor(y.size(0)).cuda() + zs = y.sum(dim=0) + zss = (y**2).sum(dim=0) + + dist.all_reduce(zc) + dist.all_reduce(zs) + dist.all_reduce(zss) + + var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) + return torch.sqrt(var + 1e-6).mean() + else: + return torch.sqrt(y.var(dim=0) + 1e-6).mean() + + +@add_start_docstrings( + 'The bare OFA Model outputting raw hidden-states without any specific head on top.', + MMSPEECH_START_DOCSTRING, +) +class MMSpeechModel(OFAModel): + r""" + The OFA model built with an encoder and a decoder only, without any classification head. + + Args: + config (MMSpeechConfig): OFA configuration. + """ + + config_class = MMSpeechConfig + + def __init__(self, config: MMSpeechConfig, **kwargs): + super().__init__(config) + self.disable_entangle = getattr(kwargs, 'disable_entangle', False) + + self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size + shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx) + + self.encoder = MMSpeechEncoder(config, shared) + self.decoder = OFADecoder(config, shared) + self.use_ofasys = config.use_ofasys + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MMSPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MMSpeechModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def get_encoder_normalized_probs(self, net_output, log_probs, **kwargs): + """Get normalized probabilities (or log probs) from a net's output.""" + logits = net_output['phone_distribution'] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, + input_ids=None, + patch_images=None, + patch_images_2=None, + patch_masks=None, + token_embeddings=None, + sample_patch_num=None, + fbank=None, + fbank_length=None, + fbank_masks=None, + phone_items=None, + phone_masks=None, + features_only=True, + mask=False, + mask_prob=None, + layer=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + + Returns: + OFASpeechOutput: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states. + past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. + decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers. + decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers. + cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder last hidden state. + encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder states of all layers including the embeddings. + encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): + the encoder attention weights of all layers. + """ # noqa + + output_attentions = output_attentions if output_attentions else self.config.output_attentions + output_hidden_states = ( + output_hidden_states + if output_hidden_states else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if encoder_outputs is None: + encoder_outputs = self.encoder( + fbank=fbank, + fbank_length=fbank_length, + fbank_masks=fbank_masks, + phone_items=phone_items, + phone_masks=phone_masks, + features_only=features_only, + mask=mask, + mask_prob=mask_prob, + layer=layer) + + if decoder_input_ids.eq(self.config.pad_token_id).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + src_pos_embed=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return MMSpeechModelOutput( + logits=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_padding_mask=encoder_outputs.padding_mask, + phone_distribution=encoder_outputs.phone_distribution, + kl_loss=encoder_outputs.kl_loss) + + def _set_gradient_checkpointing(self, module, value=False): + r""" + Turn on the switch of gradient checkpointing. + """ + if isinstance(module, (OFADecoder, MMSpeechEncoder)): + module.gradient_checkpointing = value diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py index b3776f8f..48e90336 100644 --- a/modelscope/models/multi_modal/ofa/utils/constant.py +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -11,4 +11,5 @@ OFA_TASK_KEY_MAPPING = { Tasks.text_classification: OutputKeys.LABELS, Tasks.image_classification: OutputKeys.LABELS, Tasks.visual_entailment: OutputKeys.LABELS, + Tasks.auto_speech_recognition: OutputKeys.TEXT } diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 77dff54a..1ae746b7 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -19,7 +19,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile from modelscope.utils.trie import Trie -from .ofa import OFAModel, OFATokenizer, OFATokenizerZH +from .ofa import MMSpeechModel, OFAModel, OFATokenizer, OFATokenizerZH from .ofa.generate import sequence_generator as sg from .ofa.generate.utils import move_to_device from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks @@ -37,13 +37,20 @@ __all__ = ['OfaForAllTasks'] @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) +@MODELS.register_module(Tasks.auto_speech_recognition, module_name=Models.ofa) class OfaForAllTasks(TorchModel): def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir=model_dir, *args, **kwargs) - model = OFAModel.from_pretrained(model_dir) self.cfg = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION)) + multimodal_type = self.cfg.model.get('multimodal_type', 'default') + if multimodal_type == 'default': + model = OFAModel.from_pretrained(model_dir) + elif multimodal_type == 'mmspeech': + model = MMSpeechModel.from_pretrained(model_dir) + else: + raise NotImplementedError self.model = model.module if hasattr(model, 'module') else model self.language = self.cfg.model.get('language', 'en') if self.language == 'en': @@ -54,12 +61,20 @@ class OfaForAllTasks(TorchModel): raise NotImplementedError # there is some diff between here and our ofa code, # there will be no need to use param: use_bpe + if not model.use_ofasys: - self.tokenizer.add_tokens( - [''.format(i) for i in range(8192)]) - self.tokenizer.add_tokens( - [''.format(i) for i in range(1000)]) - self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) + if multimodal_type == 'default': + self.tokenizer.add_tokens( + [''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens( + [''.format(i) for i in range(1000)]) + self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) + elif multimodal_type == 'mmspeech': + self.tokenizer.add_tokens('') + self.tokenizer.add_tokens( + [''.format(i) for i in range(30000)]) + self.cfg.update({'num_bins': 0, 'num_codes': 30000}) + self.batch_size = self.cfg.model.get('batch_size', 1) self.patch_image_size = self.cfg.model.get('patch_image_size', 480) self.max_image_size = self.cfg.model.get('max_image_size', 512) @@ -110,6 +125,7 @@ class OfaForAllTasks(TorchModel): Tasks.visual_question_answering: inference_d[self.gen_type], Tasks.text_classification: inference_d[self.gen_type], Tasks.image_classification: inference_d[self.gen_type], + Tasks.auto_speech_recognition: self._text_gen_inference, } pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' self.pattern = re.compile(pattern_str) diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 13560229..060049ef 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -186,7 +186,10 @@ TASK_INPUTS = { # ============ audio tasks =================== Tasks.auto_speech_recognition: - InputType.AUDIO, + [InputType.AUDIO, { + 'wav': InputType.AUDIO, + 'text': InputType.TEXT + }], Tasks.speech_signal_process: InputType.AUDIO, Tasks.acoustic_echo_cancellation: { diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 55906e43..d5c171a3 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from .video_multi_modal_embedding_pipeline import \ VideoMultiModalEmbeddingPipeline from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline + from .asr_pipeline import AutomaticSpeechRecognitionPipeline else: _import_structure = { @@ -26,7 +27,8 @@ else: 'video_multi_modal_embedding_pipeline': ['VideoMultiModalEmbeddingPipeline'], 'generative_multi_modal_embedding_pipeline': - ['GEMMMultiModalEmbeddingPipeline'] + ['GEMMMultiModalEmbeddingPipeline'], + 'asr_pipeline': ['AutomaticSpeechRecognitionPipeline'], } import sys diff --git a/modelscope/pipelines/multi_modal/asr_pipeline.py b/modelscope/pipelines/multi_modal/asr_pipeline.py new file mode 100644 index 00000000..3cb7439c --- /dev/null +++ b/modelscope/pipelines/multi_modal/asr_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, + Preprocessor) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.ofa_asr) +class AutomaticSpeechRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create an automatic speech recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + elif isinstance(pipe_model, MPlugForAllTasks): + preprocessor = MPlugPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 52cde61c..7ebedce1 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -53,7 +53,8 @@ class OfaPreprocessor(Preprocessor): Tasks.image_classification: OfaImageClassificationPreprocessor, Tasks.text_classification: OfaTextClassificationPreprocessor, Tasks.text_summarization: OfaSummarizationPreprocessor, - Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor + Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor, + Tasks.auto_speech_recognition: OfaASRPreprocessor } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py index 59b94b2b..ad6c3c48 100644 --- a/modelscope/preprocessors/ofa/__init__.py +++ b/modelscope/preprocessors/ofa/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .asr import OfaASRPreprocessor from .image_captioning import OfaImageCaptioningPreprocessor from .image_classification import OfaImageClassificationPreprocessor from .ocr_recognition import OfaOcrRecognitionPreprocessor diff --git a/modelscope/preprocessors/ofa/asr.py b/modelscope/preprocessors/ofa/asr.py new file mode 100644 index 00000000..928698c6 --- /dev/null +++ b/modelscope/preprocessors/ofa/asr.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import random +from pathlib import Path +from typing import Any, Dict + +import soundfile as sf +import torch +from fairseq.data.audio.feature_transforms import \ + CompositeAudioFeatureTransform +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig + +from modelscope.utils.chinese_utils import pre_chinese +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor +from .utils.text2phone import Text2Phone + + +class OfaASRPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args, + **kwargs) + # Initialize transform + self.data_cfg = S2TDataConfig( + Path(os.path.join(model_dir, 'fbank_config.yaml'))) + self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.data_cfg.get_feature_transforms('train', True)) + self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.data_cfg.get_feature_transforms('test', False)) + self.text2phone_tokenizer = Text2Phone( + os.path.join(model_dir, 'text2phone_dict.txt')) + self.phone_to_id, self.id_to_phone = self.build_phone_dict( + os.path.join(model_dir, 'phone_dict.txt')) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + speed = random.choice([0.9, 1.0, 1.1]) + wav, sr = sf.read(self.column_map['wav']) + fbank = self.prepare_fbank( + torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True) + fbank_mask = torch.tensor([True]) + sample = { + 'fbank': fbank, + 'fbank_mask': fbank_mask, + 'label': data[self.column_map['text']] + } + + target = sample['label'] + if self.language == 'zh': + target = pre_chinese(target, self.max_tgt_length) + sample['target'] = self.tokenize_text(target, add_bos=False) + else: + target = target.translate(self.transtab).strip() + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + + phone_item = self.to_phone(target) - 3 + phone_mask = torch.tensor([False]) + + sample['phone_item'] = phone_item + sample['phone_mask'] = phone_mask + + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + speed = 1.0 + wav, sr = sf.read(data[self.column_map['wav']]) + fbank = self.prepare_fbank( + torch.tensor([wav], dtype=torch.float32), + sr, + speed, + is_train=False) + fbank_mask = torch.tensor([True]) + + sample = {'fbank': fbank, 'fbank_mask': fbank_mask} + + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] + + # mock + sample['phone_item'] = torch.tensor([6, 6, 6]) + sample['phone_mask'] = torch.tensor([False]) + + return sample + + def to_phone(self, text): + phones = self.text2phone_tokenizer.trans(text) + ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')]) + return ids + + def build_phone_dict(self, phone_dict_path): + phone_to_id = dict() + id_to_phone = dict() + with open(phone_dict_path, 'r') as phone_dict_file: + for i, line in enumerate(phone_dict_file): + phone = line.strip().split(' ')[0] + phone_to_id[phone] = i + id_to_phone[i] = phone_to_id + return phone_to_id, id_to_phone diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index e5c30ff8..64bec9c9 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -6,11 +6,14 @@ from os import path as osp import json import numpy as np import torch +import torchaudio from PIL import Image from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH from modelscope.preprocessors.image import load_image from modelscope.utils.trie import Trie +from .utils.audio_helper import (_get_kaldi_fbank, _get_torchaudio_fbank, + convert_waveform) from .utils.constant import OFA_TASK_KEY_MAPPING from .utils.random_help import set_torch_seed @@ -88,6 +91,9 @@ class OfaBasePreprocessor: + answer_item.tolist() + [tokenizer.eos_token_id]) + self.train_audio_feature_transforms = None + self.test_audio_feature_transforms = None + def tokenize_text(self, text, add_bos=True, add_eos=True): if text is None: return None @@ -163,3 +169,36 @@ class OfaBasePreprocessor: image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ else load_image(path_or_url_or_pil) return image + + def prepare_fbank(self, waveform, sample_rate, speed, is_train): + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + _waveform, _ = convert_waveform( + waveform, sample_rate, to_mono=True, normalize_volume=True) + # Kaldi compliance: 16-bit signed integers + _waveform = _waveform * (2**15) + _waveform = _waveform.numpy() + fbank = _get_kaldi_fbank(_waveform, sample_rate, 80) + if fbank is None: + fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80) + if fbank is None: + raise ImportError( + 'Please install pyKaldi or torchaudio to enable fbank feature extraction' + ) + if is_train and self.train_audio_feature_transforms is not None: + fbank = self.train_audio_feature_transforms(fbank) + elif ~is_train and self.test_audio_feature_transforms( + fbank) is not None: + fbank = self.test_audio_feature_transforms(fbank) + + fbank = torch.from_numpy(fbank).float() + fbank = self.pack_frames(fbank) + return fbank + + def pack_frames(self, feature: torch.Tensor): + if self.cfg.n_frames_per_step == 1: + return feature + n_packed_frames = feature.shape[0] // self.cfg.n_frames_per_step + feature = feature[:self.cfg.n_frames_per_step * n_packed_frames] + return feature.reshape(n_packed_frames, -1) diff --git a/modelscope/preprocessors/ofa/utils/audio_helper.py b/modelscope/preprocessors/ofa/utils/audio_helper.py new file mode 100644 index 00000000..40cb2241 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/audio_helper.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + + +def convert_waveform( + waveform: Union[np.ndarray, torch.Tensor], + sample_rate: int, + normalize_volume: bool = False, + to_mono: bool = False, + to_sample_rate: Optional[int] = None, +) -> Tuple[Union[np.ndarray, torch.Tensor], int]: + """convert a waveform: + - to a target sample rate + - from multi-channel to mono channel + - volume normalization + + Args: + waveform (numpy.ndarray or torch.Tensor): 2D original waveform + (channels x length) + sample_rate (int): original sample rate + normalize_volume (bool): perform volume normalization + to_mono (bool): convert to mono channel if having multiple channels + to_sample_rate (Optional[int]): target sample rate + Returns: + waveform (numpy.ndarray): converted 2D waveform (channels x length) + sample_rate (float): target sample rate + """ + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError('Please install torchaudio: pip install torchaudio') + + effects = [] + if normalize_volume: + effects.append(['gain', '-n']) + if to_sample_rate is not None and to_sample_rate != sample_rate: + effects.append(['rate', f'{to_sample_rate}']) + if to_mono and waveform.shape[0] > 1: + effects.append(['channels', '1']) + if len(effects) > 0: + is_np_input = isinstance(waveform, np.ndarray) + _waveform = torch.from_numpy(waveform) if is_np_input else waveform + converted, converted_sample_rate = ta_sox.apply_effects_tensor( + _waveform, sample_rate, effects) + if is_np_input: + converted = converted.numpy() + return converted, converted_sample_rate + return waveform, sample_rate + + +def _get_kaldi_fbank(waveform: np.ndarray, + sample_rate: int, + n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via PyKaldi.""" + try: + from kaldi.feat.fbank import Fbank, FbankOptions + from kaldi.feat.mel import MelBanksOptions + from kaldi.feat.window import FrameExtractionOptions + from kaldi.matrix import Vector + + mel_opts = MelBanksOptions() + mel_opts.num_bins = n_bins + frame_opts = FrameExtractionOptions() + frame_opts.samp_freq = sample_rate + opts = FbankOptions() + opts.mel_opts = mel_opts + opts.frame_opts = frame_opts + fbank = Fbank(opts=opts) + features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() + return features + except ImportError: + return None + + +def _get_torchaudio_fbank(waveform: np.ndarray, + sample_rate, + n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via TorchAudio.""" + try: + import torchaudio.compliance.kaldi as ta_kaldi + + waveform = torch.from_numpy(waveform) + features = ta_kaldi.fbank( + waveform, num_mel_bins=n_bins, sample_frequency=sample_rate) + return features.numpy() + except ImportError: + return None diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index f7775680..440ea9a0 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List + import numpy as np import torch @@ -13,14 +15,12 @@ def collate_fn(samples, pad_idx, eos_idx): pad_idx, eos_idx=eos_idx) - src_tokens = merge('source') - batch = { 'nsentences': len(samples), - 'net_input': { - 'input_ids': src_tokens, - }, + 'net_input': {}, } + if samples[0].get('source', None) is not None: + batch['net_input']['input_ids'] = merge('source') if samples[0].get('id', None) is not None: batch['id'] = np.array([s.get['id'] for s in samples]) if samples[0].get('target', None) is not None: @@ -70,6 +70,20 @@ def collate_fn(samples, pad_idx, eos_idx): [s['region_coord'] for s in samples], dim=0) if samples[0].get('sample', None) is not None: batch['samples'] = [s['sample'] for s in samples] + # For asr + if samples[0].get('fbank', None) is not None: + batch['net_input']['fbank'] = _collate_frames( + [s['fbank'] for s in samples]) + batch['net_input']['fbank_length'] = torch.tensor( + [s['fbank'].size(0) for s in samples], dtype=torch.long) + if samples[0].get('fbank_mask', None) is not None: + batch['net_input']['fbank_masks'] = torch.cat( + [s['fbank_mask'] for s in samples]) + if samples[0].get('phone_item', None) is not None: + batch['net_input']['phone_items'] = merge('phone_item') + batch['net_input']['phone_masks'] = torch.cat( + [s['phone_mask'] for s in samples]) + return batch @@ -113,3 +127,19 @@ def collate_tokens( for i, v in enumerate(values): copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) return res + + +def _collate_frames(frames: List[torch.Tensor]): + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, :v.size(0)] = v + return out diff --git a/modelscope/preprocessors/ofa/utils/constant.py b/modelscope/preprocessors/ofa/utils/constant.py index 102d27c0..8a33092e 100644 --- a/modelscope/preprocessors/ofa/utils/constant.py +++ b/modelscope/preprocessors/ofa/utils/constant.py @@ -9,5 +9,6 @@ OFA_TASK_KEY_MAPPING = { Tasks.visual_grounding: ['image', 'text'], Tasks.visual_question_answering: ['image', 'text'], Tasks.visual_entailment: ['image', 'text', 'text2'], - Tasks.text_to_image_synthesis: ['text'] + Tasks.text_to_image_synthesis: ['text'], + Tasks.auto_speech_recognition: ['wav', 'text'], } diff --git a/modelscope/preprocessors/ofa/utils/text2phone.py b/modelscope/preprocessors/ofa/utils/text2phone.py new file mode 100644 index 00000000..20773c85 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/text2phone.py @@ -0,0 +1,192 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.chinese_utils import normalize_chinese_number + + +class TrieNode(object): + + def __init__(self): + """ + Initialize your data structure here. + """ + self.data = {} + self.is_word = False + + +class Trie(object): + """ + trie-tree + """ + + def __init__(self): + """ + Initialize your data structure here. + """ + self.root = TrieNode() + + def insert(self, word): + """ + Inserts a word into the trie. + :type word: str + :rtype: void + """ + node = self.root + for chars in word: + child = node.data.get(chars) + if not child: + node.data[chars] = TrieNode() + node = node.data[chars] + node.is_word = True + + def search(self, word): + """ + Returns if the word is in the trie. + :type word: str + :rtype: bool + """ + node = self.root + for chars in word: + node = node.data.get(chars) + if not node: + return False + return node.is_word + + def startsWith(self, prefix): + """ + Returns if there is any word in the trie that starts with the given prefix. + :type prefix: str + :rtype: bool + """ + node = self.root + for chars in prefix: + node = node.data.get(chars) + if not node: + return False + return True + + def get_start(self, prefix): + """ + Returns words started with prefix + :param prefix: + :return: words (list) + """ + + def get_key(pre, pre_node): + word_list = [] + if pre_node.is_word: + word_list.append(pre) + for x in pre_node.data.keys(): + word_list.extend(get_key(pre + str(x), pre_node.data.get(x))) + return word_list + + words = [] + if not self.startsWith(prefix): + return words + if self.search(prefix): + words.append(prefix) + return words + node = self.root + for chars in prefix: + node = node.data.get(chars) + return get_key(prefix, node) + + +class TrieTokenizer(Trie): + """ + word_split based on trie-tree + """ + + def __init__(self, dict_path): + super(TrieTokenizer, self).__init__() + self.dict_path = dict_path + self.create_trie_tree() + + def load_dict(self): + words = [] + with open(self.dict_path, mode='r', encoding='utf-8') as file: + for line in file: + words.append(line.strip().split('\t')[0].encode( + 'utf-8').decode('utf-8-sig')) + return words + + def create_trie_tree(self): + words = self.load_dict() + for word in words: + self.insert(word) + + def mine_tree(self, tree, sentence, trace_index): + if trace_index <= (len(sentence) - 1): + if sentence[trace_index] in tree.data: + trace_index = trace_index + 1 + trace_index = self.mine_tree( + tree.data[sentence[trace_index - 1]], sentence, + trace_index) + return trace_index + + def tokenize(self, sentence): + tokens = [] + sentence_len = len(sentence) + while sentence_len != 0: + trace_index = 0 + trace_index = self.mine_tree(self.root, sentence, trace_index) + + if trace_index == 0: + tokens.append(sentence[0:1]) + sentence = sentence[1:len(sentence)] + sentence_len = len(sentence) + else: + tokens.append(sentence[0:trace_index]) + sentence = sentence[trace_index:len(sentence)] + sentence_len = len(sentence) + + return tokens + + def combine(self, token_list): + flag = 0 + output = [] + temp = [] + for i in token_list: + if len(i) != 1: + if flag == 0: + output.append(i[::]) + else: + output.append(''.join(temp)) + output.append(i[::]) + temp = [] + flag = 0 + else: + if flag == 0: + temp.append(i) + flag = 1 + else: + temp.append(i) + return output + + +class Text2Phone: + + def __init__(self, phone_dict_path): + self.trie_cws = TrieTokenizer(phone_dict_path) + self.phone_map = self.get_phone_map(phone_dict_path) + + def get_phone_map(self, phone_dict_path): + phone_map = dict() + with open(phone_dict_path, 'r') as phone_map_file_reader: + for line in phone_map_file_reader: + key, phone_series = line.strip().split('\t') + if key not in phone_map: + phone_map[key] = phone_series + return phone_map + + def trans(self, text): + text = normalize_chinese_number(text) + tokens = self.trie_cws.tokenize(text) + phones = [] + for word in tokens: + if word in self.phone_map: + phones.append(self.phone_map[word]) + elif len(word) > 1: + for char in word: + if char in self.phone_map: + phones.append(self.phone_map[char]) + return ' '.join(phones) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index 3930febb..c8cf6db5 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -113,6 +113,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): self.use_rdrop = args.get('use_rdrop', False) self.reg_alpha = args.get('reg_alpha', 1.0) self.sample_patch_num = args.get('sample_patch_num', 196) + self.ctc_weight = args.get('ctc_weight', 0.0) self.constraint_start = None self.constraint_end = None @@ -141,6 +142,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): output = model.model(**sample['net_input']) loss, nll_loss, ntokens = self.compute_loss( output.logits, sample, update_num, reduce=reduce) + if self.ctc_weight > 0: + ctc_loss = self.compute_ctc_loss(model, output, sample) + loss = nll_loss + ctc_loss sample_size = ( sample['target'].size(0) if self.sentence_avg else ntokens) logging_output = { @@ -206,6 +210,32 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): constraint_end=self.constraint_end) return loss, nll_loss, ntokens + def compute_ctc_loss(self, model, output, sample): + lprobs = model.get_encoder_normalized_probs( + output, log_probs=True).contiguous() # (T, B, C) from the encoder + + non_padding_mask = ~output.encoder_padding_mask + input_lengths = non_padding_mask.long().sum(-1) + + target_lengths = sample['ctc_output_lengths'] + pad_mask = torch.arange(target_lengths.max()).expand([ + target_lengths.shape[0], -1 + ]).to(target_lengths) < target_lengths.unsqueeze(1) + targets_flat = sample['ctc_outputs'].masked_select(pad_mask) + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction='sum', + zero_infinity=True, + ) + + return loss + def get_schedule(scheduler): diff --git a/modelscope/utils/chinese_utils.py b/modelscope/utils/chinese_utils.py index e5fe7aa8..793c2050 100644 --- a/modelscope/utils/chinese_utils.py +++ b/modelscope/utils/chinese_utils.py @@ -1,5 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import re +import string + +from zhconv import convert + +CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。' +ENGLISH_PUNCTUATION = string.punctuation + def is_chinese_char(word: str): chinese_punctuations = { @@ -33,3 +41,28 @@ def rebuild_chinese_str(string: str): return ' '.join(''.join([ f' {char} ' if is_chinese_char(char) else char for char in string ]).split()) + + +def normalize_chinese_number(text): + chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九'] + new_text = '' + for x in text: + if x in '0123456789': + x = chinese_number[0] + new_text += x + new_text = convert(new_text, 'zh-hans') + return new_text + + +def pre_chinese(text, max_words): + + text = text.lower().replace(CHINESE_PUNCTUATION, + ' ').replace(ENGLISH_PUNCTUATION, ' ') + text = re.sub( + r'\s{2,}', + ' ', + text, + ) + text = text.rstrip('\n') + text = text.strip(' ')[:max_words] + return text diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 31e9601d..54049c56 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -8,6 +8,7 @@ pytorch_lightning<=1.7.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 sacrebleu +soundfile taming-transformers-rom1504 timm tokenizers diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index bd8a8d48..9e1b47a1 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -273,6 +273,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result[OutputKeys.OUTPUT_IMG].save('result.png') print(f'Output written to {osp.abspath("result.png")}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_asr_with_name(self): + model = 'damo/ofa_asr_pretrain_base_zh' + ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) + example = {'wav': 'data/test/audios/asr_example_ofa.wav'} + result = ofa_pipe(example) + print(result[OutputKeys.TEXT]) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check()