Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9410174master
@@ -20,8 +20,10 @@ class GenericAutomaticSpeechRecognition(Model): | |||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
am_model_name (str): the am model name from configuration.json | am_model_name (str): the am model name from configuration.json | ||||
model_config (Dict[str, Any]): the detail config about model from configuration.json | |||||
""" | """ | ||||
super().__init__(model_dir, am_model_name, model_config, *args, | |||||
**kwargs) | |||||
self.model_cfg = { | self.model_cfg = { | ||||
# the recognition model dir path | # the recognition model dir path | ||||
'model_workspace': model_dir, | 'model_workspace': model_dir, | ||||
@@ -312,5 +312,11 @@ TASK_OUTPUTS = { | |||||
# { | # { | ||||
# "text": "this is the text generated by a model." | # "text": "this is the text generated by a model." | ||||
# } | # } | ||||
Tasks.visual_question_answering: [OutputKeys.TEXT] | |||||
Tasks.visual_question_answering: [OutputKeys.TEXT], | |||||
# auto_speech_recognition result for a single sample | |||||
# { | |||||
# "text": "每天都要快乐喔" | |||||
# } | |||||
Tasks.auto_speech_recognition: [OutputKeys.TEXT] | |||||
} | } |
@@ -3,7 +3,7 @@ | |||||
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | ||||
try: | try: | ||||
from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||||
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||||
from .kws_kwsbp_pipeline import * # noqa F403 | from .kws_kwsbp_pipeline import * # noqa F403 | ||||
from .linear_aec_pipeline import LinearAECPipeline | from .linear_aec_pipeline import LinearAECPipeline | ||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
@@ -1,21 +0,0 @@ | |||||
import ssl | |||||
import nltk | |||||
try: | |||||
_create_unverified_https_context = ssl._create_unverified_context | |||||
except AttributeError: | |||||
pass | |||||
else: | |||||
ssl._create_default_https_context = _create_unverified_https_context | |||||
try: | |||||
nltk.data.find('taggers/averaged_perceptron_tagger') | |||||
except LookupError: | |||||
nltk.download( | |||||
'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True) | |||||
try: | |||||
nltk.data.find('corpora/cmudict') | |||||
except LookupError: | |||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True) |
@@ -1,690 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
import argparse | |||||
import logging | |||||
import sys | |||||
import time | |||||
from pathlib import Path | |||||
from typing import Any, Optional, Sequence, Tuple, Union | |||||
import numpy as np | |||||
import torch | |||||
from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | |||||
from espnet2.asr.transducer.beam_search_transducer import \ | |||||
ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | |||||
from espnet2.asr.transducer.beam_search_transducer import \ | |||||
Hypothesis as TransHypothesis | |||||
from espnet2.fileio.datadir_writer import DatadirWriter | |||||
from espnet2.tasks.lm import LMTask | |||||
from espnet2.text.build_tokenizer import build_tokenizer | |||||
from espnet2.text.token_id_converter import TokenIDConverter | |||||
from espnet2.torch_utils.device_funcs import to_device | |||||
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed | |||||
from espnet2.utils import config_argparse | |||||
from espnet2.utils.types import str2bool, str2triple_str, str_or_none | |||||
from espnet.nets.batch_beam_search import BatchBeamSearch | |||||
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim | |||||
from espnet.nets.beam_search import BeamSearch, Hypothesis | |||||
from espnet.nets.pytorch_backend.transformer.subsampling import \ | |||||
TooShortUttError | |||||
from espnet.nets.scorer_interface import BatchScorerInterface | |||||
from espnet.nets.scorers.ctc import CTCPrefixScorer | |||||
from espnet.nets.scorers.length_bonus import LengthBonus | |||||
from espnet.utils.cli_utils import get_commandline_args | |||||
from typeguard import check_argument_types | |||||
from .espnet.asr.frontend.wav_frontend import WavFrontend | |||||
from .espnet.tasks.asr import ASRTaskNAR as ASRTask | |||||
class Speech2Text: | |||||
def __init__(self, | |||||
asr_train_config: Union[Path, str] = None, | |||||
asr_model_file: Union[Path, str] = None, | |||||
transducer_conf: dict = None, | |||||
lm_train_config: Union[Path, str] = None, | |||||
lm_file: Union[Path, str] = None, | |||||
ngram_scorer: str = 'full', | |||||
ngram_file: Union[Path, str] = None, | |||||
token_type: str = None, | |||||
bpemodel: str = None, | |||||
device: str = 'cpu', | |||||
maxlenratio: float = 0.0, | |||||
minlenratio: float = 0.0, | |||||
batch_size: int = 1, | |||||
dtype: str = 'float32', | |||||
beam_size: int = 20, | |||||
ctc_weight: float = 0.5, | |||||
lm_weight: float = 1.0, | |||||
ngram_weight: float = 0.9, | |||||
penalty: float = 0.0, | |||||
nbest: int = 1, | |||||
streaming: bool = False, | |||||
frontend_conf: dict = None): | |||||
assert check_argument_types() | |||||
# 1. Build ASR model | |||||
scorers = {} | |||||
asr_model, asr_train_args = ASRTask.build_model_from_file( | |||||
asr_train_config, asr_model_file, device) | |||||
if asr_model.frontend is None and frontend_conf is not None: | |||||
frontend = WavFrontend(**frontend_conf) | |||||
asr_model.frontend = frontend | |||||
asr_model.to(dtype=getattr(torch, dtype)).eval() | |||||
decoder = asr_model.decoder | |||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) | |||||
token_list = asr_model.token_list | |||||
scorers.update( | |||||
decoder=decoder, | |||||
ctc=ctc, | |||||
length_bonus=LengthBonus(len(token_list)), | |||||
) | |||||
# 2. Build Language model | |||||
if lm_train_config is not None: | |||||
lm, lm_train_args = LMTask.build_model_from_file( | |||||
lm_train_config, lm_file, device) | |||||
scorers['lm'] = lm.lm | |||||
# 3. Build ngram model | |||||
if ngram_file is not None: | |||||
if ngram_scorer == 'full': | |||||
from espnet.nets.scorers.ngram import NgramFullScorer | |||||
ngram = NgramFullScorer(ngram_file, token_list) | |||||
else: | |||||
from espnet.nets.scorers.ngram import NgramPartScorer | |||||
ngram = NgramPartScorer(ngram_file, token_list) | |||||
else: | |||||
ngram = None | |||||
scorers['ngram'] = ngram | |||||
# 4. Build BeamSearch object | |||||
if asr_model.use_transducer_decoder: | |||||
beam_search_transducer = BeamSearchTransducer( | |||||
decoder=asr_model.decoder, | |||||
joint_network=asr_model.joint_network, | |||||
beam_size=beam_size, | |||||
lm=scorers['lm'] if 'lm' in scorers else None, | |||||
lm_weight=lm_weight, | |||||
**transducer_conf, | |||||
) | |||||
beam_search = None | |||||
else: | |||||
beam_search_transducer = None | |||||
weights = dict( | |||||
decoder=1.0 - ctc_weight, | |||||
ctc=ctc_weight, | |||||
lm=lm_weight, | |||||
ngram=ngram_weight, | |||||
length_bonus=penalty, | |||||
) | |||||
beam_search = BeamSearch( | |||||
beam_size=beam_size, | |||||
weights=weights, | |||||
scorers=scorers, | |||||
sos=asr_model.sos, | |||||
eos=asr_model.eos, | |||||
vocab_size=len(token_list), | |||||
token_list=token_list, | |||||
pre_beam_score_key=None if ctc_weight == 1.0 else 'full', | |||||
) | |||||
# TODO(karita): make all scorers batchfied | |||||
if batch_size == 1: | |||||
non_batch = [ | |||||
k for k, v in beam_search.full_scorers.items() | |||||
if not isinstance(v, BatchScorerInterface) | |||||
] | |||||
if len(non_batch) == 0: | |||||
if streaming: | |||||
beam_search.__class__ = BatchBeamSearchOnlineSim | |||||
beam_search.set_streaming_config(asr_train_config) | |||||
logging.info( | |||||
'BatchBeamSearchOnlineSim implementation is selected.' | |||||
) | |||||
else: | |||||
beam_search.__class__ = BatchBeamSearch | |||||
else: | |||||
logging.warning( | |||||
f'As non-batch scorers {non_batch} are found, ' | |||||
f'fall back to non-batch implementation.') | |||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() | |||||
for scorer in scorers.values(): | |||||
if isinstance(scorer, torch.nn.Module): | |||||
scorer.to( | |||||
device=device, dtype=getattr(torch, dtype)).eval() | |||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text | |||||
if token_type is None: | |||||
token_type = asr_train_args.token_type | |||||
if bpemodel is None: | |||||
bpemodel = asr_train_args.bpemodel | |||||
if token_type is None: | |||||
tokenizer = None | |||||
elif token_type == 'bpe': | |||||
if bpemodel is not None: | |||||
tokenizer = build_tokenizer( | |||||
token_type=token_type, bpemodel=bpemodel) | |||||
else: | |||||
tokenizer = None | |||||
else: | |||||
tokenizer = build_tokenizer(token_type=token_type) | |||||
converter = TokenIDConverter(token_list=token_list) | |||||
self.asr_model = asr_model | |||||
self.asr_train_args = asr_train_args | |||||
self.converter = converter | |||||
self.tokenizer = tokenizer | |||||
self.beam_search = beam_search | |||||
self.beam_search_transducer = beam_search_transducer | |||||
self.maxlenratio = maxlenratio | |||||
self.minlenratio = minlenratio | |||||
self.device = device | |||||
self.dtype = dtype | |||||
self.nbest = nbest | |||||
@torch.no_grad() | |||||
def __call__(self, speech: Union[torch.Tensor, np.ndarray]): | |||||
"""Inference | |||||
Args: | |||||
data: Input speech data | |||||
Returns: | |||||
text, token, token_int, hyp | |||||
""" | |||||
assert check_argument_types() | |||||
# Input as audio signal | |||||
if isinstance(speech, np.ndarray): | |||||
speech = torch.tensor(speech) | |||||
# data: (Nsamples,) -> (1, Nsamples) | |||||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) | |||||
# lengths: (1,) | |||||
lengths = speech.new_full([1], | |||||
dtype=torch.long, | |||||
fill_value=speech.size(1)) | |||||
batch = {'speech': speech, 'speech_lengths': lengths} | |||||
# a. To device | |||||
batch = to_device(batch, device=self.device) | |||||
# b. Forward Encoder | |||||
enc, enc_len = self.asr_model.encode(**batch) | |||||
if isinstance(enc, tuple): | |||||
enc = enc[0] | |||||
assert len(enc) == 1, len(enc) | |||||
predictor_outs = self.asr_model.calc_predictor(enc, enc_len) | |||||
pre_acoustic_embeds, pre_token_length = predictor_outs[ | |||||
0], predictor_outs[1] | |||||
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], | |||||
device=pre_acoustic_embeds.device) | |||||
decoder_outs = self.asr_model.cal_decoder_with_predictor( | |||||
enc, enc_len, pre_acoustic_embeds, pre_token_length) | |||||
decoder_out = decoder_outs[0] | |||||
yseq = decoder_out.argmax(dim=-1) | |||||
score = decoder_out.max(dim=-1)[0] | |||||
score = torch.sum(score, dim=-1) | |||||
# pad with mask tokens to ensure compatibility with sos/eos tokens | |||||
yseq = torch.tensor( | |||||
[self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos], | |||||
device=yseq.device) | |||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |||||
results = [] | |||||
for hyp in nbest_hyps: | |||||
assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp) | |||||
# remove sos/eos and get results | |||||
last_pos = None if self.asr_model.use_transducer_decoder else -1 | |||||
if isinstance(hyp.yseq, list): | |||||
token_int = hyp.yseq[1:last_pos] | |||||
else: | |||||
token_int = hyp.yseq[1:last_pos].tolist() | |||||
# remove blank symbol id, which is assumed to be 0 | |||||
token_int = list(filter(lambda x: x != 0, token_int)) | |||||
# Change integer-ids to tokens | |||||
token = self.converter.ids2tokens(token_int) | |||||
if self.tokenizer is not None: | |||||
text = self.tokenizer.tokens2text(token) | |||||
else: | |||||
text = None | |||||
results.append((text, token, token_int, hyp, speech.size(1))) | |||||
return results | |||||
@staticmethod | |||||
def from_pretrained( | |||||
model_tag: Optional[str] = None, | |||||
**kwargs: Optional[Any], | |||||
): | |||||
"""Build Speech2Text instance from the pretrained model. | |||||
Args: | |||||
model_tag (Optional[str]): Model tag of the pretrained models. | |||||
Currently, the tags of espnet_model_zoo are supported. | |||||
Returns: | |||||
Speech2Text: Speech2Text instance. | |||||
""" | |||||
if model_tag is not None: | |||||
try: | |||||
from espnet_model_zoo.downloader import ModelDownloader | |||||
except ImportError: | |||||
logging.error( | |||||
'`espnet_model_zoo` is not installed. ' | |||||
'Please install via `pip install -U espnet_model_zoo`.') | |||||
raise | |||||
d = ModelDownloader() | |||||
kwargs.update(**d.download_and_unpack(model_tag)) | |||||
return Speech2Text(**kwargs) | |||||
def inference( | |||||
output_dir: str, | |||||
maxlenratio: float, | |||||
minlenratio: float, | |||||
batch_size: int, | |||||
dtype: str, | |||||
beam_size: int, | |||||
ngpu: int, | |||||
seed: int, | |||||
ctc_weight: float, | |||||
lm_weight: float, | |||||
ngram_weight: float, | |||||
penalty: float, | |||||
nbest: int, | |||||
num_workers: int, | |||||
log_level: Union[int, str], | |||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||||
key_file: Optional[str], | |||||
asr_train_config: Optional[str], | |||||
asr_model_file: Optional[str], | |||||
lm_train_config: Optional[str], | |||||
lm_file: Optional[str], | |||||
word_lm_train_config: Optional[str], | |||||
word_lm_file: Optional[str], | |||||
ngram_file: Optional[str], | |||||
model_tag: Optional[str], | |||||
token_type: Optional[str], | |||||
bpemodel: Optional[str], | |||||
allow_variable_data_keys: bool, | |||||
transducer_conf: Optional[dict], | |||||
streaming: bool, | |||||
frontend_conf: dict = None, | |||||
): | |||||
assert check_argument_types() | |||||
if batch_size > 1: | |||||
raise NotImplementedError('batch decoding is not implemented') | |||||
if word_lm_train_config is not None: | |||||
raise NotImplementedError('Word LM is not implemented') | |||||
if ngpu > 1: | |||||
raise NotImplementedError('only single GPU decoding is supported') | |||||
logging.basicConfig( | |||||
level=log_level, | |||||
format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||||
) | |||||
if ngpu >= 1: | |||||
device = 'cuda' | |||||
else: | |||||
device = 'cpu' | |||||
# 1. Set random-seed | |||||
set_all_random_seed(seed) | |||||
# 2. Build speech2text | |||||
speech2text_kwargs = dict( | |||||
asr_train_config=asr_train_config, | |||||
asr_model_file=asr_model_file, | |||||
transducer_conf=transducer_conf, | |||||
lm_train_config=lm_train_config, | |||||
lm_file=lm_file, | |||||
ngram_file=ngram_file, | |||||
token_type=token_type, | |||||
bpemodel=bpemodel, | |||||
device=device, | |||||
maxlenratio=maxlenratio, | |||||
minlenratio=minlenratio, | |||||
dtype=dtype, | |||||
beam_size=beam_size, | |||||
ctc_weight=ctc_weight, | |||||
lm_weight=lm_weight, | |||||
ngram_weight=ngram_weight, | |||||
penalty=penalty, | |||||
nbest=nbest, | |||||
streaming=streaming, | |||||
frontend_conf=frontend_conf, | |||||
) | |||||
speech2text = Speech2Text.from_pretrained( | |||||
model_tag=model_tag, | |||||
**speech2text_kwargs, | |||||
) | |||||
# 3. Build data-iterator | |||||
loader = ASRTask.build_streaming_iterator( | |||||
data_path_and_name_and_type, | |||||
dtype=dtype, | |||||
batch_size=batch_size, | |||||
key_file=key_file, | |||||
num_workers=num_workers, | |||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, | |||||
False), | |||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), | |||||
allow_variable_data_keys=allow_variable_data_keys, | |||||
inference=True, | |||||
) | |||||
forward_time_total = 0.0 | |||||
length_total = 0.0 | |||||
# 7 .Start for-loop | |||||
# FIXME(kamo): The output format should be discussed about | |||||
with DatadirWriter(output_dir) as writer: | |||||
for keys, batch in loader: | |||||
assert isinstance(batch, dict), type(batch) | |||||
assert all(isinstance(s, str) for s in keys), keys | |||||
_bs = len(next(iter(batch.values()))) | |||||
assert len(keys) == _bs, f'{len(keys)} != {_bs}' | |||||
batch = { | |||||
k: v[0] | |||||
for k, v in batch.items() if not k.endswith('_lengths') | |||||
} | |||||
# N-best list of (text, token, token_int, hyp_object) | |||||
try: | |||||
time_beg = time.time() | |||||
results = speech2text(**batch) | |||||
time_end = time.time() | |||||
forward_time = time_end - time_beg | |||||
length = results[0][-1] | |||||
results = [results[0][:-1]] | |||||
forward_time_total += forward_time | |||||
length_total += length | |||||
except TooShortUttError as e: | |||||
logging.warning(f'Utterance {keys} {e}') | |||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) | |||||
results = [[' ', ['<space>'], [2], hyp]] * nbest | |||||
# Only supporting batch_size==1 | |||||
key = keys[0] | |||||
for n, (text, token, token_int, | |||||
hyp) in zip(range(1, nbest + 1), results): | |||||
# Create a directory: outdir/{n}best_recog | |||||
ibest_writer = writer[f'{n}best_recog'] | |||||
# Write the result to each file | |||||
ibest_writer['token'][key] = ' '.join(token) | |||||
ibest_writer['token_int'][key] = ' '.join(map(str, token_int)) | |||||
ibest_writer['score'][key] = str(hyp.score) | |||||
if text is not None: | |||||
ibest_writer['text'][key] = text | |||||
logging.info( | |||||
'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}' | |||||
.format(length_total, forward_time_total, | |||||
100 * forward_time_total / length_total)) | |||||
def get_parser(): | |||||
parser = config_argparse.ArgumentParser( | |||||
description='ASR Decoding', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
# Note(kamo): Use '_' instead of '-' as separator. | |||||
# '-' is confusing if written in yaml. | |||||
parser.add_argument( | |||||
'--log_level', | |||||
type=lambda x: x.upper(), | |||||
default='INFO', | |||||
choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'), | |||||
help='The verbose level of logging', | |||||
) | |||||
parser.add_argument('--output_dir', type=str, required=True) | |||||
parser.add_argument( | |||||
'--ngpu', | |||||
type=int, | |||||
default=0, | |||||
help='The number of gpus. 0 indicates CPU mode', | |||||
) | |||||
parser.add_argument('--seed', type=int, default=0, help='Random seed') | |||||
parser.add_argument( | |||||
'--dtype', | |||||
default='float32', | |||||
choices=['float16', 'float32', 'float64'], | |||||
help='Data type', | |||||
) | |||||
parser.add_argument( | |||||
'--num_workers', | |||||
type=int, | |||||
default=1, | |||||
help='The number of workers used for DataLoader', | |||||
) | |||||
group = parser.add_argument_group('Input data related') | |||||
group.add_argument( | |||||
'--data_path_and_name_and_type', | |||||
type=str2triple_str, | |||||
required=True, | |||||
action='append', | |||||
) | |||||
group.add_argument('--key_file', type=str_or_none) | |||||
group.add_argument( | |||||
'--allow_variable_data_keys', type=str2bool, default=False) | |||||
group = parser.add_argument_group('The model configuration related') | |||||
group.add_argument( | |||||
'--asr_train_config', | |||||
type=str, | |||||
help='ASR training configuration', | |||||
) | |||||
group.add_argument( | |||||
'--asr_model_file', | |||||
type=str, | |||||
help='ASR model parameter file', | |||||
) | |||||
group.add_argument( | |||||
'--lm_train_config', | |||||
type=str, | |||||
help='LM training configuration', | |||||
) | |||||
group.add_argument( | |||||
'--lm_file', | |||||
type=str, | |||||
help='LM parameter file', | |||||
) | |||||
group.add_argument( | |||||
'--word_lm_train_config', | |||||
type=str, | |||||
help='Word LM training configuration', | |||||
) | |||||
group.add_argument( | |||||
'--word_lm_file', | |||||
type=str, | |||||
help='Word LM parameter file', | |||||
) | |||||
group.add_argument( | |||||
'--ngram_file', | |||||
type=str, | |||||
help='N-gram parameter file', | |||||
) | |||||
group.add_argument( | |||||
'--model_tag', | |||||
type=str, | |||||
help='Pretrained model tag. If specify this option, *_train_config and ' | |||||
'*_file will be overwritten', | |||||
) | |||||
group = parser.add_argument_group('Beam-search related') | |||||
group.add_argument( | |||||
'--batch_size', | |||||
type=int, | |||||
default=1, | |||||
help='The batch size for inference', | |||||
) | |||||
group.add_argument( | |||||
'--nbest', type=int, default=1, help='Output N-best hypotheses') | |||||
group.add_argument('--beam_size', type=int, default=20, help='Beam size') | |||||
group.add_argument( | |||||
'--penalty', type=float, default=0.0, help='Insertion penalty') | |||||
group.add_argument( | |||||
'--maxlenratio', | |||||
type=float, | |||||
default=0.0, | |||||
help='Input length ratio to obtain max output length. ' | |||||
'If maxlenratio=0.0 (default), it uses a end-detect ' | |||||
'function ' | |||||
'to automatically find maximum hypothesis lengths.' | |||||
'If maxlenratio<0.0, its absolute value is interpreted' | |||||
'as a constant max output length', | |||||
) | |||||
group.add_argument( | |||||
'--minlenratio', | |||||
type=float, | |||||
default=0.0, | |||||
help='Input length ratio to obtain min output length', | |||||
) | |||||
group.add_argument( | |||||
'--ctc_weight', | |||||
type=float, | |||||
default=0.5, | |||||
help='CTC weight in joint decoding', | |||||
) | |||||
group.add_argument( | |||||
'--lm_weight', type=float, default=1.0, help='RNNLM weight') | |||||
group.add_argument( | |||||
'--ngram_weight', type=float, default=0.9, help='ngram weight') | |||||
group.add_argument('--streaming', type=str2bool, default=False) | |||||
group.add_argument( | |||||
'--frontend_conf', | |||||
default=None, | |||||
help='', | |||||
) | |||||
group = parser.add_argument_group('Text converter related') | |||||
group.add_argument( | |||||
'--token_type', | |||||
type=str_or_none, | |||||
default=None, | |||||
choices=['char', 'bpe', None], | |||||
help='The token type for ASR model. ' | |||||
'If not given, refers from the training args', | |||||
) | |||||
group.add_argument( | |||||
'--bpemodel', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The model path of sentencepiece. ' | |||||
'If not given, refers from the training args', | |||||
) | |||||
group.add_argument( | |||||
'--transducer_conf', | |||||
default=None, | |||||
help='The keyword arguments for transducer beam search.', | |||||
) | |||||
return parser | |||||
def asr_inference( | |||||
output_dir: str, | |||||
maxlenratio: float, | |||||
minlenratio: float, | |||||
beam_size: int, | |||||
ngpu: int, | |||||
ctc_weight: float, | |||||
lm_weight: float, | |||||
penalty: float, | |||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||||
asr_train_config: Optional[str], | |||||
asr_model_file: Optional[str], | |||||
nbest: int = 1, | |||||
num_workers: int = 1, | |||||
log_level: Union[int, str] = 'INFO', | |||||
batch_size: int = 1, | |||||
dtype: str = 'float32', | |||||
seed: int = 0, | |||||
key_file: Optional[str] = None, | |||||
lm_train_config: Optional[str] = None, | |||||
lm_file: Optional[str] = None, | |||||
word_lm_train_config: Optional[str] = None, | |||||
word_lm_file: Optional[str] = None, | |||||
ngram_file: Optional[str] = None, | |||||
ngram_weight: float = 0.9, | |||||
model_tag: Optional[str] = None, | |||||
token_type: Optional[str] = None, | |||||
bpemodel: Optional[str] = None, | |||||
allow_variable_data_keys: bool = False, | |||||
transducer_conf: Optional[dict] = None, | |||||
streaming: bool = False, | |||||
frontend_conf: dict = None, | |||||
): | |||||
inference( | |||||
output_dir=output_dir, | |||||
maxlenratio=maxlenratio, | |||||
minlenratio=minlenratio, | |||||
batch_size=batch_size, | |||||
dtype=dtype, | |||||
beam_size=beam_size, | |||||
ngpu=ngpu, | |||||
seed=seed, | |||||
ctc_weight=ctc_weight, | |||||
lm_weight=lm_weight, | |||||
ngram_weight=ngram_weight, | |||||
penalty=penalty, | |||||
nbest=nbest, | |||||
num_workers=num_workers, | |||||
log_level=log_level, | |||||
data_path_and_name_and_type=data_path_and_name_and_type, | |||||
key_file=key_file, | |||||
asr_train_config=asr_train_config, | |||||
asr_model_file=asr_model_file, | |||||
lm_train_config=lm_train_config, | |||||
lm_file=lm_file, | |||||
word_lm_train_config=word_lm_train_config, | |||||
word_lm_file=word_lm_file, | |||||
ngram_file=ngram_file, | |||||
model_tag=model_tag, | |||||
token_type=token_type, | |||||
bpemodel=bpemodel, | |||||
allow_variable_data_keys=allow_variable_data_keys, | |||||
transducer_conf=transducer_conf, | |||||
streaming=streaming, | |||||
frontend_conf=frontend_conf) | |||||
def main(cmd=None): | |||||
print(get_commandline_args(), file=sys.stderr) | |||||
parser = get_parser() | |||||
args = parser.parse_args(cmd) | |||||
kwargs = vars(args) | |||||
kwargs.pop('config', None) | |||||
inference(**kwargs) | |||||
if __name__ == '__main__': | |||||
main() |
@@ -1,193 +0,0 @@ | |||||
import os | |||||
from typing import Any, Dict, List | |||||
import numpy as np | |||||
def type_checking(wav_path: str, | |||||
recog_type: str = None, | |||||
audio_format: str = None, | |||||
workspace: str = None): | |||||
assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist' | |||||
r_recog_type = recog_type | |||||
r_audio_format = audio_format | |||||
r_workspace = workspace | |||||
r_wav_path = wav_path | |||||
if r_workspace is None or len(r_workspace) == 0: | |||||
r_workspace = os.path.join(os.getcwd(), '.tmp') | |||||
if r_recog_type is None: | |||||
if os.path.isfile(wav_path): | |||||
if wav_path.endswith('.wav') or wav_path.endswith('.WAV'): | |||||
r_recog_type = 'wav' | |||||
r_audio_format = 'wav' | |||||
elif os.path.isdir(wav_path): | |||||
dir_name = os.path.basename(wav_path) | |||||
if 'test' in dir_name: | |||||
r_recog_type = 'test' | |||||
elif 'dev' in dir_name: | |||||
r_recog_type = 'dev' | |||||
elif 'train' in dir_name: | |||||
r_recog_type = 'train' | |||||
if r_audio_format is None: | |||||
if find_file_by_ends(wav_path, '.ark'): | |||||
r_audio_format = 'kaldi_ark' | |||||
elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends( | |||||
wav_path, '.WAV'): | |||||
r_audio_format = 'wav' | |||||
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': | |||||
# datasets with kaldi_ark file | |||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) | |||||
elif r_audio_format == 'wav' and r_recog_type != 'wav': | |||||
# datasets with waveform files | |||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) | |||||
return r_recog_type, r_audio_format, r_workspace, r_wav_path | |||||
def find_file_by_ends(dir_path: str, ends: str): | |||||
dir_files = os.listdir(dir_path) | |||||
for file in dir_files: | |||||
file_path = os.path.join(dir_path, file) | |||||
if os.path.isfile(file_path): | |||||
if file_path.endswith(ends): | |||||
return True | |||||
elif os.path.isdir(file_path): | |||||
if find_file_by_ends(file_path, ends): | |||||
return True | |||||
return False | |||||
def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]: | |||||
assert os.path.exists(hyp_text_path), 'hyp_text does not exist' | |||||
assert os.path.exists(ref_text_path), 'ref_text does not exist' | |||||
rst = { | |||||
'Wrd': 0, | |||||
'Corr': 0, | |||||
'Ins': 0, | |||||
'Del': 0, | |||||
'Sub': 0, | |||||
'Snt': 0, | |||||
'Err': 0.0, | |||||
'S.Err': 0.0, | |||||
'wrong_words': 0, | |||||
'wrong_sentences': 0 | |||||
} | |||||
with open(ref_text_path, 'r', encoding='utf-8') as r: | |||||
r_lines = r.readlines() | |||||
with open(hyp_text_path, 'r', encoding='utf-8') as h: | |||||
h_lines = h.readlines() | |||||
for r_line in r_lines: | |||||
r_line_item = r_line.split() | |||||
r_key = r_line_item[0] | |||||
r_sentence = r_line_item[1] | |||||
for h_line in h_lines: | |||||
# find sentence from hyp text | |||||
if r_key in h_line: | |||||
h_line_item = h_line.split() | |||||
h_sentence = h_line_item[1] | |||||
out_item = compute_wer_by_line(h_sentence, r_sentence) | |||||
rst['Wrd'] += out_item['nwords'] | |||||
rst['Corr'] += out_item['cor'] | |||||
rst['wrong_words'] += out_item['wrong'] | |||||
rst['Ins'] += out_item['ins'] | |||||
rst['Del'] += out_item['del'] | |||||
rst['Sub'] += out_item['sub'] | |||||
rst['Snt'] += 1 | |||||
if out_item['wrong'] > 0: | |||||
rst['wrong_sentences'] += 1 | |||||
break | |||||
if rst['Wrd'] > 0: | |||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) | |||||
if rst['Snt'] > 0: | |||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) | |||||
return rst | |||||
def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]: | |||||
len_hyp = len(hyp) | |||||
len_ref = len(ref) | |||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) | |||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) | |||||
for i in range(len_hyp + 1): | |||||
cost_matrix[i][0] = i | |||||
for j in range(len_ref + 1): | |||||
cost_matrix[0][j] = j | |||||
for i in range(1, len_hyp + 1): | |||||
for j in range(1, len_ref + 1): | |||||
if hyp[i - 1] == ref[j - 1]: | |||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1] | |||||
else: | |||||
substitution = cost_matrix[i - 1][j - 1] + 1 | |||||
insertion = cost_matrix[i - 1][j] + 1 | |||||
deletion = cost_matrix[i][j - 1] + 1 | |||||
compare_val = [substitution, insertion, deletion] | |||||
min_val = min(compare_val) | |||||
operation_idx = compare_val.index(min_val) + 1 | |||||
cost_matrix[i][j] = min_val | |||||
ops_matrix[i][j] = operation_idx | |||||
match_idx = [] | |||||
i = len_hyp | |||||
j = len_ref | |||||
rst = { | |||||
'nwords': len_hyp, | |||||
'cor': 0, | |||||
'wrong': 0, | |||||
'ins': 0, | |||||
'del': 0, | |||||
'sub': 0 | |||||
} | |||||
while i >= 0 or j >= 0: | |||||
i_idx = max(0, i) | |||||
j_idx = max(0, j) | |||||
if ops_matrix[i_idx][j_idx] == 0: # correct | |||||
if i - 1 >= 0 and j - 1 >= 0: | |||||
match_idx.append((j - 1, i - 1)) | |||||
rst['cor'] += 1 | |||||
i -= 1 | |||||
j -= 1 | |||||
elif ops_matrix[i_idx][j_idx] == 2: # insert | |||||
i -= 1 | |||||
rst['ins'] += 1 | |||||
elif ops_matrix[i_idx][j_idx] == 3: # delete | |||||
j -= 1 | |||||
rst['del'] += 1 | |||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute | |||||
i -= 1 | |||||
j -= 1 | |||||
rst['sub'] += 1 | |||||
if i < 0 and j >= 0: | |||||
rst['del'] += 1 | |||||
elif j < 0 and i >= 0: | |||||
rst['ins'] += 1 | |||||
match_idx.reverse() | |||||
wrong_cnt = cost_matrix[len_hyp][len_ref] | |||||
rst['wrong'] = wrong_cnt | |||||
return rst |
@@ -1,757 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
"""Decoder definition.""" | |||||
from typing import Any, List, Sequence, Tuple | |||||
import torch | |||||
from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
from espnet.nets.pytorch_backend.transformer.attention import \ | |||||
MultiHeadedAttention | |||||
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer | |||||
from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ | |||||
DynamicConvolution | |||||
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ | |||||
DynamicConvolution2D | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
PositionalEncoding | |||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
from espnet.nets.pytorch_backend.transformer.lightconv import \ | |||||
LightweightConvolution | |||||
from espnet.nets.pytorch_backend.transformer.lightconv2d import \ | |||||
LightweightConvolution2D | |||||
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
PositionwiseFeedForward # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
from espnet.nets.scorer_interface import BatchScorerInterface | |||||
from typeguard import check_argument_types | |||||
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): | |||||
"""Base class of Transfomer decoder module. | |||||
Args: | |||||
vocab_size: output dim | |||||
encoder_output_size: dimension of attention | |||||
attention_heads: the number of heads of multi head attention | |||||
linear_units: the number of units of position-wise feed forward | |||||
num_blocks: the number of decoder blocks | |||||
dropout_rate: dropout rate | |||||
self_attention_dropout_rate: dropout rate for attention | |||||
input_layer: input layer type | |||||
use_output_layer: whether to use output layer | |||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
normalize_before: whether to use layer_norm before the first block | |||||
concat_after: whether to concat attention layer's input and output | |||||
if True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
if False, no additional linear will be applied. | |||||
i.e. x -> x + att(x) | |||||
""" | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
attention_dim = encoder_output_size | |||||
if input_layer == 'embed': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Embedding(vocab_size, attention_dim), | |||||
pos_enc_class(attention_dim, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'linear': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Linear(vocab_size, attention_dim), | |||||
torch.nn.LayerNorm(attention_dim), | |||||
torch.nn.Dropout(dropout_rate), | |||||
torch.nn.ReLU(), | |||||
pos_enc_class(attention_dim, positional_dropout_rate), | |||||
) | |||||
else: | |||||
raise ValueError( | |||||
f"only 'embed' or 'linear' is supported: {input_layer}") | |||||
self.normalize_before = normalize_before | |||||
if self.normalize_before: | |||||
self.after_norm = LayerNorm(attention_dim) | |||||
if use_output_layer: | |||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size) | |||||
else: | |||||
self.output_layer = None | |||||
# Must set by the inheritance | |||||
self.decoders = None | |||||
def forward( | |||||
self, | |||||
hs_pad: torch.Tensor, | |||||
hlens: torch.Tensor, | |||||
ys_in_pad: torch.Tensor, | |||||
ys_in_lens: torch.Tensor, | |||||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
"""Forward decoder. | |||||
Args: | |||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
hlens: (batch) | |||||
ys_in_pad: | |||||
input token ids, int64 (batch, maxlen_out) | |||||
if input_layer == "embed" | |||||
input tensor (batch, maxlen_out, #mels) in the other cases | |||||
ys_in_lens: (batch) | |||||
Returns: | |||||
(tuple): tuple containing: | |||||
x: decoded token score before softmax (batch, maxlen_out, token) | |||||
if use_output_layer is True, | |||||
olens: (batch, ) | |||||
""" | |||||
tgt = ys_in_pad | |||||
# tgt_mask: (B, 1, L) | |||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
# m: (1, L, L) | |||||
m = subsequent_mask( | |||||
tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
# tgt_mask: (B, L, L) | |||||
tgt_mask = tgt_mask & m | |||||
memory = hs_pad | |||||
memory_mask = ( | |||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
memory.device) | |||||
# Padding for Longformer | |||||
if memory_mask.shape[-1] != memory.shape[1]: | |||||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
'constant', False) | |||||
x = self.embed(tgt) | |||||
x, tgt_mask, memory, memory_mask = self.decoders( | |||||
x, tgt_mask, memory, memory_mask) | |||||
if self.normalize_before: | |||||
x = self.after_norm(x) | |||||
if self.output_layer is not None: | |||||
x = self.output_layer(x) | |||||
olens = tgt_mask.sum(1) | |||||
return x, olens | |||||
def forward_one_step( | |||||
self, | |||||
tgt: torch.Tensor, | |||||
tgt_mask: torch.Tensor, | |||||
memory: torch.Tensor, | |||||
cache: List[torch.Tensor] = None, | |||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
"""Forward one step. | |||||
Args: | |||||
tgt: input token ids, int64 (batch, maxlen_out) | |||||
tgt_mask: input token mask, (batch, maxlen_out) | |||||
dtype=torch.uint8 in PyTorch 1.2- | |||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2) | |||||
memory: encoded memory, float32 (batch, maxlen_in, feat) | |||||
cache: cached output list of (batch, max_time_out-1, size) | |||||
Returns: | |||||
y, cache: NN output value and cache per `self.decoders`. | |||||
y.shape` is (batch, maxlen_out, token) | |||||
""" | |||||
x = self.embed(tgt) | |||||
if cache is None: | |||||
cache = [None] * len(self.decoders) | |||||
new_cache = [] | |||||
for c, decoder in zip(cache, self.decoders): | |||||
x, tgt_mask, memory, memory_mask = decoder( | |||||
x, tgt_mask, memory, None, cache=c) | |||||
new_cache.append(x) | |||||
if self.normalize_before: | |||||
y = self.after_norm(x[:, -1]) | |||||
else: | |||||
y = x[:, -1] | |||||
if self.output_layer is not None: | |||||
y = torch.log_softmax(self.output_layer(y), dim=-1) | |||||
return y, new_cache | |||||
def score(self, ys, state, x): | |||||
"""Score.""" | |||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) | |||||
logp, state = self.forward_one_step( | |||||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) | |||||
return logp.squeeze(0), state | |||||
def batch_score(self, ys: torch.Tensor, states: List[Any], | |||||
xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: | |||||
"""Score new token batch. | |||||
Args: | |||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |||||
states (List[Any]): Scorer states for prefix tokens. | |||||
xs (torch.Tensor): | |||||
The encoder feature that generates ys (n_batch, xlen, n_feat). | |||||
Returns: | |||||
tuple[torch.Tensor, List[Any]]: Tuple of | |||||
batchfied scores for next token with shape of `(n_batch, n_vocab)` | |||||
and next state list for ys. | |||||
""" | |||||
# merge states | |||||
n_batch = len(ys) | |||||
n_layers = len(self.decoders) | |||||
if states[0] is None: | |||||
batch_state = None | |||||
else: | |||||
# transpose state of [batch, layer] into [layer, batch] | |||||
batch_state = [ | |||||
torch.stack([states[b][i] for b in range(n_batch)]) | |||||
for i in range(n_layers) | |||||
] | |||||
# batch decoding | |||||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) | |||||
logp, states = self.forward_one_step( | |||||
ys, ys_mask, xs, cache=batch_state) | |||||
# transpose state of [layer, batch] into [batch, layer] | |||||
state_list = [[states[i][b] for i in range(n_layers)] | |||||
for b in range(n_batch)] | |||||
return logp, state_list | |||||
class TransformerDecoder(BaseTransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
self_attention_dropout_rate), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
class ParaformerDecoder(TransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
self_attention_dropout_rate), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
def forward( | |||||
self, | |||||
hs_pad: torch.Tensor, | |||||
hlens: torch.Tensor, | |||||
ys_in_pad: torch.Tensor, | |||||
ys_in_lens: torch.Tensor, | |||||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
"""Forward decoder. | |||||
Args: | |||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
hlens: (batch) | |||||
ys_in_pad: | |||||
input token ids, int64 (batch, maxlen_out) | |||||
if input_layer == "embed" | |||||
input tensor (batch, maxlen_out, #mels) in the other cases | |||||
ys_in_lens: (batch) | |||||
Returns: | |||||
(tuple): tuple containing: | |||||
x: decoded token score before softmax (batch, maxlen_out, token) | |||||
if use_output_layer is True, | |||||
olens: (batch, ) | |||||
""" | |||||
tgt = ys_in_pad | |||||
# tgt_mask: (B, 1, L) | |||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
# m: (1, L, L) | |||||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
# tgt_mask: (B, L, L) | |||||
# tgt_mask = tgt_mask & m | |||||
memory = hs_pad | |||||
memory_mask = ( | |||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
memory.device) | |||||
# Padding for Longformer | |||||
if memory_mask.shape[-1] != memory.shape[1]: | |||||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
'constant', False) | |||||
# x = self.embed(tgt) | |||||
x = tgt | |||||
x, tgt_mask, memory, memory_mask = self.decoders( | |||||
x, tgt_mask, memory, memory_mask) | |||||
if self.normalize_before: | |||||
x = self.after_norm(x) | |||||
if self.output_layer is not None: | |||||
x = self.output_layer(x) | |||||
olens = tgt_mask.sum(1) | |||||
return x, olens | |||||
class ParaformerDecoderBertEmbed(TransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
embeds_id: int = 2, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
embeds_id, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
self_attention_dropout_rate), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
if embeds_id == num_blocks: | |||||
self.decoders2 = None | |||||
else: | |||||
self.decoders2 = repeat( | |||||
num_blocks - embeds_id, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
self_attention_dropout_rate), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
def forward( | |||||
self, | |||||
hs_pad: torch.Tensor, | |||||
hlens: torch.Tensor, | |||||
ys_in_pad: torch.Tensor, | |||||
ys_in_lens: torch.Tensor, | |||||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
"""Forward decoder. | |||||
Args: | |||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
hlens: (batch) | |||||
ys_in_pad: | |||||
input token ids, int64 (batch, maxlen_out) | |||||
if input_layer == "embed" | |||||
input tensor (batch, maxlen_out, #mels) in the other cases | |||||
ys_in_lens: (batch) | |||||
Returns: | |||||
(tuple): tuple containing: | |||||
x: decoded token score before softmax (batch, maxlen_out, token) | |||||
if use_output_layer is True, | |||||
olens: (batch, ) | |||||
""" | |||||
tgt = ys_in_pad | |||||
# tgt_mask: (B, 1, L) | |||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
# m: (1, L, L) | |||||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
# tgt_mask: (B, L, L) | |||||
# tgt_mask = tgt_mask & m | |||||
memory = hs_pad | |||||
memory_mask = ( | |||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
memory.device) | |||||
# Padding for Longformer | |||||
if memory_mask.shape[-1] != memory.shape[1]: | |||||
padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
'constant', False) | |||||
# x = self.embed(tgt) | |||||
x = tgt | |||||
x, tgt_mask, memory, memory_mask = self.decoders( | |||||
x, tgt_mask, memory, memory_mask) | |||||
embeds_outputs = x | |||||
if self.decoders2 is not None: | |||||
x, tgt_mask, memory, memory_mask = self.decoders2( | |||||
x, tgt_mask, memory, memory_mask) | |||||
if self.normalize_before: | |||||
x = self.after_norm(x) | |||||
if self.output_layer is not None: | |||||
x = self.output_layer(x) | |||||
olens = tgt_mask.sum(1) | |||||
return x, olens, embeds_outputs | |||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
conv_wshare: int = 4, | |||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
conv_usebias: int = False, | |||||
): | |||||
assert check_argument_types() | |||||
if len(conv_kernel_length) != num_blocks: | |||||
raise ValueError( | |||||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
f'{len(conv_kernel_length)} != {num_blocks}') | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
LightweightConvolution( | |||||
wshare=conv_wshare, | |||||
n_feat=attention_dim, | |||||
dropout_rate=self_attention_dropout_rate, | |||||
kernel_size=conv_kernel_length[lnum], | |||||
use_kernel_mask=True, | |||||
use_bias=conv_usebias, | |||||
), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
conv_wshare: int = 4, | |||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
conv_usebias: int = False, | |||||
): | |||||
assert check_argument_types() | |||||
if len(conv_kernel_length) != num_blocks: | |||||
raise ValueError( | |||||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
f'{len(conv_kernel_length)} != {num_blocks}') | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
LightweightConvolution2D( | |||||
wshare=conv_wshare, | |||||
n_feat=attention_dim, | |||||
dropout_rate=self_attention_dropout_rate, | |||||
kernel_size=conv_kernel_length[lnum], | |||||
use_kernel_mask=True, | |||||
use_bias=conv_usebias, | |||||
), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
conv_wshare: int = 4, | |||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
conv_usebias: int = False, | |||||
): | |||||
assert check_argument_types() | |||||
if len(conv_kernel_length) != num_blocks: | |||||
raise ValueError( | |||||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
f'{len(conv_kernel_length)} != {num_blocks}') | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
DynamicConvolution( | |||||
wshare=conv_wshare, | |||||
n_feat=attention_dim, | |||||
dropout_rate=self_attention_dropout_rate, | |||||
kernel_size=conv_kernel_length[lnum], | |||||
use_kernel_mask=True, | |||||
use_bias=conv_usebias, | |||||
), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||||
def __init__( | |||||
self, | |||||
vocab_size: int, | |||||
encoder_output_size: int, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
self_attention_dropout_rate: float = 0.0, | |||||
src_attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'embed', | |||||
use_output_layer: bool = True, | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
conv_wshare: int = 4, | |||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
conv_usebias: int = False, | |||||
): | |||||
assert check_argument_types() | |||||
if len(conv_kernel_length) != num_blocks: | |||||
raise ValueError( | |||||
'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
f'{len(conv_kernel_length)} != {num_blocks}') | |||||
super().__init__( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
dropout_rate=dropout_rate, | |||||
positional_dropout_rate=positional_dropout_rate, | |||||
input_layer=input_layer, | |||||
use_output_layer=use_output_layer, | |||||
pos_enc_class=pos_enc_class, | |||||
normalize_before=normalize_before, | |||||
) | |||||
attention_dim = encoder_output_size | |||||
self.decoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: DecoderLayer( | |||||
attention_dim, | |||||
DynamicConvolution2D( | |||||
wshare=conv_wshare, | |||||
n_feat=attention_dim, | |||||
dropout_rate=self_attention_dropout_rate, | |||||
kernel_size=conv_kernel_length[lnum], | |||||
use_kernel_mask=True, | |||||
use_bias=conv_usebias, | |||||
), | |||||
MultiHeadedAttention(attention_heads, attention_dim, | |||||
src_attention_dropout_rate), | |||||
PositionwiseFeedForward(attention_dim, linear_units, | |||||
dropout_rate), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) |
@@ -1,710 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
"""Conformer encoder definition.""" | |||||
import logging | |||||
from typing import List, Optional, Tuple, Union | |||||
import torch | |||||
from espnet2.asr.ctc import CTC | |||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule | |||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer | |||||
from espnet.nets.pytorch_backend.nets_utils import (get_activation, | |||||
make_pad_mask) | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
LegacyRelPositionalEncoding # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
PositionalEncoding # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
RelPositionalEncoding # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
ScaledPositionalEncoding # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||||
Conv1dLinear, MultiLayeredConv1d) | |||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
PositionwiseFeedForward # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||||
Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||||
from typeguard import check_argument_types | |||||
from ...nets.pytorch_backend.transformer.attention import \ | |||||
LegacyRelPositionMultiHeadedAttention # noqa: H301 | |||||
from ...nets.pytorch_backend.transformer.attention import \ | |||||
MultiHeadedAttention # noqa: H301 | |||||
from ...nets.pytorch_backend.transformer.attention import \ | |||||
RelPositionMultiHeadedAttention # noqa: H301 | |||||
from ...nets.pytorch_backend.transformer.attention import ( | |||||
LegacyRelPositionMultiHeadedAttentionSANM, | |||||
RelPositionMultiHeadedAttentionSANM) | |||||
class ConformerEncoder(AbsEncoder): | |||||
"""Conformer encoder module. | |||||
Args: | |||||
input_size (int): Input dimension. | |||||
output_size (int): Dimension of attention. | |||||
attention_heads (int): The number of heads of multi head attention. | |||||
linear_units (int): The number of units of position-wise feed forward. | |||||
num_blocks (int): The number of decoder blocks. | |||||
dropout_rate (float): Dropout rate. | |||||
attention_dropout_rate (float): Dropout rate in attention. | |||||
positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||||
input_layer (Union[str, torch.nn.Module]): Input layer type. | |||||
normalize_before (bool): Whether to use layer_norm before the first block. | |||||
concat_after (bool): Whether to concat attention layer's input and output. | |||||
If True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
If False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||||
rel_pos_type (str): Whether to use the latest relative positional encoding or | |||||
the legacy one. The legacy relative positional encoding will be deprecated | |||||
in the future. More Details can be found in | |||||
https://github.com/espnet/espnet/pull/2816. | |||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||||
encoder_attn_layer_type (str): Encoder attention layer type. | |||||
activation_type (str): Encoder activation function type. | |||||
macaron_style (bool): Whether to use macaron style for positionwise layer. | |||||
use_cnn_module (bool): Whether to use convolution module. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
cnn_module_kernel (int): Kernerl size of convolution module. | |||||
padding_idx (int): Padding idx for input_layer=embed. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
output_size: int = 256, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'conv2d', | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
positionwise_layer_type: str = 'linear', | |||||
positionwise_conv_kernel_size: int = 3, | |||||
macaron_style: bool = False, | |||||
rel_pos_type: str = 'legacy', | |||||
pos_enc_layer_type: str = 'rel_pos', | |||||
selfattention_layer_type: str = 'rel_selfattn', | |||||
activation_type: str = 'swish', | |||||
use_cnn_module: bool = True, | |||||
zero_triu: bool = False, | |||||
cnn_module_kernel: int = 31, | |||||
padding_idx: int = -1, | |||||
interctc_layer_idx: List[int] = [], | |||||
interctc_use_conditioning: bool = False, | |||||
stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
self._output_size = output_size | |||||
if rel_pos_type == 'legacy': | |||||
if pos_enc_layer_type == 'rel_pos': | |||||
pos_enc_layer_type = 'legacy_rel_pos' | |||||
if selfattention_layer_type == 'rel_selfattn': | |||||
selfattention_layer_type = 'legacy_rel_selfattn' | |||||
elif rel_pos_type == 'latest': | |||||
assert selfattention_layer_type != 'legacy_rel_selfattn' | |||||
assert pos_enc_layer_type != 'legacy_rel_pos' | |||||
else: | |||||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||||
activation = get_activation(activation_type) | |||||
if pos_enc_layer_type == 'abs_pos': | |||||
pos_enc_class = PositionalEncoding | |||||
elif pos_enc_layer_type == 'scaled_abs_pos': | |||||
pos_enc_class = ScaledPositionalEncoding | |||||
elif pos_enc_layer_type == 'rel_pos': | |||||
assert selfattention_layer_type == 'rel_selfattn' | |||||
pos_enc_class = RelPositionalEncoding | |||||
elif pos_enc_layer_type == 'legacy_rel_pos': | |||||
assert selfattention_layer_type == 'legacy_rel_selfattn' | |||||
pos_enc_class = LegacyRelPositionalEncoding | |||||
else: | |||||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||||
if input_layer == 'linear': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Linear(input_size, output_size), | |||||
torch.nn.LayerNorm(output_size), | |||||
torch.nn.Dropout(dropout_rate), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d': | |||||
self.embed = Conv2dSubsampling( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d2': | |||||
self.embed = Conv2dSubsampling2( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d6': | |||||
self.embed = Conv2dSubsampling6( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d8': | |||||
self.embed = Conv2dSubsampling8( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'embed': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Embedding( | |||||
input_size, output_size, padding_idx=padding_idx), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif isinstance(input_layer, torch.nn.Module): | |||||
self.embed = torch.nn.Sequential( | |||||
input_layer, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer is None: | |||||
self.embed = torch.nn.Sequential( | |||||
pos_enc_class(output_size, positional_dropout_rate)) | |||||
else: | |||||
raise ValueError('unknown input_layer: ' + input_layer) | |||||
self.normalize_before = normalize_before | |||||
if positionwise_layer_type == 'linear': | |||||
positionwise_layer = PositionwiseFeedForward | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
dropout_rate, | |||||
activation, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d': | |||||
positionwise_layer = MultiLayeredConv1d | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d-linear': | |||||
positionwise_layer = Conv1dLinear | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
else: | |||||
raise NotImplementedError('Support only linear or conv1d.') | |||||
if selfattention_layer_type == 'selfattn': | |||||
encoder_selfattn_layer = MultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
elif selfattention_layer_type == 'legacy_rel_selfattn': | |||||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
elif selfattention_layer_type == 'rel_selfattn': | |||||
assert pos_enc_layer_type == 'rel_pos' | |||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
zero_triu, | |||||
) | |||||
else: | |||||
raise ValueError('unknown encoder_attn_layer: ' | |||||
+ selfattention_layer_type) | |||||
convolution_layer = ConvolutionModule | |||||
convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||||
if isinstance(stochastic_depth_rate, float): | |||||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||||
if len(stochastic_depth_rate) != num_blocks: | |||||
raise ValueError( | |||||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||||
f'should be equal to num_blocks ({num_blocks})') | |||||
self.encoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: EncoderLayer( | |||||
output_size, | |||||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
positionwise_layer(*positionwise_layer_args), | |||||
positionwise_layer(*positionwise_layer_args) | |||||
if macaron_style else None, | |||||
convolution_layer(*convolution_layer_args) | |||||
if use_cnn_module else None, | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
stochastic_depth_rate[lnum], | |||||
), | |||||
) | |||||
if self.normalize_before: | |||||
self.after_norm = LayerNorm(output_size) | |||||
self.interctc_layer_idx = interctc_layer_idx | |||||
if len(interctc_layer_idx) > 0: | |||||
assert 0 < min(interctc_layer_idx) and max( | |||||
interctc_layer_idx) < num_blocks | |||||
self.interctc_use_conditioning = interctc_use_conditioning | |||||
self.conditioning_layer = None | |||||
def output_size(self) -> int: | |||||
return self._output_size | |||||
def forward( | |||||
self, | |||||
xs_pad: torch.Tensor, | |||||
ilens: torch.Tensor, | |||||
prev_states: torch.Tensor = None, | |||||
ctc: CTC = None, | |||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
"""Calculate forward propagation. | |||||
Args: | |||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||||
ilens (torch.Tensor): Input length (#batch). | |||||
prev_states (torch.Tensor): Not to be used now. | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, L, output_size). | |||||
torch.Tensor: Output length (#batch). | |||||
torch.Tensor: Not to be used now. | |||||
""" | |||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
if (isinstance(self.embed, Conv2dSubsampling) | |||||
or isinstance(self.embed, Conv2dSubsampling2) | |||||
or isinstance(self.embed, Conv2dSubsampling6) | |||||
or isinstance(self.embed, Conv2dSubsampling8)): | |||||
short_status, limit_size = check_short_utt(self.embed, | |||||
xs_pad.size(1)) | |||||
if short_status: | |||||
raise TooShortUttError( | |||||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
+ # noqa: * | |||||
f'(it needs more than {limit_size} frames), return empty results', # noqa: * | |||||
xs_pad.size(1), | |||||
limit_size) # noqa: * | |||||
xs_pad, masks = self.embed(xs_pad, masks) | |||||
else: | |||||
xs_pad = self.embed(xs_pad) | |||||
intermediate_outs = [] | |||||
if len(self.interctc_layer_idx) == 0: | |||||
xs_pad, masks = self.encoders(xs_pad, masks) | |||||
else: | |||||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
if layer_idx + 1 in self.interctc_layer_idx: | |||||
encoder_out = xs_pad | |||||
if isinstance(encoder_out, tuple): | |||||
encoder_out = encoder_out[0] | |||||
# intermediate outputs are also normalized | |||||
if self.normalize_before: | |||||
encoder_out = self.after_norm(encoder_out) | |||||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
if self.interctc_use_conditioning: | |||||
ctc_out = ctc.softmax(encoder_out) | |||||
if isinstance(xs_pad, tuple): | |||||
x, pos_emb = xs_pad | |||||
x = x + self.conditioning_layer(ctc_out) | |||||
xs_pad = (x, pos_emb) | |||||
else: | |||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
if isinstance(xs_pad, tuple): | |||||
xs_pad = xs_pad[0] | |||||
if self.normalize_before: | |||||
xs_pad = self.after_norm(xs_pad) | |||||
olens = masks.squeeze(1).sum(1) | |||||
if len(intermediate_outs) > 0: | |||||
return (xs_pad, intermediate_outs), olens, None | |||||
return xs_pad, olens, None | |||||
class SANMEncoder_v2(AbsEncoder): | |||||
"""Conformer encoder module. | |||||
Args: | |||||
input_size (int): Input dimension. | |||||
output_size (int): Dimension of attention. | |||||
attention_heads (int): The number of heads of multi head attention. | |||||
linear_units (int): The number of units of position-wise feed forward. | |||||
num_blocks (int): The number of decoder blocks. | |||||
dropout_rate (float): Dropout rate. | |||||
attention_dropout_rate (float): Dropout rate in attention. | |||||
positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||||
input_layer (Union[str, torch.nn.Module]): Input layer type. | |||||
normalize_before (bool): Whether to use layer_norm before the first block. | |||||
concat_after (bool): Whether to concat attention layer's input and output. | |||||
If True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
If False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||||
rel_pos_type (str): Whether to use the latest relative positional encoding or | |||||
the legacy one. The legacy relative positional encoding will be deprecated | |||||
in the future. More Details can be found in | |||||
https://github.com/espnet/espnet/pull/2816. | |||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||||
encoder_attn_layer_type (str): Encoder attention layer type. | |||||
activation_type (str): Encoder activation function type. | |||||
macaron_style (bool): Whether to use macaron style for positionwise layer. | |||||
use_cnn_module (bool): Whether to use convolution module. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
cnn_module_kernel (int): Kernerl size of convolution module. | |||||
padding_idx (int): Padding idx for input_layer=embed. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
output_size: int = 256, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
attention_dropout_rate: float = 0.0, | |||||
input_layer: str = 'conv2d', | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
positionwise_layer_type: str = 'linear', | |||||
positionwise_conv_kernel_size: int = 3, | |||||
macaron_style: bool = False, | |||||
rel_pos_type: str = 'legacy', | |||||
pos_enc_layer_type: str = 'rel_pos', | |||||
selfattention_layer_type: str = 'rel_selfattn', | |||||
activation_type: str = 'swish', | |||||
use_cnn_module: bool = False, | |||||
sanm_shfit: int = 0, | |||||
zero_triu: bool = False, | |||||
cnn_module_kernel: int = 31, | |||||
padding_idx: int = -1, | |||||
interctc_layer_idx: List[int] = [], | |||||
interctc_use_conditioning: bool = False, | |||||
stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
self._output_size = output_size | |||||
if rel_pos_type == 'legacy': | |||||
if pos_enc_layer_type == 'rel_pos': | |||||
pos_enc_layer_type = 'legacy_rel_pos' | |||||
if selfattention_layer_type == 'rel_selfattn': | |||||
selfattention_layer_type = 'legacy_rel_selfattn' | |||||
if selfattention_layer_type == 'rel_selfattnsanm': | |||||
selfattention_layer_type = 'legacy_rel_selfattnsanm' | |||||
elif rel_pos_type == 'latest': | |||||
assert selfattention_layer_type != 'legacy_rel_selfattn' | |||||
assert pos_enc_layer_type != 'legacy_rel_pos' | |||||
else: | |||||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||||
activation = get_activation(activation_type) | |||||
if pos_enc_layer_type == 'abs_pos': | |||||
pos_enc_class = PositionalEncoding | |||||
elif pos_enc_layer_type == 'scaled_abs_pos': | |||||
pos_enc_class = ScaledPositionalEncoding | |||||
elif pos_enc_layer_type == 'rel_pos': | |||||
# assert selfattention_layer_type == "rel_selfattn" | |||||
pos_enc_class = RelPositionalEncoding | |||||
elif pos_enc_layer_type == 'legacy_rel_pos': | |||||
# assert selfattention_layer_type == "legacy_rel_selfattn" | |||||
pos_enc_class = LegacyRelPositionalEncoding | |||||
logging.warning( | |||||
'Using legacy_rel_pos and it will be deprecated in the future.' | |||||
) | |||||
else: | |||||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||||
if input_layer == 'linear': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Linear(input_size, output_size), | |||||
torch.nn.LayerNorm(output_size), | |||||
torch.nn.Dropout(dropout_rate), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d': | |||||
self.embed = Conv2dSubsampling( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d2': | |||||
self.embed = Conv2dSubsampling2( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d6': | |||||
self.embed = Conv2dSubsampling6( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d8': | |||||
self.embed = Conv2dSubsampling8( | |||||
input_size, | |||||
output_size, | |||||
dropout_rate, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'embed': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Embedding( | |||||
input_size, output_size, padding_idx=padding_idx), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif isinstance(input_layer, torch.nn.Module): | |||||
self.embed = torch.nn.Sequential( | |||||
input_layer, | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer is None: | |||||
self.embed = torch.nn.Sequential( | |||||
pos_enc_class(output_size, positional_dropout_rate)) | |||||
else: | |||||
raise ValueError('unknown input_layer: ' + input_layer) | |||||
self.normalize_before = normalize_before | |||||
if positionwise_layer_type == 'linear': | |||||
positionwise_layer = PositionwiseFeedForward | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
dropout_rate, | |||||
activation, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d': | |||||
positionwise_layer = MultiLayeredConv1d | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d-linear': | |||||
positionwise_layer = Conv1dLinear | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
else: | |||||
raise NotImplementedError('Support only linear or conv1d.') | |||||
if selfattention_layer_type == 'selfattn': | |||||
encoder_selfattn_layer = MultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
elif selfattention_layer_type == 'legacy_rel_selfattn': | |||||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
logging.warning( | |||||
'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||||
) | |||||
elif selfattention_layer_type == 'legacy_rel_selfattnsanm': | |||||
assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
logging.warning( | |||||
'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||||
) | |||||
elif selfattention_layer_type == 'rel_selfattn': | |||||
assert pos_enc_layer_type == 'rel_pos' | |||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
zero_triu, | |||||
) | |||||
elif selfattention_layer_type == 'rel_selfattnsanm': | |||||
assert pos_enc_layer_type == 'rel_pos' | |||||
encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
zero_triu, | |||||
cnn_module_kernel, | |||||
sanm_shfit, | |||||
) | |||||
else: | |||||
raise ValueError('unknown encoder_attn_layer: ' | |||||
+ selfattention_layer_type) | |||||
convolution_layer = ConvolutionModule | |||||
convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||||
if isinstance(stochastic_depth_rate, float): | |||||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||||
if len(stochastic_depth_rate) != num_blocks: | |||||
raise ValueError( | |||||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||||
f'should be equal to num_blocks ({num_blocks})') | |||||
self.encoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: EncoderLayer( | |||||
output_size, | |||||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
positionwise_layer(*positionwise_layer_args), | |||||
positionwise_layer(*positionwise_layer_args) | |||||
if macaron_style else None, | |||||
convolution_layer(*convolution_layer_args) | |||||
if use_cnn_module else None, | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
stochastic_depth_rate[lnum], | |||||
), | |||||
) | |||||
if self.normalize_before: | |||||
self.after_norm = LayerNorm(output_size) | |||||
self.interctc_layer_idx = interctc_layer_idx | |||||
if len(interctc_layer_idx) > 0: | |||||
assert 0 < min(interctc_layer_idx) and max( | |||||
interctc_layer_idx) < num_blocks | |||||
self.interctc_use_conditioning = interctc_use_conditioning | |||||
self.conditioning_layer = None | |||||
def output_size(self) -> int: | |||||
return self._output_size | |||||
def forward( | |||||
self, | |||||
xs_pad: torch.Tensor, | |||||
ilens: torch.Tensor, | |||||
prev_states: torch.Tensor = None, | |||||
ctc: CTC = None, | |||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
"""Calculate forward propagation. | |||||
Args: | |||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||||
ilens (torch.Tensor): Input length (#batch). | |||||
prev_states (torch.Tensor): Not to be used now. | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, L, output_size). | |||||
torch.Tensor: Output length (#batch). | |||||
torch.Tensor: Not to be used now. | |||||
""" | |||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
if (isinstance(self.embed, Conv2dSubsampling) | |||||
or isinstance(self.embed, Conv2dSubsampling2) | |||||
or isinstance(self.embed, Conv2dSubsampling6) | |||||
or isinstance(self.embed, Conv2dSubsampling8)): | |||||
short_status, limit_size = check_short_utt(self.embed, | |||||
xs_pad.size(1)) | |||||
if short_status: | |||||
raise TooShortUttError( | |||||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
+ # noqa: * | |||||
f'(it needs more than {limit_size} frames), return empty results', | |||||
xs_pad.size(1), | |||||
limit_size) # noqa: * | |||||
xs_pad, masks = self.embed(xs_pad, masks) | |||||
else: | |||||
xs_pad = self.embed(xs_pad) | |||||
intermediate_outs = [] | |||||
if len(self.interctc_layer_idx) == 0: | |||||
xs_pad, masks = self.encoders(xs_pad, masks) | |||||
else: | |||||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
if layer_idx + 1 in self.interctc_layer_idx: | |||||
encoder_out = xs_pad | |||||
if isinstance(encoder_out, tuple): | |||||
encoder_out = encoder_out[0] | |||||
# intermediate outputs are also normalized | |||||
if self.normalize_before: | |||||
encoder_out = self.after_norm(encoder_out) | |||||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
if self.interctc_use_conditioning: | |||||
ctc_out = ctc.softmax(encoder_out) | |||||
if isinstance(xs_pad, tuple): | |||||
x, pos_emb = xs_pad | |||||
x = x + self.conditioning_layer(ctc_out) | |||||
xs_pad = (x, pos_emb) | |||||
else: | |||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
if isinstance(xs_pad, tuple): | |||||
xs_pad = xs_pad[0] | |||||
if self.normalize_before: | |||||
xs_pad = self.after_norm(xs_pad) | |||||
olens = masks.squeeze(1).sum(1) | |||||
if len(intermediate_outs) > 0: | |||||
return (xs_pad, intermediate_outs), olens, None | |||||
return xs_pad, olens, None |
@@ -1,500 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
"""Transformer encoder definition.""" | |||||
import logging | |||||
from typing import List, Optional, Sequence, Tuple, Union | |||||
import torch | |||||
from espnet2.asr.ctc import CTC | |||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
PositionalEncoding | |||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||||
Conv1dLinear, MultiLayeredConv1d) | |||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
PositionwiseFeedForward # noqa: H301 | |||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||||
Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||||
from typeguard import check_argument_types | |||||
from ...asr.streaming_utilis.chunk_utilis import overlap_chunk | |||||
from ...nets.pytorch_backend.transformer.attention import ( | |||||
MultiHeadedAttention, MultiHeadedAttentionSANM) | |||||
from ...nets.pytorch_backend.transformer.encoder_layer import ( | |||||
EncoderLayer, EncoderLayerChunk) | |||||
class SANMEncoder(AbsEncoder): | |||||
"""Transformer encoder module. | |||||
Args: | |||||
input_size: input dim | |||||
output_size: dimension of attention | |||||
attention_heads: the number of heads of multi head attention | |||||
linear_units: the number of units of position-wise feed forward | |||||
num_blocks: the number of decoder blocks | |||||
dropout_rate: dropout rate | |||||
attention_dropout_rate: dropout rate in attention | |||||
positional_dropout_rate: dropout rate after adding positional encoding | |||||
input_layer: input layer type | |||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
normalize_before: whether to use layer_norm before the first block | |||||
concat_after: whether to concat attention layer's input and output | |||||
if True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
if False, no additional linear will be applied. | |||||
i.e. x -> x + att(x) | |||||
positionwise_layer_type: linear of conv1d | |||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||||
padding_idx: padding_idx for input_layer=embed | |||||
""" | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
output_size: int = 256, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
attention_dropout_rate: float = 0.0, | |||||
input_layer: Optional[str] = 'conv2d', | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
positionwise_layer_type: str = 'linear', | |||||
positionwise_conv_kernel_size: int = 1, | |||||
padding_idx: int = -1, | |||||
interctc_layer_idx: List[int] = [], | |||||
interctc_use_conditioning: bool = False, | |||||
kernel_size: int = 11, | |||||
sanm_shfit: int = 0, | |||||
selfattention_layer_type: str = 'sanm', | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
self._output_size = output_size | |||||
if input_layer == 'linear': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Linear(input_size, output_size), | |||||
torch.nn.LayerNorm(output_size), | |||||
torch.nn.Dropout(dropout_rate), | |||||
torch.nn.ReLU(), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d': | |||||
self.embed = Conv2dSubsampling(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d2': | |||||
self.embed = Conv2dSubsampling2(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d6': | |||||
self.embed = Conv2dSubsampling6(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d8': | |||||
self.embed = Conv2dSubsampling8(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'embed': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Embedding( | |||||
input_size, output_size, padding_idx=padding_idx), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer is None: | |||||
if input_size == output_size: | |||||
self.embed = None | |||||
else: | |||||
self.embed = torch.nn.Linear(input_size, output_size) | |||||
else: | |||||
raise ValueError('unknown input_layer: ' + input_layer) | |||||
self.normalize_before = normalize_before | |||||
if positionwise_layer_type == 'linear': | |||||
positionwise_layer = PositionwiseFeedForward | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d': | |||||
positionwise_layer = MultiLayeredConv1d | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d-linear': | |||||
positionwise_layer = Conv1dLinear | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
else: | |||||
raise NotImplementedError('Support only linear or conv1d.') | |||||
if selfattention_layer_type == 'selfattn': | |||||
encoder_selfattn_layer = MultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
elif selfattention_layer_type == 'sanm': | |||||
encoder_selfattn_layer = MultiHeadedAttentionSANM | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
kernel_size, | |||||
sanm_shfit, | |||||
) | |||||
self.encoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: EncoderLayer( | |||||
output_size, | |||||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
positionwise_layer(*positionwise_layer_args), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
if self.normalize_before: | |||||
self.after_norm = LayerNorm(output_size) | |||||
self.interctc_layer_idx = interctc_layer_idx | |||||
if len(interctc_layer_idx) > 0: | |||||
assert 0 < min(interctc_layer_idx) and max( | |||||
interctc_layer_idx) < num_blocks | |||||
self.interctc_use_conditioning = interctc_use_conditioning | |||||
self.conditioning_layer = None | |||||
def output_size(self) -> int: | |||||
return self._output_size | |||||
def forward( | |||||
self, | |||||
xs_pad: torch.Tensor, | |||||
ilens: torch.Tensor, | |||||
prev_states: torch.Tensor = None, | |||||
ctc: CTC = None, | |||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
"""Embed positions in tensor. | |||||
Args: | |||||
xs_pad: input tensor (B, L, D) | |||||
ilens: input length (B) | |||||
prev_states: Not to be used now. | |||||
Returns: | |||||
position embedded tensor and mask | |||||
""" | |||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
if self.embed is None: | |||||
xs_pad = xs_pad | |||||
elif (isinstance(self.embed, Conv2dSubsampling) | |||||
or isinstance(self.embed, Conv2dSubsampling2) | |||||
or isinstance(self.embed, Conv2dSubsampling6) | |||||
or isinstance(self.embed, Conv2dSubsampling8)): | |||||
short_status, limit_size = check_short_utt(self.embed, | |||||
xs_pad.size(1)) | |||||
if short_status: | |||||
raise TooShortUttError( | |||||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
+ # noqa: * | |||||
f'(it needs more than {limit_size} frames), return empty results', | |||||
xs_pad.size(1), | |||||
limit_size, | |||||
) | |||||
xs_pad, masks = self.embed(xs_pad, masks) | |||||
else: | |||||
xs_pad = self.embed(xs_pad) | |||||
intermediate_outs = [] | |||||
if len(self.interctc_layer_idx) == 0: | |||||
xs_pad, masks = self.encoders(xs_pad, masks) | |||||
else: | |||||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
if layer_idx + 1 in self.interctc_layer_idx: | |||||
encoder_out = xs_pad | |||||
# intermediate outputs are also normalized | |||||
if self.normalize_before: | |||||
encoder_out = self.after_norm(encoder_out) | |||||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
if self.interctc_use_conditioning: | |||||
ctc_out = ctc.softmax(encoder_out) | |||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
if self.normalize_before: | |||||
xs_pad = self.after_norm(xs_pad) | |||||
olens = masks.squeeze(1).sum(1) | |||||
if len(intermediate_outs) > 0: | |||||
return (xs_pad, intermediate_outs), olens, None | |||||
return xs_pad, olens, None | |||||
class SANMEncoderChunk(AbsEncoder): | |||||
"""Transformer encoder module. | |||||
Args: | |||||
input_size: input dim | |||||
output_size: dimension of attention | |||||
attention_heads: the number of heads of multi head attention | |||||
linear_units: the number of units of position-wise feed forward | |||||
num_blocks: the number of decoder blocks | |||||
dropout_rate: dropout rate | |||||
attention_dropout_rate: dropout rate in attention | |||||
positional_dropout_rate: dropout rate after adding positional encoding | |||||
input_layer: input layer type | |||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
normalize_before: whether to use layer_norm before the first block | |||||
concat_after: whether to concat attention layer's input and output | |||||
if True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
if False, no additional linear will be applied. | |||||
i.e. x -> x + att(x) | |||||
positionwise_layer_type: linear of conv1d | |||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||||
padding_idx: padding_idx for input_layer=embed | |||||
""" | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
output_size: int = 256, | |||||
attention_heads: int = 4, | |||||
linear_units: int = 2048, | |||||
num_blocks: int = 6, | |||||
dropout_rate: float = 0.1, | |||||
positional_dropout_rate: float = 0.1, | |||||
attention_dropout_rate: float = 0.0, | |||||
input_layer: Optional[str] = 'conv2d', | |||||
pos_enc_class=PositionalEncoding, | |||||
normalize_before: bool = True, | |||||
concat_after: bool = False, | |||||
positionwise_layer_type: str = 'linear', | |||||
positionwise_conv_kernel_size: int = 1, | |||||
padding_idx: int = -1, | |||||
interctc_layer_idx: List[int] = [], | |||||
interctc_use_conditioning: bool = False, | |||||
kernel_size: int = 11, | |||||
sanm_shfit: int = 0, | |||||
selfattention_layer_type: str = 'sanm', | |||||
chunk_size: Union[int, Sequence[int]] = (16, ), | |||||
stride: Union[int, Sequence[int]] = (10, ), | |||||
pad_left: Union[int, Sequence[int]] = (0, ), | |||||
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ), | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
self._output_size = output_size | |||||
if input_layer == 'linear': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Linear(input_size, output_size), | |||||
torch.nn.LayerNorm(output_size), | |||||
torch.nn.Dropout(dropout_rate), | |||||
torch.nn.ReLU(), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer == 'conv2d': | |||||
self.embed = Conv2dSubsampling(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d2': | |||||
self.embed = Conv2dSubsampling2(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d6': | |||||
self.embed = Conv2dSubsampling6(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'conv2d8': | |||||
self.embed = Conv2dSubsampling8(input_size, output_size, | |||||
dropout_rate) | |||||
elif input_layer == 'embed': | |||||
self.embed = torch.nn.Sequential( | |||||
torch.nn.Embedding( | |||||
input_size, output_size, padding_idx=padding_idx), | |||||
pos_enc_class(output_size, positional_dropout_rate), | |||||
) | |||||
elif input_layer is None: | |||||
if input_size == output_size: | |||||
self.embed = None | |||||
else: | |||||
self.embed = torch.nn.Linear(input_size, output_size) | |||||
else: | |||||
raise ValueError('unknown input_layer: ' + input_layer) | |||||
self.normalize_before = normalize_before | |||||
if positionwise_layer_type == 'linear': | |||||
positionwise_layer = PositionwiseFeedForward | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d': | |||||
positionwise_layer = MultiLayeredConv1d | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
elif positionwise_layer_type == 'conv1d-linear': | |||||
positionwise_layer = Conv1dLinear | |||||
positionwise_layer_args = ( | |||||
output_size, | |||||
linear_units, | |||||
positionwise_conv_kernel_size, | |||||
dropout_rate, | |||||
) | |||||
else: | |||||
raise NotImplementedError('Support only linear or conv1d.') | |||||
if selfattention_layer_type == 'selfattn': | |||||
encoder_selfattn_layer = MultiHeadedAttention | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
) | |||||
elif selfattention_layer_type == 'sanm': | |||||
encoder_selfattn_layer = MultiHeadedAttentionSANM | |||||
encoder_selfattn_layer_args = ( | |||||
attention_heads, | |||||
output_size, | |||||
attention_dropout_rate, | |||||
kernel_size, | |||||
sanm_shfit, | |||||
) | |||||
self.encoders = repeat( | |||||
num_blocks, | |||||
lambda lnum: EncoderLayerChunk( | |||||
output_size, | |||||
encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
positionwise_layer(*positionwise_layer_args), | |||||
dropout_rate, | |||||
normalize_before, | |||||
concat_after, | |||||
), | |||||
) | |||||
if self.normalize_before: | |||||
self.after_norm = LayerNorm(output_size) | |||||
self.interctc_layer_idx = interctc_layer_idx | |||||
if len(interctc_layer_idx) > 0: | |||||
assert 0 < min(interctc_layer_idx) and max( | |||||
interctc_layer_idx) < num_blocks | |||||
self.interctc_use_conditioning = interctc_use_conditioning | |||||
self.conditioning_layer = None | |||||
shfit_fsmn = (kernel_size - 1) // 2 | |||||
self.overlap_chunk_cls = overlap_chunk( | |||||
chunk_size=chunk_size, | |||||
stride=stride, | |||||
pad_left=pad_left, | |||||
shfit_fsmn=shfit_fsmn, | |||||
encoder_att_look_back_factor=encoder_att_look_back_factor, | |||||
) | |||||
def output_size(self) -> int: | |||||
return self._output_size | |||||
def forward( | |||||
self, | |||||
xs_pad: torch.Tensor, | |||||
ilens: torch.Tensor, | |||||
prev_states: torch.Tensor = None, | |||||
ctc: CTC = None, | |||||
ind: int = 0, | |||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
"""Embed positions in tensor. | |||||
Args: | |||||
xs_pad: input tensor (B, L, D) | |||||
ilens: input length (B) | |||||
prev_states: Not to be used now. | |||||
Returns: | |||||
position embedded tensor and mask | |||||
""" | |||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
if self.embed is None: | |||||
xs_pad = xs_pad | |||||
elif (isinstance(self.embed, Conv2dSubsampling) | |||||
or isinstance(self.embed, Conv2dSubsampling2) | |||||
or isinstance(self.embed, Conv2dSubsampling6) | |||||
or isinstance(self.embed, Conv2dSubsampling8)): | |||||
short_status, limit_size = check_short_utt(self.embed, | |||||
xs_pad.size(1)) | |||||
if short_status: | |||||
raise TooShortUttError( | |||||
f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
+ # noqa: * | |||||
f'(it needs more than {limit_size} frames), return empty results', | |||||
xs_pad.size(1), | |||||
limit_size, | |||||
) | |||||
xs_pad, masks = self.embed(xs_pad, masks) | |||||
else: | |||||
xs_pad = self.embed(xs_pad) | |||||
mask_shfit_chunk, mask_att_chunk_encoder = None, None | |||||
if self.overlap_chunk_cls is not None: | |||||
ilens = masks.squeeze(1).sum(1) | |||||
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) | |||||
xs_pad, ilens = self.overlap_chunk_cls.split_chunk( | |||||
xs_pad, ilens, chunk_outs=chunk_outs) | |||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( | |||||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||||
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( | |||||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||||
intermediate_outs = [] | |||||
if len(self.interctc_layer_idx) == 0: | |||||
xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None, | |||||
mask_shfit_chunk, | |||||
mask_att_chunk_encoder) | |||||
else: | |||||
for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None, | |||||
mask_shfit_chunk, | |||||
mask_att_chunk_encoder) | |||||
if layer_idx + 1 in self.interctc_layer_idx: | |||||
encoder_out = xs_pad | |||||
# intermediate outputs are also normalized | |||||
if self.normalize_before: | |||||
encoder_out = self.after_norm(encoder_out) | |||||
intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
if self.interctc_use_conditioning: | |||||
ctc_out = ctc.softmax(encoder_out) | |||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
if self.normalize_before: | |||||
xs_pad = self.after_norm(xs_pad) | |||||
if self.overlap_chunk_cls is not None: | |||||
xs_pad, olens = self.overlap_chunk_cls.remove_chunk( | |||||
xs_pad, ilens, chunk_outs) | |||||
if len(intermediate_outs) > 0: | |||||
return (xs_pad, intermediate_outs), olens, None | |||||
return xs_pad, olens, None |
@@ -1,113 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
import copy | |||||
from typing import Optional, Tuple, Union | |||||
import humanfriendly | |||||
import numpy as np | |||||
import torch | |||||
import torchaudio | |||||
import torchaudio.compliance.kaldi as kaldi | |||||
from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||||
from espnet2.layers.log_mel import LogMel | |||||
from espnet2.layers.stft import Stft | |||||
from espnet2.utils.get_default_kwargs import get_default_kwargs | |||||
from espnet.nets.pytorch_backend.frontends.frontend import Frontend | |||||
from typeguard import check_argument_types | |||||
class WavFrontend(AbsFrontend): | |||||
"""Conventional frontend structure for ASR. | |||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN | |||||
""" | |||||
def __init__( | |||||
self, | |||||
fs: Union[int, str] = 16000, | |||||
n_fft: int = 512, | |||||
win_length: int = 400, | |||||
hop_length: int = 160, | |||||
window: Optional[str] = 'hamming', | |||||
center: bool = True, | |||||
normalized: bool = False, | |||||
onesided: bool = True, | |||||
n_mels: int = 80, | |||||
fmin: int = None, | |||||
fmax: int = None, | |||||
htk: bool = False, | |||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend), | |||||
apply_stft: bool = True, | |||||
): | |||||
assert check_argument_types() | |||||
super().__init__() | |||||
if isinstance(fs, str): | |||||
fs = humanfriendly.parse_size(fs) | |||||
# Deepcopy (In general, dict shouldn't be used as default arg) | |||||
frontend_conf = copy.deepcopy(frontend_conf) | |||||
self.hop_length = hop_length | |||||
self.win_length = win_length | |||||
self.window = window | |||||
self.fs = fs | |||||
if apply_stft: | |||||
self.stft = Stft( | |||||
n_fft=n_fft, | |||||
win_length=win_length, | |||||
hop_length=hop_length, | |||||
center=center, | |||||
window=window, | |||||
normalized=normalized, | |||||
onesided=onesided, | |||||
) | |||||
else: | |||||
self.stft = None | |||||
self.apply_stft = apply_stft | |||||
if frontend_conf is not None: | |||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) | |||||
else: | |||||
self.frontend = None | |||||
self.logmel = LogMel( | |||||
fs=fs, | |||||
n_fft=n_fft, | |||||
n_mels=n_mels, | |||||
fmin=fmin, | |||||
fmax=fmax, | |||||
htk=htk, | |||||
) | |||||
self.n_mels = n_mels | |||||
self.frontend_type = 'default' | |||||
def output_size(self) -> int: | |||||
return self.n_mels | |||||
def forward( | |||||
self, input: torch.Tensor, | |||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
sample_frequency = self.fs | |||||
num_mel_bins = self.n_mels | |||||
frame_length = self.win_length * 1000 / sample_frequency | |||||
frame_shift = self.hop_length * 1000 / sample_frequency | |||||
waveform = input * (1 << 15) | |||||
mat = kaldi.fbank( | |||||
waveform, | |||||
num_mel_bins=num_mel_bins, | |||||
frame_length=frame_length, | |||||
frame_shift=frame_shift, | |||||
dither=1.0, | |||||
energy_floor=0.0, | |||||
window_type=self.window, | |||||
sample_frequency=sample_frequency) | |||||
input_feats = mat[None, :] | |||||
feats_lens = torch.randn(1) | |||||
feats_lens.fill_(input_feats.shape[1]) | |||||
return input_feats, feats_lens |
@@ -1,321 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
import logging | |||||
import math | |||||
import numpy as np | |||||
import torch | |||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
from ...nets.pytorch_backend.cif_utils.cif import \ | |||||
cif_predictor as cif_predictor | |||||
np.set_printoptions(threshold=np.inf) | |||||
torch.set_printoptions(profile='full', precision=100000, linewidth=None) | |||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'): | |||||
if maxlen is None: | |||||
maxlen = lengths.max() | |||||
row_vector = torch.arange(0, maxlen, 1) | |||||
matrix = torch.unsqueeze(lengths, dim=-1) | |||||
mask = row_vector < matrix | |||||
return mask.type(dtype).to(device) | |||||
class overlap_chunk(): | |||||
def __init__( | |||||
self, | |||||
chunk_size: tuple = (16, ), | |||||
stride: tuple = (10, ), | |||||
pad_left: tuple = (0, ), | |||||
encoder_att_look_back_factor: tuple = (1, ), | |||||
shfit_fsmn: int = 0, | |||||
): | |||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \ | |||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor | |||||
self.shfit_fsmn = shfit_fsmn | |||||
self.x_add_mask = None | |||||
self.x_rm_mask = None | |||||
self.x_len = None | |||||
self.mask_shfit_chunk = None | |||||
self.mask_chunk_predictor = None | |||||
self.mask_att_chunk_encoder = None | |||||
self.mask_shift_att_chunk_decoder = None | |||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \ | |||||
= None, None, None, None | |||||
def get_chunk_size(self, ind: int = 0): | |||||
# with torch.no_grad: | |||||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[ | |||||
ind], self.stride[ind], self.pad_left[ | |||||
ind], self.encoder_att_look_back_factor[ind] | |||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, | |||||
self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ | |||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn | |||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur | |||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): | |||||
with torch.no_grad(): | |||||
x_len = x_len.cpu().numpy() | |||||
x_len_max = x_len.max() | |||||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size( | |||||
ind) | |||||
shfit_fsmn = self.shfit_fsmn | |||||
chunk_size_pad_shift = chunk_size + shfit_fsmn | |||||
chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) | |||||
x_len_chunk = ( | |||||
chunk_num_batch - 1 | |||||
) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - ( | |||||
chunk_num_batch - 1) * stride | |||||
x_len_chunk = x_len_chunk.astype(x_len.dtype) | |||||
x_len_chunk_max = x_len_chunk.max() | |||||
chunk_num = int(math.ceil(x_len_max / stride)) | |||||
dtype = np.int32 | |||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max) | |||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) | |||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) | |||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) | |||||
mask_chunk_predictor = np.zeros([0, num_units_predictor], | |||||
dtype=dtype) | |||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) | |||||
mask_att_chunk_encoder = np.zeros( | |||||
[0, chunk_num * chunk_size_pad_shift], dtype=dtype) | |||||
for chunk_ids in range(chunk_num): | |||||
# x_mask add | |||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), | |||||
dtype=dtype) | |||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) | |||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), | |||||
dtype=dtype) | |||||
x_mask_pad_right = np.zeros( | |||||
(chunk_size, max_len_for_x_mask_tmp), dtype=dtype) | |||||
x_cur_pad = np.concatenate( | |||||
[x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) | |||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] | |||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], | |||||
axis=0) | |||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], | |||||
axis=0) | |||||
# x_mask rm | |||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), | |||||
dtype=dtype) | |||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) | |||||
x_mask_right = np.zeros((stride, chunk_size - stride), | |||||
dtype=dtype) | |||||
x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1) | |||||
x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size), | |||||
dtype=dtype) | |||||
x_mask_cur_pad_bottom = np.zeros( | |||||
(max_len_for_x_mask_tmp, chunk_size), dtype=dtype) | |||||
x_rm_mask_cur = np.concatenate( | |||||
[x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], | |||||
axis=0) | |||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, : | |||||
chunk_size] | |||||
x_rm_mask_cur_fsmn = np.concatenate( | |||||
[fsmn_padding, x_rm_mask_cur], axis=1) | |||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], | |||||
axis=1) | |||||
# fsmn_padding_mask | |||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) | |||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype) | |||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], | |||||
axis=0) | |||||
mask_shfit_chunk = np.concatenate( | |||||
[mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) | |||||
# predictor mask | |||||
zeros_1 = np.zeros( | |||||
[shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) | |||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) | |||||
zeros_3 = np.zeros( | |||||
[chunk_size - stride - pad_left, num_units_predictor], | |||||
dtype=dtype) | |||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) | |||||
mask_chunk_predictor_cur = np.concatenate( | |||||
[zeros_1, ones_zeros], axis=0) | |||||
mask_chunk_predictor = np.concatenate( | |||||
[mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) | |||||
# encoder att mask | |||||
zeros_1_top = np.zeros( | |||||
[shfit_fsmn, chunk_num * chunk_size_pad_shift], | |||||
dtype=dtype) | |||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) | |||||
zeros_2 = np.zeros( | |||||
[chunk_size, zeros_2_num * chunk_size_pad_shift], | |||||
dtype=dtype) | |||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) | |||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||||
ones_2_mid = np.ones([stride, stride], dtype=dtype) | |||||
zeros_2_bottom = np.zeros([chunk_size - stride, stride], | |||||
dtype=dtype) | |||||
zeros_2_right = np.zeros([chunk_size, chunk_size - stride], | |||||
dtype=dtype) | |||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) | |||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], | |||||
axis=1) | |||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) | |||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) | |||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) | |||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) | |||||
zeros_remain = np.zeros( | |||||
[chunk_size, zeros_remain_num * chunk_size_pad_shift], | |||||
dtype=dtype) | |||||
ones2_bottom = np.concatenate( | |||||
[zeros_2, ones_2, ones_3, zeros_remain], axis=1) | |||||
mask_att_chunk_encoder_cur = np.concatenate( | |||||
[zeros_1_top, ones2_bottom], axis=0) | |||||
mask_att_chunk_encoder = np.concatenate( | |||||
[mask_att_chunk_encoder, mask_att_chunk_encoder_cur], | |||||
axis=0) | |||||
# decoder fsmn_shift_att_mask | |||||
zeros_1 = np.zeros([shfit_fsmn, 1]) | |||||
ones_1 = np.ones([chunk_size, 1]) | |||||
mask_shift_att_chunk_decoder_cur = np.concatenate( | |||||
[zeros_1, ones_1], axis=0) | |||||
mask_shift_att_chunk_decoder = np.concatenate( | |||||
[ | |||||
mask_shift_att_chunk_decoder, | |||||
mask_shift_att_chunk_decoder_cur | |||||
], | |||||
vaxis=0) # noqa: * | |||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max] | |||||
self.x_len_chunk = x_len_chunk | |||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] | |||||
self.x_len = x_len | |||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] | |||||
self.mask_chunk_predictor = mask_chunk_predictor[: | |||||
x_len_chunk_max, :] | |||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[: | |||||
x_len_chunk_max, : | |||||
x_len_chunk_max] | |||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[: | |||||
x_len_chunk_max, :] | |||||
return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len, | |||||
self.mask_shfit_chunk, self.mask_chunk_predictor, | |||||
self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder) | |||||
def split_chunk(self, x, x_len, chunk_outs): | |||||
""" | |||||
:param x: (b, t, d) | |||||
:param x_length: (b) | |||||
:param ind: int | |||||
:return: | |||||
""" | |||||
x = x[:, :x_len.max(), :] | |||||
b, t, d = x.size() | |||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device) | |||||
x *= x_len_mask[:, :, None] | |||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) | |||||
x_len_chunk = self.get_x_len_chunk( | |||||
chunk_outs, x_len.device, dtype=x_len.dtype) | |||||
x = torch.transpose(x, 1, 0) | |||||
x = torch.reshape(x, [t, -1]) | |||||
x_chunk = torch.mm(x_add_mask, x) | |||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) | |||||
return x_chunk, x_len_chunk | |||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): | |||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :] | |||||
b, t, d = x_chunk.size() | |||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( | |||||
x_chunk.device) | |||||
x_chunk *= x_len_chunk_mask[:, :, None] | |||||
x_rm_mask = self.get_x_rm_mask( | |||||
chunk_outs, x_chunk.device, dtype=x_chunk.dtype) | |||||
x_len = self.get_x_len( | |||||
chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) | |||||
x_chunk = torch.transpose(x_chunk, 1, 0) | |||||
x_chunk = torch.reshape(x_chunk, [t, -1]) | |||||
x = torch.mm(x_rm_mask, x_chunk) | |||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0) | |||||
return x, x_len | |||||
def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_mask_shfit_chunk(self, | |||||
chunk_outs, | |||||
device, | |||||
batch_size=1, | |||||
num_units=1, | |||||
idx=4, | |||||
dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_mask_chunk_predictor(self, | |||||
chunk_outs, | |||||
device, | |||||
batch_size=1, | |||||
num_units=1, | |||||
idx=5, | |||||
dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_mask_att_chunk_encoder(self, | |||||
chunk_outs, | |||||
device, | |||||
batch_size=1, | |||||
idx=6, | |||||
dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() | |||||
def get_mask_shift_att_chunk_decoder(self, | |||||
chunk_outs, | |||||
device, | |||||
batch_size=1, | |||||
idx=7, | |||||
dtype=torch.float32): | |||||
x = chunk_outs[idx] | |||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) | |||||
x = torch.from_numpy(x).type(dtype).to(device) | |||||
return x.detach() |
@@ -1,250 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
import logging | |||||
import numpy as np | |||||
import torch | |||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
from torch import nn | |||||
class CIF_Model(nn.Module): | |||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||||
super(CIF_Model, self).__init__() | |||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||||
self.cif_conv1d = nn.Conv1d( | |||||
idim, idim, l_order + r_order + 1, groups=idim) | |||||
self.cif_output = nn.Linear(idim, 1) | |||||
self.dropout = torch.nn.Dropout(p=dropout) | |||||
self.threshold = threshold | |||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1): | |||||
h = hidden | |||||
context = h.transpose(1, 2) | |||||
queries = self.pad(context) | |||||
memory = self.cif_conv1d(queries) | |||||
output = memory + context | |||||
output = self.dropout(output) | |||||
output = output.transpose(1, 2) | |||||
output = torch.relu(output) | |||||
output = self.cif_output(output) | |||||
alphas = torch.sigmoid(output) | |||||
if mask is not None: | |||||
alphas = alphas * mask.transpose(-1, -2).float() | |||||
alphas = alphas.squeeze(-1) | |||||
if target_label is not None: | |||||
target_length = (target_label != ignore_id).float().sum(-1) | |||||
else: | |||||
target_length = None | |||||
cif_length = alphas.sum(-1) | |||||
if target_label is not None: | |||||
alphas *= (target_length / cif_length)[:, None].repeat( | |||||
1, alphas.size(1)) | |||||
cif_output, cif_peak = cif(hidden, alphas, self.threshold) | |||||
return cif_output, cif_length, target_length, cif_peak | |||||
def gen_frame_alignments(self, | |||||
alphas: torch.Tensor = None, | |||||
memory_sequence_length: torch.Tensor = None, | |||||
is_training: bool = True, | |||||
dtype: torch.dtype = torch.float32): | |||||
batch_size, maximum_length = alphas.size() | |||||
int_type = torch.int32 | |||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||||
max_token_num = torch.max(token_num).item() | |||||
alphas_cumsum = torch.cumsum(alphas, dim=1) | |||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||||
[1, max_token_num, 1]) | |||||
index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||||
index = torch.cumsum(index, dim=1) | |||||
index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||||
index_div = torch.floor(torch.divide(alphas_cumsum, | |||||
index)).type(int_type) | |||||
index_div_bool_zeros = index_div.eq(0) | |||||
index_div_bool_zeros_count = torch.sum( | |||||
index_div_bool_zeros, dim=-1) + 1 | |||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||||
memory_sequence_length.max()) | |||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||||
token_num.device) | |||||
index_div_bool_zeros_count *= token_num_mask | |||||
index_div_bool_zeros_count_tile = torch.tile( | |||||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||||
ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||||
ones = torch.cumsum(ones, dim=2) | |||||
cond = index_div_bool_zeros_count_tile == ones | |||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||||
torch.bool) | |||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||||
int_type) | |||||
index_div_bool_zeros_count_tile_out = torch.sum( | |||||
index_div_bool_zeros_count_tile, dim=1) | |||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||||
int_type) | |||||
predictor_mask = (~make_pad_mask( | |||||
memory_sequence_length, | |||||
maxlen=memory_sequence_length.max())).type(int_type).to( | |||||
memory_sequence_length.device) # noqa: * | |||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||||
return index_div_bool_zeros_count_tile_out.detach( | |||||
), index_div_bool_zeros_count.detach() | |||||
class cif_predictor(nn.Module): | |||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||||
super(cif_predictor, self).__init__() | |||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||||
self.cif_conv1d = nn.Conv1d( | |||||
idim, idim, l_order + r_order + 1, groups=idim) | |||||
self.cif_output = nn.Linear(idim, 1) | |||||
self.dropout = torch.nn.Dropout(p=dropout) | |||||
self.threshold = threshold | |||||
def forward(self, | |||||
hidden, | |||||
target_label=None, | |||||
mask=None, | |||||
ignore_id=-1, | |||||
mask_chunk_predictor=None, | |||||
target_label_length=None): | |||||
h = hidden | |||||
context = h.transpose(1, 2) | |||||
queries = self.pad(context) | |||||
memory = self.cif_conv1d(queries) | |||||
output = memory + context | |||||
output = self.dropout(output) | |||||
output = output.transpose(1, 2) | |||||
output = torch.relu(output) | |||||
output = self.cif_output(output) | |||||
alphas = torch.sigmoid(output) | |||||
if mask is not None: | |||||
alphas = alphas * mask.transpose(-1, -2).float() | |||||
if mask_chunk_predictor is not None: | |||||
alphas = alphas * mask_chunk_predictor | |||||
alphas = alphas.squeeze(-1) | |||||
if target_label_length is not None: | |||||
target_length = target_label_length | |||||
elif target_label is not None: | |||||
target_length = (target_label != ignore_id).float().sum(-1) | |||||
else: | |||||
target_length = None | |||||
token_num = alphas.sum(-1) | |||||
if target_length is not None: | |||||
alphas *= (target_length / token_num)[:, None].repeat( | |||||
1, alphas.size(1)) | |||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) | |||||
return acoustic_embeds, token_num, alphas, cif_peak | |||||
def gen_frame_alignments(self, | |||||
alphas: torch.Tensor = None, | |||||
memory_sequence_length: torch.Tensor = None, | |||||
is_training: bool = True, | |||||
dtype: torch.dtype = torch.float32): | |||||
batch_size, maximum_length = alphas.size() | |||||
int_type = torch.int32 | |||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||||
max_token_num = torch.max(token_num).item() | |||||
alphas_cumsum = torch.cumsum(alphas, dim=1) | |||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||||
[1, max_token_num, 1]) | |||||
index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||||
index = torch.cumsum(index, dim=1) | |||||
index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||||
index_div = torch.floor(torch.divide(alphas_cumsum, | |||||
index)).type(int_type) | |||||
index_div_bool_zeros = index_div.eq(0) | |||||
index_div_bool_zeros_count = torch.sum( | |||||
index_div_bool_zeros, dim=-1) + 1 | |||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||||
memory_sequence_length.max()) | |||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||||
token_num.device) | |||||
index_div_bool_zeros_count *= token_num_mask | |||||
index_div_bool_zeros_count_tile = torch.tile( | |||||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||||
ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||||
ones = torch.cumsum(ones, dim=2) | |||||
cond = index_div_bool_zeros_count_tile == ones | |||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||||
torch.bool) | |||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||||
int_type) | |||||
index_div_bool_zeros_count_tile_out = torch.sum( | |||||
index_div_bool_zeros_count_tile, dim=1) | |||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||||
int_type) | |||||
predictor_mask = (~make_pad_mask( | |||||
memory_sequence_length, | |||||
maxlen=memory_sequence_length.max())).type(int_type).to( | |||||
memory_sequence_length.device) # noqa: * | |||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||||
return index_div_bool_zeros_count_tile_out.detach( | |||||
), index_div_bool_zeros_count.detach() | |||||
def cif(hidden, alphas, threshold): | |||||
batch_size, len_time, hidden_size = hidden.size() | |||||
# loop varss | |||||
integrate = torch.zeros([batch_size], device=hidden.device) | |||||
frame = torch.zeros([batch_size, hidden_size], device=hidden.device) | |||||
# intermediate vars along time | |||||
list_fires = [] | |||||
list_frames = [] | |||||
for t in range(len_time): | |||||
alpha = alphas[:, t] | |||||
distribution_completion = torch.ones([batch_size], | |||||
device=hidden.device) - integrate | |||||
integrate += alpha | |||||
list_fires.append(integrate) | |||||
fire_place = integrate >= threshold | |||||
integrate = torch.where( | |||||
fire_place, | |||||
integrate - torch.ones([batch_size], device=hidden.device), | |||||
integrate) | |||||
cur = torch.where(fire_place, distribution_completion, alpha) | |||||
remainds = alpha - cur | |||||
frame += cur[:, None] * hidden[:, t, :] | |||||
list_frames.append(frame) | |||||
frame = torch.where(fire_place[:, None].repeat(1, hidden_size), | |||||
remainds[:, None] * hidden[:, t, :], frame) | |||||
fires = torch.stack(list_fires, 1) | |||||
frames = torch.stack(list_frames, 1) | |||||
list_ls = [] | |||||
len_labels = torch.round(alphas.sum(-1)).int() | |||||
max_label_len = len_labels.max() | |||||
for b in range(batch_size): | |||||
fire = fires[b, :] | |||||
ls = torch.index_select(frames[b, :, :], 0, | |||||
torch.nonzero(fire >= threshold).squeeze()) | |||||
pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size], | |||||
device=hidden.device) | |||||
list_ls.append(torch.cat([ls, pad_l], 0)) | |||||
return torch.stack(list_ls, 0), fires |
@@ -1,680 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
"""Multi-Head Attention layer definition.""" | |||||
import logging | |||||
import math | |||||
import numpy | |||||
import torch | |||||
from torch import nn | |||||
torch.set_printoptions(profile='full', precision=1) | |||||
class MultiHeadedAttention(nn.Module): | |||||
"""Multi-Head Attention layer. | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
""" | |||||
def __init__(self, n_head, n_feat, dropout_rate): | |||||
"""Construct an MultiHeadedAttention object.""" | |||||
super(MultiHeadedAttention, self).__init__() | |||||
assert n_feat % n_head == 0 | |||||
# We assume d_v always equals d_k | |||||
self.d_k = n_feat // n_head | |||||
self.h = n_head | |||||
self.linear_q = nn.Linear(n_feat, n_feat) | |||||
self.linear_k = nn.Linear(n_feat, n_feat) | |||||
self.linear_v = nn.Linear(n_feat, n_feat) | |||||
self.linear_out = nn.Linear(n_feat, n_feat) | |||||
self.attn = None | |||||
self.dropout = nn.Dropout(p=dropout_rate) | |||||
def forward_qkv(self, query, key, value): | |||||
"""Transform query, key and value. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
Returns: | |||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||||
""" | |||||
n_batch = query.size(0) | |||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||||
q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||||
k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||||
v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||||
return q, k, v | |||||
def forward_attention(self, value, scores, mask): | |||||
"""Compute attention context vector. | |||||
Args: | |||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Transformed value (#batch, time1, d_model) | |||||
weighted by the attention score (#batch, time1, time2). | |||||
""" | |||||
n_batch = value.size(0) | |||||
if mask is not None: | |||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||||
min_value = float( | |||||
numpy.finfo(torch.tensor( | |||||
0, dtype=scores.dtype).numpy().dtype).min) | |||||
scores = scores.masked_fill(mask, min_value) | |||||
self.attn = torch.softmax( | |||||
scores, dim=-1).masked_fill(mask, | |||||
0.0) # (batch, head, time1, time2) | |||||
else: | |||||
self.attn = torch.softmax( | |||||
scores, dim=-1) # (batch, head, time1, time2) | |||||
p_attn = self.dropout(self.attn) | |||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||||
self.h * self.d_k) | |||||
) # (batch, time1, d_model) | |||||
return self.linear_out(x) # (batch, time1, d_model) | |||||
def forward(self, query, key, value, mask): | |||||
"""Compute scaled dot product attention. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||||
return self.forward_attention(v, scores, mask) | |||||
class MultiHeadedAttentionSANM(nn.Module): | |||||
"""Multi-Head Attention layer. | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
""" | |||||
def __init__(self, | |||||
n_head, | |||||
n_feat, | |||||
dropout_rate, | |||||
kernel_size, | |||||
sanm_shfit=0): | |||||
"""Construct an MultiHeadedAttention object.""" | |||||
super(MultiHeadedAttentionSANM, self).__init__() | |||||
assert n_feat % n_head == 0 | |||||
# We assume d_v always equals d_k | |||||
self.d_k = n_feat // n_head | |||||
self.h = n_head | |||||
self.linear_q = nn.Linear(n_feat, n_feat) | |||||
self.linear_k = nn.Linear(n_feat, n_feat) | |||||
self.linear_v = nn.Linear(n_feat, n_feat) | |||||
self.linear_out = nn.Linear(n_feat, n_feat) | |||||
self.attn = None | |||||
self.dropout = nn.Dropout(p=dropout_rate) | |||||
self.fsmn_block = nn.Conv1d( | |||||
n_feat, | |||||
n_feat, | |||||
kernel_size, | |||||
stride=1, | |||||
padding=0, | |||||
groups=n_feat, | |||||
bias=False) | |||||
# padding | |||||
left_padding = (kernel_size - 1) // 2 | |||||
if sanm_shfit > 0: | |||||
left_padding = left_padding + sanm_shfit | |||||
right_padding = kernel_size - 1 - left_padding | |||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | |||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | |||||
''' | |||||
:param x: (#batch, time1, size). | |||||
:param mask: Mask tensor (#batch, 1, time) | |||||
:return: | |||||
''' | |||||
# b, t, d = inputs.size() | |||||
mask = mask[:, 0, :, None] | |||||
if mask_shfit_chunk is not None: | |||||
mask = mask * mask_shfit_chunk | |||||
inputs *= mask | |||||
x = inputs.transpose(1, 2) | |||||
x = self.pad_fn(x) | |||||
x = self.fsmn_block(x) | |||||
x = x.transpose(1, 2) | |||||
x += inputs | |||||
x = self.dropout(x) | |||||
return x * mask | |||||
def forward_qkv(self, query, key, value): | |||||
"""Transform query, key and value. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
Returns: | |||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||||
""" | |||||
n_batch = query.size(0) | |||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||||
q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||||
k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||||
v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||||
return q, k, v | |||||
def forward_attention(self, | |||||
value, | |||||
scores, | |||||
mask, | |||||
mask_att_chunk_encoder=None): | |||||
"""Compute attention context vector. | |||||
Args: | |||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Transformed value (#batch, time1, d_model) | |||||
weighted by the attention score (#batch, time1, time2). | |||||
""" | |||||
n_batch = value.size(0) | |||||
if mask is not None: | |||||
if mask_att_chunk_encoder is not None: | |||||
mask = mask * mask_att_chunk_encoder | |||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||||
min_value = float( | |||||
numpy.finfo(torch.tensor( | |||||
0, dtype=scores.dtype).numpy().dtype).min) | |||||
scores = scores.masked_fill(mask, min_value) | |||||
self.attn = torch.softmax( | |||||
scores, dim=-1).masked_fill(mask, | |||||
0.0) # (batch, head, time1, time2) | |||||
else: | |||||
self.attn = torch.softmax( | |||||
scores, dim=-1) # (batch, head, time1, time2) | |||||
p_attn = self.dropout(self.attn) | |||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||||
self.h * self.d_k) | |||||
) # (batch, time1, d_model) | |||||
return self.linear_out(x) # (batch, time1, d_model) | |||||
def forward(self, | |||||
query, | |||||
key, | |||||
value, | |||||
mask, | |||||
mask_shfit_chunk=None, | |||||
mask_att_chunk_encoder=None): | |||||
"""Compute scaled dot product attention. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk) | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||||
att_outs = self.forward_attention(v, scores, mask, | |||||
mask_att_chunk_encoder) | |||||
return att_outs + fsmn_memory | |||||
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): | |||||
"""Multi-Head Attention layer with relative position encoding (old version). | |||||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
Paper: https://arxiv.org/abs/1901.02860 | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
""" | |||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||||
super().__init__(n_head, n_feat, dropout_rate) | |||||
self.zero_triu = zero_triu | |||||
# linear transformation for positional encoding | |||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
# these two learnable bias are used in matrix c and matrix d | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
def rel_shift(self, x): | |||||
"""Compute relative positional encoding. | |||||
Args: | |||||
x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor. | |||||
""" | |||||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
device=x.device, | |||||
dtype=x.dtype) | |||||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
x = x_padded[:, :, 1:].view_as(x) | |||||
if self.zero_triu: | |||||
ones = torch.ones((x.size(2), x.size(3))) | |||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
return x | |||||
def forward(self, query, key, value, pos_emb, mask): | |||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
n_batch_pos = pos_emb.size(0) | |||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
# compute attention score | |||||
# first compute matrix a and matrix c | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
# (batch, head, time1, time2) | |||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
# compute matrix b and matrix d | |||||
# (batch, head, time1, time1) | |||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
matrix_bd = self.rel_shift(matrix_bd) | |||||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
self.d_k) # (batch, head, time1, time2) | |||||
return self.forward_attention(v, scores, mask) | |||||
class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||||
"""Multi-Head Attention layer with relative position encoding (old version). | |||||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
Paper: https://arxiv.org/abs/1901.02860 | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
""" | |||||
def __init__(self, | |||||
n_head, | |||||
n_feat, | |||||
dropout_rate, | |||||
zero_triu=False, | |||||
kernel_size=15, | |||||
sanm_shfit=0): | |||||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||||
self.zero_triu = zero_triu | |||||
# linear transformation for positional encoding | |||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
# these two learnable bias are used in matrix c and matrix d | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
def rel_shift(self, x): | |||||
"""Compute relative positional encoding. | |||||
Args: | |||||
x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor. | |||||
""" | |||||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
device=x.device, | |||||
dtype=x.dtype) | |||||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
x = x_padded[:, :, 1:].view_as(x) | |||||
if self.zero_triu: | |||||
ones = torch.ones((x.size(2), x.size(3))) | |||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
return x | |||||
def forward(self, query, key, value, pos_emb, mask): | |||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
fsmn_memory = self.forward_fsmn(value, mask) | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
n_batch_pos = pos_emb.size(0) | |||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
# compute attention score | |||||
# first compute matrix a and matrix c | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
# (batch, head, time1, time2) | |||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
# compute matrix b and matrix d | |||||
# (batch, head, time1, time1) | |||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
matrix_bd = self.rel_shift(matrix_bd) | |||||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
self.d_k) # (batch, head, time1, time2) | |||||
att_outs = self.forward_attention(v, scores, mask) | |||||
return att_outs + fsmn_memory | |||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention): | |||||
"""Multi-Head Attention layer with relative position encoding (new implementation). | |||||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
Paper: https://arxiv.org/abs/1901.02860 | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
""" | |||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||||
super().__init__(n_head, n_feat, dropout_rate) | |||||
self.zero_triu = zero_triu | |||||
# linear transformation for positional encoding | |||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
# these two learnable bias are used in matrix c and matrix d | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
def rel_shift(self, x): | |||||
"""Compute relative positional encoding. | |||||
Args: | |||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||||
time1 means the length of query vector. | |||||
Returns: | |||||
torch.Tensor: Output tensor. | |||||
""" | |||||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
device=x.device, | |||||
dtype=x.dtype) | |||||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
x = x_padded[:, :, 1:].view_as( | |||||
x)[:, :, :, :x.size(-1) // 2 | |||||
+ 1] # only keep the positions from 0 to time2 | |||||
if self.zero_triu: | |||||
ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
return x | |||||
def forward(self, query, key, value, pos_emb, mask): | |||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
pos_emb (torch.Tensor): Positional embedding tensor | |||||
(#batch, 2*time1-1, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
n_batch_pos = pos_emb.size(0) | |||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
# compute attention score | |||||
# first compute matrix a and matrix c | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
# (batch, head, time1, time2) | |||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
# compute matrix b and matrix d | |||||
# (batch, head, time1, 2*time1-1) | |||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
matrix_bd = self.rel_shift(matrix_bd) | |||||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
self.d_k) # (batch, head, time1, time2) | |||||
return self.forward_attention(v, scores, mask) | |||||
class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||||
"""Multi-Head Attention layer with relative position encoding (new implementation). | |||||
Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
Paper: https://arxiv.org/abs/1901.02860 | |||||
Args: | |||||
n_head (int): The number of heads. | |||||
n_feat (int): The number of features. | |||||
dropout_rate (float): Dropout rate. | |||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
""" | |||||
def __init__(self, | |||||
n_head, | |||||
n_feat, | |||||
dropout_rate, | |||||
zero_triu=False, | |||||
kernel_size=15, | |||||
sanm_shfit=0): | |||||
"""Construct an RelPositionMultiHeadedAttention object.""" | |||||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||||
self.zero_triu = zero_triu | |||||
# linear transformation for positional encoding | |||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
# these two learnable bias are used in matrix c and matrix d | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
def rel_shift(self, x): | |||||
"""Compute relative positional encoding. | |||||
Args: | |||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||||
time1 means the length of query vector. | |||||
Returns: | |||||
torch.Tensor: Output tensor. | |||||
""" | |||||
zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
device=x.device, | |||||
dtype=x.dtype) | |||||
x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
x = x_padded[:, :, 1:].view_as( | |||||
x)[:, :, :, :x.size(-1) // 2 | |||||
+ 1] # only keep the positions from 0 to time2 | |||||
if self.zero_triu: | |||||
ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
return x | |||||
def forward(self, query, key, value, pos_emb, mask): | |||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
Args: | |||||
query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
pos_emb (torch.Tensor): Positional embedding tensor | |||||
(#batch, 2*time1-1, size). | |||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
(#batch, time1, time2). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
""" | |||||
fsmn_memory = self.forward_fsmn(value, mask) | |||||
q, k, v = self.forward_qkv(query, key, value) | |||||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
n_batch_pos = pos_emb.size(0) | |||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
# (batch, head, time1, d_k) | |||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
# compute attention score | |||||
# first compute matrix a and matrix c | |||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
# (batch, head, time1, time2) | |||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
# compute matrix b and matrix d | |||||
# (batch, head, time1, 2*time1-1) | |||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
matrix_bd = self.rel_shift(matrix_bd) | |||||
scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
self.d_k) # (batch, head, time1, time2) | |||||
att_outs = self.forward_attention(v, scores, mask) | |||||
return att_outs + fsmn_memory |
@@ -1,239 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
"""Encoder self-attention layer definition.""" | |||||
import torch | |||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
from torch import nn | |||||
class EncoderLayer(nn.Module): | |||||
"""Encoder layer module. | |||||
Args: | |||||
size (int): Input dimension. | |||||
self_attn (torch.nn.Module): Self-attention module instance. | |||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||||
can be used as the argument. | |||||
feed_forward (torch.nn.Module): Feed-forward module instance. | |||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||||
can be used as the argument. | |||||
dropout_rate (float): Dropout rate. | |||||
normalize_before (bool): Whether to use layer_norm before the first block. | |||||
concat_after (bool): Whether to concat attention layer's input and output. | |||||
if True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
if False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
stochastic_depth_rate (float): Proability to skip this layer. | |||||
During training, the layer may skip residual computation and return input | |||||
as-is with given probability. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
size, | |||||
self_attn, | |||||
feed_forward, | |||||
dropout_rate, | |||||
normalize_before=True, | |||||
concat_after=False, | |||||
stochastic_depth_rate=0.0, | |||||
): | |||||
"""Construct an EncoderLayer object.""" | |||||
super(EncoderLayer, self).__init__() | |||||
self.self_attn = self_attn | |||||
self.feed_forward = feed_forward | |||||
self.norm1 = LayerNorm(size) | |||||
self.norm2 = LayerNorm(size) | |||||
self.dropout = nn.Dropout(dropout_rate) | |||||
self.size = size | |||||
self.normalize_before = normalize_before | |||||
self.concat_after = concat_after | |||||
if self.concat_after: | |||||
self.concat_linear = nn.Linear(size + size, size) | |||||
self.stochastic_depth_rate = stochastic_depth_rate | |||||
def forward(self, x, mask, cache=None): | |||||
"""Compute encoded features. | |||||
Args: | |||||
x_input (torch.Tensor): Input tensor (#batch, time, size). | |||||
mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time, size). | |||||
torch.Tensor: Mask tensor (#batch, time). | |||||
""" | |||||
skip_layer = False | |||||
# with stochastic depth, residual connection `x + f(x)` becomes | |||||
# `x <- x + 1 / (1 - p) * f(x)` at training time. | |||||
stoch_layer_coeff = 1.0 | |||||
if self.training and self.stochastic_depth_rate > 0: | |||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||||
if skip_layer: | |||||
if cache is not None: | |||||
x = torch.cat([cache, x], dim=1) | |||||
return x, mask | |||||
residual = x | |||||
if self.normalize_before: | |||||
x = self.norm1(x) | |||||
if cache is None: | |||||
x_q = x | |||||
else: | |||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||||
x_q = x[:, -1:, :] | |||||
residual = residual[:, -1:, :] | |||||
mask = None if mask is None else mask[:, -1:, :] | |||||
if self.concat_after: | |||||
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) | |||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||||
else: | |||||
x = residual + stoch_layer_coeff * self.dropout( | |||||
self.self_attn(x_q, x, x, mask)) | |||||
if not self.normalize_before: | |||||
x = self.norm1(x) | |||||
residual = x | |||||
if self.normalize_before: | |||||
x = self.norm2(x) | |||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||||
if not self.normalize_before: | |||||
x = self.norm2(x) | |||||
if cache is not None: | |||||
x = torch.cat([cache, x], dim=1) | |||||
return x, mask | |||||
class EncoderLayerChunk(nn.Module): | |||||
"""Encoder layer module. | |||||
Args: | |||||
size (int): Input dimension. | |||||
self_attn (torch.nn.Module): Self-attention module instance. | |||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||||
can be used as the argument. | |||||
feed_forward (torch.nn.Module): Feed-forward module instance. | |||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||||
can be used as the argument. | |||||
dropout_rate (float): Dropout rate. | |||||
normalize_before (bool): Whether to use layer_norm before the first block. | |||||
concat_after (bool): Whether to concat attention layer's input and output. | |||||
if True, additional linear will be applied. | |||||
i.e. x -> x + linear(concat(x, att(x))) | |||||
if False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
stochastic_depth_rate (float): Proability to skip this layer. | |||||
During training, the layer may skip residual computation and return input | |||||
as-is with given probability. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
size, | |||||
self_attn, | |||||
feed_forward, | |||||
dropout_rate, | |||||
normalize_before=True, | |||||
concat_after=False, | |||||
stochastic_depth_rate=0.0, | |||||
): | |||||
"""Construct an EncoderLayer object.""" | |||||
super(EncoderLayerChunk, self).__init__() | |||||
self.self_attn = self_attn | |||||
self.feed_forward = feed_forward | |||||
self.norm1 = LayerNorm(size) | |||||
self.norm2 = LayerNorm(size) | |||||
self.dropout = nn.Dropout(dropout_rate) | |||||
self.size = size | |||||
self.normalize_before = normalize_before | |||||
self.concat_after = concat_after | |||||
if self.concat_after: | |||||
self.concat_linear = nn.Linear(size + size, size) | |||||
self.stochastic_depth_rate = stochastic_depth_rate | |||||
def forward(self, | |||||
x, | |||||
mask, | |||||
cache=None, | |||||
mask_shfit_chunk=None, | |||||
mask_att_chunk_encoder=None): | |||||
"""Compute encoded features. | |||||
Args: | |||||
x_input (torch.Tensor): Input tensor (#batch, time, size). | |||||
mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||||
Returns: | |||||
torch.Tensor: Output tensor (#batch, time, size). | |||||
torch.Tensor: Mask tensor (#batch, time). | |||||
""" | |||||
skip_layer = False | |||||
# with stochastic depth, residual connection `x + f(x)` becomes | |||||
# `x <- x + 1 / (1 - p) * f(x)` at training time. | |||||
stoch_layer_coeff = 1.0 | |||||
if self.training and self.stochastic_depth_rate > 0: | |||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||||
if skip_layer: | |||||
if cache is not None: | |||||
x = torch.cat([cache, x], dim=1) | |||||
return x, mask | |||||
residual = x | |||||
if self.normalize_before: | |||||
x = self.norm1(x) | |||||
if cache is None: | |||||
x_q = x | |||||
else: | |||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||||
x_q = x[:, -1:, :] | |||||
residual = residual[:, -1:, :] | |||||
mask = None if mask is None else mask[:, -1:, :] | |||||
if self.concat_after: | |||||
x_concat = torch.cat( | |||||
(x, | |||||
self.self_attn( | |||||
x_q, | |||||
x, | |||||
x, | |||||
mask, | |||||
mask_shfit_chunk=mask_shfit_chunk, | |||||
mask_att_chunk_encoder=mask_att_chunk_encoder)), | |||||
dim=-1) | |||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||||
else: | |||||
x = residual + stoch_layer_coeff * self.dropout( | |||||
self.self_attn( | |||||
x_q, | |||||
x, | |||||
x, | |||||
mask, | |||||
mask_shfit_chunk=mask_shfit_chunk, | |||||
mask_att_chunk_encoder=mask_att_chunk_encoder)) | |||||
if not self.normalize_before: | |||||
x = self.norm1(x) | |||||
residual = x | |||||
if self.normalize_before: | |||||
x = self.norm2(x) | |||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||||
if not self.normalize_before: | |||||
x = self.norm2(x) | |||||
if cache is not None: | |||||
x = torch.cat([cache, x], dim=1) | |||||
return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder |
@@ -1,890 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
# Part of the implementation is borrowed from espnet/espnet. | |||||
import argparse | |||||
import logging | |||||
import os | |||||
from pathlib import Path | |||||
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union | |||||
import numpy as np | |||||
import torch | |||||
import yaml | |||||
from espnet2.asr.ctc import CTC | |||||
from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||||
from espnet2.asr.decoder.mlm_decoder import MLMDecoder | |||||
from espnet2.asr.decoder.rnn_decoder import RNNDecoder | |||||
from espnet2.asr.decoder.transformer_decoder import \ | |||||
DynamicConvolution2DTransformerDecoder # noqa: H301 | |||||
from espnet2.asr.decoder.transformer_decoder import \ | |||||
LightweightConvolution2DTransformerDecoder # noqa: H301 | |||||
from espnet2.asr.decoder.transformer_decoder import \ | |||||
LightweightConvolutionTransformerDecoder # noqa: H301 | |||||
from espnet2.asr.decoder.transformer_decoder import ( | |||||
DynamicConvolutionTransformerDecoder, TransformerDecoder) | |||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
from espnet2.asr.encoder.contextual_block_conformer_encoder import \ | |||||
ContextualBlockConformerEncoder # noqa: H301 | |||||
from espnet2.asr.encoder.contextual_block_transformer_encoder import \ | |||||
ContextualBlockTransformerEncoder # noqa: H301 | |||||
from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, | |||||
FairseqHubertPretrainEncoder) | |||||
from espnet2.asr.encoder.longformer_encoder import LongformerEncoder | |||||
from espnet2.asr.encoder.rnn_encoder import RNNEncoder | |||||
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder | |||||
from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder | |||||
from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder | |||||
from espnet2.asr.espnet_model import ESPnetASRModel | |||||
from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||||
from espnet2.asr.frontend.default import DefaultFrontend | |||||
from espnet2.asr.frontend.fused import FusedFrontends | |||||
from espnet2.asr.frontend.s3prl import S3prlFrontend | |||||
from espnet2.asr.frontend.windowing import SlidingWindow | |||||
from espnet2.asr.maskctc_model import MaskCTCModel | |||||
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder | |||||
from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ | |||||
HuggingFaceTransformersPostEncoder # noqa: H301 | |||||
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder | |||||
from espnet2.asr.preencoder.linear import LinearProjection | |||||
from espnet2.asr.preencoder.sinc import LightweightSincConvs | |||||
from espnet2.asr.specaug.abs_specaug import AbsSpecAug | |||||
from espnet2.asr.specaug.specaug import SpecAug | |||||
from espnet2.asr.transducer.joint_network import JointNetwork | |||||
from espnet2.asr.transducer.transducer_decoder import TransducerDecoder | |||||
from espnet2.layers.abs_normalize import AbsNormalize | |||||
from espnet2.layers.global_mvn import GlobalMVN | |||||
from espnet2.layers.utterance_mvn import UtteranceMVN | |||||
from espnet2.tasks.abs_task import AbsTask | |||||
from espnet2.text.phoneme_tokenizer import g2p_choices | |||||
from espnet2.torch_utils.initialize import initialize | |||||
from espnet2.train.abs_espnet_model import AbsESPnetModel | |||||
from espnet2.train.class_choices import ClassChoices | |||||
from espnet2.train.collate_fn import CommonCollateFn | |||||
from espnet2.train.preprocessor import CommonPreprocessor | |||||
from espnet2.train.trainer import Trainer | |||||
from espnet2.utils.get_default_kwargs import get_default_kwargs | |||||
from espnet2.utils.nested_dict_action import NestedDictAction | |||||
from espnet2.utils.types import (float_or_none, int_or_none, str2bool, | |||||
str_or_none) | |||||
from typeguard import check_argument_types, check_return_type | |||||
from ..asr.decoder.transformer_decoder import (ParaformerDecoder, | |||||
ParaformerDecoderBertEmbed) | |||||
from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2 | |||||
from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk | |||||
from ..asr.espnet_model import AEDStreaming | |||||
from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed | |||||
from ..nets.pytorch_backend.cif_utils.cif import cif_predictor | |||||
# FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries. | |||||
logging.basicConfig( | |||||
level='INFO', | |||||
format=f"[{os.uname()[1].split('.')[0]}]" | |||||
f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||||
) | |||||
# FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger | |||||
logger = logging.getLogger() | |||||
frontend_choices = ClassChoices( | |||||
name='frontend', | |||||
classes=dict( | |||||
default=DefaultFrontend, | |||||
sliding_window=SlidingWindow, | |||||
s3prl=S3prlFrontend, | |||||
fused=FusedFrontends, | |||||
), | |||||
type_check=AbsFrontend, | |||||
default='default', | |||||
) | |||||
specaug_choices = ClassChoices( | |||||
name='specaug', | |||||
classes=dict(specaug=SpecAug, ), | |||||
type_check=AbsSpecAug, | |||||
default=None, | |||||
optional=True, | |||||
) | |||||
normalize_choices = ClassChoices( | |||||
'normalize', | |||||
classes=dict( | |||||
global_mvn=GlobalMVN, | |||||
utterance_mvn=UtteranceMVN, | |||||
), | |||||
type_check=AbsNormalize, | |||||
default='utterance_mvn', | |||||
optional=True, | |||||
) | |||||
model_choices = ClassChoices( | |||||
'model', | |||||
classes=dict( | |||||
espnet=ESPnetASRModel, | |||||
maskctc=MaskCTCModel, | |||||
paraformer=Paraformer, | |||||
paraformer_bert_embed=ParaformerBertEmbed, | |||||
aedstreaming=AEDStreaming, | |||||
), | |||||
type_check=AbsESPnetModel, | |||||
default='espnet', | |||||
) | |||||
preencoder_choices = ClassChoices( | |||||
name='preencoder', | |||||
classes=dict( | |||||
sinc=LightweightSincConvs, | |||||
linear=LinearProjection, | |||||
), | |||||
type_check=AbsPreEncoder, | |||||
default=None, | |||||
optional=True, | |||||
) | |||||
encoder_choices = ClassChoices( | |||||
'encoder', | |||||
classes=dict( | |||||
conformer=ConformerEncoder, | |||||
transformer=TransformerEncoder, | |||||
contextual_block_transformer=ContextualBlockTransformerEncoder, | |||||
contextual_block_conformer=ContextualBlockConformerEncoder, | |||||
vgg_rnn=VGGRNNEncoder, | |||||
rnn=RNNEncoder, | |||||
wav2vec2=FairSeqWav2Vec2Encoder, | |||||
hubert=FairseqHubertEncoder, | |||||
hubert_pretrain=FairseqHubertPretrainEncoder, | |||||
longformer=LongformerEncoder, | |||||
sanm=SANMEncoder, | |||||
sanm_v2=SANMEncoder_v2, | |||||
sanm_chunk=SANMEncoderChunk, | |||||
), | |||||
type_check=AbsEncoder, | |||||
default='rnn', | |||||
) | |||||
postencoder_choices = ClassChoices( | |||||
name='postencoder', | |||||
classes=dict( | |||||
hugging_face_transformers=HuggingFaceTransformersPostEncoder, ), | |||||
type_check=AbsPostEncoder, | |||||
default=None, | |||||
optional=True, | |||||
) | |||||
decoder_choices = ClassChoices( | |||||
'decoder', | |||||
classes=dict( | |||||
transformer=TransformerDecoder, | |||||
lightweight_conv=LightweightConvolutionTransformerDecoder, | |||||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder, | |||||
dynamic_conv=DynamicConvolutionTransformerDecoder, | |||||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder, | |||||
rnn=RNNDecoder, | |||||
transducer=TransducerDecoder, | |||||
mlm=MLMDecoder, | |||||
paraformer_decoder=ParaformerDecoder, | |||||
paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed, | |||||
), | |||||
type_check=AbsDecoder, | |||||
default='rnn', | |||||
) | |||||
predictor_choices = ClassChoices( | |||||
name='predictor', | |||||
classes=dict( | |||||
cif_predictor=cif_predictor, | |||||
ctc_predictor=None, | |||||
), | |||||
type_check=None, | |||||
default='cif_predictor', | |||||
optional=True, | |||||
) | |||||
class ASRTask(AbsTask): | |||||
# If you need more than one optimizers, change this value | |||||
num_optimizers: int = 1 | |||||
# Add variable objects configurations | |||||
class_choices_list = [ | |||||
# --frontend and --frontend_conf | |||||
frontend_choices, | |||||
# --specaug and --specaug_conf | |||||
specaug_choices, | |||||
# --normalize and --normalize_conf | |||||
normalize_choices, | |||||
# --model and --model_conf | |||||
model_choices, | |||||
# --preencoder and --preencoder_conf | |||||
preencoder_choices, | |||||
# --encoder and --encoder_conf | |||||
encoder_choices, | |||||
# --postencoder and --postencoder_conf | |||||
postencoder_choices, | |||||
# --decoder and --decoder_conf | |||||
decoder_choices, | |||||
] | |||||
# If you need to modify train() or eval() procedures, change Trainer class here | |||||
trainer = Trainer | |||||
@classmethod | |||||
def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||||
group = parser.add_argument_group(description='Task related') | |||||
# NOTE(kamo): add_arguments(..., required=True) can't be used | |||||
# to provide --print_config mode. Instead of it, do as | |||||
required = parser.get_default('required') | |||||
required += ['token_list'] | |||||
group.add_argument( | |||||
'--token_list', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='A text mapping int-id to token', | |||||
) | |||||
group.add_argument( | |||||
'--init', | |||||
type=lambda x: str_or_none(x.lower()), | |||||
default=None, | |||||
help='The initialization method', | |||||
choices=[ | |||||
'chainer', | |||||
'xavier_uniform', | |||||
'xavier_normal', | |||||
'kaiming_uniform', | |||||
'kaiming_normal', | |||||
None, | |||||
], | |||||
) | |||||
group.add_argument( | |||||
'--input_size', | |||||
type=int_or_none, | |||||
default=None, | |||||
help='The number of input dimension of the feature', | |||||
) | |||||
group.add_argument( | |||||
'--ctc_conf', | |||||
action=NestedDictAction, | |||||
default=get_default_kwargs(CTC), | |||||
help='The keyword arguments for CTC class.', | |||||
) | |||||
group.add_argument( | |||||
'--joint_net_conf', | |||||
action=NestedDictAction, | |||||
default=None, | |||||
help='The keyword arguments for joint network class.', | |||||
) | |||||
group = parser.add_argument_group(description='Preprocess related') | |||||
group.add_argument( | |||||
'--use_preprocessor', | |||||
type=str2bool, | |||||
default=True, | |||||
help='Apply preprocessing to data or not', | |||||
) | |||||
group.add_argument( | |||||
'--token_type', | |||||
type=str, | |||||
default='bpe', | |||||
choices=['bpe', 'char', 'word', 'phn'], | |||||
help='The text will be tokenized ' | |||||
'in the specified level token', | |||||
) | |||||
group.add_argument( | |||||
'--bpemodel', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The model file of sentencepiece', | |||||
) | |||||
parser.add_argument( | |||||
'--non_linguistic_symbols', | |||||
type=str_or_none, | |||||
help='non_linguistic_symbols file path', | |||||
) | |||||
parser.add_argument( | |||||
'--cleaner', | |||||
type=str_or_none, | |||||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||||
default=None, | |||||
help='Apply text cleaning', | |||||
) | |||||
parser.add_argument( | |||||
'--g2p', | |||||
type=str_or_none, | |||||
choices=g2p_choices, | |||||
default=None, | |||||
help='Specify g2p method if --token_type=phn', | |||||
) | |||||
parser.add_argument( | |||||
'--speech_volume_normalize', | |||||
type=float_or_none, | |||||
default=None, | |||||
help='Scale the maximum amplitude to the given value.', | |||||
) | |||||
parser.add_argument( | |||||
'--rir_scp', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The file path of rir scp file.', | |||||
) | |||||
parser.add_argument( | |||||
'--rir_apply_prob', | |||||
type=float, | |||||
default=1.0, | |||||
help='THe probability for applying RIR convolution.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_scp', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The file path of noise scp file.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_apply_prob', | |||||
type=float, | |||||
default=1.0, | |||||
help='The probability applying Noise adding.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_db_range', | |||||
type=str, | |||||
default='13_15', | |||||
help='The range of noise decibel level.', | |||||
) | |||||
for class_choices in cls.class_choices_list: | |||||
# Append --<name> and --<name>_conf. | |||||
# e.g. --encoder and --encoder_conf | |||||
class_choices.add_arguments(group) | |||||
@classmethod | |||||
def build_collate_fn( | |||||
cls, args: argparse.Namespace, train: bool | |||||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||||
List[str], Dict[str, torch.Tensor]], ]: | |||||
assert check_argument_types() | |||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||||
@classmethod | |||||
def build_preprocess_fn( | |||||
cls, args: argparse.Namespace, train: bool | |||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||||
assert check_argument_types() | |||||
if args.use_preprocessor: | |||||
retval = CommonPreprocessor( | |||||
train=train, | |||||
token_type=args.token_type, | |||||
token_list=args.token_list, | |||||
bpemodel=args.bpemodel, | |||||
non_linguistic_symbols=args.non_linguistic_symbols, | |||||
text_cleaner=args.cleaner, | |||||
g2p_type=args.g2p, | |||||
# NOTE(kamo): Check attribute existence for backward compatibility | |||||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||||
rir_apply_prob=args.rir_apply_prob if hasattr( | |||||
args, 'rir_apply_prob') else 1.0, | |||||
noise_scp=args.noise_scp | |||||
if hasattr(args, 'noise_scp') else None, | |||||
noise_apply_prob=args.noise_apply_prob if hasattr( | |||||
args, 'noise_apply_prob') else 1.0, | |||||
noise_db_range=args.noise_db_range if hasattr( | |||||
args, 'noise_db_range') else '13_15', | |||||
speech_volume_normalize=args.speech_volume_normalize | |||||
if hasattr(args, 'rir_scp') else None, | |||||
) | |||||
else: | |||||
retval = None | |||||
assert check_return_type(retval) | |||||
return retval | |||||
@classmethod | |||||
def required_data_names(cls, | |||||
train: bool = True, | |||||
inference: bool = False) -> Tuple[str, ...]: | |||||
if not inference: | |||||
retval = ('speech', 'text') | |||||
else: | |||||
# Recognition mode | |||||
retval = ('speech', ) | |||||
return retval | |||||
@classmethod | |||||
def optional_data_names(cls, | |||||
train: bool = True, | |||||
inference: bool = False) -> Tuple[str, ...]: | |||||
retval = () | |||||
assert check_return_type(retval) | |||||
return retval | |||||
@classmethod | |||||
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: | |||||
assert check_argument_types() | |||||
if isinstance(args.token_list, str): | |||||
with open(args.token_list, encoding='utf-8') as f: | |||||
token_list = [line.rstrip() for line in f] | |||||
# Overwriting token_list to keep it as "portable". | |||||
args.token_list = list(token_list) | |||||
elif isinstance(args.token_list, (tuple, list)): | |||||
token_list = list(args.token_list) | |||||
else: | |||||
raise RuntimeError('token_list must be str or list') | |||||
vocab_size = len(token_list) | |||||
logger.info(f'Vocabulary size: {vocab_size }') | |||||
# 1. frontend | |||||
if args.input_size is None: | |||||
# Extract features in the model | |||||
frontend_class = frontend_choices.get_class(args.frontend) | |||||
frontend = frontend_class(**args.frontend_conf) | |||||
input_size = frontend.output_size() | |||||
else: | |||||
# Give features from data-loader | |||||
args.frontend = None | |||||
args.frontend_conf = {} | |||||
frontend = None | |||||
input_size = args.input_size | |||||
# 2. Data augmentation for spectrogram | |||||
if args.specaug is not None: | |||||
specaug_class = specaug_choices.get_class(args.specaug) | |||||
specaug = specaug_class(**args.specaug_conf) | |||||
else: | |||||
specaug = None | |||||
# 3. Normalization layer | |||||
if args.normalize is not None: | |||||
normalize_class = normalize_choices.get_class(args.normalize) | |||||
normalize = normalize_class(**args.normalize_conf) | |||||
else: | |||||
normalize = None | |||||
# 4. Pre-encoder input block | |||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
if getattr(args, 'preencoder', None) is not None: | |||||
preencoder_class = preencoder_choices.get_class(args.preencoder) | |||||
preencoder = preencoder_class(**args.preencoder_conf) | |||||
input_size = preencoder.output_size() | |||||
else: | |||||
preencoder = None | |||||
# 4. Encoder | |||||
encoder_class = encoder_choices.get_class(args.encoder) | |||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||||
# 5. Post-encoder block | |||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
encoder_output_size = encoder.output_size() | |||||
if getattr(args, 'postencoder', None) is not None: | |||||
postencoder_class = postencoder_choices.get_class(args.postencoder) | |||||
postencoder = postencoder_class( | |||||
input_size=encoder_output_size, **args.postencoder_conf) | |||||
encoder_output_size = postencoder.output_size() | |||||
else: | |||||
postencoder = None | |||||
# 5. Decoder | |||||
decoder_class = decoder_choices.get_class(args.decoder) | |||||
if args.decoder == 'transducer': | |||||
decoder = decoder_class( | |||||
vocab_size, | |||||
embed_pad=0, | |||||
**args.decoder_conf, | |||||
) | |||||
joint_network = JointNetwork( | |||||
vocab_size, | |||||
encoder.output_size(), | |||||
decoder.dunits, | |||||
**args.joint_net_conf, | |||||
) | |||||
else: | |||||
decoder = decoder_class( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
**args.decoder_conf, | |||||
) | |||||
joint_network = None | |||||
# 6. CTC | |||||
ctc = CTC( | |||||
odim=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
**args.ctc_conf) | |||||
# 7. Build model | |||||
try: | |||||
model_class = model_choices.get_class(args.model) | |||||
except AttributeError: | |||||
model_class = model_choices.get_class('espnet') | |||||
model = model_class( | |||||
vocab_size=vocab_size, | |||||
frontend=frontend, | |||||
specaug=specaug, | |||||
normalize=normalize, | |||||
preencoder=preencoder, | |||||
encoder=encoder, | |||||
postencoder=postencoder, | |||||
decoder=decoder, | |||||
ctc=ctc, | |||||
joint_network=joint_network, | |||||
token_list=token_list, | |||||
**args.model_conf, | |||||
) | |||||
# FIXME(kamo): Should be done in model? | |||||
# 8. Initialize | |||||
if args.init is not None: | |||||
initialize(model, args.init) | |||||
assert check_return_type(model) | |||||
return model | |||||
class ASRTaskNAR(AbsTask): | |||||
# If you need more than one optimizers, change this value | |||||
num_optimizers: int = 1 | |||||
# Add variable objects configurations | |||||
class_choices_list = [ | |||||
# --frontend and --frontend_conf | |||||
frontend_choices, | |||||
# --specaug and --specaug_conf | |||||
specaug_choices, | |||||
# --normalize and --normalize_conf | |||||
normalize_choices, | |||||
# --model and --model_conf | |||||
model_choices, | |||||
# --preencoder and --preencoder_conf | |||||
preencoder_choices, | |||||
# --encoder and --encoder_conf | |||||
encoder_choices, | |||||
# --postencoder and --postencoder_conf | |||||
postencoder_choices, | |||||
# --decoder and --decoder_conf | |||||
decoder_choices, | |||||
# --predictor and --predictor_conf | |||||
predictor_choices, | |||||
] | |||||
# If you need to modify train() or eval() procedures, change Trainer class here | |||||
trainer = Trainer | |||||
@classmethod | |||||
def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||||
group = parser.add_argument_group(description='Task related') | |||||
# NOTE(kamo): add_arguments(..., required=True) can't be used | |||||
# to provide --print_config mode. Instead of it, do as | |||||
required = parser.get_default('required') | |||||
required += ['token_list'] | |||||
group.add_argument( | |||||
'--token_list', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='A text mapping int-id to token', | |||||
) | |||||
group.add_argument( | |||||
'--init', | |||||
type=lambda x: str_or_none(x.lower()), | |||||
default=None, | |||||
help='The initialization method', | |||||
choices=[ | |||||
'chainer', | |||||
'xavier_uniform', | |||||
'xavier_normal', | |||||
'kaiming_uniform', | |||||
'kaiming_normal', | |||||
None, | |||||
], | |||||
) | |||||
group.add_argument( | |||||
'--input_size', | |||||
type=int_or_none, | |||||
default=None, | |||||
help='The number of input dimension of the feature', | |||||
) | |||||
group.add_argument( | |||||
'--ctc_conf', | |||||
action=NestedDictAction, | |||||
default=get_default_kwargs(CTC), | |||||
help='The keyword arguments for CTC class.', | |||||
) | |||||
group.add_argument( | |||||
'--joint_net_conf', | |||||
action=NestedDictAction, | |||||
default=None, | |||||
help='The keyword arguments for joint network class.', | |||||
) | |||||
group = parser.add_argument_group(description='Preprocess related') | |||||
group.add_argument( | |||||
'--use_preprocessor', | |||||
type=str2bool, | |||||
default=True, | |||||
help='Apply preprocessing to data or not', | |||||
) | |||||
group.add_argument( | |||||
'--token_type', | |||||
type=str, | |||||
default='bpe', | |||||
choices=['bpe', 'char', 'word', 'phn'], | |||||
help='The text will be tokenized ' | |||||
'in the specified level token', | |||||
) | |||||
group.add_argument( | |||||
'--bpemodel', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The model file of sentencepiece', | |||||
) | |||||
parser.add_argument( | |||||
'--non_linguistic_symbols', | |||||
type=str_or_none, | |||||
help='non_linguistic_symbols file path', | |||||
) | |||||
parser.add_argument( | |||||
'--cleaner', | |||||
type=str_or_none, | |||||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||||
default=None, | |||||
help='Apply text cleaning', | |||||
) | |||||
parser.add_argument( | |||||
'--g2p', | |||||
type=str_or_none, | |||||
choices=g2p_choices, | |||||
default=None, | |||||
help='Specify g2p method if --token_type=phn', | |||||
) | |||||
parser.add_argument( | |||||
'--speech_volume_normalize', | |||||
type=float_or_none, | |||||
default=None, | |||||
help='Scale the maximum amplitude to the given value.', | |||||
) | |||||
parser.add_argument( | |||||
'--rir_scp', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The file path of rir scp file.', | |||||
) | |||||
parser.add_argument( | |||||
'--rir_apply_prob', | |||||
type=float, | |||||
default=1.0, | |||||
help='THe probability for applying RIR convolution.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_scp', | |||||
type=str_or_none, | |||||
default=None, | |||||
help='The file path of noise scp file.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_apply_prob', | |||||
type=float, | |||||
default=1.0, | |||||
help='The probability applying Noise adding.', | |||||
) | |||||
parser.add_argument( | |||||
'--noise_db_range', | |||||
type=str, | |||||
default='13_15', | |||||
help='The range of noise decibel level.', | |||||
) | |||||
for class_choices in cls.class_choices_list: | |||||
# Append --<name> and --<name>_conf. | |||||
# e.g. --encoder and --encoder_conf | |||||
class_choices.add_arguments(group) | |||||
@classmethod | |||||
def build_collate_fn( | |||||
cls, args: argparse.Namespace, train: bool | |||||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||||
List[str], Dict[str, torch.Tensor]], ]: | |||||
assert check_argument_types() | |||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||||
@classmethod | |||||
def build_preprocess_fn( | |||||
cls, args: argparse.Namespace, train: bool | |||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||||
assert check_argument_types() | |||||
if args.use_preprocessor: | |||||
retval = CommonPreprocessor( | |||||
train=train, | |||||
token_type=args.token_type, | |||||
token_list=args.token_list, | |||||
bpemodel=args.bpemodel, | |||||
non_linguistic_symbols=args.non_linguistic_symbols, | |||||
text_cleaner=args.cleaner, | |||||
g2p_type=args.g2p, | |||||
# NOTE(kamo): Check attribute existence for backward compatibility | |||||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||||
rir_apply_prob=args.rir_apply_prob if hasattr( | |||||
args, 'rir_apply_prob') else 1.0, | |||||
noise_scp=args.noise_scp | |||||
if hasattr(args, 'noise_scp') else None, | |||||
noise_apply_prob=args.noise_apply_prob if hasattr( | |||||
args, 'noise_apply_prob') else 1.0, | |||||
noise_db_range=args.noise_db_range if hasattr( | |||||
args, 'noise_db_range') else '13_15', | |||||
speech_volume_normalize=args.speech_volume_normalize | |||||
if hasattr(args, 'rir_scp') else None, | |||||
) | |||||
else: | |||||
retval = None | |||||
assert check_return_type(retval) | |||||
return retval | |||||
@classmethod | |||||
def required_data_names(cls, | |||||
train: bool = True, | |||||
inference: bool = False) -> Tuple[str, ...]: | |||||
if not inference: | |||||
retval = ('speech', 'text') | |||||
else: | |||||
# Recognition mode | |||||
retval = ('speech', ) | |||||
return retval | |||||
@classmethod | |||||
def optional_data_names(cls, | |||||
train: bool = True, | |||||
inference: bool = False) -> Tuple[str, ...]: | |||||
retval = () | |||||
assert check_return_type(retval) | |||||
return retval | |||||
@classmethod | |||||
def build_model(cls, args: argparse.Namespace): | |||||
assert check_argument_types() | |||||
if isinstance(args.token_list, str): | |||||
with open(args.token_list, encoding='utf-8') as f: | |||||
token_list = [line.rstrip() for line in f] | |||||
# Overwriting token_list to keep it as "portable". | |||||
args.token_list = list(token_list) | |||||
elif isinstance(args.token_list, (tuple, list)): | |||||
token_list = list(args.token_list) | |||||
else: | |||||
raise RuntimeError('token_list must be str or list') | |||||
vocab_size = len(token_list) | |||||
# logger.info(f'Vocabulary size: {vocab_size }') | |||||
# 1. frontend | |||||
if args.input_size is None: | |||||
# Extract features in the model | |||||
frontend_class = frontend_choices.get_class(args.frontend) | |||||
frontend = frontend_class(**args.frontend_conf) | |||||
input_size = frontend.output_size() | |||||
else: | |||||
# Give features from data-loader | |||||
args.frontend = None | |||||
args.frontend_conf = {} | |||||
frontend = None | |||||
input_size = args.input_size | |||||
# 2. Data augmentation for spectrogram | |||||
if args.specaug is not None: | |||||
specaug_class = specaug_choices.get_class(args.specaug) | |||||
specaug = specaug_class(**args.specaug_conf) | |||||
else: | |||||
specaug = None | |||||
# 3. Normalization layer | |||||
if args.normalize is not None: | |||||
normalize_class = normalize_choices.get_class(args.normalize) | |||||
normalize = normalize_class(**args.normalize_conf) | |||||
else: | |||||
normalize = None | |||||
# 4. Pre-encoder input block | |||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
if getattr(args, 'preencoder', None) is not None: | |||||
preencoder_class = preencoder_choices.get_class(args.preencoder) | |||||
preencoder = preencoder_class(**args.preencoder_conf) | |||||
input_size = preencoder.output_size() | |||||
else: | |||||
preencoder = None | |||||
# 4. Encoder | |||||
encoder_class = encoder_choices.get_class(args.encoder) | |||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||||
# 5. Post-encoder block | |||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
encoder_output_size = encoder.output_size() | |||||
if getattr(args, 'postencoder', None) is not None: | |||||
postencoder_class = postencoder_choices.get_class(args.postencoder) | |||||
postencoder = postencoder_class( | |||||
input_size=encoder_output_size, **args.postencoder_conf) | |||||
encoder_output_size = postencoder.output_size() | |||||
else: | |||||
postencoder = None | |||||
# 5. Decoder | |||||
decoder_class = decoder_choices.get_class(args.decoder) | |||||
if args.decoder == 'transducer': | |||||
decoder = decoder_class( | |||||
vocab_size, | |||||
embed_pad=0, | |||||
**args.decoder_conf, | |||||
) | |||||
joint_network = JointNetwork( | |||||
vocab_size, | |||||
encoder.output_size(), | |||||
decoder.dunits, | |||||
**args.joint_net_conf, | |||||
) | |||||
else: | |||||
decoder = decoder_class( | |||||
vocab_size=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
**args.decoder_conf, | |||||
) | |||||
joint_network = None | |||||
# 6. CTC | |||||
ctc = CTC( | |||||
odim=vocab_size, | |||||
encoder_output_size=encoder_output_size, | |||||
**args.ctc_conf) | |||||
predictor_class = predictor_choices.get_class(args.predictor) | |||||
predictor = predictor_class(**args.predictor_conf) | |||||
# 7. Build model | |||||
try: | |||||
model_class = model_choices.get_class(args.model) | |||||
except AttributeError: | |||||
model_class = model_choices.get_class('espnet') | |||||
model = model_class( | |||||
vocab_size=vocab_size, | |||||
frontend=frontend, | |||||
specaug=specaug, | |||||
normalize=normalize, | |||||
preencoder=preencoder, | |||||
encoder=encoder, | |||||
postencoder=postencoder, | |||||
decoder=decoder, | |||||
ctc=ctc, | |||||
joint_network=joint_network, | |||||
token_list=token_list, | |||||
predictor=predictor, | |||||
**args.model_conf, | |||||
) | |||||
# FIXME(kamo): Should be done in model? | |||||
# 8. Initialize | |||||
if args.init is not None: | |||||
initialize(model, args.init) | |||||
assert check_return_type(model) | |||||
return model |
@@ -1,223 +0,0 @@ | |||||
import os | |||||
import shutil | |||||
import threading | |||||
from typing import Any, Dict, List, Sequence, Tuple, Union | |||||
import yaml | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models import Model | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import WavToScp | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from .asr_engine.common import asr_utils | |||||
logger = get_logger() | |||||
__all__ = ['AutomaticSpeechRecognitionPipeline'] | |||||
@PIPELINES.register_module( | |||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||||
class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
"""ASR Pipeline | |||||
""" | |||||
def __init__(self, | |||||
model: Union[List[Model], List[str]] = None, | |||||
preprocessor: WavToScp = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create an asr pipeline for prediction | |||||
""" | |||||
from .asr_engine import asr_env_checking | |||||
assert model is not None, 'asr model should be provided' | |||||
model_list: List = [] | |||||
if isinstance(model[0], Model): | |||||
model_list = model | |||||
else: | |||||
model_list.append(Model.from_pretrained(model[0])) | |||||
if len(model) == 2 and model[1] is not None: | |||||
model_list.append(Model.from_pretrained(model[1])) | |||||
super().__init__(model=model_list, preprocessor=preprocessor, **kwargs) | |||||
self._preprocessor = preprocessor | |||||
self._am_model = model_list[0] | |||||
if len(model_list) == 2 and model_list[1] is not None: | |||||
self._lm_model = model_list[1] | |||||
def __call__(self, | |||||
wav_path: str, | |||||
recog_type: str = None, | |||||
audio_format: str = None, | |||||
workspace: str = None) -> Dict[str, Any]: | |||||
assert len(wav_path) > 0, 'wav_path should be provided' | |||||
self._recog_type = recog_type | |||||
self._audio_format = audio_format | |||||
self._workspace = workspace | |||||
self._wav_path = wav_path | |||||
if recog_type is None or audio_format is None or workspace is None: | |||||
self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking( | |||||
wav_path, recog_type, audio_format, workspace) | |||||
if self._preprocessor is None: | |||||
self._preprocessor = WavToScp(workspace=self._workspace) | |||||
output = self._preprocessor.forward(self._am_model.forward(), | |||||
self._recog_type, | |||||
self._audio_format, self._wav_path) | |||||
output = self.forward(output) | |||||
rst = self.postprocess(output) | |||||
return rst | |||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""Decoding | |||||
""" | |||||
logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||||
j: int = 0 | |||||
process = [] | |||||
while j < inputs['thread_count']: | |||||
data_cmd: Sequence[Tuple[str, str, str]] | |||||
if inputs['audio_format'] == 'wav': | |||||
data_cmd = [(os.path.join(inputs['workspace'], | |||||
'data.' + str(j) + '.scp'), 'speech', | |||||
'sound')] | |||||
elif inputs['audio_format'] == 'kaldi_ark': | |||||
data_cmd = [(os.path.join(inputs['workspace'], | |||||
'data.' + str(j) + '.scp'), 'speech', | |||||
'kaldi_ark')] | |||||
output_dir: str = os.path.join(inputs['output'], | |||||
'output.' + str(j)) | |||||
if not os.path.exists(output_dir): | |||||
os.mkdir(output_dir) | |||||
config_file = open(inputs['asr_model_config']) | |||||
root = yaml.full_load(config_file) | |||||
config_file.close() | |||||
frontend_conf = None | |||||
if 'frontend_conf' in root: | |||||
frontend_conf = root['frontend_conf'] | |||||
cmd = { | |||||
'model_type': inputs['model_type'], | |||||
'beam_size': root['beam_size'], | |||||
'penalty': root['penalty'], | |||||
'maxlenratio': root['maxlenratio'], | |||||
'minlenratio': root['minlenratio'], | |||||
'ctc_weight': root['ctc_weight'], | |||||
'lm_weight': root['lm_weight'], | |||||
'output_dir': output_dir, | |||||
'ngpu': 0, | |||||
'log_level': 'ERROR', | |||||
'data_path_and_name_and_type': data_cmd, | |||||
'asr_train_config': inputs['am_model_config'], | |||||
'asr_model_file': inputs['am_model_path'], | |||||
'batch_size': inputs['model_config']['batch_size'], | |||||
'frontend_conf': frontend_conf | |||||
} | |||||
thread = AsrInferenceThread(j, cmd) | |||||
thread.start() | |||||
j += 1 | |||||
process.append(thread) | |||||
for p in process: | |||||
p.join() | |||||
return inputs | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""process the asr results | |||||
""" | |||||
logger.info('Computing the result of ASR ...') | |||||
rst = {'rec_result': 'None'} | |||||
# single wav task | |||||
if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav': | |||||
text_file: str = os.path.join(inputs['output'], 'output.0', | |||||
'1best_recog', 'text') | |||||
if os.path.exists(text_file): | |||||
f = open(text_file, 'r') | |||||
result_str: str = f.readline() | |||||
f.close() | |||||
if len(result_str) > 0: | |||||
result_list = result_str.split() | |||||
if len(result_list) >= 2: | |||||
rst['rec_result'] = result_list[1] | |||||
# run with datasets, and audio format is waveform or kaldi_ark | |||||
elif inputs['recog_type'] != 'wav': | |||||
inputs['reference_text'] = self._ref_text_tidy(inputs) | |||||
inputs['datasets_result'] = asr_utils.compute_wer( | |||||
inputs['hypothesis_text'], inputs['reference_text']) | |||||
else: | |||||
raise ValueError('recog_type and audio_format are mismatching') | |||||
if 'datasets_result' in inputs: | |||||
rst['datasets_result'] = inputs['datasets_result'] | |||||
# remove workspace dir (.tmp) | |||||
if os.path.exists(self._workspace): | |||||
shutil.rmtree(self._workspace) | |||||
return rst | |||||
def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str: | |||||
ref_text: str = os.path.join(inputs['output'], 'text.ref') | |||||
k: int = 0 | |||||
while k < inputs['thread_count']: | |||||
output_text = os.path.join(inputs['output'], 'output.' + str(k), | |||||
'1best_recog', 'text') | |||||
if os.path.exists(output_text): | |||||
with open(output_text, 'r', encoding='utf-8') as i: | |||||
lines = i.readlines() | |||||
with open(ref_text, 'a', encoding='utf-8') as o: | |||||
for line in lines: | |||||
o.write(line) | |||||
k += 1 | |||||
return ref_text | |||||
class AsrInferenceThread(threading.Thread): | |||||
def __init__(self, threadID, cmd): | |||||
threading.Thread.__init__(self) | |||||
self._threadID = threadID | |||||
self._cmd = cmd | |||||
def run(self): | |||||
if self._cmd['model_type'] == 'pytorch': | |||||
from .asr_engine import asr_inference_paraformer_espnet | |||||
asr_inference_paraformer_espnet.asr_inference( | |||||
batch_size=self._cmd['batch_size'], | |||||
output_dir=self._cmd['output_dir'], | |||||
maxlenratio=self._cmd['maxlenratio'], | |||||
minlenratio=self._cmd['minlenratio'], | |||||
beam_size=self._cmd['beam_size'], | |||||
ngpu=self._cmd['ngpu'], | |||||
ctc_weight=self._cmd['ctc_weight'], | |||||
lm_weight=self._cmd['lm_weight'], | |||||
penalty=self._cmd['penalty'], | |||||
log_level=self._cmd['log_level'], | |||||
data_path_and_name_and_type=self. | |||||
_cmd['data_path_and_name_and_type'], | |||||
asr_train_config=self._cmd['asr_train_config'], | |||||
asr_model_file=self._cmd['asr_model_file'], | |||||
frontend_conf=self._cmd['frontend_conf']) |
@@ -0,0 +1,213 @@ | |||||
import os | |||||
from typing import Any, Dict, List, Sequence, Tuple, Union | |||||
import yaml | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models import Model | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import WavToScp | |||||
from modelscope.utils.constant import Frameworks, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
__all__ = ['AutomaticSpeechRecognitionPipeline'] | |||||
@PIPELINES.register_module( | |||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||||
class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
"""ASR Inference Pipeline | |||||
""" | |||||
def __init__(self, | |||||
model: Union[Model, str] = None, | |||||
preprocessor: WavToScp = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create an asr pipeline for prediction | |||||
""" | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
def __call__(self, | |||||
audio_in: Union[str, bytes], | |||||
recog_type: str = None, | |||||
audio_format: str = None) -> Dict[str, Any]: | |||||
from easyasr.common import asr_utils | |||||
self.recog_type = recog_type | |||||
self.audio_format = audio_format | |||||
self.audio_in = audio_in | |||||
if recog_type is None or audio_format is None: | |||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||||
audio_in, recog_type, audio_format) | |||||
if self.preprocessor is None: | |||||
self.preprocessor = WavToScp() | |||||
output = self.preprocessor.forward(self.model.forward(), | |||||
self.recog_type, self.audio_format, | |||||
self.audio_in) | |||||
output = self.forward(output) | |||||
rst = self.postprocess(output) | |||||
return rst | |||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""Decoding | |||||
""" | |||||
logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||||
data_cmd: Sequence[Tuple[str, str]] | |||||
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | |||||
data_cmd = ['speech', 'sound'] | |||||
elif inputs['audio_format'] == 'kaldi_ark': | |||||
data_cmd = ['speech', 'kaldi_ark'] | |||||
elif inputs['audio_format'] == 'tfrecord': | |||||
data_cmd = ['speech', 'tfrecord'] | |||||
# generate asr inference command | |||||
cmd = { | |||||
'model_type': inputs['model_type'], | |||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available | |||||
'log_level': 'ERROR', | |||||
'audio_in': inputs['audio_lists'], | |||||
'name_and_type': data_cmd, | |||||
'asr_model_file': inputs['am_model_path'], | |||||
'idx_text': '' | |||||
} | |||||
if self.framework == Frameworks.torch: | |||||
config_file = open(inputs['asr_model_config']) | |||||
root = yaml.full_load(config_file) | |||||
config_file.close() | |||||
frontend_conf = None | |||||
if 'frontend_conf' in root: | |||||
frontend_conf = root['frontend_conf'] | |||||
cmd['beam_size'] = root['beam_size'] | |||||
cmd['penalty'] = root['penalty'] | |||||
cmd['maxlenratio'] = root['maxlenratio'] | |||||
cmd['minlenratio'] = root['minlenratio'] | |||||
cmd['ctc_weight'] = root['ctc_weight'] | |||||
cmd['lm_weight'] = root['lm_weight'] | |||||
cmd['asr_train_config'] = inputs['am_model_config'] | |||||
cmd['batch_size'] = inputs['model_config']['batch_size'] | |||||
cmd['frontend_conf'] = frontend_conf | |||||
elif self.framework == Frameworks.tf: | |||||
cmd['fs'] = inputs['model_config']['fs'] | |||||
cmd['hop_length'] = inputs['model_config']['hop_length'] | |||||
cmd['feature_dims'] = inputs['model_config']['feature_dims'] | |||||
cmd['predictions_file'] = 'text' | |||||
cmd['mvn_file'] = inputs['am_mvn_file'] | |||||
cmd['vocab_file'] = inputs['vocab_file'] | |||||
if 'idx_text' in inputs: | |||||
cmd['idx_text'] = inputs['idx_text'] | |||||
else: | |||||
raise ValueError('model type is mismatching') | |||||
inputs['asr_result'] = self.run_inference(cmd) | |||||
return inputs | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""process the asr results | |||||
""" | |||||
from easyasr.common import asr_utils | |||||
logger.info('Computing the result of ASR ...') | |||||
rst = {} | |||||
# single wav or pcm task | |||||
if inputs['recog_type'] == 'wav': | |||||
if 'asr_result' in inputs and len(inputs['asr_result']) > 0: | |||||
text = inputs['asr_result'][0]['value'] | |||||
if len(text) > 0: | |||||
rst[OutputKeys.TEXT] = text | |||||
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord | |||||
elif inputs['recog_type'] != 'wav': | |||||
inputs['reference_list'] = self.ref_list_tidy(inputs) | |||||
inputs['datasets_result'] = asr_utils.compute_wer( | |||||
inputs['asr_result'], inputs['reference_list']) | |||||
else: | |||||
raise ValueError('recog_type and audio_format are mismatching') | |||||
if 'datasets_result' in inputs: | |||||
rst[OutputKeys.TEXT] = inputs['datasets_result'] | |||||
return rst | |||||
def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]: | |||||
ref_list = [] | |||||
if inputs['audio_format'] == 'tfrecord': | |||||
# should assemble idx + txt | |||||
with open(inputs['reference_text'], 'r', encoding='utf-8') as r: | |||||
text_lines = r.readlines() | |||||
with open(inputs['idx_text'], 'r', encoding='utf-8') as i: | |||||
idx_lines = i.readlines() | |||||
j: int = 0 | |||||
while j < min(len(text_lines), len(idx_lines)): | |||||
idx_str = idx_lines[j].strip() | |||||
text_str = text_lines[j].strip().replace(' ', '') | |||||
item = {'key': idx_str, 'value': text_str} | |||||
ref_list.append(item) | |||||
j += 1 | |||||
else: | |||||
# text contain idx + sentence | |||||
with open(inputs['reference_text'], 'r', encoding='utf-8') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
line_item = line.split() | |||||
item = {'key': line_item[0], 'value': line_item[1]} | |||||
ref_list.append(item) | |||||
return ref_list | |||||
def run_inference(self, cmd): | |||||
asr_result = [] | |||||
if self.framework == Frameworks.torch: | |||||
from easyasr import asr_inference_paraformer_espnet | |||||
asr_result = asr_inference_paraformer_espnet.asr_inference( | |||||
batch_size=cmd['batch_size'], | |||||
maxlenratio=cmd['maxlenratio'], | |||||
minlenratio=cmd['minlenratio'], | |||||
beam_size=cmd['beam_size'], | |||||
ngpu=cmd['ngpu'], | |||||
ctc_weight=cmd['ctc_weight'], | |||||
lm_weight=cmd['lm_weight'], | |||||
penalty=cmd['penalty'], | |||||
log_level=cmd['log_level'], | |||||
name_and_type=cmd['name_and_type'], | |||||
audio_lists=cmd['audio_in'], | |||||
asr_train_config=cmd['asr_train_config'], | |||||
asr_model_file=cmd['asr_model_file'], | |||||
frontend_conf=cmd['frontend_conf']) | |||||
elif self.framework == Frameworks.tf: | |||||
from easyasr import asr_inference_paraformer_tf | |||||
asr_result = asr_inference_paraformer_tf.asr_inference( | |||||
ngpu=cmd['ngpu'], | |||||
name_and_type=cmd['name_and_type'], | |||||
audio_lists=cmd['audio_in'], | |||||
idx_text_file=cmd['idx_text'], | |||||
asr_model_file=cmd['asr_model_file'], | |||||
vocab_file=cmd['vocab_file'], | |||||
am_mvn_file=cmd['mvn_file'], | |||||
predictions_file=cmd['predictions_file'], | |||||
fs=cmd['fs'], | |||||
hop_length=cmd['hop_length'], | |||||
feature_dims=cmd['feature_dims']) | |||||
return asr_result |
@@ -1,14 +1,9 @@ | |||||
import io | |||||
import os | import os | ||||
import shutil | |||||
from pathlib import Path | |||||
from typing import Any, Dict, List | |||||
import yaml | |||||
from typing import Any, Dict, List, Union | |||||
from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
from modelscope.utils.constant import Fields | |||||
from modelscope.utils.constant import Fields, Frameworks | |||||
from .base import Preprocessor | from .base import Preprocessor | ||||
from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
@@ -19,44 +14,32 @@ __all__ = ['WavToScp'] | |||||
Fields.audio, module_name=Preprocessors.wav_to_scp) | Fields.audio, module_name=Preprocessors.wav_to_scp) | ||||
class WavToScp(Preprocessor): | class WavToScp(Preprocessor): | ||||
"""generate audio scp from wave or ark | """generate audio scp from wave or ark | ||||
Args: | |||||
workspace (str): | |||||
""" | """ | ||||
def __init__(self, workspace: str = None): | |||||
# the workspace path | |||||
if workspace is None or len(workspace) == 0: | |||||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
else: | |||||
self._workspace = workspace | |||||
if not os.path.exists(self._workspace): | |||||
os.mkdir(self._workspace) | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, | def __call__(self, | ||||
model: List[Model] = None, | |||||
model: Model = None, | |||||
recog_type: str = None, | recog_type: str = None, | ||||
audio_format: str = None, | audio_format: str = None, | ||||
wav_path: str = None) -> Dict[str, Any]: | |||||
assert len(model) > 0, 'preprocess model is invalid' | |||||
assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||||
assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||||
assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||||
self._am_model = model[0] | |||||
if len(model) == 2 and model[1] is not None: | |||||
self._lm_model = model[1] | |||||
out = self.forward(self._am_model.forward(), recog_type, audio_format, | |||||
wav_path) | |||||
audio_in: Union[str, bytes] = None) -> Dict[str, Any]: | |||||
assert model is not None, 'preprocess model is empty' | |||||
assert recog_type is not None and len( | |||||
recog_type) > 0, 'preprocess recog_type is empty' | |||||
assert audio_format is not None, 'preprocess audio_format is empty' | |||||
assert audio_in is not None, 'preprocess audio_in is empty' | |||||
self.am_model = model | |||||
out = self.forward(self.am_model.forward(), recog_type, audio_format, | |||||
audio_in) | |||||
return out | return out | ||||
def forward(self, model: Dict[str, Any], recog_type: str, | def forward(self, model: Dict[str, Any], recog_type: str, | ||||
audio_format: str, wav_path: str) -> Dict[str, Any]: | |||||
audio_format: str, audio_in: Union[str, | |||||
bytes]) -> Dict[str, Any]: | |||||
assert len(recog_type) > 0, 'preprocess recog_type is empty' | assert len(recog_type) > 0, 'preprocess recog_type is empty' | ||||
assert len(audio_format) > 0, 'preprocess audio_format is empty' | assert len(audio_format) > 0, 'preprocess audio_format is empty' | ||||
assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||||
assert os.path.exists(wav_path), 'preprocess wav_path does not exist' | |||||
assert len( | assert len( | ||||
model['am_model']) > 0, 'preprocess model[am_model] is empty' | model['am_model']) > 0, 'preprocess model[am_model] is empty' | ||||
assert len(model['am_model_path'] | assert len(model['am_model_path'] | ||||
@@ -70,90 +53,104 @@ class WavToScp(Preprocessor): | |||||
assert len(model['model_config'] | assert len(model['model_config'] | ||||
) > 0, 'preprocess model[model_config] is empty' | ) > 0, 'preprocess model[model_config] is empty' | ||||
# the am model name | |||||
am_model: str = model['am_model'] | |||||
# the am model file path | |||||
am_model_path: str = model['am_model_path'] | |||||
# the recognition model dir path | |||||
model_workspace: str = model['model_workspace'] | |||||
# the recognition model config dict | |||||
global_model_config_dict: str = model['model_config'] | |||||
rst = { | rst = { | ||||
'workspace': os.path.join(self._workspace, recog_type), | |||||
'am_model': am_model, | |||||
'am_model_path': am_model_path, | |||||
'model_workspace': model_workspace, | |||||
# the recognition model dir path | |||||
'model_workspace': model['model_workspace'], | |||||
# the am model name | |||||
'am_model': model['am_model'], | |||||
# the am model file path | |||||
'am_model_path': model['am_model_path'], | |||||
# the asr type setting, eg: test dev train wav | # the asr type setting, eg: test dev train wav | ||||
'recog_type': recog_type, | 'recog_type': recog_type, | ||||
# the asr audio format setting, eg: wav, kaldi_ark | |||||
# the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord | |||||
'audio_format': audio_format, | 'audio_format': audio_format, | ||||
# the test wav file path or the dataset path | |||||
'wav_path': wav_path, | |||||
'model_config': global_model_config_dict | |||||
# the recognition model config dict | |||||
'model_config': model['model_config'] | |||||
} | } | ||||
out = self._config_checking(rst) | |||||
out = self._env_setting(out) | |||||
if isinstance(audio_in, str): | |||||
# wav file path or the dataset path | |||||
rst['wav_path'] = audio_in | |||||
out = self.config_checking(rst) | |||||
out = self.env_setting(out) | |||||
if audio_format == 'wav': | if audio_format == 'wav': | ||||
out = self._scp_generation_from_wav(out) | |||||
out['audio_lists'] = self.scp_generation_from_wav(out) | |||||
elif audio_format == 'kaldi_ark': | elif audio_format == 'kaldi_ark': | ||||
out = self._scp_generation_from_ark(out) | |||||
out['audio_lists'] = self.scp_generation_from_ark(out) | |||||
elif audio_format == 'tfrecord': | |||||
out['audio_lists'] = os.path.join(out['wav_path'], 'data.records') | |||||
elif audio_format == 'pcm': | |||||
out['audio_lists'] = audio_in | |||||
return out | return out | ||||
def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
def config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""config checking | """config checking | ||||
""" | """ | ||||
assert inputs['model_config'].__contains__( | assert inputs['model_config'].__contains__( | ||||
'type'), 'model type does not exist' | 'type'), 'model type does not exist' | ||||
assert inputs['model_config'].__contains__( | |||||
'batch_size'), 'batch_size does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'am_model_config'), 'am_model_config does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'asr_model_config'), 'asr_model_config does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||||
am_model_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['am_model_config']) | |||||
assert os.path.exists( | |||||
am_model_config), 'am_model_config does not exist' | |||||
inputs['am_model_config'] = am_model_config | |||||
asr_model_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['asr_model_config']) | |||||
assert os.path.exists( | |||||
asr_model_config), 'asr_model_config does not exist' | |||||
asr_model_wav_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['asr_model_wav_config']) | |||||
assert os.path.exists( | |||||
asr_model_wav_config), 'asr_model_wav_config does not exist' | |||||
inputs['model_type'] = inputs['model_config']['type'] | inputs['model_type'] = inputs['model_config']['type'] | ||||
if inputs['audio_format'] == 'wav': | |||||
inputs['asr_model_config'] = asr_model_wav_config | |||||
if inputs['model_type'] == Frameworks.torch: | |||||
assert inputs['model_config'].__contains__( | |||||
'batch_size'), 'batch_size does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'am_model_config'), 'am_model_config does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'asr_model_config'), 'asr_model_config does not exist' | |||||
assert inputs['model_config'].__contains__( | |||||
'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||||
am_model_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['am_model_config']) | |||||
assert os.path.exists( | |||||
am_model_config), 'am_model_config does not exist' | |||||
inputs['am_model_config'] = am_model_config | |||||
asr_model_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['asr_model_config']) | |||||
assert os.path.exists( | |||||
asr_model_config), 'asr_model_config does not exist' | |||||
asr_model_wav_config: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['asr_model_wav_config']) | |||||
assert os.path.exists( | |||||
asr_model_wav_config), 'asr_model_wav_config does not exist' | |||||
if inputs['audio_format'] == 'wav' or inputs[ | |||||
'audio_format'] == 'pcm': | |||||
inputs['asr_model_config'] = asr_model_wav_config | |||||
else: | |||||
inputs['asr_model_config'] = asr_model_config | |||||
elif inputs['model_type'] == Frameworks.tf: | |||||
assert inputs['model_config'].__contains__( | |||||
'vocab_file'), 'vocab_file does not exist' | |||||
vocab_file: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['vocab_file']) | |||||
assert os.path.exists(vocab_file), 'vocab file does not exist' | |||||
inputs['vocab_file'] = vocab_file | |||||
assert inputs['model_config'].__contains__( | |||||
'am_mvn_file'), 'am_mvn_file does not exist' | |||||
am_mvn_file: str = os.path.join( | |||||
inputs['model_workspace'], | |||||
inputs['model_config']['am_mvn_file']) | |||||
assert os.path.exists(am_mvn_file), 'am mvn file does not exist' | |||||
inputs['am_mvn_file'] = am_mvn_file | |||||
else: | else: | ||||
inputs['asr_model_config'] = asr_model_config | |||||
raise ValueError('model type is mismatched') | |||||
return inputs | return inputs | ||||
def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
if not os.path.exists(inputs['workspace']): | |||||
os.mkdir(inputs['workspace']) | |||||
inputs['output'] = os.path.join(inputs['workspace'], 'logdir') | |||||
if not os.path.exists(inputs['output']): | |||||
os.mkdir(inputs['output']) | |||||
def env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
# run with datasets, should set datasets_path and text_path | # run with datasets, should set datasets_path and text_path | ||||
if inputs['recog_type'] != 'wav': | if inputs['recog_type'] != 'wav': | ||||
inputs['datasets_path'] = inputs['wav_path'] | inputs['datasets_path'] = inputs['wav_path'] | ||||
@@ -162,25 +159,39 @@ class WavToScp(Preprocessor): | |||||
if inputs['audio_format'] == 'wav': | if inputs['audio_format'] == 'wav': | ||||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | inputs['wav_path'] = os.path.join(inputs['datasets_path'], | ||||
'wav', inputs['recog_type']) | 'wav', inputs['recog_type']) | ||||
inputs['hypothesis_text'] = os.path.join( | |||||
inputs['reference_text'] = os.path.join( | |||||
inputs['datasets_path'], 'transcript', 'data.text') | inputs['datasets_path'], 'transcript', 'data.text') | ||||
assert os.path.exists(inputs['hypothesis_text'] | |||||
), 'hypothesis text does not exist' | |||||
assert os.path.exists( | |||||
inputs['reference_text']), 'reference text does not exist' | |||||
# run with datasets, and audio format is kaldi_ark | |||||
elif inputs['audio_format'] == 'kaldi_ark': | elif inputs['audio_format'] == 'kaldi_ark': | ||||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | inputs['wav_path'] = os.path.join(inputs['datasets_path'], | ||||
inputs['recog_type']) | inputs['recog_type']) | ||||
inputs['hypothesis_text'] = os.path.join( | |||||
inputs['reference_text'] = os.path.join( | |||||
inputs['wav_path'], 'data.text') | inputs['wav_path'], 'data.text') | ||||
assert os.path.exists(inputs['hypothesis_text'] | |||||
), 'hypothesis text does not exist' | |||||
assert os.path.exists( | |||||
inputs['reference_text']), 'reference text does not exist' | |||||
# run with datasets, and audio format is tfrecord | |||||
elif inputs['audio_format'] == 'tfrecord': | |||||
inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||||
inputs['recog_type']) | |||||
inputs['reference_text'] = os.path.join( | |||||
inputs['wav_path'], 'data.txt') | |||||
assert os.path.exists( | |||||
inputs['reference_text']), 'reference text does not exist' | |||||
inputs['idx_text'] = os.path.join(inputs['wav_path'], | |||||
'data.idx') | |||||
assert os.path.exists( | |||||
inputs['idx_text']), 'idx text does not exist' | |||||
return inputs | return inputs | ||||
def _scp_generation_from_wav(self, inputs: Dict[str, | |||||
Any]) -> Dict[str, Any]: | |||||
def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]: | |||||
"""scp generation from waveform files | """scp generation from waveform files | ||||
""" | """ | ||||
from easyasr.common import asr_utils | |||||
# find all waveform files | # find all waveform files | ||||
wav_list = [] | wav_list = [] | ||||
@@ -191,64 +202,46 @@ class WavToScp(Preprocessor): | |||||
wav_list.append(file_path) | wav_list.append(file_path) | ||||
else: | else: | ||||
wav_dir: str = inputs['wav_path'] | wav_dir: str = inputs['wav_path'] | ||||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir) | |||||
list_count: int = len(wav_list) | list_count: int = len(wav_list) | ||||
inputs['wav_count'] = list_count | inputs['wav_count'] = list_count | ||||
# store all wav into data.0.scp | |||||
inputs['thread_count'] = 1 | |||||
# store all wav into audio list | |||||
audio_lists = [] | |||||
j: int = 0 | j: int = 0 | ||||
wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||||
with open(wav_list_path, 'a') as f: | |||||
while j < list_count: | |||||
wav_file = wav_list[j] | |||||
wave_scp_content: str = os.path.splitext( | |||||
os.path.basename(wav_file))[0] | |||||
wave_scp_content += ' ' + wav_file + '\n' | |||||
f.write(wave_scp_content) | |||||
j += 1 | |||||
while j < list_count: | |||||
wav_file = wav_list[j] | |||||
wave_key: str = os.path.splitext(os.path.basename(wav_file))[0] | |||||
item = {'key': wave_key, 'file': wav_file} | |||||
audio_lists.append(item) | |||||
j += 1 | |||||
return inputs | |||||
return audio_lists | |||||
def _scp_generation_from_ark(self, inputs: Dict[str, | |||||
Any]) -> Dict[str, Any]: | |||||
def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]: | |||||
"""scp generation from kaldi ark file | """scp generation from kaldi ark file | ||||
""" | """ | ||||
inputs['thread_count'] = 1 | |||||
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') | ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') | ||||
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') | ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') | ||||
assert os.path.exists(ark_scp_path), 'data.scp does not exist' | assert os.path.exists(ark_scp_path), 'data.scp does not exist' | ||||
assert os.path.exists(ark_file_path), 'data.ark does not exist' | assert os.path.exists(ark_file_path), 'data.ark does not exist' | ||||
new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||||
with open(ark_scp_path, 'r', encoding='utf-8') as f: | with open(ark_scp_path, 'r', encoding='utf-8') as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
with open(new_ark_scp_path, 'w', encoding='utf-8') as n: | |||||
for line in lines: | |||||
outs = line.strip().split(' ') | |||||
if len(outs) == 2: | |||||
key = outs[0] | |||||
sub = outs[1].split(':') | |||||
if len(sub) == 2: | |||||
nums = sub[1] | |||||
content = key + ' ' + ark_file_path + ':' + nums + '\n' | |||||
n.write(content) | |||||
return inputs | |||||
def _recursion_dir_all_wave(self, wav_list, | |||||
dir_path: str) -> Dict[str, Any]: | |||||
dir_files = os.listdir(dir_path) | |||||
for file in dir_files: | |||||
file_path = os.path.join(dir_path, file) | |||||
if os.path.isfile(file_path): | |||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||||
wav_list.append(file_path) | |||||
elif os.path.isdir(file_path): | |||||
self._recursion_dir_all_wave(wav_list, file_path) | |||||
return wav_list | |||||
# store all ark item into audio list | |||||
audio_lists = [] | |||||
for line in lines: | |||||
outs = line.strip().split(' ') | |||||
if len(outs) == 2: | |||||
key = outs[0] | |||||
sub = outs[1].split(':') | |||||
if len(sub) == 2: | |||||
nums = sub[1] | |||||
content = ark_file_path + ':' + nums | |||||
item = {'key': key, 'file': content} | |||||
audio_lists.append(item) | |||||
return audio_lists |
@@ -1,3 +1,4 @@ | |||||
easyasr>=0.0.2 | |||||
espnet>=202204 | espnet>=202204 | ||||
#tts | #tts | ||||
h5py | h5py | ||||
@@ -1,15 +1,20 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | import os | ||||
import shutil | import shutil | ||||
import sys | |||||
import tarfile | import tarfile | ||||
import unittest | import unittest | ||||
from typing import Any, Dict, Union | |||||
import numpy as np | |||||
import requests | import requests | ||||
import soundfile | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.constant import ColorCodes, Tasks | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from modelscope.utils.test_utils import test_level | |||||
from modelscope.utils.test_utils import download_and_untar, test_level | |||||
logger = get_logger() | logger = get_logger() | ||||
@@ -21,6 +26,9 @@ LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AS | |||||
AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' | AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' | ||||
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' | AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' | ||||
TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' | |||||
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' | |||||
def un_tar_gz(fname, dirs): | def un_tar_gz(fname, dirs): | ||||
t = tarfile.open(fname) | t = tarfile.open(fname) | ||||
@@ -28,45 +36,168 @@ def un_tar_gz(fname, dirs): | |||||
class AutomaticSpeechRecognitionTest(unittest.TestCase): | class AutomaticSpeechRecognitionTest(unittest.TestCase): | ||||
action_info = { | |||||
'test_run_with_wav_pytorch': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'wav_example' | |||||
}, | |||||
'test_run_with_pcm_pytorch': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'wav_example' | |||||
}, | |||||
'test_run_with_wav_tf': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'wav_example' | |||||
}, | |||||
'test_run_with_pcm_tf': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'wav_example' | |||||
}, | |||||
'test_run_with_wav_dataset_pytorch': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'dataset_example' | |||||
}, | |||||
'test_run_with_wav_dataset_tf': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'dataset_example' | |||||
}, | |||||
'test_run_with_ark_dataset': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'dataset_example' | |||||
}, | |||||
'test_run_with_tfrecord_dataset': { | |||||
'checking_item': OutputKeys.TEXT, | |||||
'example': 'dataset_example' | |||||
}, | |||||
'dataset_example': { | |||||
'Wrd': 49532, # the number of words | |||||
'Snt': 5000, # the number of sentences | |||||
'Corr': 47276, # the number of correct words | |||||
'Ins': 49, # the number of insert words | |||||
'Del': 152, # the number of delete words | |||||
'Sub': 2207, # the number of substitution words | |||||
'wrong_words': 2408, # the number of wrong words | |||||
'wrong_sentences': 1598, # the number of wrong sentences | |||||
'Err': 4.86, # WER/CER | |||||
'S.Err': 31.96 # SER | |||||
}, | |||||
'wav_example': { | |||||
'text': '每一天都要快乐喔' | |||||
} | |||||
} | |||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||||
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||||
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | |||||
# this temporary workspace dir will store waveform files | # this temporary workspace dir will store waveform files | ||||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
if not os.path.exists(self._workspace): | |||||
os.mkdir(self._workspace) | |||||
self.workspace = os.path.join(os.getcwd(), '.tmp') | |||||
if not os.path.exists(self.workspace): | |||||
os.mkdir(self.workspace) | |||||
def tearDown(self) -> None: | |||||
# remove workspace dir (.tmp) | |||||
shutil.rmtree(self.workspace, ignore_errors=True) | |||||
def run_pipeline(self, model_id: str, | |||||
audio_in: Union[str, bytes]) -> Dict[str, Any]: | |||||
inference_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=model_id) | |||||
rec_result = inference_16k_pipline(audio_in) | |||||
return rec_result | |||||
def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||||
+ ColorCodes.END) | |||||
logger.error( | |||||
ColorCodes.MAGENTA + functions + ' correct result example:' | |||||
+ ColorCodes.YELLOW | |||||
+ str(self.action_info[self.action_info[functions]['example']]) | |||||
+ ColorCodes.END) | |||||
raise ValueError('asr result is mismatched') | |||||
def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||||
if result.__contains__(self.action_info[functions]['checking_item']): | |||||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||||
+ ColorCodes.END) | |||||
logger.info( | |||||
ColorCodes.YELLOW | |||||
+ str(result[self.action_info[functions]['checking_item']]) | |||||
+ ColorCodes.END) | |||||
else: | |||||
self.log_error(functions, result) | |||||
def wav2bytes(self, wav_file) -> bytes: | |||||
audio, fs = soundfile.read(wav_file) | |||||
# float32 -> int16 | |||||
audio = np.asarray(audio) | |||||
dtype = np.dtype('int16') | |||||
i = np.iinfo(dtype) | |||||
abs_max = 2**(i.bits - 1) | |||||
offset = i.min + abs_max | |||||
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||||
# int16(PCM_16) -> byte | |||||
audio = audio.tobytes() | |||||
return audio | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_wav(self): | |||||
def test_run_with_wav_pytorch(self): | |||||
'''run with single waveform file | '''run with single waveform file | ||||
''' | ''' | ||||
logger.info('Run ASR test with waveform file ...') | |||||
logger.info('Run ASR test with waveform file (pytorch)...') | |||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | ||||
inference_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
self.assertTrue(inference_16k_pipline is not None) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||||
self.check_result('test_run_with_wav_pytorch', rec_result) | |||||
rec_result = inference_16k_pipline(wav_file_path) | |||||
self.assertTrue(len(rec_result['rec_result']) > 0) | |||||
self.assertTrue(rec_result['rec_result'] != 'None') | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_pcm_pytorch(self): | |||||
'''run with wav data | |||||
''' | ''' | ||||
result structure: | |||||
{ | |||||
'rec_result': '每一天都要快乐喔' | |||||
} | |||||
or | |||||
{ | |||||
'rec_result': 'None' | |||||
} | |||||
logger.info('Run ASR test with wav data (pytorch)...') | |||||
audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_pytorch_model_id, audio_in=audio) | |||||
self.check_result('test_run_with_pcm_pytorch', rec_result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_wav_tf(self): | |||||
'''run with single waveform file | |||||
''' | ''' | ||||
logger.info('test_run_with_wav rec result: ' | |||||
+ rec_result['rec_result']) | |||||
logger.info('Run ASR test with waveform file (tensorflow)...') | |||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_tf_model_id, audio_in=wav_file_path) | |||||
self.check_result('test_run_with_wav_tf', rec_result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_pcm_tf(self): | |||||
'''run with wav data | |||||
''' | |||||
logger.info('Run ASR test with wav data (tensorflow)...') | |||||
audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_tf_model_id, audio_in=audio) | |||||
self.check_result('test_run_with_pcm_tf', rec_result) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_run_with_wav_dataset(self): | |||||
def test_run_with_wav_dataset_pytorch(self): | |||||
'''run with datasets, and audio format is waveform | '''run with datasets, and audio format is waveform | ||||
datasets directory: | datasets directory: | ||||
<dataset_path> | <dataset_path> | ||||
@@ -84,57 +215,48 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
data.text # hypothesis text | data.text # hypothesis text | ||||
''' | ''' | ||||
logger.info('Run ASR test with waveform dataset ...') | |||||
logger.info('Run ASR test with waveform dataset (pytorch)...') | |||||
logger.info('Downloading waveform testsets file ...') | logger.info('Downloading waveform testsets file ...') | ||||
# downloading pos_testsets file | |||||
testsets_file_path = os.path.join(self._workspace, | |||||
LITTLE_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(LITTLE_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename( | |||||
os.path.splitext( | |||||
os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0] | |||||
# dataset_path = <cwd>/.tmp/data_aishell/wav/test | |||||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav', | |||||
'test') | |||||
# untar the dataset_path file | |||||
if not os.path.exists(dataset_path): | |||||
un_tar_gz(testsets_file_path, self._workspace) | |||||
dataset_path = download_and_untar( | |||||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||||
LITTLE_TESTSETS_URL, self.workspace) | |||||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||||
inference_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
self.assertTrue(inference_16k_pipline is not None) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||||
self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | |||||
rec_result = inference_16k_pipline(wav_path=dataset_path) | |||||
self.assertTrue(len(rec_result['datasets_result']) > 0) | |||||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||||
''' | |||||
result structure: | |||||
{ | |||||
'rec_result': 'None', | |||||
'datasets_result': | |||||
{ | |||||
'Wrd': 1654, # the number of words | |||||
'Snt': 128, # the number of sentences | |||||
'Corr': 1573, # the number of correct words | |||||
'Ins': 1, # the number of insert words | |||||
'Del': 1, # the number of delete words | |||||
'Sub': 80, # the number of substitution words | |||||
'wrong_words': 82, # the number of wrong words | |||||
'wrong_sentences': 47, # the number of wrong sentences | |||||
'Err': 4.96, # WER/CER | |||||
'S.Err': 36.72 # SER | |||||
} | |||||
} | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_with_wav_dataset_tf(self): | |||||
'''run with datasets, and audio format is waveform | |||||
datasets directory: | |||||
<dataset_path> | |||||
wav | |||||
test # testsets | |||||
xx.wav | |||||
... | |||||
dev # devsets | |||||
yy.wav | |||||
... | |||||
train # trainsets | |||||
zz.wav | |||||
... | |||||
transcript | |||||
data.text # hypothesis text | |||||
''' | ''' | ||||
logger.info('test_run_with_wav_dataset datasets result: ') | |||||
logger.info(rec_result['datasets_result']) | |||||
logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||||
logger.info('Downloading waveform testsets file ...') | |||||
dataset_path = download_and_untar( | |||||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||||
LITTLE_TESTSETS_URL, self.workspace) | |||||
dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||||
self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run_with_ark_dataset(self): | def test_run_with_ark_dataset(self): | ||||
@@ -155,56 +277,40 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
data.text | data.text | ||||
''' | ''' | ||||
logger.info('Run ASR test with ark dataset ...') | |||||
logger.info('Run ASR test with ark dataset (pytorch)...') | |||||
logger.info('Downloading ark testsets file ...') | logger.info('Downloading ark testsets file ...') | ||||
# downloading pos_testsets file | |||||
testsets_file_path = os.path.join(self._workspace, | |||||
AISHELL1_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(AISHELL1_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename( | |||||
os.path.splitext( | |||||
os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0] | |||||
# dataset_path = <cwd>/.tmp/aishell1/test | |||||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test') | |||||
# untar the dataset_path file | |||||
if not os.path.exists(dataset_path): | |||||
un_tar_gz(testsets_file_path, self._workspace) | |||||
dataset_path = download_and_untar( | |||||
os.path.join(self.workspace, AISHELL1_TESTSETS_FILE), | |||||
AISHELL1_TESTSETS_URL, self.workspace) | |||||
dataset_path = os.path.join(dataset_path, 'test') | |||||
inference_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
self.assertTrue(inference_16k_pipline is not None) | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||||
self.check_result('test_run_with_ark_dataset', rec_result) | |||||
rec_result = inference_16k_pipline(wav_path=dataset_path) | |||||
self.assertTrue(len(rec_result['datasets_result']) > 0) | |||||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||||
''' | |||||
result structure: | |||||
{ | |||||
'rec_result': 'None', | |||||
'datasets_result': | |||||
{ | |||||
'Wrd': 104816, # the number of words | |||||
'Snt': 7176, # the number of sentences | |||||
'Corr': 99327, # the number of correct words | |||||
'Ins': 104, # the number of insert words | |||||
'Del': 155, # the number of delete words | |||||
'Sub': 5334, # the number of substitution words | |||||
'wrong_words': 5593, # the number of wrong words | |||||
'wrong_sentences': 2898, # the number of wrong sentences | |||||
'Err': 5.34, # WER/CER | |||||
'S.Err': 40.38 # SER | |||||
} | |||||
} | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_tfrecord_dataset(self): | |||||
'''run with datasets, and audio format is tfrecord | |||||
datasets directory: | |||||
<dataset_path> | |||||
test # testsets | |||||
data.records | |||||
data.idx | |||||
data.text | |||||
''' | ''' | ||||
logger.info('test_run_with_ark_dataset datasets result: ') | |||||
logger.info(rec_result['datasets_result']) | |||||
logger.info('Run ASR test with tfrecord dataset (tensorflow)...') | |||||
logger.info('Downloading tfrecord testsets file ...') | |||||
dataset_path = download_and_untar( | |||||
os.path.join(self.workspace, TFRECORD_TESTSETS_FILE), | |||||
TFRECORD_TESTSETS_URL, self.workspace) | |||||
dataset_path = os.path.join(dataset_path, 'test') | |||||
rec_result = self.run_pipeline( | |||||
model_id=self.am_tf_model_id, audio_in=dataset_path) | |||||
self.check_result('test_run_with_tfrecord_dataset', rec_result) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||