Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10412829master
@@ -1,3 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
@@ -1,3 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
@@ -37,6 +37,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
**kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
if 'keywords' in kwargs.keys(): | if 'keywords' in kwargs.keys(): | ||||
self.keywords = kwargs['keywords'] | self.keywords = kwargs['keywords'] | ||||
if isinstance(self.keywords, str): | |||||
word_list = [] | |||||
word = {} | |||||
word['keyword'] = self.keywords | |||||
word_list.append(word) | |||||
self.keywords = word_list | |||||
else: | else: | ||||
self.keywords = None | self.keywords = None | ||||
@@ -96,6 +102,9 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
pos_list=pos_kws_list, | pos_list=pos_kws_list, | ||||
neg_list=neg_kws_list) | neg_list=neg_kws_list) | ||||
if 'kws_list' not in rst_dict: | |||||
rst_dict['kws_list'] = [] | |||||
return rst_dict | return rst_dict | ||||
def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
@@ -1,3 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
@@ -1,3 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
@@ -245,7 +245,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_wav_by_customized_keywords(self): | def test_run_with_wav_by_customized_keywords(self): | ||||
keywords = [{'keyword': '播放音乐'}] | |||||
keywords = '播放音乐' | |||||
kws_result = self.run_pipeline( | kws_result = self.run_pipeline( | ||||
model_id=self.model_id, | model_id=self.model_id, | ||||