|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import shutil
- import unittest
- from typing import Any, Dict, List, Union
-
- import numpy as np
- import soundfile
-
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines import pipeline
- from modelscope.utils.constant import ColorCodes, Tasks
- from modelscope.utils.demo_utils import DemoCompatibilityCheck
- from modelscope.utils.logger import get_logger
- from modelscope.utils.test_utils import download_and_untar, test_level
-
- logger = get_logger()
-
- POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
- BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
- URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav'
-
- POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
- POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
-
- NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
- NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'
-
-
- class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
- action_info = {
- 'test_run_with_wav': {
- 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
- 'checking_value': '小云小云',
- 'example': {
- 'wav_count':
- 1,
- 'kws_type':
- 'wav',
- 'kws_list': [{
- 'keyword': '小云小云',
- 'offset': 5.76,
- 'length': 9.132938,
- 'confidence': 0.990368
- }]
- }
- },
- 'test_run_with_pcm': {
- 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
- 'checking_value': '小云小云',
- 'example': {
- 'wav_count':
- 1,
- 'kws_type':
- 'pcm',
- 'kws_list': [{
- 'keyword': '小云小云',
- 'offset': 5.76,
- 'length': 9.132938,
- 'confidence': 0.990368
- }]
- }
- },
- 'test_run_with_wav_by_customized_keywords': {
- 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
- 'checking_value': '播放音乐',
- 'example': {
- 'wav_count':
- 1,
- 'kws_type':
- 'wav',
- 'kws_list': [{
- 'keyword': '播放音乐',
- 'offset': 0.87,
- 'length': 2.158313,
- 'confidence': 0.646237
- }]
- }
- },
- 'test_run_with_url': {
- 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
- 'checking_value': '小云小云',
- 'example': {
- 'wav_count':
- 1,
- 'kws_type':
- 'pcm',
- 'kws_list': [{
- 'keyword': '小云小云',
- 'offset': 0.69,
- 'length': 1.67,
- 'confidence': 0.996023
- }]
- }
- },
- 'test_run_with_pos_testsets': {
- 'checking_item': ['recall'],
- 'example': {
- 'wav_count': 450,
- 'kws_type': 'pos_testsets',
- 'wav_time': 3013.75925,
- 'keywords': ['小云小云'],
- 'recall': 0.953333,
- 'detected_count': 429,
- 'rejected_count': 21,
- 'rejected': ['yyy.wav', 'zzz.wav']
- }
- },
- 'test_run_with_neg_testsets': {
- 'checking_item': ['fa_rate'],
- 'example': {
- 'wav_count':
- 751,
- 'kws_type':
- 'neg_testsets',
- 'wav_time':
- 3572.180813,
- 'keywords': ['小云小云'],
- 'fa_rate':
- 0.001332,
- 'fa_per_hour':
- 1.007788,
- 'detected_count':
- 1,
- 'rejected_count':
- 750,
- 'detected': [{
- '6.wav': {
- 'confidence': '0.321170',
- 'keyword': '小云小云'
- }
- }]
- }
- },
- 'test_run_with_roc': {
- 'checking_item': ['keywords', 0],
- 'checking_value': '小云小云',
- 'example': {
- 'kws_type':
- 'roc',
- 'keywords': ['小云小云'],
- '小云小云': [{
- 'threshold': 0.0,
- 'recall': 0.953333,
- 'fa_per_hour': 1.007788
- }, {
- 'threshold': 0.001,
- 'recall': 0.953333,
- 'fa_per_hour': 1.007788
- }, {
- 'threshold': 0.999,
- 'recall': 0.004444,
- 'fa_per_hour': 0.0
- }]
- }
- }
- }
-
- def setUp(self) -> None:
- self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun'
- self.workspace = os.path.join(os.getcwd(), '.tmp')
- if not os.path.exists(self.workspace):
- os.mkdir(self.workspace)
-
- def tearDown(self) -> None:
- # remove workspace dir (.tmp)
- shutil.rmtree(self.workspace, ignore_errors=True)
-
- def run_pipeline(self,
- model_id: str,
- audio_in: Union[List[str], str, bytes],
- keywords: List[str] = None) -> Dict[str, Any]:
- kwsbp_16k_pipline = pipeline(
- task=Tasks.keyword_spotting, model=model_id)
-
- kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)
-
- return kws_result
-
- def log_error(self, functions: str, result: Dict[str, Any]) -> None:
- logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
- + ColorCodes.END)
- logger.error(ColorCodes.MAGENTA + functions
- + ' correct result example: ' + ColorCodes.YELLOW
- + str(self.action_info[functions]['example'])
- + ColorCodes.END)
-
- raise ValueError('kws result is mismatched')
-
- def check_result(self, functions: str, result: Dict[str, Any]) -> None:
- result_item = result
- check_list = self.action_info[functions]['checking_item']
- for check_item in check_list:
- result_item = result_item[check_item]
- if result_item is None or result_item == 'None':
- self.log_error(functions, result)
-
- if self.action_info[functions].__contains__('checking_value'):
- check_value = self.action_info[functions]['checking_value']
- if result_item != check_value:
- self.log_error(functions, result)
-
- logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
- + ColorCodes.END)
- if functions == 'test_run_with_roc':
- find_keyword = result['keywords'][0]
- keyword_list = result[find_keyword]
- for item in iter(keyword_list):
- threshold: float = item['threshold']
- recall: float = item['recall']
- fa_per_hour: float = item['fa_per_hour']
- logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold)
- + ' recall:' + str(recall) + ' fa_per_hour:'
- + str(fa_per_hour) + ColorCodes.END)
- else:
- logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
-
- def wav2bytes(self, wav_file) -> bytes:
- audio, fs = soundfile.read(wav_file)
-
- # float32 -> int16
- audio = np.asarray(audio)
- dtype = np.dtype('int16')
- i = np.iinfo(dtype)
- abs_max = 2**(i.bits - 1)
- offset = i.min + abs_max
- audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
-
- # int16(PCM_16) -> byte
- audio = audio.tobytes()
- return audio
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_wav(self):
- kws_result = self.run_pipeline(
- model_id=self.model_id, audio_in=POS_WAV_FILE)
- self.check_result('test_run_with_wav', kws_result)
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_pcm(self):
- audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))
-
- kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
- self.check_result('test_run_with_pcm', kws_result)
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_wav_by_customized_keywords(self):
- keywords = '播放音乐'
-
- kws_result = self.run_pipeline(
- model_id=self.model_id,
- audio_in=BOFANGYINYUE_WAV_FILE,
- keywords=keywords)
- self.check_result('test_run_with_wav_by_customized_keywords',
- kws_result)
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_with_url(self):
- kws_result = self.run_pipeline(
- model_id=self.model_id, audio_in=URL_FILE)
- self.check_result('test_run_with_url', kws_result)
-
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
- def test_run_with_pos_testsets(self):
- wav_file_path = download_and_untar(
- os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
- self.workspace)
- audio_list = [wav_file_path, None]
-
- kws_result = self.run_pipeline(
- model_id=self.model_id, audio_in=audio_list)
- self.check_result('test_run_with_pos_testsets', kws_result)
-
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
- def test_run_with_neg_testsets(self):
- wav_file_path = download_and_untar(
- os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
- self.workspace)
- audio_list = [None, wav_file_path]
-
- kws_result = self.run_pipeline(
- model_id=self.model_id, audio_in=audio_list)
- self.check_result('test_run_with_neg_testsets', kws_result)
-
- @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_run_with_roc(self):
- pos_file_path = download_and_untar(
- os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
- self.workspace)
- neg_file_path = download_and_untar(
- os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
- self.workspace)
- audio_list = [pos_file_path, neg_file_path]
-
- kws_result = self.run_pipeline(
- model_id=self.model_id, audio_in=audio_list)
- self.check_result('test_run_with_roc', kws_result)
-
- @unittest.skip('demo compatibility test is only enabled on a needed-basis')
- def test_demo_compatibility(self):
- self.compatibility_check()
-
-
- if __name__ == '__main__':
- unittest.main()
|