Browse Source

[To #42322933] Make all same lang_type voice models into one

[To #42322933] Merge same language type voice models into one model card, these changes make demo service much easier to handle different voices.

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9353173
master
jiaqi.sjq 3 years ago
parent
commit
1a486b4b28
6 changed files with 356 additions and 304 deletions
  1. +41
    -300
      modelscope/models/audio/tts/sambert_hifi.py
  2. +300
    -0
      modelscope/models/audio/tts/voice.py
  3. +1
    -1
      modelscope/pipelines/audio/text_to_speech_pipeline.py
  4. +7
    -0
      modelscope/utils/audio/tts_exceptions.py
  5. +3
    -0
      output.wav
  6. +4
    -3
      tests/pipelines/test_text_to_speech.py

+ 41
- 300
modelscope/models/audio/tts/sambert_hifi.py View File

@@ -8,9 +8,7 @@ from typing import Any, Dict, Optional, Union

import json
import numpy as np
import tensorflow as tf
import torch
from sklearn.preprocessing import MultiLabelBinarizer

from modelscope.metainfo import Models
from modelscope.models.base import Model
@@ -18,49 +16,20 @@ from modelscope.models.builder import MODELS
from modelscope.utils.audio.tts_exceptions import (
TtsFrontendInitializeFailedException,
TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion,
TtsVocoderMelspecShapeMismatchException)
TtsVocoderMelspecShapeMismatchException, TtsVoiceNotExistsException)
from modelscope.utils.constant import ModelFile, Tasks
from .models import Generator, create_am_model
from .text.symbols import load_symbols
from .text.symbols_dict import SymbolsDict
from .voice import Voice

__all__ = ['SambertHifigan']
MAX_WAV_VALUE = 32768.0


def multi_label_symbol_to_sequence(my_classes, my_symbol):
one_hot = MultiLabelBinarizer(classes=my_classes)
tokens = my_symbol.strip().split(' ')
sequences = []
for token in tokens:
sequences.append(tuple(token.split('&')))
return one_hot.fit_transform(sequences)


def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
checkpoint_dict = torch.load(filepath, map_location=device)
return checkpoint_dict
import tensorflow as tf # isort:skip


class AttrDict(dict):

def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
__all__ = ['SambertHifigan']


@MODELS.register_module(
Tasks.text_to_speech, module_name=Models.sambert_hifigan)
class SambertHifigan(Model):

def __init__(self,
model_dir,
pitch_control_str='',
duration_control_str='',
energy_control_str='',
*args,
**kwargs):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
if 'am' not in kwargs:
raise TtsModelConfigurationExcetion(
@@ -71,284 +40,56 @@ class SambertHifigan(Model):
if 'lang_type' not in kwargs:
raise TtsModelConfigurationExcetion(
'configuration model field missing lang_type!')
am_cfg = kwargs['am']
voc_cfg = kwargs['vocoder']
# initialize frontend
import ttsfrd
frontend = ttsfrd.TtsFrontendEngine()
zip_file = os.path.join(model_dir, 'resource.zip')
self._res_path = os.path.join(model_dir, 'resource')
self.__res_path = os.path.join(model_dir, 'resource')
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(model_dir)
if not frontend.initialize(self._res_path):
if not frontend.initialize(self.__res_path):
raise TtsFrontendInitializeFailedException(
'resource invalid: {}'.format(self._res_path))
'resource invalid: {}'.format(self.__res_path))
if not frontend.set_lang_type(kwargs['lang_type']):
raise TtsFrontendLanguageTypeInvalidException(
'language type invalid: {}'.format(kwargs['lang_type']))
self._frontend = frontend

# initialize am
tf.reset_default_graph()
local_am_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER,
'ckpt')
self._am_ckpt_path = os.path.join(model_dir, local_am_ckpt_path)
self._dict_path = os.path.join(model_dir, 'dicts')
self._am_hparams = tf.contrib.training.HParams(**kwargs['am'])
has_mask = True
if self._am_hparams.get('has_mask') is not None:
has_mask = self._am_hparams.has_mask
print('set has_mask to {}'.format(has_mask))
values = self._am_hparams.values()
hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)]
print('Hyperparameters:\n' + '\n'.join(hp))
model_name = 'robutrans'
self._lfeat_type_list = self._am_hparams.lfeat_type_list.strip().split(
',')
sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols(
self._dict_path, has_mask)
self._sy = sy
self._tone = tone
self._syllable_flag = syllable_flag
self._word_segment = word_segment
self._emo_category = emo_category
self._speaker = speaker
self._inputs_dim = dict()
for lfeat_type in self._lfeat_type_list:
if lfeat_type == 'sy':
self._inputs_dim[lfeat_type] = len(sy)
elif lfeat_type == 'tone':
self._inputs_dim[lfeat_type] = len(tone)
elif lfeat_type == 'syllable_flag':
self._inputs_dim[lfeat_type] = len(syllable_flag)
elif lfeat_type == 'word_segment':
self._inputs_dim[lfeat_type] = len(word_segment)
elif lfeat_type == 'emo_category':
self._inputs_dim[lfeat_type] = len(emo_category)
elif lfeat_type == 'speaker':
self._inputs_dim[lfeat_type] = len(speaker)

self._symbols_dict = SymbolsDict(sy, tone, syllable_flag, word_segment,
emo_category, speaker,
self._inputs_dim,
self._lfeat_type_list)
dim_inputs = sum(self._inputs_dim.values(
)) - self._inputs_dim['speaker'] - self._inputs_dim['emo_category']
inputs = tf.placeholder(tf.float32, [1, None, dim_inputs], 'inputs')
inputs_emotion = tf.placeholder(
tf.float32, [1, None, self._inputs_dim['emo_category']],
'inputs_emotion')
inputs_speaker = tf.placeholder(tf.float32,
[1, None, self._inputs_dim['speaker']],
'inputs_speaker')
input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
pitch_contours_scale = tf.placeholder(tf.float32, [1, None],
'pitch_contours_scale')
energy_contours_scale = tf.placeholder(tf.float32, [1, None],
'energy_contours_scale')
duration_scale = tf.placeholder(tf.float32, [1, None],
'duration_scale')
with tf.variable_scope('model') as _:
self._model = create_am_model(model_name, self._am_hparams)
self._model.initialize(
inputs,
inputs_emotion,
inputs_speaker,
input_lengths,
duration_scales=duration_scale,
pitch_scales=pitch_contours_scale,
energy_scales=energy_contours_scale)
self._mel_spec = self._model.mel_outputs[0]
self._duration_outputs = self._model.duration_outputs[0]
self._duration_outputs_ = self._model.duration_outputs_[0]
self._pitch_contour_outputs = self._model.pitch_contour_outputs[0]
self._energy_contour_outputs = self._model.energy_contour_outputs[
0]
self._embedded_inputs_emotion = self._model.embedded_inputs_emotion[
0]
self._embedding_fsmn_outputs = self._model.embedding_fsmn_outputs[
0]
self._encoder_outputs = self._model.encoder_outputs[0]
self._pitch_embeddings = self._model.pitch_embeddings[0]
self._energy_embeddings = self._model.energy_embeddings[0]
self._LR_outputs = self._model.LR_outputs[0]
self._postnet_fsmn_outputs = self._model.postnet_fsmn_outputs[0]
self._attention_h = self._model.attention_h
self._attention_x = self._model.attention_x

print('Loading checkpoint: %s' % self._am_ckpt_path)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
self._session.run(tf.global_variables_initializer())

saver = tf.train.Saver()
saver.restore(self._session, self._am_ckpt_path)

duration_cfg_lst = []
if len(duration_control_str) != 0:
for item in duration_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
duration_cfg_lst.append((float(percent), float(scale)))

self._duration_cfg_lst = duration_cfg_lst

pitch_contours_cfg_lst = []
if len(pitch_control_str) != 0:
for item in pitch_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
pitch_contours_cfg_lst.append(
(float(percent), float(scale)))

self._pitch_contours_cfg_lst = pitch_contours_cfg_lst

energy_contours_cfg_lst = []
if len(energy_control_str) != 0:
for item in energy_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
energy_contours_cfg_lst.append(
(float(percent), float(scale)))

self._energy_contours_cfg_lst = energy_contours_cfg_lst

# initialize vocoder
self._voc_ckpt_path = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
self._voc_config = AttrDict(**kwargs['vocoder'])
print(self._voc_config)
if torch.cuda.is_available():
torch.manual_seed(self._voc_config.seed)
self._device = torch.device('cuda')
self.__frontend = frontend
zip_file = os.path.join(model_dir, 'voices.zip')
self.__voice_path = os.path.join(model_dir, 'voices')
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(model_dir)
voice_cfg_path = os.path.join(self.__voice_path, 'voices.json')
with open(voice_cfg_path, 'r') as f:
voice_cfg = json.load(f)
if 'voices' not in voice_cfg:
raise TtsModelConfigurationExcetion('voices invalid')
self.__voice = {}
for name in voice_cfg['voices']:
voice_path = os.path.join(self.__voice_path, name)
if not os.path.exists(voice_path):
continue
self.__voice[name] = Voice(name, voice_path, am_cfg, voc_cfg)
if voice_cfg['voices']:
self.__default_voice_name = voice_cfg['voices'][0]
else:
self._device = torch.device('cpu')
self._generator = Generator(self._voc_config).to(self._device)
state_dict_g = load_checkpoint(self._voc_ckpt_path, self._device)
self._generator.load_state_dict(state_dict_g['generator'])
self._generator.eval()
self._generator.remove_weight_norm()

def am_synthesis_one_sentences(self, text):
cleaner_names = [
x.strip() for x in self._am_hparams.cleaners.split(',')
]

lfeat_symbol = text.strip().split(' ')
lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list))
for this_lfeat_symbol in lfeat_symbol:
this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split(
'$')
if len(this_lfeat_symbol) != len(self._lfeat_type_list):
raise Exception(
'Length of this_lfeat_symbol in training data'
+ ' is not equal to the length of lfeat_type_list, '
+ str(len(this_lfeat_symbol)) + ' VS. '
+ str(len(self._lfeat_type_list)))
index = 0
while index < len(lfeat_symbol_separate):
lfeat_symbol_separate[index] = lfeat_symbol_separate[
index] + this_lfeat_symbol[index] + ' '
index = index + 1

index = 0
lfeat_type = self._lfeat_type_list[index]
sequence = self._symbols_dict.symbol_to_sequence(
lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names)
sequence_array = np.asarray(
sequence[:-1],
dtype=np.int32) # sequence length minus 1 to ignore EOS ~
inputs = np.eye(
self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array]
index = index + 1
while index < len(self._lfeat_type_list) - 2:
lfeat_type = self._lfeat_type_list[index]
sequence = self._symbols_dict.symbol_to_sequence(
lfeat_symbol_separate[index].strip(), lfeat_type,
cleaner_names)
sequence_array = np.asarray(
sequence[:-1],
dtype=np.int32) # sequence length minus 1 to ignore EOS ~
inputs_temp = np.eye(
self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array]
inputs = np.concatenate((inputs, inputs_temp), axis=1)
index = index + 1
seq = inputs

lfeat_type = 'emo_category'
inputs_emotion = multi_label_symbol_to_sequence(
self._emo_category, lfeat_symbol_separate[index].strip())
# inputs_emotion = inputs_emotion * 1.5
index = index + 1

lfeat_type = 'speaker'
inputs_speaker = multi_label_symbol_to_sequence(
self._speaker, lfeat_symbol_separate[index].strip())

duration_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in self._duration_cfg_lst:
duration_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in self._pitch_contours_cfg_lst:
pitch_contours_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

energy_contours_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in self._energy_contours_cfg_lst:
energy_contours_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

feed_dict = {
self._model.inputs: [np.asarray(seq, dtype=np.float32)],
self._model.inputs_emotion:
[np.asarray(inputs_emotion, dtype=np.float32)],
self._model.inputs_speaker:
[np.asarray(inputs_speaker, dtype=np.float32)],
self._model.input_lengths:
np.asarray([len(seq)], dtype=np.int32),
self._model.duration_scales: [duration_scale],
self._model.pitch_scales: [pitch_contours_scale],
self._model.energy_scales: [energy_contours_scale]
}

result = self._session.run([
self._mel_spec, self._duration_outputs, self._duration_outputs_,
self._pitch_contour_outputs, self._embedded_inputs_emotion,
self._embedding_fsmn_outputs, self._encoder_outputs,
self._pitch_embeddings, self._LR_outputs,
self._postnet_fsmn_outputs, self._energy_contour_outputs,
self._energy_embeddings, self._attention_x, self._attention_h
], feed_dict=feed_dict) # yapf:disable
return result[0]

def vocoder_process(self, melspec):
dim0 = list(melspec.shape)[-1]
if dim0 != self._voc_config.num_mels:
raise TtsVocoderMelspecShapeMismatchException(
'input melspec mismatch require {} but {}'.format(
self._voc_config.num_mels, dim0))
with torch.no_grad():
x = melspec.T
x = torch.FloatTensor(x).to(self._device)
if len(x.shape) == 2:
x = x.unsqueeze(0)
y_g_hat = self._generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio

def forward(self, text):
result = self._frontend.gen_tacotron_symbols(text)
raise TtsVoiceNotExistsException('voices is empty in voices.json')

def __synthesis_one_sentences(self, voice_name, text):
if voice_name not in self.__voice:
raise TtsVoiceNotExistsException(f'Voice {voice_name} not exists')
return self.__voice[voice_name].forward(text)

def forward(self, text: str, voice_name: str = None):
voice = self.__default_voice_name
if voice_name is not None:
voice = voice_name
result = self.__frontend.gen_tacotron_symbols(text)
texts = [s for s in result.splitlines() if s != '']
audio_total = np.empty((0), dtype='int16')
for line in texts:
line = line.strip().split('\t')
audio = self.vocoder_process(
self.am_synthesis_one_sentences(line[1]))
audio = self.__synthesis_one_sentences(voice, line[1])
audio_total = np.append(audio_total, audio, axis=0)
return audio_total

+ 300
- 0
modelscope/models/audio/tts/voice.py View File

@@ -0,0 +1,300 @@
import os

import json
import numpy as np
import torch
from sklearn.preprocessing import MultiLabelBinarizer

from modelscope.utils.constant import ModelFile, Tasks
from .models import Generator, create_am_model
from .text.symbols import load_symbols
from .text.symbols_dict import SymbolsDict

import tensorflow as tf # isort:skip

MAX_WAV_VALUE = 32768.0


def multi_label_symbol_to_sequence(my_classes, my_symbol):
one_hot = MultiLabelBinarizer(classes=my_classes)
tokens = my_symbol.strip().split(' ')
sequences = []
for token in tokens:
sequences.append(tuple(token.split('&')))
return one_hot.fit_transform(sequences)


def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
checkpoint_dict = torch.load(filepath, map_location=device)
return checkpoint_dict


class AttrDict(dict):

def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


class Voice:

def __init__(self, voice_name, voice_path, am_hparams, voc_config):
self.__voice_name = voice_name
self.__voice_path = voice_path
self.__am_hparams = tf.contrib.training.HParams(**am_hparams)
self.__voc_config = AttrDict(**voc_config)
self.__model_loaded = False

def __load_am(self):
local_am_ckpt_path = os.path.join(self.__voice_path,
ModelFile.TF_CHECKPOINT_FOLDER)
self.__am_ckpt_path = os.path.join(local_am_ckpt_path, 'ckpt')
self.__dict_path = os.path.join(self.__voice_path, 'dicts')
has_mask = True
if self.__am_hparams.get('has_mask') is not None:
has_mask = self.__am_hparams.has_mask
model_name = 'robutrans'
self.__lfeat_type_list = self.__am_hparams.lfeat_type_list.strip(
).split(',')
sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols(
self.__dict_path, has_mask)
self.__sy = sy
self.__tone = tone
self.__syllable_flag = syllable_flag
self.__word_segment = word_segment
self.__emo_category = emo_category
self.__speaker = speaker
self.__inputs_dim = dict()
for lfeat_type in self.__lfeat_type_list:
if lfeat_type == 'sy':
self.__inputs_dim[lfeat_type] = len(sy)
elif lfeat_type == 'tone':
self.__inputs_dim[lfeat_type] = len(tone)
elif lfeat_type == 'syllable_flag':
self.__inputs_dim[lfeat_type] = len(syllable_flag)
elif lfeat_type == 'word_segment':
self.__inputs_dim[lfeat_type] = len(word_segment)
elif lfeat_type == 'emo_category':
self.__inputs_dim[lfeat_type] = len(emo_category)
elif lfeat_type == 'speaker':
self.__inputs_dim[lfeat_type] = len(speaker)

self.__symbols_dict = SymbolsDict(sy, tone, syllable_flag,
word_segment, emo_category, speaker,
self.__inputs_dim,
self.__lfeat_type_list)
dim_inputs = sum(self.__inputs_dim.values(
)) - self.__inputs_dim['speaker'] - self.__inputs_dim['emo_category']
self.__graph = tf.Graph()
with self.__graph.as_default():
inputs = tf.placeholder(tf.float32, [1, None, dim_inputs],
'inputs')
inputs_emotion = tf.placeholder(
tf.float32, [1, None, self.__inputs_dim['emo_category']],
'inputs_emotion')
inputs_speaker = tf.placeholder(
tf.float32, [1, None, self.__inputs_dim['speaker']],
'inputs_speaker')
input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
pitch_contours_scale = tf.placeholder(tf.float32, [1, None],
'pitch_contours_scale')
energy_contours_scale = tf.placeholder(tf.float32, [1, None],
'energy_contours_scale')
duration_scale = tf.placeholder(tf.float32, [1, None],
'duration_scale')
with tf.variable_scope('model') as _:
self.__model = create_am_model(model_name, self.__am_hparams)
self.__model.initialize(
inputs,
inputs_emotion,
inputs_speaker,
input_lengths,
duration_scales=duration_scale,
pitch_scales=pitch_contours_scale,
energy_scales=energy_contours_scale)
self.__mel_spec = self.__model.mel_outputs[0]
self.__duration_outputs = self.__model.duration_outputs[0]
self.__duration_outputs_ = self.__model.duration_outputs_[0]
self.__pitch_contour_outputs = self.__model.pitch_contour_outputs[
0]
self.__energy_contour_outputs = self.__model.energy_contour_outputs[
0]
self.__embedded_inputs_emotion = self.__model.embedded_inputs_emotion[
0]
self.__embedding_fsmn_outputs = self.__model.embedding_fsmn_outputs[
0]
self.__encoder_outputs = self.__model.encoder_outputs[0]
self.__pitch_embeddings = self.__model.pitch_embeddings[0]
self.__energy_embeddings = self.__model.energy_embeddings[0]
self.__LR_outputs = self.__model.LR_outputs[0]
self.__postnet_fsmn_outputs = self.__model.postnet_fsmn_outputs[
0]
self.__attention_h = self.__model.attention_h
self.__attention_x = self.__model.attention_x

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.__session = tf.Session(config=config)
self.__session.run(tf.global_variables_initializer())

saver = tf.train.Saver()
saver.restore(self.__session, self.__am_ckpt_path)

def __load_vocoder(self):
self.__voc_ckpt_path = os.path.join(self.__voice_path,
ModelFile.TORCH_MODEL_BIN_FILE)
if torch.cuda.is_available():
torch.manual_seed(self.__voc_config.seed)
self.__device = torch.device('cuda')
else:
self.__device = torch.device('cpu')
self.__generator = Generator(self.__voc_config).to(self.__device)
state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device)
self.__generator.load_state_dict(state_dict_g['generator'])
self.__generator.eval()
self.__generator.remove_weight_norm()

def __am_forward(self,
text,
pitch_control_str='',
duration_control_str='',
energy_control_str=''):
duration_cfg_lst = []
if len(duration_control_str) != 0:
for item in duration_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
duration_cfg_lst.append((float(percent), float(scale)))
pitch_contours_cfg_lst = []
if len(pitch_control_str) != 0:
for item in pitch_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
pitch_contours_cfg_lst.append((float(percent), float(scale)))
energy_contours_cfg_lst = []
if len(energy_control_str) != 0:
for item in energy_control_str.strip().split('|'):
percent, scale = item.lstrip('(').rstrip(')').split(',')
energy_contours_cfg_lst.append((float(percent), float(scale)))
cleaner_names = [
x.strip() for x in self.__am_hparams.cleaners.split(',')
]

lfeat_symbol = text.strip().split(' ')
lfeat_symbol_separate = [''] * int(len(self.__lfeat_type_list))
for this_lfeat_symbol in lfeat_symbol:
this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split(
'$')
if len(this_lfeat_symbol) != len(self.__lfeat_type_list):
raise Exception(
'Length of this_lfeat_symbol in training data'
+ ' is not equal to the length of lfeat_type_list, '
+ str(len(this_lfeat_symbol)) + ' VS. '
+ str(len(self.__lfeat_type_list)))
index = 0
while index < len(lfeat_symbol_separate):
lfeat_symbol_separate[index] = lfeat_symbol_separate[
index] + this_lfeat_symbol[index] + ' '
index = index + 1

index = 0
lfeat_type = self.__lfeat_type_list[index]
sequence = self.__symbols_dict.symbol_to_sequence(
lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names)
sequence_array = np.asarray(
sequence[:-1],
dtype=np.int32) # sequence length minus 1 to ignore EOS ~
inputs = np.eye(
self.__inputs_dim[lfeat_type], dtype=np.float32)[sequence_array]
index = index + 1
while index < len(self.__lfeat_type_list) - 2:
lfeat_type = self.__lfeat_type_list[index]
sequence = self.__symbols_dict.symbol_to_sequence(
lfeat_symbol_separate[index].strip(), lfeat_type,
cleaner_names)
sequence_array = np.asarray(
sequence[:-1],
dtype=np.int32) # sequence length minus 1 to ignore EOS ~
inputs_temp = np.eye(
self.__inputs_dim[lfeat_type],
dtype=np.float32)[sequence_array]
inputs = np.concatenate((inputs, inputs_temp), axis=1)
index = index + 1
seq = inputs

lfeat_type = 'emo_category'
inputs_emotion = multi_label_symbol_to_sequence(
self.__emo_category, lfeat_symbol_separate[index].strip())
# inputs_emotion = inputs_emotion * 1.5
index = index + 1

lfeat_type = 'speaker'
inputs_speaker = multi_label_symbol_to_sequence(
self.__speaker, lfeat_symbol_separate[index].strip())

duration_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in duration_cfg_lst:
duration_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in pitch_contours_cfg_lst:
pitch_contours_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

energy_contours_scale = np.ones((len(seq), ), dtype=np.float32)
start_idx = 0
for (percent, scale) in energy_contours_cfg_lst:
energy_contours_scale[start_idx:start_idx
+ int(percent * len(seq))] = scale
start_idx += int(percent * len(seq))

feed_dict = {
self.__model.inputs: [np.asarray(seq, dtype=np.float32)],
self.__model.inputs_emotion:
[np.asarray(inputs_emotion, dtype=np.float32)],
self.__model.inputs_speaker:
[np.asarray(inputs_speaker, dtype=np.float32)],
self.__model.input_lengths:
np.asarray([len(seq)], dtype=np.int32),
self.__model.duration_scales: [duration_scale],
self.__model.pitch_scales: [pitch_contours_scale],
self.__model.energy_scales: [energy_contours_scale]
}

result = self.__session.run([
self.__mel_spec, self.__duration_outputs, self.__duration_outputs_,
self.__pitch_contour_outputs, self.__embedded_inputs_emotion,
self.__embedding_fsmn_outputs, self.__encoder_outputs,
self.__pitch_embeddings, self.__LR_outputs,
self.__postnet_fsmn_outputs, self.__energy_contour_outputs,
self.__energy_embeddings, self.__attention_x, self.__attention_h
], feed_dict=feed_dict) # yapf:disable
return result[0]

def __vocoder_forward(self, melspec):
dim0 = list(melspec.shape)[-1]
if dim0 != self.__voc_config.num_mels:
raise TtsVocoderMelspecShapeMismatchException(
'input melspec mismatch require {} but {}'.format(
self.__voc_config.num_mels, dim0))
with torch.no_grad():
x = melspec.T
x = torch.FloatTensor(x).to(self.__device)
if len(x.shape) == 2:
x = x.unsqueeze(0)
y_g_hat = self.__generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio

def forward(self, text):
if not self.__model_loaded:
self.__load_am()
self.__load_vocoder()
self.__model_loaded = True
return self.__vocoder_forward(self.__am_forward(text))

+ 1
- 1
modelscope/pipelines/audio/text_to_speech_pipeline.py View File

@@ -37,7 +37,7 @@ class TextToSpeechSambertHifiganPipeline(Pipeline):
"""
output_wav = {}
for label, text in inputs.items():
output_wav[label] = self.model.forward(text)
output_wav[label] = self.model.forward(text, inputs.get('voice'))
return {OutputKeys.OUTPUT_PCM: output_wav}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 7
- 0
modelscope/utils/audio/tts_exceptions.py View File

@@ -17,6 +17,13 @@ class TtsModelConfigurationExcetion(TtsException):
pass


class TtsVoiceNotExistsException(TtsException):
"""
TTS voice not exists exception.
"""
pass


class TtsFrontendException(TtsException):
"""
TTS frontend module level exceptions.


+ 3
- 0
output.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b4153d9ffc0b72eeaf162b5c9f4426f95dcea2bb0da9e7b5e1b72fd2643b1915
size 50444

+ 4
- 3
tests/pipelines/test_text_to_speech.py View File

@@ -26,13 +26,14 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
def test_pipeline(self):
single_test_case_label = 'test_case_label_0'
text = '今天北京天气怎么样?'
model_id = 'damo/speech_sambert-hifigan_tts_zhitian_emo_zhcn_16k'
model_id = 'damo/speech_sambert-hifigan_tts_zhcn_16k'
voice = 'zhitian_emo'

sambert_hifigan_tts = pipeline(
task=Tasks.text_to_speech, model=model_id)
self.assertTrue(sambert_hifigan_tts is not None)
test_cases = {single_test_case_label: text}
output = sambert_hifigan_tts(test_cases)
inputs = {single_test_case_label: text, 'voice': voice}
output = sambert_hifigan_tts(inputs)
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label]
write('output.wav', 16000, pcm)


Loading…
Cancel
Save