Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9410174master
@@ -20,8 +20,10 @@ class GenericAutomaticSpeechRecognition(Model): | |||
Args: | |||
model_dir (str): the model path. | |||
am_model_name (str): the am model name from configuration.json | |||
model_config (Dict[str, Any]): the detail config about model from configuration.json | |||
""" | |||
super().__init__(model_dir, am_model_name, model_config, *args, | |||
**kwargs) | |||
self.model_cfg = { | |||
# the recognition model dir path | |||
'model_workspace': model_dir, | |||
@@ -312,5 +312,11 @@ TASK_OUTPUTS = { | |||
# { | |||
# "text": "this is the text generated by a model." | |||
# } | |||
Tasks.visual_question_answering: [OutputKeys.TEXT] | |||
Tasks.visual_question_answering: [OutputKeys.TEXT], | |||
# auto_speech_recognition result for a single sample | |||
# { | |||
# "text": "每天都要快乐喔" | |||
# } | |||
Tasks.auto_speech_recognition: [OutputKeys.TEXT] | |||
} |
@@ -3,7 +3,7 @@ | |||
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | |||
try: | |||
from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||
from .kws_kwsbp_pipeline import * # noqa F403 | |||
from .linear_aec_pipeline import LinearAECPipeline | |||
except ModuleNotFoundError as e: | |||
@@ -1,21 +0,0 @@ | |||
import ssl | |||
import nltk | |||
try: | |||
_create_unverified_https_context = ssl._create_unverified_context | |||
except AttributeError: | |||
pass | |||
else: | |||
ssl._create_default_https_context = _create_unverified_https_context | |||
try: | |||
nltk.data.find('taggers/averaged_perceptron_tagger') | |||
except LookupError: | |||
nltk.download( | |||
'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True) | |||
try: | |||
nltk.data.find('corpora/cmudict') | |||
except LookupError: | |||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True) |
@@ -1,690 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
import argparse | |||
import logging | |||
import sys | |||
import time | |||
from pathlib import Path | |||
from typing import Any, Optional, Sequence, Tuple, Union | |||
import numpy as np | |||
import torch | |||
from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | |||
from espnet2.asr.transducer.beam_search_transducer import \ | |||
ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | |||
from espnet2.asr.transducer.beam_search_transducer import \ | |||
Hypothesis as TransHypothesis | |||
from espnet2.fileio.datadir_writer import DatadirWriter | |||
from espnet2.tasks.lm import LMTask | |||
from espnet2.text.build_tokenizer import build_tokenizer | |||
from espnet2.text.token_id_converter import TokenIDConverter | |||
from espnet2.torch_utils.device_funcs import to_device | |||
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed | |||
from espnet2.utils import config_argparse | |||
from espnet2.utils.types import str2bool, str2triple_str, str_or_none | |||
from espnet.nets.batch_beam_search import BatchBeamSearch | |||
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim | |||
from espnet.nets.beam_search import BeamSearch, Hypothesis | |||
from espnet.nets.pytorch_backend.transformer.subsampling import \ | |||
TooShortUttError | |||
from espnet.nets.scorer_interface import BatchScorerInterface | |||
from espnet.nets.scorers.ctc import CTCPrefixScorer | |||
from espnet.nets.scorers.length_bonus import LengthBonus | |||
from espnet.utils.cli_utils import get_commandline_args | |||
from typeguard import check_argument_types | |||
from .espnet.asr.frontend.wav_frontend import WavFrontend | |||
from .espnet.tasks.asr import ASRTaskNAR as ASRTask | |||
class Speech2Text: | |||
def __init__(self, | |||
asr_train_config: Union[Path, str] = None, | |||
asr_model_file: Union[Path, str] = None, | |||
transducer_conf: dict = None, | |||
lm_train_config: Union[Path, str] = None, | |||
lm_file: Union[Path, str] = None, | |||
ngram_scorer: str = 'full', | |||
ngram_file: Union[Path, str] = None, | |||
token_type: str = None, | |||
bpemodel: str = None, | |||
device: str = 'cpu', | |||
maxlenratio: float = 0.0, | |||
minlenratio: float = 0.0, | |||
batch_size: int = 1, | |||
dtype: str = 'float32', | |||
beam_size: int = 20, | |||
ctc_weight: float = 0.5, | |||
lm_weight: float = 1.0, | |||
ngram_weight: float = 0.9, | |||
penalty: float = 0.0, | |||
nbest: int = 1, | |||
streaming: bool = False, | |||
frontend_conf: dict = None): | |||
assert check_argument_types() | |||
# 1. Build ASR model | |||
scorers = {} | |||
asr_model, asr_train_args = ASRTask.build_model_from_file( | |||
asr_train_config, asr_model_file, device) | |||
if asr_model.frontend is None and frontend_conf is not None: | |||
frontend = WavFrontend(**frontend_conf) | |||
asr_model.frontend = frontend | |||
asr_model.to(dtype=getattr(torch, dtype)).eval() | |||
decoder = asr_model.decoder | |||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) | |||
token_list = asr_model.token_list | |||
scorers.update( | |||
decoder=decoder, | |||
ctc=ctc, | |||
length_bonus=LengthBonus(len(token_list)), | |||
) | |||
# 2. Build Language model | |||
if lm_train_config is not None: | |||
lm, lm_train_args = LMTask.build_model_from_file( | |||
lm_train_config, lm_file, device) | |||
scorers['lm'] = lm.lm | |||
# 3. Build ngram model | |||
if ngram_file is not None: | |||
if ngram_scorer == 'full': | |||
from espnet.nets.scorers.ngram import NgramFullScorer | |||
ngram = NgramFullScorer(ngram_file, token_list) | |||
else: | |||
from espnet.nets.scorers.ngram import NgramPartScorer | |||
ngram = NgramPartScorer(ngram_file, token_list) | |||
else: | |||
ngram = None | |||
scorers['ngram'] = ngram | |||
# 4. Build BeamSearch object | |||
if asr_model.use_transducer_decoder: | |||
beam_search_transducer = BeamSearchTransducer( | |||
decoder=asr_model.decoder, | |||
joint_network=asr_model.joint_network, | |||
beam_size=beam_size, | |||
lm=scorers['lm'] if 'lm' in scorers else None, | |||
lm_weight=lm_weight, | |||
**transducer_conf, | |||
) | |||
beam_search = None | |||
else: | |||
beam_search_transducer = None | |||
weights = dict( | |||
decoder=1.0 - ctc_weight, | |||
ctc=ctc_weight, | |||
lm=lm_weight, | |||
ngram=ngram_weight, | |||
length_bonus=penalty, | |||
) | |||
beam_search = BeamSearch( | |||
beam_size=beam_size, | |||
weights=weights, | |||
scorers=scorers, | |||
sos=asr_model.sos, | |||
eos=asr_model.eos, | |||
vocab_size=len(token_list), | |||
token_list=token_list, | |||
pre_beam_score_key=None if ctc_weight == 1.0 else 'full', | |||
) | |||
# TODO(karita): make all scorers batchfied | |||
if batch_size == 1: | |||
non_batch = [ | |||
k for k, v in beam_search.full_scorers.items() | |||
if not isinstance(v, BatchScorerInterface) | |||
] | |||
if len(non_batch) == 0: | |||
if streaming: | |||
beam_search.__class__ = BatchBeamSearchOnlineSim | |||
beam_search.set_streaming_config(asr_train_config) | |||
logging.info( | |||
'BatchBeamSearchOnlineSim implementation is selected.' | |||
) | |||
else: | |||
beam_search.__class__ = BatchBeamSearch | |||
else: | |||
logging.warning( | |||
f'As non-batch scorers {non_batch} are found, ' | |||
f'fall back to non-batch implementation.') | |||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() | |||
for scorer in scorers.values(): | |||
if isinstance(scorer, torch.nn.Module): | |||
scorer.to( | |||
device=device, dtype=getattr(torch, dtype)).eval() | |||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text | |||
if token_type is None: | |||
token_type = asr_train_args.token_type | |||
if bpemodel is None: | |||
bpemodel = asr_train_args.bpemodel | |||
if token_type is None: | |||
tokenizer = None | |||
elif token_type == 'bpe': | |||
if bpemodel is not None: | |||
tokenizer = build_tokenizer( | |||
token_type=token_type, bpemodel=bpemodel) | |||
else: | |||
tokenizer = None | |||
else: | |||
tokenizer = build_tokenizer(token_type=token_type) | |||
converter = TokenIDConverter(token_list=token_list) | |||
self.asr_model = asr_model | |||
self.asr_train_args = asr_train_args | |||
self.converter = converter | |||
self.tokenizer = tokenizer | |||
self.beam_search = beam_search | |||
self.beam_search_transducer = beam_search_transducer | |||
self.maxlenratio = maxlenratio | |||
self.minlenratio = minlenratio | |||
self.device = device | |||
self.dtype = dtype | |||
self.nbest = nbest | |||
@torch.no_grad() | |||
def __call__(self, speech: Union[torch.Tensor, np.ndarray]): | |||
"""Inference | |||
Args: | |||
data: Input speech data | |||
Returns: | |||
text, token, token_int, hyp | |||
""" | |||
assert check_argument_types() | |||
# Input as audio signal | |||
if isinstance(speech, np.ndarray): | |||
speech = torch.tensor(speech) | |||
# data: (Nsamples,) -> (1, Nsamples) | |||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) | |||
# lengths: (1,) | |||
lengths = speech.new_full([1], | |||
dtype=torch.long, | |||
fill_value=speech.size(1)) | |||
batch = {'speech': speech, 'speech_lengths': lengths} | |||
# a. To device | |||
batch = to_device(batch, device=self.device) | |||
# b. Forward Encoder | |||
enc, enc_len = self.asr_model.encode(**batch) | |||
if isinstance(enc, tuple): | |||
enc = enc[0] | |||
assert len(enc) == 1, len(enc) | |||
predictor_outs = self.asr_model.calc_predictor(enc, enc_len) | |||
pre_acoustic_embeds, pre_token_length = predictor_outs[ | |||
0], predictor_outs[1] | |||
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], | |||
device=pre_acoustic_embeds.device) | |||
decoder_outs = self.asr_model.cal_decoder_with_predictor( | |||
enc, enc_len, pre_acoustic_embeds, pre_token_length) | |||
decoder_out = decoder_outs[0] | |||
yseq = decoder_out.argmax(dim=-1) | |||
score = decoder_out.max(dim=-1)[0] | |||
score = torch.sum(score, dim=-1) | |||
# pad with mask tokens to ensure compatibility with sos/eos tokens | |||
yseq = torch.tensor( | |||
[self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos], | |||
device=yseq.device) | |||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |||
results = [] | |||
for hyp in nbest_hyps: | |||
assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp) | |||
# remove sos/eos and get results | |||
last_pos = None if self.asr_model.use_transducer_decoder else -1 | |||
if isinstance(hyp.yseq, list): | |||
token_int = hyp.yseq[1:last_pos] | |||
else: | |||
token_int = hyp.yseq[1:last_pos].tolist() | |||
# remove blank symbol id, which is assumed to be 0 | |||
token_int = list(filter(lambda x: x != 0, token_int)) | |||
# Change integer-ids to tokens | |||
token = self.converter.ids2tokens(token_int) | |||
if self.tokenizer is not None: | |||
text = self.tokenizer.tokens2text(token) | |||
else: | |||
text = None | |||
results.append((text, token, token_int, hyp, speech.size(1))) | |||
return results | |||
@staticmethod | |||
def from_pretrained( | |||
model_tag: Optional[str] = None, | |||
**kwargs: Optional[Any], | |||
): | |||
"""Build Speech2Text instance from the pretrained model. | |||
Args: | |||
model_tag (Optional[str]): Model tag of the pretrained models. | |||
Currently, the tags of espnet_model_zoo are supported. | |||
Returns: | |||
Speech2Text: Speech2Text instance. | |||
""" | |||
if model_tag is not None: | |||
try: | |||
from espnet_model_zoo.downloader import ModelDownloader | |||
except ImportError: | |||
logging.error( | |||
'`espnet_model_zoo` is not installed. ' | |||
'Please install via `pip install -U espnet_model_zoo`.') | |||
raise | |||
d = ModelDownloader() | |||
kwargs.update(**d.download_and_unpack(model_tag)) | |||
return Speech2Text(**kwargs) | |||
def inference( | |||
output_dir: str, | |||
maxlenratio: float, | |||
minlenratio: float, | |||
batch_size: int, | |||
dtype: str, | |||
beam_size: int, | |||
ngpu: int, | |||
seed: int, | |||
ctc_weight: float, | |||
lm_weight: float, | |||
ngram_weight: float, | |||
penalty: float, | |||
nbest: int, | |||
num_workers: int, | |||
log_level: Union[int, str], | |||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||
key_file: Optional[str], | |||
asr_train_config: Optional[str], | |||
asr_model_file: Optional[str], | |||
lm_train_config: Optional[str], | |||
lm_file: Optional[str], | |||
word_lm_train_config: Optional[str], | |||
word_lm_file: Optional[str], | |||
ngram_file: Optional[str], | |||
model_tag: Optional[str], | |||
token_type: Optional[str], | |||
bpemodel: Optional[str], | |||
allow_variable_data_keys: bool, | |||
transducer_conf: Optional[dict], | |||
streaming: bool, | |||
frontend_conf: dict = None, | |||
): | |||
assert check_argument_types() | |||
if batch_size > 1: | |||
raise NotImplementedError('batch decoding is not implemented') | |||
if word_lm_train_config is not None: | |||
raise NotImplementedError('Word LM is not implemented') | |||
if ngpu > 1: | |||
raise NotImplementedError('only single GPU decoding is supported') | |||
logging.basicConfig( | |||
level=log_level, | |||
format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||
) | |||
if ngpu >= 1: | |||
device = 'cuda' | |||
else: | |||
device = 'cpu' | |||
# 1. Set random-seed | |||
set_all_random_seed(seed) | |||
# 2. Build speech2text | |||
speech2text_kwargs = dict( | |||
asr_train_config=asr_train_config, | |||
asr_model_file=asr_model_file, | |||
transducer_conf=transducer_conf, | |||
lm_train_config=lm_train_config, | |||
lm_file=lm_file, | |||
ngram_file=ngram_file, | |||
token_type=token_type, | |||
bpemodel=bpemodel, | |||
device=device, | |||
maxlenratio=maxlenratio, | |||
minlenratio=minlenratio, | |||
dtype=dtype, | |||
beam_size=beam_size, | |||
ctc_weight=ctc_weight, | |||
lm_weight=lm_weight, | |||
ngram_weight=ngram_weight, | |||
penalty=penalty, | |||
nbest=nbest, | |||
streaming=streaming, | |||
frontend_conf=frontend_conf, | |||
) | |||
speech2text = Speech2Text.from_pretrained( | |||
model_tag=model_tag, | |||
**speech2text_kwargs, | |||
) | |||
# 3. Build data-iterator | |||
loader = ASRTask.build_streaming_iterator( | |||
data_path_and_name_and_type, | |||
dtype=dtype, | |||
batch_size=batch_size, | |||
key_file=key_file, | |||
num_workers=num_workers, | |||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, | |||
False), | |||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), | |||
allow_variable_data_keys=allow_variable_data_keys, | |||
inference=True, | |||
) | |||
forward_time_total = 0.0 | |||
length_total = 0.0 | |||
# 7 .Start for-loop | |||
# FIXME(kamo): The output format should be discussed about | |||
with DatadirWriter(output_dir) as writer: | |||
for keys, batch in loader: | |||
assert isinstance(batch, dict), type(batch) | |||
assert all(isinstance(s, str) for s in keys), keys | |||
_bs = len(next(iter(batch.values()))) | |||
assert len(keys) == _bs, f'{len(keys)} != {_bs}' | |||
batch = { | |||
k: v[0] | |||
for k, v in batch.items() if not k.endswith('_lengths') | |||
} | |||
# N-best list of (text, token, token_int, hyp_object) | |||
try: | |||
time_beg = time.time() | |||
results = speech2text(**batch) | |||
time_end = time.time() | |||
forward_time = time_end - time_beg | |||
length = results[0][-1] | |||
results = [results[0][:-1]] | |||
forward_time_total += forward_time | |||
length_total += length | |||
except TooShortUttError as e: | |||
logging.warning(f'Utterance {keys} {e}') | |||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) | |||
results = [[' ', ['<space>'], [2], hyp]] * nbest | |||
# Only supporting batch_size==1 | |||
key = keys[0] | |||
for n, (text, token, token_int, | |||
hyp) in zip(range(1, nbest + 1), results): | |||
# Create a directory: outdir/{n}best_recog | |||
ibest_writer = writer[f'{n}best_recog'] | |||
# Write the result to each file | |||
ibest_writer['token'][key] = ' '.join(token) | |||
ibest_writer['token_int'][key] = ' '.join(map(str, token_int)) | |||
ibest_writer['score'][key] = str(hyp.score) | |||
if text is not None: | |||
ibest_writer['text'][key] = text | |||
logging.info( | |||
'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}' | |||
.format(length_total, forward_time_total, | |||
100 * forward_time_total / length_total)) | |||
def get_parser(): | |||
parser = config_argparse.ArgumentParser( | |||
description='ASR Decoding', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
# Note(kamo): Use '_' instead of '-' as separator. | |||
# '-' is confusing if written in yaml. | |||
parser.add_argument( | |||
'--log_level', | |||
type=lambda x: x.upper(), | |||
default='INFO', | |||
choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'), | |||
help='The verbose level of logging', | |||
) | |||
parser.add_argument('--output_dir', type=str, required=True) | |||
parser.add_argument( | |||
'--ngpu', | |||
type=int, | |||
default=0, | |||
help='The number of gpus. 0 indicates CPU mode', | |||
) | |||
parser.add_argument('--seed', type=int, default=0, help='Random seed') | |||
parser.add_argument( | |||
'--dtype', | |||
default='float32', | |||
choices=['float16', 'float32', 'float64'], | |||
help='Data type', | |||
) | |||
parser.add_argument( | |||
'--num_workers', | |||
type=int, | |||
default=1, | |||
help='The number of workers used for DataLoader', | |||
) | |||
group = parser.add_argument_group('Input data related') | |||
group.add_argument( | |||
'--data_path_and_name_and_type', | |||
type=str2triple_str, | |||
required=True, | |||
action='append', | |||
) | |||
group.add_argument('--key_file', type=str_or_none) | |||
group.add_argument( | |||
'--allow_variable_data_keys', type=str2bool, default=False) | |||
group = parser.add_argument_group('The model configuration related') | |||
group.add_argument( | |||
'--asr_train_config', | |||
type=str, | |||
help='ASR training configuration', | |||
) | |||
group.add_argument( | |||
'--asr_model_file', | |||
type=str, | |||
help='ASR model parameter file', | |||
) | |||
group.add_argument( | |||
'--lm_train_config', | |||
type=str, | |||
help='LM training configuration', | |||
) | |||
group.add_argument( | |||
'--lm_file', | |||
type=str, | |||
help='LM parameter file', | |||
) | |||
group.add_argument( | |||
'--word_lm_train_config', | |||
type=str, | |||
help='Word LM training configuration', | |||
) | |||
group.add_argument( | |||
'--word_lm_file', | |||
type=str, | |||
help='Word LM parameter file', | |||
) | |||
group.add_argument( | |||
'--ngram_file', | |||
type=str, | |||
help='N-gram parameter file', | |||
) | |||
group.add_argument( | |||
'--model_tag', | |||
type=str, | |||
help='Pretrained model tag. If specify this option, *_train_config and ' | |||
'*_file will be overwritten', | |||
) | |||
group = parser.add_argument_group('Beam-search related') | |||
group.add_argument( | |||
'--batch_size', | |||
type=int, | |||
default=1, | |||
help='The batch size for inference', | |||
) | |||
group.add_argument( | |||
'--nbest', type=int, default=1, help='Output N-best hypotheses') | |||
group.add_argument('--beam_size', type=int, default=20, help='Beam size') | |||
group.add_argument( | |||
'--penalty', type=float, default=0.0, help='Insertion penalty') | |||
group.add_argument( | |||
'--maxlenratio', | |||
type=float, | |||
default=0.0, | |||
help='Input length ratio to obtain max output length. ' | |||
'If maxlenratio=0.0 (default), it uses a end-detect ' | |||
'function ' | |||
'to automatically find maximum hypothesis lengths.' | |||
'If maxlenratio<0.0, its absolute value is interpreted' | |||
'as a constant max output length', | |||
) | |||
group.add_argument( | |||
'--minlenratio', | |||
type=float, | |||
default=0.0, | |||
help='Input length ratio to obtain min output length', | |||
) | |||
group.add_argument( | |||
'--ctc_weight', | |||
type=float, | |||
default=0.5, | |||
help='CTC weight in joint decoding', | |||
) | |||
group.add_argument( | |||
'--lm_weight', type=float, default=1.0, help='RNNLM weight') | |||
group.add_argument( | |||
'--ngram_weight', type=float, default=0.9, help='ngram weight') | |||
group.add_argument('--streaming', type=str2bool, default=False) | |||
group.add_argument( | |||
'--frontend_conf', | |||
default=None, | |||
help='', | |||
) | |||
group = parser.add_argument_group('Text converter related') | |||
group.add_argument( | |||
'--token_type', | |||
type=str_or_none, | |||
default=None, | |||
choices=['char', 'bpe', None], | |||
help='The token type for ASR model. ' | |||
'If not given, refers from the training args', | |||
) | |||
group.add_argument( | |||
'--bpemodel', | |||
type=str_or_none, | |||
default=None, | |||
help='The model path of sentencepiece. ' | |||
'If not given, refers from the training args', | |||
) | |||
group.add_argument( | |||
'--transducer_conf', | |||
default=None, | |||
help='The keyword arguments for transducer beam search.', | |||
) | |||
return parser | |||
def asr_inference( | |||
output_dir: str, | |||
maxlenratio: float, | |||
minlenratio: float, | |||
beam_size: int, | |||
ngpu: int, | |||
ctc_weight: float, | |||
lm_weight: float, | |||
penalty: float, | |||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||
asr_train_config: Optional[str], | |||
asr_model_file: Optional[str], | |||
nbest: int = 1, | |||
num_workers: int = 1, | |||
log_level: Union[int, str] = 'INFO', | |||
batch_size: int = 1, | |||
dtype: str = 'float32', | |||
seed: int = 0, | |||
key_file: Optional[str] = None, | |||
lm_train_config: Optional[str] = None, | |||
lm_file: Optional[str] = None, | |||
word_lm_train_config: Optional[str] = None, | |||
word_lm_file: Optional[str] = None, | |||
ngram_file: Optional[str] = None, | |||
ngram_weight: float = 0.9, | |||
model_tag: Optional[str] = None, | |||
token_type: Optional[str] = None, | |||
bpemodel: Optional[str] = None, | |||
allow_variable_data_keys: bool = False, | |||
transducer_conf: Optional[dict] = None, | |||
streaming: bool = False, | |||
frontend_conf: dict = None, | |||
): | |||
inference( | |||
output_dir=output_dir, | |||
maxlenratio=maxlenratio, | |||
minlenratio=minlenratio, | |||
batch_size=batch_size, | |||
dtype=dtype, | |||
beam_size=beam_size, | |||
ngpu=ngpu, | |||
seed=seed, | |||
ctc_weight=ctc_weight, | |||
lm_weight=lm_weight, | |||
ngram_weight=ngram_weight, | |||
penalty=penalty, | |||
nbest=nbest, | |||
num_workers=num_workers, | |||
log_level=log_level, | |||
data_path_and_name_and_type=data_path_and_name_and_type, | |||
key_file=key_file, | |||
asr_train_config=asr_train_config, | |||
asr_model_file=asr_model_file, | |||
lm_train_config=lm_train_config, | |||
lm_file=lm_file, | |||
word_lm_train_config=word_lm_train_config, | |||
word_lm_file=word_lm_file, | |||
ngram_file=ngram_file, | |||
model_tag=model_tag, | |||
token_type=token_type, | |||
bpemodel=bpemodel, | |||
allow_variable_data_keys=allow_variable_data_keys, | |||
transducer_conf=transducer_conf, | |||
streaming=streaming, | |||
frontend_conf=frontend_conf) | |||
def main(cmd=None): | |||
print(get_commandline_args(), file=sys.stderr) | |||
parser = get_parser() | |||
args = parser.parse_args(cmd) | |||
kwargs = vars(args) | |||
kwargs.pop('config', None) | |||
inference(**kwargs) | |||
if __name__ == '__main__': | |||
main() |
@@ -1,193 +0,0 @@ | |||
import os | |||
from typing import Any, Dict, List | |||
import numpy as np | |||
def type_checking(wav_path: str, | |||
recog_type: str = None, | |||
audio_format: str = None, | |||
workspace: str = None): | |||
assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist' | |||
r_recog_type = recog_type | |||
r_audio_format = audio_format | |||
r_workspace = workspace | |||
r_wav_path = wav_path | |||
if r_workspace is None or len(r_workspace) == 0: | |||
r_workspace = os.path.join(os.getcwd(), '.tmp') | |||
if r_recog_type is None: | |||
if os.path.isfile(wav_path): | |||
if wav_path.endswith('.wav') or wav_path.endswith('.WAV'): | |||
r_recog_type = 'wav' | |||
r_audio_format = 'wav' | |||
elif os.path.isdir(wav_path): | |||
dir_name = os.path.basename(wav_path) | |||
if 'test' in dir_name: | |||
r_recog_type = 'test' | |||
elif 'dev' in dir_name: | |||
r_recog_type = 'dev' | |||
elif 'train' in dir_name: | |||
r_recog_type = 'train' | |||
if r_audio_format is None: | |||
if find_file_by_ends(wav_path, '.ark'): | |||
r_audio_format = 'kaldi_ark' | |||
elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends( | |||
wav_path, '.WAV'): | |||
r_audio_format = 'wav' | |||
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': | |||
# datasets with kaldi_ark file | |||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) | |||
elif r_audio_format == 'wav' and r_recog_type != 'wav': | |||
# datasets with waveform files | |||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) | |||
return r_recog_type, r_audio_format, r_workspace, r_wav_path | |||
def find_file_by_ends(dir_path: str, ends: str): | |||
dir_files = os.listdir(dir_path) | |||
for file in dir_files: | |||
file_path = os.path.join(dir_path, file) | |||
if os.path.isfile(file_path): | |||
if file_path.endswith(ends): | |||
return True | |||
elif os.path.isdir(file_path): | |||
if find_file_by_ends(file_path, ends): | |||
return True | |||
return False | |||
def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]: | |||
assert os.path.exists(hyp_text_path), 'hyp_text does not exist' | |||
assert os.path.exists(ref_text_path), 'ref_text does not exist' | |||
rst = { | |||
'Wrd': 0, | |||
'Corr': 0, | |||
'Ins': 0, | |||
'Del': 0, | |||
'Sub': 0, | |||
'Snt': 0, | |||
'Err': 0.0, | |||
'S.Err': 0.0, | |||
'wrong_words': 0, | |||
'wrong_sentences': 0 | |||
} | |||
with open(ref_text_path, 'r', encoding='utf-8') as r: | |||
r_lines = r.readlines() | |||
with open(hyp_text_path, 'r', encoding='utf-8') as h: | |||
h_lines = h.readlines() | |||
for r_line in r_lines: | |||
r_line_item = r_line.split() | |||
r_key = r_line_item[0] | |||
r_sentence = r_line_item[1] | |||
for h_line in h_lines: | |||
# find sentence from hyp text | |||
if r_key in h_line: | |||
h_line_item = h_line.split() | |||
h_sentence = h_line_item[1] | |||
out_item = compute_wer_by_line(h_sentence, r_sentence) | |||
rst['Wrd'] += out_item['nwords'] | |||
rst['Corr'] += out_item['cor'] | |||
rst['wrong_words'] += out_item['wrong'] | |||
rst['Ins'] += out_item['ins'] | |||
rst['Del'] += out_item['del'] | |||
rst['Sub'] += out_item['sub'] | |||
rst['Snt'] += 1 | |||
if out_item['wrong'] > 0: | |||
rst['wrong_sentences'] += 1 | |||
break | |||
if rst['Wrd'] > 0: | |||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) | |||
if rst['Snt'] > 0: | |||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) | |||
return rst | |||
def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]: | |||
len_hyp = len(hyp) | |||
len_ref = len(ref) | |||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) | |||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) | |||
for i in range(len_hyp + 1): | |||
cost_matrix[i][0] = i | |||
for j in range(len_ref + 1): | |||
cost_matrix[0][j] = j | |||
for i in range(1, len_hyp + 1): | |||
for j in range(1, len_ref + 1): | |||
if hyp[i - 1] == ref[j - 1]: | |||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1] | |||
else: | |||
substitution = cost_matrix[i - 1][j - 1] + 1 | |||
insertion = cost_matrix[i - 1][j] + 1 | |||
deletion = cost_matrix[i][j - 1] + 1 | |||
compare_val = [substitution, insertion, deletion] | |||
min_val = min(compare_val) | |||
operation_idx = compare_val.index(min_val) + 1 | |||
cost_matrix[i][j] = min_val | |||
ops_matrix[i][j] = operation_idx | |||
match_idx = [] | |||
i = len_hyp | |||
j = len_ref | |||
rst = { | |||
'nwords': len_hyp, | |||
'cor': 0, | |||
'wrong': 0, | |||
'ins': 0, | |||
'del': 0, | |||
'sub': 0 | |||
} | |||
while i >= 0 or j >= 0: | |||
i_idx = max(0, i) | |||
j_idx = max(0, j) | |||
if ops_matrix[i_idx][j_idx] == 0: # correct | |||
if i - 1 >= 0 and j - 1 >= 0: | |||
match_idx.append((j - 1, i - 1)) | |||
rst['cor'] += 1 | |||
i -= 1 | |||
j -= 1 | |||
elif ops_matrix[i_idx][j_idx] == 2: # insert | |||
i -= 1 | |||
rst['ins'] += 1 | |||
elif ops_matrix[i_idx][j_idx] == 3: # delete | |||
j -= 1 | |||
rst['del'] += 1 | |||
elif ops_matrix[i_idx][j_idx] == 1: # substitute | |||
i -= 1 | |||
j -= 1 | |||
rst['sub'] += 1 | |||
if i < 0 and j >= 0: | |||
rst['del'] += 1 | |||
elif j < 0 and i >= 0: | |||
rst['ins'] += 1 | |||
match_idx.reverse() | |||
wrong_cnt = cost_matrix[len_hyp][len_ref] | |||
rst['wrong'] = wrong_cnt | |||
return rst |
@@ -1,757 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
"""Decoder definition.""" | |||
from typing import Any, List, Sequence, Tuple | |||
import torch | |||
from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
from espnet.nets.pytorch_backend.transformer.attention import \ | |||
MultiHeadedAttention | |||
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer | |||
from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ | |||
DynamicConvolution | |||
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ | |||
DynamicConvolution2D | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
PositionalEncoding | |||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
from espnet.nets.pytorch_backend.transformer.lightconv import \ | |||
LightweightConvolution | |||
from espnet.nets.pytorch_backend.transformer.lightconv2d import \ | |||
LightweightConvolution2D | |||
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
PositionwiseFeedForward # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
from espnet.nets.scorer_interface import BatchScorerInterface | |||
from typeguard import check_argument_types | |||
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): | |||
"""Base class of Transfomer decoder module. | |||
Args: | |||
vocab_size: output dim | |||
encoder_output_size: dimension of attention | |||
attention_heads: the number of heads of multi head attention | |||
linear_units: the number of units of position-wise feed forward | |||
num_blocks: the number of decoder blocks | |||
dropout_rate: dropout rate | |||
self_attention_dropout_rate: dropout rate for attention | |||
input_layer: input layer type | |||
use_output_layer: whether to use output layer | |||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
normalize_before: whether to use layer_norm before the first block | |||
concat_after: whether to concat attention layer's input and output | |||
if True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
if False, no additional linear will be applied. | |||
i.e. x -> x + att(x) | |||
""" | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
attention_dim = encoder_output_size | |||
if input_layer == 'embed': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Embedding(vocab_size, attention_dim), | |||
pos_enc_class(attention_dim, positional_dropout_rate), | |||
) | |||
elif input_layer == 'linear': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Linear(vocab_size, attention_dim), | |||
torch.nn.LayerNorm(attention_dim), | |||
torch.nn.Dropout(dropout_rate), | |||
torch.nn.ReLU(), | |||
pos_enc_class(attention_dim, positional_dropout_rate), | |||
) | |||
else: | |||
raise ValueError( | |||
f"only 'embed' or 'linear' is supported: {input_layer}") | |||
self.normalize_before = normalize_before | |||
if self.normalize_before: | |||
self.after_norm = LayerNorm(attention_dim) | |||
if use_output_layer: | |||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size) | |||
else: | |||
self.output_layer = None | |||
# Must set by the inheritance | |||
self.decoders = None | |||
def forward( | |||
self, | |||
hs_pad: torch.Tensor, | |||
hlens: torch.Tensor, | |||
ys_in_pad: torch.Tensor, | |||
ys_in_lens: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
"""Forward decoder. | |||
Args: | |||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
hlens: (batch) | |||
ys_in_pad: | |||
input token ids, int64 (batch, maxlen_out) | |||
if input_layer == "embed" | |||
input tensor (batch, maxlen_out, #mels) in the other cases | |||
ys_in_lens: (batch) | |||
Returns: | |||
(tuple): tuple containing: | |||
x: decoded token score before softmax (batch, maxlen_out, token) | |||
if use_output_layer is True, | |||
olens: (batch, ) | |||
""" | |||
tgt = ys_in_pad | |||
# tgt_mask: (B, 1, L) | |||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
# m: (1, L, L) | |||
m = subsequent_mask( | |||
tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
# tgt_mask: (B, L, L) | |||
tgt_mask = tgt_mask & m | |||
memory = hs_pad | |||
memory_mask = ( | |||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
memory.device) | |||
# Padding for Longformer | |||
if memory_mask.shape[-1] != memory.shape[1]: | |||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
'constant', False) | |||
x = self.embed(tgt) | |||
x, tgt_mask, memory, memory_mask = self.decoders( | |||
x, tgt_mask, memory, memory_mask) | |||
if self.normalize_before: | |||
x = self.after_norm(x) | |||
if self.output_layer is not None: | |||
x = self.output_layer(x) | |||
olens = tgt_mask.sum(1) | |||
return x, olens | |||
def forward_one_step( | |||
self, | |||
tgt: torch.Tensor, | |||
tgt_mask: torch.Tensor, | |||
memory: torch.Tensor, | |||
cache: List[torch.Tensor] = None, | |||
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |||
"""Forward one step. | |||
Args: | |||
tgt: input token ids, int64 (batch, maxlen_out) | |||
tgt_mask: input token mask, (batch, maxlen_out) | |||
dtype=torch.uint8 in PyTorch 1.2- | |||
dtype=torch.bool in PyTorch 1.2+ (include 1.2) | |||
memory: encoded memory, float32 (batch, maxlen_in, feat) | |||
cache: cached output list of (batch, max_time_out-1, size) | |||
Returns: | |||
y, cache: NN output value and cache per `self.decoders`. | |||
y.shape` is (batch, maxlen_out, token) | |||
""" | |||
x = self.embed(tgt) | |||
if cache is None: | |||
cache = [None] * len(self.decoders) | |||
new_cache = [] | |||
for c, decoder in zip(cache, self.decoders): | |||
x, tgt_mask, memory, memory_mask = decoder( | |||
x, tgt_mask, memory, None, cache=c) | |||
new_cache.append(x) | |||
if self.normalize_before: | |||
y = self.after_norm(x[:, -1]) | |||
else: | |||
y = x[:, -1] | |||
if self.output_layer is not None: | |||
y = torch.log_softmax(self.output_layer(y), dim=-1) | |||
return y, new_cache | |||
def score(self, ys, state, x): | |||
"""Score.""" | |||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) | |||
logp, state = self.forward_one_step( | |||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) | |||
return logp.squeeze(0), state | |||
def batch_score(self, ys: torch.Tensor, states: List[Any], | |||
xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: | |||
"""Score new token batch. | |||
Args: | |||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |||
states (List[Any]): Scorer states for prefix tokens. | |||
xs (torch.Tensor): | |||
The encoder feature that generates ys (n_batch, xlen, n_feat). | |||
Returns: | |||
tuple[torch.Tensor, List[Any]]: Tuple of | |||
batchfied scores for next token with shape of `(n_batch, n_vocab)` | |||
and next state list for ys. | |||
""" | |||
# merge states | |||
n_batch = len(ys) | |||
n_layers = len(self.decoders) | |||
if states[0] is None: | |||
batch_state = None | |||
else: | |||
# transpose state of [batch, layer] into [layer, batch] | |||
batch_state = [ | |||
torch.stack([states[b][i] for b in range(n_batch)]) | |||
for i in range(n_layers) | |||
] | |||
# batch decoding | |||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) | |||
logp, states = self.forward_one_step( | |||
ys, ys_mask, xs, cache=batch_state) | |||
# transpose state of [layer, batch] into [batch, layer] | |||
state_list = [[states[i][b] for i in range(n_layers)] | |||
for b in range(n_batch)] | |||
return logp, state_list | |||
class TransformerDecoder(BaseTransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
): | |||
assert check_argument_types() | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
self_attention_dropout_rate), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
class ParaformerDecoder(TransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
): | |||
assert check_argument_types() | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
self_attention_dropout_rate), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
def forward( | |||
self, | |||
hs_pad: torch.Tensor, | |||
hlens: torch.Tensor, | |||
ys_in_pad: torch.Tensor, | |||
ys_in_lens: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
"""Forward decoder. | |||
Args: | |||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
hlens: (batch) | |||
ys_in_pad: | |||
input token ids, int64 (batch, maxlen_out) | |||
if input_layer == "embed" | |||
input tensor (batch, maxlen_out, #mels) in the other cases | |||
ys_in_lens: (batch) | |||
Returns: | |||
(tuple): tuple containing: | |||
x: decoded token score before softmax (batch, maxlen_out, token) | |||
if use_output_layer is True, | |||
olens: (batch, ) | |||
""" | |||
tgt = ys_in_pad | |||
# tgt_mask: (B, 1, L) | |||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
# m: (1, L, L) | |||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
# tgt_mask: (B, L, L) | |||
# tgt_mask = tgt_mask & m | |||
memory = hs_pad | |||
memory_mask = ( | |||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
memory.device) | |||
# Padding for Longformer | |||
if memory_mask.shape[-1] != memory.shape[1]: | |||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
'constant', False) | |||
# x = self.embed(tgt) | |||
x = tgt | |||
x, tgt_mask, memory, memory_mask = self.decoders( | |||
x, tgt_mask, memory, memory_mask) | |||
if self.normalize_before: | |||
x = self.after_norm(x) | |||
if self.output_layer is not None: | |||
x = self.output_layer(x) | |||
olens = tgt_mask.sum(1) | |||
return x, olens | |||
class ParaformerDecoderBertEmbed(TransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
embeds_id: int = 2, | |||
): | |||
assert check_argument_types() | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
embeds_id, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
self_attention_dropout_rate), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
if embeds_id == num_blocks: | |||
self.decoders2 = None | |||
else: | |||
self.decoders2 = repeat( | |||
num_blocks - embeds_id, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
self_attention_dropout_rate), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
def forward( | |||
self, | |||
hs_pad: torch.Tensor, | |||
hlens: torch.Tensor, | |||
ys_in_pad: torch.Tensor, | |||
ys_in_lens: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
"""Forward decoder. | |||
Args: | |||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
hlens: (batch) | |||
ys_in_pad: | |||
input token ids, int64 (batch, maxlen_out) | |||
if input_layer == "embed" | |||
input tensor (batch, maxlen_out, #mels) in the other cases | |||
ys_in_lens: (batch) | |||
Returns: | |||
(tuple): tuple containing: | |||
x: decoded token score before softmax (batch, maxlen_out, token) | |||
if use_output_layer is True, | |||
olens: (batch, ) | |||
""" | |||
tgt = ys_in_pad | |||
# tgt_mask: (B, 1, L) | |||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
# m: (1, L, L) | |||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
# tgt_mask: (B, L, L) | |||
# tgt_mask = tgt_mask & m | |||
memory = hs_pad | |||
memory_mask = ( | |||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
memory.device) | |||
# Padding for Longformer | |||
if memory_mask.shape[-1] != memory.shape[1]: | |||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
'constant', False) | |||
# x = self.embed(tgt) | |||
x = tgt | |||
x, tgt_mask, memory, memory_mask = self.decoders( | |||
x, tgt_mask, memory, memory_mask) | |||
embeds_outputs = x | |||
if self.decoders2 is not None: | |||
x, tgt_mask, memory, memory_mask = self.decoders2( | |||
x, tgt_mask, memory, memory_mask) | |||
if self.normalize_before: | |||
x = self.after_norm(x) | |||
if self.output_layer is not None: | |||
x = self.output_layer(x) | |||
olens = tgt_mask.sum(1) | |||
return x, olens, embeds_outputs | |||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
conv_wshare: int = 4, | |||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
conv_usebias: int = False, | |||
): | |||
assert check_argument_types() | |||
if len(conv_kernel_length) != num_blocks: | |||
raise ValueError( | |||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||
f'{len(conv_kernel_length)} != {num_blocks}') | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
LightweightConvolution( | |||
wshare=conv_wshare, | |||
n_feat=attention_dim, | |||
dropout_rate=self_attention_dropout_rate, | |||
kernel_size=conv_kernel_length[lnum], | |||
use_kernel_mask=True, | |||
use_bias=conv_usebias, | |||
), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
conv_wshare: int = 4, | |||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
conv_usebias: int = False, | |||
): | |||
assert check_argument_types() | |||
if len(conv_kernel_length) != num_blocks: | |||
raise ValueError( | |||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||
f'{len(conv_kernel_length)} != {num_blocks}') | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
LightweightConvolution2D( | |||
wshare=conv_wshare, | |||
n_feat=attention_dim, | |||
dropout_rate=self_attention_dropout_rate, | |||
kernel_size=conv_kernel_length[lnum], | |||
use_kernel_mask=True, | |||
use_bias=conv_usebias, | |||
), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
conv_wshare: int = 4, | |||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
conv_usebias: int = False, | |||
): | |||
assert check_argument_types() | |||
if len(conv_kernel_length) != num_blocks: | |||
raise ValueError( | |||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||
f'{len(conv_kernel_length)} != {num_blocks}') | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
DynamicConvolution( | |||
wshare=conv_wshare, | |||
n_feat=attention_dim, | |||
dropout_rate=self_attention_dropout_rate, | |||
kernel_size=conv_kernel_length[lnum], | |||
use_kernel_mask=True, | |||
use_bias=conv_usebias, | |||
), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||
def __init__( | |||
self, | |||
vocab_size: int, | |||
encoder_output_size: int, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
self_attention_dropout_rate: float = 0.0, | |||
src_attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'embed', | |||
use_output_layer: bool = True, | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
conv_wshare: int = 4, | |||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
conv_usebias: int = False, | |||
): | |||
assert check_argument_types() | |||
if len(conv_kernel_length) != num_blocks: | |||
raise ValueError( | |||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||
f'{len(conv_kernel_length)} != {num_blocks}') | |||
super().__init__( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
dropout_rate=dropout_rate, | |||
positional_dropout_rate=positional_dropout_rate, | |||
input_layer=input_layer, | |||
use_output_layer=use_output_layer, | |||
pos_enc_class=pos_enc_class, | |||
normalize_before=normalize_before, | |||
) | |||
attention_dim = encoder_output_size | |||
self.decoders = repeat( | |||
num_blocks, | |||
lambda lnum: DecoderLayer( | |||
attention_dim, | |||
DynamicConvolution2D( | |||
wshare=conv_wshare, | |||
n_feat=attention_dim, | |||
dropout_rate=self_attention_dropout_rate, | |||
kernel_size=conv_kernel_length[lnum], | |||
use_kernel_mask=True, | |||
use_bias=conv_usebias, | |||
), | |||
MultiHeadedAttention(attention_heads, attention_dim, | |||
src_attention_dropout_rate), | |||
PositionwiseFeedForward(attention_dim, linear_units, | |||
dropout_rate), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) |
@@ -1,710 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
"""Conformer encoder definition.""" | |||
import logging | |||
from typing import List, Optional, Tuple, Union | |||
import torch | |||
from espnet2.asr.ctc import CTC | |||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule | |||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer | |||
from espnet.nets.pytorch_backend.nets_utils import (get_activation, | |||
make_pad_mask) | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
LegacyRelPositionalEncoding # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
PositionalEncoding # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
RelPositionalEncoding # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
ScaledPositionalEncoding # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||
Conv1dLinear, MultiLayeredConv1d) | |||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
PositionwiseFeedForward # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||
Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||
from typeguard import check_argument_types | |||
from ...nets.pytorch_backend.transformer.attention import \ | |||
LegacyRelPositionMultiHeadedAttention # noqa: H301 | |||
from ...nets.pytorch_backend.transformer.attention import \ | |||
MultiHeadedAttention # noqa: H301 | |||
from ...nets.pytorch_backend.transformer.attention import \ | |||
RelPositionMultiHeadedAttention # noqa: H301 | |||
from ...nets.pytorch_backend.transformer.attention import ( | |||
LegacyRelPositionMultiHeadedAttentionSANM, | |||
RelPositionMultiHeadedAttentionSANM) | |||
class ConformerEncoder(AbsEncoder): | |||
"""Conformer encoder module. | |||
Args: | |||
input_size (int): Input dimension. | |||
output_size (int): Dimension of attention. | |||
attention_heads (int): The number of heads of multi head attention. | |||
linear_units (int): The number of units of position-wise feed forward. | |||
num_blocks (int): The number of decoder blocks. | |||
dropout_rate (float): Dropout rate. | |||
attention_dropout_rate (float): Dropout rate in attention. | |||
positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||
input_layer (Union[str, torch.nn.Module]): Input layer type. | |||
normalize_before (bool): Whether to use layer_norm before the first block. | |||
concat_after (bool): Whether to concat attention layer's input and output. | |||
If True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
If False, no additional linear will be applied. i.e. x -> x + att(x) | |||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||
rel_pos_type (str): Whether to use the latest relative positional encoding or | |||
the legacy one. The legacy relative positional encoding will be deprecated | |||
in the future. More Details can be found in | |||
https://github.com/espnet/espnet/pull/2816. | |||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||
encoder_attn_layer_type (str): Encoder attention layer type. | |||
activation_type (str): Encoder activation function type. | |||
macaron_style (bool): Whether to use macaron style for positionwise layer. | |||
use_cnn_module (bool): Whether to use convolution module. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
cnn_module_kernel (int): Kernerl size of convolution module. | |||
padding_idx (int): Padding idx for input_layer=embed. | |||
""" | |||
def __init__( | |||
self, | |||
input_size: int, | |||
output_size: int = 256, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'conv2d', | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
positionwise_layer_type: str = 'linear', | |||
positionwise_conv_kernel_size: int = 3, | |||
macaron_style: bool = False, | |||
rel_pos_type: str = 'legacy', | |||
pos_enc_layer_type: str = 'rel_pos', | |||
selfattention_layer_type: str = 'rel_selfattn', | |||
activation_type: str = 'swish', | |||
use_cnn_module: bool = True, | |||
zero_triu: bool = False, | |||
cnn_module_kernel: int = 31, | |||
padding_idx: int = -1, | |||
interctc_layer_idx: List[int] = [], | |||
interctc_use_conditioning: bool = False, | |||
stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
self._output_size = output_size | |||
if rel_pos_type == 'legacy': | |||
if pos_enc_layer_type == 'rel_pos': | |||
pos_enc_layer_type = 'legacy_rel_pos' | |||
if selfattention_layer_type == 'rel_selfattn': | |||
selfattention_layer_type = 'legacy_rel_selfattn' | |||
elif rel_pos_type == 'latest': | |||
assert selfattention_layer_type != 'legacy_rel_selfattn' | |||
assert pos_enc_layer_type != 'legacy_rel_pos' | |||
else: | |||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||
activation = get_activation(activation_type) | |||
if pos_enc_layer_type == 'abs_pos': | |||
pos_enc_class = PositionalEncoding | |||
elif pos_enc_layer_type == 'scaled_abs_pos': | |||
pos_enc_class = ScaledPositionalEncoding | |||
elif pos_enc_layer_type == 'rel_pos': | |||
assert selfattention_layer_type == 'rel_selfattn' | |||
pos_enc_class = RelPositionalEncoding | |||
elif pos_enc_layer_type == 'legacy_rel_pos': | |||
assert selfattention_layer_type == 'legacy_rel_selfattn' | |||
pos_enc_class = LegacyRelPositionalEncoding | |||
else: | |||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||
if input_layer == 'linear': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Linear(input_size, output_size), | |||
torch.nn.LayerNorm(output_size), | |||
torch.nn.Dropout(dropout_rate), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d': | |||
self.embed = Conv2dSubsampling( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d2': | |||
self.embed = Conv2dSubsampling2( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d6': | |||
self.embed = Conv2dSubsampling6( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d8': | |||
self.embed = Conv2dSubsampling8( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'embed': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Embedding( | |||
input_size, output_size, padding_idx=padding_idx), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif isinstance(input_layer, torch.nn.Module): | |||
self.embed = torch.nn.Sequential( | |||
input_layer, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer is None: | |||
self.embed = torch.nn.Sequential( | |||
pos_enc_class(output_size, positional_dropout_rate)) | |||
else: | |||
raise ValueError('unknown input_layer: ' + input_layer) | |||
self.normalize_before = normalize_before | |||
if positionwise_layer_type == 'linear': | |||
positionwise_layer = PositionwiseFeedForward | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
dropout_rate, | |||
activation, | |||
) | |||
elif positionwise_layer_type == 'conv1d': | |||
positionwise_layer = MultiLayeredConv1d | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d-linear': | |||
positionwise_layer = Conv1dLinear | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
else: | |||
raise NotImplementedError('Support only linear or conv1d.') | |||
if selfattention_layer_type == 'selfattn': | |||
encoder_selfattn_layer = MultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
elif selfattention_layer_type == 'legacy_rel_selfattn': | |||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
elif selfattention_layer_type == 'rel_selfattn': | |||
assert pos_enc_layer_type == 'rel_pos' | |||
encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
zero_triu, | |||
) | |||
else: | |||
raise ValueError('unknown encoder_attn_layer: ' | |||
+ selfattention_layer_type) | |||
convolution_layer = ConvolutionModule | |||
convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||
if isinstance(stochastic_depth_rate, float): | |||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||
if len(stochastic_depth_rate) != num_blocks: | |||
raise ValueError( | |||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||
f'should be equal to num_blocks ({num_blocks})') | |||
self.encoders = repeat( | |||
num_blocks, | |||
lambda lnum: EncoderLayer( | |||
output_size, | |||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
positionwise_layer(*positionwise_layer_args), | |||
positionwise_layer(*positionwise_layer_args) | |||
if macaron_style else None, | |||
convolution_layer(*convolution_layer_args) | |||
if use_cnn_module else None, | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
stochastic_depth_rate[lnum], | |||
), | |||
) | |||
if self.normalize_before: | |||
self.after_norm = LayerNorm(output_size) | |||
self.interctc_layer_idx = interctc_layer_idx | |||
if len(interctc_layer_idx) > 0: | |||
assert 0 < min(interctc_layer_idx) and max( | |||
interctc_layer_idx) < num_blocks | |||
self.interctc_use_conditioning = interctc_use_conditioning | |||
self.conditioning_layer = None | |||
def output_size(self) -> int: | |||
return self._output_size | |||
def forward( | |||
self, | |||
xs_pad: torch.Tensor, | |||
ilens: torch.Tensor, | |||
prev_states: torch.Tensor = None, | |||
ctc: CTC = None, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
"""Calculate forward propagation. | |||
Args: | |||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||
ilens (torch.Tensor): Input length (#batch). | |||
prev_states (torch.Tensor): Not to be used now. | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, L, output_size). | |||
torch.Tensor: Output length (#batch). | |||
torch.Tensor: Not to be used now. | |||
""" | |||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
if (isinstance(self.embed, Conv2dSubsampling) | |||
or isinstance(self.embed, Conv2dSubsampling2) | |||
or isinstance(self.embed, Conv2dSubsampling6) | |||
or isinstance(self.embed, Conv2dSubsampling8)): | |||
short_status, limit_size = check_short_utt(self.embed, | |||
xs_pad.size(1)) | |||
if short_status: | |||
raise TooShortUttError( | |||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
+ # noqa: * | |||
f'(it needs more than {limit_size} frames), return empty results', # noqa: * | |||
xs_pad.size(1), | |||
limit_size) # noqa: * | |||
xs_pad, masks = self.embed(xs_pad, masks) | |||
else: | |||
xs_pad = self.embed(xs_pad) | |||
intermediate_outs = [] | |||
if len(self.interctc_layer_idx) == 0: | |||
xs_pad, masks = self.encoders(xs_pad, masks) | |||
else: | |||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||
if layer_idx + 1 in self.interctc_layer_idx: | |||
encoder_out = xs_pad | |||
if isinstance(encoder_out, tuple): | |||
encoder_out = encoder_out[0] | |||
# intermediate outputs are also normalized | |||
if self.normalize_before: | |||
encoder_out = self.after_norm(encoder_out) | |||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
if self.interctc_use_conditioning: | |||
ctc_out = ctc.softmax(encoder_out) | |||
if isinstance(xs_pad, tuple): | |||
x, pos_emb = xs_pad | |||
x = x + self.conditioning_layer(ctc_out) | |||
xs_pad = (x, pos_emb) | |||
else: | |||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
if isinstance(xs_pad, tuple): | |||
xs_pad = xs_pad[0] | |||
if self.normalize_before: | |||
xs_pad = self.after_norm(xs_pad) | |||
olens = masks.squeeze(1).sum(1) | |||
if len(intermediate_outs) > 0: | |||
return (xs_pad, intermediate_outs), olens, None | |||
return xs_pad, olens, None | |||
class SANMEncoder_v2(AbsEncoder): | |||
"""Conformer encoder module. | |||
Args: | |||
input_size (int): Input dimension. | |||
output_size (int): Dimension of attention. | |||
attention_heads (int): The number of heads of multi head attention. | |||
linear_units (int): The number of units of position-wise feed forward. | |||
num_blocks (int): The number of decoder blocks. | |||
dropout_rate (float): Dropout rate. | |||
attention_dropout_rate (float): Dropout rate in attention. | |||
positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||
input_layer (Union[str, torch.nn.Module]): Input layer type. | |||
normalize_before (bool): Whether to use layer_norm before the first block. | |||
concat_after (bool): Whether to concat attention layer's input and output. | |||
If True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
If False, no additional linear will be applied. i.e. x -> x + att(x) | |||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||
rel_pos_type (str): Whether to use the latest relative positional encoding or | |||
the legacy one. The legacy relative positional encoding will be deprecated | |||
in the future. More Details can be found in | |||
https://github.com/espnet/espnet/pull/2816. | |||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||
encoder_attn_layer_type (str): Encoder attention layer type. | |||
activation_type (str): Encoder activation function type. | |||
macaron_style (bool): Whether to use macaron style for positionwise layer. | |||
use_cnn_module (bool): Whether to use convolution module. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
cnn_module_kernel (int): Kernerl size of convolution module. | |||
padding_idx (int): Padding idx for input_layer=embed. | |||
""" | |||
def __init__( | |||
self, | |||
input_size: int, | |||
output_size: int = 256, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
attention_dropout_rate: float = 0.0, | |||
input_layer: str = 'conv2d', | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
positionwise_layer_type: str = 'linear', | |||
positionwise_conv_kernel_size: int = 3, | |||
macaron_style: bool = False, | |||
rel_pos_type: str = 'legacy', | |||
pos_enc_layer_type: str = 'rel_pos', | |||
selfattention_layer_type: str = 'rel_selfattn', | |||
activation_type: str = 'swish', | |||
use_cnn_module: bool = False, | |||
sanm_shfit: int = 0, | |||
zero_triu: bool = False, | |||
cnn_module_kernel: int = 31, | |||
padding_idx: int = -1, | |||
interctc_layer_idx: List[int] = [], | |||
interctc_use_conditioning: bool = False, | |||
stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
self._output_size = output_size | |||
if rel_pos_type == 'legacy': | |||
if pos_enc_layer_type == 'rel_pos': | |||
pos_enc_layer_type = 'legacy_rel_pos' | |||
if selfattention_layer_type == 'rel_selfattn': | |||
selfattention_layer_type = 'legacy_rel_selfattn' | |||
if selfattention_layer_type == 'rel_selfattnsanm': | |||
selfattention_layer_type = 'legacy_rel_selfattnsanm' | |||
elif rel_pos_type == 'latest': | |||
assert selfattention_layer_type != 'legacy_rel_selfattn' | |||
assert pos_enc_layer_type != 'legacy_rel_pos' | |||
else: | |||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||
activation = get_activation(activation_type) | |||
if pos_enc_layer_type == 'abs_pos': | |||
pos_enc_class = PositionalEncoding | |||
elif pos_enc_layer_type == 'scaled_abs_pos': | |||
pos_enc_class = ScaledPositionalEncoding | |||
elif pos_enc_layer_type == 'rel_pos': | |||
# assert selfattention_layer_type == "rel_selfattn" | |||
pos_enc_class = RelPositionalEncoding | |||
elif pos_enc_layer_type == 'legacy_rel_pos': | |||
# assert selfattention_layer_type == "legacy_rel_selfattn" | |||
pos_enc_class = LegacyRelPositionalEncoding | |||
logging.warning( | |||
'Using legacy_rel_pos and it will be deprecated in the future.' | |||
) | |||
else: | |||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||
if input_layer == 'linear': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Linear(input_size, output_size), | |||
torch.nn.LayerNorm(output_size), | |||
torch.nn.Dropout(dropout_rate), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d': | |||
self.embed = Conv2dSubsampling( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d2': | |||
self.embed = Conv2dSubsampling2( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d6': | |||
self.embed = Conv2dSubsampling6( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d8': | |||
self.embed = Conv2dSubsampling8( | |||
input_size, | |||
output_size, | |||
dropout_rate, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'embed': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Embedding( | |||
input_size, output_size, padding_idx=padding_idx), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif isinstance(input_layer, torch.nn.Module): | |||
self.embed = torch.nn.Sequential( | |||
input_layer, | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer is None: | |||
self.embed = torch.nn.Sequential( | |||
pos_enc_class(output_size, positional_dropout_rate)) | |||
else: | |||
raise ValueError('unknown input_layer: ' + input_layer) | |||
self.normalize_before = normalize_before | |||
if positionwise_layer_type == 'linear': | |||
positionwise_layer = PositionwiseFeedForward | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
dropout_rate, | |||
activation, | |||
) | |||
elif positionwise_layer_type == 'conv1d': | |||
positionwise_layer = MultiLayeredConv1d | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d-linear': | |||
positionwise_layer = Conv1dLinear | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
else: | |||
raise NotImplementedError('Support only linear or conv1d.') | |||
if selfattention_layer_type == 'selfattn': | |||
encoder_selfattn_layer = MultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
elif selfattention_layer_type == 'legacy_rel_selfattn': | |||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
logging.warning( | |||
'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||
) | |||
elif selfattention_layer_type == 'legacy_rel_selfattnsanm': | |||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
logging.warning( | |||
'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||
) | |||
elif selfattention_layer_type == 'rel_selfattn': | |||
assert pos_enc_layer_type == 'rel_pos' | |||
encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
zero_triu, | |||
) | |||
elif selfattention_layer_type == 'rel_selfattnsanm': | |||
assert pos_enc_layer_type == 'rel_pos' | |||
encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
zero_triu, | |||
cnn_module_kernel, | |||
sanm_shfit, | |||
) | |||
else: | |||
raise ValueError('unknown encoder_attn_layer: ' | |||
+ selfattention_layer_type) | |||
convolution_layer = ConvolutionModule | |||
convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||
if isinstance(stochastic_depth_rate, float): | |||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||
if len(stochastic_depth_rate) != num_blocks: | |||
raise ValueError( | |||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||
f'should be equal to num_blocks ({num_blocks})') | |||
self.encoders = repeat( | |||
num_blocks, | |||
lambda lnum: EncoderLayer( | |||
output_size, | |||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
positionwise_layer(*positionwise_layer_args), | |||
positionwise_layer(*positionwise_layer_args) | |||
if macaron_style else None, | |||
convolution_layer(*convolution_layer_args) | |||
if use_cnn_module else None, | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
stochastic_depth_rate[lnum], | |||
), | |||
) | |||
if self.normalize_before: | |||
self.after_norm = LayerNorm(output_size) | |||
self.interctc_layer_idx = interctc_layer_idx | |||
if len(interctc_layer_idx) > 0: | |||
assert 0 < min(interctc_layer_idx) and max( | |||
interctc_layer_idx) < num_blocks | |||
self.interctc_use_conditioning = interctc_use_conditioning | |||
self.conditioning_layer = None | |||
def output_size(self) -> int: | |||
return self._output_size | |||
def forward( | |||
self, | |||
xs_pad: torch.Tensor, | |||
ilens: torch.Tensor, | |||
prev_states: torch.Tensor = None, | |||
ctc: CTC = None, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
"""Calculate forward propagation. | |||
Args: | |||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||
ilens (torch.Tensor): Input length (#batch). | |||
prev_states (torch.Tensor): Not to be used now. | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, L, output_size). | |||
torch.Tensor: Output length (#batch). | |||
torch.Tensor: Not to be used now. | |||
""" | |||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
if (isinstance(self.embed, Conv2dSubsampling) | |||
or isinstance(self.embed, Conv2dSubsampling2) | |||
or isinstance(self.embed, Conv2dSubsampling6) | |||
or isinstance(self.embed, Conv2dSubsampling8)): | |||
short_status, limit_size = check_short_utt(self.embed, | |||
xs_pad.size(1)) | |||
if short_status: | |||
raise TooShortUttError( | |||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
+ # noqa: * | |||
f'(it needs more than {limit_size} frames), return empty results', | |||
xs_pad.size(1), | |||
limit_size) # noqa: * | |||
xs_pad, masks = self.embed(xs_pad, masks) | |||
else: | |||
xs_pad = self.embed(xs_pad) | |||
intermediate_outs = [] | |||
if len(self.interctc_layer_idx) == 0: | |||
xs_pad, masks = self.encoders(xs_pad, masks) | |||
else: | |||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||
if layer_idx + 1 in self.interctc_layer_idx: | |||
encoder_out = xs_pad | |||
if isinstance(encoder_out, tuple): | |||
encoder_out = encoder_out[0] | |||
# intermediate outputs are also normalized | |||
if self.normalize_before: | |||
encoder_out = self.after_norm(encoder_out) | |||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
if self.interctc_use_conditioning: | |||
ctc_out = ctc.softmax(encoder_out) | |||
if isinstance(xs_pad, tuple): | |||
x, pos_emb = xs_pad | |||
x = x + self.conditioning_layer(ctc_out) | |||
xs_pad = (x, pos_emb) | |||
else: | |||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
if isinstance(xs_pad, tuple): | |||
xs_pad = xs_pad[0] | |||
if self.normalize_before: | |||
xs_pad = self.after_norm(xs_pad) | |||
olens = masks.squeeze(1).sum(1) | |||
if len(intermediate_outs) > 0: | |||
return (xs_pad, intermediate_outs), olens, None | |||
return xs_pad, olens, None |
@@ -1,500 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
"""Transformer encoder definition.""" | |||
import logging | |||
from typing import List, Optional, Sequence, Tuple, Union | |||
import torch | |||
from espnet2.asr.ctc import CTC | |||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
PositionalEncoding | |||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||
Conv1dLinear, MultiLayeredConv1d) | |||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
PositionwiseFeedForward # noqa: H301 | |||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||
Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||
from typeguard import check_argument_types | |||
from ...asr.streaming_utilis.chunk_utilis import overlap_chunk | |||
from ...nets.pytorch_backend.transformer.attention import ( | |||
MultiHeadedAttention, MultiHeadedAttentionSANM) | |||
from ...nets.pytorch_backend.transformer.encoder_layer import ( | |||
EncoderLayer, EncoderLayerChunk) | |||
class SANMEncoder(AbsEncoder): | |||
"""Transformer encoder module. | |||
Args: | |||
input_size: input dim | |||
output_size: dimension of attention | |||
attention_heads: the number of heads of multi head attention | |||
linear_units: the number of units of position-wise feed forward | |||
num_blocks: the number of decoder blocks | |||
dropout_rate: dropout rate | |||
attention_dropout_rate: dropout rate in attention | |||
positional_dropout_rate: dropout rate after adding positional encoding | |||
input_layer: input layer type | |||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
normalize_before: whether to use layer_norm before the first block | |||
concat_after: whether to concat attention layer's input and output | |||
if True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
if False, no additional linear will be applied. | |||
i.e. x -> x + att(x) | |||
positionwise_layer_type: linear of conv1d | |||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||
padding_idx: padding_idx for input_layer=embed | |||
""" | |||
def __init__( | |||
self, | |||
input_size: int, | |||
output_size: int = 256, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
attention_dropout_rate: float = 0.0, | |||
input_layer: Optional[str] = 'conv2d', | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
positionwise_layer_type: str = 'linear', | |||
positionwise_conv_kernel_size: int = 1, | |||
padding_idx: int = -1, | |||
interctc_layer_idx: List[int] = [], | |||
interctc_use_conditioning: bool = False, | |||
kernel_size: int = 11, | |||
sanm_shfit: int = 0, | |||
selfattention_layer_type: str = 'sanm', | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
self._output_size = output_size | |||
if input_layer == 'linear': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Linear(input_size, output_size), | |||
torch.nn.LayerNorm(output_size), | |||
torch.nn.Dropout(dropout_rate), | |||
torch.nn.ReLU(), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d': | |||
self.embed = Conv2dSubsampling(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d2': | |||
self.embed = Conv2dSubsampling2(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d6': | |||
self.embed = Conv2dSubsampling6(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d8': | |||
self.embed = Conv2dSubsampling8(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'embed': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Embedding( | |||
input_size, output_size, padding_idx=padding_idx), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer is None: | |||
if input_size == output_size: | |||
self.embed = None | |||
else: | |||
self.embed = torch.nn.Linear(input_size, output_size) | |||
else: | |||
raise ValueError('unknown input_layer: ' + input_layer) | |||
self.normalize_before = normalize_before | |||
if positionwise_layer_type == 'linear': | |||
positionwise_layer = PositionwiseFeedForward | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d': | |||
positionwise_layer = MultiLayeredConv1d | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d-linear': | |||
positionwise_layer = Conv1dLinear | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
else: | |||
raise NotImplementedError('Support only linear or conv1d.') | |||
if selfattention_layer_type == 'selfattn': | |||
encoder_selfattn_layer = MultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
elif selfattention_layer_type == 'sanm': | |||
encoder_selfattn_layer = MultiHeadedAttentionSANM | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
kernel_size, | |||
sanm_shfit, | |||
) | |||
self.encoders = repeat( | |||
num_blocks, | |||
lambda lnum: EncoderLayer( | |||
output_size, | |||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
positionwise_layer(*positionwise_layer_args), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
if self.normalize_before: | |||
self.after_norm = LayerNorm(output_size) | |||
self.interctc_layer_idx = interctc_layer_idx | |||
if len(interctc_layer_idx) > 0: | |||
assert 0 < min(interctc_layer_idx) and max( | |||
interctc_layer_idx) < num_blocks | |||
self.interctc_use_conditioning = interctc_use_conditioning | |||
self.conditioning_layer = None | |||
def output_size(self) -> int: | |||
return self._output_size | |||
def forward( | |||
self, | |||
xs_pad: torch.Tensor, | |||
ilens: torch.Tensor, | |||
prev_states: torch.Tensor = None, | |||
ctc: CTC = None, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
"""Embed positions in tensor. | |||
Args: | |||
xs_pad: input tensor (B, L, D) | |||
ilens: input length (B) | |||
prev_states: Not to be used now. | |||
Returns: | |||
position embedded tensor and mask | |||
""" | |||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
if self.embed is None: | |||
xs_pad = xs_pad | |||
elif (isinstance(self.embed, Conv2dSubsampling) | |||
or isinstance(self.embed, Conv2dSubsampling2) | |||
or isinstance(self.embed, Conv2dSubsampling6) | |||
or isinstance(self.embed, Conv2dSubsampling8)): | |||
short_status, limit_size = check_short_utt(self.embed, | |||
xs_pad.size(1)) | |||
if short_status: | |||
raise TooShortUttError( | |||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
+ # noqa: * | |||
f'(it needs more than {limit_size} frames), return empty results', | |||
xs_pad.size(1), | |||
limit_size, | |||
) | |||
xs_pad, masks = self.embed(xs_pad, masks) | |||
else: | |||
xs_pad = self.embed(xs_pad) | |||
intermediate_outs = [] | |||
if len(self.interctc_layer_idx) == 0: | |||
xs_pad, masks = self.encoders(xs_pad, masks) | |||
else: | |||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||
if layer_idx + 1 in self.interctc_layer_idx: | |||
encoder_out = xs_pad | |||
# intermediate outputs are also normalized | |||
if self.normalize_before: | |||
encoder_out = self.after_norm(encoder_out) | |||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
if self.interctc_use_conditioning: | |||
ctc_out = ctc.softmax(encoder_out) | |||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
if self.normalize_before: | |||
xs_pad = self.after_norm(xs_pad) | |||
olens = masks.squeeze(1).sum(1) | |||
if len(intermediate_outs) > 0: | |||
return (xs_pad, intermediate_outs), olens, None | |||
return xs_pad, olens, None | |||
class SANMEncoderChunk(AbsEncoder): | |||
"""Transformer encoder module. | |||
Args: | |||
input_size: input dim | |||
output_size: dimension of attention | |||
attention_heads: the number of heads of multi head attention | |||
linear_units: the number of units of position-wise feed forward | |||
num_blocks: the number of decoder blocks | |||
dropout_rate: dropout rate | |||
attention_dropout_rate: dropout rate in attention | |||
positional_dropout_rate: dropout rate after adding positional encoding | |||
input_layer: input layer type | |||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
normalize_before: whether to use layer_norm before the first block | |||
concat_after: whether to concat attention layer's input and output | |||
if True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
if False, no additional linear will be applied. | |||
i.e. x -> x + att(x) | |||
positionwise_layer_type: linear of conv1d | |||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||
padding_idx: padding_idx for input_layer=embed | |||
""" | |||
def __init__( | |||
self, | |||
input_size: int, | |||
output_size: int = 256, | |||
attention_heads: int = 4, | |||
linear_units: int = 2048, | |||
num_blocks: int = 6, | |||
dropout_rate: float = 0.1, | |||
positional_dropout_rate: float = 0.1, | |||
attention_dropout_rate: float = 0.0, | |||
input_layer: Optional[str] = 'conv2d', | |||
pos_enc_class=PositionalEncoding, | |||
normalize_before: bool = True, | |||
concat_after: bool = False, | |||
positionwise_layer_type: str = 'linear', | |||
positionwise_conv_kernel_size: int = 1, | |||
padding_idx: int = -1, | |||
interctc_layer_idx: List[int] = [], | |||
interctc_use_conditioning: bool = False, | |||
kernel_size: int = 11, | |||
sanm_shfit: int = 0, | |||
selfattention_layer_type: str = 'sanm', | |||
chunk_size: Union[int, Sequence[int]] = (16, ), | |||
stride: Union[int, Sequence[int]] = (10, ), | |||
pad_left: Union[int, Sequence[int]] = (0, ), | |||
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ), | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
self._output_size = output_size | |||
if input_layer == 'linear': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Linear(input_size, output_size), | |||
torch.nn.LayerNorm(output_size), | |||
torch.nn.Dropout(dropout_rate), | |||
torch.nn.ReLU(), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer == 'conv2d': | |||
self.embed = Conv2dSubsampling(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d2': | |||
self.embed = Conv2dSubsampling2(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d6': | |||
self.embed = Conv2dSubsampling6(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'conv2d8': | |||
self.embed = Conv2dSubsampling8(input_size, output_size, | |||
dropout_rate) | |||
elif input_layer == 'embed': | |||
self.embed = torch.nn.Sequential( | |||
torch.nn.Embedding( | |||
input_size, output_size, padding_idx=padding_idx), | |||
pos_enc_class(output_size, positional_dropout_rate), | |||
) | |||
elif input_layer is None: | |||
if input_size == output_size: | |||
self.embed = None | |||
else: | |||
self.embed = torch.nn.Linear(input_size, output_size) | |||
else: | |||
raise ValueError('unknown input_layer: ' + input_layer) | |||
self.normalize_before = normalize_before | |||
if positionwise_layer_type == 'linear': | |||
positionwise_layer = PositionwiseFeedForward | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d': | |||
positionwise_layer = MultiLayeredConv1d | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
elif positionwise_layer_type == 'conv1d-linear': | |||
positionwise_layer = Conv1dLinear | |||
positionwise_layer_args = ( | |||
output_size, | |||
linear_units, | |||
positionwise_conv_kernel_size, | |||
dropout_rate, | |||
) | |||
else: | |||
raise NotImplementedError('Support only linear or conv1d.') | |||
if selfattention_layer_type == 'selfattn': | |||
encoder_selfattn_layer = MultiHeadedAttention | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
) | |||
elif selfattention_layer_type == 'sanm': | |||
encoder_selfattn_layer = MultiHeadedAttentionSANM | |||
encoder_selfattn_layer_args = ( | |||
attention_heads, | |||
output_size, | |||
attention_dropout_rate, | |||
kernel_size, | |||
sanm_shfit, | |||
) | |||
self.encoders = repeat( | |||
num_blocks, | |||
lambda lnum: EncoderLayerChunk( | |||
output_size, | |||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
positionwise_layer(*positionwise_layer_args), | |||
dropout_rate, | |||
normalize_before, | |||
concat_after, | |||
), | |||
) | |||
if self.normalize_before: | |||
self.after_norm = LayerNorm(output_size) | |||
self.interctc_layer_idx = interctc_layer_idx | |||
if len(interctc_layer_idx) > 0: | |||
assert 0 < min(interctc_layer_idx) and max( | |||
interctc_layer_idx) < num_blocks | |||
self.interctc_use_conditioning = interctc_use_conditioning | |||
self.conditioning_layer = None | |||
shfit_fsmn = (kernel_size - 1) // 2 | |||
self.overlap_chunk_cls = overlap_chunk( | |||
chunk_size=chunk_size, | |||
stride=stride, | |||
pad_left=pad_left, | |||
shfit_fsmn=shfit_fsmn, | |||
encoder_att_look_back_factor=encoder_att_look_back_factor, | |||
) | |||
def output_size(self) -> int: | |||
return self._output_size | |||
def forward( | |||
self, | |||
xs_pad: torch.Tensor, | |||
ilens: torch.Tensor, | |||
prev_states: torch.Tensor = None, | |||
ctc: CTC = None, | |||
ind: int = 0, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
"""Embed positions in tensor. | |||
Args: | |||
xs_pad: input tensor (B, L, D) | |||
ilens: input length (B) | |||
prev_states: Not to be used now. | |||
Returns: | |||
position embedded tensor and mask | |||
""" | |||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
if self.embed is None: | |||
xs_pad = xs_pad | |||
elif (isinstance(self.embed, Conv2dSubsampling) | |||
or isinstance(self.embed, Conv2dSubsampling2) | |||
or isinstance(self.embed, Conv2dSubsampling6) | |||
or isinstance(self.embed, Conv2dSubsampling8)): | |||
short_status, limit_size = check_short_utt(self.embed, | |||
xs_pad.size(1)) | |||
if short_status: | |||
raise TooShortUttError( | |||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
+ # noqa: * | |||
f'(it needs more than {limit_size} frames), return empty results', | |||
xs_pad.size(1), | |||
limit_size, | |||
) | |||
xs_pad, masks = self.embed(xs_pad, masks) | |||
else: | |||
xs_pad = self.embed(xs_pad) | |||
mask_shfit_chunk, mask_att_chunk_encoder = None, None | |||
if self.overlap_chunk_cls is not None: | |||
ilens = masks.squeeze(1).sum(1) | |||
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) | |||
xs_pad, ilens = self.overlap_chunk_cls.split_chunk( | |||
xs_pad, ilens, chunk_outs=chunk_outs) | |||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( | |||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( | |||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||
intermediate_outs = [] | |||
if len(self.interctc_layer_idx) == 0: | |||
xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None, | |||
mask_shfit_chunk, | |||
mask_att_chunk_encoder) | |||
else: | |||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||
xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None, | |||
mask_shfit_chunk, | |||
mask_att_chunk_encoder) | |||
if layer_idx + 1 in self.interctc_layer_idx: | |||
encoder_out = xs_pad | |||
# intermediate outputs are also normalized | |||
if self.normalize_before: | |||
encoder_out = self.after_norm(encoder_out) | |||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
if self.interctc_use_conditioning: | |||
ctc_out = ctc.softmax(encoder_out) | |||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
if self.normalize_before: | |||
xs_pad = self.after_norm(xs_pad) | |||
if self.overlap_chunk_cls is not None: | |||
xs_pad, olens = self.overlap_chunk_cls.remove_chunk( | |||
xs_pad, ilens, chunk_outs) | |||
if len(intermediate_outs) > 0: | |||
return (xs_pad, intermediate_outs), olens, None | |||
return xs_pad, olens, None |
@@ -1,113 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
import copy | |||
from typing import Optional, Tuple, Union | |||
import humanfriendly | |||
import numpy as np | |||
import torch | |||
import torchaudio | |||
import torchaudio.compliance.kaldi as kaldi | |||
from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||
from espnet2.layers.log_mel import LogMel | |||
from espnet2.layers.stft import Stft | |||
from espnet2.utils.get_default_kwargs import get_default_kwargs | |||
from espnet.nets.pytorch_backend.frontends.frontend import Frontend | |||
from typeguard import check_argument_types | |||
class WavFrontend(AbsFrontend): | |||
"""Conventional frontend structure for ASR. | |||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN | |||
""" | |||
def __init__( | |||
self, | |||
fs: Union[int, str] = 16000, | |||
n_fft: int = 512, | |||
win_length: int = 400, | |||
hop_length: int = 160, | |||
window: Optional[str] = 'hamming', | |||
center: bool = True, | |||
normalized: bool = False, | |||
onesided: bool = True, | |||
n_mels: int = 80, | |||
fmin: int = None, | |||
fmax: int = None, | |||
htk: bool = False, | |||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend), | |||
apply_stft: bool = True, | |||
): | |||
assert check_argument_types() | |||
super().__init__() | |||
if isinstance(fs, str): | |||
fs = humanfriendly.parse_size(fs) | |||
# Deepcopy (In general, dict shouldn't be used as default arg) | |||
frontend_conf = copy.deepcopy(frontend_conf) | |||
self.hop_length = hop_length | |||
self.win_length = win_length | |||
self.window = window | |||
self.fs = fs | |||
if apply_stft: | |||
self.stft = Stft( | |||
n_fft=n_fft, | |||
win_length=win_length, | |||
hop_length=hop_length, | |||
center=center, | |||
window=window, | |||
normalized=normalized, | |||
onesided=onesided, | |||
) | |||
else: | |||
self.stft = None | |||
self.apply_stft = apply_stft | |||
if frontend_conf is not None: | |||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) | |||
else: | |||
self.frontend = None | |||
self.logmel = LogMel( | |||
fs=fs, | |||
n_fft=n_fft, | |||
n_mels=n_mels, | |||
fmin=fmin, | |||
fmax=fmax, | |||
htk=htk, | |||
) | |||
self.n_mels = n_mels | |||
self.frontend_type = 'default' | |||
def output_size(self) -> int: | |||
return self.n_mels | |||
def forward( | |||
self, input: torch.Tensor, | |||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
sample_frequency = self.fs | |||
num_mel_bins = self.n_mels | |||
frame_length = self.win_length * 1000 / sample_frequency | |||
frame_shift = self.hop_length * 1000 / sample_frequency | |||
waveform = input * (1 << 15) | |||
mat = kaldi.fbank( | |||
waveform, | |||
num_mel_bins=num_mel_bins, | |||
frame_length=frame_length, | |||
frame_shift=frame_shift, | |||
dither=1.0, | |||
energy_floor=0.0, | |||
window_type=self.window, | |||
sample_frequency=sample_frequency) | |||
input_feats = mat[None, :] | |||
feats_lens = torch.randn(1) | |||
feats_lens.fill_(input_feats.shape[1]) | |||
return input_feats, feats_lens |
@@ -1,321 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
import logging | |||
import math | |||
import numpy as np | |||
import torch | |||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
from ...nets.pytorch_backend.cif_utils.cif import \ | |||
cif_predictor as cif_predictor | |||
np.set_printoptions(threshold=np.inf) | |||
torch.set_printoptions(profile='full', precision=100000, linewidth=None) | |||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'): | |||
if maxlen is None: | |||
maxlen = lengths.max() | |||
row_vector = torch.arange(0, maxlen, 1) | |||
matrix = torch.unsqueeze(lengths, dim=-1) | |||
mask = row_vector < matrix | |||
return mask.type(dtype).to(device) | |||
class overlap_chunk(): | |||
def __init__( | |||
self, | |||
chunk_size: tuple = (16, ), | |||
stride: tuple = (10, ), | |||
pad_left: tuple = (0, ), | |||
encoder_att_look_back_factor: tuple = (1, ), | |||
shfit_fsmn: int = 0, | |||
): | |||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \ | |||
= chunk_size, stride, pad_left, encoder_att_look_back_factor | |||
self.shfit_fsmn = shfit_fsmn | |||
self.x_add_mask = None | |||
self.x_rm_mask = None | |||
self.x_len = None | |||
self.mask_shfit_chunk = None | |||
self.mask_chunk_predictor = None | |||
self.mask_att_chunk_encoder = None | |||
self.mask_shift_att_chunk_decoder = None | |||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \ | |||
= None, None, None, None | |||
def get_chunk_size(self, ind: int = 0): | |||
# with torch.no_grad: | |||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[ | |||
ind], self.stride[ind], self.pad_left[ | |||
ind], self.encoder_att_look_back_factor[ind] | |||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, | |||
self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ | |||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn | |||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur | |||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): | |||
with torch.no_grad(): | |||
x_len = x_len.cpu().numpy() | |||
x_len_max = x_len.max() | |||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size( | |||
ind) | |||
shfit_fsmn = self.shfit_fsmn | |||
chunk_size_pad_shift = chunk_size + shfit_fsmn | |||
chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) | |||
x_len_chunk = ( | |||
chunk_num_batch - 1 | |||
) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - ( | |||
chunk_num_batch - 1) * stride | |||
x_len_chunk = x_len_chunk.astype(x_len.dtype) | |||
x_len_chunk_max = x_len_chunk.max() | |||
chunk_num = int(math.ceil(x_len_max / stride)) | |||
dtype = np.int32 | |||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max) | |||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) | |||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) | |||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) | |||
mask_chunk_predictor = np.zeros([0, num_units_predictor], | |||
dtype=dtype) | |||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) | |||
mask_att_chunk_encoder = np.zeros( | |||
[0, chunk_num * chunk_size_pad_shift], dtype=dtype) | |||
for chunk_ids in range(chunk_num): | |||
# x_mask add | |||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), | |||
dtype=dtype) | |||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) | |||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), | |||
dtype=dtype) | |||
x_mask_pad_right = np.zeros( | |||
(chunk_size, max_len_for_x_mask_tmp), dtype=dtype) | |||
x_cur_pad = np.concatenate( | |||
[x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) | |||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] | |||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], | |||
axis=0) | |||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], | |||
axis=0) | |||
# x_mask rm | |||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), | |||
dtype=dtype) | |||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) | |||
x_mask_right = np.zeros((stride, chunk_size - stride), | |||
dtype=dtype) | |||
x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1) | |||
x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size), | |||
dtype=dtype) | |||
x_mask_cur_pad_bottom = np.zeros( | |||
(max_len_for_x_mask_tmp, chunk_size), dtype=dtype) | |||
x_rm_mask_cur = np.concatenate( | |||
[x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], | |||
axis=0) | |||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, : | |||
chunk_size] | |||
x_rm_mask_cur_fsmn = np.concatenate( | |||
[fsmn_padding, x_rm_mask_cur], axis=1) | |||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], | |||
axis=1) | |||
# fsmn_padding_mask | |||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) | |||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype) | |||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], | |||
axis=0) | |||
mask_shfit_chunk = np.concatenate( | |||
[mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) | |||
# predictor mask | |||
zeros_1 = np.zeros( | |||
[shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) | |||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) | |||
zeros_3 = np.zeros( | |||
[chunk_size - stride - pad_left, num_units_predictor], | |||
dtype=dtype) | |||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) | |||
mask_chunk_predictor_cur = np.concatenate( | |||
[zeros_1, ones_zeros], axis=0) | |||
mask_chunk_predictor = np.concatenate( | |||
[mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) | |||
# encoder att mask | |||
zeros_1_top = np.zeros( | |||
[shfit_fsmn, chunk_num * chunk_size_pad_shift], | |||
dtype=dtype) | |||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) | |||
zeros_2 = np.zeros( | |||
[chunk_size, zeros_2_num * chunk_size_pad_shift], | |||
dtype=dtype) | |||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) | |||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||
ones_2_mid = np.ones([stride, stride], dtype=dtype) | |||
zeros_2_bottom = np.zeros([chunk_size - stride, stride], | |||
dtype=dtype) | |||
zeros_2_right = np.zeros([chunk_size, chunk_size - stride], | |||
dtype=dtype) | |||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) | |||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], | |||
axis=1) | |||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) | |||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) | |||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) | |||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) | |||
zeros_remain = np.zeros( | |||
[chunk_size, zeros_remain_num * chunk_size_pad_shift], | |||
dtype=dtype) | |||
ones2_bottom = np.concatenate( | |||
[zeros_2, ones_2, ones_3, zeros_remain], axis=1) | |||
mask_att_chunk_encoder_cur = np.concatenate( | |||
[zeros_1_top, ones2_bottom], axis=0) | |||
mask_att_chunk_encoder = np.concatenate( | |||
[mask_att_chunk_encoder, mask_att_chunk_encoder_cur], | |||
axis=0) | |||
# decoder fsmn_shift_att_mask | |||
zeros_1 = np.zeros([shfit_fsmn, 1]) | |||
ones_1 = np.ones([chunk_size, 1]) | |||
mask_shift_att_chunk_decoder_cur = np.concatenate( | |||
[zeros_1, ones_1], axis=0) | |||
mask_shift_att_chunk_decoder = np.concatenate( | |||
[ | |||
mask_shift_att_chunk_decoder, | |||
mask_shift_att_chunk_decoder_cur | |||
], | |||
vaxis=0) # noqa: * | |||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max] | |||
self.x_len_chunk = x_len_chunk | |||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] | |||
self.x_len = x_len | |||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] | |||
self.mask_chunk_predictor = mask_chunk_predictor[: | |||
x_len_chunk_max, :] | |||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[: | |||
x_len_chunk_max, : | |||
x_len_chunk_max] | |||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[: | |||
x_len_chunk_max, :] | |||
return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len, | |||
self.mask_shfit_chunk, self.mask_chunk_predictor, | |||
self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder) | |||
def split_chunk(self, x, x_len, chunk_outs): | |||
""" | |||
:param x: (b, t, d) | |||
:param x_length: (b) | |||
:param ind: int | |||
:return: | |||
""" | |||
x = x[:, :x_len.max(), :] | |||
b, t, d = x.size() | |||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device) | |||
x *= x_len_mask[:, :, None] | |||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) | |||
x_len_chunk = self.get_x_len_chunk( | |||
chunk_outs, x_len.device, dtype=x_len.dtype) | |||
x = torch.transpose(x, 1, 0) | |||
x = torch.reshape(x, [t, -1]) | |||
x_chunk = torch.mm(x_add_mask, x) | |||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) | |||
return x_chunk, x_len_chunk | |||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): | |||
x_chunk = x_chunk[:, :x_len_chunk.max(), :] | |||
b, t, d = x_chunk.size() | |||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( | |||
x_chunk.device) | |||
x_chunk *= x_len_chunk_mask[:, :, None] | |||
x_rm_mask = self.get_x_rm_mask( | |||
chunk_outs, x_chunk.device, dtype=x_chunk.dtype) | |||
x_len = self.get_x_len( | |||
chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) | |||
x_chunk = torch.transpose(x_chunk, 1, 0) | |||
x_chunk = torch.reshape(x_chunk, [t, -1]) | |||
x = torch.mm(x_rm_mask, x_chunk) | |||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0) | |||
return x, x_len | |||
def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_mask_shfit_chunk(self, | |||
chunk_outs, | |||
device, | |||
batch_size=1, | |||
num_units=1, | |||
idx=4, | |||
dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_mask_chunk_predictor(self, | |||
chunk_outs, | |||
device, | |||
batch_size=1, | |||
num_units=1, | |||
idx=5, | |||
dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_mask_att_chunk_encoder(self, | |||
chunk_outs, | |||
device, | |||
batch_size=1, | |||
idx=6, | |||
dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() | |||
def get_mask_shift_att_chunk_decoder(self, | |||
chunk_outs, | |||
device, | |||
batch_size=1, | |||
idx=7, | |||
dtype=torch.float32): | |||
x = chunk_outs[idx] | |||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) | |||
x = torch.from_numpy(x).type(dtype).to(device) | |||
return x.detach() |
@@ -1,250 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
import logging | |||
import numpy as np | |||
import torch | |||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
from torch import nn | |||
class CIF_Model(nn.Module): | |||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||
super(CIF_Model, self).__init__() | |||
self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||
self.cif_conv1d = nn.Conv1d( | |||
idim, idim, l_order + r_order + 1, groups=idim) | |||
self.cif_output = nn.Linear(idim, 1) | |||
self.dropout = torch.nn.Dropout(p=dropout) | |||
self.threshold = threshold | |||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1): | |||
h = hidden | |||
context = h.transpose(1, 2) | |||
queries = self.pad(context) | |||
memory = self.cif_conv1d(queries) | |||
output = memory + context | |||
output = self.dropout(output) | |||
output = output.transpose(1, 2) | |||
output = torch.relu(output) | |||
output = self.cif_output(output) | |||
alphas = torch.sigmoid(output) | |||
if mask is not None: | |||
alphas = alphas * mask.transpose(-1, -2).float() | |||
alphas = alphas.squeeze(-1) | |||
if target_label is not None: | |||
target_length = (target_label != ignore_id).float().sum(-1) | |||
else: | |||
target_length = None | |||
cif_length = alphas.sum(-1) | |||
if target_label is not None: | |||
alphas *= (target_length / cif_length)[:, None].repeat( | |||
1, alphas.size(1)) | |||
cif_output, cif_peak = cif(hidden, alphas, self.threshold) | |||
return cif_output, cif_length, target_length, cif_peak | |||
def gen_frame_alignments(self, | |||
alphas: torch.Tensor = None, | |||
memory_sequence_length: torch.Tensor = None, | |||
is_training: bool = True, | |||
dtype: torch.dtype = torch.float32): | |||
batch_size, maximum_length = alphas.size() | |||
int_type = torch.int32 | |||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||
max_token_num = torch.max(token_num).item() | |||
alphas_cumsum = torch.cumsum(alphas, dim=1) | |||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||
[1, max_token_num, 1]) | |||
index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||
index = torch.cumsum(index, dim=1) | |||
index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||
index_div = torch.floor(torch.divide(alphas_cumsum, | |||
index)).type(int_type) | |||
index_div_bool_zeros = index_div.eq(0) | |||
index_div_bool_zeros_count = torch.sum( | |||
index_div_bool_zeros, dim=-1) + 1 | |||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||
memory_sequence_length.max()) | |||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||
token_num.device) | |||
index_div_bool_zeros_count *= token_num_mask | |||
index_div_bool_zeros_count_tile = torch.tile( | |||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||
ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||
ones = torch.cumsum(ones, dim=2) | |||
cond = index_div_bool_zeros_count_tile == ones | |||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||
torch.bool) | |||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||
int_type) | |||
index_div_bool_zeros_count_tile_out = torch.sum( | |||
index_div_bool_zeros_count_tile, dim=1) | |||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||
int_type) | |||
predictor_mask = (~make_pad_mask( | |||
memory_sequence_length, | |||
maxlen=memory_sequence_length.max())).type(int_type).to( | |||
memory_sequence_length.device) # noqa: * | |||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||
return index_div_bool_zeros_count_tile_out.detach( | |||
), index_div_bool_zeros_count.detach() | |||
class cif_predictor(nn.Module): | |||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||
super(cif_predictor, self).__init__() | |||
self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||
self.cif_conv1d = nn.Conv1d( | |||
idim, idim, l_order + r_order + 1, groups=idim) | |||
self.cif_output = nn.Linear(idim, 1) | |||
self.dropout = torch.nn.Dropout(p=dropout) | |||
self.threshold = threshold | |||
def forward(self, | |||
hidden, | |||
target_label=None, | |||
mask=None, | |||
ignore_id=-1, | |||
mask_chunk_predictor=None, | |||
target_label_length=None): | |||
h = hidden | |||
context = h.transpose(1, 2) | |||
queries = self.pad(context) | |||
memory = self.cif_conv1d(queries) | |||
output = memory + context | |||
output = self.dropout(output) | |||
output = output.transpose(1, 2) | |||
output = torch.relu(output) | |||
output = self.cif_output(output) | |||
alphas = torch.sigmoid(output) | |||
if mask is not None: | |||
alphas = alphas * mask.transpose(-1, -2).float() | |||
if mask_chunk_predictor is not None: | |||
alphas = alphas * mask_chunk_predictor | |||
alphas = alphas.squeeze(-1) | |||
if target_label_length is not None: | |||
target_length = target_label_length | |||
elif target_label is not None: | |||
target_length = (target_label != ignore_id).float().sum(-1) | |||
else: | |||
target_length = None | |||
token_num = alphas.sum(-1) | |||
if target_length is not None: | |||
alphas *= (target_length / token_num)[:, None].repeat( | |||
1, alphas.size(1)) | |||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) | |||
return acoustic_embeds, token_num, alphas, cif_peak | |||
def gen_frame_alignments(self, | |||
alphas: torch.Tensor = None, | |||
memory_sequence_length: torch.Tensor = None, | |||
is_training: bool = True, | |||
dtype: torch.dtype = torch.float32): | |||
batch_size, maximum_length = alphas.size() | |||
int_type = torch.int32 | |||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||
max_token_num = torch.max(token_num).item() | |||
alphas_cumsum = torch.cumsum(alphas, dim=1) | |||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||
[1, max_token_num, 1]) | |||
index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||
index = torch.cumsum(index, dim=1) | |||
index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||
index_div = torch.floor(torch.divide(alphas_cumsum, | |||
index)).type(int_type) | |||
index_div_bool_zeros = index_div.eq(0) | |||
index_div_bool_zeros_count = torch.sum( | |||
index_div_bool_zeros, dim=-1) + 1 | |||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||
memory_sequence_length.max()) | |||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||
token_num.device) | |||
index_div_bool_zeros_count *= token_num_mask | |||
index_div_bool_zeros_count_tile = torch.tile( | |||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||
ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||
ones = torch.cumsum(ones, dim=2) | |||
cond = index_div_bool_zeros_count_tile == ones | |||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||
torch.bool) | |||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||
int_type) | |||
index_div_bool_zeros_count_tile_out = torch.sum( | |||
index_div_bool_zeros_count_tile, dim=1) | |||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||
int_type) | |||
predictor_mask = (~make_pad_mask( | |||
memory_sequence_length, | |||
maxlen=memory_sequence_length.max())).type(int_type).to( | |||
memory_sequence_length.device) # noqa: * | |||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||
return index_div_bool_zeros_count_tile_out.detach( | |||
), index_div_bool_zeros_count.detach() | |||
def cif(hidden, alphas, threshold): | |||
batch_size, len_time, hidden_size = hidden.size() | |||
# loop varss | |||
integrate = torch.zeros([batch_size], device=hidden.device) | |||
frame = torch.zeros([batch_size, hidden_size], device=hidden.device) | |||
# intermediate vars along time | |||
list_fires = [] | |||
list_frames = [] | |||
for t in range(len_time): | |||
alpha = alphas[:, t] | |||
distribution_completion = torch.ones([batch_size], | |||
device=hidden.device) - integrate | |||
integrate += alpha | |||
list_fires.append(integrate) | |||
fire_place = integrate >= threshold | |||
integrate = torch.where( | |||
fire_place, | |||
integrate - torch.ones([batch_size], device=hidden.device), | |||
integrate) | |||
cur = torch.where(fire_place, distribution_completion, alpha) | |||
remainds = alpha - cur | |||
frame += cur[:, None] * hidden[:, t, :] | |||
list_frames.append(frame) | |||
frame = torch.where(fire_place[:, None].repeat(1, hidden_size), | |||
remainds[:, None] * hidden[:, t, :], frame) | |||
fires = torch.stack(list_fires, 1) | |||
frames = torch.stack(list_frames, 1) | |||
list_ls = [] | |||
len_labels = torch.round(alphas.sum(-1)).int() | |||
max_label_len = len_labels.max() | |||
for b in range(batch_size): | |||
fire = fires[b, :] | |||
ls = torch.index_select(frames[b, :, :], 0, | |||
torch.nonzero(fire >= threshold).squeeze()) | |||
pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size], | |||
device=hidden.device) | |||
list_ls.append(torch.cat([ls, pad_l], 0)) | |||
return torch.stack(list_ls, 0), fires |
@@ -1,680 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
"""Multi-Head Attention layer definition.""" | |||
import logging | |||
import math | |||
import numpy | |||
import torch | |||
from torch import nn | |||
torch.set_printoptions(profile='full', precision=1) | |||
class MultiHeadedAttention(nn.Module): | |||
"""Multi-Head Attention layer. | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
""" | |||
def __init__(self, n_head, n_feat, dropout_rate): | |||
"""Construct an MultiHeadedAttention object.""" | |||
super(MultiHeadedAttention, self).__init__() | |||
assert n_feat % n_head == 0 | |||
# We assume d_v always equals d_k | |||
self.d_k = n_feat // n_head | |||
self.h = n_head | |||
self.linear_q = nn.Linear(n_feat, n_feat) | |||
self.linear_k = nn.Linear(n_feat, n_feat) | |||
self.linear_v = nn.Linear(n_feat, n_feat) | |||
self.linear_out = nn.Linear(n_feat, n_feat) | |||
self.attn = None | |||
self.dropout = nn.Dropout(p=dropout_rate) | |||
def forward_qkv(self, query, key, value): | |||
"""Transform query, key and value. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
Returns: | |||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||
""" | |||
n_batch = query.size(0) | |||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||
q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||
k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||
v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||
return q, k, v | |||
def forward_attention(self, value, scores, mask): | |||
"""Compute attention context vector. | |||
Args: | |||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Transformed value (#batch, time1, d_model) | |||
weighted by the attention score (#batch, time1, time2). | |||
""" | |||
n_batch = value.size(0) | |||
if mask is not None: | |||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||
min_value = float( | |||
numpy.finfo(torch.tensor( | |||
0, dtype=scores.dtype).numpy().dtype).min) | |||
scores = scores.masked_fill(mask, min_value) | |||
self.attn = torch.softmax( | |||
scores, dim=-1).masked_fill(mask, | |||
0.0) # (batch, head, time1, time2) | |||
else: | |||
self.attn = torch.softmax( | |||
scores, dim=-1) # (batch, head, time1, time2) | |||
p_attn = self.dropout(self.attn) | |||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||
self.h * self.d_k) | |||
) # (batch, time1, d_model) | |||
return self.linear_out(x) # (batch, time1, d_model) | |||
def forward(self, query, key, value, mask): | |||
"""Compute scaled dot product attention. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
q, k, v = self.forward_qkv(query, key, value) | |||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||
return self.forward_attention(v, scores, mask) | |||
class MultiHeadedAttentionSANM(nn.Module): | |||
"""Multi-Head Attention layer. | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
""" | |||
def __init__(self, | |||
n_head, | |||
n_feat, | |||
dropout_rate, | |||
kernel_size, | |||
sanm_shfit=0): | |||
"""Construct an MultiHeadedAttention object.""" | |||
super(MultiHeadedAttentionSANM, self).__init__() | |||
assert n_feat % n_head == 0 | |||
# We assume d_v always equals d_k | |||
self.d_k = n_feat // n_head | |||
self.h = n_head | |||
self.linear_q = nn.Linear(n_feat, n_feat) | |||
self.linear_k = nn.Linear(n_feat, n_feat) | |||
self.linear_v = nn.Linear(n_feat, n_feat) | |||
self.linear_out = nn.Linear(n_feat, n_feat) | |||
self.attn = None | |||
self.dropout = nn.Dropout(p=dropout_rate) | |||
self.fsmn_block = nn.Conv1d( | |||
n_feat, | |||
n_feat, | |||
kernel_size, | |||
stride=1, | |||
padding=0, | |||
groups=n_feat, | |||
bias=False) | |||
# padding | |||
left_padding = (kernel_size - 1) // 2 | |||
if sanm_shfit > 0: | |||
left_padding = left_padding + sanm_shfit | |||
right_padding = kernel_size - 1 - left_padding | |||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | |||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | |||
''' | |||
:param x: (#batch, time1, size). | |||
:param mask: Mask tensor (#batch, 1, time) | |||
:return: | |||
''' | |||
# b, t, d = inputs.size() | |||
mask = mask[:, 0, :, None] | |||
if mask_shfit_chunk is not None: | |||
mask = mask * mask_shfit_chunk | |||
inputs *= mask | |||
x = inputs.transpose(1, 2) | |||
x = self.pad_fn(x) | |||
x = self.fsmn_block(x) | |||
x = x.transpose(1, 2) | |||
x += inputs | |||
x = self.dropout(x) | |||
return x * mask | |||
def forward_qkv(self, query, key, value): | |||
"""Transform query, key and value. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
Returns: | |||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||
""" | |||
n_batch = query.size(0) | |||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||
q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||
k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||
v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||
return q, k, v | |||
def forward_attention(self, | |||
value, | |||
scores, | |||
mask, | |||
mask_att_chunk_encoder=None): | |||
"""Compute attention context vector. | |||
Args: | |||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Transformed value (#batch, time1, d_model) | |||
weighted by the attention score (#batch, time1, time2). | |||
""" | |||
n_batch = value.size(0) | |||
if mask is not None: | |||
if mask_att_chunk_encoder is not None: | |||
mask = mask * mask_att_chunk_encoder | |||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||
min_value = float( | |||
numpy.finfo(torch.tensor( | |||
0, dtype=scores.dtype).numpy().dtype).min) | |||
scores = scores.masked_fill(mask, min_value) | |||
self.attn = torch.softmax( | |||
scores, dim=-1).masked_fill(mask, | |||
0.0) # (batch, head, time1, time2) | |||
else: | |||
self.attn = torch.softmax( | |||
scores, dim=-1) # (batch, head, time1, time2) | |||
p_attn = self.dropout(self.attn) | |||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||
self.h * self.d_k) | |||
) # (batch, time1, d_model) | |||
return self.linear_out(x) # (batch, time1, d_model) | |||
def forward(self, | |||
query, | |||
key, | |||
value, | |||
mask, | |||
mask_shfit_chunk=None, | |||
mask_att_chunk_encoder=None): | |||
"""Compute scaled dot product attention. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk) | |||
q, k, v = self.forward_qkv(query, key, value) | |||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||
att_outs = self.forward_attention(v, scores, mask, | |||
mask_att_chunk_encoder) | |||
return att_outs + fsmn_memory | |||
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): | |||
"""Multi-Head Attention layer with relative position encoding (old version). | |||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
Paper: https://arxiv.org/abs/1901.02860 | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
""" | |||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||
super().__init__(n_head, n_feat, dropout_rate) | |||
self.zero_triu = zero_triu | |||
# linear transformation for positional encoding | |||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
# these two learnable bias are used in matrix c and matrix d | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
def rel_shift(self, x): | |||
"""Compute relative positional encoding. | |||
Args: | |||
x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor. | |||
""" | |||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||
device=x.device, | |||
dtype=x.dtype) | |||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
x = x_padded[:, :, 1:].view_as(x) | |||
if self.zero_triu: | |||
ones = torch.ones((x.size(2), x.size(3))) | |||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
return x | |||
def forward(self, query, key, value, pos_emb, mask): | |||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
q, k, v = self.forward_qkv(query, key, value) | |||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
n_batch_pos = pos_emb.size(0) | |||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
# compute attention score | |||
# first compute matrix a and matrix c | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
# (batch, head, time1, time2) | |||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
# compute matrix b and matrix d | |||
# (batch, head, time1, time1) | |||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
matrix_bd = self.rel_shift(matrix_bd) | |||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
self.d_k) # (batch, head, time1, time2) | |||
return self.forward_attention(v, scores, mask) | |||
class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||
"""Multi-Head Attention layer with relative position encoding (old version). | |||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
Paper: https://arxiv.org/abs/1901.02860 | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
""" | |||
def __init__(self, | |||
n_head, | |||
n_feat, | |||
dropout_rate, | |||
zero_triu=False, | |||
kernel_size=15, | |||
sanm_shfit=0): | |||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||
self.zero_triu = zero_triu | |||
# linear transformation for positional encoding | |||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
# these two learnable bias are used in matrix c and matrix d | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
def rel_shift(self, x): | |||
"""Compute relative positional encoding. | |||
Args: | |||
x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor. | |||
""" | |||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||
device=x.device, | |||
dtype=x.dtype) | |||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
x = x_padded[:, :, 1:].view_as(x) | |||
if self.zero_triu: | |||
ones = torch.ones((x.size(2), x.size(3))) | |||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
return x | |||
def forward(self, query, key, value, pos_emb, mask): | |||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
fsmn_memory = self.forward_fsmn(value, mask) | |||
q, k, v = self.forward_qkv(query, key, value) | |||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
n_batch_pos = pos_emb.size(0) | |||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
# compute attention score | |||
# first compute matrix a and matrix c | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
# (batch, head, time1, time2) | |||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
# compute matrix b and matrix d | |||
# (batch, head, time1, time1) | |||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
matrix_bd = self.rel_shift(matrix_bd) | |||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
self.d_k) # (batch, head, time1, time2) | |||
att_outs = self.forward_attention(v, scores, mask) | |||
return att_outs + fsmn_memory | |||
class RelPositionMultiHeadedAttention(MultiHeadedAttention): | |||
"""Multi-Head Attention layer with relative position encoding (new implementation). | |||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
Paper: https://arxiv.org/abs/1901.02860 | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
""" | |||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||
super().__init__(n_head, n_feat, dropout_rate) | |||
self.zero_triu = zero_triu | |||
# linear transformation for positional encoding | |||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
# these two learnable bias are used in matrix c and matrix d | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
def rel_shift(self, x): | |||
"""Compute relative positional encoding. | |||
Args: | |||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||
time1 means the length of query vector. | |||
Returns: | |||
torch.Tensor: Output tensor. | |||
""" | |||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||
device=x.device, | |||
dtype=x.dtype) | |||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
x = x_padded[:, :, 1:].view_as( | |||
x)[:, :, :, :x.size(-1) // 2 | |||
+ 1] # only keep the positions from 0 to time2 | |||
if self.zero_triu: | |||
ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
return x | |||
def forward(self, query, key, value, pos_emb, mask): | |||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
pos_emb (torch.Tensor): Positional embedding tensor | |||
(#batch, 2*time1-1, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
q, k, v = self.forward_qkv(query, key, value) | |||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
n_batch_pos = pos_emb.size(0) | |||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
# compute attention score | |||
# first compute matrix a and matrix c | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
# (batch, head, time1, time2) | |||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
# compute matrix b and matrix d | |||
# (batch, head, time1, 2*time1-1) | |||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
matrix_bd = self.rel_shift(matrix_bd) | |||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
self.d_k) # (batch, head, time1, time2) | |||
return self.forward_attention(v, scores, mask) | |||
class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||
"""Multi-Head Attention layer with relative position encoding (new implementation). | |||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
Paper: https://arxiv.org/abs/1901.02860 | |||
Args: | |||
n_head (int): The number of heads. | |||
n_feat (int): The number of features. | |||
dropout_rate (float): Dropout rate. | |||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
""" | |||
def __init__(self, | |||
n_head, | |||
n_feat, | |||
dropout_rate, | |||
zero_triu=False, | |||
kernel_size=15, | |||
sanm_shfit=0): | |||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||
self.zero_triu = zero_triu | |||
# linear transformation for positional encoding | |||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
# these two learnable bias are used in matrix c and matrix d | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
def rel_shift(self, x): | |||
"""Compute relative positional encoding. | |||
Args: | |||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||
time1 means the length of query vector. | |||
Returns: | |||
torch.Tensor: Output tensor. | |||
""" | |||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||
device=x.device, | |||
dtype=x.dtype) | |||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
x = x_padded[:, :, 1:].view_as( | |||
x)[:, :, :, :x.size(-1) // 2 | |||
+ 1] # only keep the positions from 0 to time2 | |||
if self.zero_triu: | |||
ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
return x | |||
def forward(self, query, key, value, pos_emb, mask): | |||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
Args: | |||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||
pos_emb (torch.Tensor): Positional embedding tensor | |||
(#batch, 2*time1-1, size). | |||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
(#batch, time1, time2). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||
""" | |||
fsmn_memory = self.forward_fsmn(value, mask) | |||
q, k, v = self.forward_qkv(query, key, value) | |||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
n_batch_pos = pos_emb.size(0) | |||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
# (batch, head, time1, d_k) | |||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
# compute attention score | |||
# first compute matrix a and matrix c | |||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
# (batch, head, time1, time2) | |||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
# compute matrix b and matrix d | |||
# (batch, head, time1, 2*time1-1) | |||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
matrix_bd = self.rel_shift(matrix_bd) | |||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
self.d_k) # (batch, head, time1, time2) | |||
att_outs = self.forward_attention(v, scores, mask) | |||
return att_outs + fsmn_memory |
@@ -1,239 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
"""Encoder self-attention layer definition.""" | |||
import torch | |||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
from torch import nn | |||
class EncoderLayer(nn.Module): | |||
"""Encoder layer module. | |||
Args: | |||
size (int): Input dimension. | |||
self_attn (torch.nn.Module): Self-attention module instance. | |||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||
can be used as the argument. | |||
feed_forward (torch.nn.Module): Feed-forward module instance. | |||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||
can be used as the argument. | |||
dropout_rate (float): Dropout rate. | |||
normalize_before (bool): Whether to use layer_norm before the first block. | |||
concat_after (bool): Whether to concat attention layer's input and output. | |||
if True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
if False, no additional linear will be applied. i.e. x -> x + att(x) | |||
stochastic_depth_rate (float): Proability to skip this layer. | |||
During training, the layer may skip residual computation and return input | |||
as-is with given probability. | |||
""" | |||
def __init__( | |||
self, | |||
size, | |||
self_attn, | |||
feed_forward, | |||
dropout_rate, | |||
normalize_before=True, | |||
concat_after=False, | |||
stochastic_depth_rate=0.0, | |||
): | |||
"""Construct an EncoderLayer object.""" | |||
super(EncoderLayer, self).__init__() | |||
self.self_attn = self_attn | |||
self.feed_forward = feed_forward | |||
self.norm1 = LayerNorm(size) | |||
self.norm2 = LayerNorm(size) | |||
self.dropout = nn.Dropout(dropout_rate) | |||
self.size = size | |||
self.normalize_before = normalize_before | |||
self.concat_after = concat_after | |||
if self.concat_after: | |||
self.concat_linear = nn.Linear(size + size, size) | |||
self.stochastic_depth_rate = stochastic_depth_rate | |||
def forward(self, x, mask, cache=None): | |||
"""Compute encoded features. | |||
Args: | |||
x_input (torch.Tensor): Input tensor (#batch, time, size). | |||
mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time, size). | |||
torch.Tensor: Mask tensor (#batch, time). | |||
""" | |||
skip_layer = False | |||
# with stochastic depth, residual connection `x + f(x)` becomes | |||
# `x <- x + 1 / (1 - p) * f(x)` at training time. | |||
stoch_layer_coeff = 1.0 | |||
if self.training and self.stochastic_depth_rate > 0: | |||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||
if skip_layer: | |||
if cache is not None: | |||
x = torch.cat([cache, x], dim=1) | |||
return x, mask | |||
residual = x | |||
if self.normalize_before: | |||
x = self.norm1(x) | |||
if cache is None: | |||
x_q = x | |||
else: | |||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||
x_q = x[:, -1:, :] | |||
residual = residual[:, -1:, :] | |||
mask = None if mask is None else mask[:, -1:, :] | |||
if self.concat_after: | |||
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) | |||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||
else: | |||
x = residual + stoch_layer_coeff * self.dropout( | |||
self.self_attn(x_q, x, x, mask)) | |||
if not self.normalize_before: | |||
x = self.norm1(x) | |||
residual = x | |||
if self.normalize_before: | |||
x = self.norm2(x) | |||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||
if not self.normalize_before: | |||
x = self.norm2(x) | |||
if cache is not None: | |||
x = torch.cat([cache, x], dim=1) | |||
return x, mask | |||
class EncoderLayerChunk(nn.Module): | |||
"""Encoder layer module. | |||
Args: | |||
size (int): Input dimension. | |||
self_attn (torch.nn.Module): Self-attention module instance. | |||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||
can be used as the argument. | |||
feed_forward (torch.nn.Module): Feed-forward module instance. | |||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||
can be used as the argument. | |||
dropout_rate (float): Dropout rate. | |||
normalize_before (bool): Whether to use layer_norm before the first block. | |||
concat_after (bool): Whether to concat attention layer's input and output. | |||
if True, additional linear will be applied. | |||
i.e. x -> x + linear(concat(x, att(x))) | |||
if False, no additional linear will be applied. i.e. x -> x + att(x) | |||
stochastic_depth_rate (float): Proability to skip this layer. | |||
During training, the layer may skip residual computation and return input | |||
as-is with given probability. | |||
""" | |||
def __init__( | |||
self, | |||
size, | |||
self_attn, | |||
feed_forward, | |||
dropout_rate, | |||
normalize_before=True, | |||
concat_after=False, | |||
stochastic_depth_rate=0.0, | |||
): | |||
"""Construct an EncoderLayer object.""" | |||
super(EncoderLayerChunk, self).__init__() | |||
self.self_attn = self_attn | |||
self.feed_forward = feed_forward | |||
self.norm1 = LayerNorm(size) | |||
self.norm2 = LayerNorm(size) | |||
self.dropout = nn.Dropout(dropout_rate) | |||
self.size = size | |||
self.normalize_before = normalize_before | |||
self.concat_after = concat_after | |||
if self.concat_after: | |||
self.concat_linear = nn.Linear(size + size, size) | |||
self.stochastic_depth_rate = stochastic_depth_rate | |||
def forward(self, | |||
x, | |||
mask, | |||
cache=None, | |||
mask_shfit_chunk=None, | |||
mask_att_chunk_encoder=None): | |||
"""Compute encoded features. | |||
Args: | |||
x_input (torch.Tensor): Input tensor (#batch, time, size). | |||
mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||
Returns: | |||
torch.Tensor: Output tensor (#batch, time, size). | |||
torch.Tensor: Mask tensor (#batch, time). | |||
""" | |||
skip_layer = False | |||
# with stochastic depth, residual connection `x + f(x)` becomes | |||
# `x <- x + 1 / (1 - p) * f(x)` at training time. | |||
stoch_layer_coeff = 1.0 | |||
if self.training and self.stochastic_depth_rate > 0: | |||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||
if skip_layer: | |||
if cache is not None: | |||
x = torch.cat([cache, x], dim=1) | |||
return x, mask | |||
residual = x | |||
if self.normalize_before: | |||
x = self.norm1(x) | |||
if cache is None: | |||
x_q = x | |||
else: | |||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||
x_q = x[:, -1:, :] | |||
residual = residual[:, -1:, :] | |||
mask = None if mask is None else mask[:, -1:, :] | |||
if self.concat_after: | |||
x_concat = torch.cat( | |||
(x, | |||
self.self_attn( | |||
x_q, | |||
x, | |||
x, | |||
mask, | |||
mask_shfit_chunk=mask_shfit_chunk, | |||
mask_att_chunk_encoder=mask_att_chunk_encoder)), | |||
dim=-1) | |||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||
else: | |||
x = residual + stoch_layer_coeff * self.dropout( | |||
self.self_attn( | |||
x_q, | |||
x, | |||
x, | |||
mask, | |||
mask_shfit_chunk=mask_shfit_chunk, | |||
mask_att_chunk_encoder=mask_att_chunk_encoder)) | |||
if not self.normalize_before: | |||
x = self.norm1(x) | |||
residual = x | |||
if self.normalize_before: | |||
x = self.norm2(x) | |||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||
if not self.normalize_before: | |||
x = self.norm2(x) | |||
if cache is not None: | |||
x = torch.cat([cache, x], dim=1) | |||
return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder |
@@ -1,890 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed from espnet/espnet. | |||
import argparse | |||
import logging | |||
import os | |||
from pathlib import Path | |||
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union | |||
import numpy as np | |||
import torch | |||
import yaml | |||
from espnet2.asr.ctc import CTC | |||
from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||
from espnet2.asr.decoder.mlm_decoder import MLMDecoder | |||
from espnet2.asr.decoder.rnn_decoder import RNNDecoder | |||
from espnet2.asr.decoder.transformer_decoder import \ | |||
DynamicConvolution2DTransformerDecoder # noqa: H301 | |||
from espnet2.asr.decoder.transformer_decoder import \ | |||
LightweightConvolution2DTransformerDecoder # noqa: H301 | |||
from espnet2.asr.decoder.transformer_decoder import \ | |||
LightweightConvolutionTransformerDecoder # noqa: H301 | |||
from espnet2.asr.decoder.transformer_decoder import ( | |||
DynamicConvolutionTransformerDecoder, TransformerDecoder) | |||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
from espnet2.asr.encoder.contextual_block_conformer_encoder import \ | |||
ContextualBlockConformerEncoder # noqa: H301 | |||
from espnet2.asr.encoder.contextual_block_transformer_encoder import \ | |||
ContextualBlockTransformerEncoder # noqa: H301 | |||
from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, | |||
FairseqHubertPretrainEncoder) | |||
from espnet2.asr.encoder.longformer_encoder import LongformerEncoder | |||
from espnet2.asr.encoder.rnn_encoder import RNNEncoder | |||
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder | |||
from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder | |||
from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder | |||
from espnet2.asr.espnet_model import ESPnetASRModel | |||
from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||
from espnet2.asr.frontend.default import DefaultFrontend | |||
from espnet2.asr.frontend.fused import FusedFrontends | |||
from espnet2.asr.frontend.s3prl import S3prlFrontend | |||
from espnet2.asr.frontend.windowing import SlidingWindow | |||
from espnet2.asr.maskctc_model import MaskCTCModel | |||
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder | |||
from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ | |||
HuggingFaceTransformersPostEncoder # noqa: H301 | |||
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder | |||
from espnet2.asr.preencoder.linear import LinearProjection | |||
from espnet2.asr.preencoder.sinc import LightweightSincConvs | |||
from espnet2.asr.specaug.abs_specaug import AbsSpecAug | |||
from espnet2.asr.specaug.specaug import SpecAug | |||
from espnet2.asr.transducer.joint_network import JointNetwork | |||
from espnet2.asr.transducer.transducer_decoder import TransducerDecoder | |||
from espnet2.layers.abs_normalize import AbsNormalize | |||
from espnet2.layers.global_mvn import GlobalMVN | |||
from espnet2.layers.utterance_mvn import UtteranceMVN | |||
from espnet2.tasks.abs_task import AbsTask | |||
from espnet2.text.phoneme_tokenizer import g2p_choices | |||
from espnet2.torch_utils.initialize import initialize | |||
from espnet2.train.abs_espnet_model import AbsESPnetModel | |||
from espnet2.train.class_choices import ClassChoices | |||
from espnet2.train.collate_fn import CommonCollateFn | |||
from espnet2.train.preprocessor import CommonPreprocessor | |||
from espnet2.train.trainer import Trainer | |||
from espnet2.utils.get_default_kwargs import get_default_kwargs | |||
from espnet2.utils.nested_dict_action import NestedDictAction | |||
from espnet2.utils.types import (float_or_none, int_or_none, str2bool, | |||
str_or_none) | |||
from typeguard import check_argument_types, check_return_type | |||
from ..asr.decoder.transformer_decoder import (ParaformerDecoder, | |||
ParaformerDecoderBertEmbed) | |||
from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2 | |||
from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk | |||
from ..asr.espnet_model import AEDStreaming | |||
from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed | |||
from ..nets.pytorch_backend.cif_utils.cif import cif_predictor | |||
# FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries. | |||
logging.basicConfig( | |||
level='INFO', | |||
format=f"[{os.uname()[1].split('.')[0]}]" | |||
f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||
) | |||
# FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger | |||
logger = logging.getLogger() | |||
frontend_choices = ClassChoices( | |||
name='frontend', | |||
classes=dict( | |||
default=DefaultFrontend, | |||
sliding_window=SlidingWindow, | |||
s3prl=S3prlFrontend, | |||
fused=FusedFrontends, | |||
), | |||
type_check=AbsFrontend, | |||
default='default', | |||
) | |||
specaug_choices = ClassChoices( | |||
name='specaug', | |||
classes=dict(specaug=SpecAug, ), | |||
type_check=AbsSpecAug, | |||
default=None, | |||
optional=True, | |||
) | |||
normalize_choices = ClassChoices( | |||
'normalize', | |||
classes=dict( | |||
global_mvn=GlobalMVN, | |||
utterance_mvn=UtteranceMVN, | |||
), | |||
type_check=AbsNormalize, | |||
default='utterance_mvn', | |||
optional=True, | |||
) | |||
model_choices = ClassChoices( | |||
'model', | |||
classes=dict( | |||
espnet=ESPnetASRModel, | |||
maskctc=MaskCTCModel, | |||
paraformer=Paraformer, | |||
paraformer_bert_embed=ParaformerBertEmbed, | |||
aedstreaming=AEDStreaming, | |||
), | |||
type_check=AbsESPnetModel, | |||
default='espnet', | |||
) | |||
preencoder_choices = ClassChoices( | |||
name='preencoder', | |||
classes=dict( | |||
sinc=LightweightSincConvs, | |||
linear=LinearProjection, | |||
), | |||
type_check=AbsPreEncoder, | |||
default=None, | |||
optional=True, | |||
) | |||
encoder_choices = ClassChoices( | |||
'encoder', | |||
classes=dict( | |||
conformer=ConformerEncoder, | |||
transformer=TransformerEncoder, | |||
contextual_block_transformer=ContextualBlockTransformerEncoder, | |||
contextual_block_conformer=ContextualBlockConformerEncoder, | |||
vgg_rnn=VGGRNNEncoder, | |||
rnn=RNNEncoder, | |||
wav2vec2=FairSeqWav2Vec2Encoder, | |||
hubert=FairseqHubertEncoder, | |||
hubert_pretrain=FairseqHubertPretrainEncoder, | |||
longformer=LongformerEncoder, | |||
sanm=SANMEncoder, | |||
sanm_v2=SANMEncoder_v2, | |||
sanm_chunk=SANMEncoderChunk, | |||
), | |||
type_check=AbsEncoder, | |||
default='rnn', | |||
) | |||
postencoder_choices = ClassChoices( | |||
name='postencoder', | |||
classes=dict( | |||
hugging_face_transformers=HuggingFaceTransformersPostEncoder, ), | |||
type_check=AbsPostEncoder, | |||
default=None, | |||
optional=True, | |||
) | |||
decoder_choices = ClassChoices( | |||
'decoder', | |||
classes=dict( | |||
transformer=TransformerDecoder, | |||
lightweight_conv=LightweightConvolutionTransformerDecoder, | |||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder, | |||
dynamic_conv=DynamicConvolutionTransformerDecoder, | |||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder, | |||
rnn=RNNDecoder, | |||
transducer=TransducerDecoder, | |||
mlm=MLMDecoder, | |||
paraformer_decoder=ParaformerDecoder, | |||
paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed, | |||
), | |||
type_check=AbsDecoder, | |||
default='rnn', | |||
) | |||
predictor_choices = ClassChoices( | |||
name='predictor', | |||
classes=dict( | |||
cif_predictor=cif_predictor, | |||
ctc_predictor=None, | |||
), | |||
type_check=None, | |||
default='cif_predictor', | |||
optional=True, | |||
) | |||
class ASRTask(AbsTask): | |||
# If you need more than one optimizers, change this value | |||
num_optimizers: int = 1 | |||
# Add variable objects configurations | |||
class_choices_list = [ | |||
# --frontend and --frontend_conf | |||
frontend_choices, | |||
# --specaug and --specaug_conf | |||
specaug_choices, | |||
# --normalize and --normalize_conf | |||
normalize_choices, | |||
# --model and --model_conf | |||
model_choices, | |||
# --preencoder and --preencoder_conf | |||
preencoder_choices, | |||
# --encoder and --encoder_conf | |||
encoder_choices, | |||
# --postencoder and --postencoder_conf | |||
postencoder_choices, | |||
# --decoder and --decoder_conf | |||
decoder_choices, | |||
] | |||
# If you need to modify train() or eval() procedures, change Trainer class here | |||
trainer = Trainer | |||
@classmethod | |||
def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||
group = parser.add_argument_group(description='Task related') | |||
# NOTE(kamo): add_arguments(..., required=True) can't be used | |||
# to provide --print_config mode. Instead of it, do as | |||
required = parser.get_default('required') | |||
required += ['token_list'] | |||
group.add_argument( | |||
'--token_list', | |||
type=str_or_none, | |||
default=None, | |||
help='A text mapping int-id to token', | |||
) | |||
group.add_argument( | |||
'--init', | |||
type=lambda x: str_or_none(x.lower()), | |||
default=None, | |||
help='The initialization method', | |||
choices=[ | |||
'chainer', | |||
'xavier_uniform', | |||
'xavier_normal', | |||
'kaiming_uniform', | |||
'kaiming_normal', | |||
None, | |||
], | |||
) | |||
group.add_argument( | |||
'--input_size', | |||
type=int_or_none, | |||
default=None, | |||
help='The number of input dimension of the feature', | |||
) | |||
group.add_argument( | |||
'--ctc_conf', | |||
action=NestedDictAction, | |||
default=get_default_kwargs(CTC), | |||
help='The keyword arguments for CTC class.', | |||
) | |||
group.add_argument( | |||
'--joint_net_conf', | |||
action=NestedDictAction, | |||
default=None, | |||
help='The keyword arguments for joint network class.', | |||
) | |||
group = parser.add_argument_group(description='Preprocess related') | |||
group.add_argument( | |||
'--use_preprocessor', | |||
type=str2bool, | |||
default=True, | |||
help='Apply preprocessing to data or not', | |||
) | |||
group.add_argument( | |||
'--token_type', | |||
type=str, | |||
default='bpe', | |||
choices=['bpe', 'char', 'word', 'phn'], | |||
help='The text will be tokenized ' | |||
'in the specified level token', | |||
) | |||
group.add_argument( | |||
'--bpemodel', | |||
type=str_or_none, | |||
default=None, | |||
help='The model file of sentencepiece', | |||
) | |||
parser.add_argument( | |||
'--non_linguistic_symbols', | |||
type=str_or_none, | |||
help='non_linguistic_symbols file path', | |||
) | |||
parser.add_argument( | |||
'--cleaner', | |||
type=str_or_none, | |||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||
default=None, | |||
help='Apply text cleaning', | |||
) | |||
parser.add_argument( | |||
'--g2p', | |||
type=str_or_none, | |||
choices=g2p_choices, | |||
default=None, | |||
help='Specify g2p method if --token_type=phn', | |||
) | |||
parser.add_argument( | |||
'--speech_volume_normalize', | |||
type=float_or_none, | |||
default=None, | |||
help='Scale the maximum amplitude to the given value.', | |||
) | |||
parser.add_argument( | |||
'--rir_scp', | |||
type=str_or_none, | |||
default=None, | |||
help='The file path of rir scp file.', | |||
) | |||
parser.add_argument( | |||
'--rir_apply_prob', | |||
type=float, | |||
default=1.0, | |||
help='THe probability for applying RIR convolution.', | |||
) | |||
parser.add_argument( | |||
'--noise_scp', | |||
type=str_or_none, | |||
default=None, | |||
help='The file path of noise scp file.', | |||
) | |||
parser.add_argument( | |||
'--noise_apply_prob', | |||
type=float, | |||
default=1.0, | |||
help='The probability applying Noise adding.', | |||
) | |||
parser.add_argument( | |||
'--noise_db_range', | |||
type=str, | |||
default='13_15', | |||
help='The range of noise decibel level.', | |||
) | |||
for class_choices in cls.class_choices_list: | |||
# Append --<name> and --<name>_conf. | |||
# e.g. --encoder and --encoder_conf | |||
class_choices.add_arguments(group) | |||
@classmethod | |||
def build_collate_fn( | |||
cls, args: argparse.Namespace, train: bool | |||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||
List[str], Dict[str, torch.Tensor]], ]: | |||
assert check_argument_types() | |||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||
@classmethod | |||
def build_preprocess_fn( | |||
cls, args: argparse.Namespace, train: bool | |||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||
assert check_argument_types() | |||
if args.use_preprocessor: | |||
retval = CommonPreprocessor( | |||
train=train, | |||
token_type=args.token_type, | |||
token_list=args.token_list, | |||
bpemodel=args.bpemodel, | |||
non_linguistic_symbols=args.non_linguistic_symbols, | |||
text_cleaner=args.cleaner, | |||
g2p_type=args.g2p, | |||
# NOTE(kamo): Check attribute existence for backward compatibility | |||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||
rir_apply_prob=args.rir_apply_prob if hasattr( | |||
args, 'rir_apply_prob') else 1.0, | |||
noise_scp=args.noise_scp | |||
if hasattr(args, 'noise_scp') else None, | |||
noise_apply_prob=args.noise_apply_prob if hasattr( | |||
args, 'noise_apply_prob') else 1.0, | |||
noise_db_range=args.noise_db_range if hasattr( | |||
args, 'noise_db_range') else '13_15', | |||
speech_volume_normalize=args.speech_volume_normalize | |||
if hasattr(args, 'rir_scp') else None, | |||
) | |||
else: | |||
retval = None | |||
assert check_return_type(retval) | |||
return retval | |||
@classmethod | |||
def required_data_names(cls, | |||
train: bool = True, | |||
inference: bool = False) -> Tuple[str, ...]: | |||
if not inference: | |||
retval = ('speech', 'text') | |||
else: | |||
# Recognition mode | |||
retval = ('speech', ) | |||
return retval | |||
@classmethod | |||
def optional_data_names(cls, | |||
train: bool = True, | |||
inference: bool = False) -> Tuple[str, ...]: | |||
retval = () | |||
assert check_return_type(retval) | |||
return retval | |||
@classmethod | |||
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: | |||
assert check_argument_types() | |||
if isinstance(args.token_list, str): | |||
with open(args.token_list, encoding='utf-8') as f: | |||
token_list = [line.rstrip() for line in f] | |||
# Overwriting token_list to keep it as "portable". | |||
args.token_list = list(token_list) | |||
elif isinstance(args.token_list, (tuple, list)): | |||
token_list = list(args.token_list) | |||
else: | |||
raise RuntimeError('token_list must be str or list') | |||
vocab_size = len(token_list) | |||
logger.info(f'Vocabulary size: {vocab_size }') | |||
# 1. frontend | |||
if args.input_size is None: | |||
# Extract features in the model | |||
frontend_class = frontend_choices.get_class(args.frontend) | |||
frontend = frontend_class(**args.frontend_conf) | |||
input_size = frontend.output_size() | |||
else: | |||
# Give features from data-loader | |||
args.frontend = None | |||
args.frontend_conf = {} | |||
frontend = None | |||
input_size = args.input_size | |||
# 2. Data augmentation for spectrogram | |||
if args.specaug is not None: | |||
specaug_class = specaug_choices.get_class(args.specaug) | |||
specaug = specaug_class(**args.specaug_conf) | |||
else: | |||
specaug = None | |||
# 3. Normalization layer | |||
if args.normalize is not None: | |||
normalize_class = normalize_choices.get_class(args.normalize) | |||
normalize = normalize_class(**args.normalize_conf) | |||
else: | |||
normalize = None | |||
# 4. Pre-encoder input block | |||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
if getattr(args, 'preencoder', None) is not None: | |||
preencoder_class = preencoder_choices.get_class(args.preencoder) | |||
preencoder = preencoder_class(**args.preencoder_conf) | |||
input_size = preencoder.output_size() | |||
else: | |||
preencoder = None | |||
# 4. Encoder | |||
encoder_class = encoder_choices.get_class(args.encoder) | |||
encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||
# 5. Post-encoder block | |||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
encoder_output_size = encoder.output_size() | |||
if getattr(args, 'postencoder', None) is not None: | |||
postencoder_class = postencoder_choices.get_class(args.postencoder) | |||
postencoder = postencoder_class( | |||
input_size=encoder_output_size, **args.postencoder_conf) | |||
encoder_output_size = postencoder.output_size() | |||
else: | |||
postencoder = None | |||
# 5. Decoder | |||
decoder_class = decoder_choices.get_class(args.decoder) | |||
if args.decoder == 'transducer': | |||
decoder = decoder_class( | |||
vocab_size, | |||
embed_pad=0, | |||
**args.decoder_conf, | |||
) | |||
joint_network = JointNetwork( | |||
vocab_size, | |||
encoder.output_size(), | |||
decoder.dunits, | |||
**args.joint_net_conf, | |||
) | |||
else: | |||
decoder = decoder_class( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
**args.decoder_conf, | |||
) | |||
joint_network = None | |||
# 6. CTC | |||
ctc = CTC( | |||
odim=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
**args.ctc_conf) | |||
# 7. Build model | |||
try: | |||
model_class = model_choices.get_class(args.model) | |||
except AttributeError: | |||
model_class = model_choices.get_class('espnet') | |||
model = model_class( | |||
vocab_size=vocab_size, | |||
frontend=frontend, | |||
specaug=specaug, | |||
normalize=normalize, | |||
preencoder=preencoder, | |||
encoder=encoder, | |||
postencoder=postencoder, | |||
decoder=decoder, | |||
ctc=ctc, | |||
joint_network=joint_network, | |||
token_list=token_list, | |||
**args.model_conf, | |||
) | |||
# FIXME(kamo): Should be done in model? | |||
# 8. Initialize | |||
if args.init is not None: | |||
initialize(model, args.init) | |||
assert check_return_type(model) | |||
return model | |||
class ASRTaskNAR(AbsTask): | |||
# If you need more than one optimizers, change this value | |||
num_optimizers: int = 1 | |||
# Add variable objects configurations | |||
class_choices_list = [ | |||
# --frontend and --frontend_conf | |||
frontend_choices, | |||
# --specaug and --specaug_conf | |||
specaug_choices, | |||
# --normalize and --normalize_conf | |||
normalize_choices, | |||
# --model and --model_conf | |||
model_choices, | |||
# --preencoder and --preencoder_conf | |||
preencoder_choices, | |||
# --encoder and --encoder_conf | |||
encoder_choices, | |||
# --postencoder and --postencoder_conf | |||
postencoder_choices, | |||
# --decoder and --decoder_conf | |||
decoder_choices, | |||
# --predictor and --predictor_conf | |||
predictor_choices, | |||
] | |||
# If you need to modify train() or eval() procedures, change Trainer class here | |||
trainer = Trainer | |||
@classmethod | |||
def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||
group = parser.add_argument_group(description='Task related') | |||
# NOTE(kamo): add_arguments(..., required=True) can't be used | |||
# to provide --print_config mode. Instead of it, do as | |||
required = parser.get_default('required') | |||
required += ['token_list'] | |||
group.add_argument( | |||
'--token_list', | |||
type=str_or_none, | |||
default=None, | |||
help='A text mapping int-id to token', | |||
) | |||
group.add_argument( | |||
'--init', | |||
type=lambda x: str_or_none(x.lower()), | |||
default=None, | |||
help='The initialization method', | |||
choices=[ | |||
'chainer', | |||
'xavier_uniform', | |||
'xavier_normal', | |||
'kaiming_uniform', | |||
'kaiming_normal', | |||
None, | |||
], | |||
) | |||
group.add_argument( | |||
'--input_size', | |||
type=int_or_none, | |||
default=None, | |||
help='The number of input dimension of the feature', | |||
) | |||
group.add_argument( | |||
'--ctc_conf', | |||
action=NestedDictAction, | |||
default=get_default_kwargs(CTC), | |||
help='The keyword arguments for CTC class.', | |||
) | |||
group.add_argument( | |||
'--joint_net_conf', | |||
action=NestedDictAction, | |||
default=None, | |||
help='The keyword arguments for joint network class.', | |||
) | |||
group = parser.add_argument_group(description='Preprocess related') | |||
group.add_argument( | |||
'--use_preprocessor', | |||
type=str2bool, | |||
default=True, | |||
help='Apply preprocessing to data or not', | |||
) | |||
group.add_argument( | |||
'--token_type', | |||
type=str, | |||
default='bpe', | |||
choices=['bpe', 'char', 'word', 'phn'], | |||
help='The text will be tokenized ' | |||
'in the specified level token', | |||
) | |||
group.add_argument( | |||
'--bpemodel', | |||
type=str_or_none, | |||
default=None, | |||
help='The model file of sentencepiece', | |||
) | |||
parser.add_argument( | |||
'--non_linguistic_symbols', | |||
type=str_or_none, | |||
help='non_linguistic_symbols file path', | |||
) | |||
parser.add_argument( | |||
'--cleaner', | |||
type=str_or_none, | |||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||
default=None, | |||
help='Apply text cleaning', | |||
) | |||
parser.add_argument( | |||
'--g2p', | |||
type=str_or_none, | |||
choices=g2p_choices, | |||
default=None, | |||
help='Specify g2p method if --token_type=phn', | |||
) | |||
parser.add_argument( | |||
'--speech_volume_normalize', | |||
type=float_or_none, | |||
default=None, | |||
help='Scale the maximum amplitude to the given value.', | |||
) | |||
parser.add_argument( | |||
'--rir_scp', | |||
type=str_or_none, | |||
default=None, | |||
help='The file path of rir scp file.', | |||
) | |||
parser.add_argument( | |||
'--rir_apply_prob', | |||
type=float, | |||
default=1.0, | |||
help='THe probability for applying RIR convolution.', | |||
) | |||
parser.add_argument( | |||
'--noise_scp', | |||
type=str_or_none, | |||
default=None, | |||
help='The file path of noise scp file.', | |||
) | |||
parser.add_argument( | |||
'--noise_apply_prob', | |||
type=float, | |||
default=1.0, | |||
help='The probability applying Noise adding.', | |||
) | |||
parser.add_argument( | |||
'--noise_db_range', | |||
type=str, | |||
default='13_15', | |||
help='The range of noise decibel level.', | |||
) | |||
for class_choices in cls.class_choices_list: | |||
# Append --<name> and --<name>_conf. | |||
# e.g. --encoder and --encoder_conf | |||
class_choices.add_arguments(group) | |||
@classmethod | |||
def build_collate_fn( | |||
cls, args: argparse.Namespace, train: bool | |||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||
List[str], Dict[str, torch.Tensor]], ]: | |||
assert check_argument_types() | |||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||
@classmethod | |||
def build_preprocess_fn( | |||
cls, args: argparse.Namespace, train: bool | |||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||
assert check_argument_types() | |||
if args.use_preprocessor: | |||
retval = CommonPreprocessor( | |||
train=train, | |||
token_type=args.token_type, | |||
token_list=args.token_list, | |||
bpemodel=args.bpemodel, | |||
non_linguistic_symbols=args.non_linguistic_symbols, | |||
text_cleaner=args.cleaner, | |||
g2p_type=args.g2p, | |||
# NOTE(kamo): Check attribute existence for backward compatibility | |||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||
rir_apply_prob=args.rir_apply_prob if hasattr( | |||
args, 'rir_apply_prob') else 1.0, | |||
noise_scp=args.noise_scp | |||
if hasattr(args, 'noise_scp') else None, | |||
noise_apply_prob=args.noise_apply_prob if hasattr( | |||
args, 'noise_apply_prob') else 1.0, | |||
noise_db_range=args.noise_db_range if hasattr( | |||
args, 'noise_db_range') else '13_15', | |||
speech_volume_normalize=args.speech_volume_normalize | |||
if hasattr(args, 'rir_scp') else None, | |||
) | |||
else: | |||
retval = None | |||
assert check_return_type(retval) | |||
return retval | |||
@classmethod | |||
def required_data_names(cls, | |||
train: bool = True, | |||
inference: bool = False) -> Tuple[str, ...]: | |||
if not inference: | |||
retval = ('speech', 'text') | |||
else: | |||
# Recognition mode | |||
retval = ('speech', ) | |||
return retval | |||
@classmethod | |||
def optional_data_names(cls, | |||
train: bool = True, | |||
inference: bool = False) -> Tuple[str, ...]: | |||
retval = () | |||
assert check_return_type(retval) | |||
return retval | |||
@classmethod | |||
def build_model(cls, args: argparse.Namespace): | |||
assert check_argument_types() | |||
if isinstance(args.token_list, str): | |||
with open(args.token_list, encoding='utf-8') as f: | |||
token_list = [line.rstrip() for line in f] | |||
# Overwriting token_list to keep it as "portable". | |||
args.token_list = list(token_list) | |||
elif isinstance(args.token_list, (tuple, list)): | |||
token_list = list(args.token_list) | |||
else: | |||
raise RuntimeError('token_list must be str or list') | |||
vocab_size = len(token_list) | |||
# logger.info(f'Vocabulary size: {vocab_size }') | |||
# 1. frontend | |||
if args.input_size is None: | |||
# Extract features in the model | |||
frontend_class = frontend_choices.get_class(args.frontend) | |||
frontend = frontend_class(**args.frontend_conf) | |||
input_size = frontend.output_size() | |||
else: | |||
# Give features from data-loader | |||
args.frontend = None | |||
args.frontend_conf = {} | |||
frontend = None | |||
input_size = args.input_size | |||
# 2. Data augmentation for spectrogram | |||
if args.specaug is not None: | |||
specaug_class = specaug_choices.get_class(args.specaug) | |||
specaug = specaug_class(**args.specaug_conf) | |||
else: | |||
specaug = None | |||
# 3. Normalization layer | |||
if args.normalize is not None: | |||
normalize_class = normalize_choices.get_class(args.normalize) | |||
normalize = normalize_class(**args.normalize_conf) | |||
else: | |||
normalize = None | |||
# 4. Pre-encoder input block | |||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
if getattr(args, 'preencoder', None) is not None: | |||
preencoder_class = preencoder_choices.get_class(args.preencoder) | |||
preencoder = preencoder_class(**args.preencoder_conf) | |||
input_size = preencoder.output_size() | |||
else: | |||
preencoder = None | |||
# 4. Encoder | |||
encoder_class = encoder_choices.get_class(args.encoder) | |||
encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||
# 5. Post-encoder block | |||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
encoder_output_size = encoder.output_size() | |||
if getattr(args, 'postencoder', None) is not None: | |||
postencoder_class = postencoder_choices.get_class(args.postencoder) | |||
postencoder = postencoder_class( | |||
input_size=encoder_output_size, **args.postencoder_conf) | |||
encoder_output_size = postencoder.output_size() | |||
else: | |||
postencoder = None | |||
# 5. Decoder | |||
decoder_class = decoder_choices.get_class(args.decoder) | |||
if args.decoder == 'transducer': | |||
decoder = decoder_class( | |||
vocab_size, | |||
embed_pad=0, | |||
**args.decoder_conf, | |||
) | |||
joint_network = JointNetwork( | |||
vocab_size, | |||
encoder.output_size(), | |||
decoder.dunits, | |||
**args.joint_net_conf, | |||
) | |||
else: | |||
decoder = decoder_class( | |||
vocab_size=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
**args.decoder_conf, | |||
) | |||
joint_network = None | |||
# 6. CTC | |||
ctc = CTC( | |||
odim=vocab_size, | |||
encoder_output_size=encoder_output_size, | |||
**args.ctc_conf) | |||
predictor_class = predictor_choices.get_class(args.predictor) | |||
predictor = predictor_class(**args.predictor_conf) | |||
# 7. Build model | |||
try: | |||
model_class = model_choices.get_class(args.model) | |||
except AttributeError: | |||
model_class = model_choices.get_class('espnet') | |||
model = model_class( | |||
vocab_size=vocab_size, | |||
frontend=frontend, | |||
specaug=specaug, | |||
normalize=normalize, | |||
preencoder=preencoder, | |||
encoder=encoder, | |||
postencoder=postencoder, | |||
decoder=decoder, | |||
ctc=ctc, | |||
joint_network=joint_network, | |||
token_list=token_list, | |||
predictor=predictor, | |||
**args.model_conf, | |||
) | |||
# FIXME(kamo): Should be done in model? | |||
# 8. Initialize | |||
if args.init is not None: | |||
initialize(model, args.init) | |||
assert check_return_type(model) | |||
return model |
@@ -1,223 +0,0 @@ | |||
import os | |||
import shutil | |||
import threading | |||
from typing import Any, Dict, List, Sequence, Tuple, Union | |||
import yaml | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import WavToScp | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .asr_engine.common import asr_utils | |||
logger = get_logger() | |||
__all__ = ['AutomaticSpeechRecognitionPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||
class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
"""ASR Pipeline | |||
""" | |||
def __init__(self, | |||
model: Union[List[Model], List[str]] = None, | |||
preprocessor: WavToScp = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create an asr pipeline for prediction | |||
""" | |||
from .asr_engine import asr_env_checking | |||
assert model is not None, 'asr model should be provided' | |||
model_list: List = [] | |||
if isinstance(model[0], Model): | |||
model_list = model | |||
else: | |||
model_list.append(Model.from_pretrained(model[0])) | |||
if len(model) == 2 and model[1] is not None: | |||
model_list.append(Model.from_pretrained(model[1])) | |||
super().__init__(model=model_list, preprocessor=preprocessor, **kwargs) | |||
self._preprocessor = preprocessor | |||
self._am_model = model_list[0] | |||
if len(model_list) == 2 and model_list[1] is not None: | |||
self._lm_model = model_list[1] | |||
def __call__(self, | |||
wav_path: str, | |||
recog_type: str = None, | |||
audio_format: str = None, | |||
workspace: str = None) -> Dict[str, Any]: | |||
assert len(wav_path) > 0, 'wav_path should be provided' | |||
self._recog_type = recog_type | |||
self._audio_format = audio_format | |||
self._workspace = workspace | |||
self._wav_path = wav_path | |||
if recog_type is None or audio_format is None or workspace is None: | |||
self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking( | |||
wav_path, recog_type, audio_format, workspace) | |||
if self._preprocessor is None: | |||
self._preprocessor = WavToScp(workspace=self._workspace) | |||
output = self._preprocessor.forward(self._am_model.forward(), | |||
self._recog_type, | |||
self._audio_format, self._wav_path) | |||
output = self.forward(output) | |||
rst = self.postprocess(output) | |||
return rst | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""Decoding | |||
""" | |||
logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||
j: int = 0 | |||
process = [] | |||
while j < inputs['thread_count']: | |||
data_cmd: Sequence[Tuple[str, str, str]] | |||
if inputs['audio_format'] == 'wav': | |||
data_cmd = [(os.path.join(inputs['workspace'], | |||
'data.' + str(j) + '.scp'), 'speech', | |||
'sound')] | |||
elif inputs['audio_format'] == 'kaldi_ark': | |||
data_cmd = [(os.path.join(inputs['workspace'], | |||
'data.' + str(j) + '.scp'), 'speech', | |||
'kaldi_ark')] | |||
output_dir: str = os.path.join(inputs['output'], | |||
'output.' + str(j)) | |||
if not os.path.exists(output_dir): | |||
os.mkdir(output_dir) | |||
config_file = open(inputs['asr_model_config']) | |||
root = yaml.full_load(config_file) | |||
config_file.close() | |||
frontend_conf = None | |||
if 'frontend_conf' in root: | |||
frontend_conf = root['frontend_conf'] | |||
cmd = { | |||
'model_type': inputs['model_type'], | |||
'beam_size': root['beam_size'], | |||
'penalty': root['penalty'], | |||
'maxlenratio': root['maxlenratio'], | |||
'minlenratio': root['minlenratio'], | |||
'ctc_weight': root['ctc_weight'], | |||
'lm_weight': root['lm_weight'], | |||
'output_dir': output_dir, | |||
'ngpu': 0, | |||
'log_level': 'ERROR', | |||
'data_path_and_name_and_type': data_cmd, | |||
'asr_train_config': inputs['am_model_config'], | |||
'asr_model_file': inputs['am_model_path'], | |||
'batch_size': inputs['model_config']['batch_size'], | |||
'frontend_conf': frontend_conf | |||
} | |||
thread = AsrInferenceThread(j, cmd) | |||
thread.start() | |||
j += 1 | |||
process.append(thread) | |||
for p in process: | |||
p.join() | |||
return inputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the asr results | |||
""" | |||
logger.info('Computing the result of ASR ...') | |||
rst = {'rec_result': 'None'} | |||
# single wav task | |||
if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav': | |||
text_file: str = os.path.join(inputs['output'], 'output.0', | |||
'1best_recog', 'text') | |||
if os.path.exists(text_file): | |||
f = open(text_file, 'r') | |||
result_str: str = f.readline() | |||
f.close() | |||
if len(result_str) > 0: | |||
result_list = result_str.split() | |||
if len(result_list) >= 2: | |||
rst['rec_result'] = result_list[1] | |||
# run with datasets, and audio format is waveform or kaldi_ark | |||
elif inputs['recog_type'] != 'wav': | |||
inputs['reference_text'] = self._ref_text_tidy(inputs) | |||
inputs['datasets_result'] = asr_utils.compute_wer( | |||
inputs['hypothesis_text'], inputs['reference_text']) | |||
else: | |||
raise ValueError('recog_type and audio_format are mismatching') | |||
if 'datasets_result' in inputs: | |||
rst['datasets_result'] = inputs['datasets_result'] | |||
# remove workspace dir (.tmp) | |||
if os.path.exists(self._workspace): | |||
shutil.rmtree(self._workspace) | |||
return rst | |||
def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str: | |||
ref_text: str = os.path.join(inputs['output'], 'text.ref') | |||
k: int = 0 | |||
while k < inputs['thread_count']: | |||
output_text = os.path.join(inputs['output'], 'output.' + str(k), | |||
'1best_recog', 'text') | |||
if os.path.exists(output_text): | |||
with open(output_text, 'r', encoding='utf-8') as i: | |||
lines = i.readlines() | |||
with open(ref_text, 'a', encoding='utf-8') as o: | |||
for line in lines: | |||
o.write(line) | |||
k += 1 | |||
return ref_text | |||
class AsrInferenceThread(threading.Thread): | |||
def __init__(self, threadID, cmd): | |||
threading.Thread.__init__(self) | |||
self._threadID = threadID | |||
self._cmd = cmd | |||
def run(self): | |||
if self._cmd['model_type'] == 'pytorch': | |||
from .asr_engine import asr_inference_paraformer_espnet | |||
asr_inference_paraformer_espnet.asr_inference( | |||
batch_size=self._cmd['batch_size'], | |||
output_dir=self._cmd['output_dir'], | |||
maxlenratio=self._cmd['maxlenratio'], | |||
minlenratio=self._cmd['minlenratio'], | |||
beam_size=self._cmd['beam_size'], | |||
ngpu=self._cmd['ngpu'], | |||
ctc_weight=self._cmd['ctc_weight'], | |||
lm_weight=self._cmd['lm_weight'], | |||
penalty=self._cmd['penalty'], | |||
log_level=self._cmd['log_level'], | |||
data_path_and_name_and_type=self. | |||
_cmd['data_path_and_name_and_type'], | |||
asr_train_config=self._cmd['asr_train_config'], | |||
asr_model_file=self._cmd['asr_model_file'], | |||
frontend_conf=self._cmd['frontend_conf']) |
@@ -0,0 +1,213 @@ | |||
import os | |||
from typing import Any, Dict, List, Sequence, Tuple, Union | |||
import yaml | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import WavToScp | |||
from modelscope.utils.constant import Frameworks, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['AutomaticSpeechRecognitionPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||
class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
"""ASR Inference Pipeline | |||
""" | |||
def __init__(self, | |||
model: Union[Model, str] = None, | |||
preprocessor: WavToScp = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create an asr pipeline for prediction | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def __call__(self, | |||
audio_in: Union[str, bytes], | |||
recog_type: str = None, | |||
audio_format: str = None) -> Dict[str, Any]: | |||
from easyasr.common import asr_utils | |||
self.recog_type = recog_type | |||
self.audio_format = audio_format | |||
self.audio_in = audio_in | |||
if recog_type is None or audio_format is None: | |||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||
audio_in, recog_type, audio_format) | |||
if self.preprocessor is None: | |||
self.preprocessor = WavToScp() | |||
output = self.preprocessor.forward(self.model.forward(), | |||
self.recog_type, self.audio_format, | |||
self.audio_in) | |||
output = self.forward(output) | |||
rst = self.postprocess(output) | |||
return rst | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""Decoding | |||
""" | |||
logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||
data_cmd: Sequence[Tuple[str, str]] | |||
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | |||
data_cmd = ['speech', 'sound'] | |||
elif inputs['audio_format'] == 'kaldi_ark': | |||
data_cmd = ['speech', 'kaldi_ark'] | |||
elif inputs['audio_format'] == 'tfrecord': | |||
data_cmd = ['speech', 'tfrecord'] | |||
# generate asr inference command | |||
cmd = { | |||
'model_type': inputs['model_type'], | |||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available | |||
'log_level': 'ERROR', | |||
'audio_in': inputs['audio_lists'], | |||
'name_and_type': data_cmd, | |||
'asr_model_file': inputs['am_model_path'], | |||
'idx_text': '' | |||
} | |||
if self.framework == Frameworks.torch: | |||
config_file = open(inputs['asr_model_config']) | |||
root = yaml.full_load(config_file) | |||
config_file.close() | |||
frontend_conf = None | |||
if 'frontend_conf' in root: | |||
frontend_conf = root['frontend_conf'] | |||
cmd['beam_size'] = root['beam_size'] | |||
cmd['penalty'] = root['penalty'] | |||
cmd['maxlenratio'] = root['maxlenratio'] | |||
cmd['minlenratio'] = root['minlenratio'] | |||
cmd['ctc_weight'] = root['ctc_weight'] | |||
cmd['lm_weight'] = root['lm_weight'] | |||
cmd['asr_train_config'] = inputs['am_model_config'] | |||
cmd['batch_size'] = inputs['model_config']['batch_size'] | |||
cmd['frontend_conf'] = frontend_conf | |||
elif self.framework == Frameworks.tf: | |||
cmd['fs'] = inputs['model_config']['fs'] | |||
cmd['hop_length'] = inputs['model_config']['hop_length'] | |||
cmd['feature_dims'] = inputs['model_config']['feature_dims'] | |||
cmd['predictions_file'] = 'text' | |||
cmd['mvn_file'] = inputs['am_mvn_file'] | |||
cmd['vocab_file'] = inputs['vocab_file'] | |||
if 'idx_text' in inputs: | |||
cmd['idx_text'] = inputs['idx_text'] | |||
else: | |||
raise ValueError('model type is mismatching') | |||
inputs['asr_result'] = self.run_inference(cmd) | |||
return inputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the asr results | |||
""" | |||
from easyasr.common import asr_utils | |||
logger.info('Computing the result of ASR ...') | |||
rst = {} | |||
# single wav or pcm task | |||
if inputs['recog_type'] == 'wav': | |||
if 'asr_result' in inputs and len(inputs['asr_result']) > 0: | |||
text = inputs['asr_result'][0]['value'] | |||
if len(text) > 0: | |||
rst[OutputKeys.TEXT] = text | |||
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord | |||
elif inputs['recog_type'] != 'wav': | |||
inputs['reference_list'] = self.ref_list_tidy(inputs) | |||
inputs['datasets_result'] = asr_utils.compute_wer( | |||
inputs['asr_result'], inputs['reference_list']) | |||
else: | |||
raise ValueError('recog_type and audio_format are mismatching') | |||
if 'datasets_result' in inputs: | |||
rst[OutputKeys.TEXT] = inputs['datasets_result'] | |||
return rst | |||
def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]: | |||
ref_list = [] | |||
if inputs['audio_format'] == 'tfrecord': | |||
# should assemble idx + txt | |||
with open(inputs['reference_text'], 'r', encoding='utf-8') as r: | |||
text_lines = r.readlines() | |||
with open(inputs['idx_text'], 'r', encoding='utf-8') as i: | |||
idx_lines = i.readlines() | |||
j: int = 0 | |||
while j < min(len(text_lines), len(idx_lines)): | |||
idx_str = idx_lines[j].strip() | |||
text_str = text_lines[j].strip().replace(' ', '') | |||
item = {'key': idx_str, 'value': text_str} | |||
ref_list.append(item) | |||
j += 1 | |||
else: | |||
# text contain idx + sentence | |||
with open(inputs['reference_text'], 'r', encoding='utf-8') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
line_item = line.split() | |||
item = {'key': line_item[0], 'value': line_item[1]} | |||
ref_list.append(item) | |||
return ref_list | |||
def run_inference(self, cmd): | |||
asr_result = [] | |||
if self.framework == Frameworks.torch: | |||
from easyasr import asr_inference_paraformer_espnet | |||
asr_result = asr_inference_paraformer_espnet.asr_inference( | |||
batch_size=cmd['batch_size'], | |||
maxlenratio=cmd['maxlenratio'], | |||
minlenratio=cmd['minlenratio'], | |||
beam_size=cmd['beam_size'], | |||
ngpu=cmd['ngpu'], | |||
ctc_weight=cmd['ctc_weight'], | |||
lm_weight=cmd['lm_weight'], | |||
penalty=cmd['penalty'], | |||
log_level=cmd['log_level'], | |||
name_and_type=cmd['name_and_type'], | |||
audio_lists=cmd['audio_in'], | |||
asr_train_config=cmd['asr_train_config'], | |||
asr_model_file=cmd['asr_model_file'], | |||
frontend_conf=cmd['frontend_conf']) | |||
elif self.framework == Frameworks.tf: | |||
from easyasr import asr_inference_paraformer_tf | |||
asr_result = asr_inference_paraformer_tf.asr_inference( | |||
ngpu=cmd['ngpu'], | |||
name_and_type=cmd['name_and_type'], | |||
audio_lists=cmd['audio_in'], | |||
idx_text_file=cmd['idx_text'], | |||
asr_model_file=cmd['asr_model_file'], | |||
vocab_file=cmd['vocab_file'], | |||
am_mvn_file=cmd['mvn_file'], | |||
predictions_file=cmd['predictions_file'], | |||
fs=cmd['fs'], | |||
hop_length=cmd['hop_length'], | |||
feature_dims=cmd['feature_dims']) | |||
return asr_result |
@@ -1,14 +1,9 @@ | |||
import io | |||
import os | |||
import shutil | |||
from pathlib import Path | |||
from typing import Any, Dict, List | |||
import yaml | |||
from typing import Any, Dict, List, Union | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.models.base import Model | |||
from modelscope.utils.constant import Fields | |||
from modelscope.utils.constant import Fields, Frameworks | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
@@ -19,44 +14,32 @@ __all__ = ['WavToScp'] | |||
Fields.audio, module_name=Preprocessors.wav_to_scp) | |||
class WavToScp(Preprocessor): | |||
"""generate audio scp from wave or ark | |||
Args: | |||
workspace (str): | |||
""" | |||
def __init__(self, workspace: str = None): | |||
# the workspace path | |||
if workspace is None or len(workspace) == 0: | |||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||
else: | |||
self._workspace = workspace | |||
if not os.path.exists(self._workspace): | |||
os.mkdir(self._workspace) | |||
def __init__(self): | |||
pass | |||
def __call__(self, | |||
model: List[Model] = None, | |||
model: Model = None, | |||
recog_type: str = None, | |||
audio_format: str = None, | |||
wav_path: str = None) -> Dict[str, Any]: | |||
assert len(model) > 0, 'preprocess model is invalid' | |||
assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||
assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||
assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||
self._am_model = model[0] | |||
if len(model) == 2 and model[1] is not None: | |||
self._lm_model = model[1] | |||
out = self.forward(self._am_model.forward(), recog_type, audio_format, | |||
wav_path) | |||
audio_in: Union[str, bytes] = None) -> Dict[str, Any]: | |||
assert model is not None, 'preprocess model is empty' | |||
assert recog_type is not None and len( | |||
recog_type) > 0, 'preprocess recog_type is empty' | |||
assert audio_format is not None, 'preprocess audio_format is empty' | |||
assert audio_in is not None, 'preprocess audio_in is empty' | |||
self.am_model = model | |||
out = self.forward(self.am_model.forward(), recog_type, audio_format, | |||
audio_in) | |||
return out | |||
def forward(self, model: Dict[str, Any], recog_type: str, | |||
audio_format: str, wav_path: str) -> Dict[str, Any]: | |||
audio_format: str, audio_in: Union[str, | |||
bytes]) -> Dict[str, Any]: | |||
assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||
assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||
assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||
assert os.path.exists(wav_path), 'preprocess wav_path does not exist' | |||
assert len( | |||
model['am_model']) > 0, 'preprocess model[am_model] is empty' | |||
assert len(model['am_model_path'] | |||
@@ -70,90 +53,104 @@ class WavToScp(Preprocessor): | |||
assert len(model['model_config'] | |||
) > 0, 'preprocess model[model_config] is empty' | |||
# the am model name | |||
am_model: str = model['am_model'] | |||
# the am model file path | |||
am_model_path: str = model['am_model_path'] | |||
# the recognition model dir path | |||
model_workspace: str = model['model_workspace'] | |||
# the recognition model config dict | |||
global_model_config_dict: str = model['model_config'] | |||
rst = { | |||
'workspace': os.path.join(self._workspace, recog_type), | |||
'am_model': am_model, | |||
'am_model_path': am_model_path, | |||
'model_workspace': model_workspace, | |||
# the recognition model dir path | |||
'model_workspace': model['model_workspace'], | |||
# the am model name | |||
'am_model': model['am_model'], | |||
# the am model file path | |||
'am_model_path': model['am_model_path'], | |||
# the asr type setting, eg: test dev train wav | |||
'recog_type': recog_type, | |||
# the asr audio format setting, eg: wav, kaldi_ark | |||
# the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord | |||
'audio_format': audio_format, | |||
# the test wav file path or the dataset path | |||
'wav_path': wav_path, | |||
'model_config': global_model_config_dict | |||
# the recognition model config dict | |||
'model_config': model['model_config'] | |||
} | |||
out = self._config_checking(rst) | |||
out = self._env_setting(out) | |||
if isinstance(audio_in, str): | |||
# wav file path or the dataset path | |||
rst['wav_path'] = audio_in | |||
out = self.config_checking(rst) | |||
out = self.env_setting(out) | |||
if audio_format == 'wav': | |||
out = self._scp_generation_from_wav(out) | |||
out['audio_lists'] = self.scp_generation_from_wav(out) | |||
elif audio_format == 'kaldi_ark': | |||
out = self._scp_generation_from_ark(out) | |||
out['audio_lists'] = self.scp_generation_from_ark(out) | |||
elif audio_format == 'tfrecord': | |||
out['audio_lists'] = os.path.join(out['wav_path'], 'data.records') | |||
elif audio_format == 'pcm': | |||
out['audio_lists'] = audio_in | |||
return out | |||
def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
def config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""config checking | |||
""" | |||
assert inputs['model_config'].__contains__( | |||
'type'), 'model type does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'batch_size'), 'batch_size does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'am_model_config'), 'am_model_config does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'asr_model_config'), 'asr_model_config does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||
am_model_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['am_model_config']) | |||
assert os.path.exists( | |||
am_model_config), 'am_model_config does not exist' | |||
inputs['am_model_config'] = am_model_config | |||
asr_model_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['asr_model_config']) | |||
assert os.path.exists( | |||
asr_model_config), 'asr_model_config does not exist' | |||
asr_model_wav_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['asr_model_wav_config']) | |||
assert os.path.exists( | |||
asr_model_wav_config), 'asr_model_wav_config does not exist' | |||
inputs['model_type'] = inputs['model_config']['type'] | |||
if inputs['audio_format'] == 'wav': | |||
inputs['asr_model_config'] = asr_model_wav_config | |||
if inputs['model_type'] == Frameworks.torch: | |||
assert inputs['model_config'].__contains__( | |||
'batch_size'), 'batch_size does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'am_model_config'), 'am_model_config does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'asr_model_config'), 'asr_model_config does not exist' | |||
assert inputs['model_config'].__contains__( | |||
'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||
am_model_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['am_model_config']) | |||
assert os.path.exists( | |||
am_model_config), 'am_model_config does not exist' | |||
inputs['am_model_config'] = am_model_config | |||
asr_model_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['asr_model_config']) | |||
assert os.path.exists( | |||
asr_model_config), 'asr_model_config does not exist' | |||
asr_model_wav_config: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['asr_model_wav_config']) | |||
assert os.path.exists( | |||
asr_model_wav_config), 'asr_model_wav_config does not exist' | |||
if inputs['audio_format'] == 'wav' or inputs[ | |||
'audio_format'] == 'pcm': | |||
inputs['asr_model_config'] = asr_model_wav_config | |||
else: | |||
inputs['asr_model_config'] = asr_model_config | |||
elif inputs['model_type'] == Frameworks.tf: | |||
assert inputs['model_config'].__contains__( | |||
'vocab_file'), 'vocab_file does not exist' | |||
vocab_file: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['vocab_file']) | |||
assert os.path.exists(vocab_file), 'vocab file does not exist' | |||
inputs['vocab_file'] = vocab_file | |||
assert inputs['model_config'].__contains__( | |||
'am_mvn_file'), 'am_mvn_file does not exist' | |||
am_mvn_file: str = os.path.join( | |||
inputs['model_workspace'], | |||
inputs['model_config']['am_mvn_file']) | |||
assert os.path.exists(am_mvn_file), 'am mvn file does not exist' | |||
inputs['am_mvn_file'] = am_mvn_file | |||
else: | |||
inputs['asr_model_config'] = asr_model_config | |||
raise ValueError('model type is mismatched') | |||
return inputs | |||
def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
if not os.path.exists(inputs['workspace']): | |||
os.mkdir(inputs['workspace']) | |||
inputs['output'] = os.path.join(inputs['workspace'], 'logdir') | |||
if not os.path.exists(inputs['output']): | |||
os.mkdir(inputs['output']) | |||
def env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
# run with datasets, should set datasets_path and text_path | |||
if inputs['recog_type'] != 'wav': | |||
inputs['datasets_path'] = inputs['wav_path'] | |||
@@ -162,25 +159,39 @@ class WavToScp(Preprocessor): | |||
if inputs['audio_format'] == 'wav': | |||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
'wav', inputs['recog_type']) | |||
inputs['hypothesis_text'] = os.path.join( | |||
inputs['reference_text'] = os.path.join( | |||
inputs['datasets_path'], 'transcript', 'data.text') | |||
assert os.path.exists(inputs['hypothesis_text'] | |||
), 'hypothesis text does not exist' | |||
assert os.path.exists( | |||
inputs['reference_text']), 'reference text does not exist' | |||
# run with datasets, and audio format is kaldi_ark | |||
elif inputs['audio_format'] == 'kaldi_ark': | |||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
inputs['recog_type']) | |||
inputs['hypothesis_text'] = os.path.join( | |||
inputs['reference_text'] = os.path.join( | |||
inputs['wav_path'], 'data.text') | |||
assert os.path.exists(inputs['hypothesis_text'] | |||
), 'hypothesis text does not exist' | |||
assert os.path.exists( | |||
inputs['reference_text']), 'reference text does not exist' | |||
# run with datasets, and audio format is tfrecord | |||
elif inputs['audio_format'] == 'tfrecord': | |||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
inputs['recog_type']) | |||
inputs['reference_text'] = os.path.join( | |||
inputs['wav_path'], 'data.txt') | |||
assert os.path.exists( | |||
inputs['reference_text']), 'reference text does not exist' | |||
inputs['idx_text'] = os.path.join(inputs['wav_path'], | |||
'data.idx') | |||
assert os.path.exists( | |||
inputs['idx_text']), 'idx text does not exist' | |||
return inputs | |||
def _scp_generation_from_wav(self, inputs: Dict[str, | |||
Any]) -> Dict[str, Any]: | |||
def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]: | |||
"""scp generation from waveform files | |||
""" | |||
from easyasr.common import asr_utils | |||
# find all waveform files | |||
wav_list = [] | |||
@@ -191,64 +202,46 @@ class WavToScp(Preprocessor): | |||
wav_list.append(file_path) | |||
else: | |||
wav_dir: str = inputs['wav_path'] | |||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||
wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir) | |||
list_count: int = len(wav_list) | |||
inputs['wav_count'] = list_count | |||
# store all wav into data.0.scp | |||
inputs['thread_count'] = 1 | |||
# store all wav into audio list | |||
audio_lists = [] | |||
j: int = 0 | |||
wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||
with open(wav_list_path, 'a') as f: | |||
while j < list_count: | |||
wav_file = wav_list[j] | |||
wave_scp_content: str = os.path.splitext( | |||
os.path.basename(wav_file))[0] | |||
wave_scp_content += ' ' + wav_file + '\n' | |||
f.write(wave_scp_content) | |||
j += 1 | |||
while j < list_count: | |||
wav_file = wav_list[j] | |||
wave_key: str = os.path.splitext(os.path.basename(wav_file))[0] | |||
item = {'key': wave_key, 'file': wav_file} | |||
audio_lists.append(item) | |||
j += 1 | |||
return inputs | |||
return audio_lists | |||
def _scp_generation_from_ark(self, inputs: Dict[str, | |||
Any]) -> Dict[str, Any]: | |||
def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]: | |||
"""scp generation from kaldi ark file | |||
""" | |||
inputs['thread_count'] = 1 | |||
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') | |||
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') | |||
assert os.path.exists(ark_scp_path), 'data.scp does not exist' | |||
assert os.path.exists(ark_file_path), 'data.ark does not exist' | |||
new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||
with open(ark_scp_path, 'r', encoding='utf-8') as f: | |||
lines = f.readlines() | |||
with open(new_ark_scp_path, 'w', encoding='utf-8') as n: | |||
for line in lines: | |||
outs = line.strip().split(' ') | |||
if len(outs) == 2: | |||
key = outs[0] | |||
sub = outs[1].split(':') | |||
if len(sub) == 2: | |||
nums = sub[1] | |||
content = key + ' ' + ark_file_path + ':' + nums + '\n' | |||
n.write(content) | |||
return inputs | |||
def _recursion_dir_all_wave(self, wav_list, | |||
dir_path: str) -> Dict[str, Any]: | |||
dir_files = os.listdir(dir_path) | |||
for file in dir_files: | |||
file_path = os.path.join(dir_path, file) | |||
if os.path.isfile(file_path): | |||
if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||
wav_list.append(file_path) | |||
elif os.path.isdir(file_path): | |||
self._recursion_dir_all_wave(wav_list, file_path) | |||
return wav_list | |||
# store all ark item into audio list | |||
audio_lists = [] | |||
for line in lines: | |||
outs = line.strip().split(' ') | |||
if len(outs) == 2: | |||
key = outs[0] | |||
sub = outs[1].split(':') | |||
if len(sub) == 2: | |||
nums = sub[1] | |||
content = ark_file_path + ':' + nums | |||
item = {'key': key, 'file': content} | |||
audio_lists.append(item) | |||
return audio_lists |
@@ -1,3 +1,4 @@ | |||
easyasr>=0.0.2 | |||
espnet>=202204 | |||
#tts | |||
h5py | |||
@@ -1,15 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import sys | |||
import tarfile | |||
import unittest | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import requests | |||
import soundfile | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.constant import ColorCodes, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
from modelscope.utils.test_utils import download_and_untar, test_level | |||
logger = get_logger() | |||
@@ -21,6 +26,9 @@ LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AS | |||
AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' | |||
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' | |||
TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' | |||
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' | |||
def un_tar_gz(fname, dirs): | |||
t = tarfile.open(fname) | |||
@@ -28,45 +36,168 @@ def un_tar_gz(fname, dirs): | |||
class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
action_info = { | |||
'test_run_with_wav_pytorch': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_pcm_pytorch': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_wav_tf': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_pcm_tf': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'wav_example' | |||
}, | |||
'test_run_with_wav_dataset_pytorch': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'dataset_example' | |||
}, | |||
'test_run_with_wav_dataset_tf': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'dataset_example' | |||
}, | |||
'test_run_with_ark_dataset': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'dataset_example' | |||
}, | |||
'test_run_with_tfrecord_dataset': { | |||
'checking_item': OutputKeys.TEXT, | |||
'example': 'dataset_example' | |||
}, | |||
'dataset_example': { | |||
'Wrd': 49532, # the number of words | |||
'Snt': 5000, # the number of sentences | |||
'Corr': 47276, # the number of correct words | |||
'Ins': 49, # the number of insert words | |||
'Del': 152, # the number of delete words | |||
'Sub': 2207, # the number of substitution words | |||
'wrong_words': 2408, # the number of wrong words | |||
'wrong_sentences': 1598, # the number of wrong sentences | |||
'Err': 4.86, # WER/CER | |||
'S.Err': 31.96 # SER | |||
}, | |||
'wav_example': { | |||
'text': '每一天都要快乐喔' | |||
} | |||
} | |||
def setUp(self) -> None: | |||
self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | |||
# this temporary workspace dir will store waveform files | |||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||
if not os.path.exists(self._workspace): | |||
os.mkdir(self._workspace) | |||
self.workspace = os.path.join(os.getcwd(), '.tmp') | |||
if not os.path.exists(self.workspace): | |||
os.mkdir(self.workspace) | |||
def tearDown(self) -> None: | |||
# remove workspace dir (.tmp) | |||
shutil.rmtree(self.workspace, ignore_errors=True) | |||
def run_pipeline(self, model_id: str, | |||
audio_in: Union[str, bytes]) -> Dict[str, Any]: | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=model_id) | |||
rec_result = inference_16k_pipline(audio_in) | |||
return rec_result | |||
def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||
+ ColorCodes.END) | |||
logger.error( | |||
ColorCodes.MAGENTA + functions + ' correct result example:' | |||
+ ColorCodes.YELLOW | |||
+ str(self.action_info[self.action_info[functions]['example']]) | |||
+ ColorCodes.END) | |||
raise ValueError('asr result is mismatched') | |||
def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||
if result.__contains__(self.action_info[functions]['checking_item']): | |||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||
+ ColorCodes.END) | |||
logger.info( | |||
ColorCodes.YELLOW | |||
+ str(result[self.action_info[functions]['checking_item']]) | |||
+ ColorCodes.END) | |||
else: | |||
self.log_error(functions, result) | |||
def wav2bytes(self, wav_file) -> bytes: | |||
audio, fs = soundfile.read(wav_file) | |||
# float32 -> int16 | |||
audio = np.asarray(audio) | |||
dtype = np.dtype('int16') | |||
i = np.iinfo(dtype) | |||
abs_max = 2**(i.bits - 1) | |||
offset = i.min + abs_max | |||
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||
# int16(PCM_16) -> byte | |||
audio = audio.tobytes() | |||
return audio | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav(self): | |||
def test_run_with_wav_pytorch(self): | |||
'''run with single waveform file | |||
''' | |||
logger.info('Run ASR test with waveform file ...') | |||
logger.info('Run ASR test with waveform file (pytorch)...') | |||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
self.assertTrue(inference_16k_pipline is not None) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav_pytorch', rec_result) | |||
rec_result = inference_16k_pipline(wav_file_path) | |||
self.assertTrue(len(rec_result['rec_result']) > 0) | |||
self.assertTrue(rec_result['rec_result'] != 'None') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_pcm_pytorch(self): | |||
'''run with wav data | |||
''' | |||
result structure: | |||
{ | |||
'rec_result': '每一天都要快乐喔' | |||
} | |||
or | |||
{ | |||
'rec_result': 'None' | |||
} | |||
logger.info('Run ASR test with wav data (pytorch)...') | |||
audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=audio) | |||
self.check_result('test_run_with_pcm_pytorch', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav_tf(self): | |||
'''run with single waveform file | |||
''' | |||
logger.info('test_run_with_wav rec result: ' | |||
+ rec_result['rec_result']) | |||
logger.info('Run ASR test with waveform file (tensorflow)...') | |||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=wav_file_path) | |||
self.check_result('test_run_with_wav_tf', rec_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_pcm_tf(self): | |||
'''run with wav data | |||
''' | |||
logger.info('Run ASR test with wav data (tensorflow)...') | |||
audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=audio) | |||
self.check_result('test_run_with_pcm_tf', rec_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_wav_dataset(self): | |||
def test_run_with_wav_dataset_pytorch(self): | |||
'''run with datasets, and audio format is waveform | |||
datasets directory: | |||
<dataset_path> | |||
@@ -84,57 +215,48 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
data.text # hypothesis text | |||
''' | |||
logger.info('Run ASR test with waveform dataset ...') | |||
logger.info('Run ASR test with waveform dataset (pytorch)...') | |||
logger.info('Downloading waveform testsets file ...') | |||
# downloading pos_testsets file | |||
testsets_file_path = os.path.join(self._workspace, | |||
LITTLE_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(LITTLE_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename( | |||
os.path.splitext( | |||
os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0] | |||
# dataset_path = <cwd>/.tmp/data_aishell/wav/test | |||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav', | |||
'test') | |||
# untar the dataset_path file | |||
if not os.path.exists(dataset_path): | |||
un_tar_gz(testsets_file_path, self._workspace) | |||
dataset_path = download_and_untar( | |||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||
LITTLE_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
self.assertTrue(inference_16k_pipline is not None) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | |||
rec_result = inference_16k_pipline(wav_path=dataset_path) | |||
self.assertTrue(len(rec_result['datasets_result']) > 0) | |||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||
''' | |||
result structure: | |||
{ | |||
'rec_result': 'None', | |||
'datasets_result': | |||
{ | |||
'Wrd': 1654, # the number of words | |||
'Snt': 128, # the number of sentences | |||
'Corr': 1573, # the number of correct words | |||
'Ins': 1, # the number of insert words | |||
'Del': 1, # the number of delete words | |||
'Sub': 80, # the number of substitution words | |||
'wrong_words': 82, # the number of wrong words | |||
'wrong_sentences': 47, # the number of wrong sentences | |||
'Err': 4.96, # WER/CER | |||
'S.Err': 36.72 # SER | |||
} | |||
} | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_wav_dataset_tf(self): | |||
'''run with datasets, and audio format is waveform | |||
datasets directory: | |||
<dataset_path> | |||
wav | |||
test # testsets | |||
xx.wav | |||
... | |||
dev # devsets | |||
yy.wav | |||
... | |||
train # trainsets | |||
zz.wav | |||
... | |||
transcript | |||
data.text # hypothesis text | |||
''' | |||
logger.info('test_run_with_wav_dataset datasets result: ') | |||
logger.info(rec_result['datasets_result']) | |||
logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||
logger.info('Downloading waveform testsets file ...') | |||
dataset_path = download_and_untar( | |||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||
LITTLE_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_ark_dataset(self): | |||
@@ -155,56 +277,40 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
data.text | |||
''' | |||
logger.info('Run ASR test with ark dataset ...') | |||
logger.info('Run ASR test with ark dataset (pytorch)...') | |||
logger.info('Downloading ark testsets file ...') | |||
# downloading pos_testsets file | |||
testsets_file_path = os.path.join(self._workspace, | |||
AISHELL1_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(AISHELL1_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename( | |||
os.path.splitext( | |||
os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0] | |||
# dataset_path = <cwd>/.tmp/aishell1/test | |||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test') | |||
# untar the dataset_path file | |||
if not os.path.exists(dataset_path): | |||
un_tar_gz(testsets_file_path, self._workspace) | |||
dataset_path = download_and_untar( | |||
os.path.join(self.workspace, AISHELL1_TESTSETS_FILE), | |||
AISHELL1_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'test') | |||
inference_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
self.assertTrue(inference_16k_pipline is not None) | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_ark_dataset', rec_result) | |||
rec_result = inference_16k_pipline(wav_path=dataset_path) | |||
self.assertTrue(len(rec_result['datasets_result']) > 0) | |||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||
''' | |||
result structure: | |||
{ | |||
'rec_result': 'None', | |||
'datasets_result': | |||
{ | |||
'Wrd': 104816, # the number of words | |||
'Snt': 7176, # the number of sentences | |||
'Corr': 99327, # the number of correct words | |||
'Ins': 104, # the number of insert words | |||
'Del': 155, # the number of delete words | |||
'Sub': 5334, # the number of substitution words | |||
'wrong_words': 5593, # the number of wrong words | |||
'wrong_sentences': 2898, # the number of wrong sentences | |||
'Err': 5.34, # WER/CER | |||
'S.Err': 40.38 # SER | |||
} | |||
} | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_tfrecord_dataset(self): | |||
'''run with datasets, and audio format is tfrecord | |||
datasets directory: | |||
<dataset_path> | |||
test # testsets | |||
data.records | |||
data.idx | |||
data.text | |||
''' | |||
logger.info('test_run_with_ark_dataset datasets result: ') | |||
logger.info(rec_result['datasets_result']) | |||
logger.info('Run ASR test with tfrecord dataset (tensorflow)...') | |||
logger.info('Downloading tfrecord testsets file ...') | |||
dataset_path = download_and_untar( | |||
os.path.join(self.workspace, TFRECORD_TESTSETS_FILE), | |||
TFRECORD_TESTSETS_URL, self.workspace) | |||
dataset_path = os.path.join(dataset_path, 'test') | |||
rec_result = self.run_pipeline( | |||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
self.check_result('test_run_with_tfrecord_dataset', rec_result) | |||
if __name__ == '__main__': | |||