diff --git a/data/test/audios/1ch_nihaomiya.wav b/data/test/audios/1ch_nihaomiya.wav new file mode 100644 index 00000000..4618d412 --- /dev/null +++ b/data/test/audios/1ch_nihaomiya.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d +size 1440044 diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py index 62f58fee..e2f618fa 100644 --- a/modelscope/pipelines/audio/kws_farfield_pipeline.py +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -4,6 +4,9 @@ import io import wave from typing import Any, Dict +import numpy +import soundfile as sf + from modelscope.fileio import File from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys @@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline): self.model.eval() frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH self._nframe = self.model.size_in // frame_size - self.frame_count = 0 def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, bytes): @@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline): input_file = inputs['input_file'] if isinstance(input_file, str): input_file = File.read(input_file) - if isinstance(input_file, bytes): - input_file = io.BytesIO(input_file) - self.frame_count = 0 + frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16') + if len(frames.shape) == 1: + frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1) + kws_list = [] - with wave.open(input_file, 'rb') as fin: - if 'output_file' in inputs: - with wave.open(inputs['output_file'], 'wb') as fout: - fout.setframerate(self.SAMPLE_RATE) - fout.setnchannels(self.OUTPUT_CHANNELS) - fout.setsampwidth(self.SAMPLE_WIDTH) - self._process(fin, kws_list, fout) - else: - self._process(fin, kws_list) + if 'output_file' in inputs: + with wave.open(inputs['output_file'], 'wb') as fout: + fout.setframerate(self.SAMPLE_RATE) + fout.setnchannels(self.OUTPUT_CHANNELS) + fout.setsampwidth(self.SAMPLE_WIDTH) + self._process(frames, kws_list, fout) + else: + self._process(frames, kws_list) return {OutputKeys.KWS_LIST: kws_list} def _process(self, - fin: wave.Wave_read, + frames: numpy.ndarray, kws_list, fout: wave.Wave_write = None): - data = fin.readframes(self._nframe) - while len(data) >= self.model.size_in: - self.frame_count += self._nframe + for start_index in range(0, frames.shape[0], self._nframe): + end_index = start_index + self._nframe + if end_index > frames.shape[0]: + end_index = frames.shape[0] + data = frames[start_index:end_index, :].tobytes() result = self.model.forward_decode(data) if fout: fout.writeframes(result['pcm']) if 'kws' in result: - result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE + result['kws']['offset'] += start_index / self.SAMPLE_RATE kws_list.append(result['kws']) - data = fin.readframes(self._nframe) def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: return inputs diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index fea7afd7..f8c167de 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' +TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav' TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ '?Revision=master&FilePath=examples/3ch_nihaomiya.wav' @@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase): self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_mono(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + inputs = { + 'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO) + } + result = kws(inputs) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_url(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id)