Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10275823master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||
size 2327764 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||
size 68524 |
@@ -285,6 +285,7 @@ class Trainers(object): | |||
# audio trainers | |||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
class Preprocessors(object): | |||
@@ -1,15 +1,14 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Dict | |||
import torch | |||
from typing import Dict, Optional | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.audio.audio_utils import update_conf | |||
from modelscope.utils.constant import Tasks | |||
from .fsmn_sele_v2 import FSMNSeleNetV2 | |||
@@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
MODEL_TXT = 'model.txt' | |||
SC_CONFIG = 'sound_connect.conf' | |||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
def __init__(self, | |||
model_dir: str, | |||
training: Optional[bool] = False, | |||
*args, | |||
**kwargs): | |||
"""initialize the dfsmn model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
model_bin_file = os.path.join(model_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
self._model = None | |||
if os.path.exists(model_bin_file): | |||
kwargs.pop('device') | |||
self._model = FSMNSeleNetV2(*args, **kwargs) | |||
checkpoint = torch.load(model_bin_file) | |||
self._model.load_state_dict(checkpoint, strict=False) | |||
self._sc = None | |||
if os.path.exists(model_txt_file): | |||
with open(sc_config_file) as f: | |||
lines = f.readlines() | |||
with open(sc_config_file, 'w') as f: | |||
for line in lines: | |||
if self.SC_CONF_ITEM_KWS_MODEL in line: | |||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||
model_txt_file) | |||
f.write(line) | |||
import py_sound_connect | |||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
self.size_in = self._sc.bytesPerBlockIn() | |||
self.size_out = self._sc.bytesPerBlockOut() | |||
if self._model is None and self._sc is None: | |||
raise Exception( | |||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||
) | |||
if training: | |||
self.model = FSMNSeleNetV2(*args, **kwargs) | |||
else: | |||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
self._sc = None | |||
if os.path.exists(model_txt_file): | |||
conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||
update_conf(sc_config_file, sc_config_file, conf_dict) | |||
import py_sound_connect | |||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
self.size_in = self._sc.bytesPerBlockIn() | |||
self.size_out = self._sc.bytesPerBlockOut() | |||
else: | |||
raise Exception( | |||
f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||
) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
... | |||
return self.model.forward(input) | |||
def forward_decode(self, data: bytes): | |||
result = {'pcm': self._sc.process(data, self.size_out)} | |||
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .kws_farfield_dataset import KWSDataset, KWSDataLoader | |||
else: | |||
_import_structure = { | |||
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,280 @@ | |||
""" | |||
Used to prepare simulated data. | |||
""" | |||
import math | |||
import os.path | |||
import queue | |||
import threading | |||
import time | |||
import numpy as np | |||
import torch | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
BLOCK_DEC = 2 | |||
BLOCK_CAT = 3 | |||
FBANK_SIZE = 40 | |||
LABEL_SIZE = 1 | |||
LABEL_GAIN = 100.0 | |||
class KWSDataset: | |||
""" | |||
dataset for keyword spotting and vad | |||
conf_basetrain: basetrain configure file path | |||
conf_finetune: finetune configure file path, null allowed | |||
numworkers: no. of workers | |||
basetrainratio: basetrain workers ratio | |||
numclasses: no. of nn output classes, 2 classes to generate vad label | |||
blockdec: block decimation | |||
blockcat: block concatenation | |||
""" | |||
def __init__(self, | |||
conf_basetrain, | |||
conf_finetune, | |||
numworkers, | |||
basetrainratio, | |||
numclasses, | |||
blockdec=BLOCK_CAT, | |||
blockcat=BLOCK_CAT): | |||
super().__init__() | |||
self.numclasses = numclasses | |||
self.blockdec = blockdec | |||
self.blockcat = blockcat | |||
self.sims_base = [] | |||
self.sims_senior = [] | |||
self.setup_sims(conf_basetrain, conf_finetune, numworkers, | |||
basetrainratio) | |||
def release(self): | |||
for sim in self.sims_base: | |||
del sim | |||
for sim in self.sims_senior: | |||
del sim | |||
del self.base_conf | |||
del self.senior_conf | |||
logger.info('KWSDataset: Released.') | |||
def setup_sims(self, conf_basetrain, conf_finetune, numworkers, | |||
basetrainratio): | |||
if not os.path.exists(conf_basetrain): | |||
raise ValueError(f'{conf_basetrain} does not exist!') | |||
if not os.path.exists(conf_finetune): | |||
raise ValueError(f'{conf_finetune} does not exist!') | |||
import py_sound_connect | |||
logger.info('KWSDataset init SoundConnect...') | |||
num_base = math.ceil(numworkers * basetrainratio) | |||
num_senior = numworkers - num_base | |||
# hold by fields to avoid python releasing conf object | |||
self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) | |||
self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) | |||
for i in range(num_base): | |||
fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) | |||
self.sims_base.append(fs) | |||
for i in range(num_senior): | |||
self.sims_senior.append( | |||
py_sound_connect.FeatSimuKWS(self.senior_conf.params)) | |||
logger.info('KWSDataset init SoundConnect finished.') | |||
def getBatch(self, id): | |||
""" | |||
Generate a data batch | |||
Args: | |||
id: worker id | |||
Return: time x channel x feature, label | |||
""" | |||
fs = self.get_sim(id) | |||
fs.processBatch() | |||
# get multi-channel feature vector size | |||
featsize = fs.featSize() | |||
# get label vector size | |||
labelsize = fs.labelSize() | |||
# get minibatch size (time dimension) | |||
# batchsize = fs.featBatchSize() | |||
# no. of fe output channels | |||
numchs = featsize // FBANK_SIZE | |||
# get raw data | |||
fs_feat = fs.feat() | |||
data = np.frombuffer(fs_feat, dtype='float32') | |||
data = data.reshape((-1, featsize + labelsize)) | |||
# convert float label to int | |||
label = data[:, FBANK_SIZE * numchs:] | |||
if self.numclasses == 2: | |||
# generate vad label | |||
label[label > 0.0] = 1.0 | |||
else: | |||
# generate kws label | |||
label = np.round(label * LABEL_GAIN) | |||
label[label > self.numclasses - 1] = 0.0 | |||
# decimated size | |||
size1 = int(np.ceil( | |||
label.shape[0] / self.blockdec)) - self.blockcat + 1 | |||
# label decimation | |||
label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') | |||
for tau in range(size1): | |||
label1[tau, :] = label[(tau + self.blockcat // 2) | |||
* self.blockdec, :] | |||
# feature decimation and concatenation | |||
# time x channel x feature | |||
featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), | |||
dtype='float32') | |||
for n in range(numchs): | |||
feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] | |||
for tau in range(size1): | |||
for i in range(self.blockcat): | |||
featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ | |||
feat[(tau + i) * self.blockdec, :] | |||
return torch.from_numpy(featall), torch.from_numpy(label1).long() | |||
def get_sim(self, id): | |||
num_base = len(self.sims_base) | |||
if id < num_base: | |||
fs = self.sims_base[id] | |||
else: | |||
fs = self.sims_senior[id - num_base] | |||
return fs | |||
class Worker(threading.Thread): | |||
""" | |||
id: worker id | |||
dataset: the dataset | |||
pool: queue as the global data buffer | |||
""" | |||
def __init__(self, id, dataset, pool): | |||
threading.Thread.__init__(self) | |||
self.id = id | |||
self.dataset = dataset | |||
self.pool = pool | |||
self.isrun = True | |||
self.nn = 0 | |||
def run(self): | |||
while self.isrun: | |||
self.nn += 1 | |||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') | |||
# get simulated minibatch | |||
if self.isrun: | |||
data = self.dataset.getBatch(self.id) | |||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') | |||
# put data into buffer | |||
if self.isrun: | |||
self.pool.put(data) | |||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') | |||
logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) | |||
def stopWorker(self): | |||
""" | |||
stop the worker thread | |||
""" | |||
self.isrun = False | |||
class KWSDataLoader: | |||
""" | |||
dataset: the dataset reference | |||
batchsize: data batch size | |||
numworkers: no. of workers | |||
prefetch: prefetch factor | |||
""" | |||
def __init__(self, dataset, batchsize, numworkers, prefetch=2): | |||
self.dataset = dataset | |||
self.batchsize = batchsize | |||
self.datamap = {} | |||
self.isrun = True | |||
# data queue | |||
self.pool = queue.Queue(batchsize * prefetch) | |||
# initialize workers | |||
self.workerlist = [] | |||
for id in range(numworkers): | |||
w = Worker(id, dataset, self.pool) | |||
self.workerlist.append(w) | |||
def __iter__(self): | |||
return self | |||
def __next__(self): | |||
while self.isrun: | |||
# get data from common data pool | |||
data = self.pool.get() | |||
self.pool.task_done() | |||
# group minibatches with the same shape | |||
key = str(data[0].shape) | |||
batchl = self.datamap.get(key) | |||
if batchl is None: | |||
batchl = [] | |||
self.datamap.update({key: batchl}) | |||
batchl.append(data) | |||
# a full data batch collected | |||
if len(batchl) >= self.batchsize: | |||
featbatch = [] | |||
labelbatch = [] | |||
for feat, label in batchl: | |||
featbatch.append(feat) | |||
labelbatch.append(label) | |||
batchl.clear() | |||
feattensor = torch.stack(featbatch, dim=0) | |||
labeltensor = torch.stack(labelbatch, dim=0) | |||
if feattensor.shape[-2] == 1: | |||
logger.debug('KWSDataLoader: Basetrain batch.') | |||
else: | |||
logger.debug('KWSDataLoader: Finetune batch.') | |||
return feattensor, labeltensor | |||
return None, None | |||
def start(self): | |||
""" | |||
start multi-thread data loader | |||
""" | |||
for w in self.workerlist: | |||
w.start() | |||
def stop(self): | |||
""" | |||
stop data loader | |||
""" | |||
logger.info('KWSDataLoader: Stopping...') | |||
self.isrun = False | |||
for w in self.workerlist: | |||
w.stopWorker() | |||
while not self.pool.empty(): | |||
self.pool.get(block=True, timeout=0.001) | |||
# wait workers terminated | |||
for w in self.workerlist: | |||
while not self.pool.empty(): | |||
self.pool.get(block=True, timeout=0.001) | |||
w.join() | |||
logger.info('KWSDataLoader: All worker stopped.') |
@@ -0,0 +1,279 @@ | |||
import datetime | |||
import math | |||
import os | |||
from typing import Callable, Dict, Optional | |||
import numpy as np | |||
import torch | |||
from torch import nn as nn | |||
from torch import optim as optim | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.utils.audio.audio_utils import update_conf | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.data_utils import to_device | |||
from modelscope.utils.device import create_device | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||
init_dist, is_master) | |||
logger = get_logger() | |||
BASETRAIN_CONF_EASY = 'basetrain_easy' | |||
BASETRAIN_CONF_NORMAL = 'basetrain_normal' | |||
BASETRAIN_CONF_HARD = 'basetrain_hard' | |||
FINETUNE_CONF_EASY = 'finetune_easy' | |||
FINETUNE_CONF_NORMAL = 'finetune_normal' | |||
FINETUNE_CONF_HARD = 'finetune_hard' | |||
EASY_RATIO = 0.1 | |||
NORMAL_RATIO = 0.6 | |||
HARD_RATIO = 0.3 | |||
BASETRAIN_RATIO = 0.5 | |||
@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) | |||
class KWSFarfieldTrainer(BaseTrainer): | |||
DEFAULT_WORK_DIR = './work_dir' | |||
conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, | |||
BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, | |||
BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) | |||
def __init__(self, | |||
model: str, | |||
work_dir: str, | |||
cfg_file: Optional[str] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
custom_conf: Optional[dict] = None, | |||
**kwargs): | |||
if isinstance(model, str): | |||
if os.path.exists(model): | |||
self.model_dir = model if os.path.isdir( | |||
model) else os.path.dirname(model) | |||
else: | |||
self.model_dir = snapshot_download( | |||
model, revision=model_revision) | |||
if cfg_file is None: | |||
cfg_file = os.path.join(self.model_dir, | |||
ModelFile.CONFIGURATION) | |||
else: | |||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||
self.model_dir = os.path.dirname(cfg_file) | |||
super().__init__(cfg_file, arg_parse_fn) | |||
self.model = self.build_model() | |||
self.work_dir = work_dir | |||
# the number of model output dimension | |||
# should update config outside the trainer, if user need more wake word | |||
self._num_classes = self.cfg.model.num_syn | |||
if kwargs.get('launcher', None) is not None: | |||
init_dist(kwargs['launcher']) | |||
_, world_size = get_dist_info() | |||
self._dist = world_size > 1 | |||
device_name = kwargs.get('device', 'gpu') | |||
if self._dist: | |||
local_rank = get_local_rank() | |||
device_name = f'cuda:{local_rank}' | |||
self.device = create_device(device_name) | |||
# model placement | |||
if self.device.type == 'cuda': | |||
self.model.to(self.device) | |||
if 'max_epochs' not in kwargs: | |||
assert hasattr( | |||
self.cfg.train, 'max_epochs' | |||
), 'max_epochs is missing from the configuration file' | |||
self._max_epochs = self.cfg.train.max_epochs | |||
else: | |||
self._max_epochs = kwargs['max_epochs'] | |||
self._train_iters = kwargs.get('train_iters_per_epoch', None) | |||
self._val_iters = kwargs.get('val_iters_per_epoch', None) | |||
if self._train_iters is None: | |||
self._train_iters = self.cfg.train.train_iters_per_epoch | |||
if self._val_iters is None: | |||
self._val_iters = self.cfg.evaluation.val_iters_per_epoch | |||
dataloader_config = self.cfg.train.dataloader | |||
self._threads = kwargs.get('workers', None) | |||
if self._threads is None: | |||
self._threads = dataloader_config.workers_per_gpu | |||
self._single_rate = BASETRAIN_RATIO | |||
if 'single_rate' in kwargs: | |||
self._single_rate = kwargs['single_rate'] | |||
self._batch_size = dataloader_config.batch_size_per_gpu | |||
if 'model_bin' in kwargs: | |||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | |||
checkpoint = torch.load(model_bin_file) | |||
self.model.load_state_dict(checkpoint) | |||
# build corresponding optimizer and loss function | |||
lr = self.cfg.train.optimizer.lr | |||
self.optimizer = optim.Adam(self.model.parameters(), lr) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
self.data_val = None | |||
self.json_log_path = os.path.join(self.work_dir, | |||
'{}.log.json'.format(self.timestamp)) | |||
self.conf_files = [] | |||
for conf_key in self.conf_keys: | |||
template_file = os.path.join(self.model_dir, conf_key) | |||
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') | |||
update_conf(template_file, conf_file, custom_conf[conf_key]) | |||
self.conf_files.append(conf_file) | |||
self._current_epoch = 0 | |||
self.stages = (math.floor(self._max_epochs * EASY_RATIO), | |||
math.floor(self._max_epochs * NORMAL_RATIO), | |||
math.floor(self._max_epochs * HARD_RATIO)) | |||
def build_model(self) -> nn.Module: | |||
""" Instantiate a pytorch model and return. | |||
By default, we will create a model using config from configuration file. You can | |||
override this method in a subclass. | |||
""" | |||
model = Model.from_pretrained( | |||
self.model_dir, cfg_dict=self.cfg, training=True) | |||
if isinstance(model, TorchModel) and hasattr(model, 'model'): | |||
return model.model | |||
elif isinstance(model, nn.Module): | |||
return model | |||
def train(self, *args, **kwargs): | |||
if not self.data_val: | |||
self.gen_val() | |||
logger.info('Start training...') | |||
totaltime = datetime.datetime.now() | |||
for stage, num_epoch in enumerate(self.stages): | |||
self.run_stage(stage, num_epoch) | |||
# total time spent | |||
totaltime = datetime.datetime.now() - totaltime | |||
logger.info('Total time spent: {:.2f} hours\n'.format( | |||
totaltime.total_seconds() / 3600.0)) | |||
def run_stage(self, stage, num_epoch): | |||
""" | |||
Run training stages with correspond data | |||
Args: | |||
stage: id of stage | |||
num_epoch: the number of epoch to run in this stage | |||
""" | |||
if num_epoch <= 0: | |||
logger.warning(f'Invalid epoch number, stage {stage} exit!') | |||
return | |||
logger.info(f'Starting stage {stage}...') | |||
dataset, dataloader = self.create_dataloader( | |||
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) | |||
it = iter(dataloader) | |||
for _ in range(num_epoch): | |||
self._current_epoch += 1 | |||
epochtime = datetime.datetime.now() | |||
logger.info('Start epoch %d...', self._current_epoch) | |||
loss_train_epoch = 0.0 | |||
validbatchs = 0 | |||
for bi in range(self._train_iters): | |||
# prepare data | |||
feat, label = next(it) | |||
label = torch.reshape(label, (-1, )) | |||
feat = to_device(feat, self.device) | |||
label = to_device(label, self.device) | |||
# apply model | |||
self.optimizer.zero_grad() | |||
predict = self.model(feat) | |||
# calculate loss | |||
loss = self.loss_fn( | |||
torch.reshape(predict, (-1, self._num_classes)), label) | |||
if not np.isnan(loss.item()): | |||
loss.backward() | |||
self.optimizer.step() | |||
loss_train_epoch += loss.item() | |||
validbatchs += 1 | |||
train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( | |||
self._current_epoch, self._max_epochs, bi + 1, | |||
self._train_iters, loss.item()) | |||
logger.info(train_result) | |||
self._dump_log(train_result) | |||
# average training loss in one epoch | |||
loss_train_epoch /= validbatchs | |||
loss_val_epoch = self.evaluate('') | |||
val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( | |||
self._current_epoch, loss_train_epoch, loss_val_epoch) | |||
logger.info(val_result) | |||
self._dump_log(val_result) | |||
# check point | |||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | |||
self._current_epoch, loss_train_epoch, loss_val_epoch) | |||
torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||
# time spent per epoch | |||
epochtime = datetime.datetime.now() - epochtime | |||
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | |||
self._current_epoch, | |||
epochtime.total_seconds() / 3600.0)) | |||
dataloader.stop() | |||
dataset.release() | |||
logger.info(f'Stage {stage} is finished.') | |||
def gen_val(self): | |||
""" | |||
generate validation set | |||
""" | |||
logger.info('Start generating validation set...') | |||
dataset, dataloader = self.create_dataloader(self.conf_files[2], | |||
self.conf_files[3]) | |||
it = iter(dataloader) | |||
self.data_val = [] | |||
for bi in range(self._val_iters): | |||
logger.info('Iterating validation data %d', bi) | |||
feat, label = next(it) | |||
label = torch.reshape(label, (-1, )) | |||
self.data_val.append([feat, label]) | |||
dataloader.stop() | |||
dataset.release() | |||
logger.info('Finish generating validation set!') | |||
def create_dataloader(self, base_path, finetune_path): | |||
dataset = KWSDataset(base_path, finetune_path, self._threads, | |||
self._single_rate, self._num_classes) | |||
dataloader = KWSDataLoader( | |||
dataset, batchsize=self._batch_size, numworkers=self._threads) | |||
dataloader.start() | |||
return dataset, dataloader | |||
def evaluate(self, checkpoint_path: str, *args, | |||
**kwargs) -> Dict[str, float]: | |||
logger.info('Start validation...') | |||
loss_val_epoch = 0.0 | |||
with torch.no_grad(): | |||
for feat, label in self.data_val: | |||
feat = to_device(feat, self.device) | |||
label = to_device(label, self.device) | |||
# apply model | |||
predict = self.model(feat) | |||
# calculate loss | |||
loss = self.loss_fn( | |||
torch.reshape(predict, (-1, self._num_classes)), label) | |||
loss_val_epoch += loss.item() | |||
logger.info('Finish validation.') | |||
return loss_val_epoch / self._val_iters | |||
def _dump_log(self, msg): | |||
if is_master(): | |||
with open(self.json_log_path, 'a+') as f: | |||
f.write(msg) | |||
f.write('\n') |
@@ -1,4 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import re | |||
import struct | |||
from typing import Union | |||
from urllib.parse import urlparse | |||
@@ -37,6 +38,23 @@ def audio_norm(x): | |||
return x | |||
def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||
def repl(matched): | |||
key = matched.group(1) | |||
if key in conf_item: | |||
return conf_item[key] | |||
else: | |||
return None | |||
with open(origin_config_file) as f: | |||
lines = f.readlines() | |||
with open(new_config_file, 'w') as f: | |||
for line in lines: | |||
line = re.sub(r'\$\{(.*)\}', repl, line) | |||
f.write(line) | |||
def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
data = wav | |||
if len(data) > 44: | |||
@@ -14,7 +14,11 @@ nltk | |||
numpy<=1.18 | |||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | |||
protobuf>3,<3.21.0 | |||
py_sound_connect | |||
ptflops | |||
py_sound_connect>=0.1 | |||
pytorch_wavelets | |||
PyWavelets>=1.0.0 | |||
scikit-learn | |||
SoundFile>0.10 | |||
sox | |||
torchaudio | |||
@@ -0,0 +1,85 @@ | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.test_utils import test_level | |||
POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' | |||
NEG_FILE = 'data/test/audios/speech_with_noise.wav' | |||
NOISE_FILE = 'data/test/audios/speech_with_noise.wav' | |||
INTERF_FILE = 'data/test/audios/speech_with_noise.wav' | |||
REF_FILE = 'data/test/audios/farend_speech.wav' | |||
NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' | |||
class TestKwsFarfieldTrainer(unittest.TestCase): | |||
REVISION = 'beta' | |||
def setUp(self): | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
print(f'tmp dir: {self.tmp_dir}') | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||
train_pos_list = self.create_list('pos.list', POS_FILE) | |||
train_neg_list = self.create_list('neg.list', NEG_FILE) | |||
train_noise1_list = self.create_list('noise.list', NOISE_FILE) | |||
train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) | |||
train_interf_list = self.create_list('interf.list', INTERF_FILE) | |||
train_ref_list = self.create_list('ref.list', REF_FILE) | |||
base_dict = dict( | |||
train_pos_list=train_pos_list, | |||
train_neg_list=train_neg_list, | |||
train_noise1_list=train_noise1_list) | |||
fintune_dict = dict( | |||
train_pos_list=train_pos_list, | |||
train_neg_list=train_neg_list, | |||
train_noise1_list=train_noise1_list, | |||
train_noise2_type='1', | |||
train_noise1_ratio='0.2', | |||
train_noise2_list=train_noise2_list, | |||
train_interf_list=train_interf_list, | |||
train_ref_list=train_ref_list) | |||
self.custom_conf = dict( | |||
basetrain_easy=base_dict, | |||
basetrain_normal=base_dict, | |||
basetrain_hard=base_dict, | |||
finetune_easy=fintune_dict, | |||
finetune_normal=fintune_dict, | |||
finetune_hard=fintune_dict) | |||
def create_list(self, list_name, audio_file): | |||
pos_list_file = os.path.join(self.tmp_dir, list_name) | |||
with open(pos_list_file, 'w') as f: | |||
for i in range(10): | |||
f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') | |||
train_pos_list = f'{pos_list_file}, 1.0' | |||
return train_pos_list | |||
def tearDown(self) -> None: | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_normal(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
work_dir=self.tmp_dir, | |||
model_revision=self.REVISION, | |||
workers=2, | |||
max_epochs=2, | |||
train_iters_per_epoch=2, | |||
val_iters_per_epoch=1, | |||
custom_conf=self.custom_conf) | |||
trainer = build_trainer( | |||
Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files, | |||
f'work_dir:{self.tmp_dir}') |