Browse Source

[to #42322933] feat: far field KWS accept mono audio for online demo

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10211100
master
bin.xue yingda.chen 3 years ago
parent
commit
470a1989bc
3 changed files with 36 additions and 19 deletions
  1. +3
    -0
      data/test/audios/1ch_nihaomiya.wav
  2. +22
    -19
      modelscope/pipelines/audio/kws_farfield_pipeline.py
  3. +11
    -0
      tests/pipelines/test_key_word_spotting_farfield.py

+ 3
- 0
data/test/audios/1ch_nihaomiya.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
size 1440044

+ 22
- 19
modelscope/pipelines/audio/kws_farfield_pipeline.py View File

@@ -4,6 +4,9 @@ import io
import wave
from typing import Any, Dict

import numpy
import soundfile as sf

from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline):
self.model.eval()
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
self._nframe = self.model.size_in // frame_size
self.frame_count = 0

def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline):
input_file = inputs['input_file']
if isinstance(input_file, str):
input_file = File.read(input_file)
if isinstance(input_file, bytes):
input_file = io.BytesIO(input_file)
self.frame_count = 0
frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
if len(frames.shape) == 1:
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)

kws_list = []
with wave.open(input_file, 'rb') as fin:
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
self._process(fin, kws_list, fout)
else:
self._process(fin, kws_list)
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
self._process(frames, kws_list, fout)
else:
self._process(frames, kws_list)
return {OutputKeys.KWS_LIST: kws_list}

def _process(self,
fin: wave.Wave_read,
frames: numpy.ndarray,
kws_list,
fout: wave.Wave_write = None):
data = fin.readframes(self._nframe)
while len(data) >= self.model.size_in:
self.frame_count += self._nframe
for start_index in range(0, frames.shape[0], self._nframe):
end_index = start_index + self._nframe
if end_index > frames.shape[0]:
end_index = frames.shape[0]
data = frames[start_index:end_index, :].tobytes()
result = self.model.forward_decode(data)
if fout:
fout.writeframes(result['pcm'])
if 'kws' in result:
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
result['kws']['offset'] += start_index / self.SAMPLE_RATE
kws_list.append(result['kws'])
data = fin.readframes(self._nframe)

def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs

+ 11
- 0
tests/pipelines/test_key_word_spotting_farfield.py View File

@@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
@@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_mono(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)
}
result = kws(inputs)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_url(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)


Loading…
Cancel
Save