shichen.fsc 3 years ago
parent
commit
4ec20761ce
2 changed files with 34 additions and 71 deletions
  1. +13
    -6
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  2. +21
    -65
      tests/pipelines/test_key_word_spotting.py

+ 13
- 6
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -3,7 +3,7 @@ import os
import shutil import shutil
import stat import stat
import subprocess import subprocess
from typing import Any, Dict, List
from typing import Any, Dict, List, Union


import json import json


@@ -25,19 +25,21 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):


def __init__(self, def __init__(self,
config_file: str = None, config_file: str = None,
model: Model = None,
model: Union[Model, str] = None,
preprocessor: WavToLists = None, preprocessor: WavToLists = None,
**kwargs): **kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction """use `model` and `preprocessor` to create a kws pipeline for prediction
""" """


model = model if isinstance(model,
Model) else Model.from_pretrained(model)

super().__init__( super().__init__(
config_file=config_file, config_file=config_file,
model=model, model=model,
preprocessor=preprocessor, preprocessor=preprocessor,
**kwargs) **kwargs)
assert model is not None, 'kws model should be provided' assert model is not None, 'kws model should be provided'
assert preprocessor is not None, 'preprocessor is none'


self._preprocessor = preprocessor self._preprocessor = preprocessor
self._model = model self._model = model
@@ -45,12 +47,17 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):


if 'keywords' in kwargs.keys(): if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords'] self._keywords = kwargs['keywords']
print('self._keywords len: ', len(self._keywords))
print('self._keywords: ', self._keywords)


def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
def __call__(self,
kws_type: str,
wav_path: List[str],
workspace: str = None) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid' 'roc'], f'kws_type {kws_type} is invalid'

if self._preprocessor is None:
self._preprocessor = WavToLists(workspace=workspace)

output = self._preprocessor.forward(self._model.forward(), kws_type, output = self._preprocessor.forward(self._model.forward(), kws_type,
wav_path) wav_path)
output = self.forward(output) output = self.forward(output)


+ 21
- 65
tests/pipelines/test_key_word_spotting.py View File

@@ -40,7 +40,7 @@ class KeyWordSpottingTest(unittest.TestCase):


def tearDown(self) -> None: def tearDown(self) -> None:
if os.path.exists(self.workspace): if os.path.exists(self.workspace):
shutil.rmtree(self.workspace)
shutil.rmtree(os.path.join(self.workspace), ignore_errors=True)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self): def test_run_with_wav(self):
@@ -57,23 +57,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f: with open(kwsbp_file_path, 'wb') as f:
f.write(r.content) f.write(r.content)


model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None) self.assertTrue(kwsbp_16k_pipline is not None)


kws_result = kwsbp_16k_pipline( kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected')) self.assertTrue(kws_result.__contains__('detected'))
""" """
kws result json format example: kws result json format example:
@@ -107,14 +98,6 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f: with open(kwsbp_file_path, 'wb') as f:
f.write(r.content) f.write(r.content)


model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

# customized keyword if you need. # customized keyword if you need.
# full settings eg. # full settings eg.
# keywords = [ # keywords = [
@@ -125,14 +108,14 @@ class KeyWordSpottingTest(unittest.TestCase):


kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting, task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
model=self.model_id,
keywords=keywords) keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None) self.assertTrue(kwsbp_16k_pipline is not None)


kws_result = kwsbp_16k_pipline( kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('detected')) self.assertTrue(kws_result.__contains__('detected'))
""" """
kws result json format example: kws result json format example:
@@ -185,23 +168,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f: with open(kwsbp_file_path, 'wb') as f:
f.write(r.content) f.write(r.content)


model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None) self.assertTrue(kwsbp_16k_pipline is not None)


kws_result = kwsbp_16k_pipline( kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
kws_type=kws_set,
wav_path=[wav_file_path, None],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('recall')) self.assertTrue(kws_result.__contains__('recall'))
""" """
kws result json format example: kws result json format example:
@@ -257,23 +231,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f: with open(kwsbp_file_path, 'wb') as f:
f.write(r.content) f.write(r.content)


model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None) self.assertTrue(kwsbp_16k_pipline is not None)


kws_result = kwsbp_16k_pipline( kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[None, wav_file_path])
kws_type=kws_set,
wav_path=[None, wav_file_path],
workspace=self.workspace)
self.assertTrue(kws_result.__contains__('fa_rate')) self.assertTrue(kws_result.__contains__('fa_rate'))
""" """
kws result json format example: kws result json format example:
@@ -352,23 +317,14 @@ class KeyWordSpottingTest(unittest.TestCase):
with open(kwsbp_file_path, 'wb') as f: with open(kwsbp_file_path, 'wb') as f:
f.write(r.content) f.write(r.content)


model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
task=Tasks.key_word_spotting, model=self.model_id)
self.assertTrue(kwsbp_16k_pipline is not None) self.assertTrue(kwsbp_16k_pipline is not None)


kws_result = kwsbp_16k_pipline( kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[pos_file_path, neg_file_path])
kws_type=kws_set,
wav_path=[pos_file_path, neg_file_path],
workspace=self.workspace)
""" """
kws result json format example: kws result json format example:
{ {


Loading…
Cancel
Save