diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py index 11accf0a..aebc6751 100644 --- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py index c1b7a0e4..2f70327d 100644 --- a/modelscope/models/audio/kws/generic_key_word_spotting.py +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 450a12bb..5555c9e6 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -37,6 +37,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): **kwargs) -> Dict[str, Any]: if 'keywords' in kwargs.keys(): self.keywords = kwargs['keywords'] + if isinstance(self.keywords, str): + word_list = [] + word = {} + word['keyword'] = self.keywords + word_list.append(word) + self.keywords = word_list else: self.keywords = None @@ -96,6 +102,9 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): pos_list=pos_kws_list, neg_list=neg_kws_list) + if 'kws_list' not in rst_dict: + rst_dict['kws_list'] = [] + return rst_dict def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py index d58383d7..facaa132 100644 --- a/modelscope/preprocessors/asr.py +++ b/modelscope/preprocessors/asr.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, List, Union diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py index 9c370ed5..6f09d545 100644 --- a/modelscope/preprocessors/kws.py +++ b/modelscope/preprocessors/kws.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, List, Union diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 91f9f566..f31d212b 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -245,7 +245,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_by_customized_keywords(self): - keywords = [{'keyword': '播放音乐'}] + keywords = '播放音乐' kws_result = self.run_pipeline( model_id=self.model_id,