You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_automatic_speech_recognition.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import unittest
  5. from typing import Any, Dict, Union
  6. import numpy as np
  7. import soundfile
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import ColorCodes, Tasks
  11. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import download_and_untar, test_level
  14. logger = get_logger()
  15. WAV_FILE = 'data/test/audios/asr_example.wav'
  16. URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
  17. LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
  18. LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
  19. TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
  20. TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'
  21. class AutomaticSpeechRecognitionTest(unittest.TestCase,
  22. DemoCompatibilityCheck):
  23. action_info = {
  24. 'test_run_with_wav_pytorch': {
  25. 'checking_item': OutputKeys.TEXT,
  26. 'example': 'wav_example'
  27. },
  28. 'test_run_with_pcm_pytorch': {
  29. 'checking_item': OutputKeys.TEXT,
  30. 'example': 'wav_example'
  31. },
  32. 'test_run_with_wav_tf': {
  33. 'checking_item': OutputKeys.TEXT,
  34. 'example': 'wav_example'
  35. },
  36. 'test_run_with_pcm_tf': {
  37. 'checking_item': OutputKeys.TEXT,
  38. 'example': 'wav_example'
  39. },
  40. 'test_run_with_url_pytorch': {
  41. 'checking_item': OutputKeys.TEXT,
  42. 'example': 'wav_example'
  43. },
  44. 'test_run_with_url_tf': {
  45. 'checking_item': OutputKeys.TEXT,
  46. 'example': 'wav_example'
  47. },
  48. 'test_run_with_wav_dataset_pytorch': {
  49. 'checking_item': OutputKeys.TEXT,
  50. 'example': 'dataset_example'
  51. },
  52. 'test_run_with_wav_dataset_tf': {
  53. 'checking_item': OutputKeys.TEXT,
  54. 'example': 'dataset_example'
  55. },
  56. 'dataset_example': {
  57. 'Wrd': 49532, # the number of words
  58. 'Snt': 5000, # the number of sentences
  59. 'Corr': 47276, # the number of correct words
  60. 'Ins': 49, # the number of insert words
  61. 'Del': 152, # the number of delete words
  62. 'Sub': 2207, # the number of substitution words
  63. 'wrong_words': 2408, # the number of wrong words
  64. 'wrong_sentences': 1598, # the number of wrong sentences
  65. 'Err': 4.86, # WER/CER
  66. 'S.Err': 31.96 # SER
  67. },
  68. 'wav_example': {
  69. 'text': '每一天都要快乐喔'
  70. }
  71. }
  72. all_models_info = [
  73. {
  74. 'model_group': 'damo',
  75. 'model_id':
  76. 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
  77. 'wav_path': 'data/test/audios/asr_example.wav'
  78. },
  79. {
  80. 'model_group': 'damo',
  81. 'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch',
  82. 'wav_path': 'data/test/audios/asr_example.wav'
  83. },
  84. {
  85. 'model_group': 'damo',
  86. 'model_id':
  87. 'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
  88. 'wav_path': 'data/test/audios/asr_example.wav'
  89. },
  90. {
  91. 'model_group': 'damo',
  92. 'model_id':
  93. 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1',
  94. 'wav_path': 'data/test/audios/asr_example_8K.wav'
  95. },
  96. {
  97. 'model_group': 'damo',
  98. 'model_id':
  99. 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online',
  100. 'wav_path': 'data/test/audios/asr_example.wav'
  101. },
  102. {
  103. 'model_group': 'damo',
  104. 'model_id':
  105. 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
  106. 'wav_path': 'data/test/audios/asr_example.wav'
  107. },
  108. {
  109. 'model_group': 'damo',
  110. 'model_id':
  111. 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online',
  112. 'wav_path': 'data/test/audios/asr_example_8K.wav'
  113. },
  114. {
  115. 'model_group': 'damo',
  116. 'model_id':
  117. 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline',
  118. 'wav_path': 'data/test/audios/asr_example_8K.wav'
  119. },
  120. {
  121. 'model_group': 'damo',
  122. 'model_id':
  123. 'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
  124. 'wav_path': 'data/test/audios/asr_example.wav'
  125. },
  126. {
  127. 'model_group': 'damo',
  128. 'model_id':
  129. 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online',
  130. 'wav_path': 'data/test/audios/asr_example_cn_en.wav'
  131. },
  132. {
  133. 'model_group': 'damo',
  134. 'model_id':
  135. 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline',
  136. 'wav_path': 'data/test/audios/asr_example_cn_en.wav'
  137. },
  138. {
  139. 'model_group': 'damo',
  140. 'model_id':
  141. 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online',
  142. 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
  143. },
  144. {
  145. 'model_group': 'damo',
  146. 'model_id':
  147. 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline',
  148. 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
  149. },
  150. {
  151. 'model_group': 'damo',
  152. 'model_id':
  153. 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online',
  154. 'wav_path': 'data/test/audios/asr_example.wav'
  155. },
  156. {
  157. 'model_group': 'damo',
  158. 'model_id':
  159. 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online',
  160. 'wav_path': 'data/test/audios/asr_example_8K.wav'
  161. },
  162. {
  163. 'model_group': 'damo',
  164. 'model_id':
  165. 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline',
  166. 'wav_path': 'data/test/audios/asr_example_en.wav'
  167. },
  168. {
  169. 'model_group': 'damo',
  170. 'model_id':
  171. 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online',
  172. 'wav_path': 'data/test/audios/asr_example_en.wav'
  173. },
  174. {
  175. 'model_group': 'damo',
  176. 'model_id':
  177. 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline',
  178. 'wav_path': 'data/test/audios/asr_example_ru.wav'
  179. },
  180. {
  181. 'model_group': 'damo',
  182. 'model_id':
  183. 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online',
  184. 'wav_path': 'data/test/audios/asr_example_ru.wav'
  185. },
  186. {
  187. 'model_group': 'damo',
  188. 'model_id':
  189. 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline',
  190. 'wav_path': 'data/test/audios/asr_example_es.wav'
  191. },
  192. {
  193. 'model_group': 'damo',
  194. 'model_id':
  195. 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online',
  196. 'wav_path': 'data/test/audios/asr_example_es.wav'
  197. },
  198. {
  199. 'model_group': 'damo',
  200. 'model_id':
  201. 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline',
  202. 'wav_path': 'data/test/audios/asr_example_ko.wav'
  203. },
  204. {
  205. 'model_group': 'damo',
  206. 'model_id':
  207. 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online',
  208. 'wav_path': 'data/test/audios/asr_example_ko.wav'
  209. },
  210. {
  211. 'model_group': 'damo',
  212. 'model_id':
  213. 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online',
  214. 'wav_path': 'data/test/audios/asr_example_ja.wav'
  215. },
  216. {
  217. 'model_group': 'damo',
  218. 'model_id':
  219. 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline',
  220. 'wav_path': 'data/test/audios/asr_example_ja.wav'
  221. },
  222. {
  223. 'model_group': 'damo',
  224. 'model_id':
  225. 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online',
  226. 'wav_path': 'data/test/audios/asr_example_id.wav'
  227. },
  228. {
  229. 'model_group': 'damo',
  230. 'model_id':
  231. 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline',
  232. 'wav_path': 'data/test/audios/asr_example_id.wav'
  233. },
  234. ]
  235. def setUp(self) -> None:
  236. self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
  237. self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
  238. # this temporary workspace dir will store waveform files
  239. self.workspace = os.path.join(os.getcwd(), '.tmp')
  240. self.task = Tasks.auto_speech_recognition
  241. if not os.path.exists(self.workspace):
  242. os.mkdir(self.workspace)
  243. def tearDown(self) -> None:
  244. # remove workspace dir (.tmp)
  245. shutil.rmtree(self.workspace, ignore_errors=True)
  246. def run_pipeline(self,
  247. model_id: str,
  248. audio_in: Union[str, bytes],
  249. sr: int = None) -> Dict[str, Any]:
  250. inference_16k_pipline = pipeline(
  251. task=Tasks.auto_speech_recognition, model=model_id)
  252. rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
  253. return rec_result
  254. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  255. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  256. + ColorCodes.END)
  257. logger.error(
  258. ColorCodes.MAGENTA + functions + ' correct result example:'
  259. + ColorCodes.YELLOW
  260. + str(self.action_info[self.action_info[functions]['example']])
  261. + ColorCodes.END)
  262. raise ValueError('asr result is mismatched')
  263. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  264. if result.__contains__(self.action_info[functions]['checking_item']):
  265. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  266. + ColorCodes.END)
  267. logger.info(
  268. ColorCodes.YELLOW
  269. + str(result[self.action_info[functions]['checking_item']])
  270. + ColorCodes.END)
  271. else:
  272. self.log_error(functions, result)
  273. def wav2bytes(self, wav_file):
  274. audio, fs = soundfile.read(wav_file)
  275. # float32 -> int16
  276. audio = np.asarray(audio)
  277. dtype = np.dtype('int16')
  278. i = np.iinfo(dtype)
  279. abs_max = 2**(i.bits - 1)
  280. offset = i.min + abs_max
  281. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  282. # int16(PCM_16) -> byte
  283. audio = audio.tobytes()
  284. return audio, fs
  285. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  286. def test_run_with_pcm(self):
  287. """run with wav data
  288. """
  289. logger.info('Run ASR test with wav data (tensorflow)...')
  290. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  291. rec_result = self.run_pipeline(
  292. model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
  293. self.check_result('test_run_with_pcm_tf', rec_result)
  294. logger.info('Run ASR test with wav data (pytorch)...')
  295. rec_result = self.run_pipeline(
  296. model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
  297. self.check_result('test_run_with_pcm_pytorch', rec_result)
  298. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  299. def test_run_with_wav(self):
  300. """run with single waveform file
  301. """
  302. logger.info('Run ASR test with waveform file (tensorflow)...')
  303. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  304. rec_result = self.run_pipeline(
  305. model_id=self.am_tf_model_id, audio_in=wav_file_path)
  306. self.check_result('test_run_with_wav_tf', rec_result)
  307. logger.info('Run ASR test with waveform file (pytorch)...')
  308. rec_result = self.run_pipeline(
  309. model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
  310. self.check_result('test_run_with_wav_pytorch', rec_result)
  311. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  312. def test_run_with_url(self):
  313. """run with single url file
  314. """
  315. logger.info('Run ASR test with url file (tensorflow)...')
  316. rec_result = self.run_pipeline(
  317. model_id=self.am_tf_model_id, audio_in=URL_FILE)
  318. self.check_result('test_run_with_url_tf', rec_result)
  319. logger.info('Run ASR test with url file (pytorch)...')
  320. rec_result = self.run_pipeline(
  321. model_id=self.am_pytorch_model_id, audio_in=URL_FILE)
  322. self.check_result('test_run_with_url_pytorch', rec_result)
  323. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  324. def test_run_with_wav_dataset_pytorch(self):
  325. """run with datasets, and audio format is waveform
  326. datasets directory:
  327. <dataset_path>
  328. wav
  329. test # testsets
  330. xx.wav
  331. ...
  332. dev # devsets
  333. yy.wav
  334. ...
  335. train # trainsets
  336. zz.wav
  337. ...
  338. transcript
  339. data.text # hypothesis text
  340. """
  341. logger.info('Downloading waveform testsets file ...')
  342. dataset_path = download_and_untar(
  343. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  344. LITTLE_TESTSETS_URL, self.workspace)
  345. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  346. logger.info('Run ASR test with waveform dataset (tensorflow)...')
  347. rec_result = self.run_pipeline(
  348. model_id=self.am_tf_model_id, audio_in=dataset_path)
  349. self.check_result('test_run_with_wav_dataset_tf', rec_result)
  350. logger.info('Run ASR test with waveform dataset (pytorch)...')
  351. rec_result = self.run_pipeline(
  352. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  353. self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
  354. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  355. def test_run_with_all_models(self):
  356. """run with all models
  357. """
  358. logger.info('Run ASR test with all models')
  359. for item in self.all_models_info:
  360. model_id = item['model_group'] + '/' + item['model_id']
  361. wav_path = item['wav_path']
  362. rec_result = self.run_pipeline(
  363. model_id=model_id, audio_in=wav_path)
  364. if rec_result.__contains__(OutputKeys.TEXT):
  365. logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' '
  366. + ColorCodes.YELLOW
  367. + str(rec_result[OutputKeys.TEXT])
  368. + ColorCodes.END)
  369. else:
  370. logger.info(ColorCodes.MAGENTA + str(rec_result)
  371. + ColorCodes.END)
  372. @unittest.skip('demo compatibility test is only enabled on a needed-basis')
  373. def test_demo_compatibility(self):
  374. self.compatibility_check()
  375. if __name__ == '__main__':
  376. unittest.main()