You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_automatic_speech_recognition.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import sys
  5. import tarfile
  6. import unittest
  7. from typing import Any, Dict, Union
  8. import numpy as np
  9. import requests
  10. import soundfile
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines import pipeline
  13. from modelscope.utils.constant import ColorCodes, Tasks
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.test_utils import download_and_untar, test_level
  16. logger = get_logger()
  17. WAV_FILE = 'data/test/audios/asr_example.wav'
  18. LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
  19. LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
  20. AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
  21. AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'
  22. TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
  23. TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'
  24. def un_tar_gz(fname, dirs):
  25. t = tarfile.open(fname)
  26. t.extractall(path=dirs)
  27. class AutomaticSpeechRecognitionTest(unittest.TestCase):
  28. action_info = {
  29. 'test_run_with_wav_pytorch': {
  30. 'checking_item': OutputKeys.TEXT,
  31. 'example': 'wav_example'
  32. },
  33. 'test_run_with_pcm_pytorch': {
  34. 'checking_item': OutputKeys.TEXT,
  35. 'example': 'wav_example'
  36. },
  37. 'test_run_with_wav_tf': {
  38. 'checking_item': OutputKeys.TEXT,
  39. 'example': 'wav_example'
  40. },
  41. 'test_run_with_pcm_tf': {
  42. 'checking_item': OutputKeys.TEXT,
  43. 'example': 'wav_example'
  44. },
  45. 'test_run_with_wav_dataset_pytorch': {
  46. 'checking_item': OutputKeys.TEXT,
  47. 'example': 'dataset_example'
  48. },
  49. 'test_run_with_wav_dataset_tf': {
  50. 'checking_item': OutputKeys.TEXT,
  51. 'example': 'dataset_example'
  52. },
  53. 'test_run_with_ark_dataset': {
  54. 'checking_item': OutputKeys.TEXT,
  55. 'example': 'dataset_example'
  56. },
  57. 'test_run_with_tfrecord_dataset': {
  58. 'checking_item': OutputKeys.TEXT,
  59. 'example': 'dataset_example'
  60. },
  61. 'dataset_example': {
  62. 'Wrd': 49532, # the number of words
  63. 'Snt': 5000, # the number of sentences
  64. 'Corr': 47276, # the number of correct words
  65. 'Ins': 49, # the number of insert words
  66. 'Del': 152, # the number of delete words
  67. 'Sub': 2207, # the number of substitution words
  68. 'wrong_words': 2408, # the number of wrong words
  69. 'wrong_sentences': 1598, # the number of wrong sentences
  70. 'Err': 4.86, # WER/CER
  71. 'S.Err': 31.96 # SER
  72. },
  73. 'wav_example': {
  74. 'text': '每一天都要快乐喔'
  75. }
  76. }
  77. def setUp(self) -> None:
  78. self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
  79. self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
  80. # this temporary workspace dir will store waveform files
  81. self.workspace = os.path.join(os.getcwd(), '.tmp')
  82. if not os.path.exists(self.workspace):
  83. os.mkdir(self.workspace)
  84. def tearDown(self) -> None:
  85. # remove workspace dir (.tmp)
  86. shutil.rmtree(self.workspace, ignore_errors=True)
  87. def run_pipeline(self, model_id: str,
  88. audio_in: Union[str, bytes]) -> Dict[str, Any]:
  89. inference_16k_pipline = pipeline(
  90. task=Tasks.auto_speech_recognition, model=model_id)
  91. rec_result = inference_16k_pipline(audio_in)
  92. return rec_result
  93. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  94. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  95. + ColorCodes.END)
  96. logger.error(
  97. ColorCodes.MAGENTA + functions + ' correct result example:'
  98. + ColorCodes.YELLOW
  99. + str(self.action_info[self.action_info[functions]['example']])
  100. + ColorCodes.END)
  101. raise ValueError('asr result is mismatched')
  102. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  103. if result.__contains__(self.action_info[functions]['checking_item']):
  104. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  105. + ColorCodes.END)
  106. logger.info(
  107. ColorCodes.YELLOW
  108. + str(result[self.action_info[functions]['checking_item']])
  109. + ColorCodes.END)
  110. else:
  111. self.log_error(functions, result)
  112. def wav2bytes(self, wav_file) -> bytes:
  113. audio, fs = soundfile.read(wav_file)
  114. # float32 -> int16
  115. audio = np.asarray(audio)
  116. dtype = np.dtype('int16')
  117. i = np.iinfo(dtype)
  118. abs_max = 2**(i.bits - 1)
  119. offset = i.min + abs_max
  120. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  121. # int16(PCM_16) -> byte
  122. audio = audio.tobytes()
  123. return audio
  124. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  125. def test_run_with_wav_pytorch(self):
  126. '''run with single waveform file
  127. '''
  128. logger.info('Run ASR test with waveform file (pytorch)...')
  129. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  130. rec_result = self.run_pipeline(
  131. model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
  132. self.check_result('test_run_with_wav_pytorch', rec_result)
  133. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  134. def test_run_with_pcm_pytorch(self):
  135. '''run with wav data
  136. '''
  137. logger.info('Run ASR test with wav data (pytorch)...')
  138. audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  139. rec_result = self.run_pipeline(
  140. model_id=self.am_pytorch_model_id, audio_in=audio)
  141. self.check_result('test_run_with_pcm_pytorch', rec_result)
  142. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  143. def test_run_with_wav_tf(self):
  144. '''run with single waveform file
  145. '''
  146. logger.info('Run ASR test with waveform file (tensorflow)...')
  147. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  148. rec_result = self.run_pipeline(
  149. model_id=self.am_tf_model_id, audio_in=wav_file_path)
  150. self.check_result('test_run_with_wav_tf', rec_result)
  151. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  152. def test_run_with_pcm_tf(self):
  153. '''run with wav data
  154. '''
  155. logger.info('Run ASR test with wav data (tensorflow)...')
  156. audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  157. rec_result = self.run_pipeline(
  158. model_id=self.am_tf_model_id, audio_in=audio)
  159. self.check_result('test_run_with_pcm_tf', rec_result)
  160. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  161. def test_run_with_wav_dataset_pytorch(self):
  162. '''run with datasets, and audio format is waveform
  163. datasets directory:
  164. <dataset_path>
  165. wav
  166. test # testsets
  167. xx.wav
  168. ...
  169. dev # devsets
  170. yy.wav
  171. ...
  172. train # trainsets
  173. zz.wav
  174. ...
  175. transcript
  176. data.text # hypothesis text
  177. '''
  178. logger.info('Run ASR test with waveform dataset (pytorch)...')
  179. logger.info('Downloading waveform testsets file ...')
  180. dataset_path = download_and_untar(
  181. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  182. LITTLE_TESTSETS_URL, self.workspace)
  183. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  184. rec_result = self.run_pipeline(
  185. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  186. self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
  187. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  188. def test_run_with_wav_dataset_tf(self):
  189. '''run with datasets, and audio format is waveform
  190. datasets directory:
  191. <dataset_path>
  192. wav
  193. test # testsets
  194. xx.wav
  195. ...
  196. dev # devsets
  197. yy.wav
  198. ...
  199. train # trainsets
  200. zz.wav
  201. ...
  202. transcript
  203. data.text # hypothesis text
  204. '''
  205. logger.info('Run ASR test with waveform dataset (tensorflow)...')
  206. logger.info('Downloading waveform testsets file ...')
  207. dataset_path = download_and_untar(
  208. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  209. LITTLE_TESTSETS_URL, self.workspace)
  210. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  211. rec_result = self.run_pipeline(
  212. model_id=self.am_tf_model_id, audio_in=dataset_path)
  213. self.check_result('test_run_with_wav_dataset_tf', rec_result)
  214. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  215. def test_run_with_ark_dataset(self):
  216. '''run with datasets, and audio format is kaldi_ark
  217. datasets directory:
  218. <dataset_path>
  219. test # testsets
  220. data.ark
  221. data.scp
  222. data.text
  223. dev # devsets
  224. data.ark
  225. data.scp
  226. data.text
  227. train # trainsets
  228. data.ark
  229. data.scp
  230. data.text
  231. '''
  232. logger.info('Run ASR test with ark dataset (pytorch)...')
  233. logger.info('Downloading ark testsets file ...')
  234. dataset_path = download_and_untar(
  235. os.path.join(self.workspace, AISHELL1_TESTSETS_FILE),
  236. AISHELL1_TESTSETS_URL, self.workspace)
  237. dataset_path = os.path.join(dataset_path, 'test')
  238. rec_result = self.run_pipeline(
  239. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  240. self.check_result('test_run_with_ark_dataset', rec_result)
  241. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  242. def test_run_with_tfrecord_dataset(self):
  243. '''run with datasets, and audio format is tfrecord
  244. datasets directory:
  245. <dataset_path>
  246. test # testsets
  247. data.records
  248. data.idx
  249. data.text
  250. '''
  251. logger.info('Run ASR test with tfrecord dataset (tensorflow)...')
  252. logger.info('Downloading tfrecord testsets file ...')
  253. dataset_path = download_and_untar(
  254. os.path.join(self.workspace, TFRECORD_TESTSETS_FILE),
  255. TFRECORD_TESTSETS_URL, self.workspace)
  256. dataset_path = os.path.join(dataset_path, 'test')
  257. rec_result = self.run_pipeline(
  258. model_id=self.am_tf_model_id, audio_in=dataset_path)
  259. self.check_result('test_run_with_tfrecord_dataset', rec_result)
  260. if __name__ == '__main__':
  261. unittest.main()