Browse Source

[to #42322933] fix bug about loading new trained model and update doc string

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9987197
master
bin.xue 3 years ago
parent
commit
f5fb8cf531
8 changed files with 32 additions and 71 deletions
  1. +2
    -2
      modelscope/models/audio/ans/__init__.py
  2. +6
    -0
      modelscope/models/audio/ans/complex_nn.py
  3. +1
    -0
      modelscope/models/audio/ans/conv_stft.py
  4. +10
    -52
      modelscope/models/audio/ans/frcrn.py
  5. +1
    -0
      modelscope/models/audio/ans/se_module_complex.py
  6. +4
    -0
      modelscope/models/audio/ans/unet.py
  7. +1
    -6
      modelscope/trainers/audio/ans_trainer.py
  8. +7
    -11
      modelscope/utils/audio/audio_utils.py

+ 2
- 2
modelscope/models/audio/ans/__init__.py View File

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule from modelscope.utils.import_utils import LazyImportModule


if TYPE_CHECKING: if TYPE_CHECKING:
from .frcrn import FRCRNModel
from .frcrn import FRCRNDecorator


else: else:
_import_structure = { _import_structure = {
'frcrn': ['FRCRNModel'],
'frcrn': ['FRCRNDecorator'],
} }


import sys import sys


+ 6
- 0
modelscope/models/audio/ans/complex_nn.py View File

@@ -1,3 +1,9 @@
"""
class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d are the work of
Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch

"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


+ 1
- 0
modelscope/models/audio/ans/conv_stft.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn


+ 10
- 52
modelscope/models/audio/ans/frcrn.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
from typing import Dict from typing import Dict


@@ -14,54 +15,10 @@ from .conv_stft import ConviSTFT, ConvSTFT
from .unet import UNet from .unet import UNet




class FTB(nn.Module):

def __init__(self, input_dim=257, in_channel=9, r_channel=5):

super(FTB, self).__init__()
self.in_channel = in_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(r_channel), nn.ReLU())

self.conv1d = nn.Sequential(
nn.Conv1d(
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
nn.BatchNorm1d(in_channel), nn.ReLU())
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)

self.conv2 = nn.Sequential(
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(in_channel), nn.ReLU())

def forward(self, inputs):
'''
inputs should be [Batch, Ca, Dim, Time]
'''
# T-F attention
conv1_out = self.conv1(inputs)
B, C, D, T = conv1_out.size()
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
conv1d_out = self.conv1d(reshape1_out)
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])

# now is also [B,C,D,T]
att_out = conv1d_out * inputs

# tranpose to [B,C,T,D]
att_out = torch.transpose(att_out, 2, 3)
freqfc_out = self.freq_fc(att_out)
att_out = torch.transpose(freqfc_out, 2, 3)

cat_out = torch.cat([att_out, inputs], 1)
outputs = self.conv2(cat_out)
return outputs


@MODELS.register_module( @MODELS.register_module(
Tasks.acoustic_noise_suppression, Tasks.acoustic_noise_suppression,
module_name=Models.speech_frcrn_ans_cirm_16k) module_name=Models.speech_frcrn_ans_cirm_16k)
class FRCRNModel(TorchModel):
class FRCRNDecorator(TorchModel):
r""" A decorator of FRCRN for integrating into modelscope framework """ r""" A decorator of FRCRN for integrating into modelscope framework """


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
@@ -78,13 +35,14 @@ class FRCRNModel(TorchModel):
checkpoint = torch.load( checkpoint = torch.load(
model_bin_file, map_location=torch.device('cpu')) model_bin_file, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
self.model.load_state_dict(
checkpoint['state_dict'], strict=False)
# the new trained model by user is based on FRCRNDecorator
self.load_state_dict(checkpoint['state_dict'])
else: else:
# The released model on Modelscope is based on FRCRN
self.model.load_state_dict(checkpoint, strict=False) self.model.load_state_dict(checkpoint, strict=False)


def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
result_list = self.model.forward(input['noisy'])
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
result_list = self.model.forward(inputs['noisy'])
output = { output = {
'spec_l1': result_list[0], 'spec_l1': result_list[0],
'wav_l1': result_list[1], 'wav_l1': result_list[1],
@@ -93,12 +51,12 @@ class FRCRNModel(TorchModel):
'wav_l2': result_list[4], 'wav_l2': result_list[4],
'mask_l2': result_list[5] 'mask_l2': result_list[5]
} }
if 'clean' in input:
if 'clean' in inputs:
mix_result = self.model.loss( mix_result = self.model.loss(
input['noisy'], input['clean'], result_list, mode='Mix')
inputs['noisy'], inputs['clean'], result_list, mode='Mix')
output.update(mix_result) output.update(mix_result)
sisnr_result = self.model.loss( sisnr_result = self.model.loss(
input['noisy'], input['clean'], result_list, mode='SiSNR')
inputs['noisy'], inputs['clean'], result_list, mode='SiSNR')
output.update(sisnr_result) output.update(sisnr_result)
# logger hooker will use items under 'log_vars' # logger hooker will use items under 'log_vars'
output['log_vars'] = {k: mix_result[k].item() for k in mix_result} output['log_vars'] = {k: mix_result[k].item() for k in mix_result}


+ 1
- 0
modelscope/models/audio/ans/se_module_complex.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch import torch
from torch import nn from torch import nn




+ 4
- 0
modelscope/models/audio/ans/unet.py View File

@@ -1,3 +1,7 @@
"""
Based on the work of Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch
"""
import torch import torch
import torch.nn as nn import torch.nn as nn




+ 1
- 6
modelscope/trainers/audio/ans_trainer.py View File

@@ -1,10 +1,5 @@
import time
from typing import List, Optional, Union

from datasets import Dataset

# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Trainers from modelscope.metainfo import Trainers
from modelscope.preprocessors import Preprocessor
from modelscope.trainers import EpochBasedTrainer from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.builder import TRAINERS from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import TrainerStages from modelscope.utils.constant import TrainerStages


+ 7
- 11
modelscope/utils/audio/audio_utils.py View File

@@ -1,5 +1,4 @@
import numpy as np

# Copyright (c) Alibaba, Inc. and its affiliates.
SEGMENT_LENGTH_TRAIN = 16000 SEGMENT_LENGTH_TRAIN = 16000




@@ -9,16 +8,13 @@ def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
It only works in batch mode. It only works in batch mode.
""" """
noisy_arrays = [] noisy_arrays = []
for x in batch['noisy']:
length = len(x['array'])
noisy = np.array(x['array'])
for offset in range(segment_length, length, segment_length):
noisy_arrays.append(noisy[offset - segment_length:offset])
clean_arrays = [] clean_arrays = []
for x in batch['clean']:
length = len(x['array'])
clean = np.array(x['array'])
for offset in range(segment_length, length, segment_length):
for x, y in zip(batch['noisy'], batch['clean']):
length = min(len(x['array']), len(y['array']))
noisy = x['array']
clean = y['array']
for offset in range(segment_length, length + 1, segment_length):
noisy_arrays.append(noisy[offset - segment_length:offset])
clean_arrays.append(clean[offset - segment_length:offset]) clean_arrays.append(clean[offset - segment_length:offset])
return {'noisy': noisy_arrays, 'clean': clean_arrays} return {'noisy': noisy_arrays, 'clean': clean_arrays}




Loading…
Cancel
Save