diff --git a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py index 7db11190..1947629f 100644 --- a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py @@ -8,6 +8,7 @@ from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.constant import Tasks +import json import wenetruntime as wenet __all__ = ['WeNetAutomaticSpeechRecognition'] @@ -23,23 +24,15 @@ class WeNetAutomaticSpeechRecognition(Model): Args: model_dir (str): the model path. - am_model_name (str): the am model name from configuration.json - model_config (Dict[str, Any]): the detail config about model from configuration.json """ super().__init__(model_dir, am_model_name, model_config, *args, **kwargs) - self.model_cfg = { - # the recognition model dir path - 'model_dir': model_dir, - # the recognition model config dict - 'model_config': model_config - } - self.decoder = None - - def forward(self) -> Dict[str, Any]: - """preload model and return the info of the model - """ - model_dir = self.model_cfg['model_dir'] self.decoder = wenet.Decoder(model_dir, lang='chs') - return self.model_cfg + def forward(self, inputs: Dict[str, Any]) -> str: + if inputs['audio_format'] == 'wav': + rst = self.decoder.decode_wav(inputs['audio']) + else: + rst = self.decoder.decode(inputs['audio']) + text = json.loads(rst)['nbest'][0]['sentence'] + return {'text': text} diff --git a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py index 33e8c617..6df47bcb 100644 --- a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py @@ -29,8 +29,6 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): """use `model` and `preprocessor` to create an asr pipeline for prediction """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.model_cfg = self.model.forward() - self.decoder = self.model.decoder def __call__(self, audio_in: Union[str, bytes], @@ -68,17 +66,19 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): if checking_audio_fs is not None: self.audio_fs = checking_audio_fs - self.model_cfg['audio'] = self.audio_in - self.model_cfg['audio_fs'] = self.audio_fs - - output = self.forward(self.model_cfg) + inputs = { + 'audio': self.audio_in, + 'audio_format': self.audio_format, + 'audio_fs': self.audio_fs + } + output = self.forward(inputs) rst = self.postprocess(output['asr_result']) return rst def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Decoding """ - inputs['asr_result'] = self.decoder.decode(inputs['audio']) + inputs['asr_result'] = self.model(inputs) return inputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/pipelines/test_wenet_automatic_speech_recognition.py b/tests/pipelines/test_wenet_automatic_speech_recognition.py new file mode 100644 index 00000000..4adf8119 --- /dev/null +++ b/tests/pipelines/test_wenet_automatic_speech_recognition.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest +from typing import Any, Dict, Union + +import numpy as np +import soundfile + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ColorCodes, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import download_and_untar, test_level + +logger = get_logger() + +WAV_FILE = 'data/test/audios/asr_example.wav' +URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' + + +class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + action_info = { + 'test_run_with_pcm': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_url': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_wav': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'wav_example': { + 'text': '每一天都要快乐喔' + } + } + + def setUp(self) -> None: + self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' + # this temporary workspace dir will store waveform files + self.workspace = os.path.join(os.getcwd(), '.tmp') + self.task = Tasks.auto_speech_recognition + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + # remove workspace dir (.tmp) + shutil.rmtree(self.workspace, ignore_errors=True) + + def run_pipeline(self, + model_id: str, + audio_in: Union[str, bytes], + sr: int = None) -> Dict[str, Any]: + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=model_id) + rec_result = inference_16k_pipline(audio_in, audio_fs=sr) + return rec_result + + def log_error(self, functions: str, result: Dict[str, Any]) -> None: + logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + + ColorCodes.END) + logger.error( + ColorCodes.MAGENTA + functions + ' correct result example:' + + ColorCodes.YELLOW + + str(self.action_info[self.action_info[functions]['example']]) + + ColorCodes.END) + raise ValueError('asr result is mismatched') + + def check_result(self, functions: str, result: Dict[str, Any]) -> None: + if result.__contains__(self.action_info[functions]['checking_item']): + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + logger.info( + ColorCodes.YELLOW + + str(result[self.action_info[functions]['checking_item']]) + + ColorCodes.END) + else: + self.log_error(functions, result) + + def wav2bytes(self, wav_file): + audio, fs = soundfile.read(wav_file) + + # float32 -> int16 + audio = np.asarray(audio) + dtype = np.dtype('int16') + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) + + # int16(PCM_16) -> byte + audio = audio.tobytes() + return audio, fs + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_pcm(self): + """run with wav data + """ + logger.info('Run ASR test with wav data (wenet)...') + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=audio, sr=sr) + self.check_result('test_run_with_pcm', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + """run with single waveform file + """ + logger.info('Run ASR test with waveform file (wenet)...') + wav_file_path = os.path.join(os.getcwd(), WAV_FILE) + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=wav_file_path) + self.check_result('test_run_with_wav', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_url(self): + """run with single url file + """ + logger.info('Run ASR test with url file (wenet)...') + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url', rec_result) + + +if __name__ == '__main__': + unittest.main()