diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 1f31766a..866b8d0b 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -8,6 +8,8 @@ from modelscope.models import Model from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import WavToLists +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): if self.preprocessor is None: self.preprocessor = WavToLists() + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + audio_in = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + audio_in = extract_pcm_from_wav(audio_in) + output = self.preprocessor.forward(self.model.forward(), audio_in) output = self.forward(output) rst = self.postprocess(output) diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index c93e0102..4c2c45cc 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: if len(data) > 44: frame_len = 44 file_len = len(data) - header_fields = {} - header_fields['ChunkID'] = str(data[0:4], 'UTF-8') - header_fields['Format'] = str(data[8:12], 'UTF-8') - header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') - if header_fields['ChunkID'] == 'RIFF' and header_fields[ - 'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': - header_fields['SubChunk1Size'] = struct.unpack('= 0, 'skip test in current test level') + def test_run_with_url(self): + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url', kws_result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): wav_file_path = download_and_untar(