Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9767151master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559 | |||
size 4325096 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be | |||
size 320042 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469 | |||
size 320042 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f | |||
size 76770 |
@@ -38,6 +38,7 @@ class Models(object): | |||
# audio models | |||
sambert_hifigan = 'sambert-hifigan' | |||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
kws_kwsbp = 'kws-kwsbp' | |||
generic_asr = 'generic-asr' | |||
@@ -133,6 +134,7 @@ class Pipelines(object): | |||
sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | |||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
kws_kwsbp = 'kws-kwsbp' | |||
asr_inference = 'asr-inference' | |||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .generic_key_word_spotting import GenericKeyWordSpotting | |||
from .farfield.model import FSMNSeleNetV2Decorator | |||
else: | |||
_import_structure = { | |||
'generic_key_word_spotting': ['GenericKeyWordSpotting'], | |||
'farfield.model': ['FSMNSeleNetV2Decorator'], | |||
} | |||
import sys | |||
@@ -0,0 +1,495 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32, | |||
printNeonMatrix, printNeonVector) | |||
DEBUG = False | |||
def to_kaldi_matrix(np_mat): | |||
""" function that transform as str numpy mat to standard kaldi str matrix | |||
Args: | |||
np_mat: numpy mat | |||
Returns: str | |||
""" | |||
np.set_printoptions(threshold=np.inf, linewidth=np.nan) | |||
out_str = str(np_mat) | |||
out_str = out_str.replace('[', '') | |||
out_str = out_str.replace(']', '') | |||
return '[ %s ]\n' % out_str | |||
def print_tensor(torch_tensor): | |||
""" print torch tensor for debug | |||
Args: | |||
torch_tensor: a tensor | |||
""" | |||
re_str = '' | |||
x = torch_tensor.detach().squeeze().numpy() | |||
re_str += to_kaldi_matrix(x) | |||
re_str += '<!EndOfComponent>\n' | |||
print(re_str) | |||
class LinearTransform(nn.Module): | |||
def __init__(self, input_dim, output_dim): | |||
super(LinearTransform, self).__init__() | |||
self.input_dim = input_dim | |||
self.output_dim = output_dim | |||
self.linear = nn.Linear(input_dim, output_dim, bias=False) | |||
self.debug = False | |||
self.dataout = None | |||
def forward(self, input): | |||
output = self.linear(input) | |||
if self.debug: | |||
self.dataout = output | |||
return output | |||
def print_model(self): | |||
printNeonMatrix(self.linear.weight) | |||
def to_kaldi_nnet(self): | |||
re_str = '' | |||
re_str += '<LinearTransform> %d %d\n' % (self.output_dim, | |||
self.input_dim) | |||
re_str += '<LearnRateCoef> 1\n' | |||
linear_weights = self.state_dict()['linear.weight'] | |||
x = linear_weights.squeeze().numpy() | |||
re_str += to_kaldi_matrix(x) | |||
re_str += '<!EndOfComponent>\n' | |||
return re_str | |||
class AffineTransform(nn.Module): | |||
def __init__(self, input_dim, output_dim): | |||
super(AffineTransform, self).__init__() | |||
self.input_dim = input_dim | |||
self.output_dim = output_dim | |||
self.linear = nn.Linear(input_dim, output_dim) | |||
self.debug = False | |||
self.dataout = None | |||
def forward(self, input): | |||
output = self.linear(input) | |||
if self.debug: | |||
self.dataout = output | |||
return output | |||
def print_model(self): | |||
printNeonMatrix(self.linear.weight) | |||
printNeonVector(self.linear.bias) | |||
def to_kaldi_nnet(self): | |||
re_str = '' | |||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim, | |||
self.input_dim) | |||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n' | |||
linear_weights = self.state_dict()['linear.weight'] | |||
x = linear_weights.squeeze().numpy() | |||
re_str += to_kaldi_matrix(x) | |||
linear_bias = self.state_dict()['linear.bias'] | |||
x = linear_bias.squeeze().numpy() | |||
re_str += to_kaldi_matrix(x) | |||
re_str += '<!EndOfComponent>\n' | |||
return re_str | |||
class Fsmn(nn.Module): | |||
""" | |||
FSMN implementation. | |||
""" | |||
def __init__(self, | |||
input_dim, | |||
output_dim, | |||
lorder=None, | |||
rorder=None, | |||
lstride=None, | |||
rstride=None): | |||
super(Fsmn, self).__init__() | |||
self.dim = input_dim | |||
if lorder is None: | |||
return | |||
self.lorder = lorder | |||
self.rorder = rorder | |||
self.lstride = lstride | |||
self.rstride = rstride | |||
self.conv_left = nn.Conv2d( | |||
self.dim, | |||
self.dim, (lorder, 1), | |||
dilation=(lstride, 1), | |||
groups=self.dim, | |||
bias=False) | |||
if rorder > 0: | |||
self.conv_right = nn.Conv2d( | |||
self.dim, | |||
self.dim, (rorder, 1), | |||
dilation=(rstride, 1), | |||
groups=self.dim, | |||
bias=False) | |||
else: | |||
self.conv_right = None | |||
self.debug = False | |||
self.dataout = None | |||
def forward(self, input): | |||
x = torch.unsqueeze(input, 1) | |||
x_per = x.permute(0, 3, 2, 1) | |||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) | |||
if self.conv_right is not None: | |||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) | |||
y_right = y_right[:, :, self.rstride:, :] | |||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right) | |||
else: | |||
out = x_per + self.conv_left(y_left) | |||
out1 = out.permute(0, 3, 2, 1) | |||
output = out1.squeeze(1) | |||
if self.debug: | |||
self.dataout = output | |||
return output | |||
def print_model(self): | |||
tmpw = self.conv_left.weight | |||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||
for j in range(tmpw.shape[0]): | |||
tmpwm[:, j] = tmpw[j, 0, :, 0] | |||
printNeonMatrix(tmpwm) | |||
if self.conv_right is not None: | |||
tmpw = self.conv_right.weight | |||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||
for j in range(tmpw.shape[0]): | |||
tmpwm[:, j] = tmpw[j, 0, :, 0] | |||
printNeonMatrix(tmpwm) | |||
def to_kaldi_nnet(self): | |||
re_str = '' | |||
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim) | |||
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % ( | |||
1, self.lorder, self.rorder, self.lstride, self.rstride) | |||
lfiters = self.state_dict()['conv_left.weight'] | |||
x = np.flipud(lfiters.squeeze().numpy().T) | |||
re_str += to_kaldi_matrix(x) | |||
if self.conv_right is not None: | |||
rfiters = self.state_dict()['conv_right.weight'] | |||
x = (rfiters.squeeze().numpy().T) | |||
re_str += to_kaldi_matrix(x) | |||
re_str += '<!EndOfComponent>\n' | |||
return re_str | |||
class RectifiedLinear(nn.Module): | |||
def __init__(self, input_dim, output_dim): | |||
super(RectifiedLinear, self).__init__() | |||
self.dim = input_dim | |||
self.relu = nn.ReLU() | |||
def forward(self, input): | |||
return self.relu(input) | |||
def to_kaldi_nnet(self): | |||
re_str = '' | |||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim) | |||
re_str += '<!EndOfComponent>\n' | |||
return re_str | |||
class FSMNNet(nn.Module): | |||
""" | |||
FSMN net for keyword spotting | |||
""" | |||
def __init__(self, | |||
input_dim=200, | |||
linear_dim=128, | |||
proj_dim=128, | |||
lorder=10, | |||
rorder=1, | |||
num_syn=5, | |||
fsmn_layers=4): | |||
""" | |||
Args: | |||
input_dim: input dimension | |||
linear_dim: fsmn input dimension | |||
proj_dim: fsmn projection dimension | |||
lorder: fsmn left order | |||
rorder: fsmn right order | |||
num_syn: output dimension | |||
fsmn_layers: no. of sequential fsmn layers | |||
""" | |||
super(FSMNNet, self).__init__() | |||
self.input_dim = input_dim | |||
self.linear_dim = linear_dim | |||
self.proj_dim = proj_dim | |||
self.lorder = lorder | |||
self.rorder = rorder | |||
self.num_syn = num_syn | |||
self.fsmn_layers = fsmn_layers | |||
self.linear1 = AffineTransform(input_dim, linear_dim) | |||
self.relu = RectifiedLinear(linear_dim, linear_dim) | |||
self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder, | |||
fsmn_layers) | |||
self.linear2 = AffineTransform(linear_dim, num_syn) | |||
@staticmethod | |||
def _build_repeats(linear_dim=136, | |||
proj_dim=68, | |||
lorder=3, | |||
rorder=2, | |||
fsmn_layers=5): | |||
repeats = [ | |||
nn.Sequential( | |||
LinearTransform(linear_dim, proj_dim), | |||
Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1), | |||
AffineTransform(proj_dim, linear_dim), | |||
RectifiedLinear(linear_dim, linear_dim)) | |||
for i in range(fsmn_layers) | |||
] | |||
return nn.Sequential(*repeats) | |||
def forward(self, input): | |||
x1 = self.linear1(input) | |||
x2 = self.relu(x1) | |||
x3 = self.fsmn(x2) | |||
x4 = self.linear2(x3) | |||
return x4 | |||
def print_model(self): | |||
self.linear1.print_model() | |||
for layer in self.fsmn: | |||
layer[0].print_model() | |||
layer[1].print_model() | |||
layer[2].print_model() | |||
self.linear2.print_model() | |||
def print_header(self): | |||
# | |||
# write total header | |||
# | |||
header = [0.0] * HEADER_BLOCK_SIZE * 4 | |||
# numins | |||
header[0] = 0.0 | |||
# numouts | |||
header[1] = 0.0 | |||
# dimins | |||
header[2] = self.input_dim | |||
# dimouts | |||
header[3] = self.num_syn | |||
# numlayers | |||
header[4] = 3 | |||
# | |||
# write each layer's header | |||
# | |||
hidx = 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_DENSE.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||
ActivationType.ACTIVATION_RELU.value) | |||
hidx += 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_SEQUENTIAL_FSMN.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder | |||
header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers | |||
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 | |||
hidx += 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_DENSE.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||
ActivationType.ACTIVATION_SOFTMAX.value) | |||
for h in header: | |||
print(f32ToI32(h)) | |||
def to_kaldi_nnet(self): | |||
re_str = '' | |||
re_str += '<Nnet>\n' | |||
re_str += self.linear1.to_kaldi_nnet() | |||
re_str += self.relu.to_kaldi_nnet() | |||
for fsmn in self.fsmn: | |||
re_str += fsmn[0].to_kaldi_nnet() | |||
re_str += fsmn[1].to_kaldi_nnet() | |||
re_str += fsmn[2].to_kaldi_nnet() | |||
re_str += fsmn[3].to_kaldi_nnet() | |||
re_str += self.linear2.to_kaldi_nnet() | |||
re_str += '<Softmax> %d %d\n' % (self.num_syn, self.num_syn) | |||
re_str += '<!EndOfComponent>\n' | |||
re_str += '</Nnet>\n' | |||
return re_str | |||
class DFSMN(nn.Module): | |||
""" | |||
One deep fsmn layer | |||
""" | |||
def __init__(self, | |||
dimproj=64, | |||
dimlinear=128, | |||
lorder=20, | |||
rorder=1, | |||
lstride=1, | |||
rstride=1): | |||
""" | |||
Args: | |||
dimproj: projection dimension, input and output dimension of memory blocks | |||
dimlinear: dimension of mapping layer | |||
lorder: left order | |||
rorder: right order | |||
lstride: left stride | |||
rstride: right stride | |||
""" | |||
super(DFSMN, self).__init__() | |||
self.lorder = lorder | |||
self.rorder = rorder | |||
self.lstride = lstride | |||
self.rstride = rstride | |||
self.expand = AffineTransform(dimproj, dimlinear) | |||
self.shrink = LinearTransform(dimlinear, dimproj) | |||
self.conv_left = nn.Conv2d( | |||
dimproj, | |||
dimproj, (lorder, 1), | |||
dilation=(lstride, 1), | |||
groups=dimproj, | |||
bias=False) | |||
if rorder > 0: | |||
self.conv_right = nn.Conv2d( | |||
dimproj, | |||
dimproj, (rorder, 1), | |||
dilation=(rstride, 1), | |||
groups=dimproj, | |||
bias=False) | |||
else: | |||
self.conv_right = None | |||
def forward(self, input): | |||
f1 = F.relu(self.expand(input)) | |||
p1 = self.shrink(f1) | |||
x = torch.unsqueeze(p1, 1) | |||
x_per = x.permute(0, 3, 2, 1) | |||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) | |||
if self.conv_right is not None: | |||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) | |||
y_right = y_right[:, :, self.rstride:, :] | |||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right) | |||
else: | |||
out = x_per + self.conv_left(y_left) | |||
out1 = out.permute(0, 3, 2, 1) | |||
output = input + out1.squeeze(1) | |||
return output | |||
def print_model(self): | |||
self.expand.print_model() | |||
self.shrink.print_model() | |||
tmpw = self.conv_left.weight | |||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||
for j in range(tmpw.shape[0]): | |||
tmpwm[:, j] = tmpw[j, 0, :, 0] | |||
printNeonMatrix(tmpwm) | |||
if self.conv_right is not None: | |||
tmpw = self.conv_right.weight | |||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||
for j in range(tmpw.shape[0]): | |||
tmpwm[:, j] = tmpw[j, 0, :, 0] | |||
printNeonMatrix(tmpwm) | |||
def build_dfsmn_repeats(linear_dim=128, | |||
proj_dim=64, | |||
lorder=20, | |||
rorder=1, | |||
fsmn_layers=6): | |||
""" | |||
build stacked dfsmn layers | |||
Args: | |||
linear_dim: | |||
proj_dim: | |||
lorder: | |||
rorder: | |||
fsmn_layers: | |||
Returns: | |||
""" | |||
repeats = [ | |||
nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) | |||
for i in range(fsmn_layers) | |||
] | |||
return nn.Sequential(*repeats) |
@@ -0,0 +1,236 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear | |||
from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32 | |||
class FSMNUnit(nn.Module): | |||
""" A multi-channel fsmn unit | |||
""" | |||
def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1): | |||
""" | |||
Args: | |||
dimlinear: input / output dimension | |||
dimproj: fsmn input / output dimension | |||
lorder: left ofder | |||
rorder: right order | |||
""" | |||
super(FSMNUnit, self).__init__() | |||
self.shrink = LinearTransform(dimlinear, dimproj) | |||
self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1) | |||
self.expand = AffineTransform(dimproj, dimlinear) | |||
self.debug = False | |||
self.dataout = None | |||
''' | |||
batch, time, channel, feature | |||
''' | |||
def forward(self, input): | |||
if torch.cuda.is_available(): | |||
out = torch.zeros(input.shape).cuda() | |||
else: | |||
out = torch.zeros(input.shape) | |||
for n in range(input.shape[2]): | |||
out1 = self.shrink(input[:, :, n, :]) | |||
out2 = self.fsmn(out1) | |||
out[:, :, n, :] = F.relu(self.expand(out2)) | |||
if self.debug: | |||
self.dataout = out | |||
return out | |||
def print_model(self): | |||
self.shrink.print_model() | |||
self.fsmn.print_model() | |||
self.expand.print_model() | |||
def to_kaldi_nnet(self): | |||
re_str = self.shrink.to_kaldi_nnet() | |||
re_str += self.fsmn.to_kaldi_nnet() | |||
re_str += self.expand.to_kaldi_nnet() | |||
relu = RectifiedLinear(self.expand.linear.out_features, | |||
self.expand.linear.out_features) | |||
re_str += relu.to_kaldi_nnet() | |||
return re_str | |||
class FSMNSeleNetV2(nn.Module): | |||
""" FSMN model with channel selection. | |||
""" | |||
def __init__(self, | |||
input_dim=120, | |||
linear_dim=128, | |||
proj_dim=64, | |||
lorder=20, | |||
rorder=1, | |||
num_syn=5, | |||
fsmn_layers=5, | |||
sele_layer=0): | |||
""" | |||
Args: | |||
input_dim: input dimension | |||
linear_dim: fsmn input dimension | |||
proj_dim: fsmn projection dimension | |||
lorder: fsmn left order | |||
rorder: fsmn right order | |||
num_syn: output dimension | |||
fsmn_layers: no. of fsmn units | |||
sele_layer: channel selection layer index | |||
""" | |||
super(FSMNSeleNetV2, self).__init__() | |||
self.sele_layer = sele_layer | |||
self.featmap = AffineTransform(input_dim, linear_dim) | |||
self.mem = [] | |||
for i in range(fsmn_layers): | |||
unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder) | |||
self.mem.append(unit) | |||
self.add_module('mem_{:d}'.format(i), unit) | |||
self.decision = AffineTransform(linear_dim, num_syn) | |||
def forward(self, input): | |||
# multi-channel feature mapping | |||
if torch.cuda.is_available(): | |||
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], | |||
self.featmap.linear.out_features).cuda() | |||
else: | |||
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], | |||
self.featmap.linear.out_features) | |||
for n in range(input.shape[2]): | |||
x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :])) | |||
for i, unit in enumerate(self.mem): | |||
y = unit(x) | |||
# perform channel selection | |||
if i == self.sele_layer: | |||
pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1)) | |||
y = pool(y) | |||
x = y | |||
# remove channel dimension | |||
y = torch.squeeze(y, -2) | |||
z = self.decision(y) | |||
return z | |||
def print_model(self): | |||
self.featmap.print_model() | |||
for unit in self.mem: | |||
unit.print_model() | |||
self.decision.print_model() | |||
def print_header(self): | |||
''' | |||
get FSMN params | |||
''' | |||
input_dim = self.featmap.linear.in_features | |||
linear_dim = self.featmap.linear.out_features | |||
proj_dim = self.mem[0].shrink.linear.out_features | |||
lorder = self.mem[0].fsmn.conv_left.kernel_size[0] | |||
rorder = 0 | |||
if self.mem[0].fsmn.conv_right is not None: | |||
rorder = self.mem[0].fsmn.conv_right.kernel_size[0] | |||
num_syn = self.decision.linear.out_features | |||
fsmn_layers = len(self.mem) | |||
# no. of output channels, 0.0 means the same as numins | |||
# numouts = 0.0 | |||
numouts = 1.0 | |||
# | |||
# write total header | |||
# | |||
header = [0.0] * HEADER_BLOCK_SIZE * 4 | |||
# numins | |||
header[0] = 0.0 | |||
# numouts | |||
header[1] = numouts | |||
# dimins | |||
header[2] = input_dim | |||
# dimouts | |||
header[3] = num_syn | |||
# numlayers | |||
header[4] = 3 | |||
# | |||
# write each layer's header | |||
# | |||
hidx = 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_DENSE.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||
ActivationType.ACTIVATION_RELU.value) | |||
hidx += 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_SEQUENTIAL_FSMN.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = lorder | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = rorder | |||
header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers | |||
if numouts == 1.0: | |||
header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer) | |||
else: | |||
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 | |||
hidx += 1 | |||
header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||
LayerType.LAYER_DENSE.value) | |||
header[HEADER_BLOCK_SIZE * hidx + 1] = numouts | |||
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim | |||
header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn | |||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||
header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||
ActivationType.ACTIVATION_SOFTMAX.value) | |||
for h in header: | |||
print(f32ToI32(h)) | |||
def to_kaldi_nnet(self): | |||
re_str = '<Nnet>\n' | |||
re_str = self.featmap.to_kaldi_nnet() | |||
relu = RectifiedLinear(self.featmap.linear.out_features, | |||
self.featmap.linear.out_features) | |||
re_str += relu.to_kaldi_nnet() | |||
for unit in self.mem: | |||
re_str += unit.to_kaldi_nnet() | |||
re_str += self.decision.to_kaldi_nnet() | |||
re_str += '<Softmax> %d %d\n' % (self.decision.linear.out_features, | |||
self.decision.linear.out_features) | |||
re_str += '<!EndOfComponent>\n' | |||
re_str += '</Nnet>\n' | |||
return re_str |
@@ -0,0 +1,74 @@ | |||
import os | |||
from typing import Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from .fsmn_sele_v2 import FSMNSeleNetV2 | |||
@MODELS.register_module( | |||
Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield) | |||
class FSMNSeleNetV2Decorator(TorchModel): | |||
r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """ | |||
MODEL_TXT = 'model.txt' | |||
SC_CONFIG = 'sound_connect.conf' | |||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the dfsmn model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
model_bin_file = os.path.join(model_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
self._model = None | |||
if os.path.exists(model_bin_file): | |||
self._model = FSMNSeleNetV2(*args, **kwargs) | |||
checkpoint = torch.load(model_bin_file) | |||
self._model.load_state_dict(checkpoint, strict=False) | |||
self._sc = None | |||
if os.path.exists(model_txt_file): | |||
with open(sc_config_file) as f: | |||
lines = f.readlines() | |||
with open(sc_config_file, 'w') as f: | |||
for line in lines: | |||
if self.SC_CONF_ITEM_KWS_MODEL in line: | |||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||
model_txt_file) | |||
f.write(line) | |||
import py_sound_connect | |||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
self.size_in = self._sc.bytesPerBlockIn() | |||
self.size_out = self._sc.bytesPerBlockOut() | |||
if self._model is None and self._sc is None: | |||
raise Exception( | |||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||
) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
... | |||
def forward_decode(self, data: bytes): | |||
result = {'pcm': self._sc.process(data, self.size_out)} | |||
state = self._sc.kwsState() | |||
if state == 2: | |||
result['kws'] = { | |||
'keyword': | |||
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), | |||
'offset': self._sc.kwsKeywordOffset(), | |||
'length': self._sc.kwsKeywordLength(), | |||
'confidence': self._sc.kwsConfidence() | |||
} | |||
return result |
@@ -0,0 +1,121 @@ | |||
import math | |||
import struct | |||
from enum import Enum | |||
HEADER_BLOCK_SIZE = 10 | |||
class LayerType(Enum): | |||
LAYER_DENSE = 1 | |||
LAYER_GRU = 2 | |||
LAYER_ATTENTION = 3 | |||
LAYER_FSMN = 4 | |||
LAYER_SEQUENTIAL_FSMN = 5 | |||
LAYER_FSMN_SELE = 6 | |||
LAYER_GRU_ATTENTION = 7 | |||
LAYER_DFSMN = 8 | |||
class ActivationType(Enum): | |||
ACTIVATION_NONE = 0 | |||
ACTIVATION_RELU = 1 | |||
ACTIVATION_TANH = 2 | |||
ACTIVATION_SIGMOID = 3 | |||
ACTIVATION_SOFTMAX = 4 | |||
ACTIVATION_LOGSOFTMAX = 5 | |||
def f32ToI32(f): | |||
""" | |||
print layer | |||
""" | |||
bs = struct.pack('f', f) | |||
ba = bytearray() | |||
ba.append(bs[0]) | |||
ba.append(bs[1]) | |||
ba.append(bs[2]) | |||
ba.append(bs[3]) | |||
return struct.unpack('i', ba)[0] | |||
def printNeonMatrix(w): | |||
""" | |||
print matrix with neon padding | |||
""" | |||
numrows, numcols = w.shape | |||
numnecols = math.ceil(numcols / 4) | |||
for i in range(numrows): | |||
for j in range(numcols): | |||
print(f32ToI32(w[i, j])) | |||
for j in range(numnecols * 4 - numcols): | |||
print(0) | |||
def printNeonVector(b): | |||
""" | |||
print vector with neon padding | |||
""" | |||
size = b.shape[0] | |||
nesize = math.ceil(size / 4) | |||
for i in range(size): | |||
print(f32ToI32(b[i])) | |||
for i in range(nesize * 4 - size): | |||
print(0) | |||
def printDense(layer): | |||
""" | |||
save dense layer | |||
""" | |||
statedict = layer.state_dict() | |||
printNeonMatrix(statedict['weight']) | |||
printNeonVector(statedict['bias']) | |||
def printGRU(layer): | |||
""" | |||
save gru layer | |||
""" | |||
statedict = layer.state_dict() | |||
weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']] | |||
bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']] | |||
numins, numouts = weight[0].shape | |||
numins = numins // 3 | |||
# output input weights | |||
w_rx = weight[0][:numins, :] | |||
w_zx = weight[0][numins:numins * 2, :] | |||
w_x = weight[0][numins * 2:, :] | |||
printNeonMatrix(w_zx) | |||
printNeonMatrix(w_rx) | |||
printNeonMatrix(w_x) | |||
# output recurrent weights | |||
w_rh = weight[1][:numins, :] | |||
w_zh = weight[1][numins:numins * 2, :] | |||
w_h = weight[1][numins * 2:, :] | |||
printNeonMatrix(w_zh) | |||
printNeonMatrix(w_rh) | |||
printNeonMatrix(w_h) | |||
# output input bias | |||
b_rx = bias[0][:numins] | |||
b_zx = bias[0][numins:numins * 2] | |||
b_x = bias[0][numins * 2:] | |||
printNeonVector(b_zx) | |||
printNeonVector(b_rx) | |||
printNeonVector(b_x) | |||
# output recurrent bias | |||
b_rh = bias[1][:numins] | |||
b_zh = bias[1][numins:numins * 2] | |||
b_h = bias[1][numins * 2:] | |||
printNeonVector(b_zh) | |||
printNeonVector(b_rh) | |||
printNeonVector(b_h) |
@@ -405,7 +405,7 @@ TASK_OUTPUTS = { | |||
# audio processed for single file in PCM format | |||
# { | |||
# "output_pcm": np.array with shape(samples,) and dtype float32 | |||
# "output_pcm": pcm encoded audio bytes | |||
# } | |||
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], | |||
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], | |||
@@ -417,6 +417,19 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], | |||
# { | |||
# "kws_list": [ | |||
# { | |||
# 'keyword': '', # the keyword spotted | |||
# 'offset': 19.4, # the keyword start time in second | |||
# 'length': 0.68, # the keyword length in second | |||
# 'confidence': 0.85 # the possibility if it is the keyword | |||
# }, | |||
# ... | |||
# ] | |||
# } | |||
Tasks.keyword_spotting: [OutputKeys.KWS_LIST], | |||
# ============ multi-modal tasks =================== | |||
# image caption result for single sample | |||
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .ans_pipeline import ANSPipeline | |||
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||
from .kws_farfield_pipeline import KWSFarfieldPipeline | |||
from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline | |||
from .linear_aec_pipeline import LinearAECPipeline | |||
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline | |||
@@ -14,6 +15,7 @@ else: | |||
_import_structure = { | |||
'ans_pipeline': ['ANSPipeline'], | |||
'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], | |||
'kws_farfield_pipeline': ['KWSFarfieldPipeline'], | |||
'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'], | |||
'linear_aec_pipeline': ['LinearAECPipeline'], | |||
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'], | |||
@@ -0,0 +1,81 @@ | |||
import io | |||
import wave | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
Tasks.keyword_spotting, | |||
module_name=Pipelines.speech_dfsmn_kws_char_farfield) | |||
class KWSFarfieldPipeline(Pipeline): | |||
r"""A Keyword Spotting Inference Pipeline . | |||
When invoke the class with pipeline.__call__(), it accept only one parameter: | |||
inputs(str): the path of wav file | |||
""" | |||
SAMPLE_RATE = 16000 | |||
SAMPLE_WIDTH = 2 | |||
INPUT_CHANNELS = 3 | |||
OUTPUT_CHANNELS = 2 | |||
def __init__(self, model, **kwargs): | |||
""" | |||
use `model` to create a kws far field pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
self.model = self.model.to(self.device) | |||
self.model.eval() | |||
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH | |||
self._nframe = self.model.size_in // frame_size | |||
self.frame_count = 0 | |||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||
if isinstance(inputs, bytes): | |||
return dict(input_file=inputs) | |||
elif isinstance(inputs, Dict): | |||
return inputs | |||
else: | |||
raise ValueError(f'Not supported input type: {type(inputs)}') | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
input_file = inputs['input_file'] | |||
if isinstance(input_file, bytes): | |||
input_file = io.BytesIO(input_file) | |||
self.frame_count = 0 | |||
kws_list = [] | |||
with wave.open(input_file, 'rb') as fin: | |||
if 'output_file' in inputs: | |||
with wave.open(inputs['output_file'], 'wb') as fout: | |||
fout.setframerate(self.SAMPLE_RATE) | |||
fout.setnchannels(self.OUTPUT_CHANNELS) | |||
fout.setsampwidth(self.SAMPLE_WIDTH) | |||
self._process(fin, kws_list, fout) | |||
else: | |||
self._process(fin, kws_list) | |||
return {OutputKeys.KWS_LIST: kws_list} | |||
def _process(self, | |||
fin: wave.Wave_read, | |||
kws_list, | |||
fout: wave.Wave_write = None): | |||
data = fin.readframes(self._nframe) | |||
while len(data) >= self.model.size_in: | |||
self.frame_count += self._nframe | |||
result = self.model.forward_decode(data) | |||
if fout: | |||
fout.writeframes(result['pcm']) | |||
if 'kws' in result: | |||
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE | |||
kws_list.append(result['kws']) | |||
data = fin.readframes(self._nframe) | |||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
return inputs |
@@ -255,7 +255,7 @@ class Pipeline(ABC): | |||
return self._collate_fn(torch.from_numpy(data)) | |||
elif isinstance(data, torch.Tensor): | |||
return data.to(self.device) | |||
elif isinstance(data, (str, int, float, bool, type(None))): | |||
elif isinstance(data, (bytes, str, int, float, bool, type(None))): | |||
return data | |||
elif isinstance(data, InputFeatures): | |||
return data | |||
@@ -16,6 +16,7 @@ numpy<=1.18 | |||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | |||
protobuf>3,<3.21.0 | |||
ptflops | |||
py_sound_connect | |||
pytorch_wavelets | |||
PyWavelets>=1.0.0 | |||
scikit-learn | |||
@@ -0,0 +1,43 @@ | |||
import os.path | |||
import unittest | |||
from modelscope.fileio import File | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' | |||
class KWSFarfieldTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_normal(self): | |||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} | |||
result = kws(inputs) | |||
self.assertEqual(len(result['kws_list']), 5) | |||
print(result['kws_list'][-1]) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_output(self): | |||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
inputs = { | |||
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE), | |||
'output_file': 'output.wav' | |||
} | |||
result = kws(inputs) | |||
self.assertEqual(len(result['kws_list']), 5) | |||
print(result['kws_list'][-1]) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_input_bytes(self): | |||
with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f: | |||
data = f.read() | |||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
result = kws(data) | |||
self.assertEqual(len(result['kws_list']), 5) | |||
print(result['kws_list'][-1]) |
@@ -8,22 +8,10 @@ from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | |||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | |||
NEAREND_MIC_FILE = 'nearend_mic.wav' | |||
FAREND_SPEECH_FILE = 'farend_speech.wav' | |||
NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' | |||
FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' | |||
NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' | |||
NOISE_SPEECH_FILE = 'speech_with_noise.wav' | |||
def download(remote_path, local_path): | |||
local_dir = os.path.dirname(local_path) | |||
if len(local_dir) > 0: | |||
if not os.path.exists(local_dir): | |||
os.makedirs(local_dir) | |||
with open(local_path, 'wb') as ofile: | |||
ofile.write(File.read(remote_path)) | |||
NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' | |||
class SpeechSignalProcessTest(unittest.TestCase): | |||
@@ -33,13 +21,10 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_aec(self): | |||
# Download audio files | |||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||
model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
input = { | |||
'nearend_mic': NEAREND_MIC_FILE, | |||
'farend_speech': FAREND_SPEECH_FILE | |||
'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE), | |||
'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE) | |||
} | |||
aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) | |||
output_path = os.path.abspath('output.wav') | |||
@@ -48,14 +33,11 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_aec_bytes(self): | |||
# Download audio files | |||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||
model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
input = {} | |||
with open(NEAREND_MIC_FILE, 'rb') as f: | |||
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: | |||
input['nearend_mic'] = f.read() | |||
with open(FAREND_SPEECH_FILE, 'rb') as f: | |||
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: | |||
input['farend_speech'] = f.read() | |||
aec = pipeline( | |||
Tasks.acoustic_echo_cancellation, | |||
@@ -67,13 +49,10 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_aec_tuple_bytes(self): | |||
# Download audio files | |||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||
model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
with open(NEAREND_MIC_FILE, 'rb') as f: | |||
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: | |||
nearend_bytes = f.read() | |||
with open(FAREND_SPEECH_FILE, 'rb') as f: | |||
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: | |||
farend_bytes = f.read() | |||
inputs = (nearend_bytes, farend_bytes) | |||
aec = pipeline( | |||
@@ -86,25 +65,22 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_ans(self): | |||
# Download audio files | |||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||
model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) | |||
output_path = os.path.abspath('output.wav') | |||
ans(NOISE_SPEECH_FILE, output_path=output_path) | |||
ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), | |||
output_path=output_path) | |||
print(f'Processed audio saved to {output_path}') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_ans_bytes(self): | |||
# Download audio files | |||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||
model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
ans = pipeline( | |||
Tasks.acoustic_noise_suppression, | |||
model=model_id, | |||
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) | |||
output_path = os.path.abspath('output.wav') | |||
with open(NOISE_SPEECH_FILE, 'rb') as f: | |||
with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f: | |||
data = f.read() | |||
ans(data, output_path=output_path) | |||
print(f'Processed audio saved to {output_path}') | |||