Skip to content

Commit

Permalink
Fix quick_start.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fatchord committed Sep 5, 2019
1 parent 72d28b0 commit fb8fcb4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
help='The file to use for the hyperparameters')
args = parser.parse_args()

hp.configure(args.hp_file) # Load hparams from file

parser.set_defaults(batched=hp.voc_gen_batched)
parser.set_defaults(target=hp.voc_target)
parser.set_defaults(overlap=hp.voc_overlap)
parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
args = parser.parse_args()

batched = args.batched
target = args.target
overlap = args.overlap
input_text = args.input_text
weights_path = args.weights_path

if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
Expand All @@ -66,7 +69,7 @@
sample_rate=hp.sample_rate,
mode='MOL').to(device)

voc_model.restore('quick_start/voc_weights/latest_weights.pyt')
voc_model.load('quick_start/voc_weights/latest_weights.pyt')

print('\nInitialising Tacotron Model...\n')

Expand All @@ -86,7 +89,7 @@
stop_threshold=hp.tts_stop_threshold).to(device)


tts_model.restore('quick_start/tts_weights/latest_weights.pyt')
tts_model.load('quick_start/tts_weights/latest_weights.pyt')

if input_text:
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
Expand All @@ -97,7 +100,8 @@
voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000

r = tts_model.get_r()
# TODO: get rid of this hardcoding
r = 2

simple_table([('WaveRNN', str(voc_k) + 'k'),
(f'Tacotron(r={r})', str(tts_k) + 'k'),
Expand Down

0 comments on commit fb8fcb4

Please sign in to comment.