Browse Source

[to #42322933] feat: far field KWS accept mono audio for online demo

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10211100
master
bin.xue yingda.chen 3 years ago
parent
commit
470a1989bc
3 changed files with 36 additions and 19 deletions
  1. +3
    -0
      data/test/audios/1ch_nihaomiya.wav
  2. +22
    -19
      modelscope/pipelines/audio/kws_farfield_pipeline.py
  3. +11
    -0
      tests/pipelines/test_key_word_spotting_farfield.py

+ 3
- 0
data/test/audios/1ch_nihaomiya.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
size 1440044

+ 22
- 19
modelscope/pipelines/audio/kws_farfield_pipeline.py View File

@@ -4,6 +4,9 @@ import io
import wave import wave
from typing import Any, Dict from typing import Any, Dict


import numpy
import soundfile as sf

from modelscope.fileio import File from modelscope.fileio import File
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline):
self.model.eval() self.model.eval()
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
self._nframe = self.model.size_in // frame_size self._nframe = self.model.size_in // frame_size
self.frame_count = 0


def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline):
input_file = inputs['input_file'] input_file = inputs['input_file']
if isinstance(input_file, str): if isinstance(input_file, str):
input_file = File.read(input_file) input_file = File.read(input_file)
if isinstance(input_file, bytes):
input_file = io.BytesIO(input_file)
self.frame_count = 0
frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
if len(frames.shape) == 1:
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)

kws_list = [] kws_list = []
with wave.open(input_file, 'rb') as fin:
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
self._process(fin, kws_list, fout)
else:
self._process(fin, kws_list)
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
self._process(frames, kws_list, fout)
else:
self._process(frames, kws_list)
return {OutputKeys.KWS_LIST: kws_list} return {OutputKeys.KWS_LIST: kws_list}


def _process(self, def _process(self,
fin: wave.Wave_read,
frames: numpy.ndarray,
kws_list, kws_list,
fout: wave.Wave_write = None): fout: wave.Wave_write = None):
data = fin.readframes(self._nframe)
while len(data) >= self.model.size_in:
self.frame_count += self._nframe
for start_index in range(0, frames.shape[0], self._nframe):
end_index = start_index + self._nframe
if end_index > frames.shape[0]:
end_index = frames.shape[0]
data = frames[start_index:end_index, :].tobytes()
result = self.model.forward_decode(data) result = self.model.forward_decode(data)
if fout: if fout:
fout.writeframes(result['pcm']) fout.writeframes(result['pcm'])
if 'kws' in result: if 'kws' in result:
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
result['kws']['offset'] += start_index / self.SAMPLE_RATE
kws_list.append(result['kws']) kws_list.append(result['kws'])
data = fin.readframes(self._nframe)


def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs return inputs

+ 11
- 0
tests/pipelines/test_key_word_spotting_farfield.py View File

@@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level


TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav' '?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
@@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5) self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1]) print(result['kws_list'][-1])


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_mono(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)
}
result = kws(inputs)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_url(self): def test_url(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) kws = pipeline(Tasks.keyword_spotting, model=self.model_id)


Loading…
Cancel
Save