|
|
@@ -0,0 +1,131 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import os |
|
|
|
import shutil |
|
|
|
import unittest |
|
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import soundfile |
|
|
|
|
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.pipelines import pipeline |
|
|
|
from modelscope.utils.constant import ColorCodes, Tasks |
|
|
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
from modelscope.utils.test_utils import download_and_untar, test_level |
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
WAV_FILE = 'data/test/audios/asr_example.wav' |
|
|
|
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' |
|
|
|
|
|
|
|
|
|
|
|
class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, |
|
|
|
DemoCompatibilityCheck): |
|
|
|
action_info = { |
|
|
|
'test_run_with_pcm': { |
|
|
|
'checking_item': OutputKeys.TEXT, |
|
|
|
'example': 'wav_example' |
|
|
|
}, |
|
|
|
'test_run_with_url': { |
|
|
|
'checking_item': OutputKeys.TEXT, |
|
|
|
'example': 'wav_example' |
|
|
|
}, |
|
|
|
'test_run_with_wav': { |
|
|
|
'checking_item': OutputKeys.TEXT, |
|
|
|
'example': 'wav_example' |
|
|
|
}, |
|
|
|
'wav_example': { |
|
|
|
'text': '每一天都要快乐喔' |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
def setUp(self) -> None: |
|
|
|
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' |
|
|
|
# this temporary workspace dir will store waveform files |
|
|
|
self.workspace = os.path.join(os.getcwd(), '.tmp') |
|
|
|
self.task = Tasks.auto_speech_recognition |
|
|
|
if not os.path.exists(self.workspace): |
|
|
|
os.mkdir(self.workspace) |
|
|
|
|
|
|
|
def tearDown(self) -> None: |
|
|
|
# remove workspace dir (.tmp) |
|
|
|
shutil.rmtree(self.workspace, ignore_errors=True) |
|
|
|
|
|
|
|
def run_pipeline(self, |
|
|
|
model_id: str, |
|
|
|
audio_in: Union[str, bytes], |
|
|
|
sr: int = None) -> Dict[str, Any]: |
|
|
|
inference_16k_pipline = pipeline( |
|
|
|
task=Tasks.auto_speech_recognition, model=model_id) |
|
|
|
rec_result = inference_16k_pipline(audio_in, audio_fs=sr) |
|
|
|
return rec_result |
|
|
|
|
|
|
|
def log_error(self, functions: str, result: Dict[str, Any]) -> None: |
|
|
|
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' |
|
|
|
+ ColorCodes.END) |
|
|
|
logger.error( |
|
|
|
ColorCodes.MAGENTA + functions + ' correct result example:' |
|
|
|
+ ColorCodes.YELLOW |
|
|
|
+ str(self.action_info[self.action_info[functions]['example']]) |
|
|
|
+ ColorCodes.END) |
|
|
|
raise ValueError('asr result is mismatched') |
|
|
|
|
|
|
|
def check_result(self, functions: str, result: Dict[str, Any]) -> None: |
|
|
|
if result.__contains__(self.action_info[functions]['checking_item']): |
|
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' |
|
|
|
+ ColorCodes.END) |
|
|
|
logger.info( |
|
|
|
ColorCodes.YELLOW |
|
|
|
+ str(result[self.action_info[functions]['checking_item']]) |
|
|
|
+ ColorCodes.END) |
|
|
|
else: |
|
|
|
self.log_error(functions, result) |
|
|
|
|
|
|
|
def wav2bytes(self, wav_file): |
|
|
|
audio, fs = soundfile.read(wav_file) |
|
|
|
|
|
|
|
# float32 -> int16 |
|
|
|
audio = np.asarray(audio) |
|
|
|
dtype = np.dtype('int16') |
|
|
|
i = np.iinfo(dtype) |
|
|
|
abs_max = 2**(i.bits - 1) |
|
|
|
offset = i.min + abs_max |
|
|
|
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) |
|
|
|
|
|
|
|
# int16(PCM_16) -> byte |
|
|
|
audio = audio.tobytes() |
|
|
|
return audio, fs |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_pcm(self): |
|
|
|
"""run with wav data |
|
|
|
""" |
|
|
|
logger.info('Run ASR test with wav data (wenet)...') |
|
|
|
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) |
|
|
|
rec_result = self.run_pipeline( |
|
|
|
model_id=self.am_model_id, audio_in=audio, sr=sr) |
|
|
|
self.check_result('test_run_with_pcm', rec_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_wav(self): |
|
|
|
"""run with single waveform file |
|
|
|
""" |
|
|
|
logger.info('Run ASR test with waveform file (wenet)...') |
|
|
|
wav_file_path = os.path.join(os.getcwd(), WAV_FILE) |
|
|
|
rec_result = self.run_pipeline( |
|
|
|
model_id=self.am_model_id, audio_in=wav_file_path) |
|
|
|
self.check_result('test_run_with_wav', rec_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_url(self): |
|
|
|
"""run with single url file |
|
|
|
""" |
|
|
|
logger.info('Run ASR test with url file (wenet)...') |
|
|
|
rec_result = self.run_pipeline( |
|
|
|
model_id=self.am_model_id, audio_in=URL_FILE) |
|
|
|
self.check_result('test_run_with_url', rec_result) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
unittest.main() |