diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py index b602ad01..afcdf314 100644 --- a/modelscope/models/audio/ans/__init__.py +++ b/modelscope/models/audio/ans/__init__.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .frcrn import FRCRNModel + from .frcrn import FRCRNDecorator else: _import_structure = { - 'frcrn': ['FRCRNModel'], + 'frcrn': ['FRCRNDecorator'], } import sys diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py index 69dec41e..c61446c2 100644 --- a/modelscope/models/audio/ans/complex_nn.py +++ b/modelscope/models/audio/ans/complex_nn.py @@ -1,3 +1,9 @@ +""" +class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d are the work of +Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ). +from https://github.com/sweetcocoa/DeepComplexUNetPyTorch + +""" import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/audio/ans/conv_stft.py b/modelscope/models/audio/ans/conv_stft.py index a47d7817..4b393a4c 100644 --- a/modelscope/models/audio/ans/conv_stft.py +++ b/modelscope/models/audio/ans/conv_stft.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch import torch.nn as nn diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py index 59411fbe..b74fc273 100644 --- a/modelscope/models/audio/ans/frcrn.py +++ b/modelscope/models/audio/ans/frcrn.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Dict @@ -14,54 +15,10 @@ from .conv_stft import ConviSTFT, ConvSTFT from .unet import UNet -class FTB(nn.Module): - - def __init__(self, input_dim=257, in_channel=9, r_channel=5): - - super(FTB, self).__init__() - self.in_channel = in_channel - self.conv1 = nn.Sequential( - nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), - nn.BatchNorm2d(r_channel), nn.ReLU()) - - self.conv1d = nn.Sequential( - nn.Conv1d( - r_channel * input_dim, in_channel, kernel_size=9, padding=4), - nn.BatchNorm1d(in_channel), nn.ReLU()) - self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) - - self.conv2 = nn.Sequential( - nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), - nn.BatchNorm2d(in_channel), nn.ReLU()) - - def forward(self, inputs): - ''' - inputs should be [Batch, Ca, Dim, Time] - ''' - # T-F attention - conv1_out = self.conv1(inputs) - B, C, D, T = conv1_out.size() - reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) - conv1d_out = self.conv1d(reshape1_out) - conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) - - # now is also [B,C,D,T] - att_out = conv1d_out * inputs - - # tranpose to [B,C,T,D] - att_out = torch.transpose(att_out, 2, 3) - freqfc_out = self.freq_fc(att_out) - att_out = torch.transpose(freqfc_out, 2, 3) - - cat_out = torch.cat([att_out, inputs], 1) - outputs = self.conv2(cat_out) - return outputs - - @MODELS.register_module( Tasks.acoustic_noise_suppression, module_name=Models.speech_frcrn_ans_cirm_16k) -class FRCRNModel(TorchModel): +class FRCRNDecorator(TorchModel): r""" A decorator of FRCRN for integrating into modelscope framework """ def __init__(self, model_dir: str, *args, **kwargs): @@ -78,13 +35,14 @@ class FRCRNModel(TorchModel): checkpoint = torch.load( model_bin_file, map_location=torch.device('cpu')) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - self.model.load_state_dict( - checkpoint['state_dict'], strict=False) + # the new trained model by user is based on FRCRNDecorator + self.load_state_dict(checkpoint['state_dict']) else: + # The released model on Modelscope is based on FRCRN self.model.load_state_dict(checkpoint, strict=False) - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - result_list = self.model.forward(input['noisy']) + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + result_list = self.model.forward(inputs['noisy']) output = { 'spec_l1': result_list[0], 'wav_l1': result_list[1], @@ -93,12 +51,12 @@ class FRCRNModel(TorchModel): 'wav_l2': result_list[4], 'mask_l2': result_list[5] } - if 'clean' in input: + if 'clean' in inputs: mix_result = self.model.loss( - input['noisy'], input['clean'], result_list, mode='Mix') + inputs['noisy'], inputs['clean'], result_list, mode='Mix') output.update(mix_result) sisnr_result = self.model.loss( - input['noisy'], input['clean'], result_list, mode='SiSNR') + inputs['noisy'], inputs['clean'], result_list, mode='SiSNR') output.update(sisnr_result) # logger hooker will use items under 'log_vars' output['log_vars'] = {k: mix_result[k].item() for k in mix_result} diff --git a/modelscope/models/audio/ans/se_module_complex.py b/modelscope/models/audio/ans/se_module_complex.py index f62fe523..b58eb6ba 100644 --- a/modelscope/models/audio/ans/se_module_complex.py +++ b/modelscope/models/audio/ans/se_module_complex.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch from torch import nn diff --git a/modelscope/models/audio/ans/unet.py b/modelscope/models/audio/ans/unet.py index aa5a4254..ae66eb69 100644 --- a/modelscope/models/audio/ans/unet.py +++ b/modelscope/models/audio/ans/unet.py @@ -1,3 +1,7 @@ +""" +Based on the work of Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ). +from https://github.com/sweetcocoa/DeepComplexUNetPyTorch +""" import torch import torch.nn as nn diff --git a/modelscope/trainers/audio/ans_trainer.py b/modelscope/trainers/audio/ans_trainer.py index f782b836..37b201ce 100644 --- a/modelscope/trainers/audio/ans_trainer.py +++ b/modelscope/trainers/audio/ans_trainer.py @@ -1,10 +1,5 @@ -import time -from typing import List, Optional, Union - -from datasets import Dataset - +# Copyright (c) Alibaba, Inc. and its affiliates. from modelscope.metainfo import Trainers -from modelscope.preprocessors import Preprocessor from modelscope.trainers import EpochBasedTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.constant import TrainerStages diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index 14374c65..61964345 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -1,5 +1,4 @@ -import numpy as np - +# Copyright (c) Alibaba, Inc. and its affiliates. SEGMENT_LENGTH_TRAIN = 16000 @@ -9,16 +8,13 @@ def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN): It only works in batch mode. """ noisy_arrays = [] - for x in batch['noisy']: - length = len(x['array']) - noisy = np.array(x['array']) - for offset in range(segment_length, length, segment_length): - noisy_arrays.append(noisy[offset - segment_length:offset]) clean_arrays = [] - for x in batch['clean']: - length = len(x['array']) - clean = np.array(x['array']) - for offset in range(segment_length, length, segment_length): + for x, y in zip(batch['noisy'], batch['clean']): + length = min(len(x['array']), len(y['array'])) + noisy = x['array'] + clean = y['array'] + for offset in range(segment_length, length + 1, segment_length): + noisy_arrays.append(noisy[offset - segment_length:offset]) clean_arrays.append(clean[offset - segment_length:offset]) return {'noisy': noisy_arrays, 'clean': clean_arrays}