Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8973072 * [to #41669377] docs and tools refinement and release 1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307 * [to #41669377] add pipeline tutorial and fix bugs 1. add pipleine tutorial 2. fix bugs when using pipeline with certain model and preprocessor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301 * refine doc * feat: add audio aec pipeline and preprocessor * feat: add audio aec model classes * feat: add audio aec loss functions * refactor:delete no longer used loss function * [to #42281043] support kwargs in pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 * support kwargs in pipeline * update develop doc with CR instruction * Merge branch 'release/0.1' into dev/aec * style: reformat code by pre-commit tools * feat:support maas_lib pipeline auto downloading model * test:add aec test case as sample code * feat:aec pipeline use config from maashub * feat:aec pipeline use feature parameters from maashub * update setup.cfg to disable PEP8 rule W503 in flake8 and yapf * format:fix double quoted strings, indent issues and optimize import * refactor:extract some constant in aec pipeline * refactor: delete no longer used __main__ statement * chore:change all Chinese comments to English * fix: change file name style to lower case * refactor: rename model name * feat:load C++ .so from LD_LIBRARY_PATH * feat:register PROPROCESSOR for LinearAECAndFbank * refactory:move aec process from postprocess() to forward() and update comments * refactory:add more readable error message when audio sample rate is not 16000 * fix: package maas_lib renamed to modelscope in import statement * feat: optimize the error message of audio layer classes * format: delete empty lines * refactor: rename audio preprocessor and optimize error message * refactor: change aec model id to damo/speech_dfsmn_aec_psm_16k * refactor: change sample audio file url to public oss * Merge branch 'master' into dev/aec * feat: add output info for aec pipeline * fix: normalize output audio data to [-1.0, 1.0] * refactor:use constant from ModelFile * feat: AEC pipeline can use c++ lib in current working directory and the test will download it * fix: c++ downloading should work wherever test is triggerdmaster
@@ -0,0 +1,60 @@ | |||||
import torch.nn as nn | |||||
from .layer_base import LayerBase | |||||
class RectifiedLinear(LayerBase): | |||||
def __init__(self, input_dim, output_dim): | |||||
super(RectifiedLinear, self).__init__() | |||||
self.dim = input_dim | |||||
self.relu = nn.ReLU() | |||||
def forward(self, input): | |||||
return self.relu(input) | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim) | |||||
return re_str | |||||
def load_kaldi_nnet(self, instr): | |||||
return instr | |||||
class LogSoftmax(LayerBase): | |||||
def __init__(self, input_dim, output_dim): | |||||
super(LogSoftmax, self).__init__() | |||||
self.dim = input_dim | |||||
self.ls = nn.LogSoftmax() | |||||
def forward(self, input): | |||||
return self.ls(input) | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<Softmax> %d %d\n' % (self.dim, self.dim) | |||||
return re_str | |||||
def load_kaldi_nnet(self, instr): | |||||
return instr | |||||
class Sigmoid(LayerBase): | |||||
def __init__(self, input_dim, output_dim): | |||||
super(Sigmoid, self).__init__() | |||||
self.dim = input_dim | |||||
self.sig = nn.Sigmoid() | |||||
def forward(self, input): | |||||
return self.sig(input) | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<Sigmoid> %d %d\n' % (self.dim, self.dim) | |||||
return re_str | |||||
def load_kaldi_nnet(self, instr): | |||||
return instr |
@@ -0,0 +1,78 @@ | |||||
import numpy as np | |||||
import torch as th | |||||
import torch.nn as nn | |||||
from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, | |||||
to_kaldi_matrix) | |||||
class AffineTransform(LayerBase): | |||||
def __init__(self, input_dim, output_dim): | |||||
super(AffineTransform, self).__init__() | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.linear = nn.Linear(input_dim, output_dim) | |||||
def forward(self, input): | |||||
return self.linear(input) | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim, | |||||
self.input_dim) | |||||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n' | |||||
linear_weights = self.state_dict()['linear.weight'] | |||||
x = linear_weights.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
linear_bias = self.state_dict()['linear.bias'] | |||||
x = linear_bias.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
return re_str | |||||
def to_raw_nnet(self, fid): | |||||
linear_weights = self.state_dict()['linear.weight'] | |||||
x = linear_weights.squeeze().numpy() | |||||
x.tofile(fid) | |||||
linear_bias = self.state_dict()['linear.bias'] | |||||
x = linear_bias.squeeze().numpy() | |||||
x.tofile(fid) | |||||
def load_kaldi_nnet(self, instr): | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LearnRateCoef>', | |||||
) | |||||
if output is None: | |||||
raise Exception('AffineTransform format error for <LearnRateCoef>') | |||||
instr, lr = output | |||||
output = expect_token_number(instr, '<BiasLearnRateCoef>') | |||||
if output is None: | |||||
raise Exception( | |||||
'AffineTransform format error for <BiasLearnRateCoef>') | |||||
instr, lr = output | |||||
output = expect_token_number(instr, '<MaxNorm>') | |||||
if output is None: | |||||
raise Exception('AffineTransform format error for <MaxNorm>') | |||||
instr, lr = output | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('AffineTransform format error for parsing matrix') | |||||
instr, mat = output | |||||
print(mat.shape) | |||||
self.linear.weight = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('AffineTransform format error for parsing matrix') | |||||
instr, mat = output | |||||
mat = np.squeeze(mat) | |||||
self.linear.bias = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
return instr |
@@ -0,0 +1,178 @@ | |||||
import numpy as np | |||||
import torch as th | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, | |||||
to_kaldi_matrix) | |||||
class DeepFsmn(LayerBase): | |||||
def __init__(self, | |||||
input_dim, | |||||
output_dim, | |||||
lorder=None, | |||||
rorder=None, | |||||
hidden_size=None, | |||||
layer_norm=False, | |||||
dropout=0): | |||||
super(DeepFsmn, self).__init__() | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
if lorder is None: | |||||
return | |||||
self.lorder = lorder | |||||
self.rorder = rorder | |||||
self.hidden_size = hidden_size | |||||
self.layer_norm = layer_norm | |||||
self.linear = nn.Linear(input_dim, hidden_size) | |||||
self.norm = nn.LayerNorm(hidden_size) | |||||
self.drop1 = nn.Dropout(p=dropout) | |||||
self.drop2 = nn.Dropout(p=dropout) | |||||
self.project = nn.Linear(hidden_size, output_dim, bias=False) | |||||
self.conv1 = nn.Conv2d( | |||||
output_dim, | |||||
output_dim, [lorder, 1], [1, 1], | |||||
groups=output_dim, | |||||
bias=False) | |||||
self.conv2 = nn.Conv2d( | |||||
output_dim, | |||||
output_dim, [rorder, 1], [1, 1], | |||||
groups=output_dim, | |||||
bias=False) | |||||
def forward(self, input): | |||||
f1 = F.relu(self.linear(input)) | |||||
f1 = self.drop1(f1) | |||||
if self.layer_norm: | |||||
f1 = self.norm(f1) | |||||
p1 = self.project(f1) | |||||
x = th.unsqueeze(p1, 1) | |||||
x_per = x.permute(0, 3, 2, 1) | |||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) | |||||
yr = F.pad(x_per, [0, 0, 0, self.rorder]) | |||||
yr = yr[:, :, 1:, :] | |||||
out = x_per + self.conv1(y) + self.conv2(yr) | |||||
out = self.drop2(out) | |||||
out1 = out.permute(0, 3, 2, 1) | |||||
return input + out1.squeeze() | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<UniDeepFsmn> %d %d\n'\ | |||||
% (self.output_dim, self.input_dim) | |||||
re_str += '<LearnRateCoef> %d <HidSize> %d <LOrder> %d <LStride> %d <MaxNorm> 0\n'\ | |||||
% (1, self.hidden_size, self.lorder, 1) | |||||
lfiters = self.state_dict()['conv1.weight'] | |||||
x = np.flipud(lfiters.squeeze().numpy().T) | |||||
re_str += to_kaldi_matrix(x) | |||||
proj_weights = self.state_dict()['project.weight'] | |||||
x = proj_weights.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
linear_weights = self.state_dict()['linear.weight'] | |||||
x = linear_weights.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
linear_bias = self.state_dict()['linear.bias'] | |||||
x = linear_bias.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
return re_str | |||||
def load_kaldi_nnet(self, instr): | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LearnRateCoef>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LearnRateCoef>') | |||||
instr, lr = output | |||||
output = expect_token_number( | |||||
instr, | |||||
'<HidSize>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <HidSize>') | |||||
instr, hiddensize = output | |||||
self.hidden_size = int(hiddensize) | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LOrder>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LOrder>') | |||||
instr, lorder = output | |||||
self.lorder = int(lorder) | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LStride>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LStride>') | |||||
instr, lstride = output | |||||
self.lstride = lstride | |||||
output = expect_token_number( | |||||
instr, | |||||
'<MaxNorm>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <MaxNorm>') | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
mat1 = np.fliplr(mat.T).copy() | |||||
self.conv1 = nn.Conv2d( | |||||
self.output_dim, | |||||
self.output_dim, [self.lorder, 1], [1, 1], | |||||
groups=self.output_dim, | |||||
bias=False) | |||||
mat_th = th.from_numpy(mat1).type(th.FloatTensor) | |||||
mat_th = mat_th.unsqueeze(1) | |||||
mat_th = mat_th.unsqueeze(3) | |||||
self.conv1.weight = th.nn.Parameter(mat_th) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) | |||||
self.linear = nn.Linear(self.input_dim, self.hidden_size) | |||||
self.project.weight = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
self.linear.weight = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
self.linear.bias = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
return instr |
@@ -0,0 +1,50 @@ | |||||
import abc | |||||
import re | |||||
import numpy as np | |||||
import torch.nn as nn | |||||
def expect_token_number(instr, token): | |||||
first_token = re.match(r'^\s*' + token, instr) | |||||
if first_token is None: | |||||
return None | |||||
instr = instr[first_token.end():] | |||||
lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr) | |||||
if lr is None: | |||||
return None | |||||
return instr[lr.end():], lr.groups()[0] | |||||
def expect_kaldi_matrix(instr): | |||||
pos2 = instr.find('[', 0) | |||||
pos3 = instr.find(']', pos2) | |||||
mat = [] | |||||
for stt in instr[pos2 + 1:pos3].split('\n'): | |||||
tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ') | |||||
if tmp_mat.size > 0: | |||||
mat.append(tmp_mat) | |||||
return instr[pos3 + 1:], np.array(mat) | |||||
def to_kaldi_matrix(np_mat): | |||||
""" | |||||
function that transform as str numpy mat to standard kaldi str matrix | |||||
:param np_mat: numpy mat | |||||
:return: str | |||||
""" | |||||
np.set_printoptions(threshold=np.inf, linewidth=np.nan, suppress=True) | |||||
out_str = str(np_mat) | |||||
out_str = out_str.replace('[', '') | |||||
out_str = out_str.replace(']', '') | |||||
return '[ %s ]\n' % out_str | |||||
class LayerBase(nn.Module, metaclass=abc.ABCMeta): | |||||
def __init__(self): | |||||
super(LayerBase, self).__init__() | |||||
@abc.abstractmethod | |||||
def to_kaldi_nnet(self): | |||||
pass |
@@ -0,0 +1,482 @@ | |||||
import numpy as np | |||||
import torch as th | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, | |||||
to_kaldi_matrix) | |||||
class SepConv(nn.Module): | |||||
def __init__(self, | |||||
in_channels, | |||||
filters, | |||||
out_channels, | |||||
kernel_size=(5, 2), | |||||
dilation=(1, 1)): | |||||
""" :param kernel_size (time, frequency) | |||||
""" | |||||
super(SepConv, self).__init__() | |||||
# depthwise + pointwise | |||||
self.dconv = nn.Conv2d( | |||||
in_channels, | |||||
in_channels * filters, | |||||
kernel_size, | |||||
dilation=dilation, | |||||
groups=in_channels) | |||||
self.pconv = nn.Conv2d( | |||||
in_channels * filters, out_channels, kernel_size=1) | |||||
self.padding = dilation[0] * (kernel_size[0] - 1) | |||||
def forward(self, input): | |||||
''' input: [B, C, T, F] | |||||
''' | |||||
x = F.pad(input, [0, 0, self.padding, 0]) | |||||
x = self.dconv(x) | |||||
x = self.pconv(x) | |||||
return x | |||||
class Conv2d(nn.Module): | |||||
def __init__(self, | |||||
input_dim, | |||||
output_dim, | |||||
lorder=20, | |||||
rorder=0, | |||||
groups=1, | |||||
bias=False, | |||||
skip_connect=True): | |||||
super(Conv2d, self).__init__() | |||||
self.lorder = lorder | |||||
self.conv = nn.Conv2d( | |||||
input_dim, output_dim, [lorder, 1], groups=groups, bias=bias) | |||||
self.rorder = rorder | |||||
if self.rorder: | |||||
self.conv2 = nn.Conv2d( | |||||
input_dim, output_dim, [rorder, 1], groups=groups, bias=bias) | |||||
self.skip_connect = skip_connect | |||||
def forward(self, input): | |||||
# [B, 1, T, F] | |||||
x = th.unsqueeze(input, 1) | |||||
# [B, F, T, 1] | |||||
x_per = x.permute(0, 3, 2, 1) | |||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) | |||||
out = self.conv(y) | |||||
if self.rorder: | |||||
yr = F.pad(x_per, [0, 0, 0, self.rorder]) | |||||
yr = yr[:, :, 1:, :] | |||||
out += self.conv2(yr) | |||||
out = out.permute(0, 3, 2, 1).squeeze(1) | |||||
if self.skip_connect: | |||||
out = out + input | |||||
return out | |||||
class SelfAttLayer(nn.Module): | |||||
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): | |||||
super(SelfAttLayer, self).__init__() | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
if lorder is None: | |||||
return | |||||
self.lorder = lorder | |||||
self.hidden_size = hidden_size | |||||
self.linear = nn.Linear(input_dim, hidden_size) | |||||
self.project = nn.Linear(hidden_size, output_dim, bias=False) | |||||
self.att = nn.Linear(input_dim, lorder, bias=False) | |||||
def forward(self, input): | |||||
f1 = F.relu(self.linear(input)) | |||||
p1 = self.project(f1) | |||||
x = th.unsqueeze(p1, 1) | |||||
x_per = x.permute(0, 3, 2, 1) | |||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) | |||||
# z [B, F, T, lorder] | |||||
z = x_per | |||||
for i in range(1, self.lorder): | |||||
z = th.cat([z, y[:, :, self.lorder - 1 - i:-i, :]], axis=-1) | |||||
# [B, T, lorder] | |||||
att = F.softmax(self.att(input), dim=-1) | |||||
att = th.unsqueeze(att, 1) | |||||
z = th.sum(z * att, axis=-1) | |||||
out1 = z.permute(0, 2, 1) | |||||
return input + out1 | |||||
class TFFsmn(nn.Module): | |||||
def __init__(self, | |||||
input_dim, | |||||
output_dim, | |||||
lorder=None, | |||||
hidden_size=None, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
skip_connect=True): | |||||
super(TFFsmn, self).__init__() | |||||
self.skip_connect = skip_connect | |||||
self.linear = nn.Linear(input_dim, hidden_size) | |||||
self.norm = nn.Identity() | |||||
if layer_norm: | |||||
self.norm = nn.LayerNorm(input_dim) | |||||
self.act = nn.ReLU() | |||||
self.project = nn.Linear(hidden_size, output_dim, bias=False) | |||||
self.conv1 = nn.Conv2d( | |||||
output_dim, | |||||
output_dim, [lorder, 1], | |||||
dilation=[dilation, 1], | |||||
groups=output_dim, | |||||
bias=False) | |||||
self.padding_left = dilation * (lorder - 1) | |||||
dorder = 5 | |||||
self.conv2 = nn.Conv2d(1, 1, [dorder, 1], bias=False) | |||||
self.padding_freq = dorder - 1 | |||||
def forward(self, input): | |||||
return self.compute1(input) | |||||
def compute1(self, input): | |||||
''' linear-dconv-relu(norm)-linear-dconv | |||||
''' | |||||
x = self.linear(input) | |||||
# [B, 1, F, T] | |||||
x = th.unsqueeze(x, 1).permute(0, 1, 3, 2) | |||||
z = F.pad(x, [0, 0, self.padding_freq, 0]) | |||||
z = self.conv2(z) + x | |||||
x = z.permute(0, 3, 2, 1).squeeze(-1) | |||||
x = self.act(x) | |||||
x = self.norm(x) | |||||
x = self.project(x) | |||||
x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) | |||||
# [B, F, T+lorder-1, 1] | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.conv1(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
return input + out | |||||
class CNNFsmn(nn.Module): | |||||
''' use cnn to reduce parameters | |||||
''' | |||||
def __init__(self, | |||||
input_dim, | |||||
output_dim, | |||||
lorder=None, | |||||
hidden_size=None, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
skip_connect=True): | |||||
super(CNNFsmn, self).__init__() | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.skip_connect = skip_connect | |||||
if lorder is None: | |||||
return | |||||
self.lorder = lorder | |||||
self.hidden_size = hidden_size | |||||
self.linear = nn.Linear(input_dim, hidden_size) | |||||
self.act = nn.ReLU() | |||||
kernel_size = (3, 8) | |||||
stride = (1, 4) | |||||
self.conv = nn.Sequential( | |||||
nn.ConstantPad2d((stride[1], 0, kernel_size[0] - 1, 0), 0), | |||||
nn.Conv2d(1, stride[1], kernel_size=kernel_size, stride=stride)) | |||||
self.dconv = nn.Conv2d( | |||||
output_dim, | |||||
output_dim, [lorder, 1], | |||||
dilation=[dilation, 1], | |||||
groups=output_dim, | |||||
bias=False) | |||||
self.padding_left = dilation * (lorder - 1) | |||||
def forward(self, input): | |||||
return self.compute2(input) | |||||
def compute1(self, input): | |||||
''' linear-relu(norm)-conv2d-relu?-dconv | |||||
''' | |||||
# [B, T, F] | |||||
x = self.linear(input) | |||||
x = self.act(x) | |||||
x = th.unsqueeze(x, 1) | |||||
x = self.conv(x) | |||||
# [B, C, T, F] -> [B, 1, T, F] | |||||
b, c, t, f = x.shape | |||||
x = x.view([b, 1, t, -1]) | |||||
x = x.permute(0, 3, 2, 1) | |||||
# [B, F, T+lorder-1, 1] | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.dconv(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
return input + out | |||||
def compute2(self, input): | |||||
''' conv2d-relu-linear-relu?-dconv | |||||
''' | |||||
x = th.unsqueeze(input, 1) | |||||
x = self.conv(x) | |||||
x = self.act(x) | |||||
# [B, C, T, F] -> [B, T, F] | |||||
b, c, t, f = x.shape | |||||
x = x.view([b, t, -1]) | |||||
x = self.linear(x) | |||||
x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.dconv(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
return input + out | |||||
class UniDeepFsmn(LayerBase): | |||||
def __init__(self, | |||||
input_dim, | |||||
output_dim, | |||||
lorder=None, | |||||
hidden_size=None, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
skip_connect=True): | |||||
super(UniDeepFsmn, self).__init__() | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.skip_connect = skip_connect | |||||
if lorder is None: | |||||
return | |||||
self.lorder = lorder | |||||
self.hidden_size = hidden_size | |||||
self.linear = nn.Linear(input_dim, hidden_size) | |||||
self.norm = nn.Identity() | |||||
if layer_norm: | |||||
self.norm = nn.LayerNorm(input_dim) | |||||
self.act = nn.ReLU() | |||||
self.project = nn.Linear(hidden_size, output_dim, bias=False) | |||||
self.conv1 = nn.Conv2d( | |||||
output_dim, | |||||
output_dim, [lorder, 1], | |||||
dilation=[dilation, 1], | |||||
groups=output_dim, | |||||
bias=False) | |||||
self.padding_left = dilation * (lorder - 1) | |||||
def forward(self, input): | |||||
return self.compute1(input) | |||||
def compute1(self, input): | |||||
''' linear-relu(norm)-linear-dconv | |||||
''' | |||||
# [B, T, F] | |||||
x = self.linear(input) | |||||
x = self.act(x) | |||||
x = self.norm(x) | |||||
x = self.project(x) | |||||
x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) | |||||
# [B, F, T+lorder-1, 1] | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.conv1(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
return input + out | |||||
def compute2(self, input): | |||||
''' linear-dconv-linear-relu(norm) | |||||
''' | |||||
x = self.project(input) | |||||
x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.conv1(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
x = self.linear(out) | |||||
x = self.act(x) | |||||
x = self.norm(x) | |||||
return input + x | |||||
def compute3(self, input): | |||||
''' dconv-linear-relu(norm)-linear | |||||
''' | |||||
x = th.unsqueeze(input, 1).permute(0, 3, 2, 1) | |||||
y = F.pad(x, [0, 0, self.padding_left, 0]) | |||||
out = self.conv1(y) | |||||
if self.skip_connect: | |||||
out = out + x | |||||
out = out.permute(0, 3, 2, 1).squeeze() | |||||
x = self.linear(out) | |||||
x = self.act(x) | |||||
x = self.norm(x) | |||||
x = self.project(x) | |||||
return input + x | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<UniDeepFsmn> %d %d\n' \ | |||||
% (self.output_dim, self.input_dim) | |||||
re_str += '<LearnRateCoef> %d <HidSize> %d <LOrder> %d <LStride> %d <MaxNorm> 0\n' \ | |||||
% (1, self.hidden_size, self.lorder, 1) | |||||
lfiters = self.state_dict()['conv1.weight'] | |||||
x = np.flipud(lfiters.squeeze().numpy().T) | |||||
re_str += to_kaldi_matrix(x) | |||||
proj_weights = self.state_dict()['project.weight'] | |||||
x = proj_weights.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
linear_weights = self.state_dict()['linear.weight'] | |||||
x = linear_weights.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
linear_bias = self.state_dict()['linear.bias'] | |||||
x = linear_bias.squeeze().numpy() | |||||
re_str += to_kaldi_matrix(x) | |||||
return re_str | |||||
def to_raw_nnet(self, fid): | |||||
lfiters = self.state_dict()['conv1.weight'] | |||||
x = np.flipud(lfiters.squeeze().numpy().T) | |||||
x.tofile(fid) | |||||
proj_weights = self.state_dict()['project.weight'] | |||||
x = proj_weights.squeeze().numpy() | |||||
x.tofile(fid) | |||||
linear_weights = self.state_dict()['linear.weight'] | |||||
x = linear_weights.squeeze().numpy() | |||||
x.tofile(fid) | |||||
linear_bias = self.state_dict()['linear.bias'] | |||||
x = linear_bias.squeeze().numpy() | |||||
x.tofile(fid) | |||||
def load_kaldi_nnet(self, instr): | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LearnRateCoef>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LearnRateCoef>') | |||||
instr, lr = output | |||||
output = expect_token_number( | |||||
instr, | |||||
'<HidSize>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <HidSize>') | |||||
instr, hiddensize = output | |||||
self.hidden_size = int(hiddensize) | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LOrder>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LOrder>') | |||||
instr, lorder = output | |||||
self.lorder = int(lorder) | |||||
output = expect_token_number( | |||||
instr, | |||||
'<LStride>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <LStride>') | |||||
instr, lstride = output | |||||
self.lstride = lstride | |||||
output = expect_token_number( | |||||
instr, | |||||
'<MaxNorm>', | |||||
) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for <MaxNorm>') | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
mat1 = np.fliplr(mat.T).copy() | |||||
self.conv1 = nn.Conv2d( | |||||
self.output_dim, | |||||
self.output_dim, [self.lorder, 1], [1, 1], | |||||
groups=self.output_dim, | |||||
bias=False) | |||||
mat_th = th.from_numpy(mat1).type(th.FloatTensor) | |||||
mat_th = mat_th.unsqueeze(1) | |||||
mat_th = mat_th.unsqueeze(3) | |||||
self.conv1.weight = th.nn.Parameter(mat_th) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) | |||||
self.linear = nn.Linear(self.input_dim, self.hidden_size) | |||||
self.project.weight = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
self.linear.weight = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
output = expect_kaldi_matrix(instr) | |||||
if output is None: | |||||
raise Exception('UniDeepFsmn format error for parsing matrix') | |||||
instr, mat = output | |||||
mat = np.squeeze(mat) | |||||
self.linear.bias = th.nn.Parameter( | |||||
th.from_numpy(mat).type(th.FloatTensor)) | |||||
return instr |
@@ -0,0 +1,394 @@ | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from .modulation_loss import (GaborSTRFConv, MelScale, | |||||
ModulationDomainLossModule) | |||||
EPS = 1e-8 | |||||
def compute_mask(mixed_spec, clean_spec, mask_type='psmiam', clip=1): | |||||
''' | |||||
stft: (batch, ..., 2) or complex(batch, ...) | |||||
y = x + n | |||||
''' | |||||
if torch.is_complex(mixed_spec): | |||||
yr, yi = mixed_spec.real, mixed_spec.imag | |||||
else: | |||||
yr, yi = mixed_spec[..., 0], mixed_spec[..., 1] | |||||
if torch.is_complex(clean_spec): | |||||
xr, xi = clean_spec.real, clean_spec.imag | |||||
else: | |||||
xr, xi = clean_spec[..., 0], clean_spec[..., 1] | |||||
if mask_type == 'iam': | |||||
ymag = torch.sqrt(yr**2 + yi**2) | |||||
xmag = torch.sqrt(xr**2 + xi**2) | |||||
iam = xmag / (ymag + EPS) | |||||
return torch.clamp(iam, 0, 1) | |||||
elif mask_type == 'psm': | |||||
ypow = yr**2 + yi**2 | |||||
psm = (xr * yr + xi * yi) / (ypow + EPS) | |||||
return torch.clamp(psm, 0, 1) | |||||
elif mask_type == 'psmiam': | |||||
ypow = yr**2 + yi**2 | |||||
psm = (xr * yr + xi * yi) / (ypow + EPS) | |||||
ymag = torch.sqrt(yr**2 + yi**2) | |||||
xmag = torch.sqrt(xr**2 + xi**2) | |||||
iam = xmag / (ymag + EPS) | |||||
psmiam = psm * iam | |||||
return torch.clamp(psmiam, 0, 1) | |||||
elif mask_type == 'crm': | |||||
ypow = yr**2 + yi**2 | |||||
mr = (xr * yr + xi * yi) / (ypow + EPS) | |||||
mi = (xi * yr - xr * yi) / (ypow + EPS) | |||||
mr = torch.clamp(mr, -clip, clip) | |||||
mi = torch.clamp(mi, -clip, clip) | |||||
return mr, mi | |||||
def energy_vad(spec, | |||||
thdhigh=320 * 600 * 600 * 2, | |||||
thdlow=320 * 300 * 300 * 2, | |||||
int16=True): | |||||
''' | |||||
energy based vad should be accurate enough | |||||
spec: (batch, bins, frames, 2) | |||||
returns (batch, frames) | |||||
''' | |||||
energy = torch.sum(spec[..., 0]**2 + spec[..., 1]**2, dim=1) | |||||
vad = energy > thdhigh | |||||
idx = torch.logical_and(vad == 0, energy > thdlow) | |||||
vad[idx] = 0.5 | |||||
return vad | |||||
def modulation_loss_init(n_fft): | |||||
gabor_strf_parameters = torch.load( | |||||
'./network/gabor_strf_parameters.pt')['state_dict'] | |||||
gabor_modulation_kernels = GaborSTRFConv(supn=30, supk=30, nkern=60) | |||||
gabor_modulation_kernels.load_state_dict(gabor_strf_parameters) | |||||
modulation_loss_module = ModulationDomainLossModule( | |||||
gabor_modulation_kernels.eval()) | |||||
for param in modulation_loss_module.parameters(): | |||||
param.requires_grad = False | |||||
stft2mel = MelScale( | |||||
n_mels=80, sample_rate=16000, n_stft=n_fft // 2 + 1).cuda() | |||||
return modulation_loss_module, stft2mel | |||||
def mask_loss_function( | |||||
loss_func='psm_loss', | |||||
loss_type='mse', # ['mse', 'mae', 'comb'] | |||||
mask_type='psmiam', | |||||
use_mod_loss=False, | |||||
use_wav2vec_loss=False, | |||||
n_fft=640, | |||||
hop_length=320, | |||||
EPS=1e-8, | |||||
weight=None): | |||||
if weight is not None: | |||||
print(f'Use loss weight: {weight}') | |||||
winlen = n_fft | |||||
window = torch.hamming_window(winlen, periodic=False) | |||||
def stft(x, return_complex=False): | |||||
# returns [batch, bins, frames, 2] | |||||
return torch.stft( | |||||
x, | |||||
n_fft, | |||||
hop_length, | |||||
winlen, | |||||
window=window.to(x.device), | |||||
center=False, | |||||
return_complex=return_complex) | |||||
def istft(x, slen): | |||||
return torch.istft( | |||||
x, | |||||
n_fft, | |||||
hop_length, | |||||
winlen, | |||||
window=window.to(x.device), | |||||
center=False, | |||||
length=slen) | |||||
def mask_loss(targets, masks, nframes): | |||||
''' [Batch, Time, Frequency] | |||||
''' | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(targets) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
masks = masks * mask_for_loss | |||||
targets = targets * mask_for_loss | |||||
if weight is None: | |||||
alpha = 1 | |||||
else: # for aec ST | |||||
alpha = weight - targets | |||||
if loss_type == 'mse': | |||||
loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2)) | |||||
elif loss_type == 'mae': | |||||
loss = torch.sum(alpha * torch.abs(targets - masks)) | |||||
else: # mse(mask), mae(mask) approx 1:2 | |||||
loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2) | |||||
+ 0.1 * alpha * torch.abs(targets - masks)) | |||||
loss /= torch.sum(nframes) | |||||
return loss | |||||
def spectrum_loss(targets, spec, nframes): | |||||
''' [Batch, Time, Frequency, 2] | |||||
''' | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(targets[..., 0]) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
xr = spec[..., 0] * mask_for_loss | |||||
xi = spec[..., 1] * mask_for_loss | |||||
yr = targets[..., 0] * mask_for_loss | |||||
yi = targets[..., 1] * mask_for_loss | |||||
xmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) * mask_for_loss | |||||
ymag = torch.sqrt(targets[..., 0]**2 | |||||
+ targets[..., 1]**2) * mask_for_loss | |||||
loss1 = torch.sum(torch.pow(xr - yr, 2) + torch.pow(xi - yi, 2)) | |||||
loss2 = torch.sum(torch.pow(xmag - ymag, 2)) | |||||
loss = (loss1 + loss2) / torch.sum(nframes) | |||||
return loss | |||||
def sa_loss_dlen(mixed, clean, masks, nframes): | |||||
yspec = stft(mixed).permute([0, 2, 1, 3]) / 32768 | |||||
xspec = stft(clean).permute([0, 2, 1, 3]) / 32768 | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(xspec[..., 0]) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
emag = ((yspec[..., 0]**2 + yspec[..., 1]**2)**0.15) * (masks**0.3) | |||||
xmag = (xspec[..., 0]**2 + xspec[..., 1]**2)**0.15 | |||||
emag = emag * mask_for_loss | |||||
xmag = xmag * mask_for_loss | |||||
loss = torch.sum(torch.pow(emag - xmag, 2)) / torch.sum(nframes) | |||||
return loss | |||||
def psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask=None): | |||||
mixed_spec = stft(mixed) | |||||
clean_spec = stft(clean) | |||||
targets = compute_mask(mixed_spec, clean_spec, mask_type) | |||||
# [B, T, F] | |||||
targets = targets.permute(0, 2, 1) | |||||
loss = mask_loss(targets, masks, nframes) | |||||
if subtask is not None: | |||||
vadtargets = energy_vad(clean_spec) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(targets[:, :, 0]) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:] = 0 | |||||
subtask = subtask[:, :, 0] * mask_for_loss | |||||
vadtargets = vadtargets * mask_for_loss | |||||
loss_vad = F.binary_cross_entropy(subtask, vadtargets) | |||||
return loss + loss_vad | |||||
return loss | |||||
def modulation_loss(mixed, clean, masks, nframes, subtask=None): | |||||
mixed_spec = stft(mixed, True) | |||||
clean_spec = stft(clean, True) | |||||
enhanced_mag = torch.abs(mixed_spec) | |||||
clean_mag = torch.abs(clean_spec) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(clean_mag) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, :, num:] = 0 | |||||
clean_mag = clean_mag * mask_for_loss | |||||
enhanced_mag = enhanced_mag * mask_for_loss * masks.permute([0, 2, 1]) | |||||
# Covert to log-mel representation | |||||
# (B,T,#mel_channels) | |||||
clean_log_mel = torch.log( | |||||
torch.transpose(stft2mel(clean_mag**2), 2, 1) + 1e-8) | |||||
enhanced_log_mel = torch.log( | |||||
torch.transpose(stft2mel(enhanced_mag**2), 2, 1) + 1e-8) | |||||
alpha = compute_mask(mixed_spec, clean_spec, mask_type) | |||||
alpha = alpha.permute(0, 2, 1) | |||||
loss = 0.05 * modulation_loss_module(enhanced_log_mel, clean_log_mel, | |||||
alpha) | |||||
loss2 = psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask) | |||||
# print(loss.item(), loss2.item()) #approx 1:4 | |||||
loss = loss + loss2 | |||||
return loss | |||||
def wav2vec_loss(mixed, clean, masks, nframes, subtask=None): | |||||
mixed /= 32768 | |||||
clean /= 32768 | |||||
mixed_spec = stft(mixed) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(masks) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
masks_est = masks * mask_for_loss | |||||
estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) | |||||
est_clean = istft(estimate, clean.shape[1]) | |||||
loss = wav2vec_loss_module(est_clean, clean) | |||||
return loss | |||||
def sisdr_loss_dlen(mixed, | |||||
clean, | |||||
masks, | |||||
nframes, | |||||
subtask=None, | |||||
zero_mean=True): | |||||
mixed_spec = stft(mixed) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(masks) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
masks_est = masks * mask_for_loss | |||||
estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) | |||||
est_clean = istft(estimate, clean.shape[1]) | |||||
flen = min(clean.shape[1], est_clean.shape[1]) | |||||
clean = clean[:, :flen] | |||||
est_clean = est_clean[:, :flen] | |||||
# follow asteroid/losses/sdr.py | |||||
if zero_mean: | |||||
clean = clean - torch.mean(clean, dim=1, keepdim=True) | |||||
est_clean = est_clean - torch.mean(est_clean, dim=1, keepdim=True) | |||||
dot = torch.sum(est_clean * clean, dim=1, keepdim=True) | |||||
s_clean_energy = torch.sum(clean**2, dim=1, keepdim=True) + EPS | |||||
scaled_clean = dot * clean / s_clean_energy | |||||
e_noise = est_clean - scaled_clean | |||||
# [batch] | |||||
sisdr = torch.sum( | |||||
scaled_clean**2, dim=1) / ( | |||||
torch.sum(e_noise**2, dim=1) + EPS) | |||||
sisdr = -10 * torch.log10(sisdr + EPS) | |||||
loss = sisdr.mean() | |||||
return loss | |||||
def sisdr_freq_loss_dlen(mixed, clean, masks, nframes, subtask=None): | |||||
mixed_spec = stft(mixed) | |||||
clean_spec = stft(clean) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(masks) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
masks_est = masks * mask_for_loss | |||||
estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) | |||||
dot_real = estimate[..., 0] * clean_spec[..., 0] + \ | |||||
estimate[..., 1] * clean_spec[..., 1] | |||||
dot_imag = estimate[..., 0] * clean_spec[..., 1] - \ | |||||
estimate[..., 1] * clean_spec[..., 0] | |||||
dot = torch.cat([dot_real.unsqueeze(3), dot_imag.unsqueeze(3)], dim=-1) | |||||
s_clean_energy = clean_spec[..., 0] ** 2 + \ | |||||
clean_spec[..., 1] ** 2 + EPS | |||||
scaled_clean = dot * clean_spec / s_clean_energy.unsqueeze(3) | |||||
e_noise = estimate - scaled_clean | |||||
# [batch] | |||||
scaled_clean_energy = torch.sum( | |||||
scaled_clean[..., 0]**2 + scaled_clean[..., 1]**2, dim=1) | |||||
e_noise_energy = torch.sum( | |||||
e_noise[..., 0]**2 + e_noise[..., 1]**2, dim=1) | |||||
sisdr = torch.sum( | |||||
scaled_clean_energy, dim=1) / ( | |||||
torch.sum(e_noise_energy, dim=1) + EPS) | |||||
sisdr = -10 * torch.log10(sisdr + EPS) | |||||
loss = sisdr.mean() | |||||
return loss | |||||
def crm_loss_dlen(mixed, clean, masks, nframes, subtask=None): | |||||
mixed_spec = stft(mixed).permute([0, 2, 1, 3]) | |||||
clean_spec = stft(clean).permute([0, 2, 1, 3]) | |||||
mixed_spec = mixed_spec / 32768 | |||||
clean_spec = clean_spec / 32768 | |||||
tgt_mr, tgt_mi = compute_mask(mixed_spec, clean_spec, mask_type='crm') | |||||
D = int(masks.shape[2] / 2) | |||||
with torch.no_grad(): | |||||
mask_for_loss = torch.ones_like(clean_spec[..., 0]) | |||||
for idx, num in enumerate(nframes): | |||||
mask_for_loss[idx, num:, :] = 0 | |||||
mr = masks[..., :D] * mask_for_loss | |||||
mi = masks[..., D:] * mask_for_loss | |||||
tgt_mr = tgt_mr * mask_for_loss | |||||
tgt_mi = tgt_mi * mask_for_loss | |||||
if weight is None: | |||||
alpha = 1 | |||||
else: | |||||
alpha = weight - tgt_mr | |||||
# signal approximation | |||||
yr = mixed_spec[..., 0] | |||||
yi = mixed_spec[..., 1] | |||||
loss1 = torch.sum(alpha * torch.pow((mr * yr - mi * yi) - clean_spec[..., 0], 2)) \ | |||||
+ torch.sum(alpha * torch.pow((mr * yi + mi * yr) - clean_spec[..., 1], 2)) | |||||
# mask approximation | |||||
loss2 = torch.sum(alpha * torch.pow(mr - tgt_mr, 2)) \ | |||||
+ torch.sum(alpha * torch.pow(mi - tgt_mi, 2)) | |||||
loss = 0.5 * (loss1 + loss2) / torch.sum(nframes) | |||||
return loss | |||||
def crm_miso_loss_dlen(mixed, clean, masks, nframes): | |||||
return crm_loss_dlen(mixed[..., 0], clean[..., 0], masks, nframes) | |||||
def mimo_loss_dlen(mixed, clean, masks, nframes): | |||||
chs = mixed.shape[-1] | |||||
D = masks.shape[2] // chs | |||||
loss = psm_vad_loss_dlen(mixed[..., 0], clean[..., 0], masks[..., :D], | |||||
nframes) | |||||
for ch in range(1, chs): | |||||
loss1 = psm_vad_loss_dlen(mixed[..., ch], clean[..., ch], | |||||
masks[..., ch * D:ch * D + D], nframes) | |||||
loss = loss + loss1 | |||||
return loss / chs | |||||
def spec_loss_dlen(mixed, clean, spec, nframes): | |||||
clean_spec = stft(clean).permute([0, 2, 1, 3]) | |||||
clean_spec = clean_spec / 32768 | |||||
D = spec.shape[2] // 2 | |||||
spec_est = torch.cat([spec[..., :D, None], spec[..., D:, None]], | |||||
dim=-1) | |||||
loss = spectrum_loss(clean_spec, spec_est, nframes) | |||||
return loss | |||||
if loss_func == 'psm_vad_loss_dlen': | |||||
return psm_vad_loss_dlen | |||||
elif loss_func == 'sisdr_loss_dlen': | |||||
return sisdr_loss_dlen | |||||
elif loss_func == 'sisdr_freq_loss_dlen': | |||||
return sisdr_freq_loss_dlen | |||||
elif loss_func == 'crm_loss_dlen': | |||||
return crm_loss_dlen | |||||
elif loss_func == 'modulation_loss': | |||||
return modulation_loss | |||||
elif loss_func == 'wav2vec_loss': | |||||
return wav2vec_loss | |||||
elif loss_func == 'mimo_loss_dlen': | |||||
return mimo_loss_dlen | |||||
elif loss_func == 'spec_loss_dlen': | |||||
return spec_loss_dlen | |||||
elif loss_func == 'sa_loss_dlen': | |||||
return sa_loss_dlen | |||||
else: | |||||
print('error loss func') | |||||
return None |
@@ -0,0 +1,248 @@ | |||||
import math | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torchaudio.transforms import MelScale | |||||
class ModulationDomainLossModule(torch.nn.Module): | |||||
"""Modulation-domain loss function developed in [1] for supervised speech enhancement | |||||
In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram | |||||
as the input spectrogram representation. | |||||
Specific parameter details are in the paper and in the example below | |||||
Parameters | |||||
---------- | |||||
modulation_kernels: nn.Module | |||||
Differentiable module that transforms a spectrogram representation to the modulation domain | |||||
modulation_domain = modulation_kernels(input_tf_representation) | |||||
Input Spectrogram representation (B, T, F) ---> |(M) modulation_kernels|--->Modulation Domain(B, M, T', F') | |||||
norm: boolean | |||||
Normalizes the modulation domain representation to be 0 mean across time | |||||
[1] T. Vuong, Y. Xia, and R. M. Stern, “A modulation-domain lossfor neural-network-based real-time | |||||
speech enhancement” | |||||
Accepted ICASSP 2021, https://arxiv.org/abs/2102.07330 | |||||
""" | |||||
def __init__(self, modulation_kernels, norm=True): | |||||
super(ModulationDomainLossModule, self).__init__() | |||||
self.modulation_kernels = modulation_kernels | |||||
self.mse = nn.MSELoss(reduce=False) | |||||
self.norm = norm | |||||
def forward(self, enhanced_spect, clean_spect, weight=None): | |||||
"""Calculate modulation-domain loss | |||||
Args: | |||||
enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). | |||||
clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). | |||||
Returns: | |||||
Tensor: Modulation-domain loss value. | |||||
""" | |||||
clean_mod = self.modulation_kernels(clean_spect) | |||||
enhanced_mod = self.modulation_kernels(enhanced_spect) | |||||
if self.norm: | |||||
mean_clean_mod = torch.mean(clean_mod, dim=2) | |||||
mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) | |||||
clean_mod = clean_mod - mean_clean_mod.unsqueeze(2) | |||||
enhanced_mod = enhanced_mod - mean_enhanced_mod.unsqueeze(2) | |||||
if weight is None: | |||||
alpha = 1 | |||||
else: # TF-mask weight | |||||
alpha = 1 + torch.sum(weight, dim=-1, keepdim=True).unsqueeze(1) | |||||
mod_mse_loss = self.mse(enhanced_mod, clean_mod) * alpha | |||||
mod_mse_loss = torch.mean( | |||||
torch.sum(mod_mse_loss, dim=(1, 2, 3)) | |||||
/ torch.sum(clean_mod**2, dim=(1, 2, 3))) | |||||
return mod_mse_loss | |||||
class ModulationDomainNCCLossModule(torch.nn.Module): | |||||
"""Modulation-domain loss function developed in [1] for supervised speech enhancement | |||||
# Speech Intelligibility Prediction Using Spectro-Temporal Modulation Analysis - based off of this | |||||
In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram | |||||
as the input spectrogram representation. | |||||
Specific parameter details are in the paper and in the example below | |||||
Parameters | |||||
---------- | |||||
modulation_kernels: nn.Module | |||||
Differentiable module that transforms a spectrogram representation to the modulation domain | |||||
modulation_domain = modulation_kernels(input_tf_representation) | |||||
Input Spectrogram representation(B, T, F) --- (M) modulation_kernels---> Modulation Domain(B, M, T', F') | |||||
[1] | |||||
""" | |||||
def __init__(self, modulation_kernels): | |||||
super(ModulationDomainNCCLossModule, self).__init__() | |||||
self.modulation_kernels = modulation_kernels | |||||
self.mse = nn.MSELoss(reduce=False) | |||||
def forward(self, enhanced_spect, clean_spect): | |||||
"""Calculate modulation-domain loss | |||||
Args: | |||||
enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). | |||||
clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). | |||||
Returns: | |||||
Tensor: Modulation-domain loss value. | |||||
""" | |||||
clean_mod = self.modulation_kernels(clean_spect) | |||||
enhanced_mod = self.modulation_kernels(enhanced_spect) | |||||
mean_clean_mod = torch.mean(clean_mod, dim=2) | |||||
mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) | |||||
normalized_clean = clean_mod - mean_clean_mod.unsqueeze(2) | |||||
normalized_enhanced = enhanced_mod - mean_enhanced_mod.unsqueeze(2) | |||||
inner_product = torch.sum( | |||||
normalized_clean * normalized_enhanced, dim=2) | |||||
normalized_denom = (torch.sum( | |||||
normalized_clean * normalized_clean, dim=2))**.5 * (torch.sum( | |||||
normalized_enhanced * normalized_enhanced, dim=2))**.5 | |||||
ncc = inner_product / normalized_denom | |||||
mod_mse_loss = torch.mean((ncc - 1.0)**2) | |||||
return mod_mse_loss | |||||
class GaborSTRFConv(nn.Module): | |||||
"""Gabor-STRF-based cross-correlation kernel.""" | |||||
def __init__(self, | |||||
supn, | |||||
supk, | |||||
nkern, | |||||
rates=None, | |||||
scales=None, | |||||
norm_strf=True, | |||||
real_only=False): | |||||
"""Instantiate a Gabor-based STRF convolution layer. | |||||
Parameters | |||||
---------- | |||||
supn: int | |||||
Time support in number of frames. Also the window length. | |||||
supk: int | |||||
Frequency support in number of channels. Also the window length. | |||||
nkern: int | |||||
Number of kernels, each with a learnable rate and scale. | |||||
rates: list of float, None | |||||
Initial values for temporal modulation. | |||||
scales: list of float, None | |||||
Initial values for spectral modulation. | |||||
norm_strf: Boolean | |||||
Normalize STRF kernels to be unit length | |||||
real_only: Boolean | |||||
If True, nkern REAL gabor-STRF kernels | |||||
If False, nkern//2 REAL and nkern//2 IMAGINARY gabor-STRF kernels | |||||
""" | |||||
super(GaborSTRFConv, self).__init__() | |||||
self.numN = supn | |||||
self.numK = supk | |||||
self.numKern = nkern | |||||
self.real_only = real_only | |||||
self.norm_strf = norm_strf | |||||
if not real_only: | |||||
nkern = nkern // 2 | |||||
if supk % 2 == 0: # force odd number | |||||
supk += 1 | |||||
self.supk = torch.arange(supk, dtype=torch.float32) | |||||
if supn % 2 == 0: # force odd number | |||||
supn += 1 | |||||
self.supn = torch.arange(supn, dtype=self.supk.dtype) | |||||
self.padding = (supn // 2, supk // 2) | |||||
# Set up learnable parameters | |||||
# for param in (rates, scales): | |||||
# assert (not param) or len(param) == nkern | |||||
if not rates: | |||||
rates = torch.rand(nkern) * math.pi / 2.0 | |||||
if not scales: | |||||
scales = (torch.rand(nkern) * 2.0 - 1.0) * math.pi / 2.0 | |||||
self.rates_ = nn.Parameter(torch.Tensor(rates)) | |||||
self.scales_ = nn.Parameter(torch.Tensor(scales)) | |||||
def strfs(self): | |||||
"""Make STRFs using the current parameters.""" | |||||
if self.supn.device != self.rates_.device: # for first run | |||||
self.supn = self.supn.to(self.rates_.device) | |||||
self.supk = self.supk.to(self.rates_.device) | |||||
n0, k0 = self.padding | |||||
nwind = .5 - .5 * \ | |||||
torch.cos(2 * math.pi * (self.supn + 1) / (len(self.supn) + 1)) | |||||
kwind = .5 - .5 * \ | |||||
torch.cos(2 * math.pi * (self.supk + 1) / (len(self.supk) + 1)) | |||||
new_wind = torch.matmul((nwind).unsqueeze(-1), (kwind).unsqueeze(0)) | |||||
n_n_0 = self.supn - n0 | |||||
k_k_0 = self.supk - k0 | |||||
n_mult = torch.matmul( | |||||
n_n_0.unsqueeze(1), | |||||
torch.ones((1, len(self.supk))).type(torch.FloatTensor).to( | |||||
self.rates_.device)) | |||||
k_mult = torch.matmul( | |||||
torch.ones((len(self.supn), | |||||
1)).type(torch.FloatTensor).to(self.rates_.device), | |||||
k_k_0.unsqueeze(0)) | |||||
inside = self.rates_.unsqueeze(1).unsqueeze( | |||||
1) * n_mult + self.scales_.unsqueeze(1).unsqueeze(1) * k_mult | |||||
real_strf = torch.cos(inside) * new_wind.unsqueeze(0) | |||||
if self.real_only: | |||||
final_strf = real_strf | |||||
else: | |||||
imag_strf = torch.sin(inside) * new_wind.unsqueeze(0) | |||||
final_strf = torch.cat([real_strf, imag_strf], dim=0) | |||||
if self.norm_strf: | |||||
final_strf = final_strf / (torch.sum( | |||||
final_strf**2, dim=(1, 2)).unsqueeze(1).unsqueeze(2))**.5 | |||||
return final_strf | |||||
def forward(self, sigspec): | |||||
"""Forward pass a batch of (real) spectra [Batch x Time x Frequency].""" | |||||
if len(sigspec.shape) == 2: # expand batch dimension if single eg | |||||
sigspec = sigspec.unsqueeze(0) | |||||
strfs = self.strfs().unsqueeze(1).type_as(sigspec) | |||||
out = F.conv2d(sigspec.unsqueeze(1), strfs, padding=self.padding) | |||||
return out | |||||
def __repr__(self): | |||||
"""Gabor filter""" | |||||
report = """ | |||||
+++++ Gabor Filter Kernels [{}], supn[{}], supk[{}] real only [{}] norm strf [{}] +++++ | |||||
""".format(self.numKern, self.numN, self.numK, self.real_only, | |||||
self.norm_strf) | |||||
return report |
@@ -0,0 +1,483 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from ..layers.activations import RectifiedLinear, Sigmoid | |||||
from ..layers.affine_transform import AffineTransform | |||||
from ..layers.deep_fsmn import DeepFsmn | |||||
from ..layers.uni_deep_fsmn import Conv2d, UniDeepFsmn | |||||
class MaskNet(nn.Module): | |||||
def __init__(self, | |||||
indim, | |||||
outdim, | |||||
layers=9, | |||||
hidden_dim=128, | |||||
hidden_dim2=None, | |||||
lorder=20, | |||||
rorder=0, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
crm=False, | |||||
vad=False, | |||||
linearout=False): | |||||
super(MaskNet, self).__init__() | |||||
self.linear1 = AffineTransform(indim, hidden_dim) | |||||
self.relu = RectifiedLinear(hidden_dim, hidden_dim) | |||||
if hidden_dim2 is None: | |||||
hidden_dim2 = hidden_dim | |||||
if rorder == 0: | |||||
repeats = [ | |||||
UniDeepFsmn( | |||||
hidden_dim, | |||||
hidden_dim, | |||||
lorder, | |||||
hidden_dim2, | |||||
dilation=dilation, | |||||
layer_norm=layer_norm, | |||||
dropout=dropout) for i in range(layers) | |||||
] | |||||
else: | |||||
repeats = [ | |||||
DeepFsmn( | |||||
hidden_dim, | |||||
hidden_dim, | |||||
lorder, | |||||
rorder, | |||||
hidden_dim2, | |||||
layer_norm=layer_norm, | |||||
dropout=dropout) for i in range(layers) | |||||
] | |||||
self.deepfsmn = nn.Sequential(*repeats) | |||||
self.linear2 = AffineTransform(hidden_dim, outdim) | |||||
self.crm = crm | |||||
if self.crm: | |||||
self.sig = nn.Tanh() | |||||
else: | |||||
self.sig = Sigmoid(outdim, outdim) | |||||
self.vad = vad | |||||
if self.vad: | |||||
self.linear3 = AffineTransform(hidden_dim, 1) | |||||
self.layers = layers | |||||
self.linearout = linearout | |||||
if self.linearout and self.vad: | |||||
print('Warning: not supported nnet') | |||||
def forward(self, feat, ctl=None): | |||||
x1 = self.linear1(feat) | |||||
x2 = self.relu(x1) | |||||
if ctl is not None: | |||||
ctl = min(ctl, self.layers - 1) | |||||
for i in range(ctl): | |||||
x2 = self.deepfsmn[i](x2) | |||||
mask = self.sig(self.linear2(x2)) | |||||
if self.vad: | |||||
vad = torch.sigmoid(self.linear3(x2)) | |||||
return mask, vad | |||||
else: | |||||
return mask | |||||
x3 = self.deepfsmn(x2) | |||||
if self.linearout: | |||||
return self.linear2(x3) | |||||
mask = self.sig(self.linear2(x3)) | |||||
if self.vad: | |||||
vad = torch.sigmoid(self.linear3(x3)) | |||||
return mask, vad | |||||
else: | |||||
return mask | |||||
def to_kaldi_nnet(self): | |||||
re_str = '' | |||||
re_str += '<Nnet>\n' | |||||
re_str += self.linear1.to_kaldi_nnet() | |||||
re_str += self.relu.to_kaldi_nnet() | |||||
for dfsmn in self.deepfsmn: | |||||
re_str += dfsmn.to_kaldi_nnet() | |||||
re_str += self.linear2.to_kaldi_nnet() | |||||
re_str += self.sig.to_kaldi_nnet() | |||||
re_str += '</Nnet>\n' | |||||
return re_str | |||||
def to_raw_nnet(self, fid): | |||||
self.linear1.to_raw_nnet(fid) | |||||
for dfsmn in self.deepfsmn: | |||||
dfsmn.to_raw_nnet(fid) | |||||
self.linear2.to_raw_nnet(fid) | |||||
class StageNet(nn.Module): | |||||
def __init__(self, | |||||
indim, | |||||
outdim, | |||||
layers=9, | |||||
layers2=6, | |||||
hidden_dim=128, | |||||
lorder=20, | |||||
rorder=0, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
crm=False, | |||||
vad=False, | |||||
linearout=False): | |||||
super(StageNet, self).__init__() | |||||
self.stage1 = nn.ModuleList() | |||||
self.stage2 = nn.ModuleList() | |||||
layer = nn.Sequential(nn.Linear(indim, hidden_dim), nn.ReLU()) | |||||
self.stage1.append(layer) | |||||
for i in range(layers): | |||||
layer = UniDeepFsmn( | |||||
hidden_dim, | |||||
hidden_dim, | |||||
lorder, | |||||
hidden_dim, | |||||
layer_norm=layer_norm, | |||||
dropout=dropout) | |||||
self.stage1.append(layer) | |||||
layer = nn.Sequential(nn.Linear(hidden_dim, 321), nn.Sigmoid()) | |||||
self.stage1.append(layer) | |||||
# stage2 | |||||
layer = nn.Sequential(nn.Linear(321 + indim, hidden_dim), nn.ReLU()) | |||||
self.stage2.append(layer) | |||||
for i in range(layers2): | |||||
layer = UniDeepFsmn( | |||||
hidden_dim, | |||||
hidden_dim, | |||||
lorder, | |||||
hidden_dim, | |||||
layer_norm=layer_norm, | |||||
dropout=dropout) | |||||
self.stage2.append(layer) | |||||
layer = nn.Sequential( | |||||
nn.Linear(hidden_dim, outdim), | |||||
nn.Sigmoid() if not crm else nn.Tanh()) | |||||
self.stage2.append(layer) | |||||
self.crm = crm | |||||
self.vad = vad | |||||
self.linearout = linearout | |||||
self.window = torch.hamming_window(640, periodic=False).cuda() | |||||
self.freezed = False | |||||
def freeze(self): | |||||
if not self.freezed: | |||||
for param in self.stage1.parameters(): | |||||
param.requires_grad = False | |||||
self.freezed = True | |||||
print('freezed stage1') | |||||
def forward(self, feat, mixture, ctl=None): | |||||
if ctl == 'off': | |||||
x = feat | |||||
for i in range(len(self.stage1)): | |||||
x = self.stage1[i](x) | |||||
return x | |||||
else: | |||||
self.freeze() | |||||
x = feat | |||||
for i in range(len(self.stage1)): | |||||
x = self.stage1[i](x) | |||||
spec = torch.stft( | |||||
mixture / 32768, | |||||
640, | |||||
320, | |||||
640, | |||||
self.window, | |||||
center=False, | |||||
return_complex=True) | |||||
spec = torch.view_as_real(spec).permute([0, 2, 1, 3]) | |||||
specmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) | |||||
est = x * specmag | |||||
y = torch.cat([est, feat], dim=-1) | |||||
for i in range(len(self.stage2)): | |||||
y = self.stage2[i](y) | |||||
return y | |||||
class Unet(nn.Module): | |||||
def __init__(self, | |||||
indim, | |||||
outdim, | |||||
layers=9, | |||||
dims=[256] * 4, | |||||
lorder=20, | |||||
rorder=0, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
crm=False, | |||||
vad=False, | |||||
linearout=False): | |||||
super(Unet, self).__init__() | |||||
self.linear1 = AffineTransform(indim, dims[0]) | |||||
self.relu = RectifiedLinear(dims[0], dims[0]) | |||||
self.encoder = nn.ModuleList() | |||||
self.decoder = nn.ModuleList() | |||||
for i in range(len(dims) - 1): | |||||
layer = nn.Sequential( | |||||
nn.Linear(dims[i], dims[i + 1]), nn.ReLU(), | |||||
nn.Linear(dims[i + 1], dims[i + 1], bias=False), | |||||
Conv2d( | |||||
dims[i + 1], | |||||
dims[i + 1], | |||||
lorder, | |||||
groups=dims[i + 1], | |||||
skip_connect=True)) | |||||
self.encoder.append(layer) | |||||
for i in range(len(dims) - 1, 0, -1): | |||||
layer = nn.Sequential( | |||||
nn.Linear(dims[i] * 2, dims[i - 1]), nn.ReLU(), | |||||
nn.Linear(dims[i - 1], dims[i - 1], bias=False), | |||||
Conv2d( | |||||
dims[i - 1], | |||||
dims[i - 1], | |||||
lorder, | |||||
groups=dims[i - 1], | |||||
skip_connect=True)) | |||||
self.decoder.append(layer) | |||||
self.tf = nn.ModuleList() | |||||
for i in range(layers - 2 * (len(dims) - 1)): | |||||
layer = nn.Sequential( | |||||
nn.Linear(dims[-1], dims[-1]), nn.ReLU(), | |||||
nn.Linear(dims[-1], dims[-1], bias=False), | |||||
Conv2d( | |||||
dims[-1], | |||||
dims[-1], | |||||
lorder, | |||||
groups=dims[-1], | |||||
skip_connect=True)) | |||||
self.tf.append(layer) | |||||
self.linear2 = AffineTransform(dims[0], outdim) | |||||
self.crm = crm | |||||
self.act = nn.Tanh() if self.crm else nn.Sigmoid() | |||||
self.vad = False | |||||
self.layers = layers | |||||
self.linearout = linearout | |||||
def forward(self, x, ctl=None): | |||||
x = self.linear1(x) | |||||
x = self.relu(x) | |||||
encoder_out = [] | |||||
for i in range(len(self.encoder)): | |||||
x = self.encoder[i](x) | |||||
encoder_out.append(x) | |||||
for i in range(len(self.tf)): | |||||
x = self.tf[i](x) | |||||
for i in range(len(self.decoder)): | |||||
x = torch.cat([x, encoder_out[-1 - i]], dim=-1) | |||||
x = self.decoder[i](x) | |||||
x = self.linear2(x) | |||||
if self.linearout: | |||||
return x | |||||
return self.act(x) | |||||
class BranchNet(nn.Module): | |||||
def __init__(self, | |||||
indim, | |||||
outdim, | |||||
layers=9, | |||||
hidden_dim=256, | |||||
lorder=20, | |||||
rorder=0, | |||||
dilation=1, | |||||
layer_norm=False, | |||||
dropout=0, | |||||
crm=False, | |||||
vad=False, | |||||
linearout=False): | |||||
super(BranchNet, self).__init__() | |||||
self.linear1 = AffineTransform(indim, hidden_dim) | |||||
self.relu = RectifiedLinear(hidden_dim, hidden_dim) | |||||
self.convs = nn.ModuleList() | |||||
self.deepfsmn = nn.ModuleList() | |||||
self.FREQ = nn.ModuleList() | |||||
self.TIME = nn.ModuleList() | |||||
self.br1 = nn.ModuleList() | |||||
self.br2 = nn.ModuleList() | |||||
for i in range(layers): | |||||
''' | |||||
layer = nn.Sequential( | |||||
nn.Linear(hidden_dim, hidden_dim), | |||||
nn.ReLU(), | |||||
nn.Linear(hidden_dim, hidden_dim, bias=False), | |||||
Conv2d(hidden_dim, hidden_dim, lorder, | |||||
groups=hidden_dim, skip_connect=True) | |||||
) | |||||
self.deepfsmn.append(layer) | |||||
''' | |||||
layer = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) | |||||
self.FREQ.append(layer) | |||||
''' | |||||
layer = nn.GRU(hidden_dim, hidden_dim, | |||||
batch_first=True, | |||||
bidirectional=False) | |||||
self.TIME.append(layer) | |||||
layer = nn.Sequential( | |||||
nn.Linear(hidden_dim, hidden_dim//2, bias=False), | |||||
Conv2d(hidden_dim//2, hidden_dim//2, lorder, | |||||
groups=hidden_dim//2, skip_connect=True) | |||||
) | |||||
self.br1.append(layer) | |||||
layer = nn.GRU(hidden_dim, hidden_dim//2, | |||||
batch_first=True, | |||||
bidirectional=False) | |||||
self.br2.append(layer) | |||||
''' | |||||
self.linear2 = AffineTransform(hidden_dim, outdim) | |||||
self.crm = crm | |||||
self.act = nn.Tanh() if self.crm else nn.Sigmoid() | |||||
self.vad = False | |||||
self.layers = layers | |||||
self.linearout = linearout | |||||
def forward(self, x, ctl=None): | |||||
return self.forward_branch(x) | |||||
def forward_sepconv(self, x): | |||||
x = torch.unsqueeze(x, 1) | |||||
for i in range(len(self.convs)): | |||||
x = self.convs[i](x) | |||||
x = F.relu(x) | |||||
B, C, H, W = x.shape | |||||
x = x.permute(0, 2, 1, 3) | |||||
x = torch.reshape(x, [B, H, C * W]) | |||||
x = self.linear1(x) | |||||
x = self.relu(x) | |||||
for i in range(self.layers): | |||||
x = self.deepfsmn[i](x) + x | |||||
x = self.linear2(x) | |||||
return self.act(x) | |||||
def forward_branch(self, x): | |||||
x = self.linear1(x) | |||||
x = self.relu(x) | |||||
for i in range(self.layers): | |||||
z = self.FREQ[i](x) | |||||
x = z + x | |||||
x = self.linear2(x) | |||||
if self.linearout: | |||||
return x | |||||
return self.act(x) | |||||
class TACNet(nn.Module): | |||||
''' transform average concatenate for ad hoc dr | |||||
''' | |||||
def __init__(self, | |||||
indim, | |||||
outdim, | |||||
layers=9, | |||||
hidden_dim=128, | |||||
lorder=20, | |||||
rorder=0, | |||||
crm=False, | |||||
vad=False, | |||||
linearout=False): | |||||
super(TACNet, self).__init__() | |||||
self.linear1 = AffineTransform(indim, hidden_dim) | |||||
self.relu = RectifiedLinear(hidden_dim, hidden_dim) | |||||
if rorder == 0: | |||||
repeats = [ | |||||
UniDeepFsmn(hidden_dim, hidden_dim, lorder, hidden_dim) | |||||
for i in range(layers) | |||||
] | |||||
else: | |||||
repeats = [ | |||||
DeepFsmn(hidden_dim, hidden_dim, lorder, rorder, hidden_dim) | |||||
for i in range(layers) | |||||
] | |||||
self.deepfsmn = nn.Sequential(*repeats) | |||||
self.ch_transform = nn.ModuleList([]) | |||||
self.ch_average = nn.ModuleList([]) | |||||
self.ch_concat = nn.ModuleList([]) | |||||
for i in range(layers): | |||||
self.ch_transform.append( | |||||
nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) | |||||
self.ch_average.append( | |||||
nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) | |||||
self.ch_concat.append( | |||||
nn.Sequential( | |||||
nn.Linear(hidden_dim * 2, hidden_dim), nn.PReLU())) | |||||
self.linear2 = AffineTransform(hidden_dim, outdim) | |||||
self.crm = crm | |||||
if self.crm: | |||||
self.sig = nn.Tanh() | |||||
else: | |||||
self.sig = Sigmoid(outdim, outdim) | |||||
self.vad = vad | |||||
if self.vad: | |||||
self.linear3 = AffineTransform(hidden_dim, 1) | |||||
self.layers = layers | |||||
self.linearout = linearout | |||||
if self.linearout and self.vad: | |||||
print('Warning: not supported nnet') | |||||
def forward(self, feat, ctl=None): | |||||
B, T, F = feat.shape | |||||
# assume 4ch | |||||
ch = 4 | |||||
zlist = [] | |||||
for c in range(ch): | |||||
z = self.linear1(feat[..., c * (F // 4):(c + 1) * (F // 4)]) | |||||
z = self.relu(z) | |||||
zlist.append(z) | |||||
for i in range(self.layers): | |||||
# forward | |||||
for c in range(ch): | |||||
zlist[c] = self.deepfsmn[i](zlist[c]) | |||||
# transform | |||||
olist = [] | |||||
for c in range(ch): | |||||
z = self.ch_transform[i](zlist[c]) | |||||
olist.append(z) | |||||
# average | |||||
avg = 0 | |||||
for c in range(ch): | |||||
avg = avg + olist[c] | |||||
avg = avg / ch | |||||
avg = self.ch_average[i](avg) | |||||
# concate | |||||
for c in range(ch): | |||||
tac = torch.cat([olist[c], avg], dim=-1) | |||||
tac = self.ch_concat[i](tac) | |||||
zlist[c] = zlist[c] + tac | |||||
for c in range(ch): | |||||
zlist[c] = self.sig(self.linear2(zlist[c])) | |||||
mask = torch.cat(zlist, dim=-1) | |||||
return mask | |||||
def to_kaldi_nnet(self): | |||||
pass |
@@ -1,4 +1,4 @@ | |||||
from .audio import * # noqa F403 | |||||
from .audio import LinearAECPipeline | |||||
from .base import Pipeline | from .base import Pipeline | ||||
from .builder import pipeline | from .builder import pipeline | ||||
from .cv import * # noqa F403 | from .cv import * # noqa F403 | ||||
@@ -0,0 +1 @@ | |||||
from .linear_aec_pipeline import LinearAECPipeline |
@@ -0,0 +1,160 @@ | |||||
import importlib | |||||
import os | |||||
from typing import Any, Dict | |||||
import numpy as np | |||||
import scipy.io.wavfile as wav | |||||
import torch | |||||
import yaml | |||||
from modelscope.preprocessors.audio import LinearAECAndFbank | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from ..base import Pipeline | |||||
from ..builder import PIPELINES | |||||
FEATURE_MVN = 'feature.DEY.mvn.txt' | |||||
CONFIG_YAML = 'dey_mini.yaml' | |||||
def initialize_config(module_cfg): | |||||
r"""According to config items, load specific module dynamically with params. | |||||
1. Load the module corresponding to the "module" param. | |||||
2. Call function (or instantiate class) corresponding to the "main" param. | |||||
3. Send the param (in "args") into the function (or class) when calling ( or instantiating). | |||||
Args: | |||||
module_cfg (dict): config items, eg: | |||||
{ | |||||
"module": "models.model", | |||||
"main": "Model", | |||||
"args": {...} | |||||
} | |||||
Returns: | |||||
the module loaded. | |||||
""" | |||||
module = importlib.import_module(module_cfg['module']) | |||||
return getattr(module, module_cfg['main'])(**module_cfg['args']) | |||||
@PIPELINES.register_module( | |||||
Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k') | |||||
class LinearAECPipeline(Pipeline): | |||||
r"""AEC Inference Pipeline only support 16000 sample rate. | |||||
When invoke the class with pipeline.__call__(), you should provide two params: | |||||
Dict[str, Any] | |||||
the path of wav files,eg:{ | |||||
"nearend_mic": "/your/data/near_end_mic_audio.wav", | |||||
"farend_speech": "/your/data/far_end_speech_audio.wav"} | |||||
output_path (str, optional): "/your/output/audio_after_aec.wav" | |||||
the file path to write generate audio. | |||||
""" | |||||
def __init__(self, model): | |||||
r""" | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model) | |||||
self.use_cuda = torch.cuda.is_available() | |||||
with open( | |||||
os.path.join(self.model, CONFIG_YAML), encoding='utf-8') as f: | |||||
self.config = yaml.full_load(f.read()) | |||||
self.config['io']['mvn'] = os.path.join(self.model, FEATURE_MVN) | |||||
self._init_model() | |||||
self.preprocessor = LinearAECAndFbank(self.config['io']) | |||||
n_fft = self.config['loss']['args']['n_fft'] | |||||
hop_length = self.config['loss']['args']['hop_length'] | |||||
winlen = n_fft | |||||
window = torch.hamming_window(winlen, periodic=False) | |||||
def stft(x): | |||||
return torch.stft( | |||||
x, | |||||
n_fft, | |||||
hop_length, | |||||
winlen, | |||||
center=False, | |||||
window=window.to(x.device), | |||||
return_complex=False) | |||||
def istft(x, slen): | |||||
return torch.istft( | |||||
x, | |||||
n_fft, | |||||
hop_length, | |||||
winlen, | |||||
window=window.to(x.device), | |||||
center=False, | |||||
length=slen) | |||||
self.stft = stft | |||||
self.istft = istft | |||||
def _init_model(self): | |||||
checkpoint = torch.load( | |||||
os.path.join(self.model, ModelFile.TORCH_MODEL_BIN_FILE), | |||||
map_location='cpu') | |||||
self.model = initialize_config(self.config['nnet']) | |||||
if self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
self.model.load_state_dict(checkpoint) | |||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
r"""The AEC process. | |||||
Args: | |||||
inputs: dict={'feature': Tensor, 'base': Tensor} | |||||
'feature' feature of input audio. | |||||
'base' the base audio to mask. | |||||
Returns: | |||||
dict: | |||||
{ | |||||
'output_pcm': generated audio array | |||||
} | |||||
""" | |||||
output_data = self._process(inputs['feature'], inputs['base']) | |||||
return {'output_pcm': output_data} | |||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
r"""The post process. Will save audio to file, if the output_path is given. | |||||
Args: | |||||
inputs: dict: | |||||
{ | |||||
'output_pcm': generated audio array | |||||
} | |||||
kwargs: accept 'output_path' which is the path to write generated audio | |||||
Returns: | |||||
dict: | |||||
{ | |||||
'output_pcm': generated audio array | |||||
} | |||||
""" | |||||
if 'output_path' in kwargs.keys(): | |||||
wav.write(kwargs['output_path'], self.preprocessor.SAMPLE_RATE, | |||||
inputs['output_pcm'].astype(np.int16)) | |||||
inputs['output_pcm'] = inputs['output_pcm'] / 32768.0 | |||||
return inputs | |||||
def _process(self, fbanks, mixture): | |||||
if self.use_cuda: | |||||
fbanks = fbanks.cuda() | |||||
mixture = mixture.cuda() | |||||
if self.model.vad: | |||||
with torch.no_grad(): | |||||
masks, vad = self.model(fbanks.unsqueeze(0)) | |||||
masks = masks.permute([2, 1, 0]) | |||||
else: | |||||
with torch.no_grad(): | |||||
masks = self.model(fbanks.unsqueeze(0)) | |||||
masks = masks.permute([2, 1, 0]) | |||||
spectrum = self.stft(mixture) | |||||
masked_spec = spectrum * masks | |||||
masked_sig = self.istft(masked_spec, len(mixture)).cpu().numpy() | |||||
return masked_sig |
@@ -84,6 +84,12 @@ TASK_OUTPUTS = { | |||||
# ============ audio tasks =================== | # ============ audio tasks =================== | ||||
# audio processed for single file in PCM format | |||||
# { | |||||
# "output_pcm": np.array with shape(samples,) and dtype float32 | |||||
# } | |||||
Tasks.speech_signal_process: ['output_pcm'], | |||||
# ============ multi-modal tasks =================== | # ============ multi-modal tasks =================== | ||||
# image caption result for single sample | # image caption result for single sample | ||||
@@ -1,5 +1,6 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from .audio import LinearAECAndFbank | |||||
from .base import Preprocessor | from .base import Preprocessor | ||||
from .builder import PREPROCESSORS, build_preprocessor | from .builder import PREPROCESSORS, build_preprocessor | ||||
from .common import Compose | from .common import Compose | ||||
@@ -0,0 +1,230 @@ | |||||
import ctypes | |||||
import os | |||||
from typing import Any, Dict | |||||
import numpy as np | |||||
import scipy.io.wavfile as wav | |||||
import torch | |||||
import torchaudio.compliance.kaldi as kaldi | |||||
from numpy.ctypeslib import ndpointer | |||||
from modelscope.utils.constant import Fields | |||||
from .builder import PREPROCESSORS | |||||
def load_wav(path): | |||||
samp_rate, data = wav.read(path) | |||||
return np.float32(data), samp_rate | |||||
def load_library(libaec): | |||||
libaec_in_cwd = os.path.join('.', libaec) | |||||
if os.path.exists(libaec_in_cwd): | |||||
libaec = libaec_in_cwd | |||||
mitaec = ctypes.cdll.LoadLibrary(libaec) | |||||
fe_process = mitaec.fe_process_inst | |||||
fe_process.argtypes = [ | |||||
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), | |||||
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), ctypes.c_int, | |||||
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), | |||||
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), | |||||
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS') | |||||
] | |||||
return fe_process | |||||
def do_linear_aec(fe_process, mic, ref, int16range=True): | |||||
mic = np.float32(mic) | |||||
ref = np.float32(ref) | |||||
if len(mic) > len(ref): | |||||
mic = mic[:len(ref)] | |||||
out_mic = np.zeros_like(mic) | |||||
out_linear = np.zeros_like(mic) | |||||
out_echo = np.zeros_like(mic) | |||||
out_ref = np.zeros_like(mic) | |||||
if int16range: | |||||
mic /= 32768 | |||||
ref /= 32768 | |||||
fe_process(mic, ref, len(mic), out_mic, out_linear, out_echo) | |||||
# out_ref not in use here | |||||
if int16range: | |||||
out_mic *= 32768 | |||||
out_linear *= 32768 | |||||
out_echo *= 32768 | |||||
return out_mic, out_ref, out_linear, out_echo | |||||
def load_kaldi_feature_transform(filename): | |||||
fp = open(filename, 'r') | |||||
all_str = fp.read() | |||||
pos1 = all_str.find('AddShift') | |||||
pos2 = all_str.find('[', pos1) | |||||
pos3 = all_str.find(']', pos2) | |||||
mean = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') | |||||
pos1 = all_str.find('Rescale') | |||||
pos2 = all_str.find('[', pos1) | |||||
pos3 = all_str.find(']', pos2) | |||||
scale = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') | |||||
fp.close() | |||||
return mean, scale | |||||
class Feature: | |||||
r"""Extract feat from one utterance. | |||||
""" | |||||
def __init__(self, | |||||
fbank_config, | |||||
feat_type='spec', | |||||
mvn_file=None, | |||||
cuda=False): | |||||
r""" | |||||
Args: | |||||
fbank_config (dict): | |||||
feat_type (str): | |||||
raw: do nothing | |||||
fbank: use kaldi.fbank | |||||
spec: Real/Imag | |||||
logpow: log(1+|x|^2) | |||||
mvn_file (str): the path of data file for mean variance normalization | |||||
cuda: | |||||
""" | |||||
self.fbank_config = fbank_config | |||||
self.feat_type = feat_type | |||||
self.n_fft = fbank_config['frame_length'] * fbank_config[ | |||||
'sample_frequency'] // 1000 | |||||
self.hop_length = fbank_config['frame_shift'] * fbank_config[ | |||||
'sample_frequency'] // 1000 | |||||
self.window = torch.hamming_window(self.n_fft, periodic=False) | |||||
self.mvn = False | |||||
if mvn_file is not None and os.path.exists(mvn_file): | |||||
print(f'loading mvn file: {mvn_file}') | |||||
shift, scale = load_kaldi_feature_transform(mvn_file) | |||||
self.shift = torch.from_numpy(shift) | |||||
self.scale = torch.from_numpy(scale) | |||||
self.mvn = True | |||||
if cuda: | |||||
self.window = self.window.cuda() | |||||
if self.mvn: | |||||
self.shift = self.shift.cuda() | |||||
self.scale = self.scale.cuda() | |||||
def compute(self, utt): | |||||
r""" | |||||
Args: | |||||
utt: in [-32768, 32767] range | |||||
Returns: | |||||
[..., T, F] | |||||
""" | |||||
if self.feat_type == 'raw': | |||||
return utt | |||||
elif self.feat_type == 'fbank': | |||||
if len(utt.shape) == 1: | |||||
utt = utt.unsqueeze(0) | |||||
feat = kaldi.fbank(utt, **self.fbank_config) | |||||
elif self.feat_type == 'spec': | |||||
spec = torch.stft( | |||||
utt / 32768, | |||||
self.n_fft, | |||||
self.hop_length, | |||||
self.n_fft, | |||||
self.window, | |||||
center=False, | |||||
return_complex=True) | |||||
feat = torch.cat([spec.real, spec.imag], dim=-2).permute(-1, -2) | |||||
elif self.feat_type == 'logpow': | |||||
spec = torch.stft( | |||||
utt, | |||||
self.n_fft, | |||||
self.hop_length, | |||||
self.n_fft, | |||||
self.window, | |||||
center=False, | |||||
return_complex=True) | |||||
abspow = torch.abs(spec)**2 | |||||
feat = torch.log(1 + abspow).permute(-1, -2) | |||||
return feat | |||||
def normalize(self, feat): | |||||
if self.mvn: | |||||
feat = feat + self.shift | |||||
feat = feat * self.scale | |||||
return feat | |||||
@PREPROCESSORS.register_module(Fields.audio) | |||||
class LinearAECAndFbank: | |||||
SAMPLE_RATE = 16000 | |||||
def __init__(self, io_config): | |||||
self.trunc_length = 7200 * self.SAMPLE_RATE | |||||
self.linear_aec_delay = io_config['linear_aec_delay'] | |||||
self.feature = Feature(io_config['fbank_config'], | |||||
io_config['feat_type'], io_config['mvn']) | |||||
self.mitaec = load_library(io_config['mitaec_library']) | |||||
self.mask_on_mic = io_config['mask_on'] == 'nearend_mic' | |||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
""" linear filtering the near end mic and far end audio, then extract the feature | |||||
:param data: dict with two keys and correspond audios: "nearend_mic" and "farend_speech" | |||||
:return: dict with two keys and Tensor values: "base" linear filtered audio,and "feature" | |||||
""" | |||||
# read files | |||||
nearend_mic, fs = load_wav(data['nearend_mic']) | |||||
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' | |||||
farend_speech, fs = load_wav(data['farend_speech']) | |||||
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' | |||||
if 'nearend_speech' in data: | |||||
nearend_speech, fs = load_wav(data['nearend_speech']) | |||||
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' | |||||
else: | |||||
nearend_speech = np.zeros_like(nearend_mic) | |||||
out_mic, out_ref, out_linear, out_echo = do_linear_aec( | |||||
self.mitaec, nearend_mic, farend_speech) | |||||
# fix 20ms linear aec delay by delaying the target speech | |||||
extra_zeros = np.zeros([int(self.linear_aec_delay * fs)]) | |||||
nearend_speech = np.concatenate([extra_zeros, nearend_speech]) | |||||
# truncate files to the same length | |||||
flen = min( | |||||
len(out_mic), len(out_ref), len(out_linear), len(out_echo), | |||||
len(nearend_speech)) | |||||
fstart = 0 | |||||
flen = min(flen, self.trunc_length) | |||||
nearend_mic, out_ref, out_linear, out_echo, nearend_speech = ( | |||||
out_mic[fstart:flen], out_ref[fstart:flen], | |||||
out_linear[fstart:flen], out_echo[fstart:flen], | |||||
nearend_speech[fstart:flen]) | |||||
# extract features (frames, [mic, linear, ref, aes?]) | |||||
feat = torch.FloatTensor() | |||||
nearend_mic = torch.from_numpy(np.float32(nearend_mic)) | |||||
fbank_nearend_mic = self.feature.compute(nearend_mic) | |||||
feat = torch.cat([feat, fbank_nearend_mic], dim=1) | |||||
out_linear = torch.from_numpy(np.float32(out_linear)) | |||||
fbank_out_linear = self.feature.compute(out_linear) | |||||
feat = torch.cat([feat, fbank_out_linear], dim=1) | |||||
out_echo = torch.from_numpy(np.float32(out_echo)) | |||||
fbank_out_echo = self.feature.compute(out_echo) | |||||
feat = torch.cat([feat, fbank_out_echo], dim=1) | |||||
# feature transform | |||||
feat = self.feature.normalize(feat) | |||||
# prepare target | |||||
if nearend_speech is not None: | |||||
nearend_speech = torch.from_numpy(np.float32(nearend_speech)) | |||||
if self.mask_on_mic: | |||||
base = nearend_mic | |||||
else: | |||||
base = out_linear | |||||
out_data = {'base': base, 'target': nearend_speech, 'feature': feat} | |||||
return out_data |
@@ -7,6 +7,7 @@ opencv-python-headless | |||||
Pillow>=6.2.0 | Pillow>=6.2.0 | ||||
pyyaml | pyyaml | ||||
requests | requests | ||||
scipy | |||||
tokenizers<=0.10.3 | tokenizers<=0.10.3 | ||||
transformers<=4.16.2 | transformers<=4.16.2 | ||||
yapf | yapf |
@@ -11,6 +11,7 @@ default_section = THIRDPARTY | |||||
BASED_ON_STYLE = pep8 | BASED_ON_STYLE = pep8 | ||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true | ||||
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true | ||||
SPLIT_BEFORE_ARITHMETIC_OPERATOR = true | |||||
[codespell] | [codespell] | ||||
skip = *.ipynb | skip = *.ipynb | ||||
@@ -20,5 +21,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||||
[flake8] | [flake8] | ||||
select = B,C,E,F,P,T4,W,B9 | select = B,C,E,F,P,T4,W,B9 | ||||
max-line-length = 120 | max-line-length = 120 | ||||
ignore = F401,F821 | |||||
ignore = F401,F821,W503 | |||||
exclude = docs/src,*.pyi,.git | exclude = docs/src,*.pyi,.git |
@@ -0,0 +1,56 @@ | |||||
import os.path | |||||
import shutil | |||||
import unittest | |||||
from modelscope.fileio import File | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.hub import get_model_cache_dir | |||||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | |||||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | |||||
NEAREND_MIC_FILE = 'nearend_mic.wav' | |||||
FAREND_SPEECH_FILE = 'farend_speech.wav' | |||||
AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Flib%2Flibmitaec_pyio.so' \ | |||||
'?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' | |||||
AEC_LIB_FILE = 'libmitaec_pyio.so' | |||||
def download(remote_path, local_path): | |||||
local_dir = os.path.dirname(local_path) | |||||
if len(local_dir) > 0: | |||||
if not os.path.exists(local_dir): | |||||
os.makedirs(local_dir) | |||||
with open(local_path, 'wb') as ofile: | |||||
ofile.write(File.read(remote_path)) | |||||
class SpeechSignalProcessTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||||
# switch to False if downloading everytime is not desired | |||||
purge_cache = True | |||||
if purge_cache: | |||||
shutil.rmtree( | |||||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
# A temporary hack to provide c++ lib. Download it first. | |||||
download(AEC_LIB_URL, AEC_LIB_FILE) | |||||
def test_run(self): | |||||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||||
input = { | |||||
'nearend_mic': NEAREND_MIC_FILE, | |||||
'farend_speech': FAREND_SPEECH_FILE | |||||
} | |||||
aec = pipeline( | |||||
Tasks.speech_signal_process, | |||||
model=self.model_id, | |||||
pipeline_name=r'speech_dfsmn_aec_psm_16k') | |||||
aec(input, output_path='output.wav') | |||||
if __name__ == '__main__': | |||||
unittest.main() |