|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import shutil
- import unittest
- from typing import Any, Dict, Union
-
- import numpy as np
- import soundfile
-
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines import pipeline
- from modelscope.utils.constant import ColorCodes, Tasks
- from modelscope.utils.demo_utils import DemoCompatibilityCheck
- from modelscope.utils.logger import get_logger
- from modelscope.utils.test_utils import download_and_untar, test_level
-
- logger = get_logger()
-
- WAV_FILE = 'data/test/audios/asr_example.wav'
- URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
-
-
- class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
- DemoCompatibilityCheck):
- action_info = {
- 'test_run_with_pcm': {
- 'checking_item': OutputKeys.TEXT,
- 'example': 'wav_example'
- },
- 'test_run_with_url': {
- 'checking_item': OutputKeys.TEXT,
- 'example': 'wav_example'
- },
- 'test_run_with_wav': {
- 'checking_item': OutputKeys.TEXT,
- 'example': 'wav_example'
- },
- 'wav_example': {
- 'text': '每一天都要快乐喔'
- }
- }
-
- def setUp(self) -> None:
- self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
- # this temporary workspace dir will store waveform files
- self.workspace = os.path.join(os.getcwd(), '.tmp')
- self.task = Tasks.auto_speech_recognition
- if not os.path.exists(self.workspace):
- os.mkdir(self.workspace)
-
- def tearDown(self) -> None:
- # remove workspace dir (.tmp)
- shutil.rmtree(self.workspace, ignore_errors=True)
-
- def run_pipeline(self,
- model_id: str,
- audio_in: Union[str, bytes],
- sr: int = None) -> Dict[str, Any]:
- inference_16k_pipline = pipeline(
- task=Tasks.auto_speech_recognition, model=model_id)
- rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
- return rec_result
-
- def log_error(self, functions: str, result: Dict[str, Any]) -> None:
- logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
- + ColorCodes.END)
- logger.error(
- ColorCodes.MAGENTA + functions + ' correct result example:'
- + ColorCodes.YELLOW
- + str(self.action_info[self.action_info[functions]['example']])
- + ColorCodes.END)
- raise ValueError('asr result is mismatched')
-
- def check_result(self, functions: str, result: Dict[str, Any]) -> None:
- if result.__contains__(self.action_info[functions]['checking_item']):
- logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
- + ColorCodes.END)
- logger.info(
- ColorCodes.YELLOW
- + str(result[self.action_info[functions]['checking_item']])
- + ColorCodes.END)
- else:
- self.log_error(functions, result)
-
- def wav2bytes(self, wav_file):
- audio, fs = soundfile.read(wav_file)
-
- # float32 -> int16
- audio = np.asarray(audio)
- dtype = np.dtype('int16')
- i = np.iinfo(dtype)
- abs_max = 2**(i.bits - 1)
- offset = i.min + abs_max
- audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
-
- # int16(PCM_16) -> byte
- audio = audio.tobytes()
- return audio, fs
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_pcm(self):
- """run with wav data
- """
- logger.info('Run ASR test with wav data (wenet)...')
- audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
- rec_result = self.run_pipeline(
- model_id=self.am_model_id, audio_in=audio, sr=sr)
- self.check_result('test_run_with_pcm', rec_result)
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_wav(self):
- """run with single waveform file
- """
- logger.info('Run ASR test with waveform file (wenet)...')
- wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
- rec_result = self.run_pipeline(
- model_id=self.am_model_id, audio_in=wav_file_path)
- self.check_result('test_run_with_wav', rec_result)
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_url(self):
- """run with single url file
- """
- logger.info('Run ASR test with url file (wenet)...')
- rec_result = self.run_pipeline(
- model_id=self.am_model_id, audio_in=URL_FILE)
- self.check_result('test_run_with_url', rec_result)
-
-
- if __name__ == '__main__':
- unittest.main()
|