Browse Source

Merge remote-tracking branch 'origin/master' into cv/fix_3d_body_keypoints

master
hanyuan.chy 2 years ago
parent
commit
292a75f5cd
2 changed files with 14 additions and 4 deletions
  1. +9
    -2
      modelscope/models/audio/kws/farfield/model.py
  2. +5
    -2
      modelscope/trainers/audio/kws_farfield_trainer.py

+ 9
- 2
modelscope/models/audio/kws/farfield/model.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tempfile
from typing import Dict, Optional

from modelscope.metainfo import Models
@@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel):
else:
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
self.tmp_dir = tempfile.TemporaryDirectory()
new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG)

self._sc = None
if os.path.exists(model_txt_file):
conf_dict = dict(mode=56542, kws_model=model_txt_file)
update_conf(sc_config_file, sc_config_file, conf_dict)
update_conf(sc_config_file, new_config_file, conf_dict)
import py_sound_connect
self._sc = py_sound_connect.SoundConnect(sc_config_file)
self._sc = py_sound_connect.SoundConnect(new_config_file)
self.size_in = self._sc.bytesPerBlockIn()
self.size_out = self._sc.bytesPerBlockOut()
else:
@@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel):
f'Invalid model directory! Failed to load model file: {model_txt_file}.'
)

def __del__(self):
self.tmp_dir.cleanup()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.model.forward(input)



+ 5
- 2
modelscope/trainers/audio/kws_farfield_trainer.py View File

@@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer):

super().__init__(cfg_file, arg_parse_fn)

self.model = self.build_model()
self.work_dir = work_dir
# the number of model output dimension
# should update config outside the trainer, if user need more wake word
num_syn = kwargs.get('num_syn', None)
if num_syn:
self.cfg.model.num_syn = num_syn
self._num_classes = self.cfg.model.num_syn
self.model = self.build_model()
self.work_dir = work_dir

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])


Loading…
Cancel
Save