Skip to content

Commit

Permalink
Made inference scripts use Pathlib, fixes minor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 7, 2019
1 parent e185a96 commit d30754f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions gen_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@
_, m, attention = tts_model.generate(x)

if input_text:
save_path = f'{paths.tts_output}__input_{input_text[:10]}_{tts_k}k.wav'
save_path = paths.tts_output/f'__input_{input_text[:10]}_{tts_k}k.wav'
else:
save_path = f'{paths.tts_output}{i}_batched{str(batched)}_{tts_k}k.wav'
save_path = paths.tts_output/f'{i}_batched{str(batched)}_{tts_k}k.wav'

if save_attn: save_attention(attention, save_path)

Expand Down
13 changes: 7 additions & 6 deletions gen_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from utils.display import simple_table
import torch
import argparse
from pathlib import Path


def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path):
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path):

k = model.get_step() // 1000

Expand All @@ -26,27 +27,27 @@ def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, sav
else:
x = label_2_float(x, bits)

save_wav(x, f'{save_path}{k}k_steps_{i}_target.wav')
save_wav(x, save_path/f'{k}k_steps_{i}_target.wav')

batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
save_str = f'{save_path}{k}k_steps_{i}_{batch_str}.wav'
save_str = str(save_path/f'{k}k_steps_{i}_{batch_str}.wav')

_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)


def gen_from_file(model: WaveRNN, load_path, save_path, batched, target, overlap):
def gen_from_file(model: WaveRNN, load_path, save_path: Path, batched, target, overlap):

k = model.get_step() // 1000
file_name = load_path.split('/')[-1]

wav = load_wav(load_path)
save_wav(wav, f'{save_path}__{file_name}__{k}k_steps_target.wav')
save_wav(wav, save_path/f'__{file_name}__{k}k_steps_target.wav')

mel = melspectrogram(wav)
mel = torch.tensor(mel).unsqueeze(0)

batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
save_str = f'{save_path}__{file_name}__{k}k_steps_{batch_str}.wav'
save_str = save_path/f'__{file_name}__{k}k_steps_{batch_str}.wav'

_ = model.generate(mel, save_str, batched, target, overlap, hp.mu_law)

Expand Down

0 comments on commit d30754f

Please sign in to comment.