Browse Source

[to #42322933] bug fix: deadlock when setting the thread number up to 90

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10743508

* fix: load model directly from .pth
master
bin.xue 2 years ago
parent
commit
3798677395
4 changed files with 20 additions and 13 deletions
  1. +2
    -1
      modelscope/models/audio/kws/farfield/model.py
  2. +10
    -8
      modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py
  3. +4
    -3
      modelscope/trainers/audio/kws_farfield_trainer.py
  4. +4
    -1
      modelscope/utils/audio/audio_utils.py

+ 2
- 1
modelscope/models/audio/kws/farfield/model.py View File

@@ -54,7 +54,8 @@ class FSMNSeleNetV2Decorator(TorchModel):
) )


def __del__(self): def __del__(self):
self.tmp_dir.cleanup()
if hasattr(self, 'tmp_dir'):
self.tmp_dir.cleanup()


def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.model.forward(input) return self.model.forward(input)


+ 10
- 8
modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py View File

@@ -188,11 +188,13 @@ class Worker(threading.Thread):




class KWSDataLoader: class KWSDataLoader:
"""
dataset: the dataset reference
batchsize: data batch size
numworkers: no. of workers
prefetch: prefetch factor
""" Load and organize audio data with multiple threads

Args:
dataset: the dataset reference
batchsize: data batch size
numworkers: no. of workers
prefetch: prefetch factor
""" """


def __init__(self, dataset, batchsize, numworkers, prefetch=2): def __init__(self, dataset, batchsize, numworkers, prefetch=2):
@@ -202,7 +204,7 @@ class KWSDataLoader:
self.isrun = True self.isrun = True


# data queue # data queue
self.pool = queue.Queue(batchsize * prefetch)
self.pool = queue.Queue(numworkers * prefetch)


# initialize workers # initialize workers
self.workerlist = [] self.workerlist = []
@@ -270,11 +272,11 @@ class KWSDataLoader:
w.stopWorker() w.stopWorker()


while not self.pool.empty(): while not self.pool.empty():
self.pool.get(block=True, timeout=0.001)
self.pool.get(block=True, timeout=0.01)


# wait workers terminated # wait workers terminated
for w in self.workerlist: for w in self.workerlist:
while not self.pool.empty(): while not self.pool.empty():
self.pool.get(block=True, timeout=0.001)
self.pool.get(block=True, timeout=0.01)
w.join() w.join()
logger.info('KWSDataLoader: All worker stopped.') logger.info('KWSDataLoader: All worker stopped.')

+ 4
- 3
modelscope/trainers/audio/kws_farfield_trainer.py View File

@@ -117,8 +117,7 @@ class KWSFarfieldTrainer(BaseTrainer):
self._batch_size = dataloader_config.batch_size_per_gpu self._batch_size = dataloader_config.batch_size_per_gpu
if 'model_bin' in kwargs: if 'model_bin' in kwargs:
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
checkpoint = torch.load(model_bin_file)
self.model.load_state_dict(checkpoint)
self.model = torch.load(model_bin_file)
# build corresponding optimizer and loss function # build corresponding optimizer and loss function
lr = self.cfg.train.optimizer.lr lr = self.cfg.train.optimizer.lr
self.optimizer = optim.Adam(self.model.parameters(), lr) self.optimizer = optim.Adam(self.model.parameters(), lr)
@@ -219,7 +218,9 @@ class KWSFarfieldTrainer(BaseTrainer):
# check point # check point
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
self._current_epoch, loss_train_epoch, loss_val_epoch) self._current_epoch, loss_train_epoch, loss_val_epoch)
torch.save(self.model, os.path.join(self.work_dir, ckpt_name))
save_path = os.path.join(self.work_dir, ckpt_name)
logger.info(f'Save model to {save_path}')
torch.save(self.model, save_path)
# time spent per epoch # time spent per epoch
epochtime = datetime.datetime.now() - epochtime epochtime = datetime.datetime.now() - epochtime
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( logger.info('Epoch {:04d} time spent: {:.2f} hours'.format(


+ 4
- 1
modelscope/utils/audio/audio_utils.py View File

@@ -43,7 +43,10 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
def repl(matched): def repl(matched):
key = matched.group(1) key = matched.group(1)
if key in conf_item: if key in conf_item:
return conf_item[key]
value = conf_item[key]
if not isinstance(value, str):
value = str(value)
return value
else: else:
return None return None




Loading…
Cancel
Save