Browse Source

[to #42322933] Fix bug in KWS when setting customized keyword

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10412829
master
shichen.fsc yingda.chen 3 years ago
parent
commit
542c4ce1b3
6 changed files with 18 additions and 1 deletions
  1. +2
    -0
      modelscope/models/audio/asr/generic_automatic_speech_recognition.py
  2. +2
    -0
      modelscope/models/audio/kws/generic_key_word_spotting.py
  3. +9
    -0
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  4. +2
    -0
      modelscope/preprocessors/asr.py
  5. +2
    -0
      modelscope/preprocessors/kws.py
  6. +1
    -1
      tests/pipelines/test_key_word_spotting.py

+ 2
- 0
modelscope/models/audio/asr/generic_automatic_speech_recognition.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict



+ 2
- 0
modelscope/models/audio/kws/generic_key_word_spotting.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict



+ 9
- 0
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -37,6 +37,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
**kwargs) -> Dict[str, Any]:
if 'keywords' in kwargs.keys():
self.keywords = kwargs['keywords']
if isinstance(self.keywords, str):
word_list = []
word = {}
word['keyword'] = self.keywords
word_list.append(word)
self.keywords = word_list
else:
self.keywords = None

@@ -96,6 +102,9 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
pos_list=pos_kws_list,
neg_list=neg_kws_list)

if 'kws_list' not in rst_dict:
rst_dict['kws_list'] = []

return rst_dict

def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 2
- 0
modelscope/preprocessors/asr.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict, List, Union



+ 2
- 0
modelscope/preprocessors/kws.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict, List, Union



+ 1
- 1
tests/pipelines/test_key_word_spotting.py View File

@@ -245,7 +245,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
keywords = [{'keyword': '播放音乐'}]
keywords = '播放音乐'

kws_result = self.run_pipeline(
model_id=self.model_id,


Loading…
Cancel
Save