|
@@ -1,3 +1,4 @@ |
|
|
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
import os |
|
|
import os |
|
|
from typing import Dict |
|
|
from typing import Dict |
|
|
|
|
|
|
|
@@ -14,54 +15,10 @@ from .conv_stft import ConviSTFT, ConvSTFT |
|
|
from .unet import UNet |
|
|
from .unet import UNet |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FTB(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dim=257, in_channel=9, r_channel=5): |
|
|
|
|
|
|
|
|
|
|
|
super(FTB, self).__init__() |
|
|
|
|
|
self.in_channel = in_channel |
|
|
|
|
|
self.conv1 = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), |
|
|
|
|
|
nn.BatchNorm2d(r_channel), nn.ReLU()) |
|
|
|
|
|
|
|
|
|
|
|
self.conv1d = nn.Sequential( |
|
|
|
|
|
nn.Conv1d( |
|
|
|
|
|
r_channel * input_dim, in_channel, kernel_size=9, padding=4), |
|
|
|
|
|
nn.BatchNorm1d(in_channel), nn.ReLU()) |
|
|
|
|
|
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
self.conv2 = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), |
|
|
|
|
|
nn.BatchNorm2d(in_channel), nn.ReLU()) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
|
|
''' |
|
|
|
|
|
inputs should be [Batch, Ca, Dim, Time] |
|
|
|
|
|
''' |
|
|
|
|
|
# T-F attention |
|
|
|
|
|
conv1_out = self.conv1(inputs) |
|
|
|
|
|
B, C, D, T = conv1_out.size() |
|
|
|
|
|
reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) |
|
|
|
|
|
conv1d_out = self.conv1d(reshape1_out) |
|
|
|
|
|
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) |
|
|
|
|
|
|
|
|
|
|
|
# now is also [B,C,D,T] |
|
|
|
|
|
att_out = conv1d_out * inputs |
|
|
|
|
|
|
|
|
|
|
|
# tranpose to [B,C,T,D] |
|
|
|
|
|
att_out = torch.transpose(att_out, 2, 3) |
|
|
|
|
|
freqfc_out = self.freq_fc(att_out) |
|
|
|
|
|
att_out = torch.transpose(freqfc_out, 2, 3) |
|
|
|
|
|
|
|
|
|
|
|
cat_out = torch.cat([att_out, inputs], 1) |
|
|
|
|
|
outputs = self.conv2(cat_out) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module( |
|
|
@MODELS.register_module( |
|
|
Tasks.acoustic_noise_suppression, |
|
|
Tasks.acoustic_noise_suppression, |
|
|
module_name=Models.speech_frcrn_ans_cirm_16k) |
|
|
module_name=Models.speech_frcrn_ans_cirm_16k) |
|
|
class FRCRNModel(TorchModel): |
|
|
|
|
|
|
|
|
class FRCRNDecorator(TorchModel): |
|
|
r""" A decorator of FRCRN for integrating into modelscope framework """ |
|
|
r""" A decorator of FRCRN for integrating into modelscope framework """ |
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
@@ -78,13 +35,14 @@ class FRCRNModel(TorchModel): |
|
|
checkpoint = torch.load( |
|
|
checkpoint = torch.load( |
|
|
model_bin_file, map_location=torch.device('cpu')) |
|
|
model_bin_file, map_location=torch.device('cpu')) |
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
|
self.model.load_state_dict( |
|
|
|
|
|
checkpoint['state_dict'], strict=False) |
|
|
|
|
|
|
|
|
# the new trained model by user is based on FRCRNDecorator |
|
|
|
|
|
self.load_state_dict(checkpoint['state_dict']) |
|
|
else: |
|
|
else: |
|
|
|
|
|
# The released model on Modelscope is based on FRCRN |
|
|
self.model.load_state_dict(checkpoint, strict=False) |
|
|
self.model.load_state_dict(checkpoint, strict=False) |
|
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
|
|
result_list = self.model.forward(input['noisy']) |
|
|
|
|
|
|
|
|
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
|
|
result_list = self.model.forward(inputs['noisy']) |
|
|
output = { |
|
|
output = { |
|
|
'spec_l1': result_list[0], |
|
|
'spec_l1': result_list[0], |
|
|
'wav_l1': result_list[1], |
|
|
'wav_l1': result_list[1], |
|
@@ -93,12 +51,12 @@ class FRCRNModel(TorchModel): |
|
|
'wav_l2': result_list[4], |
|
|
'wav_l2': result_list[4], |
|
|
'mask_l2': result_list[5] |
|
|
'mask_l2': result_list[5] |
|
|
} |
|
|
} |
|
|
if 'clean' in input: |
|
|
|
|
|
|
|
|
if 'clean' in inputs: |
|
|
mix_result = self.model.loss( |
|
|
mix_result = self.model.loss( |
|
|
input['noisy'], input['clean'], result_list, mode='Mix') |
|
|
|
|
|
|
|
|
inputs['noisy'], inputs['clean'], result_list, mode='Mix') |
|
|
output.update(mix_result) |
|
|
output.update(mix_result) |
|
|
sisnr_result = self.model.loss( |
|
|
sisnr_result = self.model.loss( |
|
|
input['noisy'], input['clean'], result_list, mode='SiSNR') |
|
|
|
|
|
|
|
|
inputs['noisy'], inputs['clean'], result_list, mode='SiSNR') |
|
|
output.update(sisnr_result) |
|
|
output.update(sisnr_result) |
|
|
# logger hooker will use items under 'log_vars' |
|
|
# logger hooker will use items under 'log_vars' |
|
|
output['log_vars'] = {k: mix_result[k].item() for k in mix_result} |
|
|
output['log_vars'] = {k: mix_result[k].item() for k in mix_result} |
|
|