Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10275823master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||||
size 2327764 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||||
size 68524 |
@@ -285,6 +285,7 @@ class Trainers(object): | |||||
# audio trainers | # audio trainers | ||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||||
class Preprocessors(object): | class Preprocessors(object): | ||||
@@ -1,15 +1,14 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | import os | ||||
from typing import Dict | |||||
import torch | |||||
from typing import Dict, Optional | |||||
from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
from modelscope.models.base import Tensor | from modelscope.models.base import Tensor | ||||
from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.audio.audio_utils import update_conf | |||||
from modelscope.utils.constant import Tasks | |||||
from .fsmn_sele_v2 import FSMNSeleNetV2 | from .fsmn_sele_v2 import FSMNSeleNetV2 | ||||
@@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||||
MODEL_TXT = 'model.txt' | MODEL_TXT = 'model.txt' | ||||
SC_CONFIG = 'sound_connect.conf' | SC_CONFIG = 'sound_connect.conf' | ||||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
def __init__(self, | |||||
model_dir: str, | |||||
training: Optional[bool] = False, | |||||
*args, | |||||
**kwargs): | |||||
"""initialize the dfsmn model from the `model_dir` path. | """initialize the dfsmn model from the `model_dir` path. | ||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
""" | """ | ||||
super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
model_bin_file = os.path.join(model_dir, | |||||
ModelFile.TORCH_MODEL_BIN_FILE) | |||||
self._model = None | |||||
if os.path.exists(model_bin_file): | |||||
kwargs.pop('device') | |||||
self._model = FSMNSeleNetV2(*args, **kwargs) | |||||
checkpoint = torch.load(model_bin_file) | |||||
self._model.load_state_dict(checkpoint, strict=False) | |||||
self._sc = None | |||||
if os.path.exists(model_txt_file): | |||||
with open(sc_config_file) as f: | |||||
lines = f.readlines() | |||||
with open(sc_config_file, 'w') as f: | |||||
for line in lines: | |||||
if self.SC_CONF_ITEM_KWS_MODEL in line: | |||||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||||
model_txt_file) | |||||
f.write(line) | |||||
import py_sound_connect | |||||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
self.size_in = self._sc.bytesPerBlockIn() | |||||
self.size_out = self._sc.bytesPerBlockOut() | |||||
if self._model is None and self._sc is None: | |||||
raise Exception( | |||||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||||
) | |||||
if training: | |||||
self.model = FSMNSeleNetV2(*args, **kwargs) | |||||
else: | |||||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
self._sc = None | |||||
if os.path.exists(model_txt_file): | |||||
conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||||
update_conf(sc_config_file, sc_config_file, conf_dict) | |||||
import py_sound_connect | |||||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
self.size_in = self._sc.bytesPerBlockIn() | |||||
self.size_out = self._sc.bytesPerBlockOut() | |||||
else: | |||||
raise Exception( | |||||
f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||||
) | |||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
... | |||||
return self.model.forward(input) | |||||
def forward_decode(self, data: bytes): | def forward_decode(self, data: bytes): | ||||
result = {'pcm': self._sc.process(data, self.size_out)} | result = {'pcm': self._sc.process(data, self.size_out)} | ||||
@@ -0,0 +1,21 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .kws_farfield_dataset import KWSDataset, KWSDataLoader | |||||
else: | |||||
_import_structure = { | |||||
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], | |||||
} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,280 @@ | |||||
""" | |||||
Used to prepare simulated data. | |||||
""" | |||||
import math | |||||
import os.path | |||||
import queue | |||||
import threading | |||||
import time | |||||
import numpy as np | |||||
import torch | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
BLOCK_DEC = 2 | |||||
BLOCK_CAT = 3 | |||||
FBANK_SIZE = 40 | |||||
LABEL_SIZE = 1 | |||||
LABEL_GAIN = 100.0 | |||||
class KWSDataset: | |||||
""" | |||||
dataset for keyword spotting and vad | |||||
conf_basetrain: basetrain configure file path | |||||
conf_finetune: finetune configure file path, null allowed | |||||
numworkers: no. of workers | |||||
basetrainratio: basetrain workers ratio | |||||
numclasses: no. of nn output classes, 2 classes to generate vad label | |||||
blockdec: block decimation | |||||
blockcat: block concatenation | |||||
""" | |||||
def __init__(self, | |||||
conf_basetrain, | |||||
conf_finetune, | |||||
numworkers, | |||||
basetrainratio, | |||||
numclasses, | |||||
blockdec=BLOCK_CAT, | |||||
blockcat=BLOCK_CAT): | |||||
super().__init__() | |||||
self.numclasses = numclasses | |||||
self.blockdec = blockdec | |||||
self.blockcat = blockcat | |||||
self.sims_base = [] | |||||
self.sims_senior = [] | |||||
self.setup_sims(conf_basetrain, conf_finetune, numworkers, | |||||
basetrainratio) | |||||
def release(self): | |||||
for sim in self.sims_base: | |||||
del sim | |||||
for sim in self.sims_senior: | |||||
del sim | |||||
del self.base_conf | |||||
del self.senior_conf | |||||
logger.info('KWSDataset: Released.') | |||||
def setup_sims(self, conf_basetrain, conf_finetune, numworkers, | |||||
basetrainratio): | |||||
if not os.path.exists(conf_basetrain): | |||||
raise ValueError(f'{conf_basetrain} does not exist!') | |||||
if not os.path.exists(conf_finetune): | |||||
raise ValueError(f'{conf_finetune} does not exist!') | |||||
import py_sound_connect | |||||
logger.info('KWSDataset init SoundConnect...') | |||||
num_base = math.ceil(numworkers * basetrainratio) | |||||
num_senior = numworkers - num_base | |||||
# hold by fields to avoid python releasing conf object | |||||
self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) | |||||
self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) | |||||
for i in range(num_base): | |||||
fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) | |||||
self.sims_base.append(fs) | |||||
for i in range(num_senior): | |||||
self.sims_senior.append( | |||||
py_sound_connect.FeatSimuKWS(self.senior_conf.params)) | |||||
logger.info('KWSDataset init SoundConnect finished.') | |||||
def getBatch(self, id): | |||||
""" | |||||
Generate a data batch | |||||
Args: | |||||
id: worker id | |||||
Return: time x channel x feature, label | |||||
""" | |||||
fs = self.get_sim(id) | |||||
fs.processBatch() | |||||
# get multi-channel feature vector size | |||||
featsize = fs.featSize() | |||||
# get label vector size | |||||
labelsize = fs.labelSize() | |||||
# get minibatch size (time dimension) | |||||
# batchsize = fs.featBatchSize() | |||||
# no. of fe output channels | |||||
numchs = featsize // FBANK_SIZE | |||||
# get raw data | |||||
fs_feat = fs.feat() | |||||
data = np.frombuffer(fs_feat, dtype='float32') | |||||
data = data.reshape((-1, featsize + labelsize)) | |||||
# convert float label to int | |||||
label = data[:, FBANK_SIZE * numchs:] | |||||
if self.numclasses == 2: | |||||
# generate vad label | |||||
label[label > 0.0] = 1.0 | |||||
else: | |||||
# generate kws label | |||||
label = np.round(label * LABEL_GAIN) | |||||
label[label > self.numclasses - 1] = 0.0 | |||||
# decimated size | |||||
size1 = int(np.ceil( | |||||
label.shape[0] / self.blockdec)) - self.blockcat + 1 | |||||
# label decimation | |||||
label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') | |||||
for tau in range(size1): | |||||
label1[tau, :] = label[(tau + self.blockcat // 2) | |||||
* self.blockdec, :] | |||||
# feature decimation and concatenation | |||||
# time x channel x feature | |||||
featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), | |||||
dtype='float32') | |||||
for n in range(numchs): | |||||
feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] | |||||
for tau in range(size1): | |||||
for i in range(self.blockcat): | |||||
featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ | |||||
feat[(tau + i) * self.blockdec, :] | |||||
return torch.from_numpy(featall), torch.from_numpy(label1).long() | |||||
def get_sim(self, id): | |||||
num_base = len(self.sims_base) | |||||
if id < num_base: | |||||
fs = self.sims_base[id] | |||||
else: | |||||
fs = self.sims_senior[id - num_base] | |||||
return fs | |||||
class Worker(threading.Thread): | |||||
""" | |||||
id: worker id | |||||
dataset: the dataset | |||||
pool: queue as the global data buffer | |||||
""" | |||||
def __init__(self, id, dataset, pool): | |||||
threading.Thread.__init__(self) | |||||
self.id = id | |||||
self.dataset = dataset | |||||
self.pool = pool | |||||
self.isrun = True | |||||
self.nn = 0 | |||||
def run(self): | |||||
while self.isrun: | |||||
self.nn += 1 | |||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') | |||||
# get simulated minibatch | |||||
if self.isrun: | |||||
data = self.dataset.getBatch(self.id) | |||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') | |||||
# put data into buffer | |||||
if self.isrun: | |||||
self.pool.put(data) | |||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') | |||||
logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) | |||||
def stopWorker(self): | |||||
""" | |||||
stop the worker thread | |||||
""" | |||||
self.isrun = False | |||||
class KWSDataLoader: | |||||
""" | |||||
dataset: the dataset reference | |||||
batchsize: data batch size | |||||
numworkers: no. of workers | |||||
prefetch: prefetch factor | |||||
""" | |||||
def __init__(self, dataset, batchsize, numworkers, prefetch=2): | |||||
self.dataset = dataset | |||||
self.batchsize = batchsize | |||||
self.datamap = {} | |||||
self.isrun = True | |||||
# data queue | |||||
self.pool = queue.Queue(batchsize * prefetch) | |||||
# initialize workers | |||||
self.workerlist = [] | |||||
for id in range(numworkers): | |||||
w = Worker(id, dataset, self.pool) | |||||
self.workerlist.append(w) | |||||
def __iter__(self): | |||||
return self | |||||
def __next__(self): | |||||
while self.isrun: | |||||
# get data from common data pool | |||||
data = self.pool.get() | |||||
self.pool.task_done() | |||||
# group minibatches with the same shape | |||||
key = str(data[0].shape) | |||||
batchl = self.datamap.get(key) | |||||
if batchl is None: | |||||
batchl = [] | |||||
self.datamap.update({key: batchl}) | |||||
batchl.append(data) | |||||
# a full data batch collected | |||||
if len(batchl) >= self.batchsize: | |||||
featbatch = [] | |||||
labelbatch = [] | |||||
for feat, label in batchl: | |||||
featbatch.append(feat) | |||||
labelbatch.append(label) | |||||
batchl.clear() | |||||
feattensor = torch.stack(featbatch, dim=0) | |||||
labeltensor = torch.stack(labelbatch, dim=0) | |||||
if feattensor.shape[-2] == 1: | |||||
logger.debug('KWSDataLoader: Basetrain batch.') | |||||
else: | |||||
logger.debug('KWSDataLoader: Finetune batch.') | |||||
return feattensor, labeltensor | |||||
return None, None | |||||
def start(self): | |||||
""" | |||||
start multi-thread data loader | |||||
""" | |||||
for w in self.workerlist: | |||||
w.start() | |||||
def stop(self): | |||||
""" | |||||
stop data loader | |||||
""" | |||||
logger.info('KWSDataLoader: Stopping...') | |||||
self.isrun = False | |||||
for w in self.workerlist: | |||||
w.stopWorker() | |||||
while not self.pool.empty(): | |||||
self.pool.get(block=True, timeout=0.001) | |||||
# wait workers terminated | |||||
for w in self.workerlist: | |||||
while not self.pool.empty(): | |||||
self.pool.get(block=True, timeout=0.001) | |||||
w.join() | |||||
logger.info('KWSDataLoader: All worker stopped.') |
@@ -0,0 +1,279 @@ | |||||
import datetime | |||||
import math | |||||
import os | |||||
from typing import Callable, Dict, Optional | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn as nn | |||||
from torch import optim as optim | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.models import Model, TorchModel | |||||
from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | |||||
from modelscope.trainers.base import BaseTrainer | |||||
from modelscope.trainers.builder import TRAINERS | |||||
from modelscope.utils.audio.audio_utils import update_conf | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
from modelscope.utils.data_utils import to_device | |||||
from modelscope.utils.device import create_device | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||||
init_dist, is_master) | |||||
logger = get_logger() | |||||
BASETRAIN_CONF_EASY = 'basetrain_easy' | |||||
BASETRAIN_CONF_NORMAL = 'basetrain_normal' | |||||
BASETRAIN_CONF_HARD = 'basetrain_hard' | |||||
FINETUNE_CONF_EASY = 'finetune_easy' | |||||
FINETUNE_CONF_NORMAL = 'finetune_normal' | |||||
FINETUNE_CONF_HARD = 'finetune_hard' | |||||
EASY_RATIO = 0.1 | |||||
NORMAL_RATIO = 0.6 | |||||
HARD_RATIO = 0.3 | |||||
BASETRAIN_RATIO = 0.5 | |||||
@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) | |||||
class KWSFarfieldTrainer(BaseTrainer): | |||||
DEFAULT_WORK_DIR = './work_dir' | |||||
conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, | |||||
BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, | |||||
BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) | |||||
def __init__(self, | |||||
model: str, | |||||
work_dir: str, | |||||
cfg_file: Optional[str] = None, | |||||
arg_parse_fn: Optional[Callable] = None, | |||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
custom_conf: Optional[dict] = None, | |||||
**kwargs): | |||||
if isinstance(model, str): | |||||
if os.path.exists(model): | |||||
self.model_dir = model if os.path.isdir( | |||||
model) else os.path.dirname(model) | |||||
else: | |||||
self.model_dir = snapshot_download( | |||||
model, revision=model_revision) | |||||
if cfg_file is None: | |||||
cfg_file = os.path.join(self.model_dir, | |||||
ModelFile.CONFIGURATION) | |||||
else: | |||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||||
self.model_dir = os.path.dirname(cfg_file) | |||||
super().__init__(cfg_file, arg_parse_fn) | |||||
self.model = self.build_model() | |||||
self.work_dir = work_dir | |||||
# the number of model output dimension | |||||
# should update config outside the trainer, if user need more wake word | |||||
self._num_classes = self.cfg.model.num_syn | |||||
if kwargs.get('launcher', None) is not None: | |||||
init_dist(kwargs['launcher']) | |||||
_, world_size = get_dist_info() | |||||
self._dist = world_size > 1 | |||||
device_name = kwargs.get('device', 'gpu') | |||||
if self._dist: | |||||
local_rank = get_local_rank() | |||||
device_name = f'cuda:{local_rank}' | |||||
self.device = create_device(device_name) | |||||
# model placement | |||||
if self.device.type == 'cuda': | |||||
self.model.to(self.device) | |||||
if 'max_epochs' not in kwargs: | |||||
assert hasattr( | |||||
self.cfg.train, 'max_epochs' | |||||
), 'max_epochs is missing from the configuration file' | |||||
self._max_epochs = self.cfg.train.max_epochs | |||||
else: | |||||
self._max_epochs = kwargs['max_epochs'] | |||||
self._train_iters = kwargs.get('train_iters_per_epoch', None) | |||||
self._val_iters = kwargs.get('val_iters_per_epoch', None) | |||||
if self._train_iters is None: | |||||
self._train_iters = self.cfg.train.train_iters_per_epoch | |||||
if self._val_iters is None: | |||||
self._val_iters = self.cfg.evaluation.val_iters_per_epoch | |||||
dataloader_config = self.cfg.train.dataloader | |||||
self._threads = kwargs.get('workers', None) | |||||
if self._threads is None: | |||||
self._threads = dataloader_config.workers_per_gpu | |||||
self._single_rate = BASETRAIN_RATIO | |||||
if 'single_rate' in kwargs: | |||||
self._single_rate = kwargs['single_rate'] | |||||
self._batch_size = dataloader_config.batch_size_per_gpu | |||||
if 'model_bin' in kwargs: | |||||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | |||||
checkpoint = torch.load(model_bin_file) | |||||
self.model.load_state_dict(checkpoint) | |||||
# build corresponding optimizer and loss function | |||||
lr = self.cfg.train.optimizer.lr | |||||
self.optimizer = optim.Adam(self.model.parameters(), lr) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
self.data_val = None | |||||
self.json_log_path = os.path.join(self.work_dir, | |||||
'{}.log.json'.format(self.timestamp)) | |||||
self.conf_files = [] | |||||
for conf_key in self.conf_keys: | |||||
template_file = os.path.join(self.model_dir, conf_key) | |||||
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') | |||||
update_conf(template_file, conf_file, custom_conf[conf_key]) | |||||
self.conf_files.append(conf_file) | |||||
self._current_epoch = 0 | |||||
self.stages = (math.floor(self._max_epochs * EASY_RATIO), | |||||
math.floor(self._max_epochs * NORMAL_RATIO), | |||||
math.floor(self._max_epochs * HARD_RATIO)) | |||||
def build_model(self) -> nn.Module: | |||||
""" Instantiate a pytorch model and return. | |||||
By default, we will create a model using config from configuration file. You can | |||||
override this method in a subclass. | |||||
""" | |||||
model = Model.from_pretrained( | |||||
self.model_dir, cfg_dict=self.cfg, training=True) | |||||
if isinstance(model, TorchModel) and hasattr(model, 'model'): | |||||
return model.model | |||||
elif isinstance(model, nn.Module): | |||||
return model | |||||
def train(self, *args, **kwargs): | |||||
if not self.data_val: | |||||
self.gen_val() | |||||
logger.info('Start training...') | |||||
totaltime = datetime.datetime.now() | |||||
for stage, num_epoch in enumerate(self.stages): | |||||
self.run_stage(stage, num_epoch) | |||||
# total time spent | |||||
totaltime = datetime.datetime.now() - totaltime | |||||
logger.info('Total time spent: {:.2f} hours\n'.format( | |||||
totaltime.total_seconds() / 3600.0)) | |||||
def run_stage(self, stage, num_epoch): | |||||
""" | |||||
Run training stages with correspond data | |||||
Args: | |||||
stage: id of stage | |||||
num_epoch: the number of epoch to run in this stage | |||||
""" | |||||
if num_epoch <= 0: | |||||
logger.warning(f'Invalid epoch number, stage {stage} exit!') | |||||
return | |||||
logger.info(f'Starting stage {stage}...') | |||||
dataset, dataloader = self.create_dataloader( | |||||
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) | |||||
it = iter(dataloader) | |||||
for _ in range(num_epoch): | |||||
self._current_epoch += 1 | |||||
epochtime = datetime.datetime.now() | |||||
logger.info('Start epoch %d...', self._current_epoch) | |||||
loss_train_epoch = 0.0 | |||||
validbatchs = 0 | |||||
for bi in range(self._train_iters): | |||||
# prepare data | |||||
feat, label = next(it) | |||||
label = torch.reshape(label, (-1, )) | |||||
feat = to_device(feat, self.device) | |||||
label = to_device(label, self.device) | |||||
# apply model | |||||
self.optimizer.zero_grad() | |||||
predict = self.model(feat) | |||||
# calculate loss | |||||
loss = self.loss_fn( | |||||
torch.reshape(predict, (-1, self._num_classes)), label) | |||||
if not np.isnan(loss.item()): | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
loss_train_epoch += loss.item() | |||||
validbatchs += 1 | |||||
train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( | |||||
self._current_epoch, self._max_epochs, bi + 1, | |||||
self._train_iters, loss.item()) | |||||
logger.info(train_result) | |||||
self._dump_log(train_result) | |||||
# average training loss in one epoch | |||||
loss_train_epoch /= validbatchs | |||||
loss_val_epoch = self.evaluate('') | |||||
val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( | |||||
self._current_epoch, loss_train_epoch, loss_val_epoch) | |||||
logger.info(val_result) | |||||
self._dump_log(val_result) | |||||
# check point | |||||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | |||||
self._current_epoch, loss_train_epoch, loss_val_epoch) | |||||
torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||||
# time spent per epoch | |||||
epochtime = datetime.datetime.now() - epochtime | |||||
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | |||||
self._current_epoch, | |||||
epochtime.total_seconds() / 3600.0)) | |||||
dataloader.stop() | |||||
dataset.release() | |||||
logger.info(f'Stage {stage} is finished.') | |||||
def gen_val(self): | |||||
""" | |||||
generate validation set | |||||
""" | |||||
logger.info('Start generating validation set...') | |||||
dataset, dataloader = self.create_dataloader(self.conf_files[2], | |||||
self.conf_files[3]) | |||||
it = iter(dataloader) | |||||
self.data_val = [] | |||||
for bi in range(self._val_iters): | |||||
logger.info('Iterating validation data %d', bi) | |||||
feat, label = next(it) | |||||
label = torch.reshape(label, (-1, )) | |||||
self.data_val.append([feat, label]) | |||||
dataloader.stop() | |||||
dataset.release() | |||||
logger.info('Finish generating validation set!') | |||||
def create_dataloader(self, base_path, finetune_path): | |||||
dataset = KWSDataset(base_path, finetune_path, self._threads, | |||||
self._single_rate, self._num_classes) | |||||
dataloader = KWSDataLoader( | |||||
dataset, batchsize=self._batch_size, numworkers=self._threads) | |||||
dataloader.start() | |||||
return dataset, dataloader | |||||
def evaluate(self, checkpoint_path: str, *args, | |||||
**kwargs) -> Dict[str, float]: | |||||
logger.info('Start validation...') | |||||
loss_val_epoch = 0.0 | |||||
with torch.no_grad(): | |||||
for feat, label in self.data_val: | |||||
feat = to_device(feat, self.device) | |||||
label = to_device(label, self.device) | |||||
# apply model | |||||
predict = self.model(feat) | |||||
# calculate loss | |||||
loss = self.loss_fn( | |||||
torch.reshape(predict, (-1, self._num_classes)), label) | |||||
loss_val_epoch += loss.item() | |||||
logger.info('Finish validation.') | |||||
return loss_val_epoch / self._val_iters | |||||
def _dump_log(self, msg): | |||||
if is_master(): | |||||
with open(self.json_log_path, 'a+') as f: | |||||
f.write(msg) | |||||
f.write('\n') |
@@ -1,4 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import re | |||||
import struct | import struct | ||||
from typing import Union | from typing import Union | ||||
from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
@@ -37,6 +38,23 @@ def audio_norm(x): | |||||
return x | return x | ||||
def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||||
def repl(matched): | |||||
key = matched.group(1) | |||||
if key in conf_item: | |||||
return conf_item[key] | |||||
else: | |||||
return None | |||||
with open(origin_config_file) as f: | |||||
lines = f.readlines() | |||||
with open(new_config_file, 'w') as f: | |||||
for line in lines: | |||||
line = re.sub(r'\$\{(.*)\}', repl, line) | |||||
f.write(line) | |||||
def extract_pcm_from_wav(wav: bytes) -> bytes: | def extract_pcm_from_wav(wav: bytes) -> bytes: | ||||
data = wav | data = wav | ||||
if len(data) > 44: | if len(data) > 44: | ||||
@@ -14,7 +14,11 @@ nltk | |||||
numpy<=1.18 | numpy<=1.18 | ||||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | ||||
protobuf>3,<3.21.0 | protobuf>3,<3.21.0 | ||||
py_sound_connect | |||||
ptflops | |||||
py_sound_connect>=0.1 | |||||
pytorch_wavelets | |||||
PyWavelets>=1.0.0 | |||||
scikit-learn | |||||
SoundFile>0.10 | SoundFile>0.10 | ||||
sox | sox | ||||
torchaudio | torchaudio | ||||
@@ -0,0 +1,85 @@ | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.test_utils import test_level | |||||
POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' | |||||
NEG_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
NOISE_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
INTERF_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
REF_FILE = 'data/test/audios/farend_speech.wav' | |||||
NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' | |||||
class TestKwsFarfieldTrainer(unittest.TestCase): | |||||
REVISION = 'beta' | |||||
def setUp(self): | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
print(f'tmp dir: {self.tmp_dir}') | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||||
train_pos_list = self.create_list('pos.list', POS_FILE) | |||||
train_neg_list = self.create_list('neg.list', NEG_FILE) | |||||
train_noise1_list = self.create_list('noise.list', NOISE_FILE) | |||||
train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) | |||||
train_interf_list = self.create_list('interf.list', INTERF_FILE) | |||||
train_ref_list = self.create_list('ref.list', REF_FILE) | |||||
base_dict = dict( | |||||
train_pos_list=train_pos_list, | |||||
train_neg_list=train_neg_list, | |||||
train_noise1_list=train_noise1_list) | |||||
fintune_dict = dict( | |||||
train_pos_list=train_pos_list, | |||||
train_neg_list=train_neg_list, | |||||
train_noise1_list=train_noise1_list, | |||||
train_noise2_type='1', | |||||
train_noise1_ratio='0.2', | |||||
train_noise2_list=train_noise2_list, | |||||
train_interf_list=train_interf_list, | |||||
train_ref_list=train_ref_list) | |||||
self.custom_conf = dict( | |||||
basetrain_easy=base_dict, | |||||
basetrain_normal=base_dict, | |||||
basetrain_hard=base_dict, | |||||
finetune_easy=fintune_dict, | |||||
finetune_normal=fintune_dict, | |||||
finetune_hard=fintune_dict) | |||||
def create_list(self, list_name, audio_file): | |||||
pos_list_file = os.path.join(self.tmp_dir, list_name) | |||||
with open(pos_list_file, 'w') as f: | |||||
for i in range(10): | |||||
f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') | |||||
train_pos_list = f'{pos_list_file}, 1.0' | |||||
return train_pos_list | |||||
def tearDown(self) -> None: | |||||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
super().tearDown() | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_normal(self): | |||||
kwargs = dict( | |||||
model=self.model_id, | |||||
work_dir=self.tmp_dir, | |||||
model_revision=self.REVISION, | |||||
workers=2, | |||||
max_epochs=2, | |||||
train_iters_per_epoch=2, | |||||
val_iters_per_epoch=1, | |||||
custom_conf=self.custom_conf) | |||||
trainer = build_trainer( | |||||
Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files, | |||||
f'work_dir:{self.tmp_dir}') |