diff --git a/modelscope/models/audio/tts/sambert_hifi.py b/modelscope/models/audio/tts/sambert_hifi.py index 72c5b80c..401e32c9 100644 --- a/modelscope/models/audio/tts/sambert_hifi.py +++ b/modelscope/models/audio/tts/sambert_hifi.py @@ -8,9 +8,7 @@ from typing import Any, Dict, Optional, Union import json import numpy as np -import tensorflow as tf import torch -from sklearn.preprocessing import MultiLabelBinarizer from modelscope.metainfo import Models from modelscope.models.base import Model @@ -18,49 +16,20 @@ from modelscope.models.builder import MODELS from modelscope.utils.audio.tts_exceptions import ( TtsFrontendInitializeFailedException, TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion, - TtsVocoderMelspecShapeMismatchException) + TtsVocoderMelspecShapeMismatchException, TtsVoiceNotExistsException) from modelscope.utils.constant import ModelFile, Tasks -from .models import Generator, create_am_model -from .text.symbols import load_symbols -from .text.symbols_dict import SymbolsDict +from .voice import Voice -__all__ = ['SambertHifigan'] -MAX_WAV_VALUE = 32768.0 - - -def multi_label_symbol_to_sequence(my_classes, my_symbol): - one_hot = MultiLabelBinarizer(classes=my_classes) - tokens = my_symbol.strip().split(' ') - sequences = [] - for token in tokens: - sequences.append(tuple(token.split('&'))) - return one_hot.fit_transform(sequences) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - checkpoint_dict = torch.load(filepath, map_location=device) - return checkpoint_dict +import tensorflow as tf # isort:skip - -class AttrDict(dict): - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self +__all__ = ['SambertHifigan'] @MODELS.register_module( Tasks.text_to_speech, module_name=Models.sambert_hifigan) class SambertHifigan(Model): - def __init__(self, - model_dir, - pitch_control_str='', - duration_control_str='', - energy_control_str='', - *args, - **kwargs): + def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) if 'am' not in kwargs: raise TtsModelConfigurationExcetion( @@ -71,284 +40,56 @@ class SambertHifigan(Model): if 'lang_type' not in kwargs: raise TtsModelConfigurationExcetion( 'configuration model field missing lang_type!') + am_cfg = kwargs['am'] + voc_cfg = kwargs['vocoder'] # initialize frontend import ttsfrd frontend = ttsfrd.TtsFrontendEngine() zip_file = os.path.join(model_dir, 'resource.zip') - self._res_path = os.path.join(model_dir, 'resource') + self.__res_path = os.path.join(model_dir, 'resource') with zipfile.ZipFile(zip_file, 'r') as zip_ref: zip_ref.extractall(model_dir) - if not frontend.initialize(self._res_path): + if not frontend.initialize(self.__res_path): raise TtsFrontendInitializeFailedException( - 'resource invalid: {}'.format(self._res_path)) + 'resource invalid: {}'.format(self.__res_path)) if not frontend.set_lang_type(kwargs['lang_type']): raise TtsFrontendLanguageTypeInvalidException( 'language type invalid: {}'.format(kwargs['lang_type'])) - self._frontend = frontend - - # initialize am - tf.reset_default_graph() - local_am_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, - 'ckpt') - self._am_ckpt_path = os.path.join(model_dir, local_am_ckpt_path) - self._dict_path = os.path.join(model_dir, 'dicts') - self._am_hparams = tf.contrib.training.HParams(**kwargs['am']) - has_mask = True - if self._am_hparams.get('has_mask') is not None: - has_mask = self._am_hparams.has_mask - print('set has_mask to {}'.format(has_mask)) - values = self._am_hparams.values() - hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] - print('Hyperparameters:\n' + '\n'.join(hp)) - model_name = 'robutrans' - self._lfeat_type_list = self._am_hparams.lfeat_type_list.strip().split( - ',') - sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( - self._dict_path, has_mask) - self._sy = sy - self._tone = tone - self._syllable_flag = syllable_flag - self._word_segment = word_segment - self._emo_category = emo_category - self._speaker = speaker - self._inputs_dim = dict() - for lfeat_type in self._lfeat_type_list: - if lfeat_type == 'sy': - self._inputs_dim[lfeat_type] = len(sy) - elif lfeat_type == 'tone': - self._inputs_dim[lfeat_type] = len(tone) - elif lfeat_type == 'syllable_flag': - self._inputs_dim[lfeat_type] = len(syllable_flag) - elif lfeat_type == 'word_segment': - self._inputs_dim[lfeat_type] = len(word_segment) - elif lfeat_type == 'emo_category': - self._inputs_dim[lfeat_type] = len(emo_category) - elif lfeat_type == 'speaker': - self._inputs_dim[lfeat_type] = len(speaker) - - self._symbols_dict = SymbolsDict(sy, tone, syllable_flag, word_segment, - emo_category, speaker, - self._inputs_dim, - self._lfeat_type_list) - dim_inputs = sum(self._inputs_dim.values( - )) - self._inputs_dim['speaker'] - self._inputs_dim['emo_category'] - inputs = tf.placeholder(tf.float32, [1, None, dim_inputs], 'inputs') - inputs_emotion = tf.placeholder( - tf.float32, [1, None, self._inputs_dim['emo_category']], - 'inputs_emotion') - inputs_speaker = tf.placeholder(tf.float32, - [1, None, self._inputs_dim['speaker']], - 'inputs_speaker') - input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') - pitch_contours_scale = tf.placeholder(tf.float32, [1, None], - 'pitch_contours_scale') - energy_contours_scale = tf.placeholder(tf.float32, [1, None], - 'energy_contours_scale') - duration_scale = tf.placeholder(tf.float32, [1, None], - 'duration_scale') - with tf.variable_scope('model') as _: - self._model = create_am_model(model_name, self._am_hparams) - self._model.initialize( - inputs, - inputs_emotion, - inputs_speaker, - input_lengths, - duration_scales=duration_scale, - pitch_scales=pitch_contours_scale, - energy_scales=energy_contours_scale) - self._mel_spec = self._model.mel_outputs[0] - self._duration_outputs = self._model.duration_outputs[0] - self._duration_outputs_ = self._model.duration_outputs_[0] - self._pitch_contour_outputs = self._model.pitch_contour_outputs[0] - self._energy_contour_outputs = self._model.energy_contour_outputs[ - 0] - self._embedded_inputs_emotion = self._model.embedded_inputs_emotion[ - 0] - self._embedding_fsmn_outputs = self._model.embedding_fsmn_outputs[ - 0] - self._encoder_outputs = self._model.encoder_outputs[0] - self._pitch_embeddings = self._model.pitch_embeddings[0] - self._energy_embeddings = self._model.energy_embeddings[0] - self._LR_outputs = self._model.LR_outputs[0] - self._postnet_fsmn_outputs = self._model.postnet_fsmn_outputs[0] - self._attention_h = self._model.attention_h - self._attention_x = self._model.attention_x - - print('Loading checkpoint: %s' % self._am_ckpt_path) - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self._session.run(tf.global_variables_initializer()) - - saver = tf.train.Saver() - saver.restore(self._session, self._am_ckpt_path) - - duration_cfg_lst = [] - if len(duration_control_str) != 0: - for item in duration_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - duration_cfg_lst.append((float(percent), float(scale))) - - self._duration_cfg_lst = duration_cfg_lst - - pitch_contours_cfg_lst = [] - if len(pitch_control_str) != 0: - for item in pitch_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - pitch_contours_cfg_lst.append( - (float(percent), float(scale))) - - self._pitch_contours_cfg_lst = pitch_contours_cfg_lst - - energy_contours_cfg_lst = [] - if len(energy_control_str) != 0: - for item in energy_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - energy_contours_cfg_lst.append( - (float(percent), float(scale))) - - self._energy_contours_cfg_lst = energy_contours_cfg_lst - - # initialize vocoder - self._voc_ckpt_path = os.path.join(model_dir, - ModelFile.TORCH_MODEL_BIN_FILE) - self._voc_config = AttrDict(**kwargs['vocoder']) - print(self._voc_config) - if torch.cuda.is_available(): - torch.manual_seed(self._voc_config.seed) - self._device = torch.device('cuda') + self.__frontend = frontend + zip_file = os.path.join(model_dir, 'voices.zip') + self.__voice_path = os.path.join(model_dir, 'voices') + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(model_dir) + voice_cfg_path = os.path.join(self.__voice_path, 'voices.json') + with open(voice_cfg_path, 'r') as f: + voice_cfg = json.load(f) + if 'voices' not in voice_cfg: + raise TtsModelConfigurationExcetion('voices invalid') + self.__voice = {} + for name in voice_cfg['voices']: + voice_path = os.path.join(self.__voice_path, name) + if not os.path.exists(voice_path): + continue + self.__voice[name] = Voice(name, voice_path, am_cfg, voc_cfg) + if voice_cfg['voices']: + self.__default_voice_name = voice_cfg['voices'][0] else: - self._device = torch.device('cpu') - self._generator = Generator(self._voc_config).to(self._device) - state_dict_g = load_checkpoint(self._voc_ckpt_path, self._device) - self._generator.load_state_dict(state_dict_g['generator']) - self._generator.eval() - self._generator.remove_weight_norm() - - def am_synthesis_one_sentences(self, text): - cleaner_names = [ - x.strip() for x in self._am_hparams.cleaners.split(',') - ] - - lfeat_symbol = text.strip().split(' ') - lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) - for this_lfeat_symbol in lfeat_symbol: - this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( - '$') - if len(this_lfeat_symbol) != len(self._lfeat_type_list): - raise Exception( - 'Length of this_lfeat_symbol in training data' - + ' is not equal to the length of lfeat_type_list, ' - + str(len(this_lfeat_symbol)) + ' VS. ' - + str(len(self._lfeat_type_list))) - index = 0 - while index < len(lfeat_symbol_separate): - lfeat_symbol_separate[index] = lfeat_symbol_separate[ - index] + this_lfeat_symbol[index] + ' ' - index = index + 1 - - index = 0 - lfeat_type = self._lfeat_type_list[index] - sequence = self._symbols_dict.symbol_to_sequence( - lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names) - sequence_array = np.asarray( - sequence[:-1], - dtype=np.int32) # sequence length minus 1 to ignore EOS ~ - inputs = np.eye( - self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] - index = index + 1 - while index < len(self._lfeat_type_list) - 2: - lfeat_type = self._lfeat_type_list[index] - sequence = self._symbols_dict.symbol_to_sequence( - lfeat_symbol_separate[index].strip(), lfeat_type, - cleaner_names) - sequence_array = np.asarray( - sequence[:-1], - dtype=np.int32) # sequence length minus 1 to ignore EOS ~ - inputs_temp = np.eye( - self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] - inputs = np.concatenate((inputs, inputs_temp), axis=1) - index = index + 1 - seq = inputs - - lfeat_type = 'emo_category' - inputs_emotion = multi_label_symbol_to_sequence( - self._emo_category, lfeat_symbol_separate[index].strip()) - # inputs_emotion = inputs_emotion * 1.5 - index = index + 1 - - lfeat_type = 'speaker' - inputs_speaker = multi_label_symbol_to_sequence( - self._speaker, lfeat_symbol_separate[index].strip()) - - duration_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in self._duration_cfg_lst: - duration_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in self._pitch_contours_cfg_lst: - pitch_contours_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - energy_contours_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in self._energy_contours_cfg_lst: - energy_contours_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - feed_dict = { - self._model.inputs: [np.asarray(seq, dtype=np.float32)], - self._model.inputs_emotion: - [np.asarray(inputs_emotion, dtype=np.float32)], - self._model.inputs_speaker: - [np.asarray(inputs_speaker, dtype=np.float32)], - self._model.input_lengths: - np.asarray([len(seq)], dtype=np.int32), - self._model.duration_scales: [duration_scale], - self._model.pitch_scales: [pitch_contours_scale], - self._model.energy_scales: [energy_contours_scale] - } - - result = self._session.run([ - self._mel_spec, self._duration_outputs, self._duration_outputs_, - self._pitch_contour_outputs, self._embedded_inputs_emotion, - self._embedding_fsmn_outputs, self._encoder_outputs, - self._pitch_embeddings, self._LR_outputs, - self._postnet_fsmn_outputs, self._energy_contour_outputs, - self._energy_embeddings, self._attention_x, self._attention_h - ], feed_dict=feed_dict) # yapf:disable - return result[0] - - def vocoder_process(self, melspec): - dim0 = list(melspec.shape)[-1] - if dim0 != self._voc_config.num_mels: - raise TtsVocoderMelspecShapeMismatchException( - 'input melspec mismatch require {} but {}'.format( - self._voc_config.num_mels, dim0)) - with torch.no_grad(): - x = melspec.T - x = torch.FloatTensor(x).to(self._device) - if len(x.shape) == 2: - x = x.unsqueeze(0) - y_g_hat = self._generator(x) - audio = y_g_hat.squeeze() - audio = audio * MAX_WAV_VALUE - audio = audio.cpu().numpy().astype('int16') - return audio - - def forward(self, text): - result = self._frontend.gen_tacotron_symbols(text) + raise TtsVoiceNotExistsException('voices is empty in voices.json') + + def __synthesis_one_sentences(self, voice_name, text): + if voice_name not in self.__voice: + raise TtsVoiceNotExistsException(f'Voice {voice_name} not exists') + return self.__voice[voice_name].forward(text) + + def forward(self, text: str, voice_name: str = None): + voice = self.__default_voice_name + if voice_name is not None: + voice = voice_name + result = self.__frontend.gen_tacotron_symbols(text) texts = [s for s in result.splitlines() if s != ''] audio_total = np.empty((0), dtype='int16') for line in texts: line = line.strip().split('\t') - audio = self.vocoder_process( - self.am_synthesis_one_sentences(line[1])) + audio = self.__synthesis_one_sentences(voice, line[1]) audio_total = np.append(audio_total, audio, axis=0) return audio_total diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py new file mode 100644 index 00000000..deaebf11 --- /dev/null +++ b/modelscope/models/audio/tts/voice.py @@ -0,0 +1,300 @@ +import os + +import json +import numpy as np +import torch +from sklearn.preprocessing import MultiLabelBinarizer + +from modelscope.utils.constant import ModelFile, Tasks +from .models import Generator, create_am_model +from .text.symbols import load_symbols +from .text.symbols_dict import SymbolsDict + +import tensorflow as tf # isort:skip + +MAX_WAV_VALUE = 32768.0 + + +def multi_label_symbol_to_sequence(my_classes, my_symbol): + one_hot = MultiLabelBinarizer(classes=my_classes) + tokens = my_symbol.strip().split(' ') + sequences = [] + for token in tokens: + sequences.append(tuple(token.split('&'))) + return one_hot.fit_transform(sequences) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + checkpoint_dict = torch.load(filepath, map_location=device) + return checkpoint_dict + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class Voice: + + def __init__(self, voice_name, voice_path, am_hparams, voc_config): + self.__voice_name = voice_name + self.__voice_path = voice_path + self.__am_hparams = tf.contrib.training.HParams(**am_hparams) + self.__voc_config = AttrDict(**voc_config) + self.__model_loaded = False + + def __load_am(self): + local_am_ckpt_path = os.path.join(self.__voice_path, + ModelFile.TF_CHECKPOINT_FOLDER) + self.__am_ckpt_path = os.path.join(local_am_ckpt_path, 'ckpt') + self.__dict_path = os.path.join(self.__voice_path, 'dicts') + has_mask = True + if self.__am_hparams.get('has_mask') is not None: + has_mask = self.__am_hparams.has_mask + model_name = 'robutrans' + self.__lfeat_type_list = self.__am_hparams.lfeat_type_list.strip( + ).split(',') + sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( + self.__dict_path, has_mask) + self.__sy = sy + self.__tone = tone + self.__syllable_flag = syllable_flag + self.__word_segment = word_segment + self.__emo_category = emo_category + self.__speaker = speaker + self.__inputs_dim = dict() + for lfeat_type in self.__lfeat_type_list: + if lfeat_type == 'sy': + self.__inputs_dim[lfeat_type] = len(sy) + elif lfeat_type == 'tone': + self.__inputs_dim[lfeat_type] = len(tone) + elif lfeat_type == 'syllable_flag': + self.__inputs_dim[lfeat_type] = len(syllable_flag) + elif lfeat_type == 'word_segment': + self.__inputs_dim[lfeat_type] = len(word_segment) + elif lfeat_type == 'emo_category': + self.__inputs_dim[lfeat_type] = len(emo_category) + elif lfeat_type == 'speaker': + self.__inputs_dim[lfeat_type] = len(speaker) + + self.__symbols_dict = SymbolsDict(sy, tone, syllable_flag, + word_segment, emo_category, speaker, + self.__inputs_dim, + self.__lfeat_type_list) + dim_inputs = sum(self.__inputs_dim.values( + )) - self.__inputs_dim['speaker'] - self.__inputs_dim['emo_category'] + self.__graph = tf.Graph() + with self.__graph.as_default(): + inputs = tf.placeholder(tf.float32, [1, None, dim_inputs], + 'inputs') + inputs_emotion = tf.placeholder( + tf.float32, [1, None, self.__inputs_dim['emo_category']], + 'inputs_emotion') + inputs_speaker = tf.placeholder( + tf.float32, [1, None, self.__inputs_dim['speaker']], + 'inputs_speaker') + input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') + pitch_contours_scale = tf.placeholder(tf.float32, [1, None], + 'pitch_contours_scale') + energy_contours_scale = tf.placeholder(tf.float32, [1, None], + 'energy_contours_scale') + duration_scale = tf.placeholder(tf.float32, [1, None], + 'duration_scale') + with tf.variable_scope('model') as _: + self.__model = create_am_model(model_name, self.__am_hparams) + self.__model.initialize( + inputs, + inputs_emotion, + inputs_speaker, + input_lengths, + duration_scales=duration_scale, + pitch_scales=pitch_contours_scale, + energy_scales=energy_contours_scale) + self.__mel_spec = self.__model.mel_outputs[0] + self.__duration_outputs = self.__model.duration_outputs[0] + self.__duration_outputs_ = self.__model.duration_outputs_[0] + self.__pitch_contour_outputs = self.__model.pitch_contour_outputs[ + 0] + self.__energy_contour_outputs = self.__model.energy_contour_outputs[ + 0] + self.__embedded_inputs_emotion = self.__model.embedded_inputs_emotion[ + 0] + self.__embedding_fsmn_outputs = self.__model.embedding_fsmn_outputs[ + 0] + self.__encoder_outputs = self.__model.encoder_outputs[0] + self.__pitch_embeddings = self.__model.pitch_embeddings[0] + self.__energy_embeddings = self.__model.energy_embeddings[0] + self.__LR_outputs = self.__model.LR_outputs[0] + self.__postnet_fsmn_outputs = self.__model.postnet_fsmn_outputs[ + 0] + self.__attention_h = self.__model.attention_h + self.__attention_x = self.__model.attention_x + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.__session = tf.Session(config=config) + self.__session.run(tf.global_variables_initializer()) + + saver = tf.train.Saver() + saver.restore(self.__session, self.__am_ckpt_path) + + def __load_vocoder(self): + self.__voc_ckpt_path = os.path.join(self.__voice_path, + ModelFile.TORCH_MODEL_BIN_FILE) + if torch.cuda.is_available(): + torch.manual_seed(self.__voc_config.seed) + self.__device = torch.device('cuda') + else: + self.__device = torch.device('cpu') + self.__generator = Generator(self.__voc_config).to(self.__device) + state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device) + self.__generator.load_state_dict(state_dict_g['generator']) + self.__generator.eval() + self.__generator.remove_weight_norm() + + def __am_forward(self, + text, + pitch_control_str='', + duration_control_str='', + energy_control_str=''): + duration_cfg_lst = [] + if len(duration_control_str) != 0: + for item in duration_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + duration_cfg_lst.append((float(percent), float(scale))) + pitch_contours_cfg_lst = [] + if len(pitch_control_str) != 0: + for item in pitch_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + pitch_contours_cfg_lst.append((float(percent), float(scale))) + energy_contours_cfg_lst = [] + if len(energy_control_str) != 0: + for item in energy_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + energy_contours_cfg_lst.append((float(percent), float(scale))) + cleaner_names = [ + x.strip() for x in self.__am_hparams.cleaners.split(',') + ] + + lfeat_symbol = text.strip().split(' ') + lfeat_symbol_separate = [''] * int(len(self.__lfeat_type_list)) + for this_lfeat_symbol in lfeat_symbol: + this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( + '$') + if len(this_lfeat_symbol) != len(self.__lfeat_type_list): + raise Exception( + 'Length of this_lfeat_symbol in training data' + + ' is not equal to the length of lfeat_type_list, ' + + str(len(this_lfeat_symbol)) + ' VS. ' + + str(len(self.__lfeat_type_list))) + index = 0 + while index < len(lfeat_symbol_separate): + lfeat_symbol_separate[index] = lfeat_symbol_separate[ + index] + this_lfeat_symbol[index] + ' ' + index = index + 1 + + index = 0 + lfeat_type = self.__lfeat_type_list[index] + sequence = self.__symbols_dict.symbol_to_sequence( + lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names) + sequence_array = np.asarray( + sequence[:-1], + dtype=np.int32) # sequence length minus 1 to ignore EOS ~ + inputs = np.eye( + self.__inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] + index = index + 1 + while index < len(self.__lfeat_type_list) - 2: + lfeat_type = self.__lfeat_type_list[index] + sequence = self.__symbols_dict.symbol_to_sequence( + lfeat_symbol_separate[index].strip(), lfeat_type, + cleaner_names) + sequence_array = np.asarray( + sequence[:-1], + dtype=np.int32) # sequence length minus 1 to ignore EOS ~ + inputs_temp = np.eye( + self.__inputs_dim[lfeat_type], + dtype=np.float32)[sequence_array] + inputs = np.concatenate((inputs, inputs_temp), axis=1) + index = index + 1 + seq = inputs + + lfeat_type = 'emo_category' + inputs_emotion = multi_label_symbol_to_sequence( + self.__emo_category, lfeat_symbol_separate[index].strip()) + # inputs_emotion = inputs_emotion * 1.5 + index = index + 1 + + lfeat_type = 'speaker' + inputs_speaker = multi_label_symbol_to_sequence( + self.__speaker, lfeat_symbol_separate[index].strip()) + + duration_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in duration_cfg_lst: + duration_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in pitch_contours_cfg_lst: + pitch_contours_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + energy_contours_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in energy_contours_cfg_lst: + energy_contours_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + feed_dict = { + self.__model.inputs: [np.asarray(seq, dtype=np.float32)], + self.__model.inputs_emotion: + [np.asarray(inputs_emotion, dtype=np.float32)], + self.__model.inputs_speaker: + [np.asarray(inputs_speaker, dtype=np.float32)], + self.__model.input_lengths: + np.asarray([len(seq)], dtype=np.int32), + self.__model.duration_scales: [duration_scale], + self.__model.pitch_scales: [pitch_contours_scale], + self.__model.energy_scales: [energy_contours_scale] + } + + result = self.__session.run([ + self.__mel_spec, self.__duration_outputs, self.__duration_outputs_, + self.__pitch_contour_outputs, self.__embedded_inputs_emotion, + self.__embedding_fsmn_outputs, self.__encoder_outputs, + self.__pitch_embeddings, self.__LR_outputs, + self.__postnet_fsmn_outputs, self.__energy_contour_outputs, + self.__energy_embeddings, self.__attention_x, self.__attention_h + ], feed_dict=feed_dict) # yapf:disable + return result[0] + + def __vocoder_forward(self, melspec): + dim0 = list(melspec.shape)[-1] + if dim0 != self.__voc_config.num_mels: + raise TtsVocoderMelspecShapeMismatchException( + 'input melspec mismatch require {} but {}'.format( + self.__voc_config.num_mels, dim0)) + with torch.no_grad(): + x = melspec.T + x = torch.FloatTensor(x).to(self.__device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + y_g_hat = self.__generator(x) + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + return audio + + def forward(self, text): + if not self.__model_loaded: + self.__load_am() + self.__load_vocoder() + self.__model_loaded = True + return self.__vocoder_forward(self.__am_forward(text)) diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py index d8d7ca02..23380e25 100644 --- a/modelscope/pipelines/audio/text_to_speech_pipeline.py +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -37,7 +37,7 @@ class TextToSpeechSambertHifiganPipeline(Pipeline): """ output_wav = {} for label, text in inputs.items(): - output_wav[label] = self.model.forward(text) + output_wav[label] = self.model.forward(text, inputs.get('voice')) return {OutputKeys.OUTPUT_PCM: output_wav} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/utils/audio/tts_exceptions.py b/modelscope/utils/audio/tts_exceptions.py index 6204582d..8c73b603 100644 --- a/modelscope/utils/audio/tts_exceptions.py +++ b/modelscope/utils/audio/tts_exceptions.py @@ -17,6 +17,13 @@ class TtsModelConfigurationExcetion(TtsException): pass +class TtsVoiceNotExistsException(TtsException): + """ + TTS voice not exists exception. + """ + pass + + class TtsFrontendException(TtsException): """ TTS frontend module level exceptions. diff --git a/output.wav b/output.wav new file mode 100644 index 00000000..fcbc39ad --- /dev/null +++ b/output.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4153d9ffc0b72eeaf162b5c9f4426f95dcea2bb0da9e7b5e1b72fd2643b1915 +size 50444 diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index bd9ddb20..37fe07e5 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -26,13 +26,14 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): def test_pipeline(self): single_test_case_label = 'test_case_label_0' text = '今天北京天气怎么样?' - model_id = 'damo/speech_sambert-hifigan_tts_zhitian_emo_zhcn_16k' + model_id = 'damo/speech_sambert-hifigan_tts_zhcn_16k' + voice = 'zhitian_emo' sambert_hifigan_tts = pipeline( task=Tasks.text_to_speech, model=model_id) self.assertTrue(sambert_hifigan_tts is not None) - test_cases = {single_test_case_label: text} - output = sambert_hifigan_tts(test_cases) + inputs = {single_test_case_label: text, 'voice': voice} + output = sambert_hifigan_tts(inputs) self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label] write('output.wav', 16000, pcm)