Browse Source

[to #42322933] fix bug about loading new trained model and update doc string

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9987197
master
bin.xue 3 years ago
parent
commit
f5fb8cf531
8 changed files with 32 additions and 71 deletions
  1. +2
    -2
      modelscope/models/audio/ans/__init__.py
  2. +6
    -0
      modelscope/models/audio/ans/complex_nn.py
  3. +1
    -0
      modelscope/models/audio/ans/conv_stft.py
  4. +10
    -52
      modelscope/models/audio/ans/frcrn.py
  5. +1
    -0
      modelscope/models/audio/ans/se_module_complex.py
  6. +4
    -0
      modelscope/models/audio/ans/unet.py
  7. +1
    -6
      modelscope/trainers/audio/ans_trainer.py
  8. +7
    -11
      modelscope/utils/audio/audio_utils.py

+ 2
- 2
modelscope/models/audio/ans/__init__.py View File

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .frcrn import FRCRNModel
from .frcrn import FRCRNDecorator

else:
_import_structure = {
'frcrn': ['FRCRNModel'],
'frcrn': ['FRCRNDecorator'],
}

import sys


+ 6
- 0
modelscope/models/audio/ans/complex_nn.py View File

@@ -1,3 +1,9 @@
"""
class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d are the work of
Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch

"""
import torch
import torch.nn as nn
import torch.nn.functional as F


+ 1
- 0
modelscope/models/audio/ans/conv_stft.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn


+ 10
- 52
modelscope/models/audio/ans/frcrn.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Dict

@@ -14,54 +15,10 @@ from .conv_stft import ConviSTFT, ConvSTFT
from .unet import UNet


class FTB(nn.Module):

def __init__(self, input_dim=257, in_channel=9, r_channel=5):

super(FTB, self).__init__()
self.in_channel = in_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(r_channel), nn.ReLU())

self.conv1d = nn.Sequential(
nn.Conv1d(
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
nn.BatchNorm1d(in_channel), nn.ReLU())
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)

self.conv2 = nn.Sequential(
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(in_channel), nn.ReLU())

def forward(self, inputs):
'''
inputs should be [Batch, Ca, Dim, Time]
'''
# T-F attention
conv1_out = self.conv1(inputs)
B, C, D, T = conv1_out.size()
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
conv1d_out = self.conv1d(reshape1_out)
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])

# now is also [B,C,D,T]
att_out = conv1d_out * inputs

# tranpose to [B,C,T,D]
att_out = torch.transpose(att_out, 2, 3)
freqfc_out = self.freq_fc(att_out)
att_out = torch.transpose(freqfc_out, 2, 3)

cat_out = torch.cat([att_out, inputs], 1)
outputs = self.conv2(cat_out)
return outputs


@MODELS.register_module(
Tasks.acoustic_noise_suppression,
module_name=Models.speech_frcrn_ans_cirm_16k)
class FRCRNModel(TorchModel):
class FRCRNDecorator(TorchModel):
r""" A decorator of FRCRN for integrating into modelscope framework """

def __init__(self, model_dir: str, *args, **kwargs):
@@ -78,13 +35,14 @@ class FRCRNModel(TorchModel):
checkpoint = torch.load(
model_bin_file, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
self.model.load_state_dict(
checkpoint['state_dict'], strict=False)
# the new trained model by user is based on FRCRNDecorator
self.load_state_dict(checkpoint['state_dict'])
else:
# The released model on Modelscope is based on FRCRN
self.model.load_state_dict(checkpoint, strict=False)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
result_list = self.model.forward(input['noisy'])
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
result_list = self.model.forward(inputs['noisy'])
output = {
'spec_l1': result_list[0],
'wav_l1': result_list[1],
@@ -93,12 +51,12 @@ class FRCRNModel(TorchModel):
'wav_l2': result_list[4],
'mask_l2': result_list[5]
}
if 'clean' in input:
if 'clean' in inputs:
mix_result = self.model.loss(
input['noisy'], input['clean'], result_list, mode='Mix')
inputs['noisy'], inputs['clean'], result_list, mode='Mix')
output.update(mix_result)
sisnr_result = self.model.loss(
input['noisy'], input['clean'], result_list, mode='SiSNR')
inputs['noisy'], inputs['clean'], result_list, mode='SiSNR')
output.update(sisnr_result)
# logger hooker will use items under 'log_vars'
output['log_vars'] = {k: mix_result[k].item() for k in mix_result}


+ 1
- 0
modelscope/models/audio/ans/se_module_complex.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from torch import nn



+ 4
- 0
modelscope/models/audio/ans/unet.py View File

@@ -1,3 +1,7 @@
"""
Based on the work of Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch
"""
import torch
import torch.nn as nn



+ 1
- 6
modelscope/trainers/audio/ans_trainer.py View File

@@ -1,10 +1,5 @@
import time
from typing import List, Optional, Union

from datasets import Dataset

# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Trainers
from modelscope.preprocessors import Preprocessor
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import TrainerStages


+ 7
- 11
modelscope/utils/audio/audio_utils.py View File

@@ -1,5 +1,4 @@
import numpy as np

# Copyright (c) Alibaba, Inc. and its affiliates.
SEGMENT_LENGTH_TRAIN = 16000


@@ -9,16 +8,13 @@ def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
It only works in batch mode.
"""
noisy_arrays = []
for x in batch['noisy']:
length = len(x['array'])
noisy = np.array(x['array'])
for offset in range(segment_length, length, segment_length):
noisy_arrays.append(noisy[offset - segment_length:offset])
clean_arrays = []
for x in batch['clean']:
length = len(x['array'])
clean = np.array(x['array'])
for offset in range(segment_length, length, segment_length):
for x, y in zip(batch['noisy'], batch['clean']):
length = min(len(x['array']), len(y['array']))
noisy = x['array']
clean = y['array']
for offset in range(segment_length, length + 1, segment_length):
noisy_arrays.append(noisy[offset - segment_length:offset])
clean_arrays.append(clean[offset - segment_length:offset])
return {'noisy': noisy_arrays, 'clean': clean_arrays}



Loading…
Cancel
Save