shichen.fsc 3 years ago
parent
commit
4ec20761ce
2 changed files with 34 additions and 71 deletions
  1. +13
    -6
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  2. +21
    -65
      tests/pipelines/test_key_word_spotting.py

+ 13
- 6
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -3,7 +3,7 @@ import os
import shutil
import stat
import subprocess
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import json

@@ -25,19 +25,21 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):

def __init__(self,
config_file: str = None,
model: Model = None,
model: Union[Model, str] = None,
preprocessor: WavToLists = None,
**kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction
"""

model = model if isinstance(model,
Model) else Model.from_pretrained(model)

super().__init__(
config_file=config_file,
model=model,
preprocessor=preprocessor,
**kwargs)
assert model is not None, 'kws model should be provided'
assert preprocessor is not None, 'preprocessor is none'

self._preprocessor = preprocessor
self._model = model
@@ -45,12 +47,17 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):

if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']
print('self._keywords len: ', len(self._keywords))
print('self._keywords: ', self._keywords)

def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
def __call__(self,
kws_type: str,
wav_path: List[str],
workspace: str = None) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid'

if self._preprocessor is None:
self._preprocessor = WavToLists(workspace=workspace)

output = self._preprocessor.forward(self._model.forward(), kws_type,
wav_path)
output = self.forward(output)


+ 21
- 65
tests/pipelines/test_key_word_spotting.py View File

@@ -40,7 +40,7 @@ class KeyWordSpottingTest(unittest.TestCase):

def tearDown(self) -> None:
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace)
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
@@ -57,23 +57,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
@@ -107,14 +98,6 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

# customized keyword if you need.
# full settings eg.
# keywords = [
@@ -125,14 +108,14 @@ class KeyWordSpottingTest(unittest.TestCase):

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
model=self.model_id,
keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
@@ -185,23 +168,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('recall'))
"""
kws result json format example:
@@ -257,23 +231,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[None, wav_file_path])
kws_type=kws_set,
wav_path=[None, wav_file_path],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('fa_rate'))
"""
kws result json format example:
@@ -352,23 +317,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[pos_file_path, neg_file_path])
kws_type=kws_set,
wav_path=[pos_file_path, neg_file_path],
workspace=self.workspace)
"""
kws result json format example:
{


Loading…
Cancel
Save