Browse Source

asr 统一接口,支持conformer和uniasr模型

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10964641

    * support new asr paraformer model

* support asr conformer model

* add new asr model tests

* fix format

* support new in params

* fix conflict

* type fix

* fix conflict
master^2
jiangyu.xzy 2 years ago
parent
commit
db7c5d1494
1 changed files with 23 additions and 10 deletions
  1. +23
    -10
      modelscope/pipelines/audio/asr_inference_pipeline.py

+ 23
- 10
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -124,6 +124,15 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
frontend_conf = None
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
token_num_relax = None
if 'token_num_relax' in root:
token_num_relax = root['token_num_relax']
decoding_ind = None
if 'decoding_ind' in root:
decoding_ind = root['decoding_ind']
decoding_mode = None
if 'decoding_mode' in root:
decoding_mode = root['decoding_mode']

cmd['beam_size'] = root['beam_size']
cmd['penalty'] = root['penalty']
@@ -138,6 +147,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['frontend_conf'] = frontend_conf
if frontend_conf is not None and 'fs' in frontend_conf:
cmd['fs']['model_fs'] = frontend_conf['fs']
cmd['token_num_relax'] = token_num_relax
cmd['decoding_ind'] = decoding_ind
cmd['decoding_mode'] = decoding_mode

elif self.framework == Frameworks.tf:
cmd['fs']['model_fs'] = inputs['model_config']['fs']
@@ -234,16 +246,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
def run_inference(self, cmd):
asr_result = []
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
if cmd['mode'] == 'asr':
from funasr.bin import asr_inference_modelscope as asr_inference
else:
from funasr.bin import asr_inference_paraformer_modelscope as asr_inference
from funasr.bin import asr_inference_launch

if hasattr(asr_inference, 'set_parameters'):
asr_inference.set_parameters(sample_rate=cmd['fs'])
asr_inference.set_parameters(language=cmd['lang'])
if hasattr(asr_inference_launch, 'set_parameters'):
asr_inference_launch.set_parameters(sample_rate=cmd['fs'])
asr_inference_launch.set_parameters(language=cmd['lang'])

asr_result = asr_inference.asr_inference(
asr_result = asr_inference_launch.inference_launch(
mode=cmd['mode'],
batch_size=cmd['batch_size'],
maxlenratio=cmd['maxlenratio'],
minlenratio=cmd['minlenratio'],
@@ -253,13 +263,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
lm_weight=cmd['lm_weight'],
penalty=cmd['penalty'],
log_level=cmd['log_level'],
name_and_type=cmd['name_and_type'],
data_path_and_name_and_type=cmd['name_and_type'],
audio_lists=cmd['audio_in'],
asr_train_config=cmd['asr_train_config'],
asr_model_file=cmd['asr_model_file'],
lm_file=cmd['lm_file'],
lm_train_config=cmd['lm_train_config'],
frontend_conf=cmd['frontend_conf'])
frontend_conf=cmd['frontend_conf'],
token_num_relax=cmd['token_num_relax'],
decoding_ind=cmd['decoding_ind'],
decoding_mode=cmd['decoding_mode'])
elif self.framework == Frameworks.torch:
from easyasr import asr_inference_paraformer_espnet



Loading…
Cancel
Save