You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_automatic_speech_recognition.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import unittest
  5. from typing import Any, Dict, Union
  6. import numpy as np
  7. import soundfile
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import ColorCodes, Tasks
  11. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import download_and_untar, test_level
  14. logger = get_logger()
  15. WAV_FILE = 'data/test/audios/asr_example.wav'
  16. URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
  17. LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
  18. LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
  19. TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
  20. TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'
  21. class AutomaticSpeechRecognitionTest(unittest.TestCase,
  22. DemoCompatibilityCheck):
  23. action_info = {
  24. 'test_run_with_wav_pytorch': {
  25. 'checking_item': OutputKeys.TEXT,
  26. 'example': 'wav_example'
  27. },
  28. 'test_run_with_pcm_pytorch': {
  29. 'checking_item': OutputKeys.TEXT,
  30. 'example': 'wav_example'
  31. },
  32. 'test_run_with_wav_tf': {
  33. 'checking_item': OutputKeys.TEXT,
  34. 'example': 'wav_example'
  35. },
  36. 'test_run_with_pcm_tf': {
  37. 'checking_item': OutputKeys.TEXT,
  38. 'example': 'wav_example'
  39. },
  40. 'test_run_with_url_tf': {
  41. 'checking_item': OutputKeys.TEXT,
  42. 'example': 'wav_example'
  43. },
  44. 'test_run_with_wav_dataset_pytorch': {
  45. 'checking_item': OutputKeys.TEXT,
  46. 'example': 'dataset_example'
  47. },
  48. 'test_run_with_wav_dataset_tf': {
  49. 'checking_item': OutputKeys.TEXT,
  50. 'example': 'dataset_example'
  51. },
  52. 'dataset_example': {
  53. 'Wrd': 49532, # the number of words
  54. 'Snt': 5000, # the number of sentences
  55. 'Corr': 47276, # the number of correct words
  56. 'Ins': 49, # the number of insert words
  57. 'Del': 152, # the number of delete words
  58. 'Sub': 2207, # the number of substitution words
  59. 'wrong_words': 2408, # the number of wrong words
  60. 'wrong_sentences': 1598, # the number of wrong sentences
  61. 'Err': 4.86, # WER/CER
  62. 'S.Err': 31.96 # SER
  63. },
  64. 'wav_example': {
  65. 'text': '每一天都要快乐喔'
  66. }
  67. }
  68. def setUp(self) -> None:
  69. self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
  70. self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
  71. # this temporary workspace dir will store waveform files
  72. self.workspace = os.path.join(os.getcwd(), '.tmp')
  73. self.task = Tasks.auto_speech_recognition
  74. if not os.path.exists(self.workspace):
  75. os.mkdir(self.workspace)
  76. def tearDown(self) -> None:
  77. # remove workspace dir (.tmp)
  78. shutil.rmtree(self.workspace, ignore_errors=True)
  79. def run_pipeline(self,
  80. model_id: str,
  81. audio_in: Union[str, bytes],
  82. sr: int = 16000) -> Dict[str, Any]:
  83. inference_16k_pipline = pipeline(
  84. task=Tasks.auto_speech_recognition, model=model_id)
  85. rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
  86. return rec_result
  87. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  88. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  89. + ColorCodes.END)
  90. logger.error(
  91. ColorCodes.MAGENTA + functions + ' correct result example:'
  92. + ColorCodes.YELLOW
  93. + str(self.action_info[self.action_info[functions]['example']])
  94. + ColorCodes.END)
  95. raise ValueError('asr result is mismatched')
  96. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  97. if result.__contains__(self.action_info[functions]['checking_item']):
  98. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  99. + ColorCodes.END)
  100. logger.info(
  101. ColorCodes.YELLOW
  102. + str(result[self.action_info[functions]['checking_item']])
  103. + ColorCodes.END)
  104. else:
  105. self.log_error(functions, result)
  106. def wav2bytes(self, wav_file):
  107. audio, fs = soundfile.read(wav_file)
  108. # float32 -> int16
  109. audio = np.asarray(audio)
  110. dtype = np.dtype('int16')
  111. i = np.iinfo(dtype)
  112. abs_max = 2**(i.bits - 1)
  113. offset = i.min + abs_max
  114. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  115. # int16(PCM_16) -> byte
  116. audio = audio.tobytes()
  117. return audio, fs
  118. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  119. def test_run_with_wav_pytorch(self):
  120. """run with single waveform file
  121. """
  122. logger.info('Run ASR test with waveform file (pytorch)...')
  123. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  124. rec_result = self.run_pipeline(
  125. model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
  126. self.check_result('test_run_with_wav_pytorch', rec_result)
  127. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  128. def test_run_with_pcm_pytorch(self):
  129. """run with wav data
  130. """
  131. logger.info('Run ASR test with wav data (pytorch)...')
  132. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  133. rec_result = self.run_pipeline(
  134. model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
  135. self.check_result('test_run_with_pcm_pytorch', rec_result)
  136. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  137. def test_run_with_wav_tf(self):
  138. """run with single waveform file
  139. """
  140. logger.info('Run ASR test with waveform file (tensorflow)...')
  141. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  142. rec_result = self.run_pipeline(
  143. model_id=self.am_tf_model_id, audio_in=wav_file_path)
  144. self.check_result('test_run_with_wav_tf', rec_result)
  145. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  146. def test_run_with_pcm_tf(self):
  147. """run with wav data
  148. """
  149. logger.info('Run ASR test with wav data (tensorflow)...')
  150. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  151. rec_result = self.run_pipeline(
  152. model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
  153. self.check_result('test_run_with_pcm_tf', rec_result)
  154. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  155. def test_run_with_url_tf(self):
  156. """run with single url file
  157. """
  158. logger.info('Run ASR test with url file (tensorflow)...')
  159. rec_result = self.run_pipeline(
  160. model_id=self.am_tf_model_id, audio_in=URL_FILE)
  161. self.check_result('test_run_with_url_tf', rec_result)
  162. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  163. def test_run_with_wav_dataset_pytorch(self):
  164. """run with datasets, and audio format is waveform
  165. datasets directory:
  166. <dataset_path>
  167. wav
  168. test # testsets
  169. xx.wav
  170. ...
  171. dev # devsets
  172. yy.wav
  173. ...
  174. train # trainsets
  175. zz.wav
  176. ...
  177. transcript
  178. data.text # hypothesis text
  179. """
  180. logger.info('Run ASR test with waveform dataset (pytorch)...')
  181. logger.info('Downloading waveform testsets file ...')
  182. dataset_path = download_and_untar(
  183. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  184. LITTLE_TESTSETS_URL, self.workspace)
  185. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  186. rec_result = self.run_pipeline(
  187. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  188. self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
  189. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  190. def test_run_with_wav_dataset_tf(self):
  191. """run with datasets, and audio format is waveform
  192. datasets directory:
  193. <dataset_path>
  194. wav
  195. test # testsets
  196. xx.wav
  197. ...
  198. dev # devsets
  199. yy.wav
  200. ...
  201. train # trainsets
  202. zz.wav
  203. ...
  204. transcript
  205. data.text # hypothesis text
  206. """
  207. logger.info('Run ASR test with waveform dataset (tensorflow)...')
  208. logger.info('Downloading waveform testsets file ...')
  209. dataset_path = download_and_untar(
  210. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  211. LITTLE_TESTSETS_URL, self.workspace)
  212. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  213. rec_result = self.run_pipeline(
  214. model_id=self.am_tf_model_id, audio_in=dataset_path)
  215. self.check_result('test_run_with_wav_dataset_tf', rec_result)
  216. @unittest.skip('demo compatibility test is only enabled on a needed-basis')
  217. def test_demo_compatibility(self):
  218. self.compatibility_check()
  219. if __name__ == '__main__':
  220. unittest.main()