Browse Source

[to #42322933] simplify kws code, and remove disk-write behavior

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491681
master
shichen.fsc 3 years ago
parent
commit
b3b950e616
7 changed files with 345 additions and 893 deletions
  1. +1
    -1
      modelscope/models/audio/kws/generic_key_word_spotting.py
  2. +80
    -436
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  3. +41
    -160
      modelscope/preprocessors/kws.py
  4. +8
    -0
      modelscope/utils/constant.py
  5. +20
    -0
      modelscope/utils/test_utils.py
  6. +1
    -0
      requirements/audio.txt
  7. +194
    -296
      tests/pipelines/test_key_word_spotting.py

+ 1
- 1
modelscope/models/audio/kws/generic_key_word_spotting.py View File

@@ -19,7 +19,7 @@ class GenericKeyWordSpotting(Model):
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir)
super().__init__(model_dir, *args, **kwargs)
self.model_cfg = {
'model_workspace': model_dir,
'config_path': os.path.join(model_dir, 'config.yaml')


+ 80
- 436
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -1,5 +1,4 @@
import os
import subprocess
from typing import Any, Dict, List, Union

import json
@@ -10,6 +9,9 @@ from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToLists
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['KeyWordSpottingKwsbpPipeline']

@@ -21,41 +23,24 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""

def __init__(self,
config_file: str = None,
model: Union[Model, str] = None,
preprocessor: WavToLists = None,
**kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(
config_file=config_file,
model=model,
preprocessor=preprocessor,
**kwargs)

assert model is not None, 'kws model should be provided'

self._preprocessor = preprocessor
self._keywords = None
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def __call__(self, wav_path: Union[List[str], str],
**kwargs) -> Dict[str, Any]:
if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']

def __call__(self,
kws_type: str,
wav_path: List[str],
workspace: str = None) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid'
self.keywords = kwargs['keywords']
else:
self.keywords = None

if self._preprocessor is None:
self._preprocessor = WavToLists(workspace=workspace)
if self.preprocessor is None:
self.preprocessor = WavToLists()

output = self._preprocessor.forward(self.model.forward(), kws_type,
wav_path)
output = self.preprocessor.forward(self.model.forward(), wav_path)
output = self.forward(output)
rst = self.postprocess(output)
return rst
@@ -64,433 +49,92 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""Decoding
"""

# will generate kws result into dump/dump.JOB.log
out = self._run_with_kwsbp(inputs)
logger.info(f"Decoding with {inputs['kws_set']} mode ...")

# will generate kws result
out = self.run_with_kwsbp(inputs)

return out

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the kws results
"""

pos_result_json = {}
neg_result_json = {}

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
self._parse_dump_log(pos_result_json, inputs['pos_dump_path'])
if inputs['kws_set'] in ['neg_testsets', 'roc']:
self._parse_dump_log(neg_result_json, inputs['neg_dump_path'])
"""
result_json format example:
{
"wav_count": 450,
"keywords": ["小云小云"],
"wav_time": 3560.999999,
"detected": [
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
......
],
"detected_count": 429,
"rejected_count": 21,
"rejected": [
"yyy.wav",
"zzz.wav",
......
]
}
Args:
inputs['pos_kws_list'] or inputs['neg_kws_list']:
result_dict format example:
[{
'confidence': 0.9903678297996521,
'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav',
'keyword': '小云小云',
'offset': 5.760000228881836, # second
'rtf_time': 66, # millisecond
'threshold': 0,
'wav_time': 9.1329375 # second
}]
"""

rst_dict = {'kws_set': inputs['kws_set']}

# parsing the result of wav
if inputs['kws_set'] == 'wav':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json['detected_count'] == 1:
rst_dict['keywords'] = pos_result_json['keywords']
rst_dict['detected'] = True
wav_file_name = os.path.basename(inputs['pos_wav_path'])
rst_dict['confidence'] = float(pos_result_json['detected'][0]
[wav_file_name]['confidence'])
else:
rst_dict['detected'] = False

# parsing the result of pos_tests
elif inputs['kws_set'] == 'pos_testsets':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json.__contains__('keywords'):
rst_dict['keywords'] = pos_result_json['keywords']

rst_dict['recall'] = round(
pos_result_json['detected_count'] / rst_dict['wav_count'], 6)

if pos_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = pos_result_json['detected_count']
if pos_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = pos_result_json['rejected_count']
if pos_result_json.__contains__('rejected'):
rst_dict['rejected'] = pos_result_json['rejected']

# parsing the result of neg_tests
elif inputs['kws_set'] == 'neg_testsets':
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[
'neg_wav_count']
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6)
if neg_result_json.__contains__('keywords'):
rst_dict['keywords'] = neg_result_json['keywords']

rst_dict['fa_rate'] = 0.0
rst_dict['fa_per_hour'] = 0.0

if neg_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = neg_result_json['detected_count']
rst_dict['fa_rate'] = round(
neg_result_json['detected_count'] / rst_dict['wav_count'],
6)
if neg_result_json.__contains__('wav_time'):
rst_dict['fa_per_hour'] = round(
neg_result_json['detected_count']
/ float(neg_result_json['wav_time'] / 3600), 6)

if neg_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = neg_result_json['rejected_count']

if neg_result_json.__contains__('detected'):
rst_dict['detected'] = neg_result_json['detected']

# parsing the result of roc
elif inputs['kws_set'] == 'roc':
threshold_start = 0.000
threshold_step = 0.001
threshold_end = 1.000

pos_keywords_list = []
neg_keywords_list = []
if pos_result_json.__contains__('keywords'):
pos_keywords_list = pos_result_json['keywords']
if neg_result_json.__contains__('keywords'):
neg_keywords_list = neg_result_json['keywords']

keywords_list = list(set(pos_keywords_list + neg_keywords_list))

pos_result_json['wav_count'] = inputs['pos_wav_count']
neg_result_json['wav_count'] = inputs['neg_wav_count']

if len(keywords_list) > 0:
rst_dict['keywords'] = keywords_list

for index in range(len(rst_dict['keywords'])):
cur_keyword = rst_dict['keywords'][index]
output_list = self._generate_roc_list(
start=threshold_start,
step=threshold_step,
end=threshold_end,
keyword=cur_keyword,
pos_inputs=pos_result_json,
neg_inputs=neg_result_json)

rst_dict[cur_keyword] = output_list
import kws_util.common
neg_kws_list = None
pos_kws_list = None
if 'pos_kws_list' in inputs:
pos_kws_list = inputs['pos_kws_list']
if 'neg_kws_list' in inputs:
neg_kws_list = inputs['neg_kws_list']
rst_dict = kws_util.common.parsing_kws_result(
kws_type=inputs['kws_set'],
pos_list=pos_kws_list,
neg_list=neg_kws_list)

return rst_dict

def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
opts: str = ''
def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
cmd = {
'sys_dir': inputs['model_workspace'],
'cfg_file': inputs['cfg_file_path'],
'sample_rate': inputs['sample_rate'],
'keyword_custom': ''
}

import kwsbp
import kws_util.common
kws_inference = kwsbp.KwsbpEngine()

# setting customized keywords
keywords_json = self._set_customized_keywords()
if len(keywords_json) > 0:
keywords_json_file = os.path.join(inputs['workspace'],
'keyword_custom.json')
with open(keywords_json_file, 'w') as f:
json.dump(keywords_json, f)
opts = '--keyword-custom ' + keywords_json_file
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords(
self.keywords)

if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json')

if inputs['kws_set'] == 'wav':
dump_log_path: str = os.path.join(inputs['pos_dump_path'],
'dump.log')
kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)

if inputs['kws_set'] in ['pos_testsets', 'roc']:
data_dir: str = os.listdir(inputs['pos_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['pos_data_path'], i))

j: int = 0
process = []
while j < inputs['pos_num_thread']:
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1
if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
cmd['wave_scp'] = inputs['pos_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['pos_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
pos_result = json.loads(result)
inputs['pos_kws_list'] = pos_result['kws_list']

if inputs['kws_set'] in ['neg_testsets', 'roc']:
data_dir: str = os.listdir(inputs['neg_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['neg_data_path'], i))

j: int = 0
process = []
while j < inputs['neg_num_thread']:
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1
cmd['wave_scp'] = inputs['neg_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['neg_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
neg_result = json.loads(result)
inputs['neg_kws_list'] = neg_result['kws_list']

return inputs

def _parse_dump_log(self, result_json: Dict[str, Any],
dump_path: str) -> Dict[str, Any]:
dump_dir = os.listdir(dump_path)
for i in dump_dir:
basename = os.path.splitext(os.path.basename(i))[0]
# find dump.JOB.log
if 'dump' in basename:
with open(
os.path.join(dump_path, i), mode='r',
encoding='utf-8') as file:
while 1:
line = file.readline()
if not line:
break
else:
result_json = self._parse_result_log(
line, result_json)

def _parse_result_log(self, line: str,
result_json: Dict[str, Any]) -> Dict[str, Any]:
# valid info
if '[rejected]' in line or '[detected]' in line:
detected_count = 0
rejected_count = 0

if result_json.__contains__('detected_count'):
detected_count = result_json['detected_count']
if result_json.__contains__('rejected_count'):
rejected_count = result_json['rejected_count']

if '[detected]' in line:
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav,
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00,
detected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
keyword = content_list[2].split(':')[1]
confidence = content_list[3].split(':')[1]

keywords_list = []
if result_json.__contains__('keywords'):
keywords_list = result_json['keywords']

if keyword not in keywords_list:
keywords_list.append(keyword)
result_json['keywords'] = keywords_list

keyword_item = {}
keyword_item['confidence'] = confidence
keyword_item['keyword'] = keyword
item = {}
item[file_name] = keyword_item

detected_list = []
if result_json.__contains__('detected'):
detected_list = result_json['detected']

detected_list.append(item)
result_json['detected'] = detected_list

elif '[rejected]' in line:
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav
rejected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
file_name = file_name.strip().replace('\n',
'').replace('\r', '')

rejected_list = []
if result_json.__contains__('rejected'):
rejected_list = result_json['rejected']

rejected_list.append(file_name)
result_json['rejected'] = rejected_list

result_json['detected_count'] = detected_count
result_json['rejected_count'] = rejected_count

elif 'total_proc_time=' in line and 'wav_time=' in line:
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799
wav_total_time = 0
content_list = line.split('), ')
if result_json.__contains__('wav_time'):
wav_total_time = result_json['wav_time']

wav_time_str = content_list[1].split('=')[1]
wav_time_str = wav_time_str.split('(')[0]
wav_time = float(wav_time_str)
wav_time = round(wav_time, 6)

if isinstance(wav_time, float):
wav_total_time += wav_time

result_json['wav_time'] = wav_total_time

return result_json

def _generate_roc_list(self, start: float, step: float, end: float,
keyword: str, pos_inputs: Dict[str, Any],
neg_inputs: Dict[str, Any]) -> Dict[str, Any]:
pos_wav_count = pos_inputs['wav_count']
neg_wav_time = neg_inputs['wav_time']
det_lists = pos_inputs['detected']
fa_lists = neg_inputs['detected']
threshold_cur = start
"""
input det_lists dict
[
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
]

output dict
[
{
"threshold": 0.000,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
{
"threshold": 0.001,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
]
"""

output = []
while threshold_cur <= end:
det_count = 0
fa_count = 0
for index in range(len(det_lists)):
det_item = det_lists[index]
det_wav_item = det_item.get(next(iter(det_item)))
if det_wav_item['keyword'] == keyword:
confidence = float(det_wav_item['confidence'])
if confidence >= threshold_cur:
det_count += 1

for index in range(len(fa_lists)):
fa_item = fa_lists[index]
fa_wav_item = fa_item.get(next(iter(fa_item)))
if fa_wav_item['keyword'] == keyword:
confidence = float(fa_wav_item['confidence'])
if confidence >= threshold_cur:
fa_count += 1

output_item = {
'threshold': round(threshold_cur, 3),
'recall': round(float(det_count / pos_wav_count), 6),
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6)
}
output.append(output_item)

threshold_cur += step

return output

def _set_customized_keywords(self) -> Dict[str, Any]:
if self._keywords is not None:
word_list_inputs = self._keywords
word_list = []
for i in range(len(word_list_inputs)):
key = word_list_inputs[i]
new_item = {}
if key.__contains__('keyword'):
name = key['keyword']
new_name: str = ''
for n in range(0, len(name), 1):
new_name += name[n]
new_name += ' '
new_name = new_name.strip()
new_item['name'] = new_name

if key.__contains__('threshold'):
threshold1: float = key['threshold']
new_item['threshold1'] = threshold1

word_list.append(new_item)
out = {'word_list': word_list}
return out
else:
return ''

+ 41
- 160
modelscope/preprocessors/kws.py View File

@@ -1,8 +1,5 @@
import os
import shutil
import stat
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import yaml

@@ -19,48 +16,28 @@ __all__ = ['WavToLists']
Fields.audio, module_name=Preprocessors.wav_to_lists)
class WavToLists(Preprocessor):
"""generate audio lists file from wav

Args:
workspace (str): store temporarily kws intermedium and result
"""

def __init__(self, workspace: str = None):
# the workspace path
if len(workspace) == 0:
self._workspace = os.path.join(os.getcwd(), '.tmp')
else:
self._workspace = workspace

if not os.path.exists(self._workspace):
os.mkdir(self._workspace)
def __init__(self):
pass

def __call__(self,
model: Model = None,
kws_type: str = None,
wav_path: List[str] = None) -> Dict[str, Any]:
def __call__(self, model: Model, wav_path: Union[List[str],
str]) -> Dict[str, Any]:
"""Call functions to load model and wav.

Args:
model (Model): model should be provided
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
Returns:
Dict[str, Any]: the kws result
"""

assert model is not None, 'preprocess kws model should be provided'
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'
assert wav_path[0] is not None or wav_path[
1] is not None, 'preprocess wav_path is invalid'

self._model = model
out = self.forward(self._model.forward(), kws_type, wav_path)
self.model = model
out = self.forward(self.model.forward(), wav_path)
return out

def forward(self, model: Dict[str, Any], kws_type: str,
wav_path: List[str]) -> Dict[str, Any]:
assert len(kws_type) > 0, 'preprocess kws_type is empty'
def forward(self, model: Dict[str, Any],
wav_path: Union[List[str], str]) -> Dict[str, Any]:
assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists(
@@ -68,19 +45,29 @@ class WavToLists(Preprocessor):

inputs = model.copy()

wav_list = [None, None]
if isinstance(wav_path, str):
wav_list[0] = wav_path
else:
wav_list = wav_path

import kws_util.common
kws_type = kws_util.common.type_checking(wav_list)
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'

inputs['kws_set'] = kws_type
inputs['workspace'] = self._workspace
if wav_path[0] is not None:
inputs['pos_wav_path'] = wav_path[0]
if wav_path[1] is not None:
inputs['neg_wav_path'] = wav_path[1]
if wav_list[0] is not None:
inputs['pos_wav_path'] = wav_list[0]
if wav_list[1] is not None:
inputs['neg_wav_path'] = wav_list[1]

out = self._read_config(inputs)
out = self._generate_wav_lists(out)
out = self.read_config(inputs)
out = self.generate_wav_lists(out)

return out

def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""read and parse config.yaml to get all model files
"""

@@ -97,157 +84,51 @@ class WavToLists(Preprocessor):
inputs['keyword_grammar'] = root['keyword_grammar']
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], root['keyword_grammar'])
inputs['sample_rate'] = str(root['sample_rate'])
inputs['kws_tool'] = root['kws_tool']

if os.path.exists(
os.path.join(inputs['workspace'], inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join(inputs['workspace'],
inputs['kws_tool'])
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/usr/bin',
inputs['kws_tool'])
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool'])

assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp'
os.chmod(inputs['kws_tool_path'],
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH)

self._config_checking(inputs)
inputs['sample_rate'] = root['sample_rate']

return inputs

def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""assemble wav lists
"""
import kws_util.common

if inputs['kws_set'] == 'wav':
inputs['pos_num_thread'] = 1
wave_scp_content: str = inputs['pos_wav_path'] + '\n'

with open(os.path.join(inputs['pos_data_path'], 'wave.list'),
'a') as f:
f.write(wave_scp_content)

wav_list = []
wave_scp_content: str = inputs['pos_wav_path']
wav_list.append(wave_scp_content)
inputs['pos_wav_list'] = wav_list
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1

if inputs['kws_set'] in ['pos_testsets', 'roc']:
# find all positive wave
wav_list = []
wav_dir = inputs['pos_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['pos_wav_list'] = wav_list

list_count: int = len(wav_list)
inputs['pos_wav_count'] = list_count

if list_count <= 128:
inputs['pos_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['pos_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

if inputs['kws_set'] in ['neg_testsets', 'roc']:
# find all negative wave
wav_list = []
wav_dir = inputs['neg_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['neg_wav_list'] = wav_list

list_count: int = len(wav_list)
inputs['neg_wav_count'] = list_count

if list_count <= 128:
inputs['neg_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['neg_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

return inputs

def _recursion_dir_all_wave(self, wav_list,
dir_path: str) -> Dict[str, Any]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
self._recursion_dir_all_wave(wav_list, file_path)

return wav_list

def _config_checking(self, inputs: Dict[str, Any]):

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
inputs['pos_data_path'] = os.path.join(inputs['workspace'],
'pos_data')
if not os.path.exists(inputs['pos_data_path']):
os.mkdir(inputs['pos_data_path'])
else:
shutil.rmtree(inputs['pos_data_path'])
os.mkdir(inputs['pos_data_path'])

inputs['pos_dump_path'] = os.path.join(inputs['workspace'],
'pos_dump')
if not os.path.exists(inputs['pos_dump_path']):
os.mkdir(inputs['pos_dump_path'])
else:
shutil.rmtree(inputs['pos_dump_path'])
os.mkdir(inputs['pos_dump_path'])

if inputs['kws_set'] in ['neg_testsets', 'roc']:
inputs['neg_data_path'] = os.path.join(inputs['workspace'],
'neg_data')
if not os.path.exists(inputs['neg_data_path']):
os.mkdir(inputs['neg_data_path'])
else:
shutil.rmtree(inputs['neg_data_path'])
os.mkdir(inputs['neg_data_path'])

inputs['neg_dump_path'] = os.path.join(inputs['workspace'],
'neg_dump')
if not os.path.exists(inputs['neg_dump_path']):
os.mkdir(inputs['neg_dump_path'])
else:
shutil.rmtree(inputs['neg_dump_path'])
os.mkdir(inputs['neg_dump_path'])

+ 8
- 0
modelscope/utils/constant.py View File

@@ -218,3 +218,11 @@ class TrainerStages:
after_val_iter = 'after_val_iter'
after_val_epoch = 'after_val_epoch'
after_run = 'after_run'


class ColorCodes:
MAGENTA = '\033[95m'
YELLOW = '\033[93m'
GREEN = '\033[92m'
RED = '\033[91m'
END = '\033[0m'

+ 20
- 0
modelscope/utils/test_utils.py View File

@@ -2,9 +2,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tarfile
import unittest

import numpy as np
import requests
from datasets import Dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE

@@ -42,3 +44,21 @@ def set_test_level(level: int):
def create_dummy_test_dataset(feat, label, num):
return MsDataset.from_hf_dataset(
Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num)))


def download_and_untar(fpath, furl, dst) -> str:
if not os.path.exists(fpath):
r = requests.get(furl)
with open(fpath, 'wb') as f:
f.write(r.content)

file_name = os.path.basename(fpath)
root_dir = os.path.dirname(fpath)
target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0]
target_dir_path = os.path.join(root_dir, target_dir_name)

# untar the file
t = tarfile.open(fpath)
t.extractall(path=dst)

return target_dir_path

+ 1
- 0
requirements/audio.txt View File

@@ -3,6 +3,7 @@ espnet>=202204
h5py
inflect
keras
kwsbp
librosa
lxml
matplotlib


+ 194
- 296
tests/pipelines/test_key_word_spotting.py View File

@@ -3,14 +3,16 @@ import os
import shutil
import tarfile
import unittest
from typing import Any, Dict, List, Union

import requests

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import download_and_untar, test_level

KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'
logger = get_logger()

POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
@@ -22,12 +24,102 @@ NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'


def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class KeyWordSpottingTest(unittest.TestCase):
action_info = {
'test_run_with_wav': {
'checking_item': 'kws_list',
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '小云小云',
'offset': 5.76,
'length': 9.132938,
'confidence': 0.990368
}]
}
},
'test_run_with_wav_by_customized_keywords': {
'checking_item': 'kws_list',
'checking_value': '播放音乐',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '播放音乐',
'offset': 0.87,
'length': 2.158313,
'confidence': 0.646237
}]
}
},
'test_run_with_pos_testsets': {
'checking_item': 'recall',
'example': {
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.75925,
'keywords': ['小云小云'],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': ['yyy.wav', 'zzz.wav']
}
},
'test_run_with_neg_testsets': {
'checking_item': 'fa_rate',
'example': {
'wav_count':
751,
'kws_set':
'neg_testsets',
'wav_time':
3572.180813,
'keywords': ['小云小云'],
'fa_rate':
0.001332,
'fa_per_hour':
1.007788,
'detected_count':
1,
'rejected_count':
750,
'detected': [{
'6.wav': {
'confidence': '0.321170',
'keyword': '小云小云'
}
}]
}
},
'test_run_with_roc': {
'checking_item': 'keywords',
'checking_value': '小云小云',
'example': {
'kws_set':
'roc',
'keywords': ['小云小云'],
'小云小云': [{
'threshold': 0.0,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.001,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.999,
'recall': 0.004444,
'fa_per_hour': 0.0
}]
}
}
}

def setUp(self) -> None:
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun'
@@ -36,315 +128,121 @@ class KeyWordSpottingTest(unittest.TestCase):
os.mkdir(self.workspace)

def tearDown(self) -> None:
# remove workspace dir (.tmp)
if os.path.exists(self.workspace):
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = POS_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)
shutil.rmtree(self.workspace, ignore_errors=True)

def run_pipeline(self,
model_id: str,
wav_path: Union[List[str], str],
keywords: List[str] = None) -> Dict[str, Any]:
kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)
task=Tasks.auto_speech_recognition, model=model_id)

kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords)

return kws_result

def print_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(ColorCodes.MAGENTA + functions
+ ' correct result example: ' + ColorCodes.YELLOW
+ str(self.action_info[functions]['example'])
+ ColorCodes.END)

raise ValueError('kws result is mismatched')

def check_and_print_result(self, functions: str,
result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
checking_item = result[self.action_info[functions]
['checking_item']]
if functions == 'test_run_with_roc':
if checking_item[0] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav_by_customized_keywords':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:'
+ str(threshold) + ' recall:' + str(recall)
+ ' fa_per_hour:' + str(fa_per_hour)
+ ColorCodes.END)
else:
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
else:
self.print_error(functions, result)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['小云小云'],
'detected': True,
'confidence': 0.990368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav confidence: ', kws_result['confidence'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=POS_WAV_FILE)
self.check_and_print_result('test_run_with_wav', kws_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = BOFANGYINYUE_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

# customized keyword if you need.
# full settings eg.
# keywords = [
# {'keyword':'你好电视', 'threshold': 0.008},
# {'keyword':'播放音乐', 'threshold': 0.008}
# ]
keywords = [{'keyword': '播放音乐'}]

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model=self.model_id,
kws_result = self.run_pipeline(
model_id=self.model_id,
wav_path=BOFANGYINYUE_WAV_FILE,
keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['播放音乐'],
'detected': True,
'confidence': 0.660368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav_by_customized_keywords keywords: ',
kws_result['keywords'])
print('test_run_with_wav_by_customized_keywords confidence: ',
kws_result['confidence'])
print('test_run_with_wav_by_customized_keywords detected result: ',
kws_result['detected'])
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
kws_result['wav_time'])
self.check_and_print_result('test_run_with_wav_by_customized_keywords',
kws_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'pos_testsets'

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)
wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
wav_path = [wav_file_path, None]

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('recall'))
"""
kws result json format example:
{
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.759254,
'keywords': ["小云小云"],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': [
'yyy.wav',
'zzz.wav',
......
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_pos_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_pos_testsets recall: ', kws_result['recall'])
print('test_run_with_pos_testsets wave time(seconds): ',
kws_result['wav_time'])
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_pos_testsets', kws_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'neg_testsets'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)
wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [None, wav_file_path]

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[None, wav_file_path],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('fa_rate'))
"""
kws result json format example:
{
'wav_count': 751,
'kws_set': 'neg_testsets',
'wav_time': 3572.180812,
'keywords': ['小云小云'],
'fa_rate': 0.001332,
'fa_per_hour': 1.007788,
'detected_count': 1,
'rejected_count': 750,
'detected': [
{
'6.wav': {
'confidence': '0.321170'
}
}
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_neg_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate'])
print('test_run_with_neg_testsets fa per hour: ',
kws_result['fa_per_hour'])
print('test_run_with_neg_testsets wave time(seconds): ',
kws_result['wav_time'])
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_neg_testsets', kws_result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'roc'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/
neg_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(neg_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/
pos_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(pos_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set,
wav_path=[pos_file_path, neg_file_path],
workspace=self.workspace)
"""
kws result json format example:
{
'kws_set': 'roc',
'keywords': ['小云小云'],
'小云小云': [
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788},
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788},
......
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0}
]
}
"""
if kws_result.__contains__('keywords'):
find_keyword = kws_result['keywords'][0]
print('test_run_with_roc keywords: ', find_keyword)
keyword_list = kws_result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
print(' threshold:', threshold, ' recall:', recall,
' fa_per_hour:', fa_per_hour)
pos_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
neg_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [pos_file_path, neg_file_path]

kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_roc', kws_result)


if __name__ == '__main__':


Loading…
Cancel
Save