Skip to content

Commit

Permalink
clean up, support mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
NTT123 committed Aug 9, 2023
1 parent 2518300 commit ca0ff9c
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 182 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
LJSpeech-1.1
data
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 0 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def forward(self, x, x_lengths, attn, y, y_lengths, sid=None):
# logw_ = torch.log(w + 1e-6) * x_mask
# logw = self.dp(x, x_mask, g=g)
# l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
print("m_p", m_p.shape)

# expand prior
m_p = torch.matmul(attn, m_p.transpose(1, 2)).transpose(1, 2)
Expand Down
40 changes: 40 additions & 0 deletions tfloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch # isort:skip
import tensorflow as tf


def load_tfdata(root, split):
files = tf.data.Dataset.list_files(f"{root}/{split}/part_*.tfrecords")
files = files.repeat().shuffle(len(files))

feature_description = {
"phone_idx": tf.io.FixedLenFeature([], tf.string),
"phone_duration": tf.io.FixedLenFeature([], tf.string),
"wav": tf.io.FixedLenFeature([], tf.string),
"spec": tf.io.FixedLenFeature([], tf.string),
}

def parse_tfrecord(r):
r = tf.io.parse_example(r, feature_description)
wav = tf.reshape(tf.io.parse_tensor(r["wav"], out_type=tf.float16), [-1])
spec = tf.io.parse_tensor(r["spec"], out_type=tf.float16)
spec = tf.reshape(spec, [-1, tf.shape(spec)[-1]])
phone_idx = tf.reshape(
tf.io.parse_tensor(r["phone_idx"], out_type=tf.int32), [-1]
)
phone_duration = tf.reshape(
tf.io.parse_tensor(r["phone_duration"], out_type=tf.float32), [-1]
)
return {
"phone_idx": phone_idx,
"phone_duration": phone_duration,
"phone_length": tf.shape(phone_duration)[0],
"wav": wav,
"wav_length": tf.shape(wav)[0],
"spec": spec,
"spec_length": tf.shape(spec)[0],
}

ds = tf.data.TFRecordDataset(files, num_parallel_reads=4).map(
parse_tfrecord, num_parallel_calls=4
)
return ds
189 changes: 96 additions & 93 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,77 @@
import tensorflow as tf
import torch

tf.config.set_visible_devices([], "GPU")


import torch # isort:skip
import json
from argparse import ArgumentParser
from contextlib import nullcontext
from types import SimpleNamespace

import tensorflow as tf
import torch
from torch.nn import functional as F
from tqdm.auto import tqdm

import commons
from losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from models import MultiPeriodDiscriminator, SynthesizerTrn
from tfloader import load_tfdata

with open("config.json", "rb") as f:
hps = json.load(f)


def load_tfdata(root, split):
files = tf.data.Dataset.list_files(f"{root}/{split}/part_*.tfrecords")
files = files.repeat().shuffle(len(files))

feature_description = {
"phone_idx": tf.io.FixedLenFeature([], tf.string),
"phone_duration": tf.io.FixedLenFeature([], tf.string),
"phone_mask": tf.io.FixedLenFeature([], tf.string),
"wav": tf.io.FixedLenFeature([], tf.string),
"spec": tf.io.FixedLenFeature([], tf.string),
}

def parse_tfrecord(r):
r = tf.io.parse_example(r, feature_description)
wav = tf.reshape(tf.io.parse_tensor(r["wav"], out_type=tf.float16), [-1])
phone_mask = tf.reshape(
tf.io.parse_tensor(r["phone_mask"], out_type=tf.bool), [-1]
)
spec = tf.io.parse_tensor(r["spec"], out_type=tf.float16)
spec = tf.reshape(spec, [-1, tf.shape(spec)[-1]])
return {
"phone_idx": tf.reshape(
tf.io.parse_tensor(r["phone_idx"], out_type=tf.int32), [-1]
),
"phone_duration": tf.reshape(
tf.io.parse_tensor(r["phone_duration"], out_type=tf.float32), [-1]
),
"phone_mask": phone_mask,
"phone_length": tf.shape(phone_mask)[0],
"wav": wav,
"wav_length": tf.shape(wav)[0],
"spec": spec,
"spec_length": tf.shape(spec)[0],
}

ds = tf.data.TFRecordDataset(files, num_parallel_reads=4).map(
parse_tfrecord, num_parallel_calls=4
)
return ds
tf.config.set_visible_devices([], "GPU")

parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.json")
parser.add_argument("--tfdata", type=str, default="data/tfdata")
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--seed", type=int, default=42)
FLAGS = parser.parse_args()

# credit: https://github.com/karpathy/nanoGPT/blob/master/train.py#L72-L112
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed(FLAGS.seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = FLAGS.device
dtype = (
"bfloat16"
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else "float16"
)
compile = FLAGS.compile
device_type = "cuda" if "cuda" in device else "cpu"
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# initialize a GradScaler. If enabled=False scaler is a no-op
print(dtype, ptdtype, ctx)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


train_ds = load_tfdata("tfdata", "train")
bs = 16
train_ds = load_tfdata(FLAGS.tfdata, "train")
ds = train_ds.bucket_by_sequence_length(
lambda x: tf.shape(x["spec"])[0],
bucket_boundaries=(32, 300, 400, 500, 600, 700, 800, 900, 1000),
bucket_batch_sizes=[bs] * 10,
bucket_batch_sizes=[FLAGS.batch_size] * 10,
pad_to_bucket_boundary=False,
)


from utils import HParams

config_save_path = "config.json"
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)

hparams = HParams(**config)
hparams.model_dir = "./"
hps = hparams
with open(FLAGS.config, "rb") as f:
hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
torch.manual_seed(hps.train.seed)

from tqdm.auto import tqdm

net_g = SynthesizerTrn(
256,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
**vars(hps.model),
)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
optim_g = torch.optim.AdamW(
Expand All @@ -103,6 +87,9 @@ def parse_tfrecord(r):
eps=hps.train.eps,
)

if compile:
net_g = torch.compile(net_g)
net_d = torch.compile(net_d)

epoch_str = 1
global_step = 0
Expand All @@ -118,16 +105,29 @@ def parse_tfrecord(r):
net_g.train()
net_d.train()


for batch in tqdm(ds.prefetch(1).as_numpy_iterator()):
x = torch.from_numpy(batch["phone_idx"]).long()
x_lengths = torch.from_numpy(batch["phone_mask"]).sum(-1).long()
spec = torch.from_numpy(batch["spec"]).swapaxes(-1, -2).float()
x = torch.from_numpy(batch["phone_idx"]).long().to(device, non_blocking=True)
x_lengths = (
torch.from_numpy(batch["phone_length"]).long().to(device, non_blocking=True)
)
spec = (
torch.from_numpy(batch["spec"])
.swapaxes(-1, -2)
.float()
.to(device, non_blocking=True)
)
spec = torch.log(1e-3 + spec)
spec_lengths = torch.from_numpy(batch["spec_length"]).long()
y = torch.from_numpy(batch["wav"]).float()[:, None, :]
y_lengths = torch.from_numpy(batch["wav_length"]).long()

duration = torch.from_numpy(batch["phone_duration"]).float()
spec_lengths = (
torch.from_numpy(batch["spec_length"]).long().to(device, non_blocking=True)
)
y = torch.from_numpy(batch["wav"]).float()[:, None, :].to(device, non_blocking=True)
y_lengths = (
torch.from_numpy(batch["wav_length"]).long().to(device, non_blocking=True)
)
duration = (
torch.from_numpy(batch["phone_duration"]).float().to(device, non_blocking=True)
)
end_time = torch.cumsum(duration, dim=-1)
start_time = end_time - duration
start_frame = (
Expand All @@ -140,15 +140,16 @@ def parse_tfrecord(r):
pos[None, :, None] < end_frame[:, None, :],
)

(
y_hat,
l_length,
attn,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(x, x_lengths, attn.float(), spec, spec_lengths)
with ctx:
(
y_hat,
l_length,
attn,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(x, x_lengths, attn.float(), spec, spec_lengths)

mel = spec_to_mel_torch(
spec,
Expand Down Expand Up @@ -176,19 +177,20 @@ def parse_tfrecord(r):
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice

# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())

with ctx:
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc_all = loss_disc
optim_d.zero_grad()
loss_disc_all.backward()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
optim_d.step()
scaler.step(optim_d)

y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with ctx:
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)

# loss_dur = torch.sum(l_length.float())
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

Expand All @@ -197,9 +199,10 @@ def parse_tfrecord(r):
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl

optim_g.zero_grad()

loss_gen_all.backward()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
optim_g.step()
scaler.step(optim_g)
scaler.update()

loss_disc_all, loss_gen_all
print(loss_disc_all.item(), loss_gen_all.item())
Loading

0 comments on commit ca0ff9c

Please sign in to comment.