Browse Source

[to #42322933] simplify kws code, and remove disk-write behavior

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491681
master
shichen.fsc 3 years ago
parent
commit
b3b950e616
7 changed files with 345 additions and 893 deletions
  1. +1
    -1
      modelscope/models/audio/kws/generic_key_word_spotting.py
  2. +80
    -436
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  3. +41
    -160
      modelscope/preprocessors/kws.py
  4. +8
    -0
      modelscope/utils/constant.py
  5. +20
    -0
      modelscope/utils/test_utils.py
  6. +1
    -0
      requirements/audio.txt
  7. +194
    -296
      tests/pipelines/test_key_word_spotting.py

+ 1
- 1
modelscope/models/audio/kws/generic_key_word_spotting.py View File

@@ -19,7 +19,7 @@ class GenericKeyWordSpotting(Model):
Args: Args:
model_dir (str): the model path. model_dir (str): the model path.
""" """
super().__init__(model_dir)
super().__init__(model_dir, *args, **kwargs)
self.model_cfg = { self.model_cfg = {
'model_workspace': model_dir, 'model_workspace': model_dir,
'config_path': os.path.join(model_dir, 'config.yaml') 'config_path': os.path.join(model_dir, 'config.yaml')


+ 80
- 436
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -1,5 +1,4 @@
import os import os
import subprocess
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union


import json import json
@@ -10,6 +9,9 @@ from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToLists from modelscope.preprocessors import WavToLists
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


__all__ = ['KeyWordSpottingKwsbpPipeline'] __all__ = ['KeyWordSpottingKwsbpPipeline']


@@ -21,41 +23,24 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
""" """


def __init__(self, def __init__(self,
config_file: str = None,
model: Union[Model, str] = None, model: Union[Model, str] = None,
preprocessor: WavToLists = None, preprocessor: WavToLists = None,
**kwargs): **kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction
""" """
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(
config_file=config_file,
model=model,
preprocessor=preprocessor,
**kwargs)

assert model is not None, 'kws model should be provided'

self._preprocessor = preprocessor
self._keywords = None
super().__init__(model=model, preprocessor=preprocessor, **kwargs)


def __call__(self, wav_path: Union[List[str], str],
**kwargs) -> Dict[str, Any]:
if 'keywords' in kwargs.keys(): if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']

def __call__(self,
kws_type: str,
wav_path: List[str],
workspace: str = None) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid'
self.keywords = kwargs['keywords']
else:
self.keywords = None


if self._preprocessor is None:
self._preprocessor = WavToLists(workspace=workspace)
if self.preprocessor is None:
self.preprocessor = WavToLists()


output = self._preprocessor.forward(self.model.forward(), kws_type,
wav_path)
output = self.preprocessor.forward(self.model.forward(), wav_path)
output = self.forward(output) output = self.forward(output)
rst = self.postprocess(output) rst = self.postprocess(output)
return rst return rst
@@ -64,433 +49,92 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""Decoding """Decoding
""" """


# will generate kws result into dump/dump.JOB.log
out = self._run_with_kwsbp(inputs)
logger.info(f"Decoding with {inputs['kws_set']} mode ...")

# will generate kws result
out = self.run_with_kwsbp(inputs)


return out return out


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the kws results """process the kws results
"""


pos_result_json = {}
neg_result_json = {}

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
self._parse_dump_log(pos_result_json, inputs['pos_dump_path'])
if inputs['kws_set'] in ['neg_testsets', 'roc']:
self._parse_dump_log(neg_result_json, inputs['neg_dump_path'])
"""
result_json format example:
{
"wav_count": 450,
"keywords": ["小云小云"],
"wav_time": 3560.999999,
"detected": [
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
......
],
"detected_count": 429,
"rejected_count": 21,
"rejected": [
"yyy.wav",
"zzz.wav",
......
]
}
Args:
inputs['pos_kws_list'] or inputs['neg_kws_list']:
result_dict format example:
[{
'confidence': 0.9903678297996521,
'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav',
'keyword': '小云小云',
'offset': 5.760000228881836, # second
'rtf_time': 66, # millisecond
'threshold': 0,
'wav_time': 9.1329375 # second
}]
""" """


rst_dict = {'kws_set': inputs['kws_set']}

# parsing the result of wav
if inputs['kws_set'] == 'wav':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json['detected_count'] == 1:
rst_dict['keywords'] = pos_result_json['keywords']
rst_dict['detected'] = True
wav_file_name = os.path.basename(inputs['pos_wav_path'])
rst_dict['confidence'] = float(pos_result_json['detected'][0]
[wav_file_name]['confidence'])
else:
rst_dict['detected'] = False

# parsing the result of pos_tests
elif inputs['kws_set'] == 'pos_testsets':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json.__contains__('keywords'):
rst_dict['keywords'] = pos_result_json['keywords']

rst_dict['recall'] = round(
pos_result_json['detected_count'] / rst_dict['wav_count'], 6)

if pos_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = pos_result_json['detected_count']
if pos_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = pos_result_json['rejected_count']
if pos_result_json.__contains__('rejected'):
rst_dict['rejected'] = pos_result_json['rejected']

# parsing the result of neg_tests
elif inputs['kws_set'] == 'neg_testsets':
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[
'neg_wav_count']
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6)
if neg_result_json.__contains__('keywords'):
rst_dict['keywords'] = neg_result_json['keywords']

rst_dict['fa_rate'] = 0.0
rst_dict['fa_per_hour'] = 0.0

if neg_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = neg_result_json['detected_count']
rst_dict['fa_rate'] = round(
neg_result_json['detected_count'] / rst_dict['wav_count'],
6)
if neg_result_json.__contains__('wav_time'):
rst_dict['fa_per_hour'] = round(
neg_result_json['detected_count']
/ float(neg_result_json['wav_time'] / 3600), 6)

if neg_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = neg_result_json['rejected_count']

if neg_result_json.__contains__('detected'):
rst_dict['detected'] = neg_result_json['detected']

# parsing the result of roc
elif inputs['kws_set'] == 'roc':
threshold_start = 0.000
threshold_step = 0.001
threshold_end = 1.000

pos_keywords_list = []
neg_keywords_list = []
if pos_result_json.__contains__('keywords'):
pos_keywords_list = pos_result_json['keywords']
if neg_result_json.__contains__('keywords'):
neg_keywords_list = neg_result_json['keywords']

keywords_list = list(set(pos_keywords_list + neg_keywords_list))

pos_result_json['wav_count'] = inputs['pos_wav_count']
neg_result_json['wav_count'] = inputs['neg_wav_count']

if len(keywords_list) > 0:
rst_dict['keywords'] = keywords_list

for index in range(len(rst_dict['keywords'])):
cur_keyword = rst_dict['keywords'][index]
output_list = self._generate_roc_list(
start=threshold_start,
step=threshold_step,
end=threshold_end,
keyword=cur_keyword,
pos_inputs=pos_result_json,
neg_inputs=neg_result_json)

rst_dict[cur_keyword] = output_list
import kws_util.common
neg_kws_list = None
pos_kws_list = None
if 'pos_kws_list' in inputs:
pos_kws_list = inputs['pos_kws_list']
if 'neg_kws_list' in inputs:
neg_kws_list = inputs['neg_kws_list']
rst_dict = kws_util.common.parsing_kws_result(
kws_type=inputs['kws_set'],
pos_list=pos_kws_list,
neg_list=neg_kws_list)


return rst_dict return rst_dict


def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
opts: str = ''
def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
cmd = {
'sys_dir': inputs['model_workspace'],
'cfg_file': inputs['cfg_file_path'],
'sample_rate': inputs['sample_rate'],
'keyword_custom': ''
}

import kwsbp
import kws_util.common
kws_inference = kwsbp.KwsbpEngine()


# setting customized keywords # setting customized keywords
keywords_json = self._set_customized_keywords()
if len(keywords_json) > 0:
keywords_json_file = os.path.join(inputs['workspace'],
'keyword_custom.json')
with open(keywords_json_file, 'w') as f:
json.dump(keywords_json, f)
opts = '--keyword-custom ' + keywords_json_file
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords(
self.keywords)


if inputs['kws_set'] == 'roc': if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join( inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json') inputs['model_workspace'], 'keywords_roc.json')


if inputs['kws_set'] == 'wav':
dump_log_path: str = os.path.join(inputs['pos_dump_path'],
'dump.log')
kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)

if inputs['kws_set'] in ['pos_testsets', 'roc']:
data_dir: str = os.listdir(inputs['pos_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['pos_data_path'], i))

j: int = 0
process = []
while j < inputs['pos_num_thread']:
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
cmd['wave_scp'] = inputs['pos_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['pos_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
pos_result = json.loads(result)
inputs['pos_kws_list'] = pos_result['kws_list']


if inputs['kws_set'] in ['neg_testsets', 'roc']: if inputs['kws_set'] in ['neg_testsets', 'roc']:
data_dir: str = os.listdir(inputs['neg_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['neg_data_path'], i))

j: int = 0
process = []
while j < inputs['neg_num_thread']:
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1
cmd['wave_scp'] = inputs['neg_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['neg_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
neg_result = json.loads(result)
inputs['neg_kws_list'] = neg_result['kws_list']


return inputs return inputs

def _parse_dump_log(self, result_json: Dict[str, Any],
dump_path: str) -> Dict[str, Any]:
dump_dir = os.listdir(dump_path)
for i in dump_dir:
basename = os.path.splitext(os.path.basename(i))[0]
# find dump.JOB.log
if 'dump' in basename:
with open(
os.path.join(dump_path, i), mode='r',
encoding='utf-8') as file:
while 1:
line = file.readline()
if not line:
break
else:
result_json = self._parse_result_log(
line, result_json)

def _parse_result_log(self, line: str,
result_json: Dict[str, Any]) -> Dict[str, Any]:
# valid info
if '[rejected]' in line or '[detected]' in line:
detected_count = 0
rejected_count = 0

if result_json.__contains__('detected_count'):
detected_count = result_json['detected_count']
if result_json.__contains__('rejected_count'):
rejected_count = result_json['rejected_count']

if '[detected]' in line:
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav,
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00,
detected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
keyword = content_list[2].split(':')[1]
confidence = content_list[3].split(':')[1]

keywords_list = []
if result_json.__contains__('keywords'):
keywords_list = result_json['keywords']

if keyword not in keywords_list:
keywords_list.append(keyword)
result_json['keywords'] = keywords_list

keyword_item = {}
keyword_item['confidence'] = confidence
keyword_item['keyword'] = keyword
item = {}
item[file_name] = keyword_item

detected_list = []
if result_json.__contains__('detected'):
detected_list = result_json['detected']

detected_list.append(item)
result_json['detected'] = detected_list

elif '[rejected]' in line:
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav
rejected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
file_name = file_name.strip().replace('\n',
'').replace('\r', '')

rejected_list = []
if result_json.__contains__('rejected'):
rejected_list = result_json['rejected']

rejected_list.append(file_name)
result_json['rejected'] = rejected_list

result_json['detected_count'] = detected_count
result_json['rejected_count'] = rejected_count

elif 'total_proc_time=' in line and 'wav_time=' in line:
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799
wav_total_time = 0
content_list = line.split('), ')
if result_json.__contains__('wav_time'):
wav_total_time = result_json['wav_time']

wav_time_str = content_list[1].split('=')[1]
wav_time_str = wav_time_str.split('(')[0]
wav_time = float(wav_time_str)
wav_time = round(wav_time, 6)

if isinstance(wav_time, float):
wav_total_time += wav_time

result_json['wav_time'] = wav_total_time

return result_json

def _generate_roc_list(self, start: float, step: float, end: float,
keyword: str, pos_inputs: Dict[str, Any],
neg_inputs: Dict[str, Any]) -> Dict[str, Any]:
pos_wav_count = pos_inputs['wav_count']
neg_wav_time = neg_inputs['wav_time']
det_lists = pos_inputs['detected']
fa_lists = neg_inputs['detected']
threshold_cur = start
"""
input det_lists dict
[
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
]

output dict
[
{
"threshold": 0.000,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
{
"threshold": 0.001,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
]
"""

output = []
while threshold_cur <= end:
det_count = 0
fa_count = 0
for index in range(len(det_lists)):
det_item = det_lists[index]
det_wav_item = det_item.get(next(iter(det_item)))
if det_wav_item['keyword'] == keyword:
confidence = float(det_wav_item['confidence'])
if confidence >= threshold_cur:
det_count += 1

for index in range(len(fa_lists)):
fa_item = fa_lists[index]
fa_wav_item = fa_item.get(next(iter(fa_item)))
if fa_wav_item['keyword'] == keyword:
confidence = float(fa_wav_item['confidence'])
if confidence >= threshold_cur:
fa_count += 1

output_item = {
'threshold': round(threshold_cur, 3),
'recall': round(float(det_count / pos_wav_count), 6),
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6)
}
output.append(output_item)

threshold_cur += step

return output

def _set_customized_keywords(self) -> Dict[str, Any]:
if self._keywords is not None:
word_list_inputs = self._keywords
word_list = []
for i in range(len(word_list_inputs)):
key = word_list_inputs[i]
new_item = {}
if key.__contains__('keyword'):
name = key['keyword']
new_name: str = ''
for n in range(0, len(name), 1):
new_name += name[n]
new_name += ' '
new_name = new_name.strip()
new_item['name'] = new_name

if key.__contains__('threshold'):
threshold1: float = key['threshold']
new_item['threshold1'] = threshold1

word_list.append(new_item)
out = {'word_list': word_list}
return out
else:
return ''

+ 41
- 160
modelscope/preprocessors/kws.py View File

@@ -1,8 +1,5 @@
import os import os
import shutil
import stat
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Union


import yaml import yaml


@@ -19,48 +16,28 @@ __all__ = ['WavToLists']
Fields.audio, module_name=Preprocessors.wav_to_lists) Fields.audio, module_name=Preprocessors.wav_to_lists)
class WavToLists(Preprocessor): class WavToLists(Preprocessor):
"""generate audio lists file from wav """generate audio lists file from wav

Args:
workspace (str): store temporarily kws intermedium and result
""" """


def __init__(self, workspace: str = None):
# the workspace path
if len(workspace) == 0:
self._workspace = os.path.join(os.getcwd(), '.tmp')
else:
self._workspace = workspace

if not os.path.exists(self._workspace):
os.mkdir(self._workspace)
def __init__(self):
pass


def __call__(self,
model: Model = None,
kws_type: str = None,
wav_path: List[str] = None) -> Dict[str, Any]:
def __call__(self, model: Model, wav_path: Union[List[str],
str]) -> Dict[str, Any]:
"""Call functions to load model and wav. """Call functions to load model and wav.


Args: Args:
model (Model): model should be provided model (Model): model should be provided
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
Returns: Returns:
Dict[str, Any]: the kws result Dict[str, Any]: the kws result
""" """


assert model is not None, 'preprocess kws model should be provided'
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'
assert wav_path[0] is not None or wav_path[
1] is not None, 'preprocess wav_path is invalid'

self._model = model
out = self.forward(self._model.forward(), kws_type, wav_path)
self.model = model
out = self.forward(self.model.forward(), wav_path)
return out return out


def forward(self, model: Dict[str, Any], kws_type: str,
wav_path: List[str]) -> Dict[str, Any]:
assert len(kws_type) > 0, 'preprocess kws_type is empty'
def forward(self, model: Dict[str, Any],
wav_path: Union[List[str], str]) -> Dict[str, Any]:
assert len( assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty' model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists( assert os.path.exists(
@@ -68,19 +45,29 @@ class WavToLists(Preprocessor):


inputs = model.copy() inputs = model.copy()


wav_list = [None, None]
if isinstance(wav_path, str):
wav_list[0] = wav_path
else:
wav_list = wav_path

import kws_util.common
kws_type = kws_util.common.type_checking(wav_list)
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'

inputs['kws_set'] = kws_type inputs['kws_set'] = kws_type
inputs['workspace'] = self._workspace
if wav_path[0] is not None:
inputs['pos_wav_path'] = wav_path[0]
if wav_path[1] is not None:
inputs['neg_wav_path'] = wav_path[1]
if wav_list[0] is not None:
inputs['pos_wav_path'] = wav_list[0]
if wav_list[1] is not None:
inputs['neg_wav_path'] = wav_list[1]


out = self._read_config(inputs)
out = self._generate_wav_lists(out)
out = self.read_config(inputs)
out = self.generate_wav_lists(out)


return out return out


def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""read and parse config.yaml to get all model files """read and parse config.yaml to get all model files
""" """


@@ -97,157 +84,51 @@ class WavToLists(Preprocessor):
inputs['keyword_grammar'] = root['keyword_grammar'] inputs['keyword_grammar'] = root['keyword_grammar']
inputs['keyword_grammar_path'] = os.path.join( inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], root['keyword_grammar']) inputs['model_workspace'], root['keyword_grammar'])
inputs['sample_rate'] = str(root['sample_rate'])
inputs['kws_tool'] = root['kws_tool']

if os.path.exists(
os.path.join(inputs['workspace'], inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join(inputs['workspace'],
inputs['kws_tool'])
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/usr/bin',
inputs['kws_tool'])
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool'])

assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp'
os.chmod(inputs['kws_tool_path'],
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH)

self._config_checking(inputs)
inputs['sample_rate'] = root['sample_rate']

return inputs return inputs


def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""assemble wav lists """assemble wav lists
""" """
import kws_util.common


if inputs['kws_set'] == 'wav': if inputs['kws_set'] == 'wav':
inputs['pos_num_thread'] = 1
wave_scp_content: str = inputs['pos_wav_path'] + '\n'

with open(os.path.join(inputs['pos_data_path'], 'wave.list'),
'a') as f:
f.write(wave_scp_content)

wav_list = []
wave_scp_content: str = inputs['pos_wav_path']
wav_list.append(wave_scp_content)
inputs['pos_wav_list'] = wav_list
inputs['pos_wav_count'] = 1 inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1


if inputs['kws_set'] in ['pos_testsets', 'roc']: if inputs['kws_set'] in ['pos_testsets', 'roc']:
# find all positive wave # find all positive wave
wav_list = [] wav_list = []
wav_dir = inputs['pos_wav_path'] wav_dir = inputs['pos_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['pos_wav_list'] = wav_list


list_count: int = len(wav_list) list_count: int = len(wav_list)
inputs['pos_wav_count'] = list_count inputs['pos_wav_count'] = list_count


if list_count <= 128: if list_count <= 128:
inputs['pos_num_thread'] = list_count inputs['pos_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else: else:
inputs['pos_num_thread'] = 128 inputs['pos_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0


if inputs['kws_set'] in ['neg_testsets', 'roc']: if inputs['kws_set'] in ['neg_testsets', 'roc']:
# find all negative wave # find all negative wave
wav_list = [] wav_list = []
wav_dir = inputs['neg_wav_path'] wav_dir = inputs['neg_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['neg_wav_list'] = wav_list


list_count: int = len(wav_list) list_count: int = len(wav_list)
inputs['neg_wav_count'] = list_count inputs['neg_wav_count'] = list_count


if list_count <= 128: if list_count <= 128:
inputs['neg_num_thread'] = list_count inputs['neg_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else: else:
inputs['neg_num_thread'] = 128 inputs['neg_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0


return inputs return inputs

def _recursion_dir_all_wave(self, wav_list,
dir_path: str) -> Dict[str, Any]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
self._recursion_dir_all_wave(wav_list, file_path)

return wav_list

def _config_checking(self, inputs: Dict[str, Any]):

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
inputs['pos_data_path'] = os.path.join(inputs['workspace'],
'pos_data')
if not os.path.exists(inputs['pos_data_path']):
os.mkdir(inputs['pos_data_path'])
else:
shutil.rmtree(inputs['pos_data_path'])
os.mkdir(inputs['pos_data_path'])

inputs['pos_dump_path'] = os.path.join(inputs['workspace'],
'pos_dump')
if not os.path.exists(inputs['pos_dump_path']):
os.mkdir(inputs['pos_dump_path'])
else:
shutil.rmtree(inputs['pos_dump_path'])
os.mkdir(inputs['pos_dump_path'])

if inputs['kws_set'] in ['neg_testsets', 'roc']:
inputs['neg_data_path'] = os.path.join(inputs['workspace'],
'neg_data')
if not os.path.exists(inputs['neg_data_path']):
os.mkdir(inputs['neg_data_path'])
else:
shutil.rmtree(inputs['neg_data_path'])
os.mkdir(inputs['neg_data_path'])

inputs['neg_dump_path'] = os.path.join(inputs['workspace'],
'neg_dump')
if not os.path.exists(inputs['neg_dump_path']):
os.mkdir(inputs['neg_dump_path'])
else:
shutil.rmtree(inputs['neg_dump_path'])
os.mkdir(inputs['neg_dump_path'])

+ 8
- 0
modelscope/utils/constant.py View File

@@ -218,3 +218,11 @@ class TrainerStages:
after_val_iter = 'after_val_iter' after_val_iter = 'after_val_iter'
after_val_epoch = 'after_val_epoch' after_val_epoch = 'after_val_epoch'
after_run = 'after_run' after_run = 'after_run'


class ColorCodes:
MAGENTA = '\033[95m'
YELLOW = '\033[93m'
GREEN = '\033[92m'
RED = '\033[91m'
END = '\033[0m'

+ 20
- 0
modelscope/utils/test_utils.py View File

@@ -2,9 +2,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os import os
import tarfile
import unittest import unittest


import numpy as np import numpy as np
import requests
from datasets import Dataset from datasets import Dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE


@@ -42,3 +44,21 @@ def set_test_level(level: int):
def create_dummy_test_dataset(feat, label, num): def create_dummy_test_dataset(feat, label, num):
return MsDataset.from_hf_dataset( return MsDataset.from_hf_dataset(
Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num)))


def download_and_untar(fpath, furl, dst) -> str:
if not os.path.exists(fpath):
r = requests.get(furl)
with open(fpath, 'wb') as f:
f.write(r.content)

file_name = os.path.basename(fpath)
root_dir = os.path.dirname(fpath)
target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0]
target_dir_path = os.path.join(root_dir, target_dir_name)

# untar the file
t = tarfile.open(fpath)
t.extractall(path=dst)

return target_dir_path

+ 1
- 0
requirements/audio.txt View File

@@ -3,6 +3,7 @@ espnet>=202204
h5py h5py
inflect inflect
keras keras
kwsbp
librosa librosa
lxml lxml
matplotlib matplotlib


+ 194
- 296
tests/pipelines/test_key_word_spotting.py View File

@@ -3,14 +3,16 @@ import os
import shutil import shutil
import tarfile import tarfile
import unittest import unittest
from typing import Any, Dict, List, Union


import requests import requests


from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import download_and_untar, test_level


KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'
logger = get_logger()


POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
@@ -22,12 +24,102 @@ NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'




def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class KeyWordSpottingTest(unittest.TestCase): class KeyWordSpottingTest(unittest.TestCase):
action_info = {
'test_run_with_wav': {
'checking_item': 'kws_list',
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '小云小云',
'offset': 5.76,
'length': 9.132938,
'confidence': 0.990368
}]
}
},
'test_run_with_wav_by_customized_keywords': {
'checking_item': 'kws_list',
'checking_value': '播放音乐',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '播放音乐',
'offset': 0.87,
'length': 2.158313,
'confidence': 0.646237
}]
}
},
'test_run_with_pos_testsets': {
'checking_item': 'recall',
'example': {
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.75925,
'keywords': ['小云小云'],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': ['yyy.wav', 'zzz.wav']
}
},
'test_run_with_neg_testsets': {
'checking_item': 'fa_rate',
'example': {
'wav_count':
751,
'kws_set':
'neg_testsets',
'wav_time':
3572.180813,
'keywords': ['小云小云'],
'fa_rate':
0.001332,
'fa_per_hour':
1.007788,
'detected_count':
1,
'rejected_count':
750,
'detected': [{
'6.wav': {
'confidence': '0.321170',
'keyword': '小云小云'
}
}]
}
},
'test_run_with_roc': {
'checking_item': 'keywords',
'checking_value': '小云小云',
'example': {
'kws_set':
'roc',
'keywords': ['小云小云'],
'小云小云': [{
'threshold': 0.0,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.001,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.999,
'recall': 0.004444,
'fa_per_hour': 0.0
}]
}
}
}


def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun'
@@ -36,315 +128,121 @@ class KeyWordSpottingTest(unittest.TestCase):
os.mkdir(self.workspace) os.mkdir(self.workspace)


def tearDown(self) -> None: def tearDown(self) -> None:
# remove workspace dir (.tmp)
if os.path.exists(self.workspace): if os.path.exists(self.workspace):
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = POS_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)
shutil.rmtree(self.workspace, ignore_errors=True)


def run_pipeline(self,
model_id: str,
wav_path: Union[List[str], str],
keywords: List[str] = None) -> Dict[str, Any]:
kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)
task=Tasks.auto_speech_recognition, model=model_id)

kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords)

return kws_result

def print_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(ColorCodes.MAGENTA + functions
+ ' correct result example: ' + ColorCodes.YELLOW
+ str(self.action_info[functions]['example'])
+ ColorCodes.END)

raise ValueError('kws result is mismatched')

def check_and_print_result(self, functions: str,
result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
checking_item = result[self.action_info[functions]
['checking_item']]
if functions == 'test_run_with_roc':
if checking_item[0] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav_by_customized_keywords':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:'
+ str(threshold) + ' recall:' + str(recall)
+ ' fa_per_hour:' + str(fa_per_hour)
+ ColorCodes.END)
else:
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
else:
self.print_error(functions, result)


kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['小云小云'],
'detected': True,
'confidence': 0.990368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav confidence: ', kws_result['confidence'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=POS_WAV_FILE)
self.check_and_print_result('test_run_with_wav', kws_result)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self): def test_run_with_wav_by_customized_keywords(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = BOFANGYINYUE_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

# customized keyword if you need.
# full settings eg.
# keywords = [
# {'keyword':'你好电视', 'threshold': 0.008},
# {'keyword':'播放音乐', 'threshold': 0.008}
# ]
keywords = [{'keyword': '播放音乐'}] keywords = [{'keyword': '播放音乐'}]


kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model=self.model_id,
kws_result = self.run_pipeline(
model_id=self.model_id,
wav_path=BOFANGYINYUE_WAV_FILE,
keywords=keywords) keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['播放音乐'],
'detected': True,
'confidence': 0.660368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav_by_customized_keywords keywords: ',
kws_result['keywords'])
print('test_run_with_wav_by_customized_keywords confidence: ',
kws_result['confidence'])
print('test_run_with_wav_by_customized_keywords detected result: ',
kws_result['detected'])
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
kws_result['wav_time'])
self.check_and_print_result('test_run_with_wav_by_customized_keywords',
kws_result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self): def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'pos_testsets'

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)
wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
wav_path = [wav_file_path, None]


testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('recall'))
"""
kws result json format example:
{
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.759254,
'keywords': ["小云小云"],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': [
'yyy.wav',
'zzz.wav',
......
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_pos_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_pos_testsets recall: ', kws_result['recall'])
print('test_run_with_pos_testsets wave time(seconds): ',
kws_result['wav_time'])
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_pos_testsets', kws_result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self): def test_run_with_neg_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'neg_testsets'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)
wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [None, wav_file_path]


kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[None, wav_file_path],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('fa_rate'))
"""
kws result json format example:
{
'wav_count': 751,
'kws_set': 'neg_testsets',
'wav_time': 3572.180812,
'keywords': ['小云小云'],
'fa_rate': 0.001332,
'fa_per_hour': 1.007788,
'detected_count': 1,
'rejected_count': 750,
'detected': [
{
'6.wav': {
'confidence': '0.321170'
}
}
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_neg_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate'])
print('test_run_with_neg_testsets fa per hour: ',
kws_result['fa_per_hour'])
print('test_run_with_neg_testsets wave time(seconds): ',
kws_result['wav_time'])
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_neg_testsets', kws_result)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self): def test_run_with_roc(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'roc'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/
neg_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(neg_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/
pos_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(pos_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[pos_file_path, neg_file_path],
workspace=self.workspace)
"""
kws result json format example:
{
'kws_set': 'roc',
'keywords': ['小云小云'],
'小云小云': [
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788},
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788},
......
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0}
]
}
"""
if kws_result.__contains__('keywords'):
find_keyword = kws_result['keywords'][0]
print('test_run_with_roc keywords: ', find_keyword)
keyword_list = kws_result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
print(' threshold:', threshold, ' recall:', recall,
' fa_per_hour:', fa_per_hour)
pos_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
neg_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [pos_file_path, neg_file_path]

kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_roc', kws_result)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save