|
|
@@ -1,6 +1,7 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import os |
|
|
|
import tempfile |
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
@@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel): |
|
|
|
else: |
|
|
|
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) |
|
|
|
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) |
|
|
|
self.tmp_dir = tempfile.TemporaryDirectory() |
|
|
|
new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG) |
|
|
|
|
|
|
|
self._sc = None |
|
|
|
if os.path.exists(model_txt_file): |
|
|
|
conf_dict = dict(mode=56542, kws_model=model_txt_file) |
|
|
|
update_conf(sc_config_file, sc_config_file, conf_dict) |
|
|
|
update_conf(sc_config_file, new_config_file, conf_dict) |
|
|
|
import py_sound_connect |
|
|
|
self._sc = py_sound_connect.SoundConnect(sc_config_file) |
|
|
|
self._sc = py_sound_connect.SoundConnect(new_config_file) |
|
|
|
self.size_in = self._sc.bytesPerBlockIn() |
|
|
|
self.size_out = self._sc.bytesPerBlockOut() |
|
|
|
else: |
|
|
@@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel): |
|
|
|
f'Invalid model directory! Failed to load model file: {model_txt_file}.' |
|
|
|
) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self.tmp_dir.cleanup() |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
return self.model.forward(input) |
|
|
|
|
|
|
|