bin.xue yingda.chen 3 years ago
parent
commit
3863efc14d
10 changed files with 721 additions and 38 deletions
  1. +3
    -0
      data/test/audios/noise_2ch.wav
  2. +3
    -0
      data/test/audios/wake_word_with_label_xyxy.wav
  3. +1
    -0
      modelscope/metainfo.py
  4. +26
    -37
      modelscope/models/audio/kws/farfield/model.py
  5. +21
    -0
      modelscope/msdatasets/task_datasets/audio/__init__.py
  6. +280
    -0
      modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py
  7. +279
    -0
      modelscope/trainers/audio/kws_farfield_trainer.py
  8. +18
    -0
      modelscope/utils/audio/audio_utils.py
  9. +5
    -1
      requirements/audio.txt
  10. +85
    -0
      tests/trainers/audio/test_kws_farfield_trainer.py

+ 3
- 0
data/test/audios/noise_2ch.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d
size 2327764

+ 3
- 0
data/test/audios/wake_word_with_label_xyxy.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150
size 68524

+ 1
- 0
modelscope/metainfo.py View File

@@ -285,6 +285,7 @@ class Trainers(object):


# audio trainers # audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'




class Preprocessors(object): class Preprocessors(object):


+ 26
- 37
modelscope/models/audio/kws/farfield/model.py View File

@@ -1,15 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os import os
from typing import Dict

import torch
from typing import Dict, Optional


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models import TorchModel from modelscope.models import TorchModel
from modelscope.models.base import Tensor from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.audio.audio_utils import update_conf
from modelscope.utils.constant import Tasks
from .fsmn_sele_v2 import FSMNSeleNetV2 from .fsmn_sele_v2 import FSMNSeleNetV2




@@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel):


MODEL_TXT = 'model.txt' MODEL_TXT = 'model.txt'
SC_CONFIG = 'sound_connect.conf' SC_CONFIG = 'sound_connect.conf'
SC_CONF_ITEM_KWS_MODEL = '${kws_model}'


def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self,
model_dir: str,
training: Optional[bool] = False,
*args,
**kwargs):
"""initialize the dfsmn model from the `model_dir` path. """initialize the dfsmn model from the `model_dir` path.


Args: Args:
model_dir (str): the model path. model_dir (str): the model path.
""" """
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
model_bin_file = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
self._model = None
if os.path.exists(model_bin_file):
kwargs.pop('device')
self._model = FSMNSeleNetV2(*args, **kwargs)
checkpoint = torch.load(model_bin_file)
self._model.load_state_dict(checkpoint, strict=False)

self._sc = None
if os.path.exists(model_txt_file):
with open(sc_config_file) as f:
lines = f.readlines()
with open(sc_config_file, 'w') as f:
for line in lines:
if self.SC_CONF_ITEM_KWS_MODEL in line:
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL,
model_txt_file)
f.write(line)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()

if self._model is None and self._sc is None:
raise Exception(
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.'
)
if training:
self.model = FSMNSeleNetV2(*args, **kwargs)
else:
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
self._sc = None
if os.path.exists(model_txt_file):
conf_dict = dict(mode=56542, kws_model=model_txt_file)
update_conf(sc_config_file, sc_config_file, conf_dict)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()
else:
raise Exception(
f'Invalid model directory! Failed to load model file: {model_txt_file}.'
)


def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
...
return self.model.forward(input)


def forward_decode(self, data: bytes): def forward_decode(self, data: bytes):
result = {'pcm': self._sc.process(data, self.size_out)} result = {'pcm': self._sc.process(data, self.size_out)}


+ 21
- 0
modelscope/msdatasets/task_datasets/audio/__init__.py View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .kws_farfield_dataset import KWSDataset, KWSDataLoader

else:
_import_structure = {
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'],
}
import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 280
- 0
modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py View File

@@ -0,0 +1,280 @@
"""
Used to prepare simulated data.
"""
import math
import os.path
import queue
import threading
import time

import numpy as np
import torch

from modelscope.utils.logger import get_logger

logger = get_logger()

BLOCK_DEC = 2
BLOCK_CAT = 3
FBANK_SIZE = 40
LABEL_SIZE = 1
LABEL_GAIN = 100.0


class KWSDataset:
"""
dataset for keyword spotting and vad
conf_basetrain: basetrain configure file path
conf_finetune: finetune configure file path, null allowed
numworkers: no. of workers
basetrainratio: basetrain workers ratio
numclasses: no. of nn output classes, 2 classes to generate vad label
blockdec: block decimation
blockcat: block concatenation
"""

def __init__(self,
conf_basetrain,
conf_finetune,
numworkers,
basetrainratio,
numclasses,
blockdec=BLOCK_CAT,
blockcat=BLOCK_CAT):
super().__init__()
self.numclasses = numclasses
self.blockdec = blockdec
self.blockcat = blockcat
self.sims_base = []
self.sims_senior = []
self.setup_sims(conf_basetrain, conf_finetune, numworkers,
basetrainratio)

def release(self):
for sim in self.sims_base:
del sim
for sim in self.sims_senior:
del sim
del self.base_conf
del self.senior_conf
logger.info('KWSDataset: Released.')

def setup_sims(self, conf_basetrain, conf_finetune, numworkers,
basetrainratio):
if not os.path.exists(conf_basetrain):
raise ValueError(f'{conf_basetrain} does not exist!')
if not os.path.exists(conf_finetune):
raise ValueError(f'{conf_finetune} does not exist!')
import py_sound_connect
logger.info('KWSDataset init SoundConnect...')
num_base = math.ceil(numworkers * basetrainratio)
num_senior = numworkers - num_base
# hold by fields to avoid python releasing conf object
self.base_conf = py_sound_connect.ConfigFile(conf_basetrain)
self.senior_conf = py_sound_connect.ConfigFile(conf_finetune)
for i in range(num_base):
fs = py_sound_connect.FeatSimuKWS(self.base_conf.params)
self.sims_base.append(fs)
for i in range(num_senior):
self.sims_senior.append(
py_sound_connect.FeatSimuKWS(self.senior_conf.params))
logger.info('KWSDataset init SoundConnect finished.')

def getBatch(self, id):
"""
Generate a data batch

Args:
id: worker id

Return: time x channel x feature, label
"""
fs = self.get_sim(id)
fs.processBatch()
# get multi-channel feature vector size
featsize = fs.featSize()
# get label vector size
labelsize = fs.labelSize()
# get minibatch size (time dimension)
# batchsize = fs.featBatchSize()
# no. of fe output channels
numchs = featsize // FBANK_SIZE
# get raw data
fs_feat = fs.feat()
data = np.frombuffer(fs_feat, dtype='float32')
data = data.reshape((-1, featsize + labelsize))

# convert float label to int
label = data[:, FBANK_SIZE * numchs:]

if self.numclasses == 2:
# generate vad label
label[label > 0.0] = 1.0
else:
# generate kws label
label = np.round(label * LABEL_GAIN)
label[label > self.numclasses - 1] = 0.0

# decimated size
size1 = int(np.ceil(
label.shape[0] / self.blockdec)) - self.blockcat + 1

# label decimation
label1 = np.zeros((size1, LABEL_SIZE), dtype='float32')
for tau in range(size1):
label1[tau, :] = label[(tau + self.blockcat // 2)
* self.blockdec, :]

# feature decimation and concatenation
# time x channel x feature
featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat),
dtype='float32')
for n in range(numchs):
feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)]

for tau in range(size1):
for i in range(self.blockcat):
featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \
feat[(tau + i) * self.blockdec, :]

return torch.from_numpy(featall), torch.from_numpy(label1).long()

def get_sim(self, id):
num_base = len(self.sims_base)
if id < num_base:
fs = self.sims_base[id]
else:
fs = self.sims_senior[id - num_base]
return fs


class Worker(threading.Thread):
"""
id: worker id
dataset: the dataset
pool: queue as the global data buffer
"""

def __init__(self, id, dataset, pool):
threading.Thread.__init__(self)

self.id = id
self.dataset = dataset
self.pool = pool
self.isrun = True
self.nn = 0

def run(self):
while self.isrun:
self.nn += 1
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1')
# get simulated minibatch
if self.isrun:
data = self.dataset.getBatch(self.id)
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2')

# put data into buffer
if self.isrun:
self.pool.put(data)
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3')

logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id))

def stopWorker(self):
"""
stop the worker thread
"""
self.isrun = False


class KWSDataLoader:
"""
dataset: the dataset reference
batchsize: data batch size
numworkers: no. of workers
prefetch: prefetch factor
"""

def __init__(self, dataset, batchsize, numworkers, prefetch=2):
self.dataset = dataset
self.batchsize = batchsize
self.datamap = {}
self.isrun = True

# data queue
self.pool = queue.Queue(batchsize * prefetch)

# initialize workers
self.workerlist = []
for id in range(numworkers):
w = Worker(id, dataset, self.pool)
self.workerlist.append(w)

def __iter__(self):
return self

def __next__(self):
while self.isrun:
# get data from common data pool
data = self.pool.get()
self.pool.task_done()

# group minibatches with the same shape
key = str(data[0].shape)

batchl = self.datamap.get(key)
if batchl is None:
batchl = []
self.datamap.update({key: batchl})

batchl.append(data)

# a full data batch collected
if len(batchl) >= self.batchsize:
featbatch = []
labelbatch = []

for feat, label in batchl:
featbatch.append(feat)
labelbatch.append(label)

batchl.clear()

feattensor = torch.stack(featbatch, dim=0)
labeltensor = torch.stack(labelbatch, dim=0)

if feattensor.shape[-2] == 1:
logger.debug('KWSDataLoader: Basetrain batch.')
else:
logger.debug('KWSDataLoader: Finetune batch.')

return feattensor, labeltensor

return None, None

def start(self):
"""
start multi-thread data loader
"""
for w in self.workerlist:
w.start()

def stop(self):
"""
stop data loader
"""
logger.info('KWSDataLoader: Stopping...')
self.isrun = False

for w in self.workerlist:
w.stopWorker()

while not self.pool.empty():
self.pool.get(block=True, timeout=0.001)

# wait workers terminated
for w in self.workerlist:
while not self.pool.empty():
self.pool.get(block=True, timeout=0.001)
w.join()
logger.info('KWSDataLoader: All worker stopped.')

+ 279
- 0
modelscope/trainers/audio/kws_farfield_trainer.py View File

@@ -0,0 +1,279 @@
import datetime
import math
import os
from typing import Callable, Dict, Optional

import numpy as np
import torch
from torch import nn as nn
from torch import optim as optim

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.audio.audio_utils import update_conf
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.data_utils import to_device
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, is_master)

logger = get_logger()

BASETRAIN_CONF_EASY = 'basetrain_easy'
BASETRAIN_CONF_NORMAL = 'basetrain_normal'
BASETRAIN_CONF_HARD = 'basetrain_hard'
FINETUNE_CONF_EASY = 'finetune_easy'
FINETUNE_CONF_NORMAL = 'finetune_normal'
FINETUNE_CONF_HARD = 'finetune_hard'

EASY_RATIO = 0.1
NORMAL_RATIO = 0.6
HARD_RATIO = 0.3
BASETRAIN_RATIO = 0.5


@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield)
class KWSFarfieldTrainer(BaseTrainer):
DEFAULT_WORK_DIR = './work_dir'
conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY,
BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL,
BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD)

def __init__(self,
model: str,
work_dir: str,
cfg_file: Optional[str] = None,
arg_parse_fn: Optional[Callable] = None,
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
custom_conf: Optional[dict] = None,
**kwargs):

if isinstance(model, str):
if os.path.exists(model):
self.model_dir = model if os.path.isdir(
model) else os.path.dirname(model)
else:
self.model_dir = snapshot_download(
model, revision=model_revision)
if cfg_file is None:
cfg_file = os.path.join(self.model_dir,
ModelFile.CONFIGURATION)
else:
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
self.model_dir = os.path.dirname(cfg_file)

super().__init__(cfg_file, arg_parse_fn)

self.model = self.build_model()
self.work_dir = work_dir
# the number of model output dimension
# should update config outside the trainer, if user need more wake word
self._num_classes = self.cfg.model.num_syn

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])

_, world_size = get_dist_info()
self._dist = world_size > 1

device_name = kwargs.get('device', 'gpu')
if self._dist:
local_rank = get_local_rank()
device_name = f'cuda:{local_rank}'

self.device = create_device(device_name)
# model placement
if self.device.type == 'cuda':
self.model.to(self.device)

if 'max_epochs' not in kwargs:
assert hasattr(
self.cfg.train, 'max_epochs'
), 'max_epochs is missing from the configuration file'
self._max_epochs = self.cfg.train.max_epochs
else:
self._max_epochs = kwargs['max_epochs']
self._train_iters = kwargs.get('train_iters_per_epoch', None)
self._val_iters = kwargs.get('val_iters_per_epoch', None)
if self._train_iters is None:
self._train_iters = self.cfg.train.train_iters_per_epoch
if self._val_iters is None:
self._val_iters = self.cfg.evaluation.val_iters_per_epoch
dataloader_config = self.cfg.train.dataloader
self._threads = kwargs.get('workers', None)
if self._threads is None:
self._threads = dataloader_config.workers_per_gpu
self._single_rate = BASETRAIN_RATIO
if 'single_rate' in kwargs:
self._single_rate = kwargs['single_rate']
self._batch_size = dataloader_config.batch_size_per_gpu
if 'model_bin' in kwargs:
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
checkpoint = torch.load(model_bin_file)
self.model.load_state_dict(checkpoint)
# build corresponding optimizer and loss function
lr = self.cfg.train.optimizer.lr
self.optimizer = optim.Adam(self.model.parameters(), lr)
self.loss_fn = nn.CrossEntropyLoss()
self.data_val = None
self.json_log_path = os.path.join(self.work_dir,
'{}.log.json'.format(self.timestamp))
self.conf_files = []
for conf_key in self.conf_keys:
template_file = os.path.join(self.model_dir, conf_key)
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf')
update_conf(template_file, conf_file, custom_conf[conf_key])
self.conf_files.append(conf_file)
self._current_epoch = 0
self.stages = (math.floor(self._max_epochs * EASY_RATIO),
math.floor(self._max_epochs * NORMAL_RATIO),
math.floor(self._max_epochs * HARD_RATIO))

def build_model(self) -> nn.Module:
""" Instantiate a pytorch model and return.

By default, we will create a model using config from configuration file. You can
override this method in a subclass.

"""
model = Model.from_pretrained(
self.model_dir, cfg_dict=self.cfg, training=True)
if isinstance(model, TorchModel) and hasattr(model, 'model'):
return model.model
elif isinstance(model, nn.Module):
return model

def train(self, *args, **kwargs):
if not self.data_val:
self.gen_val()
logger.info('Start training...')
totaltime = datetime.datetime.now()

for stage, num_epoch in enumerate(self.stages):
self.run_stage(stage, num_epoch)

# total time spent
totaltime = datetime.datetime.now() - totaltime
logger.info('Total time spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))

def run_stage(self, stage, num_epoch):
"""
Run training stages with correspond data

Args:
stage: id of stage
num_epoch: the number of epoch to run in this stage
"""
if num_epoch <= 0:
logger.warning(f'Invalid epoch number, stage {stage} exit!')
return
logger.info(f'Starting stage {stage}...')
dataset, dataloader = self.create_dataloader(
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
it = iter(dataloader)
for _ in range(num_epoch):
self._current_epoch += 1
epochtime = datetime.datetime.now()
logger.info('Start epoch %d...', self._current_epoch)
loss_train_epoch = 0.0
validbatchs = 0
for bi in range(self._train_iters):
# prepare data
feat, label = next(it)
label = torch.reshape(label, (-1, ))
feat = to_device(feat, self.device)
label = to_device(label, self.device)
# apply model
self.optimizer.zero_grad()
predict = self.model(feat)
# calculate loss
loss = self.loss_fn(
torch.reshape(predict, (-1, self._num_classes)), label)
if not np.isnan(loss.item()):
loss.backward()
self.optimizer.step()
loss_train_epoch += loss.item()
validbatchs += 1
train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format(
self._current_epoch, self._max_epochs, bi + 1,
self._train_iters, loss.item())
logger.info(train_result)
self._dump_log(train_result)

# average training loss in one epoch
loss_train_epoch /= validbatchs
loss_val_epoch = self.evaluate('')
val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format(
self._current_epoch, loss_train_epoch, loss_val_epoch)
logger.info(val_result)
self._dump_log(val_result)
# check point
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
self._current_epoch, loss_train_epoch, loss_val_epoch)
torch.save(self.model, os.path.join(self.work_dir, ckpt_name))
# time spent per epoch
epochtime = datetime.datetime.now() - epochtime
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format(
self._current_epoch,
epochtime.total_seconds() / 3600.0))
dataloader.stop()
dataset.release()
logger.info(f'Stage {stage} is finished.')

def gen_val(self):
"""
generate validation set
"""
logger.info('Start generating validation set...')
dataset, dataloader = self.create_dataloader(self.conf_files[2],
self.conf_files[3])
it = iter(dataloader)

self.data_val = []
for bi in range(self._val_iters):
logger.info('Iterating validation data %d', bi)
feat, label = next(it)
label = torch.reshape(label, (-1, ))
self.data_val.append([feat, label])

dataloader.stop()
dataset.release()
logger.info('Finish generating validation set!')

def create_dataloader(self, base_path, finetune_path):
dataset = KWSDataset(base_path, finetune_path, self._threads,
self._single_rate, self._num_classes)
dataloader = KWSDataLoader(
dataset, batchsize=self._batch_size, numworkers=self._threads)
dataloader.start()
return dataset, dataloader

def evaluate(self, checkpoint_path: str, *args,
**kwargs) -> Dict[str, float]:
logger.info('Start validation...')
loss_val_epoch = 0.0

with torch.no_grad():
for feat, label in self.data_val:
feat = to_device(feat, self.device)
label = to_device(label, self.device)
# apply model
predict = self.model(feat)
# calculate loss
loss = self.loss_fn(
torch.reshape(predict, (-1, self._num_classes)), label)
loss_val_epoch += loss.item()
logger.info('Finish validation.')
return loss_val_epoch / self._val_iters

def _dump_log(self, msg):
if is_master():
with open(self.json_log_path, 'a+') as f:
f.write(msg)
f.write('\n')

+ 18
- 0
modelscope/utils/audio/audio_utils.py View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import re
import struct import struct
from typing import Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -37,6 +38,23 @@ def audio_norm(x):
return x return x




def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):

def repl(matched):
key = matched.group(1)
if key in conf_item:
return conf_item[key]
else:
return None

with open(origin_config_file) as f:
lines = f.readlines()
with open(new_config_file, 'w') as f:
for line in lines:
line = re.sub(r'\$\{(.*)\}', repl, line)
f.write(line)


def extract_pcm_from_wav(wav: bytes) -> bytes: def extract_pcm_from_wav(wav: bytes) -> bytes:
data = wav data = wav
if len(data) > 44: if len(data) > 44:


+ 5
- 1
requirements/audio.txt View File

@@ -14,7 +14,11 @@ nltk
numpy<=1.18 numpy<=1.18
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
protobuf>3,<3.21.0 protobuf>3,<3.21.0
py_sound_connect
ptflops
py_sound_connect>=0.1
pytorch_wavelets
PyWavelets>=1.0.0
scikit-learn
SoundFile>0.10 SoundFile>0.10
sox sox
torchaudio torchaudio


+ 85
- 0
tests/trainers/audio/test_kws_farfield_trainer.py View File

@@ -0,0 +1,85 @@
import os
import shutil
import tempfile
import unittest

from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.utils.test_utils import test_level

POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav'
NEG_FILE = 'data/test/audios/speech_with_noise.wav'
NOISE_FILE = 'data/test/audios/speech_with_noise.wav'
INTERF_FILE = 'data/test/audios/speech_with_noise.wav'
REF_FILE = 'data/test/audios/farend_speech.wav'
NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav'


class TestKwsFarfieldTrainer(unittest.TestCase):
REVISION = 'beta'

def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory().name
print(f'tmp dir: {self.tmp_dir}')
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'

train_pos_list = self.create_list('pos.list', POS_FILE)
train_neg_list = self.create_list('neg.list', NEG_FILE)
train_noise1_list = self.create_list('noise.list', NOISE_FILE)
train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE)
train_interf_list = self.create_list('interf.list', INTERF_FILE)
train_ref_list = self.create_list('ref.list', REF_FILE)

base_dict = dict(
train_pos_list=train_pos_list,
train_neg_list=train_neg_list,
train_noise1_list=train_noise1_list)
fintune_dict = dict(
train_pos_list=train_pos_list,
train_neg_list=train_neg_list,
train_noise1_list=train_noise1_list,
train_noise2_type='1',
train_noise1_ratio='0.2',
train_noise2_list=train_noise2_list,
train_interf_list=train_interf_list,
train_ref_list=train_ref_list)
self.custom_conf = dict(
basetrain_easy=base_dict,
basetrain_normal=base_dict,
basetrain_hard=base_dict,
finetune_easy=fintune_dict,
finetune_normal=fintune_dict,
finetune_hard=fintune_dict)

def create_list(self, list_name, audio_file):
pos_list_file = os.path.join(self.tmp_dir, list_name)
with open(pos_list_file, 'w') as f:
for i in range(10):
f.write(f'{os.path.join(os.getcwd(), audio_file)}\n')
train_pos_list = f'{pos_list_file}, 1.0'
return train_pos_list

def tearDown(self) -> None:
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal(self):
kwargs = dict(
model=self.model_id,
work_dir=self.tmp_dir,
model_revision=self.REVISION,
workers=2,
max_epochs=2,
train_iters_per_epoch=2,
val_iters_per_epoch=1,
custom_conf=self.custom_conf)

trainer = build_trainer(
Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files,
f'work_dir:{self.tmp_dir}')

Loading…
Cancel
Save