Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491681master
@@ -19,7 +19,7 @@ class GenericKeyWordSpotting(Model): | |||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
""" | """ | ||||
super().__init__(model_dir) | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
self.model_cfg = { | self.model_cfg = { | ||||
'model_workspace': model_dir, | 'model_workspace': model_dir, | ||||
'config_path': os.path.join(model_dir, 'config.yaml') | 'config_path': os.path.join(model_dir, 'config.yaml') | ||||
@@ -1,5 +1,4 @@ | |||||
import os | import os | ||||
import subprocess | |||||
from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
import json | import json | ||||
@@ -10,6 +9,9 @@ from modelscope.pipelines.base import Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
from modelscope.preprocessors import WavToLists | from modelscope.preprocessors import WavToLists | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
__all__ = ['KeyWordSpottingKwsbpPipeline'] | __all__ = ['KeyWordSpottingKwsbpPipeline'] | ||||
@@ -21,41 +23,24 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
config_file: str = None, | |||||
model: Union[Model, str] = None, | model: Union[Model, str] = None, | ||||
preprocessor: WavToLists = None, | preprocessor: WavToLists = None, | ||||
**kwargs): | **kwargs): | ||||
"""use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
""" | """ | ||||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__( | |||||
config_file=config_file, | |||||
model=model, | |||||
preprocessor=preprocessor, | |||||
**kwargs) | |||||
assert model is not None, 'kws model should be provided' | |||||
self._preprocessor = preprocessor | |||||
self._keywords = None | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
def __call__(self, wav_path: Union[List[str], str], | |||||
**kwargs) -> Dict[str, Any]: | |||||
if 'keywords' in kwargs.keys(): | if 'keywords' in kwargs.keys(): | ||||
self._keywords = kwargs['keywords'] | |||||
def __call__(self, | |||||
kws_type: str, | |||||
wav_path: List[str], | |||||
workspace: str = None) -> Dict[str, Any]: | |||||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | |||||
'roc'], f'kws_type {kws_type} is invalid' | |||||
self.keywords = kwargs['keywords'] | |||||
else: | |||||
self.keywords = None | |||||
if self._preprocessor is None: | |||||
self._preprocessor = WavToLists(workspace=workspace) | |||||
if self.preprocessor is None: | |||||
self.preprocessor = WavToLists() | |||||
output = self._preprocessor.forward(self.model.forward(), kws_type, | |||||
wav_path) | |||||
output = self.preprocessor.forward(self.model.forward(), wav_path) | |||||
output = self.forward(output) | output = self.forward(output) | ||||
rst = self.postprocess(output) | rst = self.postprocess(output) | ||||
return rst | return rst | ||||
@@ -64,433 +49,92 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
"""Decoding | """Decoding | ||||
""" | """ | ||||
# will generate kws result into dump/dump.JOB.log | |||||
out = self._run_with_kwsbp(inputs) | |||||
logger.info(f"Decoding with {inputs['kws_set']} mode ...") | |||||
# will generate kws result | |||||
out = self.run_with_kwsbp(inputs) | |||||
return out | return out | ||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
"""process the kws results | """process the kws results | ||||
""" | |||||
pos_result_json = {} | |||||
neg_result_json = {} | |||||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||||
self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) | |||||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) | |||||
""" | |||||
result_json format example: | |||||
{ | |||||
"wav_count": 450, | |||||
"keywords": ["小云小云"], | |||||
"wav_time": 3560.999999, | |||||
"detected": [ | |||||
{ | |||||
"xxx.wav": { | |||||
"confidence": "0.990368", | |||||
"keyword": "小云小云" | |||||
} | |||||
}, | |||||
{ | |||||
"yyy.wav": { | |||||
"confidence": "0.990368", | |||||
"keyword": "小云小云" | |||||
} | |||||
}, | |||||
...... | |||||
], | |||||
"detected_count": 429, | |||||
"rejected_count": 21, | |||||
"rejected": [ | |||||
"yyy.wav", | |||||
"zzz.wav", | |||||
...... | |||||
] | |||||
} | |||||
Args: | |||||
inputs['pos_kws_list'] or inputs['neg_kws_list']: | |||||
result_dict format example: | |||||
[{ | |||||
'confidence': 0.9903678297996521, | |||||
'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav', | |||||
'keyword': '小云小云', | |||||
'offset': 5.760000228881836, # second | |||||
'rtf_time': 66, # millisecond | |||||
'threshold': 0, | |||||
'wav_time': 9.1329375 # second | |||||
}] | |||||
""" | """ | ||||
rst_dict = {'kws_set': inputs['kws_set']} | |||||
# parsing the result of wav | |||||
if inputs['kws_set'] == 'wav': | |||||
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||||
'pos_wav_count'] | |||||
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||||
if pos_result_json['detected_count'] == 1: | |||||
rst_dict['keywords'] = pos_result_json['keywords'] | |||||
rst_dict['detected'] = True | |||||
wav_file_name = os.path.basename(inputs['pos_wav_path']) | |||||
rst_dict['confidence'] = float(pos_result_json['detected'][0] | |||||
[wav_file_name]['confidence']) | |||||
else: | |||||
rst_dict['detected'] = False | |||||
# parsing the result of pos_tests | |||||
elif inputs['kws_set'] == 'pos_testsets': | |||||
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||||
'pos_wav_count'] | |||||
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||||
if pos_result_json.__contains__('keywords'): | |||||
rst_dict['keywords'] = pos_result_json['keywords'] | |||||
rst_dict['recall'] = round( | |||||
pos_result_json['detected_count'] / rst_dict['wav_count'], 6) | |||||
if pos_result_json.__contains__('detected_count'): | |||||
rst_dict['detected_count'] = pos_result_json['detected_count'] | |||||
if pos_result_json.__contains__('rejected_count'): | |||||
rst_dict['rejected_count'] = pos_result_json['rejected_count'] | |||||
if pos_result_json.__contains__('rejected'): | |||||
rst_dict['rejected'] = pos_result_json['rejected'] | |||||
# parsing the result of neg_tests | |||||
elif inputs['kws_set'] == 'neg_testsets': | |||||
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ | |||||
'neg_wav_count'] | |||||
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) | |||||
if neg_result_json.__contains__('keywords'): | |||||
rst_dict['keywords'] = neg_result_json['keywords'] | |||||
rst_dict['fa_rate'] = 0.0 | |||||
rst_dict['fa_per_hour'] = 0.0 | |||||
if neg_result_json.__contains__('detected_count'): | |||||
rst_dict['detected_count'] = neg_result_json['detected_count'] | |||||
rst_dict['fa_rate'] = round( | |||||
neg_result_json['detected_count'] / rst_dict['wav_count'], | |||||
6) | |||||
if neg_result_json.__contains__('wav_time'): | |||||
rst_dict['fa_per_hour'] = round( | |||||
neg_result_json['detected_count'] | |||||
/ float(neg_result_json['wav_time'] / 3600), 6) | |||||
if neg_result_json.__contains__('rejected_count'): | |||||
rst_dict['rejected_count'] = neg_result_json['rejected_count'] | |||||
if neg_result_json.__contains__('detected'): | |||||
rst_dict['detected'] = neg_result_json['detected'] | |||||
# parsing the result of roc | |||||
elif inputs['kws_set'] == 'roc': | |||||
threshold_start = 0.000 | |||||
threshold_step = 0.001 | |||||
threshold_end = 1.000 | |||||
pos_keywords_list = [] | |||||
neg_keywords_list = [] | |||||
if pos_result_json.__contains__('keywords'): | |||||
pos_keywords_list = pos_result_json['keywords'] | |||||
if neg_result_json.__contains__('keywords'): | |||||
neg_keywords_list = neg_result_json['keywords'] | |||||
keywords_list = list(set(pos_keywords_list + neg_keywords_list)) | |||||
pos_result_json['wav_count'] = inputs['pos_wav_count'] | |||||
neg_result_json['wav_count'] = inputs['neg_wav_count'] | |||||
if len(keywords_list) > 0: | |||||
rst_dict['keywords'] = keywords_list | |||||
for index in range(len(rst_dict['keywords'])): | |||||
cur_keyword = rst_dict['keywords'][index] | |||||
output_list = self._generate_roc_list( | |||||
start=threshold_start, | |||||
step=threshold_step, | |||||
end=threshold_end, | |||||
keyword=cur_keyword, | |||||
pos_inputs=pos_result_json, | |||||
neg_inputs=neg_result_json) | |||||
rst_dict[cur_keyword] = output_list | |||||
import kws_util.common | |||||
neg_kws_list = None | |||||
pos_kws_list = None | |||||
if 'pos_kws_list' in inputs: | |||||
pos_kws_list = inputs['pos_kws_list'] | |||||
if 'neg_kws_list' in inputs: | |||||
neg_kws_list = inputs['neg_kws_list'] | |||||
rst_dict = kws_util.common.parsing_kws_result( | |||||
kws_type=inputs['kws_set'], | |||||
pos_list=pos_kws_list, | |||||
neg_list=neg_kws_list) | |||||
return rst_dict | return rst_dict | ||||
def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
opts: str = '' | |||||
def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
cmd = { | |||||
'sys_dir': inputs['model_workspace'], | |||||
'cfg_file': inputs['cfg_file_path'], | |||||
'sample_rate': inputs['sample_rate'], | |||||
'keyword_custom': '' | |||||
} | |||||
import kwsbp | |||||
import kws_util.common | |||||
kws_inference = kwsbp.KwsbpEngine() | |||||
# setting customized keywords | # setting customized keywords | ||||
keywords_json = self._set_customized_keywords() | |||||
if len(keywords_json) > 0: | |||||
keywords_json_file = os.path.join(inputs['workspace'], | |||||
'keyword_custom.json') | |||||
with open(keywords_json_file, 'w') as f: | |||||
json.dump(keywords_json, f) | |||||
opts = '--keyword-custom ' + keywords_json_file | |||||
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords( | |||||
self.keywords) | |||||
if inputs['kws_set'] == 'roc': | if inputs['kws_set'] == 'roc': | ||||
inputs['keyword_grammar_path'] = os.path.join( | inputs['keyword_grammar_path'] = os.path.join( | ||||
inputs['model_workspace'], 'keywords_roc.json') | inputs['model_workspace'], 'keywords_roc.json') | ||||
if inputs['kws_set'] == 'wav': | |||||
dump_log_path: str = os.path.join(inputs['pos_dump_path'], | |||||
'dump.log') | |||||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
' --sys-dir=' + inputs['model_workspace'] + \ | |||||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
' --sample-rate=' + inputs['sample_rate'] + \ | |||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | |||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
os.system(kws_cmd) | |||||
if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||||
data_dir: str = os.listdir(inputs['pos_data_path']) | |||||
wav_list = [] | |||||
for i in data_dir: | |||||
suffix = os.path.splitext(os.path.basename(i))[1] | |||||
if suffix == '.list': | |||||
wav_list.append(os.path.join(inputs['pos_data_path'], i)) | |||||
j: int = 0 | |||||
process = [] | |||||
while j < inputs['pos_num_thread']: | |||||
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( | |||||
j) + '.list' | |||||
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( | |||||
j) + '.log' | |||||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
' --sys-dir=' + inputs['model_workspace'] + \ | |||||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
' --sample-rate=' + inputs['sample_rate'] + \ | |||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
' --wave-scp=' + wav_list_path + \ | |||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
p = subprocess.Popen(kws_cmd, shell=True) | |||||
process.append(p) | |||||
j += 1 | |||||
k: int = 0 | |||||
while k < len(process): | |||||
process[k].wait() | |||||
k += 1 | |||||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||||
cmd['wave_scp'] = inputs['pos_wav_list'] | |||||
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] | |||||
cmd['num_thread'] = inputs['pos_num_thread'] | |||||
# run and get inference result | |||||
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], | |||||
cmd['keyword_grammar_path'], | |||||
str(json.dumps(cmd['wave_scp'])), | |||||
str(cmd['customized_keywords']), | |||||
cmd['sample_rate'], | |||||
cmd['num_thread']) | |||||
pos_result = json.loads(result) | |||||
inputs['pos_kws_list'] = pos_result['kws_list'] | |||||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | if inputs['kws_set'] in ['neg_testsets', 'roc']: | ||||
data_dir: str = os.listdir(inputs['neg_data_path']) | |||||
wav_list = [] | |||||
for i in data_dir: | |||||
suffix = os.path.splitext(os.path.basename(i))[1] | |||||
if suffix == '.list': | |||||
wav_list.append(os.path.join(inputs['neg_data_path'], i)) | |||||
j: int = 0 | |||||
process = [] | |||||
while j < inputs['neg_num_thread']: | |||||
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( | |||||
j) + '.list' | |||||
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( | |||||
j) + '.log' | |||||
kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
' --sys-dir=' + inputs['model_workspace'] + \ | |||||
' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
' --sample-rate=' + inputs['sample_rate'] + \ | |||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
' --wave-scp=' + wav_list_path + \ | |||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
p = subprocess.Popen(kws_cmd, shell=True) | |||||
process.append(p) | |||||
j += 1 | |||||
k: int = 0 | |||||
while k < len(process): | |||||
process[k].wait() | |||||
k += 1 | |||||
cmd['wave_scp'] = inputs['neg_wav_list'] | |||||
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] | |||||
cmd['num_thread'] = inputs['neg_num_thread'] | |||||
# run and get inference result | |||||
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], | |||||
cmd['keyword_grammar_path'], | |||||
str(json.dumps(cmd['wave_scp'])), | |||||
str(cmd['customized_keywords']), | |||||
cmd['sample_rate'], | |||||
cmd['num_thread']) | |||||
neg_result = json.loads(result) | |||||
inputs['neg_kws_list'] = neg_result['kws_list'] | |||||
return inputs | return inputs | ||||
def _parse_dump_log(self, result_json: Dict[str, Any], | |||||
dump_path: str) -> Dict[str, Any]: | |||||
dump_dir = os.listdir(dump_path) | |||||
for i in dump_dir: | |||||
basename = os.path.splitext(os.path.basename(i))[0] | |||||
# find dump.JOB.log | |||||
if 'dump' in basename: | |||||
with open( | |||||
os.path.join(dump_path, i), mode='r', | |||||
encoding='utf-8') as file: | |||||
while 1: | |||||
line = file.readline() | |||||
if not line: | |||||
break | |||||
else: | |||||
result_json = self._parse_result_log( | |||||
line, result_json) | |||||
def _parse_result_log(self, line: str, | |||||
result_json: Dict[str, Any]) -> Dict[str, Any]: | |||||
# valid info | |||||
if '[rejected]' in line or '[detected]' in line: | |||||
detected_count = 0 | |||||
rejected_count = 0 | |||||
if result_json.__contains__('detected_count'): | |||||
detected_count = result_json['detected_count'] | |||||
if result_json.__contains__('rejected_count'): | |||||
rejected_count = result_json['rejected_count'] | |||||
if '[detected]' in line: | |||||
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, | |||||
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, | |||||
detected_count += 1 | |||||
content_list = line.split(', ') | |||||
file_name = os.path.basename(content_list[1].split(':')[1]) | |||||
keyword = content_list[2].split(':')[1] | |||||
confidence = content_list[3].split(':')[1] | |||||
keywords_list = [] | |||||
if result_json.__contains__('keywords'): | |||||
keywords_list = result_json['keywords'] | |||||
if keyword not in keywords_list: | |||||
keywords_list.append(keyword) | |||||
result_json['keywords'] = keywords_list | |||||
keyword_item = {} | |||||
keyword_item['confidence'] = confidence | |||||
keyword_item['keyword'] = keyword | |||||
item = {} | |||||
item[file_name] = keyword_item | |||||
detected_list = [] | |||||
if result_json.__contains__('detected'): | |||||
detected_list = result_json['detected'] | |||||
detected_list.append(item) | |||||
result_json['detected'] = detected_list | |||||
elif '[rejected]' in line: | |||||
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav | |||||
rejected_count += 1 | |||||
content_list = line.split(', ') | |||||
file_name = os.path.basename(content_list[1].split(':')[1]) | |||||
file_name = file_name.strip().replace('\n', | |||||
'').replace('\r', '') | |||||
rejected_list = [] | |||||
if result_json.__contains__('rejected'): | |||||
rejected_list = result_json['rejected'] | |||||
rejected_list.append(file_name) | |||||
result_json['rejected'] = rejected_list | |||||
result_json['detected_count'] = detected_count | |||||
result_json['rejected_count'] = rejected_count | |||||
elif 'total_proc_time=' in line and 'wav_time=' in line: | |||||
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 | |||||
wav_total_time = 0 | |||||
content_list = line.split('), ') | |||||
if result_json.__contains__('wav_time'): | |||||
wav_total_time = result_json['wav_time'] | |||||
wav_time_str = content_list[1].split('=')[1] | |||||
wav_time_str = wav_time_str.split('(')[0] | |||||
wav_time = float(wav_time_str) | |||||
wav_time = round(wav_time, 6) | |||||
if isinstance(wav_time, float): | |||||
wav_total_time += wav_time | |||||
result_json['wav_time'] = wav_total_time | |||||
return result_json | |||||
def _generate_roc_list(self, start: float, step: float, end: float, | |||||
keyword: str, pos_inputs: Dict[str, Any], | |||||
neg_inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
pos_wav_count = pos_inputs['wav_count'] | |||||
neg_wav_time = neg_inputs['wav_time'] | |||||
det_lists = pos_inputs['detected'] | |||||
fa_lists = neg_inputs['detected'] | |||||
threshold_cur = start | |||||
""" | |||||
input det_lists dict | |||||
[ | |||||
{ | |||||
"xxx.wav": { | |||||
"confidence": "0.990368", | |||||
"keyword": "小云小云" | |||||
} | |||||
}, | |||||
{ | |||||
"yyy.wav": { | |||||
"confidence": "0.990368", | |||||
"keyword": "小云小云" | |||||
} | |||||
}, | |||||
] | |||||
output dict | |||||
[ | |||||
{ | |||||
"threshold": 0.000, | |||||
"recall": 0.999888, | |||||
"fa_per_hour": 1.999999 | |||||
}, | |||||
{ | |||||
"threshold": 0.001, | |||||
"recall": 0.999888, | |||||
"fa_per_hour": 1.999999 | |||||
}, | |||||
] | |||||
""" | |||||
output = [] | |||||
while threshold_cur <= end: | |||||
det_count = 0 | |||||
fa_count = 0 | |||||
for index in range(len(det_lists)): | |||||
det_item = det_lists[index] | |||||
det_wav_item = det_item.get(next(iter(det_item))) | |||||
if det_wav_item['keyword'] == keyword: | |||||
confidence = float(det_wav_item['confidence']) | |||||
if confidence >= threshold_cur: | |||||
det_count += 1 | |||||
for index in range(len(fa_lists)): | |||||
fa_item = fa_lists[index] | |||||
fa_wav_item = fa_item.get(next(iter(fa_item))) | |||||
if fa_wav_item['keyword'] == keyword: | |||||
confidence = float(fa_wav_item['confidence']) | |||||
if confidence >= threshold_cur: | |||||
fa_count += 1 | |||||
output_item = { | |||||
'threshold': round(threshold_cur, 3), | |||||
'recall': round(float(det_count / pos_wav_count), 6), | |||||
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) | |||||
} | |||||
output.append(output_item) | |||||
threshold_cur += step | |||||
return output | |||||
def _set_customized_keywords(self) -> Dict[str, Any]: | |||||
if self._keywords is not None: | |||||
word_list_inputs = self._keywords | |||||
word_list = [] | |||||
for i in range(len(word_list_inputs)): | |||||
key = word_list_inputs[i] | |||||
new_item = {} | |||||
if key.__contains__('keyword'): | |||||
name = key['keyword'] | |||||
new_name: str = '' | |||||
for n in range(0, len(name), 1): | |||||
new_name += name[n] | |||||
new_name += ' ' | |||||
new_name = new_name.strip() | |||||
new_item['name'] = new_name | |||||
if key.__contains__('threshold'): | |||||
threshold1: float = key['threshold'] | |||||
new_item['threshold1'] = threshold1 | |||||
word_list.append(new_item) | |||||
out = {'word_list': word_list} | |||||
return out | |||||
else: | |||||
return '' |
@@ -1,8 +1,5 @@ | |||||
import os | import os | ||||
import shutil | |||||
import stat | |||||
from pathlib import Path | |||||
from typing import Any, Dict, List | |||||
from typing import Any, Dict, List, Union | |||||
import yaml | import yaml | ||||
@@ -19,48 +16,28 @@ __all__ = ['WavToLists'] | |||||
Fields.audio, module_name=Preprocessors.wav_to_lists) | Fields.audio, module_name=Preprocessors.wav_to_lists) | ||||
class WavToLists(Preprocessor): | class WavToLists(Preprocessor): | ||||
"""generate audio lists file from wav | """generate audio lists file from wav | ||||
Args: | |||||
workspace (str): store temporarily kws intermedium and result | |||||
""" | """ | ||||
def __init__(self, workspace: str = None): | |||||
# the workspace path | |||||
if len(workspace) == 0: | |||||
self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
else: | |||||
self._workspace = workspace | |||||
if not os.path.exists(self._workspace): | |||||
os.mkdir(self._workspace) | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, | |||||
model: Model = None, | |||||
kws_type: str = None, | |||||
wav_path: List[str] = None) -> Dict[str, Any]: | |||||
def __call__(self, model: Model, wav_path: Union[List[str], | |||||
str]) -> Dict[str, Any]: | |||||
"""Call functions to load model and wav. | """Call functions to load model and wav. | ||||
Args: | Args: | ||||
model (Model): model should be provided | model (Model): model should be provided | ||||
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc | |||||
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path | |||||
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path | |||||
Returns: | Returns: | ||||
Dict[str, Any]: the kws result | Dict[str, Any]: the kws result | ||||
""" | """ | ||||
assert model is not None, 'preprocess kws model should be provided' | |||||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' | |||||
], f'preprocess kws_type {kws_type} is invalid' | |||||
assert wav_path[0] is not None or wav_path[ | |||||
1] is not None, 'preprocess wav_path is invalid' | |||||
self._model = model | |||||
out = self.forward(self._model.forward(), kws_type, wav_path) | |||||
self.model = model | |||||
out = self.forward(self.model.forward(), wav_path) | |||||
return out | return out | ||||
def forward(self, model: Dict[str, Any], kws_type: str, | |||||
wav_path: List[str]) -> Dict[str, Any]: | |||||
assert len(kws_type) > 0, 'preprocess kws_type is empty' | |||||
def forward(self, model: Dict[str, Any], | |||||
wav_path: Union[List[str], str]) -> Dict[str, Any]: | |||||
assert len( | assert len( | ||||
model['config_path']) > 0, 'preprocess model[config_path] is empty' | model['config_path']) > 0, 'preprocess model[config_path] is empty' | ||||
assert os.path.exists( | assert os.path.exists( | ||||
@@ -68,19 +45,29 @@ class WavToLists(Preprocessor): | |||||
inputs = model.copy() | inputs = model.copy() | ||||
wav_list = [None, None] | |||||
if isinstance(wav_path, str): | |||||
wav_list[0] = wav_path | |||||
else: | |||||
wav_list = wav_path | |||||
import kws_util.common | |||||
kws_type = kws_util.common.type_checking(wav_list) | |||||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' | |||||
], f'preprocess kws_type {kws_type} is invalid' | |||||
inputs['kws_set'] = kws_type | inputs['kws_set'] = kws_type | ||||
inputs['workspace'] = self._workspace | |||||
if wav_path[0] is not None: | |||||
inputs['pos_wav_path'] = wav_path[0] | |||||
if wav_path[1] is not None: | |||||
inputs['neg_wav_path'] = wav_path[1] | |||||
if wav_list[0] is not None: | |||||
inputs['pos_wav_path'] = wav_list[0] | |||||
if wav_list[1] is not None: | |||||
inputs['neg_wav_path'] = wav_list[1] | |||||
out = self._read_config(inputs) | |||||
out = self._generate_wav_lists(out) | |||||
out = self.read_config(inputs) | |||||
out = self.generate_wav_lists(out) | |||||
return out | return out | ||||
def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""read and parse config.yaml to get all model files | """read and parse config.yaml to get all model files | ||||
""" | """ | ||||
@@ -97,157 +84,51 @@ class WavToLists(Preprocessor): | |||||
inputs['keyword_grammar'] = root['keyword_grammar'] | inputs['keyword_grammar'] = root['keyword_grammar'] | ||||
inputs['keyword_grammar_path'] = os.path.join( | inputs['keyword_grammar_path'] = os.path.join( | ||||
inputs['model_workspace'], root['keyword_grammar']) | inputs['model_workspace'], root['keyword_grammar']) | ||||
inputs['sample_rate'] = str(root['sample_rate']) | |||||
inputs['kws_tool'] = root['kws_tool'] | |||||
if os.path.exists( | |||||
os.path.join(inputs['workspace'], inputs['kws_tool'])): | |||||
inputs['kws_tool_path'] = os.path.join(inputs['workspace'], | |||||
inputs['kws_tool']) | |||||
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): | |||||
inputs['kws_tool_path'] = os.path.join('/usr/bin', | |||||
inputs['kws_tool']) | |||||
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): | |||||
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) | |||||
assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' | |||||
os.chmod(inputs['kws_tool_path'], | |||||
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) | |||||
self._config_checking(inputs) | |||||
inputs['sample_rate'] = root['sample_rate'] | |||||
return inputs | return inputs | ||||
def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""assemble wav lists | """assemble wav lists | ||||
""" | """ | ||||
import kws_util.common | |||||
if inputs['kws_set'] == 'wav': | if inputs['kws_set'] == 'wav': | ||||
inputs['pos_num_thread'] = 1 | |||||
wave_scp_content: str = inputs['pos_wav_path'] + '\n' | |||||
with open(os.path.join(inputs['pos_data_path'], 'wave.list'), | |||||
'a') as f: | |||||
f.write(wave_scp_content) | |||||
wav_list = [] | |||||
wave_scp_content: str = inputs['pos_wav_path'] | |||||
wav_list.append(wave_scp_content) | |||||
inputs['pos_wav_list'] = wav_list | |||||
inputs['pos_wav_count'] = 1 | inputs['pos_wav_count'] = 1 | ||||
inputs['pos_num_thread'] = 1 | |||||
if inputs['kws_set'] in ['pos_testsets', 'roc']: | if inputs['kws_set'] in ['pos_testsets', 'roc']: | ||||
# find all positive wave | # find all positive wave | ||||
wav_list = [] | wav_list = [] | ||||
wav_dir = inputs['pos_wav_path'] | wav_dir = inputs['pos_wav_path'] | ||||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) | |||||
inputs['pos_wav_list'] = wav_list | |||||
list_count: int = len(wav_list) | list_count: int = len(wav_list) | ||||
inputs['pos_wav_count'] = list_count | inputs['pos_wav_count'] = list_count | ||||
if list_count <= 128: | if list_count <= 128: | ||||
inputs['pos_num_thread'] = list_count | inputs['pos_num_thread'] = list_count | ||||
j: int = 0 | |||||
while j < list_count: | |||||
wave_scp_content: str = wav_list[j] + '\n' | |||||
wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||||
j) + '.list' | |||||
with open(wav_list_path, 'a') as f: | |||||
f.write(wave_scp_content) | |||||
j += 1 | |||||
else: | else: | ||||
inputs['pos_num_thread'] = 128 | inputs['pos_num_thread'] = 128 | ||||
j: int = 0 | |||||
k: int = 0 | |||||
while j < list_count: | |||||
wave_scp_content: str = wav_list[j] + '\n' | |||||
wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||||
k) + '.list' | |||||
with open(wav_list_path, 'a') as f: | |||||
f.write(wave_scp_content) | |||||
j += 1 | |||||
k += 1 | |||||
if k >= 128: | |||||
k = 0 | |||||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | if inputs['kws_set'] in ['neg_testsets', 'roc']: | ||||
# find all negative wave | # find all negative wave | ||||
wav_list = [] | wav_list = [] | ||||
wav_dir = inputs['neg_wav_path'] | wav_dir = inputs['neg_wav_path'] | ||||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) | |||||
inputs['neg_wav_list'] = wav_list | |||||
list_count: int = len(wav_list) | list_count: int = len(wav_list) | ||||
inputs['neg_wav_count'] = list_count | inputs['neg_wav_count'] = list_count | ||||
if list_count <= 128: | if list_count <= 128: | ||||
inputs['neg_num_thread'] = list_count | inputs['neg_num_thread'] = list_count | ||||
j: int = 0 | |||||
while j < list_count: | |||||
wave_scp_content: str = wav_list[j] + '\n' | |||||
wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||||
j) + '.list' | |||||
with open(wav_list_path, 'a') as f: | |||||
f.write(wave_scp_content) | |||||
j += 1 | |||||
else: | else: | ||||
inputs['neg_num_thread'] = 128 | inputs['neg_num_thread'] = 128 | ||||
j: int = 0 | |||||
k: int = 0 | |||||
while j < list_count: | |||||
wave_scp_content: str = wav_list[j] + '\n' | |||||
wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||||
k) + '.list' | |||||
with open(wav_list_path, 'a') as f: | |||||
f.write(wave_scp_content) | |||||
j += 1 | |||||
k += 1 | |||||
if k >= 128: | |||||
k = 0 | |||||
return inputs | return inputs | ||||
def _recursion_dir_all_wave(self, wav_list, | |||||
dir_path: str) -> Dict[str, Any]: | |||||
dir_files = os.listdir(dir_path) | |||||
for file in dir_files: | |||||
file_path = os.path.join(dir_path, file) | |||||
if os.path.isfile(file_path): | |||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||||
wav_list.append(file_path) | |||||
elif os.path.isdir(file_path): | |||||
self._recursion_dir_all_wave(wav_list, file_path) | |||||
return wav_list | |||||
def _config_checking(self, inputs: Dict[str, Any]): | |||||
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||||
inputs['pos_data_path'] = os.path.join(inputs['workspace'], | |||||
'pos_data') | |||||
if not os.path.exists(inputs['pos_data_path']): | |||||
os.mkdir(inputs['pos_data_path']) | |||||
else: | |||||
shutil.rmtree(inputs['pos_data_path']) | |||||
os.mkdir(inputs['pos_data_path']) | |||||
inputs['pos_dump_path'] = os.path.join(inputs['workspace'], | |||||
'pos_dump') | |||||
if not os.path.exists(inputs['pos_dump_path']): | |||||
os.mkdir(inputs['pos_dump_path']) | |||||
else: | |||||
shutil.rmtree(inputs['pos_dump_path']) | |||||
os.mkdir(inputs['pos_dump_path']) | |||||
if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
inputs['neg_data_path'] = os.path.join(inputs['workspace'], | |||||
'neg_data') | |||||
if not os.path.exists(inputs['neg_data_path']): | |||||
os.mkdir(inputs['neg_data_path']) | |||||
else: | |||||
shutil.rmtree(inputs['neg_data_path']) | |||||
os.mkdir(inputs['neg_data_path']) | |||||
inputs['neg_dump_path'] = os.path.join(inputs['workspace'], | |||||
'neg_dump') | |||||
if not os.path.exists(inputs['neg_dump_path']): | |||||
os.mkdir(inputs['neg_dump_path']) | |||||
else: | |||||
shutil.rmtree(inputs['neg_dump_path']) | |||||
os.mkdir(inputs['neg_dump_path']) |
@@ -218,3 +218,11 @@ class TrainerStages: | |||||
after_val_iter = 'after_val_iter' | after_val_iter = 'after_val_iter' | ||||
after_val_epoch = 'after_val_epoch' | after_val_epoch = 'after_val_epoch' | ||||
after_run = 'after_run' | after_run = 'after_run' | ||||
class ColorCodes: | |||||
MAGENTA = '\033[95m' | |||||
YELLOW = '\033[93m' | |||||
GREEN = '\033[92m' | |||||
RED = '\033[91m' | |||||
END = '\033[0m' |
@@ -2,9 +2,11 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | import os | ||||
import tarfile | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import requests | |||||
from datasets import Dataset | from datasets import Dataset | ||||
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | ||||
@@ -42,3 +44,21 @@ def set_test_level(level: int): | |||||
def create_dummy_test_dataset(feat, label, num): | def create_dummy_test_dataset(feat, label, num): | ||||
return MsDataset.from_hf_dataset( | return MsDataset.from_hf_dataset( | ||||
Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) | Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) | ||||
def download_and_untar(fpath, furl, dst) -> str: | |||||
if not os.path.exists(fpath): | |||||
r = requests.get(furl) | |||||
with open(fpath, 'wb') as f: | |||||
f.write(r.content) | |||||
file_name = os.path.basename(fpath) | |||||
root_dir = os.path.dirname(fpath) | |||||
target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] | |||||
target_dir_path = os.path.join(root_dir, target_dir_name) | |||||
# untar the file | |||||
t = tarfile.open(fpath) | |||||
t.extractall(path=dst) | |||||
return target_dir_path |
@@ -3,6 +3,7 @@ espnet>=202204 | |||||
h5py | h5py | ||||
inflect | inflect | ||||
keras | keras | ||||
kwsbp | |||||
librosa | librosa | ||||
lxml | lxml | ||||
matplotlib | matplotlib | ||||
@@ -3,14 +3,16 @@ import os | |||||
import shutil | import shutil | ||||
import tarfile | import tarfile | ||||
import unittest | import unittest | ||||
from typing import Any, Dict, List, Union | |||||
import requests | import requests | ||||
from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
from modelscope.utils.constant import ColorCodes, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.test_utils import download_and_untar, test_level | |||||
KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | |||||
logger = get_logger() | |||||
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | ||||
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | ||||
@@ -22,12 +24,102 @@ NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' | |||||
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' | NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' | ||||
def un_tar_gz(fname, dirs): | |||||
t = tarfile.open(fname) | |||||
t.extractall(path=dirs) | |||||
class KeyWordSpottingTest(unittest.TestCase): | class KeyWordSpottingTest(unittest.TestCase): | ||||
action_info = { | |||||
'test_run_with_wav': { | |||||
'checking_item': 'kws_list', | |||||
'checking_value': '小云小云', | |||||
'example': { | |||||
'wav_count': | |||||
1, | |||||
'kws_set': | |||||
'wav', | |||||
'kws_list': [{ | |||||
'keyword': '小云小云', | |||||
'offset': 5.76, | |||||
'length': 9.132938, | |||||
'confidence': 0.990368 | |||||
}] | |||||
} | |||||
}, | |||||
'test_run_with_wav_by_customized_keywords': { | |||||
'checking_item': 'kws_list', | |||||
'checking_value': '播放音乐', | |||||
'example': { | |||||
'wav_count': | |||||
1, | |||||
'kws_set': | |||||
'wav', | |||||
'kws_list': [{ | |||||
'keyword': '播放音乐', | |||||
'offset': 0.87, | |||||
'length': 2.158313, | |||||
'confidence': 0.646237 | |||||
}] | |||||
} | |||||
}, | |||||
'test_run_with_pos_testsets': { | |||||
'checking_item': 'recall', | |||||
'example': { | |||||
'wav_count': 450, | |||||
'kws_set': 'pos_testsets', | |||||
'wav_time': 3013.75925, | |||||
'keywords': ['小云小云'], | |||||
'recall': 0.953333, | |||||
'detected_count': 429, | |||||
'rejected_count': 21, | |||||
'rejected': ['yyy.wav', 'zzz.wav'] | |||||
} | |||||
}, | |||||
'test_run_with_neg_testsets': { | |||||
'checking_item': 'fa_rate', | |||||
'example': { | |||||
'wav_count': | |||||
751, | |||||
'kws_set': | |||||
'neg_testsets', | |||||
'wav_time': | |||||
3572.180813, | |||||
'keywords': ['小云小云'], | |||||
'fa_rate': | |||||
0.001332, | |||||
'fa_per_hour': | |||||
1.007788, | |||||
'detected_count': | |||||
1, | |||||
'rejected_count': | |||||
750, | |||||
'detected': [{ | |||||
'6.wav': { | |||||
'confidence': '0.321170', | |||||
'keyword': '小云小云' | |||||
} | |||||
}] | |||||
} | |||||
}, | |||||
'test_run_with_roc': { | |||||
'checking_item': 'keywords', | |||||
'checking_value': '小云小云', | |||||
'example': { | |||||
'kws_set': | |||||
'roc', | |||||
'keywords': ['小云小云'], | |||||
'小云小云': [{ | |||||
'threshold': 0.0, | |||||
'recall': 0.953333, | |||||
'fa_per_hour': 1.007788 | |||||
}, { | |||||
'threshold': 0.001, | |||||
'recall': 0.953333, | |||||
'fa_per_hour': 1.007788 | |||||
}, { | |||||
'threshold': 0.999, | |||||
'recall': 0.004444, | |||||
'fa_per_hour': 0.0 | |||||
}] | |||||
} | |||||
} | |||||
} | |||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' | self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' | ||||
@@ -36,315 +128,121 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
os.mkdir(self.workspace) | os.mkdir(self.workspace) | ||||
def tearDown(self) -> None: | def tearDown(self) -> None: | ||||
# remove workspace dir (.tmp) | |||||
if os.path.exists(self.workspace): | if os.path.exists(self.workspace): | ||||
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_wav(self): | |||||
# wav, neg_testsets, pos_testsets, roc | |||||
kws_set = 'wav' | |||||
# get wav file | |||||
wav_file_path = POS_WAV_FILE | |||||
# downloading kwsbp | |||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
if not os.path.exists(kwsbp_file_path): | |||||
r = requests.get(KWSBP_URL) | |||||
with open(kwsbp_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
shutil.rmtree(self.workspace, ignore_errors=True) | |||||
def run_pipeline(self, | |||||
model_id: str, | |||||
wav_path: Union[List[str], str], | |||||
keywords: List[str] = None) -> Dict[str, Any]: | |||||
kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||||
self.assertTrue(kwsbp_16k_pipline is not None) | |||||
task=Tasks.auto_speech_recognition, model=model_id) | |||||
kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords) | |||||
return kws_result | |||||
def print_error(self, functions: str, result: Dict[str, Any]) -> None: | |||||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||||
+ ColorCodes.END) | |||||
logger.error(ColorCodes.MAGENTA + functions | |||||
+ ' correct result example: ' + ColorCodes.YELLOW | |||||
+ str(self.action_info[functions]['example']) | |||||
+ ColorCodes.END) | |||||
raise ValueError('kws result is mismatched') | |||||
def check_and_print_result(self, functions: str, | |||||
result: Dict[str, Any]) -> None: | |||||
if result.__contains__(self.action_info[functions]['checking_item']): | |||||
checking_item = result[self.action_info[functions] | |||||
['checking_item']] | |||||
if functions == 'test_run_with_roc': | |||||
if checking_item[0] != self.action_info[functions][ | |||||
'checking_value']: | |||||
self.print_error(functions, result) | |||||
elif functions == 'test_run_with_wav': | |||||
if checking_item[0]['keyword'] != self.action_info[functions][ | |||||
'checking_value']: | |||||
self.print_error(functions, result) | |||||
elif functions == 'test_run_with_wav_by_customized_keywords': | |||||
if checking_item[0]['keyword'] != self.action_info[functions][ | |||||
'checking_value']: | |||||
self.print_error(functions, result) | |||||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||||
+ ColorCodes.END) | |||||
if functions == 'test_run_with_roc': | |||||
find_keyword = result['keywords'][0] | |||||
keyword_list = result[find_keyword] | |||||
for item in iter(keyword_list): | |||||
threshold: float = item['threshold'] | |||||
recall: float = item['recall'] | |||||
fa_per_hour: float = item['fa_per_hour'] | |||||
logger.info(ColorCodes.YELLOW + ' threshold:' | |||||
+ str(threshold) + ' recall:' + str(recall) | |||||
+ ' fa_per_hour:' + str(fa_per_hour) | |||||
+ ColorCodes.END) | |||||
else: | |||||
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) | |||||
else: | |||||
self.print_error(functions, result) | |||||
kws_result = kwsbp_16k_pipline( | |||||
kws_type=kws_set, | |||||
wav_path=[wav_file_path, None], | |||||
workspace=self.workspace) | |||||
self.assertTrue(kws_result.__contains__('detected')) | |||||
""" | |||||
kws result json format example: | |||||
{ | |||||
'wav_count': 1, | |||||
'kws_set': 'wav', | |||||
'wav_time': 9.132938, | |||||
'keywords': ['小云小云'], | |||||
'detected': True, | |||||
'confidence': 0.990368 | |||||
} | |||||
""" | |||||
if kws_result.__contains__('keywords'): | |||||
print('test_run_with_wav keywords: ', kws_result['keywords']) | |||||
print('test_run_with_wav confidence: ', kws_result['confidence']) | |||||
print('test_run_with_wav detected result: ', kws_result['detected']) | |||||
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_wav(self): | |||||
kws_result = self.run_pipeline( | |||||
model_id=self.model_id, wav_path=POS_WAV_FILE) | |||||
self.check_and_print_result('test_run_with_wav', kws_result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_wav_by_customized_keywords(self): | def test_run_with_wav_by_customized_keywords(self): | ||||
# wav, neg_testsets, pos_testsets, roc | |||||
kws_set = 'wav' | |||||
# get wav file | |||||
wav_file_path = BOFANGYINYUE_WAV_FILE | |||||
# downloading kwsbp | |||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
if not os.path.exists(kwsbp_file_path): | |||||
r = requests.get(KWSBP_URL) | |||||
with open(kwsbp_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
# customized keyword if you need. | |||||
# full settings eg. | |||||
# keywords = [ | |||||
# {'keyword':'你好电视', 'threshold': 0.008}, | |||||
# {'keyword':'播放音乐', 'threshold': 0.008} | |||||
# ] | |||||
keywords = [{'keyword': '播放音乐'}] | keywords = [{'keyword': '播放音乐'}] | ||||
kwsbp_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, | |||||
model=self.model_id, | |||||
kws_result = self.run_pipeline( | |||||
model_id=self.model_id, | |||||
wav_path=BOFANGYINYUE_WAV_FILE, | |||||
keywords=keywords) | keywords=keywords) | ||||
self.assertTrue(kwsbp_16k_pipline is not None) | |||||
kws_result = kwsbp_16k_pipline( | |||||
kws_type=kws_set, | |||||
wav_path=[wav_file_path, None], | |||||
workspace=self.workspace) | |||||
self.assertTrue(kws_result.__contains__('detected')) | |||||
""" | |||||
kws result json format example: | |||||
{ | |||||
'wav_count': 1, | |||||
'kws_set': 'wav', | |||||
'wav_time': 9.132938, | |||||
'keywords': ['播放音乐'], | |||||
'detected': True, | |||||
'confidence': 0.660368 | |||||
} | |||||
""" | |||||
if kws_result.__contains__('keywords'): | |||||
print('test_run_with_wav_by_customized_keywords keywords: ', | |||||
kws_result['keywords']) | |||||
print('test_run_with_wav_by_customized_keywords confidence: ', | |||||
kws_result['confidence']) | |||||
print('test_run_with_wav_by_customized_keywords detected result: ', | |||||
kws_result['detected']) | |||||
print('test_run_with_wav_by_customized_keywords wave time(seconds): ', | |||||
kws_result['wav_time']) | |||||
self.check_and_print_result('test_run_with_wav_by_customized_keywords', | |||||
kws_result) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_run_with_pos_testsets(self): | def test_run_with_pos_testsets(self): | ||||
# wav, neg_testsets, pos_testsets, roc | |||||
kws_set = 'pos_testsets' | |||||
# downloading pos_testsets file | |||||
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(POS_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
wav_file_path = download_and_untar( | |||||
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, | |||||
self.workspace) | |||||
wav_path = [wav_file_path, None] | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(POS_TESTSETS_FILE))[0] | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(testsets_dir_name))[0] | |||||
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/ | |||||
wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
# untar the pos_testsets file | |||||
if not os.path.exists(wav_file_path): | |||||
un_tar_gz(testsets_file_path, self.workspace) | |||||
# downloading kwsbp -- a kws batch processing tool | |||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
if not os.path.exists(kwsbp_file_path): | |||||
r = requests.get(KWSBP_URL) | |||||
with open(kwsbp_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
kwsbp_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||||
self.assertTrue(kwsbp_16k_pipline is not None) | |||||
kws_result = kwsbp_16k_pipline( | |||||
kws_type=kws_set, | |||||
wav_path=[wav_file_path, None], | |||||
workspace=self.workspace) | |||||
self.assertTrue(kws_result.__contains__('recall')) | |||||
""" | |||||
kws result json format example: | |||||
{ | |||||
'wav_count': 450, | |||||
'kws_set': 'pos_testsets', | |||||
'wav_time': 3013.759254, | |||||
'keywords': ["小云小云"], | |||||
'recall': 0.953333, | |||||
'detected_count': 429, | |||||
'rejected_count': 21, | |||||
'rejected': [ | |||||
'yyy.wav', | |||||
'zzz.wav', | |||||
...... | |||||
] | |||||
} | |||||
""" | |||||
if kws_result.__contains__('keywords'): | |||||
print('test_run_with_pos_testsets keywords: ', | |||||
kws_result['keywords']) | |||||
print('test_run_with_pos_testsets recall: ', kws_result['recall']) | |||||
print('test_run_with_pos_testsets wave time(seconds): ', | |||||
kws_result['wav_time']) | |||||
kws_result = self.run_pipeline( | |||||
model_id=self.model_id, wav_path=wav_path) | |||||
self.check_and_print_result('test_run_with_pos_testsets', kws_result) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_run_with_neg_testsets(self): | def test_run_with_neg_testsets(self): | ||||
# wav, neg_testsets, pos_testsets, roc | |||||
kws_set = 'neg_testsets' | |||||
# downloading neg_testsets file | |||||
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(NEG_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(NEG_TESTSETS_FILE))[0] | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(testsets_dir_name))[0] | |||||
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/ | |||||
wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
# untar the neg_testsets file | |||||
if not os.path.exists(wav_file_path): | |||||
un_tar_gz(testsets_file_path, self.workspace) | |||||
# downloading kwsbp -- a kws batch processing tool | |||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
if not os.path.exists(kwsbp_file_path): | |||||
r = requests.get(KWSBP_URL) | |||||
with open(kwsbp_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
kwsbp_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||||
self.assertTrue(kwsbp_16k_pipline is not None) | |||||
wav_file_path = download_and_untar( | |||||
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, | |||||
self.workspace) | |||||
wav_path = [None, wav_file_path] | |||||
kws_result = kwsbp_16k_pipline( | |||||
kws_type=kws_set, | |||||
wav_path=[None, wav_file_path], | |||||
workspace=self.workspace) | |||||
self.assertTrue(kws_result.__contains__('fa_rate')) | |||||
""" | |||||
kws result json format example: | |||||
{ | |||||
'wav_count': 751, | |||||
'kws_set': 'neg_testsets', | |||||
'wav_time': 3572.180812, | |||||
'keywords': ['小云小云'], | |||||
'fa_rate': 0.001332, | |||||
'fa_per_hour': 1.007788, | |||||
'detected_count': 1, | |||||
'rejected_count': 750, | |||||
'detected': [ | |||||
{ | |||||
'6.wav': { | |||||
'confidence': '0.321170' | |||||
} | |||||
} | |||||
] | |||||
} | |||||
""" | |||||
if kws_result.__contains__('keywords'): | |||||
print('test_run_with_neg_testsets keywords: ', | |||||
kws_result['keywords']) | |||||
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) | |||||
print('test_run_with_neg_testsets fa per hour: ', | |||||
kws_result['fa_per_hour']) | |||||
print('test_run_with_neg_testsets wave time(seconds): ', | |||||
kws_result['wav_time']) | |||||
kws_result = self.run_pipeline( | |||||
model_id=self.model_id, wav_path=wav_path) | |||||
self.check_and_print_result('test_run_with_neg_testsets', kws_result) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run_with_roc(self): | def test_run_with_roc(self): | ||||
# wav, neg_testsets, pos_testsets, roc | |||||
kws_set = 'roc' | |||||
# downloading neg_testsets file | |||||
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(NEG_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(NEG_TESTSETS_FILE))[0] | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(testsets_dir_name))[0] | |||||
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/ | |||||
neg_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
# untar the neg_testsets file | |||||
if not os.path.exists(neg_file_path): | |||||
un_tar_gz(testsets_file_path, self.workspace) | |||||
# downloading pos_testsets file | |||||
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||||
if not os.path.exists(testsets_file_path): | |||||
r = requests.get(POS_TESTSETS_URL) | |||||
with open(testsets_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(POS_TESTSETS_FILE))[0] | |||||
testsets_dir_name = os.path.splitext( | |||||
os.path.basename(testsets_dir_name))[0] | |||||
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/ | |||||
pos_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
# untar the pos_testsets file | |||||
if not os.path.exists(pos_file_path): | |||||
un_tar_gz(testsets_file_path, self.workspace) | |||||
# downloading kwsbp -- a kws batch processing tool | |||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
if not os.path.exists(kwsbp_file_path): | |||||
r = requests.get(KWSBP_URL) | |||||
with open(kwsbp_file_path, 'wb') as f: | |||||
f.write(r.content) | |||||
kwsbp_16k_pipline = pipeline( | |||||
task=Tasks.auto_speech_recognition, model=self.model_id) | |||||
self.assertTrue(kwsbp_16k_pipline is not None) | |||||
kws_result = kwsbp_16k_pipline( | |||||
kws_type=kws_set, | |||||
wav_path=[pos_file_path, neg_file_path], | |||||
workspace=self.workspace) | |||||
""" | |||||
kws result json format example: | |||||
{ | |||||
'kws_set': 'roc', | |||||
'keywords': ['小云小云'], | |||||
'小云小云': [ | |||||
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||||
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||||
...... | |||||
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} | |||||
] | |||||
} | |||||
""" | |||||
if kws_result.__contains__('keywords'): | |||||
find_keyword = kws_result['keywords'][0] | |||||
print('test_run_with_roc keywords: ', find_keyword) | |||||
keyword_list = kws_result[find_keyword] | |||||
for item in iter(keyword_list): | |||||
threshold: float = item['threshold'] | |||||
recall: float = item['recall'] | |||||
fa_per_hour: float = item['fa_per_hour'] | |||||
print(' threshold:', threshold, ' recall:', recall, | |||||
' fa_per_hour:', fa_per_hour) | |||||
pos_file_path = download_and_untar( | |||||
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, | |||||
self.workspace) | |||||
neg_file_path = download_and_untar( | |||||
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, | |||||
self.workspace) | |||||
wav_path = [pos_file_path, neg_file_path] | |||||
kws_result = self.run_pipeline( | |||||
model_id=self.model_id, wav_path=wav_path) | |||||
self.check_and_print_result('test_run_with_roc', kws_result) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||