bin.xue yingda.chen 3 years ago
parent
commit
d0933a2374
18 changed files with 1096 additions and 38 deletions
  1. +3
    -0
      data/test/audios/3ch_nihaomiya.wav
  2. +3
    -0
      data/test/audios/farend_speech.wav
  3. +3
    -0
      data/test/audios/nearend_mic.wav
  4. +3
    -0
      data/test/audios/speech_with_noise.wav
  5. +2
    -0
      modelscope/metainfo.py
  6. +2
    -0
      modelscope/models/audio/kws/__init__.py
  7. +0
    -0
      modelscope/models/audio/kws/farfield/__init__.py
  8. +495
    -0
      modelscope/models/audio/kws/farfield/fsmn.py
  9. +236
    -0
      modelscope/models/audio/kws/farfield/fsmn_sele_v2.py
  10. +74
    -0
      modelscope/models/audio/kws/farfield/model.py
  11. +121
    -0
      modelscope/models/audio/kws/farfield/model_def.py
  12. +14
    -1
      modelscope/outputs.py
  13. +2
    -0
      modelscope/pipelines/audio/__init__.py
  14. +81
    -0
      modelscope/pipelines/audio/kws_farfield_pipeline.py
  15. +1
    -1
      modelscope/pipelines/base.py
  16. +1
    -0
      requirements/audio.txt
  17. +43
    -0
      tests/pipelines/test_key_word_spotting_farfield.py
  18. +12
    -36
      tests/pipelines/test_speech_signal_process.py

+ 3
- 0
data/test/audios/3ch_nihaomiya.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559
size 4325096

+ 3
- 0
data/test/audios/farend_speech.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be
size 320042

+ 3
- 0
data/test/audios/nearend_mic.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469
size 320042

+ 3
- 0
data/test/audios/speech_with_noise.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f
size 76770

+ 2
- 0
modelscope/metainfo.py View File

@@ -38,6 +38,7 @@ class Models(object):
# audio models
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'

@@ -133,6 +134,7 @@ class Pipelines(object):
sambert_hifigan_tts = 'sambert-hifigan-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'



+ 2
- 0
modelscope/models/audio/kws/__init__.py View File

@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .generic_key_word_spotting import GenericKeyWordSpotting
from .farfield.model import FSMNSeleNetV2Decorator

else:
_import_structure = {
'generic_key_word_spotting': ['GenericKeyWordSpotting'],
'farfield.model': ['FSMNSeleNetV2Decorator'],
}

import sys


+ 0
- 0
modelscope/models/audio/kws/farfield/__init__.py View File


+ 495
- 0
modelscope/models/audio/kws/farfield/fsmn.py View File

@@ -0,0 +1,495 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32,
printNeonMatrix, printNeonVector)

DEBUG = False


def to_kaldi_matrix(np_mat):
""" function that transform as str numpy mat to standard kaldi str matrix

Args:
np_mat: numpy mat

Returns: str
"""
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str


def print_tensor(torch_tensor):
""" print torch tensor for debug

Args:
torch_tensor: a tensor
"""
re_str = ''
x = torch_tensor.detach().squeeze().numpy()
re_str += to_kaldi_matrix(x)
re_str += '<!EndOfComponent>\n'
print(re_str)


class LinearTransform(nn.Module):

def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)

self.debug = False
self.dataout = None

def forward(self, input):
output = self.linear(input)

if self.debug:
self.dataout = output

return output

def print_model(self):
printNeonMatrix(self.linear.weight)

def to_kaldi_nnet(self):
re_str = ''
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1\n'

linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += to_kaldi_matrix(x)
re_str += '<!EndOfComponent>\n'

return re_str


class AffineTransform(nn.Module):

def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim

self.linear = nn.Linear(input_dim, output_dim)

self.debug = False
self.dataout = None

def forward(self, input):
output = self.linear(input)

if self.debug:
self.dataout = output

return output

def print_model(self):
printNeonMatrix(self.linear.weight)
printNeonVector(self.linear.bias)

def to_kaldi_nnet(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'

linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += to_kaldi_matrix(x)

linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += to_kaldi_matrix(x)
re_str += '<!EndOfComponent>\n'

return re_str


class Fsmn(nn.Module):
"""
FSMN implementation.
"""

def __init__(self,
input_dim,
output_dim,
lorder=None,
rorder=None,
lstride=None,
rstride=None):
super(Fsmn, self).__init__()

self.dim = input_dim

if lorder is None:
return

self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride

self.conv_left = nn.Conv2d(
self.dim,
self.dim, (lorder, 1),
dilation=(lstride, 1),
groups=self.dim,
bias=False)

if rorder > 0:
self.conv_right = nn.Conv2d(
self.dim,
self.dim, (rorder, 1),
dilation=(rstride, 1),
groups=self.dim,
bias=False)
else:
self.conv_right = None

self.debug = False
self.dataout = None

def forward(self, input):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)

y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])

if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)

out1 = out.permute(0, 3, 2, 1)
output = out1.squeeze(1)

if self.debug:
self.dataout = output

return output

def print_model(self):
tmpw = self.conv_left.weight
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
for j in range(tmpw.shape[0]):
tmpwm[:, j] = tmpw[j, 0, :, 0]

printNeonMatrix(tmpwm)

if self.conv_right is not None:
tmpw = self.conv_right.weight
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
for j in range(tmpw.shape[0]):
tmpwm[:, j] = tmpw[j, 0, :, 0]

printNeonMatrix(tmpwm)

def to_kaldi_nnet(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)

lfiters = self.state_dict()['conv_left.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += to_kaldi_matrix(x)

if self.conv_right is not None:
rfiters = self.state_dict()['conv_right.weight']
x = (rfiters.squeeze().numpy().T)
re_str += to_kaldi_matrix(x)
re_str += '<!EndOfComponent>\n'

return re_str


class RectifiedLinear(nn.Module):

def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()

def forward(self, input):
return self.relu(input)

def to_kaldi_nnet(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
re_str += '<!EndOfComponent>\n'
return re_str


class FSMNNet(nn.Module):
"""
FSMN net for keyword spotting
"""

def __init__(self,
input_dim=200,
linear_dim=128,
proj_dim=128,
lorder=10,
rorder=1,
num_syn=5,
fsmn_layers=4):
"""
Args:
input_dim: input dimension
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
"""
super(FSMNNet, self).__init__()

self.input_dim = input_dim
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.lorder = lorder
self.rorder = rorder
self.num_syn = num_syn
self.fsmn_layers = fsmn_layers

self.linear1 = AffineTransform(input_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)

self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder,
fsmn_layers)

self.linear2 = AffineTransform(linear_dim, num_syn)

@staticmethod
def _build_repeats(linear_dim=136,
proj_dim=68,
lorder=3,
rorder=2,
fsmn_layers=5):
repeats = [
nn.Sequential(
LinearTransform(linear_dim, proj_dim),
Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1),
AffineTransform(proj_dim, linear_dim),
RectifiedLinear(linear_dim, linear_dim))
for i in range(fsmn_layers)
]

return nn.Sequential(*repeats)

def forward(self, input):
x1 = self.linear1(input)
x2 = self.relu(x1)
x3 = self.fsmn(x2)
x4 = self.linear2(x3)
return x4

def print_model(self):
self.linear1.print_model()

for layer in self.fsmn:
layer[0].print_model()
layer[1].print_model()
layer[2].print_model()

self.linear2.print_model()

def print_header(self):
#
# write total header
#
header = [0.0] * HEADER_BLOCK_SIZE * 4
# numins
header[0] = 0.0
# numouts
header[1] = 0.0
# dimins
header[2] = self.input_dim
# dimouts
header[3] = self.num_syn
# numlayers
header[4] = 3

#
# write each layer's header
#
hidx = 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_DENSE.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
ActivationType.ACTIVATION_RELU.value)
hidx += 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_SEQUENTIAL_FSMN.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim
header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder
header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder
header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0
hidx += 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_DENSE.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
ActivationType.ACTIVATION_SOFTMAX.value)

for h in header:
print(f32ToI32(h))

def to_kaldi_nnet(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.linear1.to_kaldi_nnet()
re_str += self.relu.to_kaldi_nnet()

for fsmn in self.fsmn:
re_str += fsmn[0].to_kaldi_nnet()
re_str += fsmn[1].to_kaldi_nnet()
re_str += fsmn[2].to_kaldi_nnet()
re_str += fsmn[3].to_kaldi_nnet()

re_str += self.linear2.to_kaldi_nnet()
re_str += '<Softmax> %d %d\n' % (self.num_syn, self.num_syn)
re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'

return re_str


class DFSMN(nn.Module):
"""
One deep fsmn layer
"""

def __init__(self,
dimproj=64,
dimlinear=128,
lorder=20,
rorder=1,
lstride=1,
rstride=1):
"""
Args:
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
"""
super(DFSMN, self).__init__()

self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride

self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)

self.conv_left = nn.Conv2d(
dimproj,
dimproj, (lorder, 1),
dilation=(lstride, 1),
groups=dimproj,
bias=False)

if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj,
dimproj, (rorder, 1),
dilation=(rstride, 1),
groups=dimproj,
bias=False)
else:
self.conv_right = None

def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)

x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)

y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])

if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)

out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)

return output

def print_model(self):
self.expand.print_model()
self.shrink.print_model()

tmpw = self.conv_left.weight
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
for j in range(tmpw.shape[0]):
tmpwm[:, j] = tmpw[j, 0, :, 0]

printNeonMatrix(tmpwm)

if self.conv_right is not None:
tmpw = self.conv_right.weight
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
for j in range(tmpw.shape[0]):
tmpwm[:, j] = tmpw[j, 0, :, 0]

printNeonMatrix(tmpwm)


def build_dfsmn_repeats(linear_dim=128,
proj_dim=64,
lorder=20,
rorder=1,
fsmn_layers=6):
"""
build stacked dfsmn layers
Args:
linear_dim:
proj_dim:
lorder:
rorder:
fsmn_layers:

Returns:

"""
repeats = [
nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]

return nn.Sequential(*repeats)

+ 236
- 0
modelscope/models/audio/kws/farfield/fsmn_sele_v2.py View File

@@ -0,0 +1,236 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear
from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32


class FSMNUnit(nn.Module):
""" A multi-channel fsmn unit

"""

def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1):
"""
Args:
dimlinear: input / output dimension
dimproj: fsmn input / output dimension
lorder: left ofder
rorder: right order
"""
super(FSMNUnit, self).__init__()

self.shrink = LinearTransform(dimlinear, dimproj)
self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1)
self.expand = AffineTransform(dimproj, dimlinear)

self.debug = False
self.dataout = None

'''
batch, time, channel, feature
'''

def forward(self, input):
if torch.cuda.is_available():
out = torch.zeros(input.shape).cuda()
else:
out = torch.zeros(input.shape)

for n in range(input.shape[2]):
out1 = self.shrink(input[:, :, n, :])
out2 = self.fsmn(out1)
out[:, :, n, :] = F.relu(self.expand(out2))

if self.debug:
self.dataout = out

return out

def print_model(self):
self.shrink.print_model()
self.fsmn.print_model()
self.expand.print_model()

def to_kaldi_nnet(self):
re_str = self.shrink.to_kaldi_nnet()
re_str += self.fsmn.to_kaldi_nnet()
re_str += self.expand.to_kaldi_nnet()

relu = RectifiedLinear(self.expand.linear.out_features,
self.expand.linear.out_features)
re_str += relu.to_kaldi_nnet()

return re_str


class FSMNSeleNetV2(nn.Module):
""" FSMN model with channel selection.
"""

def __init__(self,
input_dim=120,
linear_dim=128,
proj_dim=64,
lorder=20,
rorder=1,
num_syn=5,
fsmn_layers=5,
sele_layer=0):
"""
Args:
input_dim: input dimension
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of fsmn units
sele_layer: channel selection layer index
"""
super(FSMNSeleNetV2, self).__init__()

self.sele_layer = sele_layer

self.featmap = AffineTransform(input_dim, linear_dim)

self.mem = []
for i in range(fsmn_layers):
unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder)
self.mem.append(unit)
self.add_module('mem_{:d}'.format(i), unit)

self.decision = AffineTransform(linear_dim, num_syn)

def forward(self, input):
# multi-channel feature mapping
if torch.cuda.is_available():
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2],
self.featmap.linear.out_features).cuda()
else:
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2],
self.featmap.linear.out_features)

for n in range(input.shape[2]):
x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :]))

for i, unit in enumerate(self.mem):
y = unit(x)

# perform channel selection
if i == self.sele_layer:
pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1))
y = pool(y)

x = y

# remove channel dimension
y = torch.squeeze(y, -2)
z = self.decision(y)

return z

def print_model(self):
self.featmap.print_model()

for unit in self.mem:
unit.print_model()

self.decision.print_model()

def print_header(self):
'''
get FSMN params
'''
input_dim = self.featmap.linear.in_features
linear_dim = self.featmap.linear.out_features
proj_dim = self.mem[0].shrink.linear.out_features
lorder = self.mem[0].fsmn.conv_left.kernel_size[0]
rorder = 0
if self.mem[0].fsmn.conv_right is not None:
rorder = self.mem[0].fsmn.conv_right.kernel_size[0]

num_syn = self.decision.linear.out_features
fsmn_layers = len(self.mem)

# no. of output channels, 0.0 means the same as numins
# numouts = 0.0
numouts = 1.0

#
# write total header
#
header = [0.0] * HEADER_BLOCK_SIZE * 4
# numins
header[0] = 0.0
# numouts
header[1] = numouts
# dimins
header[2] = input_dim
# dimouts
header[3] = num_syn
# numlayers
header[4] = 3

#
# write each layer's header
#
hidx = 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_DENSE.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
ActivationType.ACTIVATION_RELU.value)
hidx += 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_SEQUENTIAL_FSMN.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim
header[HEADER_BLOCK_SIZE * hidx + 4] = lorder
header[HEADER_BLOCK_SIZE * hidx + 5] = rorder
header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers
if numouts == 1.0:
header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer)
else:
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0
hidx += 1

header[HEADER_BLOCK_SIZE * hidx + 0] = float(
LayerType.LAYER_DENSE.value)
header[HEADER_BLOCK_SIZE * hidx + 1] = numouts
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim
header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
ActivationType.ACTIVATION_SOFTMAX.value)

for h in header:
print(f32ToI32(h))

def to_kaldi_nnet(self):
re_str = '<Nnet>\n'

re_str = self.featmap.to_kaldi_nnet()

relu = RectifiedLinear(self.featmap.linear.out_features,
self.featmap.linear.out_features)
re_str += relu.to_kaldi_nnet()

for unit in self.mem:
re_str += unit.to_kaldi_nnet()

re_str += self.decision.to_kaldi_nnet()

re_str += '<Softmax> %d %d\n' % (self.decision.linear.out_features,
self.decision.linear.out_features)
re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'

return re_str

+ 74
- 0
modelscope/models/audio/kws/farfield/model.py View File

@@ -0,0 +1,74 @@
import os
from typing import Dict

import torch

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from .fsmn_sele_v2 import FSMNSeleNetV2


@MODELS.register_module(
Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield)
class FSMNSeleNetV2Decorator(TorchModel):
r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """

MODEL_TXT = 'model.txt'
SC_CONFIG = 'sound_connect.conf'
SC_CONF_ITEM_KWS_MODEL = '${kws_model}'

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the dfsmn model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
model_bin_file = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
self._model = None
if os.path.exists(model_bin_file):
self._model = FSMNSeleNetV2(*args, **kwargs)
checkpoint = torch.load(model_bin_file)
self._model.load_state_dict(checkpoint, strict=False)

self._sc = None
if os.path.exists(model_txt_file):
with open(sc_config_file) as f:
lines = f.readlines()
with open(sc_config_file, 'w') as f:
for line in lines:
if self.SC_CONF_ITEM_KWS_MODEL in line:
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL,
model_txt_file)
f.write(line)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()

if self._model is None and self._sc is None:
raise Exception(
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.'
)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
...

def forward_decode(self, data: bytes):
result = {'pcm': self._sc.process(data, self.size_out)}
state = self._sc.kwsState()
if state == 2:
result['kws'] = {
'keyword':
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
'offset': self._sc.kwsKeywordOffset(),
'length': self._sc.kwsKeywordLength(),
'confidence': self._sc.kwsConfidence()
}
return result

+ 121
- 0
modelscope/models/audio/kws/farfield/model_def.py View File

@@ -0,0 +1,121 @@
import math
import struct
from enum import Enum

HEADER_BLOCK_SIZE = 10


class LayerType(Enum):
LAYER_DENSE = 1
LAYER_GRU = 2
LAYER_ATTENTION = 3
LAYER_FSMN = 4
LAYER_SEQUENTIAL_FSMN = 5
LAYER_FSMN_SELE = 6
LAYER_GRU_ATTENTION = 7
LAYER_DFSMN = 8


class ActivationType(Enum):
ACTIVATION_NONE = 0
ACTIVATION_RELU = 1
ACTIVATION_TANH = 2
ACTIVATION_SIGMOID = 3
ACTIVATION_SOFTMAX = 4
ACTIVATION_LOGSOFTMAX = 5


def f32ToI32(f):
"""
print layer
"""
bs = struct.pack('f', f)

ba = bytearray()
ba.append(bs[0])
ba.append(bs[1])
ba.append(bs[2])
ba.append(bs[3])

return struct.unpack('i', ba)[0]


def printNeonMatrix(w):
"""
print matrix with neon padding
"""
numrows, numcols = w.shape
numnecols = math.ceil(numcols / 4)

for i in range(numrows):
for j in range(numcols):
print(f32ToI32(w[i, j]))

for j in range(numnecols * 4 - numcols):
print(0)


def printNeonVector(b):
"""
print vector with neon padding
"""
size = b.shape[0]
nesize = math.ceil(size / 4)

for i in range(size):
print(f32ToI32(b[i]))

for i in range(nesize * 4 - size):
print(0)


def printDense(layer):
"""
save dense layer
"""
statedict = layer.state_dict()
printNeonMatrix(statedict['weight'])
printNeonVector(statedict['bias'])


def printGRU(layer):
"""
save gru layer
"""
statedict = layer.state_dict()
weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']]
bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']]
numins, numouts = weight[0].shape
numins = numins // 3

# output input weights
w_rx = weight[0][:numins, :]
w_zx = weight[0][numins:numins * 2, :]
w_x = weight[0][numins * 2:, :]
printNeonMatrix(w_zx)
printNeonMatrix(w_rx)
printNeonMatrix(w_x)

# output recurrent weights
w_rh = weight[1][:numins, :]
w_zh = weight[1][numins:numins * 2, :]
w_h = weight[1][numins * 2:, :]
printNeonMatrix(w_zh)
printNeonMatrix(w_rh)
printNeonMatrix(w_h)

# output input bias
b_rx = bias[0][:numins]
b_zx = bias[0][numins:numins * 2]
b_x = bias[0][numins * 2:]
printNeonVector(b_zx)
printNeonVector(b_rx)
printNeonVector(b_x)

# output recurrent bias
b_rh = bias[1][:numins]
b_zh = bias[1][numins:numins * 2]
b_h = bias[1][numins * 2:]
printNeonVector(b_zh)
printNeonVector(b_rh)
printNeonVector(b_h)

+ 14
- 1
modelscope/outputs.py View File

@@ -405,7 +405,7 @@ TASK_OUTPUTS = {

# audio processed for single file in PCM format
# {
# "output_pcm": np.array with shape(samples,) and dtype float32
# "output_pcm": pcm encoded audio bytes
# }
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM],
@@ -417,6 +417,19 @@ TASK_OUTPUTS = {
# }
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM],

# {
# "kws_list": [
# {
# 'keyword': '', # the keyword spotted
# 'offset': 19.4, # the keyword start time in second
# 'length': 0.68, # the keyword length in second
# 'confidence': 0.85 # the possibility if it is the keyword
# },
# ...
# ]
# }
Tasks.keyword_spotting: [OutputKeys.KWS_LIST],

# ============ multi-modal tasks ===================

# image caption result for single sample


+ 2
- 0
modelscope/pipelines/audio/__init__.py View File

@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .ans_pipeline import ANSPipeline
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline
from .kws_farfield_pipeline import KWSFarfieldPipeline
from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline
from .linear_aec_pipeline import LinearAECPipeline
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline
@@ -14,6 +15,7 @@ else:
_import_structure = {
'ans_pipeline': ['ANSPipeline'],
'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'],
'kws_farfield_pipeline': ['KWSFarfieldPipeline'],
'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'],
'linear_aec_pipeline': ['LinearAECPipeline'],
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'],


+ 81
- 0
modelscope/pipelines/audio/kws_farfield_pipeline.py View File

@@ -0,0 +1,81 @@
import io
import wave
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
Tasks.keyword_spotting,
module_name=Pipelines.speech_dfsmn_kws_char_farfield)
class KWSFarfieldPipeline(Pipeline):
r"""A Keyword Spotting Inference Pipeline .

When invoke the class with pipeline.__call__(), it accept only one parameter:
inputs(str): the path of wav file
"""
SAMPLE_RATE = 16000
SAMPLE_WIDTH = 2
INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 2

def __init__(self, model, **kwargs):
"""
use `model` to create a kws far field pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
self.model = self.model.to(self.device)
self.model.eval()
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
self._nframe = self.model.size_in // frame_size
self.frame_count = 0

def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
return dict(input_file=inputs)
elif isinstance(inputs, Dict):
return inputs
else:
raise ValueError(f'Not supported input type: {type(inputs)}')

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
input_file = inputs['input_file']
if isinstance(input_file, bytes):
input_file = io.BytesIO(input_file)
self.frame_count = 0
kws_list = []
with wave.open(input_file, 'rb') as fin:
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
self._process(fin, kws_list, fout)
else:
self._process(fin, kws_list)
return {OutputKeys.KWS_LIST: kws_list}

def _process(self,
fin: wave.Wave_read,
kws_list,
fout: wave.Wave_write = None):
data = fin.readframes(self._nframe)
while len(data) >= self.model.size_in:
self.frame_count += self._nframe
result = self.model.forward_decode(data)
if fout:
fout.writeframes(result['pcm'])
if 'kws' in result:
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
kws_list.append(result['kws'])
data = fin.readframes(self._nframe)

def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs

+ 1
- 1
modelscope/pipelines/base.py View File

@@ -255,7 +255,7 @@ class Pipeline(ABC):
return self._collate_fn(torch.from_numpy(data))
elif isinstance(data, torch.Tensor):
return data.to(self.device)
elif isinstance(data, (str, int, float, bool, type(None))):
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
return data
elif isinstance(data, InputFeatures):
return data


+ 1
- 0
requirements/audio.txt View File

@@ -16,6 +16,7 @@ numpy<=1.18
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
protobuf>3,<3.21.0
ptflops
py_sound_connect
pytorch_wavelets
PyWavelets>=1.0.0
scikit-learn


+ 43
- 0
tests/pipelines/test_key_word_spotting_farfield.py View File

@@ -0,0 +1,43 @@
import os.path
import unittest

from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'


class KWSFarfieldTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)}
result = kws(inputs)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_output(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE),
'output_file': 'output.wav'
}
result = kws(inputs)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_input_bytes(self):
with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f:
data = f.read()
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
result = kws(data)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

+ 12
- 36
tests/pipelines/test_speech_signal_process.py View File

@@ -8,22 +8,10 @@ from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
NEAREND_MIC_FILE = 'nearend_mic.wav'
FAREND_SPEECH_FILE = 'farend_speech.wav'
NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav'
FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav'

NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
NOISE_SPEECH_FILE = 'speech_with_noise.wav'


def download(remote_path, local_path):
local_dir = os.path.dirname(local_path)
if len(local_dir) > 0:
if not os.path.exists(local_dir):
os.makedirs(local_dir)
with open(local_path, 'wb') as ofile:
ofile.write(File.read(remote_path))
NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav'


class SpeechSignalProcessTest(unittest.TestCase):
@@ -33,13 +21,10 @@ class SpeechSignalProcessTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_aec(self):
# Download audio files
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {
'nearend_mic': NEAREND_MIC_FILE,
'farend_speech': FAREND_SPEECH_FILE
'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE),
'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE)
}
aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id)
output_path = os.path.abspath('output.wav')
@@ -48,14 +33,11 @@ class SpeechSignalProcessTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec_bytes(self):
# Download audio files
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {}
with open(NEAREND_MIC_FILE, 'rb') as f:
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f:
input['nearend_mic'] = f.read()
with open(FAREND_SPEECH_FILE, 'rb') as f:
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f:
input['farend_speech'] = f.read()
aec = pipeline(
Tasks.acoustic_echo_cancellation,
@@ -67,13 +49,10 @@ class SpeechSignalProcessTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec_tuple_bytes(self):
# Download audio files
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
model_id = 'damo/speech_dfsmn_aec_psm_16k'
with open(NEAREND_MIC_FILE, 'rb') as f:
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f:
nearend_bytes = f.read()
with open(FAREND_SPEECH_FILE, 'rb') as f:
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f:
farend_bytes = f.read()
inputs = (nearend_bytes, farend_bytes)
aec = pipeline(
@@ -86,25 +65,22 @@ class SpeechSignalProcessTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_ans(self):
# Download audio files
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
output_path = os.path.abspath('output.wav')
ans(NOISE_SPEECH_FILE, output_path=output_path)
ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE),
output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans_bytes(self):
# Download audio files
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(
Tasks.acoustic_noise_suppression,
model=model_id,
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
output_path = os.path.abspath('output.wav')
with open(NOISE_SPEECH_FILE, 'rb') as f:
with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f:
data = f.read()
ans(data, output_path=output_path)
print(f'Processed audio saved to {output_path}')


Loading…
Cancel
Save