You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_key_word_spotting.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import unittest
  5. from typing import Any, Dict, List, Union
  6. import numpy as np
  7. import soundfile
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import ColorCodes, Tasks
  11. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import download_and_untar, test_level
  14. logger = get_logger()
  15. POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
  16. BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
  17. URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav'
  18. POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
  19. POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
  20. NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
  21. NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'
  22. class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
  23. action_info = {
  24. 'test_run_with_wav': {
  25. 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
  26. 'checking_value': '小云小云',
  27. 'example': {
  28. 'wav_count':
  29. 1,
  30. 'kws_type':
  31. 'wav',
  32. 'kws_list': [{
  33. 'keyword': '小云小云',
  34. 'offset': 5.76,
  35. 'length': 9.132938,
  36. 'confidence': 0.990368
  37. }]
  38. }
  39. },
  40. 'test_run_with_pcm': {
  41. 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
  42. 'checking_value': '小云小云',
  43. 'example': {
  44. 'wav_count':
  45. 1,
  46. 'kws_type':
  47. 'pcm',
  48. 'kws_list': [{
  49. 'keyword': '小云小云',
  50. 'offset': 5.76,
  51. 'length': 9.132938,
  52. 'confidence': 0.990368
  53. }]
  54. }
  55. },
  56. 'test_run_with_wav_by_customized_keywords': {
  57. 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
  58. 'checking_value': '播放音乐',
  59. 'example': {
  60. 'wav_count':
  61. 1,
  62. 'kws_type':
  63. 'wav',
  64. 'kws_list': [{
  65. 'keyword': '播放音乐',
  66. 'offset': 0.87,
  67. 'length': 2.158313,
  68. 'confidence': 0.646237
  69. }]
  70. }
  71. },
  72. 'test_run_with_url': {
  73. 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
  74. 'checking_value': '小云小云',
  75. 'example': {
  76. 'wav_count':
  77. 1,
  78. 'kws_type':
  79. 'pcm',
  80. 'kws_list': [{
  81. 'keyword': '小云小云',
  82. 'offset': 0.69,
  83. 'length': 1.67,
  84. 'confidence': 0.996023
  85. }]
  86. }
  87. },
  88. 'test_run_with_pos_testsets': {
  89. 'checking_item': ['recall'],
  90. 'example': {
  91. 'wav_count': 450,
  92. 'kws_type': 'pos_testsets',
  93. 'wav_time': 3013.75925,
  94. 'keywords': ['小云小云'],
  95. 'recall': 0.953333,
  96. 'detected_count': 429,
  97. 'rejected_count': 21,
  98. 'rejected': ['yyy.wav', 'zzz.wav']
  99. }
  100. },
  101. 'test_run_with_neg_testsets': {
  102. 'checking_item': ['fa_rate'],
  103. 'example': {
  104. 'wav_count':
  105. 751,
  106. 'kws_type':
  107. 'neg_testsets',
  108. 'wav_time':
  109. 3572.180813,
  110. 'keywords': ['小云小云'],
  111. 'fa_rate':
  112. 0.001332,
  113. 'fa_per_hour':
  114. 1.007788,
  115. 'detected_count':
  116. 1,
  117. 'rejected_count':
  118. 750,
  119. 'detected': [{
  120. '6.wav': {
  121. 'confidence': '0.321170',
  122. 'keyword': '小云小云'
  123. }
  124. }]
  125. }
  126. },
  127. 'test_run_with_roc': {
  128. 'checking_item': ['keywords', 0],
  129. 'checking_value': '小云小云',
  130. 'example': {
  131. 'kws_type':
  132. 'roc',
  133. 'keywords': ['小云小云'],
  134. '小云小云': [{
  135. 'threshold': 0.0,
  136. 'recall': 0.953333,
  137. 'fa_per_hour': 1.007788
  138. }, {
  139. 'threshold': 0.001,
  140. 'recall': 0.953333,
  141. 'fa_per_hour': 1.007788
  142. }, {
  143. 'threshold': 0.999,
  144. 'recall': 0.004444,
  145. 'fa_per_hour': 0.0
  146. }]
  147. }
  148. }
  149. }
  150. def setUp(self) -> None:
  151. self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun'
  152. self.workspace = os.path.join(os.getcwd(), '.tmp')
  153. if not os.path.exists(self.workspace):
  154. os.mkdir(self.workspace)
  155. def tearDown(self) -> None:
  156. # remove workspace dir (.tmp)
  157. shutil.rmtree(self.workspace, ignore_errors=True)
  158. def run_pipeline(self,
  159. model_id: str,
  160. audio_in: Union[List[str], str, bytes],
  161. keywords: List[str] = None) -> Dict[str, Any]:
  162. kwsbp_16k_pipline = pipeline(
  163. task=Tasks.keyword_spotting, model=model_id)
  164. kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)
  165. return kws_result
  166. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  167. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  168. + ColorCodes.END)
  169. logger.error(ColorCodes.MAGENTA + functions
  170. + ' correct result example: ' + ColorCodes.YELLOW
  171. + str(self.action_info[functions]['example'])
  172. + ColorCodes.END)
  173. raise ValueError('kws result is mismatched')
  174. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  175. result_item = result
  176. check_list = self.action_info[functions]['checking_item']
  177. for check_item in check_list:
  178. result_item = result_item[check_item]
  179. if result_item is None or result_item == 'None':
  180. self.log_error(functions, result)
  181. if self.action_info[functions].__contains__('checking_value'):
  182. check_value = self.action_info[functions]['checking_value']
  183. if result_item != check_value:
  184. self.log_error(functions, result)
  185. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  186. + ColorCodes.END)
  187. if functions == 'test_run_with_roc':
  188. find_keyword = result['keywords'][0]
  189. keyword_list = result[find_keyword]
  190. for item in iter(keyword_list):
  191. threshold: float = item['threshold']
  192. recall: float = item['recall']
  193. fa_per_hour: float = item['fa_per_hour']
  194. logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold)
  195. + ' recall:' + str(recall) + ' fa_per_hour:'
  196. + str(fa_per_hour) + ColorCodes.END)
  197. else:
  198. logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
  199. def wav2bytes(self, wav_file) -> bytes:
  200. audio, fs = soundfile.read(wav_file)
  201. # float32 -> int16
  202. audio = np.asarray(audio)
  203. dtype = np.dtype('int16')
  204. i = np.iinfo(dtype)
  205. abs_max = 2**(i.bits - 1)
  206. offset = i.min + abs_max
  207. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  208. # int16(PCM_16) -> byte
  209. audio = audio.tobytes()
  210. return audio
  211. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  212. def test_run_with_wav(self):
  213. kws_result = self.run_pipeline(
  214. model_id=self.model_id, audio_in=POS_WAV_FILE)
  215. self.check_result('test_run_with_wav', kws_result)
  216. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  217. def test_run_with_pcm(self):
  218. audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))
  219. kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
  220. self.check_result('test_run_with_pcm', kws_result)
  221. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  222. def test_run_with_wav_by_customized_keywords(self):
  223. keywords = '播放音乐'
  224. kws_result = self.run_pipeline(
  225. model_id=self.model_id,
  226. audio_in=BOFANGYINYUE_WAV_FILE,
  227. keywords=keywords)
  228. self.check_result('test_run_with_wav_by_customized_keywords',
  229. kws_result)
  230. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  231. def test_run_with_url(self):
  232. kws_result = self.run_pipeline(
  233. model_id=self.model_id, audio_in=URL_FILE)
  234. self.check_result('test_run_with_url', kws_result)
  235. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  236. def test_run_with_pos_testsets(self):
  237. wav_file_path = download_and_untar(
  238. os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
  239. self.workspace)
  240. audio_list = [wav_file_path, None]
  241. kws_result = self.run_pipeline(
  242. model_id=self.model_id, audio_in=audio_list)
  243. self.check_result('test_run_with_pos_testsets', kws_result)
  244. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  245. def test_run_with_neg_testsets(self):
  246. wav_file_path = download_and_untar(
  247. os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
  248. self.workspace)
  249. audio_list = [None, wav_file_path]
  250. kws_result = self.run_pipeline(
  251. model_id=self.model_id, audio_in=audio_list)
  252. self.check_result('test_run_with_neg_testsets', kws_result)
  253. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  254. def test_run_with_roc(self):
  255. pos_file_path = download_and_untar(
  256. os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
  257. self.workspace)
  258. neg_file_path = download_and_untar(
  259. os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
  260. self.workspace)
  261. audio_list = [pos_file_path, neg_file_path]
  262. kws_result = self.run_pipeline(
  263. model_id=self.model_id, audio_in=audio_list)
  264. self.check_result('test_run_with_roc', kws_result)
  265. @unittest.skip('demo compatibility test is only enabled on a needed-basis')
  266. def test_demo_compatibility(self):
  267. self.compatibility_check()
  268. if __name__ == '__main__':
  269. unittest.main()