diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index dc830db5..b7240088 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -2,6 +2,7 @@ import os import pickle as pkl +from threading import Lock import json import numpy as np @@ -27,6 +28,7 @@ class Voice: self.__am_config = AttrDict(**am_config) self.__voc_config = AttrDict(**voc_config) self.__model_loaded = False + self.__lock = Lock() if 'am' not in self.__am_config: raise TtsModelConfigurationException( 'modelscope error: am configuration invalid') @@ -71,34 +73,35 @@ class Voice: self.__generator.remove_weight_norm() def __am_forward(self, symbol_seq): - with torch.no_grad(): - inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( - symbol_seq) - inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( - self.__device) - inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( - self.__device) - inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( - self.__device) - inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( - self.__device) - inputs_ling = torch.stack( - [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], - dim=-1).unsqueeze(0) - inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( - self.__device).unsqueeze(0) - inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( - self.__device).unsqueeze(0) - inputs_len = torch.zeros(1).to(self.__device).long( - ) + inputs_emo.size(1) - 1 # minus 1 for "~" - res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], - inputs_spk[:, :-1], inputs_len) - postnet_outputs = res['postnet_outputs'] - LR_length_rounded = res['LR_length_rounded'] - valid_length = int(LR_length_rounded[0].item()) - postnet_outputs = postnet_outputs[ - 0, :valid_length, :].cpu().numpy() - return postnet_outputs + with self.__lock: + with torch.no_grad(): + inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( + symbol_seq) + inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( + self.__device) + inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( + self.__device) + inputs_syllable = torch.from_numpy( + inputs_feat_lst[2]).long().to(self.__device) + inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( + self.__device) + inputs_ling = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], + dim=-1).unsqueeze(0) + inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( + self.__device).unsqueeze(0) + inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( + self.__device).unsqueeze(0) + inputs_len = torch.zeros(1).to(self.__device).long( + ) + inputs_emo.size(1) - 1 # minus 1 for "~" + res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], + inputs_spk[:, :-1], inputs_len) + postnet_outputs = res['postnet_outputs'] + LR_length_rounded = res['LR_length_rounded'] + valid_length = int(LR_length_rounded[0].item()) + postnet_outputs = postnet_outputs[ + 0, :valid_length, :].cpu().numpy() + return postnet_outputs def __vocoder_forward(self, melspec): dim0 = list(melspec.shape)[-1] @@ -118,14 +121,15 @@ class Voice: return audio def forward(self, symbol_seq): - if not self.__model_loaded: - torch.manual_seed(self.__am_config.seed) - if torch.cuda.is_available(): + with self.__lock: + if not self.__model_loaded: torch.manual_seed(self.__am_config.seed) - self.__device = torch.device('cuda') - else: - self.__device = torch.device('cpu') - self.__load_am() - self.__load_vocoder() - self.__model_loaded = True + if torch.cuda.is_available(): + torch.manual_seed(self.__am_config.seed) + self.__device = torch.device('cuda') + else: + self.__device = torch.device('cpu') + self.__load_am() + self.__load_vocoder() + self.__model_loaded = True return self.__vocoder_forward(self.__am_forward(symbol_seq))