Skip to content

Commit

Permalink
Began adding Griffin-Lim vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 9, 2019
1 parent d2fc0a5 commit d0cc6b4
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 54 deletions.
130 changes: 84 additions & 46 deletions gen_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,58 @@
import argparse
from utils.text import text_to_sequence
from utils.display import save_attention, simple_table
from utils.dsp import reconstruct_waveform, save_wav

if __name__ == "__main__":

# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')

parser.set_defaults(batched=None)
parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
parser.set_defaults(save_attention=False)

# name of subcommand goes to args.vocoder
subparsers = parser.add_subparsers(required=True, dest='vocoder')

wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
wr_parser.set_defaults(batched=None)

gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')

args = parser.parse_args()

if args.vocoder in ['griffinlim', 'gl']:
args.vocoder = 'griffinlim'
elif args.vocoder in ['wavernn', 'wr']:
args.vocoder = 'wavernn'
else:
raise argparse.ArgumentError('Must provide a valid vocoder type!')

hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched

batched = args.batched
target = args.target
overlap = args.overlap
if args.vocoder == 'wavernn':
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched

batched = args.batched
target = args.target
overlap = args.overlap

input_text = args.input_text
weights_path = args.weights_path
save_attn = args.save_attention
Expand All @@ -53,23 +71,22 @@
device = torch.device('cpu')
print('Using device:', device)

print('\nInitialising WaveRNN Model...\n')

# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)

voc_model.load(paths.voc_latest_weights)
if args.vocoder == 'wavernn':
print('\nInitialising WaveRNN Model...\n')
# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)
voc_model.load(paths.voc_latest_weights)

print('\nInitialising Tacotron Model...\n')

Expand All @@ -96,31 +113,52 @@
with open('sentences.txt') as f:
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]

voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000

simple_table([('WaveRNN', str(voc_k) + 'k'),
('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
if args.vocoder == 'wavernn':
voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000

simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'WaveRNN'),
('WaveRNN', str(voc_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])

elif args.vocoder == 'griffinlim':
tts_k = tts_model.get_step() // 1000
simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'Griffin-Lim'),
('GL Iters', args.iters)])

for i, x in enumerate(inputs, 1):

print(f'\n| Generating {i}/{len(inputs)}')
_, m, attention = tts_model.generate(x)

if args.vocoder == 'griffinlim':
v_type = args.vocoder
elif args.vocoder == 'wavernn' and args.batched:
v_type = 'wavernn_batched'
else:
v_type = 'wavernn_unbatched'

if input_text:
save_path = paths.tts_output/f'__input_{input_text[:10]}_{tts_k}k.wav'
save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
else:
save_path = paths.tts_output/f'{i}_batched{str(batched)}_{tts_k}k.wav'
save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'

if save_attn: save_attention(attention, save_path)

m = torch.tensor(m).unsqueeze(0)
m = (m + 4) / 8
if args.vocoder == 'wavernn':
m = torch.tensor(m).unsqueeze(0)
m = (m + 4) / 8

voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
elif args.vocoder == 'griffinlim':
wav = reconstruct_waveform(m, n_iter=args.iters)
save_wav(wav, save_path)

voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)

print('\n\nDone.\n')
28 changes: 20 additions & 8 deletions utils/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ def encode_16bits(x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)


mel_basis = None
def linear_to_mel(spectrogram):
global mel_basis
if mel_basis is None:
mel_basis = build_mel_basis()
return np.dot(mel_basis, spectrogram)

librosa.feature.melspectrogram(
S=spectrogram, sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)

'''
def build_mel_basis():
return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)

'''

def normalize(S):
return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
Expand Down Expand Up @@ -79,7 +76,9 @@ def melspectrogram(y):


def stft(y):
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
return librosa.stft(
y=y, sr=hp.sample_rate,
n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)


def pre_emphasis(x):
Expand All @@ -103,3 +102,16 @@ def decode_mu_law(y, mu, from_labels=True):
x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
return x

def reconstruct_waveform(mel, n_iter=32):
"""Uses Griffin-Lim phase reconstruction to convert from a normalized
mel spectrogram back into a waveform."""
denormalized = denormalize(mel)
amp_mel = db_to_amp(denormalized)
S = librosa.feature.inverse.mel_to_stft(
amp_mel, power=1, sr=hp.sample_rate,
n_fft=hp.n_fft, fmin=hp.fmin)
wav = librosa.core.griffinlim(
S, n_iter=n_iter,
hop_length=hp.hop_length, win_length=hp.win_length)
return wav

0 comments on commit d0cc6b4

Please sign in to comment.