From 31498c1d6a81611f5a58db5e3d6e983e8f386ed1 Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Fri, 17 Jun 2022 19:56:11 +0800 Subject: [PATCH] [to #41669377] add speech AEC pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8973072 * [to #41669377] docs and tools refinement and release 1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307 * [to #41669377] add pipeline tutorial and fix bugs 1. add pipleine tutorial 2. fix bugs when using pipeline with certain model and preprocessor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301 * refine doc * feat: add audio aec pipeline and preprocessor * feat: add audio aec model classes * feat: add audio aec loss functions * refactor:delete no longer used loss function * [to #42281043] support kwargs in pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 * support kwargs in pipeline * update develop doc with CR instruction * Merge branch 'release/0.1' into dev/aec * style: reformat code by pre-commit tools * feat:support maas_lib pipeline auto downloading model * test:add aec test case as sample code * feat:aec pipeline use config from maashub * feat:aec pipeline use feature parameters from maashub * update setup.cfg to disable PEP8 rule W503 in flake8 and yapf * format:fix double quoted strings, indent issues and optimize import * refactor:extract some constant in aec pipeline * refactor: delete no longer used __main__ statement * chore:change all Chinese comments to English * fix: change file name style to lower case * refactor: rename model name * feat:load C++ .so from LD_LIBRARY_PATH * feat:register PROPROCESSOR for LinearAECAndFbank * refactory:move aec process from postprocess() to forward() and update comments * refactory:add more readable error message when audio sample rate is not 16000 * fix: package maas_lib renamed to modelscope in import statement * feat: optimize the error message of audio layer classes * format: delete empty lines * refactor: rename audio preprocessor and optimize error message * refactor: change aec model id to damo/speech_dfsmn_aec_psm_16k * refactor: change sample audio file url to public oss * Merge branch 'master' into dev/aec * feat: add output info for aec pipeline * fix: normalize output audio data to [-1.0, 1.0] * refactor:use constant from ModelFile * feat: AEC pipeline can use c++ lib in current working directory and the test will download it * fix: c++ downloading should work wherever test is triggerd --- modelscope/models/audio/__init__.py | 0 modelscope/models/audio/layers/__init__.py | 0 modelscope/models/audio/layers/activations.py | 60 +++ .../models/audio/layers/affine_transform.py | 78 +++ modelscope/models/audio/layers/deep_fsmn.py | 178 +++++++ modelscope/models/audio/layers/layer_base.py | 50 ++ .../models/audio/layers/uni_deep_fsmn.py | 482 +++++++++++++++++ modelscope/models/audio/network/__init__.py | 0 modelscope/models/audio/network/loss.py | 394 ++++++++++++++ .../models/audio/network/modulation_loss.py | 248 +++++++++ modelscope/models/audio/network/se_net.py | 483 ++++++++++++++++++ modelscope/pipelines/__init__.py | 2 +- modelscope/pipelines/audio/__init__.py | 1 + .../pipelines/audio/linear_aec_pipeline.py | 160 ++++++ modelscope/pipelines/outputs.py | 6 + modelscope/preprocessors/__init__.py | 1 + modelscope/preprocessors/audio.py | 230 +++++++++ requirements/runtime.txt | 1 + setup.cfg | 3 +- tests/pipelines/test_speech_signal_process.py | 56 ++ 20 files changed, 2431 insertions(+), 2 deletions(-) create mode 100644 modelscope/models/audio/__init__.py create mode 100644 modelscope/models/audio/layers/__init__.py create mode 100644 modelscope/models/audio/layers/activations.py create mode 100644 modelscope/models/audio/layers/affine_transform.py create mode 100644 modelscope/models/audio/layers/deep_fsmn.py create mode 100644 modelscope/models/audio/layers/layer_base.py create mode 100644 modelscope/models/audio/layers/uni_deep_fsmn.py create mode 100644 modelscope/models/audio/network/__init__.py create mode 100644 modelscope/models/audio/network/loss.py create mode 100644 modelscope/models/audio/network/modulation_loss.py create mode 100644 modelscope/models/audio/network/se_net.py create mode 100644 modelscope/pipelines/audio/linear_aec_pipeline.py create mode 100644 modelscope/preprocessors/audio.py create mode 100644 tests/pipelines/test_speech_signal_process.py diff --git a/modelscope/models/audio/__init__.py b/modelscope/models/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/layers/__init__.py b/modelscope/models/audio/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/layers/activations.py b/modelscope/models/audio/layers/activations.py new file mode 100644 index 00000000..b0215bcc --- /dev/null +++ b/modelscope/models/audio/layers/activations.py @@ -0,0 +1,60 @@ +import torch.nn as nn + +from .layer_base import LayerBase + + +class RectifiedLinear(LayerBase): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + + def forward(self, input): + return self.relu(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class LogSoftmax(LayerBase): + + def __init__(self, input_dim, output_dim): + super(LogSoftmax, self).__init__() + self.dim = input_dim + self.ls = nn.LogSoftmax() + + def forward(self, input): + return self.ls(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class Sigmoid(LayerBase): + + def __init__(self, input_dim, output_dim): + super(Sigmoid, self).__init__() + self.dim = input_dim + self.sig = nn.Sigmoid() + + def forward(self, input): + return self.sig(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr diff --git a/modelscope/models/audio/layers/affine_transform.py b/modelscope/models/audio/layers/affine_transform.py new file mode 100644 index 00000000..33479505 --- /dev/null +++ b/modelscope/models/audio/layers/affine_transform.py @@ -0,0 +1,78 @@ +import numpy as np +import torch as th +import torch.nn as nn + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class AffineTransform(LayerBase): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, input): + return self.linear(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def to_raw_nnet(self, fid): + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + x.tofile(fid) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + x.tofile(fid) + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('AffineTransform format error for ') + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception( + 'AffineTransform format error for ') + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception('AffineTransform format error for ') + instr, lr = output + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('AffineTransform format error for parsing matrix') + instr, mat = output + + print(mat.shape) + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('AffineTransform format error for parsing matrix') + instr, mat = output + mat = np.squeeze(mat) + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + return instr diff --git a/modelscope/models/audio/layers/deep_fsmn.py b/modelscope/models/audio/layers/deep_fsmn.py new file mode 100644 index 00000000..72ba07dc --- /dev/null +++ b/modelscope/models/audio/layers/deep_fsmn.py @@ -0,0 +1,178 @@ +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class DeepFsmn(LayerBase): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + rorder=None, + hidden_size=None, + layer_norm=False, + dropout=0): + super(DeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.hidden_size = hidden_size + self.layer_norm = layer_norm + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.LayerNorm(hidden_size) + self.drop1 = nn.Dropout(p=dropout) + self.drop2 = nn.Dropout(p=dropout) + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], [1, 1], + groups=output_dim, + bias=False) + self.conv2 = nn.Conv2d( + output_dim, + output_dim, [rorder, 1], [1, 1], + groups=output_dim, + bias=False) + + def forward(self, input): + + f1 = F.relu(self.linear(input)) + + f1 = self.drop1(f1) + if self.layer_norm: + f1 = self.norm(f1) + + p1 = self.project(f1) + + x = th.unsqueeze(p1, 1) + + x_per = x.permute(0, 3, 2, 1) + + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + yr = F.pad(x_per, [0, 0, 0, self.rorder]) + yr = yr[:, :, 1:, :] + + out = x_per + self.conv1(y) + self.conv2(yr) + out = self.drop2(out) + + out1 = out.permute(0, 3, 2, 1) + + return input + out1.squeeze() + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n'\ + % (self.output_dim, self.input_dim) + re_str += ' %d %d %d %d 0\n'\ + % (1, self.hidden_size, self.lorder, 1) + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lr = output + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, hiddensize = output + self.hidden_size = int(hiddensize) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lorder = output + self.lorder = int(lorder) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lstride = output + self.lstride = lstride + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat1 = np.fliplr(mat.T).copy() + self.conv1 = nn.Conv2d( + self.output_dim, + self.output_dim, [self.lorder, 1], [1, 1], + groups=self.output_dim, + bias=False) + mat_th = th.from_numpy(mat1).type(th.FloatTensor) + mat_th = mat_th.unsqueeze(1) + mat_th = mat_th.unsqueeze(3) + self.conv1.weight = th.nn.Parameter(mat_th) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + + self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) + self.linear = nn.Linear(self.input_dim, self.hidden_size) + + self.project.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + return instr diff --git a/modelscope/models/audio/layers/layer_base.py b/modelscope/models/audio/layers/layer_base.py new file mode 100644 index 00000000..e56c4bc0 --- /dev/null +++ b/modelscope/models/audio/layers/layer_base.py @@ -0,0 +1,50 @@ +import abc +import re + +import numpy as np +import torch.nn as nn + + +def expect_token_number(instr, token): + first_token = re.match(r'^\s*' + token, instr) + if first_token is None: + return None + instr = instr[first_token.end():] + lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr) + if lr is None: + return None + return instr[lr.end():], lr.groups()[0] + + +def expect_kaldi_matrix(instr): + pos2 = instr.find('[', 0) + pos3 = instr.find(']', pos2) + mat = [] + for stt in instr[pos2 + 1:pos3].split('\n'): + tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ') + if tmp_mat.size > 0: + mat.append(tmp_mat) + return instr[pos3 + 1:], np.array(mat) + + +def to_kaldi_matrix(np_mat): + """ + function that transform as str numpy mat to standard kaldi str matrix + :param np_mat: numpy mat + :return: str + """ + np.set_printoptions(threshold=np.inf, linewidth=np.nan, suppress=True) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +class LayerBase(nn.Module, metaclass=abc.ABCMeta): + + def __init__(self): + super(LayerBase, self).__init__() + + @abc.abstractmethod + def to_kaldi_nnet(self): + pass diff --git a/modelscope/models/audio/layers/uni_deep_fsmn.py b/modelscope/models/audio/layers/uni_deep_fsmn.py new file mode 100644 index 00000000..c22460c4 --- /dev/null +++ b/modelscope/models/audio/layers/uni_deep_fsmn.py @@ -0,0 +1,482 @@ +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class SepConv(nn.Module): + + def __init__(self, + in_channels, + filters, + out_channels, + kernel_size=(5, 2), + dilation=(1, 1)): + """ :param kernel_size (time, frequency) + + """ + super(SepConv, self).__init__() + # depthwise + pointwise + self.dconv = nn.Conv2d( + in_channels, + in_channels * filters, + kernel_size, + dilation=dilation, + groups=in_channels) + self.pconv = nn.Conv2d( + in_channels * filters, out_channels, kernel_size=1) + self.padding = dilation[0] * (kernel_size[0] - 1) + + def forward(self, input): + ''' input: [B, C, T, F] + ''' + x = F.pad(input, [0, 0, self.padding, 0]) + x = self.dconv(x) + x = self.pconv(x) + return x + + +class Conv2d(nn.Module): + + def __init__(self, + input_dim, + output_dim, + lorder=20, + rorder=0, + groups=1, + bias=False, + skip_connect=True): + super(Conv2d, self).__init__() + self.lorder = lorder + self.conv = nn.Conv2d( + input_dim, output_dim, [lorder, 1], groups=groups, bias=bias) + self.rorder = rorder + if self.rorder: + self.conv2 = nn.Conv2d( + input_dim, output_dim, [rorder, 1], groups=groups, bias=bias) + self.skip_connect = skip_connect + + def forward(self, input): + # [B, 1, T, F] + x = th.unsqueeze(input, 1) + # [B, F, T, 1] + x_per = x.permute(0, 3, 2, 1) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + out = self.conv(y) + if self.rorder: + yr = F.pad(x_per, [0, 0, 0, self.rorder]) + yr = yr[:, :, 1:, :] + out += self.conv2(yr) + out = out.permute(0, 3, 2, 1).squeeze(1) + if self.skip_connect: + out = out + input + return out + + +class SelfAttLayer(nn.Module): + + def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): + super(SelfAttLayer, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.att = nn.Linear(input_dim, lorder, bias=False) + + def forward(self, input): + + f1 = F.relu(self.linear(input)) + + p1 = self.project(f1) + + x = th.unsqueeze(p1, 1) + + x_per = x.permute(0, 3, 2, 1) + + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + # z [B, F, T, lorder] + z = x_per + for i in range(1, self.lorder): + z = th.cat([z, y[:, :, self.lorder - 1 - i:-i, :]], axis=-1) + + # [B, T, lorder] + att = F.softmax(self.att(input), dim=-1) + att = th.unsqueeze(att, 1) + z = th.sum(z * att, axis=-1) + + out1 = z.permute(0, 2, 1) + + return input + out1 + + +class TFFsmn(nn.Module): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(TFFsmn, self).__init__() + + self.skip_connect = skip_connect + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.Identity() + if layer_norm: + self.norm = nn.LayerNorm(input_dim) + self.act = nn.ReLU() + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + dorder = 5 + self.conv2 = nn.Conv2d(1, 1, [dorder, 1], bias=False) + self.padding_freq = dorder - 1 + + def forward(self, input): + return self.compute1(input) + + def compute1(self, input): + ''' linear-dconv-relu(norm)-linear-dconv + ''' + x = self.linear(input) + # [B, 1, F, T] + x = th.unsqueeze(x, 1).permute(0, 1, 3, 2) + z = F.pad(x, [0, 0, self.padding_freq, 0]) + z = self.conv2(z) + x + x = z.permute(0, 3, 2, 1).squeeze(-1) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + + return input + out + + +class CNNFsmn(nn.Module): + ''' use cnn to reduce parameters + ''' + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(CNNFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.skip_connect = skip_connect + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + self.act = nn.ReLU() + kernel_size = (3, 8) + stride = (1, 4) + self.conv = nn.Sequential( + nn.ConstantPad2d((stride[1], 0, kernel_size[0] - 1, 0), 0), + nn.Conv2d(1, stride[1], kernel_size=kernel_size, stride=stride)) + + self.dconv = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + + def forward(self, input): + return self.compute2(input) + + def compute1(self, input): + ''' linear-relu(norm)-conv2d-relu?-dconv + ''' + # [B, T, F] + x = self.linear(input) + x = self.act(x) + x = th.unsqueeze(x, 1) + x = self.conv(x) + # [B, C, T, F] -> [B, 1, T, F] + b, c, t, f = x.shape + x = x.view([b, 1, t, -1]) + x = x.permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.dconv(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + return input + out + + def compute2(self, input): + ''' conv2d-relu-linear-relu?-dconv + ''' + x = th.unsqueeze(input, 1) + x = self.conv(x) + x = self.act(x) + # [B, C, T, F] -> [B, T, F] + b, c, t, f = x.shape + x = x.view([b, t, -1]) + x = self.linear(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.dconv(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + return input + out + + +class UniDeepFsmn(LayerBase): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.skip_connect = skip_connect + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.Identity() + if layer_norm: + self.norm = nn.LayerNorm(input_dim) + self.act = nn.ReLU() + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + + def forward(self, input): + return self.compute1(input) + + def compute1(self, input): + ''' linear-relu(norm)-linear-dconv + ''' + # [B, T, F] + x = self.linear(input) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + + return input + out + + def compute2(self, input): + ''' linear-dconv-linear-relu(norm) + ''' + x = self.project(input) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + x = self.linear(out) + x = self.act(x) + x = self.norm(x) + + return input + x + + def compute3(self, input): + ''' dconv-linear-relu(norm)-linear + ''' + x = th.unsqueeze(input, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + x = self.linear(out) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + + return input + x + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' \ + % (self.output_dim, self.input_dim) + re_str += ' %d %d %d %d 0\n' \ + % (1, self.hidden_size, self.lorder, 1) + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def to_raw_nnet(self, fid): + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + x.tofile(fid) + + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + x.tofile(fid) + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + x.tofile(fid) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + x.tofile(fid) + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lr = output + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, hiddensize = output + self.hidden_size = int(hiddensize) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lorder = output + self.lorder = int(lorder) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lstride = output + self.lstride = lstride + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat1 = np.fliplr(mat.T).copy() + + self.conv1 = nn.Conv2d( + self.output_dim, + self.output_dim, [self.lorder, 1], [1, 1], + groups=self.output_dim, + bias=False) + + mat_th = th.from_numpy(mat1).type(th.FloatTensor) + mat_th = mat_th.unsqueeze(1) + mat_th = mat_th.unsqueeze(3) + self.conv1.weight = th.nn.Parameter(mat_th) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + + self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) + self.linear = nn.Linear(self.input_dim, self.hidden_size) + + self.project.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat = np.squeeze(mat) + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + return instr diff --git a/modelscope/models/audio/network/__init__.py b/modelscope/models/audio/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/network/loss.py b/modelscope/models/audio/network/loss.py new file mode 100644 index 00000000..743661b3 --- /dev/null +++ b/modelscope/models/audio/network/loss.py @@ -0,0 +1,394 @@ +import torch +import torch.nn.functional as F + +from .modulation_loss import (GaborSTRFConv, MelScale, + ModulationDomainLossModule) + +EPS = 1e-8 + + +def compute_mask(mixed_spec, clean_spec, mask_type='psmiam', clip=1): + ''' + stft: (batch, ..., 2) or complex(batch, ...) + y = x + n + ''' + if torch.is_complex(mixed_spec): + yr, yi = mixed_spec.real, mixed_spec.imag + else: + yr, yi = mixed_spec[..., 0], mixed_spec[..., 1] + if torch.is_complex(clean_spec): + xr, xi = clean_spec.real, clean_spec.imag + else: + xr, xi = clean_spec[..., 0], clean_spec[..., 1] + + if mask_type == 'iam': + ymag = torch.sqrt(yr**2 + yi**2) + xmag = torch.sqrt(xr**2 + xi**2) + iam = xmag / (ymag + EPS) + return torch.clamp(iam, 0, 1) + + elif mask_type == 'psm': + ypow = yr**2 + yi**2 + psm = (xr * yr + xi * yi) / (ypow + EPS) + return torch.clamp(psm, 0, 1) + + elif mask_type == 'psmiam': + ypow = yr**2 + yi**2 + psm = (xr * yr + xi * yi) / (ypow + EPS) + ymag = torch.sqrt(yr**2 + yi**2) + xmag = torch.sqrt(xr**2 + xi**2) + iam = xmag / (ymag + EPS) + psmiam = psm * iam + return torch.clamp(psmiam, 0, 1) + + elif mask_type == 'crm': + ypow = yr**2 + yi**2 + mr = (xr * yr + xi * yi) / (ypow + EPS) + mi = (xi * yr - xr * yi) / (ypow + EPS) + mr = torch.clamp(mr, -clip, clip) + mi = torch.clamp(mi, -clip, clip) + return mr, mi + + +def energy_vad(spec, + thdhigh=320 * 600 * 600 * 2, + thdlow=320 * 300 * 300 * 2, + int16=True): + ''' + energy based vad should be accurate enough + spec: (batch, bins, frames, 2) + returns (batch, frames) + ''' + energy = torch.sum(spec[..., 0]**2 + spec[..., 1]**2, dim=1) + vad = energy > thdhigh + idx = torch.logical_and(vad == 0, energy > thdlow) + vad[idx] = 0.5 + return vad + + +def modulation_loss_init(n_fft): + gabor_strf_parameters = torch.load( + './network/gabor_strf_parameters.pt')['state_dict'] + gabor_modulation_kernels = GaborSTRFConv(supn=30, supk=30, nkern=60) + gabor_modulation_kernels.load_state_dict(gabor_strf_parameters) + + modulation_loss_module = ModulationDomainLossModule( + gabor_modulation_kernels.eval()) + for param in modulation_loss_module.parameters(): + param.requires_grad = False + + stft2mel = MelScale( + n_mels=80, sample_rate=16000, n_stft=n_fft // 2 + 1).cuda() + + return modulation_loss_module, stft2mel + + +def mask_loss_function( + loss_func='psm_loss', + loss_type='mse', # ['mse', 'mae', 'comb'] + mask_type='psmiam', + use_mod_loss=False, + use_wav2vec_loss=False, + n_fft=640, + hop_length=320, + EPS=1e-8, + weight=None): + if weight is not None: + print(f'Use loss weight: {weight}') + winlen = n_fft + window = torch.hamming_window(winlen, periodic=False) + + def stft(x, return_complex=False): + # returns [batch, bins, frames, 2] + return torch.stft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + return_complex=return_complex) + + def istft(x, slen): + return torch.istft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + length=slen) + + def mask_loss(targets, masks, nframes): + ''' [Batch, Time, Frequency] + ''' + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks = masks * mask_for_loss + targets = targets * mask_for_loss + + if weight is None: + alpha = 1 + else: # for aec ST + alpha = weight - targets + + if loss_type == 'mse': + loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2)) + elif loss_type == 'mae': + loss = torch.sum(alpha * torch.abs(targets - masks)) + else: # mse(mask), mae(mask) approx 1:2 + loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2) + + 0.1 * alpha * torch.abs(targets - masks)) + loss /= torch.sum(nframes) + return loss + + def spectrum_loss(targets, spec, nframes): + ''' [Batch, Time, Frequency, 2] + ''' + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + xr = spec[..., 0] * mask_for_loss + xi = spec[..., 1] * mask_for_loss + yr = targets[..., 0] * mask_for_loss + yi = targets[..., 1] * mask_for_loss + xmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) * mask_for_loss + ymag = torch.sqrt(targets[..., 0]**2 + + targets[..., 1]**2) * mask_for_loss + + loss1 = torch.sum(torch.pow(xr - yr, 2) + torch.pow(xi - yi, 2)) + loss2 = torch.sum(torch.pow(xmag - ymag, 2)) + + loss = (loss1 + loss2) / torch.sum(nframes) + return loss + + def sa_loss_dlen(mixed, clean, masks, nframes): + yspec = stft(mixed).permute([0, 2, 1, 3]) / 32768 + xspec = stft(clean).permute([0, 2, 1, 3]) / 32768 + with torch.no_grad(): + mask_for_loss = torch.ones_like(xspec[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + emag = ((yspec[..., 0]**2 + yspec[..., 1]**2)**0.15) * (masks**0.3) + xmag = (xspec[..., 0]**2 + xspec[..., 1]**2)**0.15 + emag = emag * mask_for_loss + xmag = xmag * mask_for_loss + + loss = torch.sum(torch.pow(emag - xmag, 2)) / torch.sum(nframes) + return loss + + def psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed) + clean_spec = stft(clean) + targets = compute_mask(mixed_spec, clean_spec, mask_type) + # [B, T, F] + targets = targets.permute(0, 2, 1) + + loss = mask_loss(targets, masks, nframes) + + if subtask is not None: + vadtargets = energy_vad(clean_spec) + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets[:, :, 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:] = 0 + subtask = subtask[:, :, 0] * mask_for_loss + vadtargets = vadtargets * mask_for_loss + + loss_vad = F.binary_cross_entropy(subtask, vadtargets) + return loss + loss_vad + return loss + + def modulation_loss(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed, True) + clean_spec = stft(clean, True) + enhanced_mag = torch.abs(mixed_spec) + clean_mag = torch.abs(clean_spec) + with torch.no_grad(): + mask_for_loss = torch.ones_like(clean_mag) + for idx, num in enumerate(nframes): + mask_for_loss[idx, :, num:] = 0 + clean_mag = clean_mag * mask_for_loss + enhanced_mag = enhanced_mag * mask_for_loss * masks.permute([0, 2, 1]) + + # Covert to log-mel representation + # (B,T,#mel_channels) + clean_log_mel = torch.log( + torch.transpose(stft2mel(clean_mag**2), 2, 1) + 1e-8) + enhanced_log_mel = torch.log( + torch.transpose(stft2mel(enhanced_mag**2), 2, 1) + 1e-8) + + alpha = compute_mask(mixed_spec, clean_spec, mask_type) + alpha = alpha.permute(0, 2, 1) + loss = 0.05 * modulation_loss_module(enhanced_log_mel, clean_log_mel, + alpha) + loss2 = psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask) + # print(loss.item(), loss2.item()) #approx 1:4 + loss = loss + loss2 + return loss + + def wav2vec_loss(mixed, clean, masks, nframes, subtask=None): + mixed /= 32768 + clean /= 32768 + mixed_spec = stft(mixed) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + est_clean = istft(estimate, clean.shape[1]) + loss = wav2vec_loss_module(est_clean, clean) + return loss + + def sisdr_loss_dlen(mixed, + clean, + masks, + nframes, + subtask=None, + zero_mean=True): + mixed_spec = stft(mixed) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + est_clean = istft(estimate, clean.shape[1]) + flen = min(clean.shape[1], est_clean.shape[1]) + clean = clean[:, :flen] + est_clean = est_clean[:, :flen] + + # follow asteroid/losses/sdr.py + if zero_mean: + clean = clean - torch.mean(clean, dim=1, keepdim=True) + est_clean = est_clean - torch.mean(est_clean, dim=1, keepdim=True) + + dot = torch.sum(est_clean * clean, dim=1, keepdim=True) + s_clean_energy = torch.sum(clean**2, dim=1, keepdim=True) + EPS + scaled_clean = dot * clean / s_clean_energy + e_noise = est_clean - scaled_clean + + # [batch] + sisdr = torch.sum( + scaled_clean**2, dim=1) / ( + torch.sum(e_noise**2, dim=1) + EPS) + sisdr = -10 * torch.log10(sisdr + EPS) + loss = sisdr.mean() + return loss + + def sisdr_freq_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed) + clean_spec = stft(clean) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + + dot_real = estimate[..., 0] * clean_spec[..., 0] + \ + estimate[..., 1] * clean_spec[..., 1] + dot_imag = estimate[..., 0] * clean_spec[..., 1] - \ + estimate[..., 1] * clean_spec[..., 0] + dot = torch.cat([dot_real.unsqueeze(3), dot_imag.unsqueeze(3)], dim=-1) + s_clean_energy = clean_spec[..., 0] ** 2 + \ + clean_spec[..., 1] ** 2 + EPS + scaled_clean = dot * clean_spec / s_clean_energy.unsqueeze(3) + e_noise = estimate - scaled_clean + + # [batch] + scaled_clean_energy = torch.sum( + scaled_clean[..., 0]**2 + scaled_clean[..., 1]**2, dim=1) + e_noise_energy = torch.sum( + e_noise[..., 0]**2 + e_noise[..., 1]**2, dim=1) + sisdr = torch.sum( + scaled_clean_energy, dim=1) / ( + torch.sum(e_noise_energy, dim=1) + EPS) + sisdr = -10 * torch.log10(sisdr + EPS) + loss = sisdr.mean() + return loss + + def crm_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed).permute([0, 2, 1, 3]) + clean_spec = stft(clean).permute([0, 2, 1, 3]) + mixed_spec = mixed_spec / 32768 + clean_spec = clean_spec / 32768 + tgt_mr, tgt_mi = compute_mask(mixed_spec, clean_spec, mask_type='crm') + + D = int(masks.shape[2] / 2) + with torch.no_grad(): + mask_for_loss = torch.ones_like(clean_spec[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + mr = masks[..., :D] * mask_for_loss + mi = masks[..., D:] * mask_for_loss + tgt_mr = tgt_mr * mask_for_loss + tgt_mi = tgt_mi * mask_for_loss + + if weight is None: + alpha = 1 + else: + alpha = weight - tgt_mr + # signal approximation + yr = mixed_spec[..., 0] + yi = mixed_spec[..., 1] + loss1 = torch.sum(alpha * torch.pow((mr * yr - mi * yi) - clean_spec[..., 0], 2)) \ + + torch.sum(alpha * torch.pow((mr * yi + mi * yr) - clean_spec[..., 1], 2)) + # mask approximation + loss2 = torch.sum(alpha * torch.pow(mr - tgt_mr, 2)) \ + + torch.sum(alpha * torch.pow(mi - tgt_mi, 2)) + loss = 0.5 * (loss1 + loss2) / torch.sum(nframes) + return loss + + def crm_miso_loss_dlen(mixed, clean, masks, nframes): + return crm_loss_dlen(mixed[..., 0], clean[..., 0], masks, nframes) + + def mimo_loss_dlen(mixed, clean, masks, nframes): + chs = mixed.shape[-1] + D = masks.shape[2] // chs + loss = psm_vad_loss_dlen(mixed[..., 0], clean[..., 0], masks[..., :D], + nframes) + for ch in range(1, chs): + loss1 = psm_vad_loss_dlen(mixed[..., ch], clean[..., ch], + masks[..., ch * D:ch * D + D], nframes) + loss = loss + loss1 + return loss / chs + + def spec_loss_dlen(mixed, clean, spec, nframes): + clean_spec = stft(clean).permute([0, 2, 1, 3]) + clean_spec = clean_spec / 32768 + + D = spec.shape[2] // 2 + spec_est = torch.cat([spec[..., :D, None], spec[..., D:, None]], + dim=-1) + loss = spectrum_loss(clean_spec, spec_est, nframes) + return loss + + if loss_func == 'psm_vad_loss_dlen': + return psm_vad_loss_dlen + elif loss_func == 'sisdr_loss_dlen': + return sisdr_loss_dlen + elif loss_func == 'sisdr_freq_loss_dlen': + return sisdr_freq_loss_dlen + elif loss_func == 'crm_loss_dlen': + return crm_loss_dlen + elif loss_func == 'modulation_loss': + return modulation_loss + elif loss_func == 'wav2vec_loss': + return wav2vec_loss + elif loss_func == 'mimo_loss_dlen': + return mimo_loss_dlen + elif loss_func == 'spec_loss_dlen': + return spec_loss_dlen + elif loss_func == 'sa_loss_dlen': + return sa_loss_dlen + else: + print('error loss func') + return None diff --git a/modelscope/models/audio/network/modulation_loss.py b/modelscope/models/audio/network/modulation_loss.py new file mode 100644 index 00000000..a45ddead --- /dev/null +++ b/modelscope/models/audio/network/modulation_loss.py @@ -0,0 +1,248 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchaudio.transforms import MelScale + + +class ModulationDomainLossModule(torch.nn.Module): + """Modulation-domain loss function developed in [1] for supervised speech enhancement + + In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram + as the input spectrogram representation. + Specific parameter details are in the paper and in the example below + + Parameters + ---------- + modulation_kernels: nn.Module + Differentiable module that transforms a spectrogram representation to the modulation domain + + modulation_domain = modulation_kernels(input_tf_representation) + Input Spectrogram representation (B, T, F) ---> |(M) modulation_kernels|--->Modulation Domain(B, M, T', F') + + norm: boolean + Normalizes the modulation domain representation to be 0 mean across time + + [1] T. Vuong, Y. Xia, and R. M. Stern, “A modulation-domain lossfor neural-network-based real-time + speech enhancement” + Accepted ICASSP 2021, https://arxiv.org/abs/2102.07330 + + + """ + + def __init__(self, modulation_kernels, norm=True): + super(ModulationDomainLossModule, self).__init__() + + self.modulation_kernels = modulation_kernels + self.mse = nn.MSELoss(reduce=False) + self.norm = norm + + def forward(self, enhanced_spect, clean_spect, weight=None): + """Calculate modulation-domain loss + Args: + enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). + clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). + Returns: + Tensor: Modulation-domain loss value. + """ + + clean_mod = self.modulation_kernels(clean_spect) + enhanced_mod = self.modulation_kernels(enhanced_spect) + + if self.norm: + mean_clean_mod = torch.mean(clean_mod, dim=2) + mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) + + clean_mod = clean_mod - mean_clean_mod.unsqueeze(2) + enhanced_mod = enhanced_mod - mean_enhanced_mod.unsqueeze(2) + + if weight is None: + alpha = 1 + else: # TF-mask weight + alpha = 1 + torch.sum(weight, dim=-1, keepdim=True).unsqueeze(1) + mod_mse_loss = self.mse(enhanced_mod, clean_mod) * alpha + mod_mse_loss = torch.mean( + torch.sum(mod_mse_loss, dim=(1, 2, 3)) + / torch.sum(clean_mod**2, dim=(1, 2, 3))) + + return mod_mse_loss + + +class ModulationDomainNCCLossModule(torch.nn.Module): + """Modulation-domain loss function developed in [1] for supervised speech enhancement + + # Speech Intelligibility Prediction Using Spectro-Temporal Modulation Analysis - based off of this + + In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram + as the input spectrogram representation. + Specific parameter details are in the paper and in the example below + + Parameters + ---------- + modulation_kernels: nn.Module + Differentiable module that transforms a spectrogram representation to the modulation domain + + modulation_domain = modulation_kernels(input_tf_representation) + Input Spectrogram representation(B, T, F) --- (M) modulation_kernels---> Modulation Domain(B, M, T', F') + + [1] + + """ + + def __init__(self, modulation_kernels): + super(ModulationDomainNCCLossModule, self).__init__() + + self.modulation_kernels = modulation_kernels + self.mse = nn.MSELoss(reduce=False) + + def forward(self, enhanced_spect, clean_spect): + """Calculate modulation-domain loss + Args: + enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). + clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). + Returns: + Tensor: Modulation-domain loss value. + """ + + clean_mod = self.modulation_kernels(clean_spect) + enhanced_mod = self.modulation_kernels(enhanced_spect) + mean_clean_mod = torch.mean(clean_mod, dim=2) + mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) + + normalized_clean = clean_mod - mean_clean_mod.unsqueeze(2) + normalized_enhanced = enhanced_mod - mean_enhanced_mod.unsqueeze(2) + + inner_product = torch.sum( + normalized_clean * normalized_enhanced, dim=2) + normalized_denom = (torch.sum( + normalized_clean * normalized_clean, dim=2))**.5 * (torch.sum( + normalized_enhanced * normalized_enhanced, dim=2))**.5 + + ncc = inner_product / normalized_denom + mod_mse_loss = torch.mean((ncc - 1.0)**2) + + return mod_mse_loss + + +class GaborSTRFConv(nn.Module): + """Gabor-STRF-based cross-correlation kernel.""" + + def __init__(self, + supn, + supk, + nkern, + rates=None, + scales=None, + norm_strf=True, + real_only=False): + """Instantiate a Gabor-based STRF convolution layer. + Parameters + ---------- + supn: int + Time support in number of frames. Also the window length. + supk: int + Frequency support in number of channels. Also the window length. + nkern: int + Number of kernels, each with a learnable rate and scale. + rates: list of float, None + Initial values for temporal modulation. + scales: list of float, None + Initial values for spectral modulation. + norm_strf: Boolean + Normalize STRF kernels to be unit length + real_only: Boolean + If True, nkern REAL gabor-STRF kernels + If False, nkern//2 REAL and nkern//2 IMAGINARY gabor-STRF kernels + """ + super(GaborSTRFConv, self).__init__() + self.numN = supn + self.numK = supk + self.numKern = nkern + self.real_only = real_only + self.norm_strf = norm_strf + + if not real_only: + nkern = nkern // 2 + + if supk % 2 == 0: # force odd number + supk += 1 + self.supk = torch.arange(supk, dtype=torch.float32) + if supn % 2 == 0: # force odd number + supn += 1 + self.supn = torch.arange(supn, dtype=self.supk.dtype) + self.padding = (supn // 2, supk // 2) + # Set up learnable parameters + # for param in (rates, scales): + # assert (not param) or len(param) == nkern + if not rates: + + rates = torch.rand(nkern) * math.pi / 2.0 + + if not scales: + + scales = (torch.rand(nkern) * 2.0 - 1.0) * math.pi / 2.0 + + self.rates_ = nn.Parameter(torch.Tensor(rates)) + self.scales_ = nn.Parameter(torch.Tensor(scales)) + + def strfs(self): + """Make STRFs using the current parameters.""" + + if self.supn.device != self.rates_.device: # for first run + self.supn = self.supn.to(self.rates_.device) + self.supk = self.supk.to(self.rates_.device) + n0, k0 = self.padding + + nwind = .5 - .5 * \ + torch.cos(2 * math.pi * (self.supn + 1) / (len(self.supn) + 1)) + kwind = .5 - .5 * \ + torch.cos(2 * math.pi * (self.supk + 1) / (len(self.supk) + 1)) + + new_wind = torch.matmul((nwind).unsqueeze(-1), (kwind).unsqueeze(0)) + + n_n_0 = self.supn - n0 + k_k_0 = self.supk - k0 + n_mult = torch.matmul( + n_n_0.unsqueeze(1), + torch.ones((1, len(self.supk))).type(torch.FloatTensor).to( + self.rates_.device)) + k_mult = torch.matmul( + torch.ones((len(self.supn), + 1)).type(torch.FloatTensor).to(self.rates_.device), + k_k_0.unsqueeze(0)) + + inside = self.rates_.unsqueeze(1).unsqueeze( + 1) * n_mult + self.scales_.unsqueeze(1).unsqueeze(1) * k_mult + real_strf = torch.cos(inside) * new_wind.unsqueeze(0) + + if self.real_only: + final_strf = real_strf + + else: + imag_strf = torch.sin(inside) * new_wind.unsqueeze(0) + final_strf = torch.cat([real_strf, imag_strf], dim=0) + + if self.norm_strf: + final_strf = final_strf / (torch.sum( + final_strf**2, dim=(1, 2)).unsqueeze(1).unsqueeze(2))**.5 + + return final_strf + + def forward(self, sigspec): + """Forward pass a batch of (real) spectra [Batch x Time x Frequency].""" + if len(sigspec.shape) == 2: # expand batch dimension if single eg + sigspec = sigspec.unsqueeze(0) + strfs = self.strfs().unsqueeze(1).type_as(sigspec) + out = F.conv2d(sigspec.unsqueeze(1), strfs, padding=self.padding) + return out + + def __repr__(self): + """Gabor filter""" + report = """ + +++++ Gabor Filter Kernels [{}], supn[{}], supk[{}] real only [{}] norm strf [{}] +++++ + + """.format(self.numKern, self.numN, self.numK, self.real_only, + self.norm_strf) + + return report diff --git a/modelscope/models/audio/network/se_net.py b/modelscope/models/audio/network/se_net.py new file mode 100644 index 00000000..54808043 --- /dev/null +++ b/modelscope/models/audio/network/se_net.py @@ -0,0 +1,483 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..layers.activations import RectifiedLinear, Sigmoid +from ..layers.affine_transform import AffineTransform +from ..layers.deep_fsmn import DeepFsmn +from ..layers.uni_deep_fsmn import Conv2d, UniDeepFsmn + + +class MaskNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=128, + hidden_dim2=None, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(MaskNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + if hidden_dim2 is None: + hidden_dim2 = hidden_dim + + if rorder == 0: + repeats = [ + UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim2, + dilation=dilation, + layer_norm=layer_norm, + dropout=dropout) for i in range(layers) + ] + else: + repeats = [ + DeepFsmn( + hidden_dim, + hidden_dim, + lorder, + rorder, + hidden_dim2, + layer_norm=layer_norm, + dropout=dropout) for i in range(layers) + ] + self.deepfsmn = nn.Sequential(*repeats) + + self.linear2 = AffineTransform(hidden_dim, outdim) + + self.crm = crm + if self.crm: + self.sig = nn.Tanh() + else: + self.sig = Sigmoid(outdim, outdim) + + self.vad = vad + if self.vad: + self.linear3 = AffineTransform(hidden_dim, 1) + + self.layers = layers + self.linearout = linearout + if self.linearout and self.vad: + print('Warning: not supported nnet') + + def forward(self, feat, ctl=None): + x1 = self.linear1(feat) + x2 = self.relu(x1) + if ctl is not None: + ctl = min(ctl, self.layers - 1) + for i in range(ctl): + x2 = self.deepfsmn[i](x2) + mask = self.sig(self.linear2(x2)) + if self.vad: + vad = torch.sigmoid(self.linear3(x2)) + return mask, vad + else: + return mask + x3 = self.deepfsmn(x2) + if self.linearout: + return self.linear2(x3) + mask = self.sig(self.linear2(x3)) + if self.vad: + vad = torch.sigmoid(self.linear3(x3)) + return mask, vad + else: + return mask + + def to_kaldi_nnet(self): + re_str = '' + re_str += '\n' + re_str += self.linear1.to_kaldi_nnet() + re_str += self.relu.to_kaldi_nnet() + for dfsmn in self.deepfsmn: + re_str += dfsmn.to_kaldi_nnet() + re_str += self.linear2.to_kaldi_nnet() + re_str += self.sig.to_kaldi_nnet() + re_str += '\n' + + return re_str + + def to_raw_nnet(self, fid): + self.linear1.to_raw_nnet(fid) + for dfsmn in self.deepfsmn: + dfsmn.to_raw_nnet(fid) + self.linear2.to_raw_nnet(fid) + + +class StageNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + layers2=6, + hidden_dim=128, + lorder=20, + rorder=0, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(StageNet, self).__init__() + + self.stage1 = nn.ModuleList() + self.stage2 = nn.ModuleList() + layer = nn.Sequential(nn.Linear(indim, hidden_dim), nn.ReLU()) + self.stage1.append(layer) + for i in range(layers): + layer = UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim, + layer_norm=layer_norm, + dropout=dropout) + self.stage1.append(layer) + layer = nn.Sequential(nn.Linear(hidden_dim, 321), nn.Sigmoid()) + self.stage1.append(layer) + # stage2 + layer = nn.Sequential(nn.Linear(321 + indim, hidden_dim), nn.ReLU()) + self.stage2.append(layer) + for i in range(layers2): + layer = UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim, + layer_norm=layer_norm, + dropout=dropout) + self.stage2.append(layer) + layer = nn.Sequential( + nn.Linear(hidden_dim, outdim), + nn.Sigmoid() if not crm else nn.Tanh()) + self.stage2.append(layer) + self.crm = crm + self.vad = vad + self.linearout = linearout + self.window = torch.hamming_window(640, periodic=False).cuda() + self.freezed = False + + def freeze(self): + if not self.freezed: + for param in self.stage1.parameters(): + param.requires_grad = False + self.freezed = True + print('freezed stage1') + + def forward(self, feat, mixture, ctl=None): + if ctl == 'off': + x = feat + for i in range(len(self.stage1)): + x = self.stage1[i](x) + return x + else: + self.freeze() + x = feat + for i in range(len(self.stage1)): + x = self.stage1[i](x) + + spec = torch.stft( + mixture / 32768, + 640, + 320, + 640, + self.window, + center=False, + return_complex=True) + spec = torch.view_as_real(spec).permute([0, 2, 1, 3]) + specmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) + est = x * specmag + y = torch.cat([est, feat], dim=-1) + for i in range(len(self.stage2)): + y = self.stage2[i](y) + return y + + +class Unet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + dims=[256] * 4, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(Unet, self).__init__() + + self.linear1 = AffineTransform(indim, dims[0]) + self.relu = RectifiedLinear(dims[0], dims[0]) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + for i in range(len(dims) - 1): + layer = nn.Sequential( + nn.Linear(dims[i], dims[i + 1]), nn.ReLU(), + nn.Linear(dims[i + 1], dims[i + 1], bias=False), + Conv2d( + dims[i + 1], + dims[i + 1], + lorder, + groups=dims[i + 1], + skip_connect=True)) + self.encoder.append(layer) + for i in range(len(dims) - 1, 0, -1): + layer = nn.Sequential( + nn.Linear(dims[i] * 2, dims[i - 1]), nn.ReLU(), + nn.Linear(dims[i - 1], dims[i - 1], bias=False), + Conv2d( + dims[i - 1], + dims[i - 1], + lorder, + groups=dims[i - 1], + skip_connect=True)) + self.decoder.append(layer) + self.tf = nn.ModuleList() + for i in range(layers - 2 * (len(dims) - 1)): + layer = nn.Sequential( + nn.Linear(dims[-1], dims[-1]), nn.ReLU(), + nn.Linear(dims[-1], dims[-1], bias=False), + Conv2d( + dims[-1], + dims[-1], + lorder, + groups=dims[-1], + skip_connect=True)) + self.tf.append(layer) + + self.linear2 = AffineTransform(dims[0], outdim) + self.crm = crm + self.act = nn.Tanh() if self.crm else nn.Sigmoid() + self.vad = False + self.layers = layers + self.linearout = linearout + + def forward(self, x, ctl=None): + x = self.linear1(x) + x = self.relu(x) + + encoder_out = [] + for i in range(len(self.encoder)): + x = self.encoder[i](x) + encoder_out.append(x) + for i in range(len(self.tf)): + x = self.tf[i](x) + for i in range(len(self.decoder)): + x = torch.cat([x, encoder_out[-1 - i]], dim=-1) + x = self.decoder[i](x) + + x = self.linear2(x) + if self.linearout: + return x + return self.act(x) + + +class BranchNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=256, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(BranchNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + + self.convs = nn.ModuleList() + self.deepfsmn = nn.ModuleList() + self.FREQ = nn.ModuleList() + self.TIME = nn.ModuleList() + self.br1 = nn.ModuleList() + self.br2 = nn.ModuleList() + for i in range(layers): + ''' + layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim, bias=False), + Conv2d(hidden_dim, hidden_dim, lorder, + groups=hidden_dim, skip_connect=True) + ) + self.deepfsmn.append(layer) + ''' + layer = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + self.FREQ.append(layer) + ''' + layer = nn.GRU(hidden_dim, hidden_dim, + batch_first=True, + bidirectional=False) + self.TIME.append(layer) + + layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim//2, bias=False), + Conv2d(hidden_dim//2, hidden_dim//2, lorder, + groups=hidden_dim//2, skip_connect=True) + ) + self.br1.append(layer) + layer = nn.GRU(hidden_dim, hidden_dim//2, + batch_first=True, + bidirectional=False) + self.br2.append(layer) + ''' + + self.linear2 = AffineTransform(hidden_dim, outdim) + self.crm = crm + self.act = nn.Tanh() if self.crm else nn.Sigmoid() + self.vad = False + self.layers = layers + self.linearout = linearout + + def forward(self, x, ctl=None): + return self.forward_branch(x) + + def forward_sepconv(self, x): + x = torch.unsqueeze(x, 1) + for i in range(len(self.convs)): + x = self.convs[i](x) + x = F.relu(x) + B, C, H, W = x.shape + x = x.permute(0, 2, 1, 3) + x = torch.reshape(x, [B, H, C * W]) + x = self.linear1(x) + x = self.relu(x) + for i in range(self.layers): + x = self.deepfsmn[i](x) + x + x = self.linear2(x) + return self.act(x) + + def forward_branch(self, x): + x = self.linear1(x) + x = self.relu(x) + for i in range(self.layers): + z = self.FREQ[i](x) + x = z + x + x = self.linear2(x) + if self.linearout: + return x + return self.act(x) + + +class TACNet(nn.Module): + ''' transform average concatenate for ad hoc dr + ''' + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=128, + lorder=20, + rorder=0, + crm=False, + vad=False, + linearout=False): + super(TACNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + + if rorder == 0: + repeats = [ + UniDeepFsmn(hidden_dim, hidden_dim, lorder, hidden_dim) + for i in range(layers) + ] + else: + repeats = [ + DeepFsmn(hidden_dim, hidden_dim, lorder, rorder, hidden_dim) + for i in range(layers) + ] + self.deepfsmn = nn.Sequential(*repeats) + + self.ch_transform = nn.ModuleList([]) + self.ch_average = nn.ModuleList([]) + self.ch_concat = nn.ModuleList([]) + for i in range(layers): + self.ch_transform.append( + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) + self.ch_average.append( + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) + self.ch_concat.append( + nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), nn.PReLU())) + + self.linear2 = AffineTransform(hidden_dim, outdim) + + self.crm = crm + if self.crm: + self.sig = nn.Tanh() + else: + self.sig = Sigmoid(outdim, outdim) + + self.vad = vad + if self.vad: + self.linear3 = AffineTransform(hidden_dim, 1) + + self.layers = layers + self.linearout = linearout + if self.linearout and self.vad: + print('Warning: not supported nnet') + + def forward(self, feat, ctl=None): + B, T, F = feat.shape + # assume 4ch + ch = 4 + zlist = [] + for c in range(ch): + z = self.linear1(feat[..., c * (F // 4):(c + 1) * (F // 4)]) + z = self.relu(z) + zlist.append(z) + for i in range(self.layers): + # forward + for c in range(ch): + zlist[c] = self.deepfsmn[i](zlist[c]) + + # transform + olist = [] + for c in range(ch): + z = self.ch_transform[i](zlist[c]) + olist.append(z) + # average + avg = 0 + for c in range(ch): + avg = avg + olist[c] + avg = avg / ch + avg = self.ch_average[i](avg) + # concate + for c in range(ch): + tac = torch.cat([olist[c], avg], dim=-1) + tac = self.ch_concat[i](tac) + zlist[c] = zlist[c] + tac + + for c in range(ch): + zlist[c] = self.sig(self.linear2(zlist[c])) + mask = torch.cat(zlist, dim=-1) + return mask + + def to_kaldi_nnet(self): + pass diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py index d47ce8cf..14865872 100644 --- a/modelscope/pipelines/__init__.py +++ b/modelscope/pipelines/__init__.py @@ -1,4 +1,4 @@ -from .audio import * # noqa F403 +from .audio import LinearAECPipeline from .base import Pipeline from .builder import pipeline from .cv import * # noqa F403 diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index e69de29b..eaa31c7c 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -0,0 +1 @@ +from .linear_aec_pipeline import LinearAECPipeline diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py new file mode 100644 index 00000000..528d8d47 --- /dev/null +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -0,0 +1,160 @@ +import importlib +import os +from typing import Any, Dict + +import numpy as np +import scipy.io.wavfile as wav +import torch +import yaml + +from modelscope.preprocessors.audio import LinearAECAndFbank +from modelscope.utils.constant import ModelFile, Tasks +from ..base import Pipeline +from ..builder import PIPELINES + +FEATURE_MVN = 'feature.DEY.mvn.txt' + +CONFIG_YAML = 'dey_mini.yaml' + + +def initialize_config(module_cfg): + r"""According to config items, load specific module dynamically with params. + 1. Load the module corresponding to the "module" param. + 2. Call function (or instantiate class) corresponding to the "main" param. + 3. Send the param (in "args") into the function (or class) when calling ( or instantiating). + + Args: + module_cfg (dict): config items, eg: + { + "module": "models.model", + "main": "Model", + "args": {...} + } + + Returns: + the module loaded. + """ + module = importlib.import_module(module_cfg['module']) + return getattr(module, module_cfg['main'])(**module_cfg['args']) + + +@PIPELINES.register_module( + Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k') +class LinearAECPipeline(Pipeline): + r"""AEC Inference Pipeline only support 16000 sample rate. + + When invoke the class with pipeline.__call__(), you should provide two params: + Dict[str, Any] + the path of wav files,eg:{ + "nearend_mic": "/your/data/near_end_mic_audio.wav", + "farend_speech": "/your/data/far_end_speech_audio.wav"} + output_path (str, optional): "/your/output/audio_after_aec.wav" + the file path to write generate audio. + """ + + def __init__(self, model): + r""" + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.use_cuda = torch.cuda.is_available() + with open( + os.path.join(self.model, CONFIG_YAML), encoding='utf-8') as f: + self.config = yaml.full_load(f.read()) + self.config['io']['mvn'] = os.path.join(self.model, FEATURE_MVN) + self._init_model() + self.preprocessor = LinearAECAndFbank(self.config['io']) + + n_fft = self.config['loss']['args']['n_fft'] + hop_length = self.config['loss']['args']['hop_length'] + winlen = n_fft + window = torch.hamming_window(winlen, periodic=False) + + def stft(x): + return torch.stft( + x, + n_fft, + hop_length, + winlen, + center=False, + window=window.to(x.device), + return_complex=False) + + def istft(x, slen): + return torch.istft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + length=slen) + + self.stft = stft + self.istft = istft + + def _init_model(self): + checkpoint = torch.load( + os.path.join(self.model, ModelFile.TORCH_MODEL_BIN_FILE), + map_location='cpu') + self.model = initialize_config(self.config['nnet']) + if self.use_cuda: + self.model = self.model.cuda() + self.model.load_state_dict(checkpoint) + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + r"""The AEC process. + + Args: + inputs: dict={'feature': Tensor, 'base': Tensor} + 'feature' feature of input audio. + 'base' the base audio to mask. + + Returns: + dict: + { + 'output_pcm': generated audio array + } + """ + output_data = self._process(inputs['feature'], inputs['base']) + return {'output_pcm': output_data} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + r"""The post process. Will save audio to file, if the output_path is given. + + Args: + inputs: dict: + { + 'output_pcm': generated audio array + } + kwargs: accept 'output_path' which is the path to write generated audio + + Returns: + dict: + { + 'output_pcm': generated audio array + } + """ + if 'output_path' in kwargs.keys(): + wav.write(kwargs['output_path'], self.preprocessor.SAMPLE_RATE, + inputs['output_pcm'].astype(np.int16)) + inputs['output_pcm'] = inputs['output_pcm'] / 32768.0 + return inputs + + def _process(self, fbanks, mixture): + if self.use_cuda: + fbanks = fbanks.cuda() + mixture = mixture.cuda() + if self.model.vad: + with torch.no_grad(): + masks, vad = self.model(fbanks.unsqueeze(0)) + masks = masks.permute([2, 1, 0]) + else: + with torch.no_grad(): + masks = self.model(fbanks.unsqueeze(0)) + masks = masks.permute([2, 1, 0]) + spectrum = self.stft(mixture) + masked_spec = spectrum * masks + masked_sig = self.istft(masked_spec, len(mixture)).cpu().numpy() + return masked_sig diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index c88e358c..15d8a995 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -84,6 +84,12 @@ TASK_OUTPUTS = { # ============ audio tasks =================== + # audio processed for single file in PCM format + # { + # "output_pcm": np.array with shape(samples,) and dtype float32 + # } + Tasks.speech_signal_process: ['output_pcm'], + # ============ multi-modal tasks =================== # image caption result for single sample diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 81ca1007..5db5b407 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio import LinearAECAndFbank from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose diff --git a/modelscope/preprocessors/audio.py b/modelscope/preprocessors/audio.py new file mode 100644 index 00000000..a2c15714 --- /dev/null +++ b/modelscope/preprocessors/audio.py @@ -0,0 +1,230 @@ +import ctypes +import os +from typing import Any, Dict + +import numpy as np +import scipy.io.wavfile as wav +import torch +import torchaudio.compliance.kaldi as kaldi +from numpy.ctypeslib import ndpointer + +from modelscope.utils.constant import Fields +from .builder import PREPROCESSORS + + +def load_wav(path): + samp_rate, data = wav.read(path) + return np.float32(data), samp_rate + + +def load_library(libaec): + libaec_in_cwd = os.path.join('.', libaec) + if os.path.exists(libaec_in_cwd): + libaec = libaec_in_cwd + mitaec = ctypes.cdll.LoadLibrary(libaec) + fe_process = mitaec.fe_process_inst + fe_process.argtypes = [ + ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), + ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), ctypes.c_int, + ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), + ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), + ndpointer(ctypes.c_float, flags='C_CONTIGUOUS') + ] + return fe_process + + +def do_linear_aec(fe_process, mic, ref, int16range=True): + mic = np.float32(mic) + ref = np.float32(ref) + if len(mic) > len(ref): + mic = mic[:len(ref)] + out_mic = np.zeros_like(mic) + out_linear = np.zeros_like(mic) + out_echo = np.zeros_like(mic) + out_ref = np.zeros_like(mic) + if int16range: + mic /= 32768 + ref /= 32768 + fe_process(mic, ref, len(mic), out_mic, out_linear, out_echo) + # out_ref not in use here + if int16range: + out_mic *= 32768 + out_linear *= 32768 + out_echo *= 32768 + return out_mic, out_ref, out_linear, out_echo + + +def load_kaldi_feature_transform(filename): + fp = open(filename, 'r') + all_str = fp.read() + pos1 = all_str.find('AddShift') + pos2 = all_str.find('[', pos1) + pos3 = all_str.find(']', pos2) + mean = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') + pos1 = all_str.find('Rescale') + pos2 = all_str.find('[', pos1) + pos3 = all_str.find(']', pos2) + scale = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') + fp.close() + return mean, scale + + +class Feature: + r"""Extract feat from one utterance. + """ + + def __init__(self, + fbank_config, + feat_type='spec', + mvn_file=None, + cuda=False): + r""" + + Args: + fbank_config (dict): + feat_type (str): + raw: do nothing + fbank: use kaldi.fbank + spec: Real/Imag + logpow: log(1+|x|^2) + mvn_file (str): the path of data file for mean variance normalization + cuda: + """ + self.fbank_config = fbank_config + self.feat_type = feat_type + self.n_fft = fbank_config['frame_length'] * fbank_config[ + 'sample_frequency'] // 1000 + self.hop_length = fbank_config['frame_shift'] * fbank_config[ + 'sample_frequency'] // 1000 + self.window = torch.hamming_window(self.n_fft, periodic=False) + + self.mvn = False + if mvn_file is not None and os.path.exists(mvn_file): + print(f'loading mvn file: {mvn_file}') + shift, scale = load_kaldi_feature_transform(mvn_file) + self.shift = torch.from_numpy(shift) + self.scale = torch.from_numpy(scale) + self.mvn = True + if cuda: + self.window = self.window.cuda() + if self.mvn: + self.shift = self.shift.cuda() + self.scale = self.scale.cuda() + + def compute(self, utt): + r""" + + Args: + utt: in [-32768, 32767] range + + Returns: + [..., T, F] + """ + if self.feat_type == 'raw': + return utt + elif self.feat_type == 'fbank': + if len(utt.shape) == 1: + utt = utt.unsqueeze(0) + feat = kaldi.fbank(utt, **self.fbank_config) + elif self.feat_type == 'spec': + spec = torch.stft( + utt / 32768, + self.n_fft, + self.hop_length, + self.n_fft, + self.window, + center=False, + return_complex=True) + feat = torch.cat([spec.real, spec.imag], dim=-2).permute(-1, -2) + elif self.feat_type == 'logpow': + spec = torch.stft( + utt, + self.n_fft, + self.hop_length, + self.n_fft, + self.window, + center=False, + return_complex=True) + abspow = torch.abs(spec)**2 + feat = torch.log(1 + abspow).permute(-1, -2) + return feat + + def normalize(self, feat): + if self.mvn: + feat = feat + self.shift + feat = feat * self.scale + return feat + + +@PREPROCESSORS.register_module(Fields.audio) +class LinearAECAndFbank: + SAMPLE_RATE = 16000 + + def __init__(self, io_config): + self.trunc_length = 7200 * self.SAMPLE_RATE + self.linear_aec_delay = io_config['linear_aec_delay'] + self.feature = Feature(io_config['fbank_config'], + io_config['feat_type'], io_config['mvn']) + self.mitaec = load_library(io_config['mitaec_library']) + self.mask_on_mic = io_config['mask_on'] == 'nearend_mic' + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ linear filtering the near end mic and far end audio, then extract the feature + :param data: dict with two keys and correspond audios: "nearend_mic" and "farend_speech" + :return: dict with two keys and Tensor values: "base" linear filtered audio,and "feature" + """ + # read files + nearend_mic, fs = load_wav(data['nearend_mic']) + assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' + farend_speech, fs = load_wav(data['farend_speech']) + assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' + if 'nearend_speech' in data: + nearend_speech, fs = load_wav(data['nearend_speech']) + assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}' + else: + nearend_speech = np.zeros_like(nearend_mic) + + out_mic, out_ref, out_linear, out_echo = do_linear_aec( + self.mitaec, nearend_mic, farend_speech) + # fix 20ms linear aec delay by delaying the target speech + extra_zeros = np.zeros([int(self.linear_aec_delay * fs)]) + nearend_speech = np.concatenate([extra_zeros, nearend_speech]) + # truncate files to the same length + flen = min( + len(out_mic), len(out_ref), len(out_linear), len(out_echo), + len(nearend_speech)) + fstart = 0 + flen = min(flen, self.trunc_length) + nearend_mic, out_ref, out_linear, out_echo, nearend_speech = ( + out_mic[fstart:flen], out_ref[fstart:flen], + out_linear[fstart:flen], out_echo[fstart:flen], + nearend_speech[fstart:flen]) + + # extract features (frames, [mic, linear, ref, aes?]) + feat = torch.FloatTensor() + + nearend_mic = torch.from_numpy(np.float32(nearend_mic)) + fbank_nearend_mic = self.feature.compute(nearend_mic) + feat = torch.cat([feat, fbank_nearend_mic], dim=1) + + out_linear = torch.from_numpy(np.float32(out_linear)) + fbank_out_linear = self.feature.compute(out_linear) + feat = torch.cat([feat, fbank_out_linear], dim=1) + + out_echo = torch.from_numpy(np.float32(out_echo)) + fbank_out_echo = self.feature.compute(out_echo) + feat = torch.cat([feat, fbank_out_echo], dim=1) + + # feature transform + feat = self.feature.normalize(feat) + + # prepare target + if nearend_speech is not None: + nearend_speech = torch.from_numpy(np.float32(nearend_speech)) + + if self.mask_on_mic: + base = nearend_mic + else: + base = out_linear + out_data = {'base': base, 'target': nearend_speech, 'feature': feat} + return out_data diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 43684a06..dd5616a2 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -7,6 +7,7 @@ opencv-python-headless Pillow>=6.2.0 pyyaml requests +scipy tokenizers<=0.10.3 transformers<=4.16.2 yapf diff --git a/setup.cfg b/setup.cfg index 0b929b04..16c10cae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,7 @@ default_section = THIRDPARTY BASED_ON_STYLE = pep8 BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true +SPLIT_BEFORE_ARITHMETIC_OPERATOR = true [codespell] skip = *.ipynb @@ -20,5 +21,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids [flake8] select = B,C,E,F,P,T4,W,B9 max-line-length = 120 -ignore = F401,F821 +ignore = F401,F821,W503 exclude = docs/src,*.pyi,.git diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py new file mode 100644 index 00000000..8b5c9468 --- /dev/null +++ b/tests/pipelines/test_speech_signal_process.py @@ -0,0 +1,56 @@ +import os.path +import shutil +import unittest + +from modelscope.fileio import File +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import get_model_cache_dir + +NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' +FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' +NEAREND_MIC_FILE = 'nearend_mic.wav' +FAREND_SPEECH_FILE = 'farend_speech.wav' + +AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Flib%2Flibmitaec_pyio.so' \ + '?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' +AEC_LIB_FILE = 'libmitaec_pyio.so' + + +def download(remote_path, local_path): + local_dir = os.path.dirname(local_path) + if len(local_dir) > 0: + if not os.path.exists(local_dir): + os.makedirs(local_dir) + with open(local_path, 'wb') as ofile: + ofile.write(File.read(remote_path)) + + +class SpeechSignalProcessTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/speech_dfsmn_aec_psm_16k' + # switch to False if downloading everytime is not desired + purge_cache = True + if purge_cache: + shutil.rmtree( + get_model_cache_dir(self.model_id), ignore_errors=True) + # A temporary hack to provide c++ lib. Download it first. + download(AEC_LIB_URL, AEC_LIB_FILE) + + def test_run(self): + download(NEAREND_MIC_URL, NEAREND_MIC_FILE) + download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) + input = { + 'nearend_mic': NEAREND_MIC_FILE, + 'farend_speech': FAREND_SPEECH_FILE + } + aec = pipeline( + Tasks.speech_signal_process, + model=self.model_id, + pipeline_name=r'speech_dfsmn_aec_psm_16k') + aec(input, output_path='output.wav') + + +if __name__ == '__main__': + unittest.main()