Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10743508 * fix: load model directly from .pthmaster
@@ -54,7 +54,8 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||||
) | ) | ||||
def __del__(self): | def __del__(self): | ||||
self.tmp_dir.cleanup() | |||||
if hasattr(self, 'tmp_dir'): | |||||
self.tmp_dir.cleanup() | |||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
return self.model.forward(input) | return self.model.forward(input) | ||||
@@ -188,11 +188,13 @@ class Worker(threading.Thread): | |||||
class KWSDataLoader: | class KWSDataLoader: | ||||
""" | |||||
dataset: the dataset reference | |||||
batchsize: data batch size | |||||
numworkers: no. of workers | |||||
prefetch: prefetch factor | |||||
""" Load and organize audio data with multiple threads | |||||
Args: | |||||
dataset: the dataset reference | |||||
batchsize: data batch size | |||||
numworkers: no. of workers | |||||
prefetch: prefetch factor | |||||
""" | """ | ||||
def __init__(self, dataset, batchsize, numworkers, prefetch=2): | def __init__(self, dataset, batchsize, numworkers, prefetch=2): | ||||
@@ -202,7 +204,7 @@ class KWSDataLoader: | |||||
self.isrun = True | self.isrun = True | ||||
# data queue | # data queue | ||||
self.pool = queue.Queue(batchsize * prefetch) | |||||
self.pool = queue.Queue(numworkers * prefetch) | |||||
# initialize workers | # initialize workers | ||||
self.workerlist = [] | self.workerlist = [] | ||||
@@ -270,11 +272,11 @@ class KWSDataLoader: | |||||
w.stopWorker() | w.stopWorker() | ||||
while not self.pool.empty(): | while not self.pool.empty(): | ||||
self.pool.get(block=True, timeout=0.001) | |||||
self.pool.get(block=True, timeout=0.01) | |||||
# wait workers terminated | # wait workers terminated | ||||
for w in self.workerlist: | for w in self.workerlist: | ||||
while not self.pool.empty(): | while not self.pool.empty(): | ||||
self.pool.get(block=True, timeout=0.001) | |||||
self.pool.get(block=True, timeout=0.01) | |||||
w.join() | w.join() | ||||
logger.info('KWSDataLoader: All worker stopped.') | logger.info('KWSDataLoader: All worker stopped.') |
@@ -117,8 +117,7 @@ class KWSFarfieldTrainer(BaseTrainer): | |||||
self._batch_size = dataloader_config.batch_size_per_gpu | self._batch_size = dataloader_config.batch_size_per_gpu | ||||
if 'model_bin' in kwargs: | if 'model_bin' in kwargs: | ||||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | ||||
checkpoint = torch.load(model_bin_file) | |||||
self.model.load_state_dict(checkpoint) | |||||
self.model = torch.load(model_bin_file) | |||||
# build corresponding optimizer and loss function | # build corresponding optimizer and loss function | ||||
lr = self.cfg.train.optimizer.lr | lr = self.cfg.train.optimizer.lr | ||||
self.optimizer = optim.Adam(self.model.parameters(), lr) | self.optimizer = optim.Adam(self.model.parameters(), lr) | ||||
@@ -219,7 +218,9 @@ class KWSFarfieldTrainer(BaseTrainer): | |||||
# check point | # check point | ||||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | ||||
self._current_epoch, loss_train_epoch, loss_val_epoch) | self._current_epoch, loss_train_epoch, loss_val_epoch) | ||||
torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||||
save_path = os.path.join(self.work_dir, ckpt_name) | |||||
logger.info(f'Save model to {save_path}') | |||||
torch.save(self.model, save_path) | |||||
# time spent per epoch | # time spent per epoch | ||||
epochtime = datetime.datetime.now() - epochtime | epochtime = datetime.datetime.now() - epochtime | ||||
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | ||||
@@ -43,7 +43,10 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||||
def repl(matched): | def repl(matched): | ||||
key = matched.group(1) | key = matched.group(1) | ||||
if key in conf_item: | if key in conf_item: | ||||
return conf_item[key] | |||||
value = conf_item[key] | |||||
if not isinstance(value, str): | |||||
value = str(value) | |||||
return value | |||||
else: | else: | ||||
return None | return None | ||||