|
|
@@ -4,6 +4,9 @@ import io |
|
|
|
import wave |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import numpy |
|
|
|
import soundfile as sf |
|
|
|
|
|
|
|
from modelscope.fileio import File |
|
|
|
from modelscope.metainfo import Pipelines |
|
|
|
from modelscope.outputs import OutputKeys |
|
|
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline): |
|
|
|
self.model.eval() |
|
|
|
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH |
|
|
|
self._nframe = self.model.size_in // frame_size |
|
|
|
self.frame_count = 0 |
|
|
|
|
|
|
|
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: |
|
|
|
if isinstance(inputs, bytes): |
|
|
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline): |
|
|
|
input_file = inputs['input_file'] |
|
|
|
if isinstance(input_file, str): |
|
|
|
input_file = File.read(input_file) |
|
|
|
if isinstance(input_file, bytes): |
|
|
|
input_file = io.BytesIO(input_file) |
|
|
|
self.frame_count = 0 |
|
|
|
frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16') |
|
|
|
if len(frames.shape) == 1: |
|
|
|
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1) |
|
|
|
|
|
|
|
kws_list = [] |
|
|
|
with wave.open(input_file, 'rb') as fin: |
|
|
|
if 'output_file' in inputs: |
|
|
|
with wave.open(inputs['output_file'], 'wb') as fout: |
|
|
|
fout.setframerate(self.SAMPLE_RATE) |
|
|
|
fout.setnchannels(self.OUTPUT_CHANNELS) |
|
|
|
fout.setsampwidth(self.SAMPLE_WIDTH) |
|
|
|
self._process(fin, kws_list, fout) |
|
|
|
else: |
|
|
|
self._process(fin, kws_list) |
|
|
|
if 'output_file' in inputs: |
|
|
|
with wave.open(inputs['output_file'], 'wb') as fout: |
|
|
|
fout.setframerate(self.SAMPLE_RATE) |
|
|
|
fout.setnchannels(self.OUTPUT_CHANNELS) |
|
|
|
fout.setsampwidth(self.SAMPLE_WIDTH) |
|
|
|
self._process(frames, kws_list, fout) |
|
|
|
else: |
|
|
|
self._process(frames, kws_list) |
|
|
|
return {OutputKeys.KWS_LIST: kws_list} |
|
|
|
|
|
|
|
def _process(self, |
|
|
|
fin: wave.Wave_read, |
|
|
|
frames: numpy.ndarray, |
|
|
|
kws_list, |
|
|
|
fout: wave.Wave_write = None): |
|
|
|
data = fin.readframes(self._nframe) |
|
|
|
while len(data) >= self.model.size_in: |
|
|
|
self.frame_count += self._nframe |
|
|
|
for start_index in range(0, frames.shape[0], self._nframe): |
|
|
|
end_index = start_index + self._nframe |
|
|
|
if end_index > frames.shape[0]: |
|
|
|
end_index = frames.shape[0] |
|
|
|
data = frames[start_index:end_index, :].tobytes() |
|
|
|
result = self.model.forward_decode(data) |
|
|
|
if fout: |
|
|
|
fout.writeframes(result['pcm']) |
|
|
|
if 'kws' in result: |
|
|
|
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE |
|
|
|
result['kws']['offset'] += start_index / self.SAMPLE_RATE |
|
|
|
kws_list.append(result['kws']) |
|
|
|
data = fin.readframes(self._nframe) |
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: |
|
|
|
return inputs |