Skip to content

Commit

Permalink
Merge pull request fatchord#126 from TheButlah/master
Browse files Browse the repository at this point in the history
Add Multi-GPU training, Griffin-Lim vocoder, safely restore checkpoints, pathlib, multiple hparams files
  • Loading branch information
fatchord authored Sep 8, 2019
2 parents 3f5a96e + 31179a6 commit 12922a7
Show file tree
Hide file tree
Showing 20 changed files with 942 additions and 503 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# PyCharm files
# IDE files
.idea
.vscode

# Mac files
.DS_Store
Expand Down
18 changes: 7 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Pytorch implementation of Deepmind's WaveRNN model from [Efficient Neural Audio

# Installation

Ensure you have:
Ensure you have:

* Python >= 3.6
* [Pytorch 1 with CUDA](https://pytorch.org/)
Expand Down Expand Up @@ -37,20 +37,20 @@ You can also use that script to generate custom tts sentences and/or use '-u' to

Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) Dataset.

Edit **hparams.py**, point **wav_path** to your dataset and run:
Edit **hparams.py**, point **wav_path** to your dataset and run:

> python preprocess.py
or use preprocess.py --path to point directly to the dataset
___

Here's my recommendation on what order to run things:
Here's my recommendation on what order to run things:

1 - Train Tacotron with:

> python train_tacotron.py
2 - You can leave that finish training or at any point you can use:
2 - You can leave that finish training or at any point you can use:

> python train_tacotron.py --force_gta
Expand All @@ -64,11 +64,11 @@ NB: You can always just run train_wavernn.py without --gta if you're not interes

4 - Generate Sentences with both models using:

> python gen_tacotron.py
> python gen_tacotron.py wavernn
this will generate default sentences. If you want generate custom sentences you can use

> python gen_tacotron.py --input_text "this is whatever you want it to be"
> python gen_tacotron.py --input_text "this is whatever you want it to be" wavernn
And finally, you can always use --help on any of those scripts to see what options are available :)

Expand All @@ -84,7 +84,7 @@ Currently there are two pretrained models available in the /pretrained/ folder':

Both are trained on LJSpeech

* WaveRNN (Mixture of Logistics output) trained to 800k steps
* WaveRNN (Mixture of Logistics output) trained to 800k steps
* Tacotron trained to 180k steps

____
Expand All @@ -100,7 +100,3 @@ ____
* [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron)
* [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
* Special thanks to github users [G-Wang](https://github.com/G-Wang), [geneing](https://github.com/geneing) & [erogol](https://github.com/erogol)




152 changes: 102 additions & 50 deletions gen_tacotron.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,68 @@
import torch
from models.fatchord_version import WaveRNN
import hparams as hp
from utils import hparams as hp
from utils.text.symbols import symbols
from utils.paths import Paths
from models.tacotron import Tacotron
import argparse
from utils.text import text_to_sequence
from utils.display import save_attention, simple_table
from utils.dsp import reconstruct_waveform, save_wav
import numpy as np

if __name__ == "__main__":

# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.set_defaults(batched=hp.voc_gen_batched)
parser.set_defaults(target=hp.voc_target)
parser.set_defaults(overlap=hp.voc_overlap)
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')

parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
parser.set_defaults(save_attention=False)

# name of subcommand goes to args.vocoder
subparsers = parser.add_subparsers(required=True, dest='vocoder')

wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
wr_parser.set_defaults(batched=None)

gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')

args = parser.parse_args()

batched = args.batched
target = args.target
overlap = args.overlap
if args.vocoder in ['griffinlim', 'gl']:
args.vocoder = 'griffinlim'
elif args.vocoder in ['wavernn', 'wr']:
args.vocoder = 'wavernn'
else:
raise argparse.ArgumentError('Must provide a valid vocoder type!')

hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.vocoder == 'wavernn':
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched

batched = args.batched
target = args.target
overlap = args.overlap

input_text = args.input_text
weights_path = args.weights_path
save_attn = args.save_attention
tts_weights = args.tts_weights
save_attn = args.save_attn

paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

Expand All @@ -43,23 +72,24 @@
device = torch.device('cpu')
print('Using device:', device)

print('\nInitialising WaveRNN Model...\n')

# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)

voc_model.restore(paths.voc_latest_weights)
if args.vocoder == 'wavernn':
print('\nInitialising WaveRNN Model...\n')
# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)

voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
voc_model.load(voc_load_path)

print('\nInitialising Tacotron Model...\n')

Expand All @@ -75,42 +105,64 @@
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device)
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)

tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
tts_model.restore(tts_restore_path)
tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
tts_model.load(tts_load_path)

if input_text:
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
else:
with open('sentences.txt') as f:
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]

voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000

simple_table([('WaveRNN', str(voc_k) + 'k'),
('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r.item()),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
if args.vocoder == 'wavernn':
voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000

simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'WaveRNN'),
('WaveRNN', str(voc_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])

elif args.vocoder == 'griffinlim':
tts_k = tts_model.get_step() // 1000
simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'Griffin-Lim'),
('GL Iters', args.iters)])

for i, x in enumerate(inputs, 1):

print(f'\n| Generating {i}/{len(inputs)}')
_, m, attention = tts_model.generate(x)
# Fix mel spectrogram scaling to be from 0 to 1
m = (m + 4) / 8
np.clip(m, 0, 1, out=m)

if args.vocoder == 'griffinlim':
v_type = args.vocoder
elif args.vocoder == 'wavernn' and args.batched:
v_type = 'wavernn_batched'
else:
v_type = 'wavernn_unbatched'

if input_text:
save_path = f'{paths.tts_output}__input_{input_text[:10]}_{tts_k}k.wav'
save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
else:
save_path = f'{paths.tts_output}{i}_batched{str(batched)}_{tts_k}k.wav'
save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'

if save_attn: save_attention(attention, save_path)

m = torch.tensor(m).unsqueeze(0)
m = (m + 4) / 8

voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
if args.vocoder == 'wavernn':
m = torch.tensor(m).unsqueeze(0)
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
elif args.vocoder == 'griffinlim':
wav = reconstruct_waveform(m, n_iter=args.iters)
save_wav(wav, save_path)

print('\n\nDone.\n')
Loading

0 comments on commit 12922a7

Please sign in to comment.