You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_wenet_automatic_speech_recognition.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import unittest
  5. from typing import Any, Dict, Union
  6. import numpy as np
  7. import soundfile
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import ColorCodes, Tasks
  11. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import download_and_untar, test_level
  14. logger = get_logger()
  15. WAV_FILE = 'data/test/audios/asr_example.wav'
  16. URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
  17. class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
  18. DemoCompatibilityCheck):
  19. action_info = {
  20. 'test_run_with_pcm': {
  21. 'checking_item': OutputKeys.TEXT,
  22. 'example': 'wav_example'
  23. },
  24. 'test_run_with_url': {
  25. 'checking_item': OutputKeys.TEXT,
  26. 'example': 'wav_example'
  27. },
  28. 'test_run_with_wav': {
  29. 'checking_item': OutputKeys.TEXT,
  30. 'example': 'wav_example'
  31. },
  32. 'wav_example': {
  33. 'text': '每一天都要快乐喔'
  34. }
  35. }
  36. def setUp(self) -> None:
  37. self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
  38. # this temporary workspace dir will store waveform files
  39. self.workspace = os.path.join(os.getcwd(), '.tmp')
  40. self.task = Tasks.auto_speech_recognition
  41. if not os.path.exists(self.workspace):
  42. os.mkdir(self.workspace)
  43. def tearDown(self) -> None:
  44. # remove workspace dir (.tmp)
  45. shutil.rmtree(self.workspace, ignore_errors=True)
  46. def run_pipeline(self,
  47. model_id: str,
  48. audio_in: Union[str, bytes],
  49. sr: int = None) -> Dict[str, Any]:
  50. inference_16k_pipline = pipeline(
  51. task=Tasks.auto_speech_recognition, model=model_id)
  52. rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
  53. return rec_result
  54. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  55. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  56. + ColorCodes.END)
  57. logger.error(
  58. ColorCodes.MAGENTA + functions + ' correct result example:'
  59. + ColorCodes.YELLOW
  60. + str(self.action_info[self.action_info[functions]['example']])
  61. + ColorCodes.END)
  62. raise ValueError('asr result is mismatched')
  63. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  64. if result.__contains__(self.action_info[functions]['checking_item']):
  65. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  66. + ColorCodes.END)
  67. logger.info(
  68. ColorCodes.YELLOW
  69. + str(result[self.action_info[functions]['checking_item']])
  70. + ColorCodes.END)
  71. else:
  72. self.log_error(functions, result)
  73. def wav2bytes(self, wav_file):
  74. audio, fs = soundfile.read(wav_file)
  75. # float32 -> int16
  76. audio = np.asarray(audio)
  77. dtype = np.dtype('int16')
  78. i = np.iinfo(dtype)
  79. abs_max = 2**(i.bits - 1)
  80. offset = i.min + abs_max
  81. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  82. # int16(PCM_16) -> byte
  83. audio = audio.tobytes()
  84. return audio, fs
  85. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  86. def test_run_with_pcm(self):
  87. """run with wav data
  88. """
  89. logger.info('Run ASR test with wav data (wenet)...')
  90. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  91. rec_result = self.run_pipeline(
  92. model_id=self.am_model_id, audio_in=audio, sr=sr)
  93. self.check_result('test_run_with_pcm', rec_result)
  94. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  95. def test_run_with_wav(self):
  96. """run with single waveform file
  97. """
  98. logger.info('Run ASR test with waveform file (wenet)...')
  99. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  100. rec_result = self.run_pipeline(
  101. model_id=self.am_model_id, audio_in=wav_file_path)
  102. self.check_result('test_run_with_wav', rec_result)
  103. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  104. def test_run_with_url(self):
  105. """run with single url file
  106. """
  107. logger.info('Run ASR test with url file (wenet)...')
  108. rec_result = self.run_pipeline(
  109. model_id=self.am_model_id, audio_in=URL_FILE)
  110. self.check_result('test_run_with_url', rec_result)
  111. if __name__ == '__main__':
  112. unittest.main()