|
|
@@ -2,6 +2,7 @@ |
|
|
|
|
|
|
|
import os |
|
|
|
import pickle as pkl |
|
|
|
from threading import Lock |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
@@ -27,6 +28,7 @@ class Voice: |
|
|
|
self.__am_config = AttrDict(**am_config) |
|
|
|
self.__voc_config = AttrDict(**voc_config) |
|
|
|
self.__model_loaded = False |
|
|
|
self.__lock = Lock() |
|
|
|
if 'am' not in self.__am_config: |
|
|
|
raise TtsModelConfigurationException( |
|
|
|
'modelscope error: am configuration invalid') |
|
|
@@ -71,34 +73,35 @@ class Voice: |
|
|
|
self.__generator.remove_weight_norm() |
|
|
|
|
|
|
|
def __am_forward(self, symbol_seq): |
|
|
|
with torch.no_grad(): |
|
|
|
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( |
|
|
|
symbol_seq) |
|
|
|
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_ling = torch.stack( |
|
|
|
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], |
|
|
|
dim=-1).unsqueeze(0) |
|
|
|
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( |
|
|
|
self.__device).unsqueeze(0) |
|
|
|
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( |
|
|
|
self.__device).unsqueeze(0) |
|
|
|
inputs_len = torch.zeros(1).to(self.__device).long( |
|
|
|
) + inputs_emo.size(1) - 1 # minus 1 for "~" |
|
|
|
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], |
|
|
|
inputs_spk[:, :-1], inputs_len) |
|
|
|
postnet_outputs = res['postnet_outputs'] |
|
|
|
LR_length_rounded = res['LR_length_rounded'] |
|
|
|
valid_length = int(LR_length_rounded[0].item()) |
|
|
|
postnet_outputs = postnet_outputs[ |
|
|
|
0, :valid_length, :].cpu().numpy() |
|
|
|
return postnet_outputs |
|
|
|
with self.__lock: |
|
|
|
with torch.no_grad(): |
|
|
|
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( |
|
|
|
symbol_seq) |
|
|
|
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_syllable = torch.from_numpy( |
|
|
|
inputs_feat_lst[2]).long().to(self.__device) |
|
|
|
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( |
|
|
|
self.__device) |
|
|
|
inputs_ling = torch.stack( |
|
|
|
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], |
|
|
|
dim=-1).unsqueeze(0) |
|
|
|
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( |
|
|
|
self.__device).unsqueeze(0) |
|
|
|
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( |
|
|
|
self.__device).unsqueeze(0) |
|
|
|
inputs_len = torch.zeros(1).to(self.__device).long( |
|
|
|
) + inputs_emo.size(1) - 1 # minus 1 for "~" |
|
|
|
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], |
|
|
|
inputs_spk[:, :-1], inputs_len) |
|
|
|
postnet_outputs = res['postnet_outputs'] |
|
|
|
LR_length_rounded = res['LR_length_rounded'] |
|
|
|
valid_length = int(LR_length_rounded[0].item()) |
|
|
|
postnet_outputs = postnet_outputs[ |
|
|
|
0, :valid_length, :].cpu().numpy() |
|
|
|
return postnet_outputs |
|
|
|
|
|
|
|
def __vocoder_forward(self, melspec): |
|
|
|
dim0 = list(melspec.shape)[-1] |
|
|
@@ -118,14 +121,15 @@ class Voice: |
|
|
|
return audio |
|
|
|
|
|
|
|
def forward(self, symbol_seq): |
|
|
|
if not self.__model_loaded: |
|
|
|
torch.manual_seed(self.__am_config.seed) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
with self.__lock: |
|
|
|
if not self.__model_loaded: |
|
|
|
torch.manual_seed(self.__am_config.seed) |
|
|
|
self.__device = torch.device('cuda') |
|
|
|
else: |
|
|
|
self.__device = torch.device('cpu') |
|
|
|
self.__load_am() |
|
|
|
self.__load_vocoder() |
|
|
|
self.__model_loaded = True |
|
|
|
if torch.cuda.is_available(): |
|
|
|
torch.manual_seed(self.__am_config.seed) |
|
|
|
self.__device = torch.device('cuda') |
|
|
|
else: |
|
|
|
self.__device = torch.device('cpu') |
|
|
|
self.__load_am() |
|
|
|
self.__load_vocoder() |
|
|
|
self.__model_loaded = True |
|
|
|
return self.__vocoder_forward(self.__am_forward(symbol_seq)) |