From d0933a23749951d8d9838d522483d8db051c83fe Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Tue, 16 Aug 2022 20:23:55 +0800 Subject: [PATCH] [to #42322933] add far field kws model pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9767151 --- data/test/audios/3ch_nihaomiya.wav | 3 + data/test/audios/farend_speech.wav | 3 + data/test/audios/nearend_mic.wav | 3 + data/test/audios/speech_with_noise.wav | 3 + modelscope/metainfo.py | 2 + modelscope/models/audio/kws/__init__.py | 2 + .../models/audio/kws/farfield/__init__.py | 0 modelscope/models/audio/kws/farfield/fsmn.py | 495 ++++++++++++++++++ .../models/audio/kws/farfield/fsmn_sele_v2.py | 236 +++++++++ modelscope/models/audio/kws/farfield/model.py | 74 +++ .../models/audio/kws/farfield/model_def.py | 121 +++++ modelscope/outputs.py | 15 +- modelscope/pipelines/audio/__init__.py | 2 + .../pipelines/audio/kws_farfield_pipeline.py | 81 +++ modelscope/pipelines/base.py | 2 +- requirements/audio.txt | 1 + .../test_key_word_spotting_farfield.py | 43 ++ tests/pipelines/test_speech_signal_process.py | 48 +- 18 files changed, 1096 insertions(+), 38 deletions(-) create mode 100644 data/test/audios/3ch_nihaomiya.wav create mode 100644 data/test/audios/farend_speech.wav create mode 100644 data/test/audios/nearend_mic.wav create mode 100644 data/test/audios/speech_with_noise.wav create mode 100644 modelscope/models/audio/kws/farfield/__init__.py create mode 100644 modelscope/models/audio/kws/farfield/fsmn.py create mode 100644 modelscope/models/audio/kws/farfield/fsmn_sele_v2.py create mode 100644 modelscope/models/audio/kws/farfield/model.py create mode 100644 modelscope/models/audio/kws/farfield/model_def.py create mode 100644 modelscope/pipelines/audio/kws_farfield_pipeline.py create mode 100644 tests/pipelines/test_key_word_spotting_farfield.py diff --git a/data/test/audios/3ch_nihaomiya.wav b/data/test/audios/3ch_nihaomiya.wav new file mode 100644 index 00000000..57d9f061 --- /dev/null +++ b/data/test/audios/3ch_nihaomiya.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559 +size 4325096 diff --git a/data/test/audios/farend_speech.wav b/data/test/audios/farend_speech.wav new file mode 100644 index 00000000..4e96d842 --- /dev/null +++ b/data/test/audios/farend_speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be +size 320042 diff --git a/data/test/audios/nearend_mic.wav b/data/test/audios/nearend_mic.wav new file mode 100644 index 00000000..e055c2e0 --- /dev/null +++ b/data/test/audios/nearend_mic.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469 +size 320042 diff --git a/data/test/audios/speech_with_noise.wav b/data/test/audios/speech_with_noise.wav new file mode 100644 index 00000000..d57488c9 --- /dev/null +++ b/data/test/audios/speech_with_noise.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f +size 76770 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 54109571..ca2e21d6 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -38,6 +38,7 @@ class Models(object): # audio models sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' generic_asr = 'generic-asr' @@ -133,6 +134,7 @@ class Pipelines(object): sambert_hifigan_tts = 'sambert-hifigan-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' asr_inference = 'asr-inference' diff --git a/modelscope/models/audio/kws/__init__.py b/modelscope/models/audio/kws/__init__.py index f3db5e08..dd183fe5 100644 --- a/modelscope/models/audio/kws/__init__.py +++ b/modelscope/models/audio/kws/__init__.py @@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .generic_key_word_spotting import GenericKeyWordSpotting + from .farfield.model import FSMNSeleNetV2Decorator else: _import_structure = { 'generic_key_word_spotting': ['GenericKeyWordSpotting'], + 'farfield.model': ['FSMNSeleNetV2Decorator'], } import sys diff --git a/modelscope/models/audio/kws/farfield/__init__.py b/modelscope/models/audio/kws/farfield/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/kws/farfield/fsmn.py b/modelscope/models/audio/kws/farfield/fsmn.py new file mode 100644 index 00000000..e88d3976 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/fsmn.py @@ -0,0 +1,495 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32, + printNeonMatrix, printNeonVector) + +DEBUG = False + + +def to_kaldi_matrix(np_mat): + """ function that transform as str numpy mat to standard kaldi str matrix + + Args: + np_mat: numpy mat + + Returns: str + """ + np.set_printoptions(threshold=np.inf, linewidth=np.nan) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +def print_tensor(torch_tensor): + """ print torch tensor for debug + + Args: + torch_tensor: a tensor + """ + re_str = '' + x = torch_tensor.detach().squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + print(re_str) + + +class LinearTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(LinearTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim, bias=False) + + self.debug = False + self.dataout = None + + def forward(self, input): + output = self.linear(input) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + printNeonMatrix(self.linear.weight) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class AffineTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear = nn.Linear(input_dim, output_dim) + + self.debug = False + self.dataout = None + + def forward(self, input): + output = self.linear(input) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + printNeonMatrix(self.linear.weight) + printNeonVector(self.linear.bias) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class Fsmn(nn.Module): + """ + FSMN implementation. + """ + + def __init__(self, + input_dim, + output_dim, + lorder=None, + rorder=None, + lstride=None, + rstride=None): + super(Fsmn, self).__init__() + + self.dim = input_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.conv_left = nn.Conv2d( + self.dim, + self.dim, (lorder, 1), + dilation=(lstride, 1), + groups=self.dim, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + self.dim, + self.dim, (rorder, 1), + dilation=(rstride, 1), + groups=self.dim, + bias=False) + else: + self.conv_right = None + + self.debug = False + self.dataout = None + + def forward(self, input): + x = torch.unsqueeze(input, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + out = x_per + self.conv_left(y_left) + self.conv_right(y_right) + else: + out = x_per + self.conv_left(y_left) + + out1 = out.permute(0, 3, 2, 1) + output = out1.squeeze(1) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + tmpw = self.conv_left.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + if self.conv_right is not None: + tmpw = self.conv_right.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += ' %d %d %d %d %d 0\n' % ( + 1, self.lorder, self.rorder, self.lstride, self.rstride) + + lfiters = self.state_dict()['conv_left.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + + if self.conv_right is not None: + rfiters = self.state_dict()['conv_right.weight'] + x = (rfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + + def forward(self, input): + return self.relu(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += '\n' + return re_str + + +class FSMNNet(nn.Module): + """ + FSMN net for keyword spotting + """ + + def __init__(self, + input_dim=200, + linear_dim=128, + proj_dim=128, + lorder=10, + rorder=1, + num_syn=5, + fsmn_layers=4): + """ + Args: + input_dim: input dimension + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + num_syn: output dimension + fsmn_layers: no. of sequential fsmn layers + """ + super(FSMNNet, self).__init__() + + self.input_dim = input_dim + self.linear_dim = linear_dim + self.proj_dim = proj_dim + self.lorder = lorder + self.rorder = rorder + self.num_syn = num_syn + self.fsmn_layers = fsmn_layers + + self.linear1 = AffineTransform(input_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + + self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder, + fsmn_layers) + + self.linear2 = AffineTransform(linear_dim, num_syn) + + @staticmethod + def _build_repeats(linear_dim=136, + proj_dim=68, + lorder=3, + rorder=2, + fsmn_layers=5): + repeats = [ + nn.Sequential( + LinearTransform(linear_dim, proj_dim), + Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1), + AffineTransform(proj_dim, linear_dim), + RectifiedLinear(linear_dim, linear_dim)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) + + def forward(self, input): + x1 = self.linear1(input) + x2 = self.relu(x1) + x3 = self.fsmn(x2) + x4 = self.linear2(x3) + return x4 + + def print_model(self): + self.linear1.print_model() + + for layer in self.fsmn: + layer[0].print_model() + layer[1].print_model() + layer[2].print_model() + + self.linear2.print_model() + + def print_header(self): + # + # write total header + # + header = [0.0] * HEADER_BLOCK_SIZE * 4 + # numins + header[0] = 0.0 + # numouts + header[1] = 0.0 + # dimins + header[2] = self.input_dim + # dimouts + header[3] = self.num_syn + # numlayers + header[4] = 3 + + # + # write each layer's header + # + hidx = 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_RELU.value) + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_SEQUENTIAL_FSMN.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder + header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder + header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers + header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_SOFTMAX.value) + + for h in header: + print(f32ToI32(h)) + + def to_kaldi_nnet(self): + re_str = '' + re_str += '\n' + re_str += self.linear1.to_kaldi_nnet() + re_str += self.relu.to_kaldi_nnet() + + for fsmn in self.fsmn: + re_str += fsmn[0].to_kaldi_nnet() + re_str += fsmn[1].to_kaldi_nnet() + re_str += fsmn[2].to_kaldi_nnet() + re_str += fsmn[3].to_kaldi_nnet() + + re_str += self.linear2.to_kaldi_nnet() + re_str += ' %d %d\n' % (self.num_syn, self.num_syn) + re_str += '\n' + re_str += '\n' + + return re_str + + +class DFSMN(nn.Module): + """ + One deep fsmn layer + """ + + def __init__(self, + dimproj=64, + dimlinear=128, + lorder=20, + rorder=1, + lstride=1, + rstride=1): + """ + Args: + dimproj: projection dimension, input and output dimension of memory blocks + dimlinear: dimension of mapping layer + lorder: left order + rorder: right order + lstride: left stride + rstride: right stride + """ + super(DFSMN, self).__init__() + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.expand = AffineTransform(dimproj, dimlinear) + self.shrink = LinearTransform(dimlinear, dimproj) + + self.conv_left = nn.Conv2d( + dimproj, + dimproj, (lorder, 1), + dilation=(lstride, 1), + groups=dimproj, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + dimproj, + dimproj, (rorder, 1), + dilation=(rstride, 1), + groups=dimproj, + bias=False) + else: + self.conv_right = None + + def forward(self, input): + f1 = F.relu(self.expand(input)) + p1 = self.shrink(f1) + + x = torch.unsqueeze(p1, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + out = x_per + self.conv_left(y_left) + self.conv_right(y_right) + else: + out = x_per + self.conv_left(y_left) + + out1 = out.permute(0, 3, 2, 1) + output = input + out1.squeeze(1) + + return output + + def print_model(self): + self.expand.print_model() + self.shrink.print_model() + + tmpw = self.conv_left.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + if self.conv_right is not None: + tmpw = self.conv_right.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + +def build_dfsmn_repeats(linear_dim=128, + proj_dim=64, + lorder=20, + rorder=1, + fsmn_layers=6): + """ + build stacked dfsmn layers + Args: + linear_dim: + proj_dim: + lorder: + rorder: + fsmn_layers: + + Returns: + + """ + repeats = [ + nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) diff --git a/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py b/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py new file mode 100644 index 00000000..1884e533 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py @@ -0,0 +1,236 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear +from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32 + + +class FSMNUnit(nn.Module): + """ A multi-channel fsmn unit + + """ + + def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1): + """ + Args: + dimlinear: input / output dimension + dimproj: fsmn input / output dimension + lorder: left ofder + rorder: right order + """ + super(FSMNUnit, self).__init__() + + self.shrink = LinearTransform(dimlinear, dimproj) + self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1) + self.expand = AffineTransform(dimproj, dimlinear) + + self.debug = False + self.dataout = None + + ''' + batch, time, channel, feature + ''' + + def forward(self, input): + if torch.cuda.is_available(): + out = torch.zeros(input.shape).cuda() + else: + out = torch.zeros(input.shape) + + for n in range(input.shape[2]): + out1 = self.shrink(input[:, :, n, :]) + out2 = self.fsmn(out1) + out[:, :, n, :] = F.relu(self.expand(out2)) + + if self.debug: + self.dataout = out + + return out + + def print_model(self): + self.shrink.print_model() + self.fsmn.print_model() + self.expand.print_model() + + def to_kaldi_nnet(self): + re_str = self.shrink.to_kaldi_nnet() + re_str += self.fsmn.to_kaldi_nnet() + re_str += self.expand.to_kaldi_nnet() + + relu = RectifiedLinear(self.expand.linear.out_features, + self.expand.linear.out_features) + re_str += relu.to_kaldi_nnet() + + return re_str + + +class FSMNSeleNetV2(nn.Module): + """ FSMN model with channel selection. + """ + + def __init__(self, + input_dim=120, + linear_dim=128, + proj_dim=64, + lorder=20, + rorder=1, + num_syn=5, + fsmn_layers=5, + sele_layer=0): + """ + Args: + input_dim: input dimension + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + num_syn: output dimension + fsmn_layers: no. of fsmn units + sele_layer: channel selection layer index + """ + super(FSMNSeleNetV2, self).__init__() + + self.sele_layer = sele_layer + + self.featmap = AffineTransform(input_dim, linear_dim) + + self.mem = [] + for i in range(fsmn_layers): + unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder) + self.mem.append(unit) + self.add_module('mem_{:d}'.format(i), unit) + + self.decision = AffineTransform(linear_dim, num_syn) + + def forward(self, input): + # multi-channel feature mapping + if torch.cuda.is_available(): + x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], + self.featmap.linear.out_features).cuda() + else: + x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], + self.featmap.linear.out_features) + + for n in range(input.shape[2]): + x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :])) + + for i, unit in enumerate(self.mem): + y = unit(x) + + # perform channel selection + if i == self.sele_layer: + pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1)) + y = pool(y) + + x = y + + # remove channel dimension + y = torch.squeeze(y, -2) + z = self.decision(y) + + return z + + def print_model(self): + self.featmap.print_model() + + for unit in self.mem: + unit.print_model() + + self.decision.print_model() + + def print_header(self): + ''' + get FSMN params + ''' + input_dim = self.featmap.linear.in_features + linear_dim = self.featmap.linear.out_features + proj_dim = self.mem[0].shrink.linear.out_features + lorder = self.mem[0].fsmn.conv_left.kernel_size[0] + rorder = 0 + if self.mem[0].fsmn.conv_right is not None: + rorder = self.mem[0].fsmn.conv_right.kernel_size[0] + + num_syn = self.decision.linear.out_features + fsmn_layers = len(self.mem) + + # no. of output channels, 0.0 means the same as numins + # numouts = 0.0 + numouts = 1.0 + + # + # write total header + # + header = [0.0] * HEADER_BLOCK_SIZE * 4 + # numins + header[0] = 0.0 + # numouts + header[1] = numouts + # dimins + header[2] = input_dim + # dimouts + header[3] = num_syn + # numlayers + header[4] = 3 + + # + # write each layer's header + # + hidx = 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_RELU.value) + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_SEQUENTIAL_FSMN.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = lorder + header[HEADER_BLOCK_SIZE * hidx + 5] = rorder + header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers + if numouts == 1.0: + header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer) + else: + header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = numouts + header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_SOFTMAX.value) + + for h in header: + print(f32ToI32(h)) + + def to_kaldi_nnet(self): + re_str = '\n' + + re_str = self.featmap.to_kaldi_nnet() + + relu = RectifiedLinear(self.featmap.linear.out_features, + self.featmap.linear.out_features) + re_str += relu.to_kaldi_nnet() + + for unit in self.mem: + re_str += unit.to_kaldi_nnet() + + re_str += self.decision.to_kaldi_nnet() + + re_str += ' %d %d\n' % (self.decision.linear.out_features, + self.decision.linear.out_features) + re_str += '\n' + re_str += '\n' + + return re_str diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py new file mode 100644 index 00000000..81e47350 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/model.py @@ -0,0 +1,74 @@ +import os +from typing import Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .fsmn_sele_v2 import FSMNSeleNetV2 + + +@MODELS.register_module( + Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield) +class FSMNSeleNetV2Decorator(TorchModel): + r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """ + + MODEL_TXT = 'model.txt' + SC_CONFIG = 'sound_connect.conf' + SC_CONF_ITEM_KWS_MODEL = '${kws_model}' + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the dfsmn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + sc_config_file = os.path.join(model_dir, self.SC_CONFIG) + model_txt_file = os.path.join(model_dir, self.MODEL_TXT) + model_bin_file = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + self._model = None + if os.path.exists(model_bin_file): + self._model = FSMNSeleNetV2(*args, **kwargs) + checkpoint = torch.load(model_bin_file) + self._model.load_state_dict(checkpoint, strict=False) + + self._sc = None + if os.path.exists(model_txt_file): + with open(sc_config_file) as f: + lines = f.readlines() + with open(sc_config_file, 'w') as f: + for line in lines: + if self.SC_CONF_ITEM_KWS_MODEL in line: + line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, + model_txt_file) + f.write(line) + import py_sound_connect + self._sc = py_sound_connect.SoundConnect(sc_config_file) + self.size_in = self._sc.bytesPerBlockIn() + self.size_out = self._sc.bytesPerBlockOut() + + if self._model is None and self._sc is None: + raise Exception( + f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' + ) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + ... + + def forward_decode(self, data: bytes): + result = {'pcm': self._sc.process(data, self.size_out)} + state = self._sc.kwsState() + if state == 2: + result['kws'] = { + 'keyword': + self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), + 'offset': self._sc.kwsKeywordOffset(), + 'length': self._sc.kwsKeywordLength(), + 'confidence': self._sc.kwsConfidence() + } + return result diff --git a/modelscope/models/audio/kws/farfield/model_def.py b/modelscope/models/audio/kws/farfield/model_def.py new file mode 100644 index 00000000..3f5ba7d7 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/model_def.py @@ -0,0 +1,121 @@ +import math +import struct +from enum import Enum + +HEADER_BLOCK_SIZE = 10 + + +class LayerType(Enum): + LAYER_DENSE = 1 + LAYER_GRU = 2 + LAYER_ATTENTION = 3 + LAYER_FSMN = 4 + LAYER_SEQUENTIAL_FSMN = 5 + LAYER_FSMN_SELE = 6 + LAYER_GRU_ATTENTION = 7 + LAYER_DFSMN = 8 + + +class ActivationType(Enum): + ACTIVATION_NONE = 0 + ACTIVATION_RELU = 1 + ACTIVATION_TANH = 2 + ACTIVATION_SIGMOID = 3 + ACTIVATION_SOFTMAX = 4 + ACTIVATION_LOGSOFTMAX = 5 + + +def f32ToI32(f): + """ + print layer + """ + bs = struct.pack('f', f) + + ba = bytearray() + ba.append(bs[0]) + ba.append(bs[1]) + ba.append(bs[2]) + ba.append(bs[3]) + + return struct.unpack('i', ba)[0] + + +def printNeonMatrix(w): + """ + print matrix with neon padding + """ + numrows, numcols = w.shape + numnecols = math.ceil(numcols / 4) + + for i in range(numrows): + for j in range(numcols): + print(f32ToI32(w[i, j])) + + for j in range(numnecols * 4 - numcols): + print(0) + + +def printNeonVector(b): + """ + print vector with neon padding + """ + size = b.shape[0] + nesize = math.ceil(size / 4) + + for i in range(size): + print(f32ToI32(b[i])) + + for i in range(nesize * 4 - size): + print(0) + + +def printDense(layer): + """ + save dense layer + """ + statedict = layer.state_dict() + printNeonMatrix(statedict['weight']) + printNeonVector(statedict['bias']) + + +def printGRU(layer): + """ + save gru layer + """ + statedict = layer.state_dict() + weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']] + bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']] + numins, numouts = weight[0].shape + numins = numins // 3 + + # output input weights + w_rx = weight[0][:numins, :] + w_zx = weight[0][numins:numins * 2, :] + w_x = weight[0][numins * 2:, :] + printNeonMatrix(w_zx) + printNeonMatrix(w_rx) + printNeonMatrix(w_x) + + # output recurrent weights + w_rh = weight[1][:numins, :] + w_zh = weight[1][numins:numins * 2, :] + w_h = weight[1][numins * 2:, :] + printNeonMatrix(w_zh) + printNeonMatrix(w_rh) + printNeonMatrix(w_h) + + # output input bias + b_rx = bias[0][:numins] + b_zx = bias[0][numins:numins * 2] + b_x = bias[0][numins * 2:] + printNeonVector(b_zx) + printNeonVector(b_rx) + printNeonVector(b_x) + + # output recurrent bias + b_rh = bias[1][:numins] + b_zh = bias[1][numins:numins * 2] + b_h = bias[1][numins * 2:] + printNeonVector(b_zh) + printNeonVector(b_rh) + printNeonVector(b_h) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index f279f311..0c644219 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -405,7 +405,7 @@ TASK_OUTPUTS = { # audio processed for single file in PCM format # { - # "output_pcm": np.array with shape(samples,) and dtype float32 + # "output_pcm": pcm encoded audio bytes # } Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], @@ -417,6 +417,19 @@ TASK_OUTPUTS = { # } Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], + # { + # "kws_list": [ + # { + # 'keyword': '', # the keyword spotted + # 'offset': 19.4, # the keyword start time in second + # 'length': 0.68, # the keyword length in second + # 'confidence': 0.85 # the possibility if it is the keyword + # }, + # ... + # ] + # } + Tasks.keyword_spotting: [OutputKeys.KWS_LIST], + # ============ multi-modal tasks =================== # image caption result for single sample diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index 562125b4..b46ca87e 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .ans_pipeline import ANSPipeline from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline + from .kws_farfield_pipeline import KWSFarfieldPipeline from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline from .linear_aec_pipeline import LinearAECPipeline from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline @@ -14,6 +15,7 @@ else: _import_structure = { 'ans_pipeline': ['ANSPipeline'], 'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], + 'kws_farfield_pipeline': ['KWSFarfieldPipeline'], 'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'], 'linear_aec_pipeline': ['LinearAECPipeline'], 'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'], diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py new file mode 100644 index 00000000..a114e7fb --- /dev/null +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -0,0 +1,81 @@ +import io +import wave +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.keyword_spotting, + module_name=Pipelines.speech_dfsmn_kws_char_farfield) +class KWSFarfieldPipeline(Pipeline): + r"""A Keyword Spotting Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + SAMPLE_WIDTH = 2 + INPUT_CHANNELS = 3 + OUTPUT_CHANNELS = 2 + + def __init__(self, model, **kwargs): + """ + use `model` to create a kws far field pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.model = self.model.to(self.device) + self.model.eval() + frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH + self._nframe = self.model.size_in // frame_size + self.frame_count = 0 + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if isinstance(inputs, bytes): + return dict(input_file=inputs) + elif isinstance(inputs, Dict): + return inputs + else: + raise ValueError(f'Not supported input type: {type(inputs)}') + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + input_file = inputs['input_file'] + if isinstance(input_file, bytes): + input_file = io.BytesIO(input_file) + self.frame_count = 0 + kws_list = [] + with wave.open(input_file, 'rb') as fin: + if 'output_file' in inputs: + with wave.open(inputs['output_file'], 'wb') as fout: + fout.setframerate(self.SAMPLE_RATE) + fout.setnchannels(self.OUTPUT_CHANNELS) + fout.setsampwidth(self.SAMPLE_WIDTH) + self._process(fin, kws_list, fout) + else: + self._process(fin, kws_list) + return {OutputKeys.KWS_LIST: kws_list} + + def _process(self, + fin: wave.Wave_read, + kws_list, + fout: wave.Wave_write = None): + data = fin.readframes(self._nframe) + while len(data) >= self.model.size_in: + self.frame_count += self._nframe + result = self.model.forward_decode(data) + if fout: + fout.writeframes(result['pcm']) + if 'kws' in result: + result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE + kws_list.append(result['kws']) + data = fin.readframes(self._nframe) + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index b1d82557..041dfb34 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -255,7 +255,7 @@ class Pipeline(ABC): return self._collate_fn(torch.from_numpy(data)) elif isinstance(data, torch.Tensor): return data.to(self.device) - elif isinstance(data, (str, int, float, bool, type(None))): + elif isinstance(data, (bytes, str, int, float, bool, type(None))): return data elif isinstance(data, InputFeatures): return data diff --git a/requirements/audio.txt b/requirements/audio.txt index 81d288bd..5e4bc104 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -16,6 +16,7 @@ numpy<=1.18 # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. protobuf>3,<3.21.0 ptflops +py_sound_connect pytorch_wavelets PyWavelets>=1.0.0 scikit-learn diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py new file mode 100644 index 00000000..e7967edc --- /dev/null +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -0,0 +1,43 @@ +import os.path +import unittest + +from modelscope.fileio import File +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' + + +class KWSFarfieldTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} + result = kws(inputs) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_output(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + inputs = { + 'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE), + 'output_file': 'output.wav' + } + result = kws(inputs) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_input_bytes(self): + with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f: + data = f.read() + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(data) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index e8b4a551..007e6c73 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -8,22 +8,10 @@ from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level -NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' -FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' -NEAREND_MIC_FILE = 'nearend_mic.wav' -FAREND_SPEECH_FILE = 'farend_speech.wav' +NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' +FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' -NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' -NOISE_SPEECH_FILE = 'speech_with_noise.wav' - - -def download(remote_path, local_path): - local_dir = os.path.dirname(local_path) - if len(local_dir) > 0: - if not os.path.exists(local_dir): - os.makedirs(local_dir) - with open(local_path, 'wb') as ofile: - ofile.write(File.read(remote_path)) +NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' class SpeechSignalProcessTest(unittest.TestCase): @@ -33,13 +21,10 @@ class SpeechSignalProcessTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_aec(self): - # Download audio files - download(NEAREND_MIC_URL, NEAREND_MIC_FILE) - download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) model_id = 'damo/speech_dfsmn_aec_psm_16k' input = { - 'nearend_mic': NEAREND_MIC_FILE, - 'farend_speech': FAREND_SPEECH_FILE + 'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE), + 'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE) } aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) output_path = os.path.abspath('output.wav') @@ -48,14 +33,11 @@ class SpeechSignalProcessTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_aec_bytes(self): - # Download audio files - download(NEAREND_MIC_URL, NEAREND_MIC_FILE) - download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) model_id = 'damo/speech_dfsmn_aec_psm_16k' input = {} - with open(NEAREND_MIC_FILE, 'rb') as f: + with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: input['nearend_mic'] = f.read() - with open(FAREND_SPEECH_FILE, 'rb') as f: + with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: input['farend_speech'] = f.read() aec = pipeline( Tasks.acoustic_echo_cancellation, @@ -67,13 +49,10 @@ class SpeechSignalProcessTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_aec_tuple_bytes(self): - # Download audio files - download(NEAREND_MIC_URL, NEAREND_MIC_FILE) - download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) model_id = 'damo/speech_dfsmn_aec_psm_16k' - with open(NEAREND_MIC_FILE, 'rb') as f: + with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: nearend_bytes = f.read() - with open(FAREND_SPEECH_FILE, 'rb') as f: + with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: farend_bytes = f.read() inputs = (nearend_bytes, farend_bytes) aec = pipeline( @@ -86,25 +65,22 @@ class SpeechSignalProcessTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_ans(self): - # Download audio files - download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) model_id = 'damo/speech_frcrn_ans_cirm_16k' ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) output_path = os.path.abspath('output.wav') - ans(NOISE_SPEECH_FILE, output_path=output_path) + ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), + output_path=output_path) print(f'Processed audio saved to {output_path}') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ans_bytes(self): - # Download audio files - download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) model_id = 'damo/speech_frcrn_ans_cirm_16k' ans = pipeline( Tasks.acoustic_noise_suppression, model=model_id, pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) output_path = os.path.abspath('output.wav') - with open(NOISE_SPEECH_FILE, 'rb') as f: + with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f: data = f.read() ans(data, output_path=output_path) print(f'Processed audio saved to {output_path}')