Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491681master
@@ -19,7 +19,7 @@ class GenericKeyWordSpotting(Model): | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir) | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model_cfg = { | |||
'model_workspace': model_dir, | |||
'config_path': os.path.join(model_dir, 'config.yaml') | |||
@@ -1,5 +1,4 @@ | |||
import os | |||
import subprocess | |||
from typing import Any, Dict, List, Union | |||
import json | |||
@@ -10,6 +9,9 @@ from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import WavToLists | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['KeyWordSpottingKwsbpPipeline'] | |||
@@ -21,41 +23,24 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
""" | |||
def __init__(self, | |||
config_file: str = None, | |||
model: Union[Model, str] = None, | |||
preprocessor: WavToLists = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create a kws pipeline for prediction | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__( | |||
config_file=config_file, | |||
model=model, | |||
preprocessor=preprocessor, | |||
**kwargs) | |||
assert model is not None, 'kws model should be provided' | |||
self._preprocessor = preprocessor | |||
self._keywords = None | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def __call__(self, wav_path: Union[List[str], str], | |||
**kwargs) -> Dict[str, Any]: | |||
if 'keywords' in kwargs.keys(): | |||
self._keywords = kwargs['keywords'] | |||
def __call__(self, | |||
kws_type: str, | |||
wav_path: List[str], | |||
workspace: str = None) -> Dict[str, Any]: | |||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | |||
'roc'], f'kws_type {kws_type} is invalid' | |||
self.keywords = kwargs['keywords'] | |||
else: | |||
self.keywords = None | |||
if self._preprocessor is None: | |||
self._preprocessor = WavToLists(workspace=workspace) | |||
if self.preprocessor is None: | |||
self.preprocessor = WavToLists() | |||
output = self._preprocessor.forward(self.model.forward(), kws_type, | |||
wav_path) | |||
output = self.preprocessor.forward(self.model.forward(), wav_path) | |||
output = self.forward(output) | |||
rst = self.postprocess(output) | |||
return rst | |||
@@ -64,433 +49,92 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
"""Decoding | |||
""" | |||
# will generate kws result into dump/dump.JOB.log | |||
out = self._run_with_kwsbp(inputs) | |||
logger.info(f"Decoding with {inputs['kws_set']} mode ...") | |||
# will generate kws result | |||
out = self.run_with_kwsbp(inputs) | |||
return out | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the kws results | |||
""" | |||
pos_result_json = {} | |||
neg_result_json = {} | |||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||
self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) | |||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||
self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) | |||
""" | |||
result_json format example: | |||
{ | |||
"wav_count": 450, | |||
"keywords": ["小云小云"], | |||
"wav_time": 3560.999999, | |||
"detected": [ | |||
{ | |||
"xxx.wav": { | |||
"confidence": "0.990368", | |||
"keyword": "小云小云" | |||
} | |||
}, | |||
{ | |||
"yyy.wav": { | |||
"confidence": "0.990368", | |||
"keyword": "小云小云" | |||
} | |||
}, | |||
...... | |||
], | |||
"detected_count": 429, | |||
"rejected_count": 21, | |||
"rejected": [ | |||
"yyy.wav", | |||
"zzz.wav", | |||
...... | |||
] | |||
} | |||
Args: | |||
inputs['pos_kws_list'] or inputs['neg_kws_list']: | |||
result_dict format example: | |||
[{ | |||
'confidence': 0.9903678297996521, | |||
'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav', | |||
'keyword': '小云小云', | |||
'offset': 5.760000228881836, # second | |||
'rtf_time': 66, # millisecond | |||
'threshold': 0, | |||
'wav_time': 9.1329375 # second | |||
}] | |||
""" | |||
rst_dict = {'kws_set': inputs['kws_set']} | |||
# parsing the result of wav | |||
if inputs['kws_set'] == 'wav': | |||
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||
'pos_wav_count'] | |||
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||
if pos_result_json['detected_count'] == 1: | |||
rst_dict['keywords'] = pos_result_json['keywords'] | |||
rst_dict['detected'] = True | |||
wav_file_name = os.path.basename(inputs['pos_wav_path']) | |||
rst_dict['confidence'] = float(pos_result_json['detected'][0] | |||
[wav_file_name]['confidence']) | |||
else: | |||
rst_dict['detected'] = False | |||
# parsing the result of pos_tests | |||
elif inputs['kws_set'] == 'pos_testsets': | |||
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||
'pos_wav_count'] | |||
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||
if pos_result_json.__contains__('keywords'): | |||
rst_dict['keywords'] = pos_result_json['keywords'] | |||
rst_dict['recall'] = round( | |||
pos_result_json['detected_count'] / rst_dict['wav_count'], 6) | |||
if pos_result_json.__contains__('detected_count'): | |||
rst_dict['detected_count'] = pos_result_json['detected_count'] | |||
if pos_result_json.__contains__('rejected_count'): | |||
rst_dict['rejected_count'] = pos_result_json['rejected_count'] | |||
if pos_result_json.__contains__('rejected'): | |||
rst_dict['rejected'] = pos_result_json['rejected'] | |||
# parsing the result of neg_tests | |||
elif inputs['kws_set'] == 'neg_testsets': | |||
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ | |||
'neg_wav_count'] | |||
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) | |||
if neg_result_json.__contains__('keywords'): | |||
rst_dict['keywords'] = neg_result_json['keywords'] | |||
rst_dict['fa_rate'] = 0.0 | |||
rst_dict['fa_per_hour'] = 0.0 | |||
if neg_result_json.__contains__('detected_count'): | |||
rst_dict['detected_count'] = neg_result_json['detected_count'] | |||
rst_dict['fa_rate'] = round( | |||
neg_result_json['detected_count'] / rst_dict['wav_count'], | |||
6) | |||
if neg_result_json.__contains__('wav_time'): | |||
rst_dict['fa_per_hour'] = round( | |||
neg_result_json['detected_count'] | |||
/ float(neg_result_json['wav_time'] / 3600), 6) | |||
if neg_result_json.__contains__('rejected_count'): | |||
rst_dict['rejected_count'] = neg_result_json['rejected_count'] | |||
if neg_result_json.__contains__('detected'): | |||
rst_dict['detected'] = neg_result_json['detected'] | |||
# parsing the result of roc | |||
elif inputs['kws_set'] == 'roc': | |||
threshold_start = 0.000 | |||
threshold_step = 0.001 | |||
threshold_end = 1.000 | |||
pos_keywords_list = [] | |||
neg_keywords_list = [] | |||
if pos_result_json.__contains__('keywords'): | |||
pos_keywords_list = pos_result_json['keywords'] | |||
if neg_result_json.__contains__('keywords'): | |||
neg_keywords_list = neg_result_json['keywords'] | |||
keywords_list = list(set(pos_keywords_list + neg_keywords_list)) | |||
pos_result_json['wav_count'] = inputs['pos_wav_count'] | |||
neg_result_json['wav_count'] = inputs['neg_wav_count'] | |||
if len(keywords_list) > 0: | |||
rst_dict['keywords'] = keywords_list | |||
for index in range(len(rst_dict['keywords'])): | |||
cur_keyword = rst_dict['keywords'][index] | |||
output_list = self._generate_roc_list( | |||
start=threshold_start, | |||
step=threshold_step, | |||
end=threshold_end, | |||
keyword=cur_keyword, | |||
pos_inputs=pos_result_json, | |||
neg_inputs=neg_result_json) | |||
rst_dict[cur_keyword] = output_list | |||
import kws_util.common | |||
neg_kws_list = None | |||
pos_kws_list = None | |||
if 'pos_kws_list' in inputs: | |||
pos_kws_list = inputs['pos_kws_list'] | |||
if 'neg_kws_list' in inputs: | |||
neg_kws_list = inputs['neg_kws_list'] | |||
rst_dict = kws_util.common.parsing_kws_result( | |||
kws_type=inputs['kws_set'], | |||
pos_list=pos_kws_list, | |||
neg_list=neg_kws_list) | |||
return rst_dict | |||
def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
opts: str = '' | |||
def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
cmd = { | |||
'sys_dir': inputs['model_workspace'], | |||
'cfg_file': inputs['cfg_file_path'], | |||
'sample_rate': inputs['sample_rate'], | |||
'keyword_custom': '' | |||
} | |||
import kwsbp | |||
import kws_util.common | |||
kws_inference = kwsbp.KwsbpEngine() | |||
# setting customized keywords | |||
keywords_json = self._set_customized_keywords() | |||
if len(keywords_json) > 0: | |||
keywords_json_file = os.path.join(inputs['workspace'], | |||
'keyword_custom.json') | |||
with open(keywords_json_file, 'w') as f: | |||
json.dump(keywords_json, f) | |||
opts = '--keyword-custom ' + keywords_json_file | |||
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords( | |||
self.keywords) | |||
if inputs['kws_set'] == 'roc': | |||
inputs['keyword_grammar_path'] = os.path.join( | |||
inputs['model_workspace'], 'keywords_roc.json') | |||
if inputs['kws_set'] == 'wav': | |||
dump_log_path: str = os.path.join(inputs['pos_dump_path'], | |||
'dump.log') | |||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||
' --sys-dir=' + inputs['model_workspace'] + \ | |||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||
' --sample-rate=' + inputs['sample_rate'] + \ | |||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | |||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
os.system(kws_cmd) | |||
if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||
data_dir: str = os.listdir(inputs['pos_data_path']) | |||
wav_list = [] | |||
for i in data_dir: | |||
suffix = os.path.splitext(os.path.basename(i))[1] | |||
if suffix == '.list': | |||
wav_list.append(os.path.join(inputs['pos_data_path'], i)) | |||
j: int = 0 | |||
process = [] | |||
while j < inputs['pos_num_thread']: | |||
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( | |||
j) + '.list' | |||
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( | |||
j) + '.log' | |||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||
' --sys-dir=' + inputs['model_workspace'] + \ | |||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||
' --sample-rate=' + inputs['sample_rate'] + \ | |||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
' --wave-scp=' + wav_list_path + \ | |||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
p = subprocess.Popen(kws_cmd, shell=True) | |||
process.append(p) | |||
j += 1 | |||
k: int = 0 | |||
while k < len(process): | |||
process[k].wait() | |||
k += 1 | |||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||
cmd['wave_scp'] = inputs['pos_wav_list'] | |||
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] | |||
cmd['num_thread'] = inputs['pos_num_thread'] | |||
# run and get inference result | |||
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], | |||
cmd['keyword_grammar_path'], | |||
str(json.dumps(cmd['wave_scp'])), | |||
str(cmd['customized_keywords']), | |||
cmd['sample_rate'], | |||
cmd['num_thread']) | |||
pos_result = json.loads(result) | |||
inputs['pos_kws_list'] = pos_result['kws_list'] | |||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||
data_dir: str = os.listdir(inputs['neg_data_path']) | |||
wav_list = [] | |||
for i in data_dir: | |||
suffix = os.path.splitext(os.path.basename(i))[1] | |||
if suffix == '.list': | |||
wav_list.append(os.path.join(inputs['neg_data_path'], i)) | |||
j: int = 0 | |||
process = [] | |||
while j < inputs['neg_num_thread']: | |||
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( | |||
j) + '.list' | |||
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( | |||
j) + '.log' | |||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||
' --sys-dir=' + inputs['model_workspace'] + \ | |||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||
' --sample-rate=' + inputs['sample_rate'] + \ | |||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
' --wave-scp=' + wav_list_path + \ | |||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
p = subprocess.Popen(kws_cmd, shell=True) | |||
process.append(p) | |||
j += 1 | |||
k: int = 0 | |||
while k < len(process): | |||
process[k].wait() | |||
k += 1 | |||
cmd['wave_scp'] = inputs['neg_wav_list'] | |||
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] | |||
cmd['num_thread'] = inputs['neg_num_thread'] | |||
# run and get inference result | |||
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], | |||
cmd['keyword_grammar_path'], | |||
str(json.dumps(cmd['wave_scp'])), | |||
str(cmd['customized_keywords']), | |||
cmd['sample_rate'], | |||
cmd['num_thread']) | |||
neg_result = json.loads(result) | |||
inputs['neg_kws_list'] = neg_result['kws_list'] | |||
return inputs | |||
def _parse_dump_log(self, result_json: Dict[str, Any], | |||
dump_path: str) -> Dict[str, Any]: | |||
dump_dir = os.listdir(dump_path) | |||
for i in dump_dir: | |||
basename = os.path.splitext(os.path.basename(i))[0] | |||
# find dump.JOB.log | |||
if 'dump' in basename: | |||
with open( | |||
os.path.join(dump_path, i), mode='r', | |||
encoding='utf-8') as file: | |||
while 1: | |||
line = file.readline() | |||
if not line: | |||
break | |||
else: | |||
result_json = self._parse_result_log( | |||
line, result_json) | |||
def _parse_result_log(self, line: str, | |||
result_json: Dict[str, Any]) -> Dict[str, Any]: | |||
# valid info | |||
if '[rejected]' in line or '[detected]' in line: | |||
detected_count = 0 | |||
rejected_count = 0 | |||
if result_json.__contains__('detected_count'): | |||
detected_count = result_json['detected_count'] | |||
if result_json.__contains__('rejected_count'): | |||
rejected_count = result_json['rejected_count'] | |||
if '[detected]' in line: | |||
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, | |||
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, | |||
detected_count += 1 | |||
content_list = line.split(', ') | |||
file_name = os.path.basename(content_list[1].split(':')[1]) | |||
keyword = content_list[2].split(':')[1] | |||
confidence = content_list[3].split(':')[1] | |||
keywords_list = [] | |||
if result_json.__contains__('keywords'): | |||
keywords_list = result_json['keywords'] | |||
if keyword not in keywords_list: | |||
keywords_list.append(keyword) | |||
result_json['keywords'] = keywords_list | |||
keyword_item = {} | |||
keyword_item['confidence'] = confidence | |||
keyword_item['keyword'] = keyword | |||
item = {} | |||
item[file_name] = keyword_item | |||
detected_list = [] | |||
if result_json.__contains__('detected'): | |||
detected_list = result_json['detected'] | |||
detected_list.append(item) | |||
result_json['detected'] = detected_list | |||
elif '[rejected]' in line: | |||
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav | |||
rejected_count += 1 | |||
content_list = line.split(', ') | |||
file_name = os.path.basename(content_list[1].split(':')[1]) | |||
file_name = file_name.strip().replace('\n', | |||
'').replace('\r', '') | |||
rejected_list = [] | |||
if result_json.__contains__('rejected'): | |||
rejected_list = result_json['rejected'] | |||
rejected_list.append(file_name) | |||
result_json['rejected'] = rejected_list | |||
result_json['detected_count'] = detected_count | |||
result_json['rejected_count'] = rejected_count | |||
elif 'total_proc_time=' in line and 'wav_time=' in line: | |||
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 | |||
wav_total_time = 0 | |||
content_list = line.split('), ') | |||
if result_json.__contains__('wav_time'): | |||
wav_total_time = result_json['wav_time'] | |||
wav_time_str = content_list[1].split('=')[1] | |||
wav_time_str = wav_time_str.split('(')[0] | |||
wav_time = float(wav_time_str) | |||
wav_time = round(wav_time, 6) | |||
if isinstance(wav_time, float): | |||
wav_total_time += wav_time | |||
result_json['wav_time'] = wav_total_time | |||
return result_json | |||
def _generate_roc_list(self, start: float, step: float, end: float, | |||
keyword: str, pos_inputs: Dict[str, Any], | |||
neg_inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
pos_wav_count = pos_inputs['wav_count'] | |||
neg_wav_time = neg_inputs['wav_time'] | |||
det_lists = pos_inputs['detected'] | |||
fa_lists = neg_inputs['detected'] | |||
threshold_cur = start | |||
""" | |||
input det_lists dict | |||
[ | |||
{ | |||
"xxx.wav": { | |||
"confidence": "0.990368", | |||
"keyword": "小云小云" | |||
} | |||
}, | |||
{ | |||
"yyy.wav": { | |||
"confidence": "0.990368", | |||
"keyword": "小云小云" | |||
} | |||
}, | |||
] | |||
output dict | |||
[ | |||
{ | |||
"threshold": 0.000, | |||
"recall": 0.999888, | |||
"fa_per_hour": 1.999999 | |||
}, | |||
{ | |||
"threshold": 0.001, | |||
"recall": 0.999888, | |||
"fa_per_hour": 1.999999 | |||
}, | |||
] | |||
""" | |||
output = [] | |||
while threshold_cur <= end: | |||
det_count = 0 | |||
fa_count = 0 | |||
for index in range(len(det_lists)): | |||
det_item = det_lists[index] | |||
det_wav_item = det_item.get(next(iter(det_item))) | |||
if det_wav_item['keyword'] == keyword: | |||
confidence = float(det_wav_item['confidence']) | |||
if confidence >= threshold_cur: | |||
det_count += 1 | |||
for index in range(len(fa_lists)): | |||
fa_item = fa_lists[index] | |||
fa_wav_item = fa_item.get(next(iter(fa_item))) | |||
if fa_wav_item['keyword'] == keyword: | |||
confidence = float(fa_wav_item['confidence']) | |||
if confidence >= threshold_cur: | |||
fa_count += 1 | |||
output_item = { | |||
'threshold': round(threshold_cur, 3), | |||
'recall': round(float(det_count / pos_wav_count), 6), | |||
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) | |||
} | |||
output.append(output_item) | |||
threshold_cur += step | |||
return output | |||
def _set_customized_keywords(self) -> Dict[str, Any]: | |||
if self._keywords is not None: | |||
word_list_inputs = self._keywords | |||
word_list = [] | |||
for i in range(len(word_list_inputs)): | |||
key = word_list_inputs[i] | |||
new_item = {} | |||
if key.__contains__('keyword'): | |||
name = key['keyword'] | |||
new_name: str = '' | |||
for n in range(0, len(name), 1): | |||
new_name += name[n] | |||
new_name += ' ' | |||
new_name = new_name.strip() | |||
new_item['name'] = new_name | |||
if key.__contains__('threshold'): | |||
threshold1: float = key['threshold'] | |||
new_item['threshold1'] = threshold1 | |||
word_list.append(new_item) | |||
out = {'word_list': word_list} | |||
return out | |||
else: | |||
return '' |
@@ -1,8 +1,5 @@ | |||
import os | |||
import shutil | |||
import stat | |||
from pathlib import Path | |||
from typing import Any, Dict, List | |||
from typing import Any, Dict, List, Union | |||
import yaml | |||
@@ -19,48 +16,28 @@ __all__ = ['WavToLists'] | |||
Fields.audio, module_name=Preprocessors.wav_to_lists) | |||
class WavToLists(Preprocessor): | |||
"""generate audio lists file from wav | |||
Args: | |||
workspace (str): store temporarily kws intermedium and result | |||
""" | |||
def __init__(self, workspace: str = None): | |||
# the workspace path | |||
if len(workspace) == 0: | |||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||
else: | |||
self._workspace = workspace | |||
if not os.path.exists(self._workspace): | |||
os.mkdir(self._workspace) | |||
def __init__(self): | |||
pass | |||
def __call__(self, | |||
model: Model = None, | |||
kws_type: str = None, | |||
wav_path: List[str] = None) -> Dict[str, Any]: | |||
def __call__(self, model: Model, wav_path: Union[List[str], | |||
str]) -> Dict[str, Any]: | |||
"""Call functions to load model and wav. | |||
Args: | |||
model (Model): model should be provided | |||
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc | |||
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path | |||
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path | |||
Returns: | |||
Dict[str, Any]: the kws result | |||
""" | |||
assert model is not None, 'preprocess kws model should be provided' | |||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' | |||
], f'preprocess kws_type {kws_type} is invalid' | |||
assert wav_path[0] is not None or wav_path[ | |||
1] is not None, 'preprocess wav_path is invalid' | |||
self._model = model | |||
out = self.forward(self._model.forward(), kws_type, wav_path) | |||
self.model = model | |||
out = self.forward(self.model.forward(), wav_path) | |||
return out | |||
def forward(self, model: Dict[str, Any], kws_type: str, | |||
wav_path: List[str]) -> Dict[str, Any]: | |||
assert len(kws_type) > 0, 'preprocess kws_type is empty' | |||
def forward(self, model: Dict[str, Any], | |||
wav_path: Union[List[str], str]) -> Dict[str, Any]: | |||
assert len( | |||
model['config_path']) > 0, 'preprocess model[config_path] is empty' | |||
assert os.path.exists( | |||
@@ -68,19 +45,29 @@ class WavToLists(Preprocessor): | |||
inputs = model.copy() | |||
wav_list = [None, None] | |||
if isinstance(wav_path, str): | |||
wav_list[0] = wav_path | |||
else: | |||
wav_list = wav_path | |||
import kws_util.common | |||
kws_type = kws_util.common.type_checking(wav_list) | |||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' | |||
], f'preprocess kws_type {kws_type} is invalid' | |||
inputs['kws_set'] = kws_type | |||
inputs['workspace'] = self._workspace | |||
if wav_path[0] is not None: | |||
inputs['pos_wav_path'] = wav_path[0] | |||
if wav_path[1] is not None: | |||
inputs['neg_wav_path'] = wav_path[1] | |||
if wav_list[0] is not None: | |||
inputs['pos_wav_path'] = wav_list[0] | |||
if wav_list[1] is not None: | |||
inputs['neg_wav_path'] = wav_list[1] | |||
out = self._read_config(inputs) | |||
out = self._generate_wav_lists(out) | |||
out = self.read_config(inputs) | |||
out = self.generate_wav_lists(out) | |||
return out | |||
def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""read and parse config.yaml to get all model files | |||
""" | |||
@@ -97,157 +84,51 @@ class WavToLists(Preprocessor): | |||
inputs['keyword_grammar'] = root['keyword_grammar'] | |||
inputs['keyword_grammar_path'] = os.path.join( | |||
inputs['model_workspace'], root['keyword_grammar']) | |||
inputs['sample_rate'] = str(root['sample_rate']) | |||
inputs['kws_tool'] = root['kws_tool'] | |||
if os.path.exists( | |||
os.path.join(inputs['workspace'], inputs['kws_tool'])): | |||
inputs['kws_tool_path'] = os.path.join(inputs['workspace'], | |||
inputs['kws_tool']) | |||
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): | |||
inputs['kws_tool_path'] = os.path.join('/usr/bin', | |||
inputs['kws_tool']) | |||
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): | |||
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) | |||
assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' | |||
os.chmod(inputs['kws_tool_path'], | |||
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) | |||
self._config_checking(inputs) | |||
inputs['sample_rate'] = root['sample_rate'] | |||
return inputs | |||
def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
"""assemble wav lists | |||
""" | |||
import kws_util.common | |||
if inputs['kws_set'] == 'wav': | |||
inputs['pos_num_thread'] = 1 | |||
wave_scp_content: str = inputs['pos_wav_path'] + '\n' | |||
with open(os.path.join(inputs['pos_data_path'], 'wave.list'), | |||
'a') as f: | |||
f.write(wave_scp_content) | |||
wav_list = [] | |||
wave_scp_content: str = inputs['pos_wav_path'] | |||
wav_list.append(wave_scp_content) | |||
inputs['pos_wav_list'] = wav_list | |||
inputs['pos_wav_count'] = 1 | |||
inputs['pos_num_thread'] = 1 | |||
if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||
# find all positive wave | |||
wav_list = [] | |||
wav_dir = inputs['pos_wav_path'] | |||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) | |||
inputs['pos_wav_list'] = wav_list | |||
list_count: int = len(wav_list) | |||
inputs['pos_wav_count'] = list_count | |||
if list_count <= 128: | |||
inputs['pos_num_thread'] = list_count | |||
j: int = 0 | |||
while j < list_count: | |||
wave_scp_content: str = wav_list[j] + '\n' | |||
wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||
j) + '.list' | |||
with open(wav_list_path, 'a') as f: | |||
f.write(wave_scp_content) | |||
j += 1 | |||
else: | |||
inputs['pos_num_thread'] = 128 | |||
j: int = 0 | |||
k: int = 0 | |||
while j < list_count: | |||
wave_scp_content: str = wav_list[j] + '\n' | |||
wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||
k) + '.list' | |||
with open(wav_list_path, 'a') as f: | |||
f.write(wave_scp_content) | |||
j += 1 | |||
k += 1 | |||
if k >= 128: | |||
k = 0 | |||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||
# find all negative wave | |||
wav_list = [] | |||
wav_dir = inputs['neg_wav_path'] | |||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) | |||
inputs['neg_wav_list'] = wav_list | |||
list_count: int = len(wav_list) | |||
inputs['neg_wav_count'] = list_count | |||
if list_count <= 128: | |||
inputs['neg_num_thread'] = list_count | |||
j: int = 0 | |||
while j < list_count: | |||
wave_scp_content: str = wav_list[j] + '\n' | |||
wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||
j) + '.list' | |||
with open(wav_list_path, 'a') as f: | |||
f.write(wave_scp_content) | |||
j += 1 | |||
else: | |||
inputs['neg_num_thread'] = 128 | |||
j: int = 0 | |||
k: int = 0 | |||
while j < list_count: | |||
wave_scp_content: str = wav_list[j] + '\n' | |||
wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||
k) + '.list' | |||
with open(wav_list_path, 'a') as f: | |||
f.write(wave_scp_content) | |||
j += 1 | |||
k += 1 | |||
if k >= 128: | |||
k = 0 | |||
return inputs | |||
def _recursion_dir_all_wave(self, wav_list, | |||
dir_path: str) -> Dict[str, Any]: | |||
dir_files = os.listdir(dir_path) | |||
for file in dir_files: | |||
file_path = os.path.join(dir_path, file) | |||
if os.path.isfile(file_path): | |||
if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||
wav_list.append(file_path) | |||
elif os.path.isdir(file_path): | |||
self._recursion_dir_all_wave(wav_list, file_path) | |||
return wav_list | |||
def _config_checking(self, inputs: Dict[str, Any]): | |||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||
inputs['pos_data_path'] = os.path.join(inputs['workspace'], | |||
'pos_data') | |||
if not os.path.exists(inputs['pos_data_path']): | |||
os.mkdir(inputs['pos_data_path']) | |||
else: | |||
shutil.rmtree(inputs['pos_data_path']) | |||
os.mkdir(inputs['pos_data_path']) | |||
inputs['pos_dump_path'] = os.path.join(inputs['workspace'], | |||
'pos_dump') | |||
if not os.path.exists(inputs['pos_dump_path']): | |||
os.mkdir(inputs['pos_dump_path']) | |||
else: | |||
shutil.rmtree(inputs['pos_dump_path']) | |||
os.mkdir(inputs['pos_dump_path']) | |||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||
inputs['neg_data_path'] = os.path.join(inputs['workspace'], | |||
'neg_data') | |||
if not os.path.exists(inputs['neg_data_path']): | |||
os.mkdir(inputs['neg_data_path']) | |||
else: | |||
shutil.rmtree(inputs['neg_data_path']) | |||
os.mkdir(inputs['neg_data_path']) | |||
inputs['neg_dump_path'] = os.path.join(inputs['workspace'], | |||
'neg_dump') | |||
if not os.path.exists(inputs['neg_dump_path']): | |||
os.mkdir(inputs['neg_dump_path']) | |||
else: | |||
shutil.rmtree(inputs['neg_dump_path']) | |||
os.mkdir(inputs['neg_dump_path']) |
@@ -218,3 +218,11 @@ class TrainerStages: | |||
after_val_iter = 'after_val_iter' | |||
after_val_epoch = 'after_val_epoch' | |||
after_run = 'after_run' | |||
class ColorCodes: | |||
MAGENTA = '\033[95m' | |||
YELLOW = '\033[93m' | |||
GREEN = '\033[92m' | |||
RED = '\033[91m' | |||
END = '\033[0m' |
@@ -2,9 +2,11 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import tarfile | |||
import unittest | |||
import numpy as np | |||
import requests | |||
from datasets import Dataset | |||
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | |||
@@ -42,3 +44,21 @@ def set_test_level(level: int): | |||
def create_dummy_test_dataset(feat, label, num): | |||
return MsDataset.from_hf_dataset( | |||
Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) | |||
def download_and_untar(fpath, furl, dst) -> str: | |||
if not os.path.exists(fpath): | |||
r = requests.get(furl) | |||
with open(fpath, 'wb') as f: | |||
f.write(r.content) | |||
file_name = os.path.basename(fpath) | |||
root_dir = os.path.dirname(fpath) | |||
target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] | |||
target_dir_path = os.path.join(root_dir, target_dir_name) | |||
# untar the file | |||
t = tarfile.open(fpath) | |||
t.extractall(path=dst) | |||
return target_dir_path |
@@ -3,6 +3,7 @@ espnet>=202204 | |||
h5py | |||
inflect | |||
keras | |||
kwsbp | |||
librosa | |||
lxml | |||
matplotlib | |||
@@ -3,14 +3,16 @@ import os | |||
import shutil | |||
import tarfile | |||
import unittest | |||
from typing import Any, Dict, List, Union | |||
import requests | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
from modelscope.utils.constant import ColorCodes, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import download_and_untar, test_level | |||
KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | |||
logger = get_logger() | |||
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | |||
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | |||
@@ -22,12 +24,102 @@ NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' | |||
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' | |||
def un_tar_gz(fname, dirs): | |||
t = tarfile.open(fname) | |||
t.extractall(path=dirs) | |||
class KeyWordSpottingTest(unittest.TestCase): | |||
action_info = { | |||
'test_run_with_wav': { | |||
'checking_item': 'kws_list', | |||
'checking_value': '小云小云', | |||
'example': { | |||
'wav_count': | |||
1, | |||
'kws_set': | |||
'wav', | |||
'kws_list': [{ | |||
'keyword': '小云小云', | |||
'offset': 5.76, | |||
'length': 9.132938, | |||
'confidence': 0.990368 | |||
}] | |||
} | |||
}, | |||
'test_run_with_wav_by_customized_keywords': { | |||
'checking_item': 'kws_list', | |||
'checking_value': '播放音乐', | |||
'example': { | |||
'wav_count': | |||
1, | |||
'kws_set': | |||
'wav', | |||
'kws_list': [{ | |||
'keyword': '播放音乐', | |||
'offset': 0.87, | |||
'length': 2.158313, | |||
'confidence': 0.646237 | |||
}] | |||
} | |||
}, | |||
'test_run_with_pos_testsets': { | |||
'checking_item': 'recall', | |||
'example': { | |||
'wav_count': 450, | |||
'kws_set': 'pos_testsets', | |||
'wav_time': 3013.75925, | |||
'keywords': ['小云小云'], | |||
'recall': 0.953333, | |||
'detected_count': 429, | |||
'rejected_count': 21, | |||
'rejected': ['yyy.wav', 'zzz.wav'] | |||
} | |||
}, | |||
'test_run_with_neg_testsets': { | |||
'checking_item': 'fa_rate', | |||
'example': { | |||
'wav_count': | |||
751, | |||
'kws_set': | |||
'neg_testsets', | |||
'wav_time': | |||
3572.180813, | |||
'keywords': ['小云小云'], | |||
'fa_rate': | |||
0.001332, | |||
'fa_per_hour': | |||
1.007788, | |||
'detected_count': | |||
1, | |||
'rejected_count': | |||
750, | |||
'detected': [{ | |||
'6.wav': { | |||
'confidence': '0.321170', | |||
'keyword': '小云小云' | |||
} | |||
}] | |||
} | |||
}, | |||
'test_run_with_roc': { | |||
'checking_item': 'keywords', | |||
'checking_value': '小云小云', | |||
'example': { | |||
'kws_set': | |||
'roc', | |||
'keywords': ['小云小云'], | |||
'小云小云': [{ | |||
'threshold': 0.0, | |||
'recall': 0.953333, | |||
'fa_per_hour': 1.007788 | |||
}, { | |||
'threshold': 0.001, | |||
'recall': 0.953333, | |||
'fa_per_hour': 1.007788 | |||
}, { | |||
'threshold': 0.999, | |||
'recall': 0.004444, | |||
'fa_per_hour': 0.0 | |||
}] | |||
} | |||
} | |||
} | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' | |||
@@ -36,315 +128,121 @@ class KeyWordSpottingTest(unittest.TestCase): | |||
os.mkdir(self.workspace) | |||
def tearDown(self) -> None: | |||
# remove workspace dir (.tmp) | |||
if os.path.exists(self.workspace): | |||
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav(self): | |||
# wav, neg_testsets, pos_testsets, roc | |||
kws_set = 'wav' | |||
# get wav file | |||
wav_file_path = POS_WAV_FILE | |||
# downloading kwsbp | |||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
if not os.path.exists(kwsbp_file_path): | |||
r = requests.get(KWSBP_URL) | |||
with open(kwsbp_file_path, 'wb') as f: | |||
f.write(r.content) | |||
shutil.rmtree(self.workspace, ignore_errors=True) | |||
def run_pipeline(self, | |||
model_id: str, | |||
wav_path: Union[List[str], str], | |||
keywords: List[str] = None) -> Dict[str, Any]: | |||
kwsbp_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||
self.assertTrue(kwsbp_16k_pipline is not None) | |||
task=Tasks.auto_speech_recognition, model=model_id) | |||
kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords) | |||
return kws_result | |||
def print_error(self, functions: str, result: Dict[str, Any]) -> None: | |||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||
+ ColorCodes.END) | |||
logger.error(ColorCodes.MAGENTA + functions | |||
+ ' correct result example: ' + ColorCodes.YELLOW | |||
+ str(self.action_info[functions]['example']) | |||
+ ColorCodes.END) | |||
raise ValueError('kws result is mismatched') | |||
def check_and_print_result(self, functions: str, | |||
result: Dict[str, Any]) -> None: | |||
if result.__contains__(self.action_info[functions]['checking_item']): | |||
checking_item = result[self.action_info[functions] | |||
['checking_item']] | |||
if functions == 'test_run_with_roc': | |||
if checking_item[0] != self.action_info[functions][ | |||
'checking_value']: | |||
self.print_error(functions, result) | |||
elif functions == 'test_run_with_wav': | |||
if checking_item[0]['keyword'] != self.action_info[functions][ | |||
'checking_value']: | |||
self.print_error(functions, result) | |||
elif functions == 'test_run_with_wav_by_customized_keywords': | |||
if checking_item[0]['keyword'] != self.action_info[functions][ | |||
'checking_value']: | |||
self.print_error(functions, result) | |||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||
+ ColorCodes.END) | |||
if functions == 'test_run_with_roc': | |||
find_keyword = result['keywords'][0] | |||
keyword_list = result[find_keyword] | |||
for item in iter(keyword_list): | |||
threshold: float = item['threshold'] | |||
recall: float = item['recall'] | |||
fa_per_hour: float = item['fa_per_hour'] | |||
logger.info(ColorCodes.YELLOW + ' threshold:' | |||
+ str(threshold) + ' recall:' + str(recall) | |||
+ ' fa_per_hour:' + str(fa_per_hour) | |||
+ ColorCodes.END) | |||
else: | |||
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) | |||
else: | |||
self.print_error(functions, result) | |||
kws_result = kwsbp_16k_pipline( | |||
kws_type=kws_set, | |||
wav_path=[wav_file_path, None], | |||
workspace=self.workspace) | |||
self.assertTrue(kws_result.__contains__('detected')) | |||
""" | |||
kws result json format example: | |||
{ | |||
'wav_count': 1, | |||
'kws_set': 'wav', | |||
'wav_time': 9.132938, | |||
'keywords': ['小云小云'], | |||
'detected': True, | |||
'confidence': 0.990368 | |||
} | |||
""" | |||
if kws_result.__contains__('keywords'): | |||
print('test_run_with_wav keywords: ', kws_result['keywords']) | |||
print('test_run_with_wav confidence: ', kws_result['confidence']) | |||
print('test_run_with_wav detected result: ', kws_result['detected']) | |||
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav(self): | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, wav_path=POS_WAV_FILE) | |||
self.check_and_print_result('test_run_with_wav', kws_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_wav_by_customized_keywords(self): | |||
# wav, neg_testsets, pos_testsets, roc | |||
kws_set = 'wav' | |||
# get wav file | |||
wav_file_path = BOFANGYINYUE_WAV_FILE | |||
# downloading kwsbp | |||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
if not os.path.exists(kwsbp_file_path): | |||
r = requests.get(KWSBP_URL) | |||
with open(kwsbp_file_path, 'wb') as f: | |||
f.write(r.content) | |||
# customized keyword if you need. | |||
# full settings eg. | |||
# keywords = [ | |||
# {'keyword':'你好电视', 'threshold': 0.008}, | |||
# {'keyword':'播放音乐', 'threshold': 0.008} | |||
# ] | |||
keywords = [{'keyword': '播放音乐'}] | |||
kwsbp_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, | |||
model=self.model_id, | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, | |||
wav_path=BOFANGYINYUE_WAV_FILE, | |||
keywords=keywords) | |||
self.assertTrue(kwsbp_16k_pipline is not None) | |||
kws_result = kwsbp_16k_pipline( | |||
kws_type=kws_set, | |||
wav_path=[wav_file_path, None], | |||
workspace=self.workspace) | |||
self.assertTrue(kws_result.__contains__('detected')) | |||
""" | |||
kws result json format example: | |||
{ | |||
'wav_count': 1, | |||
'kws_set': 'wav', | |||
'wav_time': 9.132938, | |||
'keywords': ['播放音乐'], | |||
'detected': True, | |||
'confidence': 0.660368 | |||
} | |||
""" | |||
if kws_result.__contains__('keywords'): | |||
print('test_run_with_wav_by_customized_keywords keywords: ', | |||
kws_result['keywords']) | |||
print('test_run_with_wav_by_customized_keywords confidence: ', | |||
kws_result['confidence']) | |||
print('test_run_with_wav_by_customized_keywords detected result: ', | |||
kws_result['detected']) | |||
print('test_run_with_wav_by_customized_keywords wave time(seconds): ', | |||
kws_result['wav_time']) | |||
self.check_and_print_result('test_run_with_wav_by_customized_keywords', | |||
kws_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_pos_testsets(self): | |||
# wav, neg_testsets, pos_testsets, roc | |||
kws_set = 'pos_testsets' | |||
# downloading pos_testsets file | |||
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(POS_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
wav_file_path = download_and_untar( | |||
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, | |||
self.workspace) | |||
wav_path = [wav_file_path, None] | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(POS_TESTSETS_FILE))[0] | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(testsets_dir_name))[0] | |||
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/ | |||
wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||
# untar the pos_testsets file | |||
if not os.path.exists(wav_file_path): | |||
un_tar_gz(testsets_file_path, self.workspace) | |||
# downloading kwsbp -- a kws batch processing tool | |||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
if not os.path.exists(kwsbp_file_path): | |||
r = requests.get(KWSBP_URL) | |||
with open(kwsbp_file_path, 'wb') as f: | |||
f.write(r.content) | |||
kwsbp_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||
self.assertTrue(kwsbp_16k_pipline is not None) | |||
kws_result = kwsbp_16k_pipline( | |||
kws_type=kws_set, | |||
wav_path=[wav_file_path, None], | |||
workspace=self.workspace) | |||
self.assertTrue(kws_result.__contains__('recall')) | |||
""" | |||
kws result json format example: | |||
{ | |||
'wav_count': 450, | |||
'kws_set': 'pos_testsets', | |||
'wav_time': 3013.759254, | |||
'keywords': ["小云小云"], | |||
'recall': 0.953333, | |||
'detected_count': 429, | |||
'rejected_count': 21, | |||
'rejected': [ | |||
'yyy.wav', | |||
'zzz.wav', | |||
...... | |||
] | |||
} | |||
""" | |||
if kws_result.__contains__('keywords'): | |||
print('test_run_with_pos_testsets keywords: ', | |||
kws_result['keywords']) | |||
print('test_run_with_pos_testsets recall: ', kws_result['recall']) | |||
print('test_run_with_pos_testsets wave time(seconds): ', | |||
kws_result['wav_time']) | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, wav_path=wav_path) | |||
self.check_and_print_result('test_run_with_pos_testsets', kws_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_neg_testsets(self): | |||
# wav, neg_testsets, pos_testsets, roc | |||
kws_set = 'neg_testsets' | |||
# downloading neg_testsets file | |||
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(NEG_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(NEG_TESTSETS_FILE))[0] | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(testsets_dir_name))[0] | |||
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/ | |||
wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||
# untar the neg_testsets file | |||
if not os.path.exists(wav_file_path): | |||
un_tar_gz(testsets_file_path, self.workspace) | |||
# downloading kwsbp -- a kws batch processing tool | |||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
if not os.path.exists(kwsbp_file_path): | |||
r = requests.get(KWSBP_URL) | |||
with open(kwsbp_file_path, 'wb') as f: | |||
f.write(r.content) | |||
kwsbp_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||
self.assertTrue(kwsbp_16k_pipline is not None) | |||
wav_file_path = download_and_untar( | |||
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, | |||
self.workspace) | |||
wav_path = [None, wav_file_path] | |||
kws_result = kwsbp_16k_pipline( | |||
kws_type=kws_set, | |||
wav_path=[None, wav_file_path], | |||
workspace=self.workspace) | |||
self.assertTrue(kws_result.__contains__('fa_rate')) | |||
""" | |||
kws result json format example: | |||
{ | |||
'wav_count': 751, | |||
'kws_set': 'neg_testsets', | |||
'wav_time': 3572.180812, | |||
'keywords': ['小云小云'], | |||
'fa_rate': 0.001332, | |||
'fa_per_hour': 1.007788, | |||
'detected_count': 1, | |||
'rejected_count': 750, | |||
'detected': [ | |||
{ | |||
'6.wav': { | |||
'confidence': '0.321170' | |||
} | |||
} | |||
] | |||
} | |||
""" | |||
if kws_result.__contains__('keywords'): | |||
print('test_run_with_neg_testsets keywords: ', | |||
kws_result['keywords']) | |||
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) | |||
print('test_run_with_neg_testsets fa per hour: ', | |||
kws_result['fa_per_hour']) | |||
print('test_run_with_neg_testsets wave time(seconds): ', | |||
kws_result['wav_time']) | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, wav_path=wav_path) | |||
self.check_and_print_result('test_run_with_neg_testsets', kws_result) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_roc(self): | |||
# wav, neg_testsets, pos_testsets, roc | |||
kws_set = 'roc' | |||
# downloading neg_testsets file | |||
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(NEG_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(NEG_TESTSETS_FILE))[0] | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(testsets_dir_name))[0] | |||
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/ | |||
neg_file_path = os.path.join(self.workspace, testsets_dir_name) | |||
# untar the neg_testsets file | |||
if not os.path.exists(neg_file_path): | |||
un_tar_gz(testsets_file_path, self.workspace) | |||
# downloading pos_testsets file | |||
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||
if not os.path.exists(testsets_file_path): | |||
r = requests.get(POS_TESTSETS_URL) | |||
with open(testsets_file_path, 'wb') as f: | |||
f.write(r.content) | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(POS_TESTSETS_FILE))[0] | |||
testsets_dir_name = os.path.splitext( | |||
os.path.basename(testsets_dir_name))[0] | |||
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/ | |||
pos_file_path = os.path.join(self.workspace, testsets_dir_name) | |||
# untar the pos_testsets file | |||
if not os.path.exists(pos_file_path): | |||
un_tar_gz(testsets_file_path, self.workspace) | |||
# downloading kwsbp -- a kws batch processing tool | |||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
if not os.path.exists(kwsbp_file_path): | |||
r = requests.get(KWSBP_URL) | |||
with open(kwsbp_file_path, 'wb') as f: | |||
f.write(r.content) | |||
kwsbp_16k_pipline = pipeline( | |||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||
self.assertTrue(kwsbp_16k_pipline is not None) | |||
kws_result = kwsbp_16k_pipline( | |||
kws_type=kws_set, | |||
wav_path=[pos_file_path, neg_file_path], | |||
workspace=self.workspace) | |||
""" | |||
kws result json format example: | |||
{ | |||
'kws_set': 'roc', | |||
'keywords': ['小云小云'], | |||
'小云小云': [ | |||
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||
...... | |||
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} | |||
] | |||
} | |||
""" | |||
if kws_result.__contains__('keywords'): | |||
find_keyword = kws_result['keywords'][0] | |||
print('test_run_with_roc keywords: ', find_keyword) | |||
keyword_list = kws_result[find_keyword] | |||
for item in iter(keyword_list): | |||
threshold: float = item['threshold'] | |||
recall: float = item['recall'] | |||
fa_per_hour: float = item['fa_per_hour'] | |||
print(' threshold:', threshold, ' recall:', recall, | |||
' fa_per_hour:', fa_per_hour) | |||
pos_file_path = download_and_untar( | |||
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, | |||
self.workspace) | |||
neg_file_path = download_and_untar( | |||
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, | |||
self.workspace) | |||
wav_path = [pos_file_path, neg_file_path] | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, wav_path=wav_path) | |||
self.check_and_print_result('test_run_with_roc', kws_result) | |||
if __name__ == '__main__': | |||