diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py index e9c4ebb9..861cca20 100644 --- a/modelscope/models/audio/kws/generic_key_word_spotting.py +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -19,7 +19,7 @@ class GenericKeyWordSpotting(Model): Args: model_dir (str): the model path. """ - super().__init__(model_dir) + super().__init__(model_dir, *args, **kwargs) self.model_cfg = { 'model_workspace': model_dir, 'config_path': os.path.join(model_dir, 'config.yaml') diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index acc27015..a6cc4d55 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -1,5 +1,4 @@ import os -import subprocess from typing import Any, Dict, List, Union import json @@ -10,6 +9,9 @@ from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import WavToLists from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() __all__ = ['KeyWordSpottingKwsbpPipeline'] @@ -21,41 +23,24 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): """ def __init__(self, - config_file: str = None, model: Union[Model, str] = None, preprocessor: WavToLists = None, **kwargs): + """use `model` and `preprocessor` to create a kws pipeline for prediction """ - use `model` and `preprocessor` to create a kws pipeline for prediction - Args: - model: model id on modelscope hub. - """ - super().__init__( - config_file=config_file, - model=model, - preprocessor=preprocessor, - **kwargs) - - assert model is not None, 'kws model should be provided' - - self._preprocessor = preprocessor - self._keywords = None + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + def __call__(self, wav_path: Union[List[str], str], + **kwargs) -> Dict[str, Any]: if 'keywords' in kwargs.keys(): - self._keywords = kwargs['keywords'] - - def __call__(self, - kws_type: str, - wav_path: List[str], - workspace: str = None) -> Dict[str, Any]: - assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', - 'roc'], f'kws_type {kws_type} is invalid' + self.keywords = kwargs['keywords'] + else: + self.keywords = None - if self._preprocessor is None: - self._preprocessor = WavToLists(workspace=workspace) + if self.preprocessor is None: + self.preprocessor = WavToLists() - output = self._preprocessor.forward(self.model.forward(), kws_type, - wav_path) + output = self.preprocessor.forward(self.model.forward(), wav_path) output = self.forward(output) rst = self.postprocess(output) return rst @@ -64,433 +49,92 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): """Decoding """ - # will generate kws result into dump/dump.JOB.log - out = self._run_with_kwsbp(inputs) + logger.info(f"Decoding with {inputs['kws_set']} mode ...") + + # will generate kws result + out = self.run_with_kwsbp(inputs) return out def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """process the kws results - """ - pos_result_json = {} - neg_result_json = {} - - if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: - self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) - if inputs['kws_set'] in ['neg_testsets', 'roc']: - self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) - """ - result_json format example: - { - "wav_count": 450, - "keywords": ["小云小云"], - "wav_time": 3560.999999, - "detected": [ - { - "xxx.wav": { - "confidence": "0.990368", - "keyword": "小云小云" - } - }, - { - "yyy.wav": { - "confidence": "0.990368", - "keyword": "小云小云" - } - }, - ...... - ], - "detected_count": 429, - "rejected_count": 21, - "rejected": [ - "yyy.wav", - "zzz.wav", - ...... - ] - } + Args: + inputs['pos_kws_list'] or inputs['neg_kws_list']: + result_dict format example: + [{ + 'confidence': 0.9903678297996521, + 'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav', + 'keyword': '小云小云', + 'offset': 5.760000228881836, # second + 'rtf_time': 66, # millisecond + 'threshold': 0, + 'wav_time': 9.1329375 # second + }] """ - rst_dict = {'kws_set': inputs['kws_set']} - - # parsing the result of wav - if inputs['kws_set'] == 'wav': - rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ - 'pos_wav_count'] - rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) - if pos_result_json['detected_count'] == 1: - rst_dict['keywords'] = pos_result_json['keywords'] - rst_dict['detected'] = True - wav_file_name = os.path.basename(inputs['pos_wav_path']) - rst_dict['confidence'] = float(pos_result_json['detected'][0] - [wav_file_name]['confidence']) - else: - rst_dict['detected'] = False - - # parsing the result of pos_tests - elif inputs['kws_set'] == 'pos_testsets': - rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ - 'pos_wav_count'] - rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) - if pos_result_json.__contains__('keywords'): - rst_dict['keywords'] = pos_result_json['keywords'] - - rst_dict['recall'] = round( - pos_result_json['detected_count'] / rst_dict['wav_count'], 6) - - if pos_result_json.__contains__('detected_count'): - rst_dict['detected_count'] = pos_result_json['detected_count'] - if pos_result_json.__contains__('rejected_count'): - rst_dict['rejected_count'] = pos_result_json['rejected_count'] - if pos_result_json.__contains__('rejected'): - rst_dict['rejected'] = pos_result_json['rejected'] - - # parsing the result of neg_tests - elif inputs['kws_set'] == 'neg_testsets': - rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ - 'neg_wav_count'] - rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) - if neg_result_json.__contains__('keywords'): - rst_dict['keywords'] = neg_result_json['keywords'] - - rst_dict['fa_rate'] = 0.0 - rst_dict['fa_per_hour'] = 0.0 - - if neg_result_json.__contains__('detected_count'): - rst_dict['detected_count'] = neg_result_json['detected_count'] - rst_dict['fa_rate'] = round( - neg_result_json['detected_count'] / rst_dict['wav_count'], - 6) - if neg_result_json.__contains__('wav_time'): - rst_dict['fa_per_hour'] = round( - neg_result_json['detected_count'] - / float(neg_result_json['wav_time'] / 3600), 6) - - if neg_result_json.__contains__('rejected_count'): - rst_dict['rejected_count'] = neg_result_json['rejected_count'] - - if neg_result_json.__contains__('detected'): - rst_dict['detected'] = neg_result_json['detected'] - - # parsing the result of roc - elif inputs['kws_set'] == 'roc': - threshold_start = 0.000 - threshold_step = 0.001 - threshold_end = 1.000 - - pos_keywords_list = [] - neg_keywords_list = [] - if pos_result_json.__contains__('keywords'): - pos_keywords_list = pos_result_json['keywords'] - if neg_result_json.__contains__('keywords'): - neg_keywords_list = neg_result_json['keywords'] - - keywords_list = list(set(pos_keywords_list + neg_keywords_list)) - - pos_result_json['wav_count'] = inputs['pos_wav_count'] - neg_result_json['wav_count'] = inputs['neg_wav_count'] - - if len(keywords_list) > 0: - rst_dict['keywords'] = keywords_list - - for index in range(len(rst_dict['keywords'])): - cur_keyword = rst_dict['keywords'][index] - output_list = self._generate_roc_list( - start=threshold_start, - step=threshold_step, - end=threshold_end, - keyword=cur_keyword, - pos_inputs=pos_result_json, - neg_inputs=neg_result_json) - - rst_dict[cur_keyword] = output_list + import kws_util.common + neg_kws_list = None + pos_kws_list = None + if 'pos_kws_list' in inputs: + pos_kws_list = inputs['pos_kws_list'] + if 'neg_kws_list' in inputs: + neg_kws_list = inputs['neg_kws_list'] + rst_dict = kws_util.common.parsing_kws_result( + kws_type=inputs['kws_set'], + pos_list=pos_kws_list, + neg_list=neg_kws_list) return rst_dict - def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - opts: str = '' + def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + cmd = { + 'sys_dir': inputs['model_workspace'], + 'cfg_file': inputs['cfg_file_path'], + 'sample_rate': inputs['sample_rate'], + 'keyword_custom': '' + } + + import kwsbp + import kws_util.common + kws_inference = kwsbp.KwsbpEngine() # setting customized keywords - keywords_json = self._set_customized_keywords() - if len(keywords_json) > 0: - keywords_json_file = os.path.join(inputs['workspace'], - 'keyword_custom.json') - with open(keywords_json_file, 'w') as f: - json.dump(keywords_json, f) - opts = '--keyword-custom ' + keywords_json_file + cmd['customized_keywords'] = kws_util.common.generate_customized_keywords( + self.keywords) if inputs['kws_set'] == 'roc': inputs['keyword_grammar_path'] = os.path.join( inputs['model_workspace'], 'keywords_roc.json') - if inputs['kws_set'] == 'wav': - dump_log_path: str = os.path.join(inputs['pos_dump_path'], - 'dump.log') - kws_cmd: str = inputs['kws_tool_path'] + \ - ' --sys-dir=' + inputs['model_workspace'] + \ - ' --cfg-file=' + inputs['cfg_file_path'] + \ - ' --sample-rate=' + inputs['sample_rate'] + \ - ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ - ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ - ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' - os.system(kws_cmd) - - if inputs['kws_set'] in ['pos_testsets', 'roc']: - data_dir: str = os.listdir(inputs['pos_data_path']) - wav_list = [] - for i in data_dir: - suffix = os.path.splitext(os.path.basename(i))[1] - if suffix == '.list': - wav_list.append(os.path.join(inputs['pos_data_path'], i)) - - j: int = 0 - process = [] - while j < inputs['pos_num_thread']: - wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( - j) + '.list' - dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( - j) + '.log' - - kws_cmd: str = inputs['kws_tool_path'] + \ - ' --sys-dir=' + inputs['model_workspace'] + \ - ' --cfg-file=' + inputs['cfg_file_path'] + \ - ' --sample-rate=' + inputs['sample_rate'] + \ - ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ - ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' - p = subprocess.Popen(kws_cmd, shell=True) - process.append(p) - j += 1 - - k: int = 0 - while k < len(process): - process[k].wait() - k += 1 + if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + cmd['wave_scp'] = inputs['pos_wav_list'] + cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] + cmd['num_thread'] = inputs['pos_num_thread'] + + # run and get inference result + result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), + cmd['sample_rate'], + cmd['num_thread']) + pos_result = json.loads(result) + inputs['pos_kws_list'] = pos_result['kws_list'] if inputs['kws_set'] in ['neg_testsets', 'roc']: - data_dir: str = os.listdir(inputs['neg_data_path']) - wav_list = [] - for i in data_dir: - suffix = os.path.splitext(os.path.basename(i))[1] - if suffix == '.list': - wav_list.append(os.path.join(inputs['neg_data_path'], i)) - - j: int = 0 - process = [] - while j < inputs['neg_num_thread']: - wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( - j) + '.list' - dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( - j) + '.log' - - kws_cmd: str = inputs['kws_tool_path'] + \ - ' --sys-dir=' + inputs['model_workspace'] + \ - ' --cfg-file=' + inputs['cfg_file_path'] + \ - ' --sample-rate=' + inputs['sample_rate'] + \ - ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ - ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' - p = subprocess.Popen(kws_cmd, shell=True) - process.append(p) - j += 1 - - k: int = 0 - while k < len(process): - process[k].wait() - k += 1 + cmd['wave_scp'] = inputs['neg_wav_list'] + cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] + cmd['num_thread'] = inputs['neg_num_thread'] + + # run and get inference result + result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), + cmd['sample_rate'], + cmd['num_thread']) + neg_result = json.loads(result) + inputs['neg_kws_list'] = neg_result['kws_list'] return inputs - - def _parse_dump_log(self, result_json: Dict[str, Any], - dump_path: str) -> Dict[str, Any]: - dump_dir = os.listdir(dump_path) - for i in dump_dir: - basename = os.path.splitext(os.path.basename(i))[0] - # find dump.JOB.log - if 'dump' in basename: - with open( - os.path.join(dump_path, i), mode='r', - encoding='utf-8') as file: - while 1: - line = file.readline() - if not line: - break - else: - result_json = self._parse_result_log( - line, result_json) - - def _parse_result_log(self, line: str, - result_json: Dict[str, Any]) -> Dict[str, Any]: - # valid info - if '[rejected]' in line or '[detected]' in line: - detected_count = 0 - rejected_count = 0 - - if result_json.__contains__('detected_count'): - detected_count = result_json['detected_count'] - if result_json.__contains__('rejected_count'): - rejected_count = result_json['rejected_count'] - - if '[detected]' in line: - # [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, - # kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, - detected_count += 1 - content_list = line.split(', ') - file_name = os.path.basename(content_list[1].split(':')[1]) - keyword = content_list[2].split(':')[1] - confidence = content_list[3].split(':')[1] - - keywords_list = [] - if result_json.__contains__('keywords'): - keywords_list = result_json['keywords'] - - if keyword not in keywords_list: - keywords_list.append(keyword) - result_json['keywords'] = keywords_list - - keyword_item = {} - keyword_item['confidence'] = confidence - keyword_item['keyword'] = keyword - item = {} - item[file_name] = keyword_item - - detected_list = [] - if result_json.__contains__('detected'): - detected_list = result_json['detected'] - - detected_list.append(item) - result_json['detected'] = detected_list - - elif '[rejected]' in line: - # [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav - rejected_count += 1 - content_list = line.split(', ') - file_name = os.path.basename(content_list[1].split(':')[1]) - file_name = file_name.strip().replace('\n', - '').replace('\r', '') - - rejected_list = [] - if result_json.__contains__('rejected'): - rejected_list = result_json['rejected'] - - rejected_list.append(file_name) - result_json['rejected'] = rejected_list - - result_json['detected_count'] = detected_count - result_json['rejected_count'] = rejected_count - - elif 'total_proc_time=' in line and 'wav_time=' in line: - # eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 - wav_total_time = 0 - content_list = line.split('), ') - if result_json.__contains__('wav_time'): - wav_total_time = result_json['wav_time'] - - wav_time_str = content_list[1].split('=')[1] - wav_time_str = wav_time_str.split('(')[0] - wav_time = float(wav_time_str) - wav_time = round(wav_time, 6) - - if isinstance(wav_time, float): - wav_total_time += wav_time - - result_json['wav_time'] = wav_total_time - - return result_json - - def _generate_roc_list(self, start: float, step: float, end: float, - keyword: str, pos_inputs: Dict[str, Any], - neg_inputs: Dict[str, Any]) -> Dict[str, Any]: - pos_wav_count = pos_inputs['wav_count'] - neg_wav_time = neg_inputs['wav_time'] - det_lists = pos_inputs['detected'] - fa_lists = neg_inputs['detected'] - threshold_cur = start - """ - input det_lists dict - [ - { - "xxx.wav": { - "confidence": "0.990368", - "keyword": "小云小云" - } - }, - { - "yyy.wav": { - "confidence": "0.990368", - "keyword": "小云小云" - } - }, - ] - - output dict - [ - { - "threshold": 0.000, - "recall": 0.999888, - "fa_per_hour": 1.999999 - }, - { - "threshold": 0.001, - "recall": 0.999888, - "fa_per_hour": 1.999999 - }, - ] - """ - - output = [] - while threshold_cur <= end: - det_count = 0 - fa_count = 0 - for index in range(len(det_lists)): - det_item = det_lists[index] - det_wav_item = det_item.get(next(iter(det_item))) - if det_wav_item['keyword'] == keyword: - confidence = float(det_wav_item['confidence']) - if confidence >= threshold_cur: - det_count += 1 - - for index in range(len(fa_lists)): - fa_item = fa_lists[index] - fa_wav_item = fa_item.get(next(iter(fa_item))) - if fa_wav_item['keyword'] == keyword: - confidence = float(fa_wav_item['confidence']) - if confidence >= threshold_cur: - fa_count += 1 - - output_item = { - 'threshold': round(threshold_cur, 3), - 'recall': round(float(det_count / pos_wav_count), 6), - 'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) - } - output.append(output_item) - - threshold_cur += step - - return output - - def _set_customized_keywords(self) -> Dict[str, Any]: - if self._keywords is not None: - word_list_inputs = self._keywords - word_list = [] - for i in range(len(word_list_inputs)): - key = word_list_inputs[i] - new_item = {} - if key.__contains__('keyword'): - name = key['keyword'] - new_name: str = '' - for n in range(0, len(name), 1): - new_name += name[n] - new_name += ' ' - new_name = new_name.strip() - new_item['name'] = new_name - - if key.__contains__('threshold'): - threshold1: float = key['threshold'] - new_item['threshold1'] = threshold1 - - word_list.append(new_item) - out = {'word_list': word_list} - return out - else: - return '' diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py index d69e8283..b406465a 100644 --- a/modelscope/preprocessors/kws.py +++ b/modelscope/preprocessors/kws.py @@ -1,8 +1,5 @@ import os -import shutil -import stat -from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import yaml @@ -19,48 +16,28 @@ __all__ = ['WavToLists'] Fields.audio, module_name=Preprocessors.wav_to_lists) class WavToLists(Preprocessor): """generate audio lists file from wav - - Args: - workspace (str): store temporarily kws intermedium and result """ - def __init__(self, workspace: str = None): - # the workspace path - if len(workspace) == 0: - self._workspace = os.path.join(os.getcwd(), '.tmp') - else: - self._workspace = workspace - - if not os.path.exists(self._workspace): - os.mkdir(self._workspace) + def __init__(self): + pass - def __call__(self, - model: Model = None, - kws_type: str = None, - wav_path: List[str] = None) -> Dict[str, Any]: + def __call__(self, model: Model, wav_path: Union[List[str], + str]) -> Dict[str, Any]: """Call functions to load model and wav. Args: model (Model): model should be provided - kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc - wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path + wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path Returns: Dict[str, Any]: the kws result """ - assert model is not None, 'preprocess kws model should be provided' - assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' - ], f'preprocess kws_type {kws_type} is invalid' - assert wav_path[0] is not None or wav_path[ - 1] is not None, 'preprocess wav_path is invalid' - - self._model = model - out = self.forward(self._model.forward(), kws_type, wav_path) + self.model = model + out = self.forward(self.model.forward(), wav_path) return out - def forward(self, model: Dict[str, Any], kws_type: str, - wav_path: List[str]) -> Dict[str, Any]: - assert len(kws_type) > 0, 'preprocess kws_type is empty' + def forward(self, model: Dict[str, Any], + wav_path: Union[List[str], str]) -> Dict[str, Any]: assert len( model['config_path']) > 0, 'preprocess model[config_path] is empty' assert os.path.exists( @@ -68,19 +45,29 @@ class WavToLists(Preprocessor): inputs = model.copy() + wav_list = [None, None] + if isinstance(wav_path, str): + wav_list[0] = wav_path + else: + wav_list = wav_path + + import kws_util.common + kws_type = kws_util.common.type_checking(wav_list) + assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' + ], f'preprocess kws_type {kws_type} is invalid' + inputs['kws_set'] = kws_type - inputs['workspace'] = self._workspace - if wav_path[0] is not None: - inputs['pos_wav_path'] = wav_path[0] - if wav_path[1] is not None: - inputs['neg_wav_path'] = wav_path[1] + if wav_list[0] is not None: + inputs['pos_wav_path'] = wav_list[0] + if wav_list[1] is not None: + inputs['neg_wav_path'] = wav_list[1] - out = self._read_config(inputs) - out = self._generate_wav_lists(out) + out = self.read_config(inputs) + out = self.generate_wav_lists(out) return out - def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """read and parse config.yaml to get all model files """ @@ -97,157 +84,51 @@ class WavToLists(Preprocessor): inputs['keyword_grammar'] = root['keyword_grammar'] inputs['keyword_grammar_path'] = os.path.join( inputs['model_workspace'], root['keyword_grammar']) - inputs['sample_rate'] = str(root['sample_rate']) - inputs['kws_tool'] = root['kws_tool'] - - if os.path.exists( - os.path.join(inputs['workspace'], inputs['kws_tool'])): - inputs['kws_tool_path'] = os.path.join(inputs['workspace'], - inputs['kws_tool']) - elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): - inputs['kws_tool_path'] = os.path.join('/usr/bin', - inputs['kws_tool']) - elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): - inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) - - assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' - os.chmod(inputs['kws_tool_path'], - stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) - - self._config_checking(inputs) + inputs['sample_rate'] = root['sample_rate'] + return inputs - def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """assemble wav lists """ + import kws_util.common if inputs['kws_set'] == 'wav': - inputs['pos_num_thread'] = 1 - wave_scp_content: str = inputs['pos_wav_path'] + '\n' - - with open(os.path.join(inputs['pos_data_path'], 'wave.list'), - 'a') as f: - f.write(wave_scp_content) - + wav_list = [] + wave_scp_content: str = inputs['pos_wav_path'] + wav_list.append(wave_scp_content) + inputs['pos_wav_list'] = wav_list inputs['pos_wav_count'] = 1 + inputs['pos_num_thread'] = 1 if inputs['kws_set'] in ['pos_testsets', 'roc']: # find all positive wave wav_list = [] wav_dir = inputs['pos_wav_path'] - wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) + inputs['pos_wav_list'] = wav_list list_count: int = len(wav_list) inputs['pos_wav_count'] = list_count if list_count <= 128: inputs['pos_num_thread'] = list_count - j: int = 0 - while j < list_count: - wave_scp_content: str = wav_list[j] + '\n' - wav_list_path = inputs['pos_data_path'] + '/wave.' + str( - j) + '.list' - with open(wav_list_path, 'a') as f: - f.write(wave_scp_content) - j += 1 - else: inputs['pos_num_thread'] = 128 - j: int = 0 - k: int = 0 - while j < list_count: - wave_scp_content: str = wav_list[j] + '\n' - wav_list_path = inputs['pos_data_path'] + '/wave.' + str( - k) + '.list' - with open(wav_list_path, 'a') as f: - f.write(wave_scp_content) - j += 1 - k += 1 - if k >= 128: - k = 0 if inputs['kws_set'] in ['neg_testsets', 'roc']: # find all negative wave wav_list = [] wav_dir = inputs['neg_wav_path'] - wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) + inputs['neg_wav_list'] = wav_list list_count: int = len(wav_list) inputs['neg_wav_count'] = list_count if list_count <= 128: inputs['neg_num_thread'] = list_count - j: int = 0 - while j < list_count: - wave_scp_content: str = wav_list[j] + '\n' - wav_list_path = inputs['neg_data_path'] + '/wave.' + str( - j) + '.list' - with open(wav_list_path, 'a') as f: - f.write(wave_scp_content) - j += 1 - else: inputs['neg_num_thread'] = 128 - j: int = 0 - k: int = 0 - while j < list_count: - wave_scp_content: str = wav_list[j] + '\n' - wav_list_path = inputs['neg_data_path'] + '/wave.' + str( - k) + '.list' - with open(wav_list_path, 'a') as f: - f.write(wave_scp_content) - j += 1 - k += 1 - if k >= 128: - k = 0 return inputs - - def _recursion_dir_all_wave(self, wav_list, - dir_path: str) -> Dict[str, Any]: - dir_files = os.listdir(dir_path) - for file in dir_files: - file_path = os.path.join(dir_path, file) - if os.path.isfile(file_path): - if file_path.endswith('.wav') or file_path.endswith('.WAV'): - wav_list.append(file_path) - elif os.path.isdir(file_path): - self._recursion_dir_all_wave(wav_list, file_path) - - return wav_list - - def _config_checking(self, inputs: Dict[str, Any]): - - if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: - inputs['pos_data_path'] = os.path.join(inputs['workspace'], - 'pos_data') - if not os.path.exists(inputs['pos_data_path']): - os.mkdir(inputs['pos_data_path']) - else: - shutil.rmtree(inputs['pos_data_path']) - os.mkdir(inputs['pos_data_path']) - - inputs['pos_dump_path'] = os.path.join(inputs['workspace'], - 'pos_dump') - if not os.path.exists(inputs['pos_dump_path']): - os.mkdir(inputs['pos_dump_path']) - else: - shutil.rmtree(inputs['pos_dump_path']) - os.mkdir(inputs['pos_dump_path']) - - if inputs['kws_set'] in ['neg_testsets', 'roc']: - inputs['neg_data_path'] = os.path.join(inputs['workspace'], - 'neg_data') - if not os.path.exists(inputs['neg_data_path']): - os.mkdir(inputs['neg_data_path']) - else: - shutil.rmtree(inputs['neg_data_path']) - os.mkdir(inputs['neg_data_path']) - - inputs['neg_dump_path'] = os.path.join(inputs['workspace'], - 'neg_dump') - if not os.path.exists(inputs['neg_dump_path']): - os.mkdir(inputs['neg_dump_path']) - else: - shutil.rmtree(inputs['neg_dump_path']) - os.mkdir(inputs['neg_dump_path']) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4adb48f0..893db798 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -218,3 +218,11 @@ class TrainerStages: after_val_iter = 'after_val_iter' after_val_epoch = 'after_val_epoch' after_run = 'after_run' + + +class ColorCodes: + MAGENTA = '\033[95m' + YELLOW = '\033[93m' + GREEN = '\033[92m' + RED = '\033[91m' + END = '\033[0m' diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py index 0ca58c4e..ad7c0238 100644 --- a/modelscope/utils/test_utils.py +++ b/modelscope/utils/test_utils.py @@ -2,9 +2,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import tarfile import unittest import numpy as np +import requests from datasets import Dataset from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE @@ -42,3 +44,21 @@ def set_test_level(level: int): def create_dummy_test_dataset(feat, label, num): return MsDataset.from_hf_dataset( Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) + + +def download_and_untar(fpath, furl, dst) -> str: + if not os.path.exists(fpath): + r = requests.get(furl) + with open(fpath, 'wb') as f: + f.write(r.content) + + file_name = os.path.basename(fpath) + root_dir = os.path.dirname(fpath) + target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] + target_dir_path = os.path.join(root_dir, target_dir_name) + + # untar the file + t = tarfile.open(fpath) + t.extractall(path=dst) + + return target_dir_path diff --git a/requirements/audio.txt b/requirements/audio.txt index a0085772..e3d50b57 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -3,6 +3,7 @@ espnet>=202204 h5py inflect keras +kwsbp librosa lxml matplotlib diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 7999b421..8b0e37e6 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -3,14 +3,16 @@ import os import shutil import tarfile import unittest +from typing import Any, Dict, List, Union import requests from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.test_utils import test_level +from modelscope.utils.constant import ColorCodes, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import download_and_untar, test_level -KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' +logger = get_logger() POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' @@ -22,12 +24,102 @@ NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' -def un_tar_gz(fname, dirs): - t = tarfile.open(fname) - t.extractall(path=dirs) - - class KeyWordSpottingTest(unittest.TestCase): + action_info = { + 'test_run_with_wav': { + 'checking_item': 'kws_list', + 'checking_value': '小云小云', + 'example': { + 'wav_count': + 1, + 'kws_set': + 'wav', + 'kws_list': [{ + 'keyword': '小云小云', + 'offset': 5.76, + 'length': 9.132938, + 'confidence': 0.990368 + }] + } + }, + 'test_run_with_wav_by_customized_keywords': { + 'checking_item': 'kws_list', + 'checking_value': '播放音乐', + 'example': { + 'wav_count': + 1, + 'kws_set': + 'wav', + 'kws_list': [{ + 'keyword': '播放音乐', + 'offset': 0.87, + 'length': 2.158313, + 'confidence': 0.646237 + }] + } + }, + 'test_run_with_pos_testsets': { + 'checking_item': 'recall', + 'example': { + 'wav_count': 450, + 'kws_set': 'pos_testsets', + 'wav_time': 3013.75925, + 'keywords': ['小云小云'], + 'recall': 0.953333, + 'detected_count': 429, + 'rejected_count': 21, + 'rejected': ['yyy.wav', 'zzz.wav'] + } + }, + 'test_run_with_neg_testsets': { + 'checking_item': 'fa_rate', + 'example': { + 'wav_count': + 751, + 'kws_set': + 'neg_testsets', + 'wav_time': + 3572.180813, + 'keywords': ['小云小云'], + 'fa_rate': + 0.001332, + 'fa_per_hour': + 1.007788, + 'detected_count': + 1, + 'rejected_count': + 750, + 'detected': [{ + '6.wav': { + 'confidence': '0.321170', + 'keyword': '小云小云' + } + }] + } + }, + 'test_run_with_roc': { + 'checking_item': 'keywords', + 'checking_value': '小云小云', + 'example': { + 'kws_set': + 'roc', + 'keywords': ['小云小云'], + '小云小云': [{ + 'threshold': 0.0, + 'recall': 0.953333, + 'fa_per_hour': 1.007788 + }, { + 'threshold': 0.001, + 'recall': 0.953333, + 'fa_per_hour': 1.007788 + }, { + 'threshold': 0.999, + 'recall': 0.004444, + 'fa_per_hour': 0.0 + }] + } + } + } def setUp(self) -> None: self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' @@ -36,315 +128,121 @@ class KeyWordSpottingTest(unittest.TestCase): os.mkdir(self.workspace) def tearDown(self) -> None: + # remove workspace dir (.tmp) if os.path.exists(self.workspace): - shutil.rmtree(os.path.join(self.workspace), ignore_errors=True) - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_wav(self): - # wav, neg_testsets, pos_testsets, roc - kws_set = 'wav' - - # get wav file - wav_file_path = POS_WAV_FILE - - # downloading kwsbp - kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') - if not os.path.exists(kwsbp_file_path): - r = requests.get(KWSBP_URL) - with open(kwsbp_file_path, 'wb') as f: - f.write(r.content) + shutil.rmtree(self.workspace, ignore_errors=True) + def run_pipeline(self, + model_id: str, + wav_path: Union[List[str], str], + keywords: List[str] = None) -> Dict[str, Any]: kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, model=self.model_id) - self.assertTrue(kwsbp_16k_pipline is not None) + task=Tasks.auto_speech_recognition, model=model_id) + + kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords) + + return kws_result + + def print_error(self, functions: str, result: Dict[str, Any]) -> None: + logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + + ColorCodes.END) + logger.error(ColorCodes.MAGENTA + functions + + ' correct result example: ' + ColorCodes.YELLOW + + str(self.action_info[functions]['example']) + + ColorCodes.END) + + raise ValueError('kws result is mismatched') + + def check_and_print_result(self, functions: str, + result: Dict[str, Any]) -> None: + if result.__contains__(self.action_info[functions]['checking_item']): + checking_item = result[self.action_info[functions] + ['checking_item']] + if functions == 'test_run_with_roc': + if checking_item[0] != self.action_info[functions][ + 'checking_value']: + self.print_error(functions, result) + + elif functions == 'test_run_with_wav': + if checking_item[0]['keyword'] != self.action_info[functions][ + 'checking_value']: + self.print_error(functions, result) + + elif functions == 'test_run_with_wav_by_customized_keywords': + if checking_item[0]['keyword'] != self.action_info[functions][ + 'checking_value']: + self.print_error(functions, result) + + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + if functions == 'test_run_with_roc': + find_keyword = result['keywords'][0] + keyword_list = result[find_keyword] + for item in iter(keyword_list): + threshold: float = item['threshold'] + recall: float = item['recall'] + fa_per_hour: float = item['fa_per_hour'] + logger.info(ColorCodes.YELLOW + ' threshold:' + + str(threshold) + ' recall:' + str(recall) + + ' fa_per_hour:' + str(fa_per_hour) + + ColorCodes.END) + else: + logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) + else: + self.print_error(functions, result) - kws_result = kwsbp_16k_pipline( - kws_type=kws_set, - wav_path=[wav_file_path, None], - workspace=self.workspace) - self.assertTrue(kws_result.__contains__('detected')) - """ - kws result json format example: - { - 'wav_count': 1, - 'kws_set': 'wav', - 'wav_time': 9.132938, - 'keywords': ['小云小云'], - 'detected': True, - 'confidence': 0.990368 - } - """ - if kws_result.__contains__('keywords'): - print('test_run_with_wav keywords: ', kws_result['keywords']) - print('test_run_with_wav confidence: ', kws_result['confidence']) - print('test_run_with_wav detected result: ', kws_result['detected']) - print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + kws_result = self.run_pipeline( + model_id=self.model_id, wav_path=POS_WAV_FILE) + self.check_and_print_result('test_run_with_wav', kws_result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_by_customized_keywords(self): - # wav, neg_testsets, pos_testsets, roc - kws_set = 'wav' - - # get wav file - wav_file_path = BOFANGYINYUE_WAV_FILE - - # downloading kwsbp - kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') - if not os.path.exists(kwsbp_file_path): - r = requests.get(KWSBP_URL) - with open(kwsbp_file_path, 'wb') as f: - f.write(r.content) - - # customized keyword if you need. - # full settings eg. - # keywords = [ - # {'keyword':'你好电视', 'threshold': 0.008}, - # {'keyword':'播放音乐', 'threshold': 0.008} - # ] keywords = [{'keyword': '播放音乐'}] - kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, - model=self.model_id, + kws_result = self.run_pipeline( + model_id=self.model_id, + wav_path=BOFANGYINYUE_WAV_FILE, keywords=keywords) - self.assertTrue(kwsbp_16k_pipline is not None) - - kws_result = kwsbp_16k_pipline( - kws_type=kws_set, - wav_path=[wav_file_path, None], - workspace=self.workspace) - self.assertTrue(kws_result.__contains__('detected')) - """ - kws result json format example: - { - 'wav_count': 1, - 'kws_set': 'wav', - 'wav_time': 9.132938, - 'keywords': ['播放音乐'], - 'detected': True, - 'confidence': 0.660368 - } - """ - if kws_result.__contains__('keywords'): - print('test_run_with_wav_by_customized_keywords keywords: ', - kws_result['keywords']) - print('test_run_with_wav_by_customized_keywords confidence: ', - kws_result['confidence']) - print('test_run_with_wav_by_customized_keywords detected result: ', - kws_result['detected']) - print('test_run_with_wav_by_customized_keywords wave time(seconds): ', - kws_result['wav_time']) + self.check_and_print_result('test_run_with_wav_by_customized_keywords', + kws_result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): - # wav, neg_testsets, pos_testsets, roc - kws_set = 'pos_testsets' - - # downloading pos_testsets file - testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) - if not os.path.exists(testsets_file_path): - r = requests.get(POS_TESTSETS_URL) - with open(testsets_file_path, 'wb') as f: - f.write(r.content) + wav_file_path = download_and_untar( + os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, + self.workspace) + wav_path = [wav_file_path, None] - testsets_dir_name = os.path.splitext( - os.path.basename(POS_TESTSETS_FILE))[0] - testsets_dir_name = os.path.splitext( - os.path.basename(testsets_dir_name))[0] - # wav_file_path = /.tmp_pos_testsets/pos_testsets/ - wav_file_path = os.path.join(self.workspace, testsets_dir_name) - - # untar the pos_testsets file - if not os.path.exists(wav_file_path): - un_tar_gz(testsets_file_path, self.workspace) - - # downloading kwsbp -- a kws batch processing tool - kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') - if not os.path.exists(kwsbp_file_path): - r = requests.get(KWSBP_URL) - with open(kwsbp_file_path, 'wb') as f: - f.write(r.content) - - kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, model=self.model_id) - self.assertTrue(kwsbp_16k_pipline is not None) - - kws_result = kwsbp_16k_pipline( - kws_type=kws_set, - wav_path=[wav_file_path, None], - workspace=self.workspace) - self.assertTrue(kws_result.__contains__('recall')) - """ - kws result json format example: - { - 'wav_count': 450, - 'kws_set': 'pos_testsets', - 'wav_time': 3013.759254, - 'keywords': ["小云小云"], - 'recall': 0.953333, - 'detected_count': 429, - 'rejected_count': 21, - 'rejected': [ - 'yyy.wav', - 'zzz.wav', - ...... - ] - } - """ - if kws_result.__contains__('keywords'): - print('test_run_with_pos_testsets keywords: ', - kws_result['keywords']) - print('test_run_with_pos_testsets recall: ', kws_result['recall']) - print('test_run_with_pos_testsets wave time(seconds): ', - kws_result['wav_time']) + kws_result = self.run_pipeline( + model_id=self.model_id, wav_path=wav_path) + self.check_and_print_result('test_run_with_pos_testsets', kws_result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_neg_testsets(self): - # wav, neg_testsets, pos_testsets, roc - kws_set = 'neg_testsets' - - # downloading neg_testsets file - testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) - if not os.path.exists(testsets_file_path): - r = requests.get(NEG_TESTSETS_URL) - with open(testsets_file_path, 'wb') as f: - f.write(r.content) - - testsets_dir_name = os.path.splitext( - os.path.basename(NEG_TESTSETS_FILE))[0] - testsets_dir_name = os.path.splitext( - os.path.basename(testsets_dir_name))[0] - # wav_file_path = /.tmp_neg_testsets/neg_testsets/ - wav_file_path = os.path.join(self.workspace, testsets_dir_name) - - # untar the neg_testsets file - if not os.path.exists(wav_file_path): - un_tar_gz(testsets_file_path, self.workspace) - - # downloading kwsbp -- a kws batch processing tool - kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') - if not os.path.exists(kwsbp_file_path): - r = requests.get(KWSBP_URL) - with open(kwsbp_file_path, 'wb') as f: - f.write(r.content) - - kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, model=self.model_id) - self.assertTrue(kwsbp_16k_pipline is not None) + wav_file_path = download_and_untar( + os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, + self.workspace) + wav_path = [None, wav_file_path] - kws_result = kwsbp_16k_pipline( - kws_type=kws_set, - wav_path=[None, wav_file_path], - workspace=self.workspace) - self.assertTrue(kws_result.__contains__('fa_rate')) - """ - kws result json format example: - { - 'wav_count': 751, - 'kws_set': 'neg_testsets', - 'wav_time': 3572.180812, - 'keywords': ['小云小云'], - 'fa_rate': 0.001332, - 'fa_per_hour': 1.007788, - 'detected_count': 1, - 'rejected_count': 750, - 'detected': [ - { - '6.wav': { - 'confidence': '0.321170' - } - } - ] - } - """ - if kws_result.__contains__('keywords'): - print('test_run_with_neg_testsets keywords: ', - kws_result['keywords']) - print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) - print('test_run_with_neg_testsets fa per hour: ', - kws_result['fa_per_hour']) - print('test_run_with_neg_testsets wave time(seconds): ', - kws_result['wav_time']) + kws_result = self.run_pipeline( + model_id=self.model_id, wav_path=wav_path) + self.check_and_print_result('test_run_with_neg_testsets', kws_result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_roc(self): - # wav, neg_testsets, pos_testsets, roc - kws_set = 'roc' - - # downloading neg_testsets file - testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) - if not os.path.exists(testsets_file_path): - r = requests.get(NEG_TESTSETS_URL) - with open(testsets_file_path, 'wb') as f: - f.write(r.content) - - testsets_dir_name = os.path.splitext( - os.path.basename(NEG_TESTSETS_FILE))[0] - testsets_dir_name = os.path.splitext( - os.path.basename(testsets_dir_name))[0] - # neg_file_path = /.tmp_roc/neg_testsets/ - neg_file_path = os.path.join(self.workspace, testsets_dir_name) - - # untar the neg_testsets file - if not os.path.exists(neg_file_path): - un_tar_gz(testsets_file_path, self.workspace) - - # downloading pos_testsets file - testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) - if not os.path.exists(testsets_file_path): - r = requests.get(POS_TESTSETS_URL) - with open(testsets_file_path, 'wb') as f: - f.write(r.content) - - testsets_dir_name = os.path.splitext( - os.path.basename(POS_TESTSETS_FILE))[0] - testsets_dir_name = os.path.splitext( - os.path.basename(testsets_dir_name))[0] - # pos_file_path = /.tmp_roc/pos_testsets/ - pos_file_path = os.path.join(self.workspace, testsets_dir_name) - - # untar the pos_testsets file - if not os.path.exists(pos_file_path): - un_tar_gz(testsets_file_path, self.workspace) - - # downloading kwsbp -- a kws batch processing tool - kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') - if not os.path.exists(kwsbp_file_path): - r = requests.get(KWSBP_URL) - with open(kwsbp_file_path, 'wb') as f: - f.write(r.content) - - kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, model=self.model_id) - self.assertTrue(kwsbp_16k_pipline is not None) - - kws_result = kwsbp_16k_pipline( - kws_type=kws_set, - wav_path=[pos_file_path, neg_file_path], - workspace=self.workspace) - """ - kws result json format example: - { - 'kws_set': 'roc', - 'keywords': ['小云小云'], - '小云小云': [ - {'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, - {'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, - ...... - {'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} - ] - } - """ - if kws_result.__contains__('keywords'): - find_keyword = kws_result['keywords'][0] - print('test_run_with_roc keywords: ', find_keyword) - keyword_list = kws_result[find_keyword] - for item in iter(keyword_list): - threshold: float = item['threshold'] - recall: float = item['recall'] - fa_per_hour: float = item['fa_per_hour'] - print(' threshold:', threshold, ' recall:', recall, - ' fa_per_hour:', fa_per_hour) + pos_file_path = download_and_untar( + os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, + self.workspace) + neg_file_path = download_and_untar( + os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, + self.workspace) + wav_path = [pos_file_path, neg_file_path] + + kws_result = self.run_pipeline( + model_id=self.model_id, wav_path=wav_path) + self.check_and_print_result('test_run_with_roc', kws_result) if __name__ == '__main__':