Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yzGuu830 committed Sep 30, 2024
1 parent 5a5cd12 commit 12a78c0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ For more details, see the `example.ipynb` notebook.
We provide developmental training and evaluation datasets available on [Hugging Face](https://huggingface.co/datasets/Tracygu/dnscustom/tree/main). For custom training, set the `train_data_path` in `exp.yaml` to the parent directory containing `.wav` audio segments. Run the following to start training:

```ruby
WANDB_API_KEY='your_API_key'
WANDB_API_KEY=your_API_key
accelerate launch main.py --exp_name esc9kbps --config_path ./configs/9kbps_esc_base.yaml --wandb_project efficient-speech-codec --lr 1.0e-4 --num_epochs 80 --num_pretraining_epochs 15 --num_devices 4 --dropout_rate 0.75 --save_path /path/to/output --seed 53
```

Expand Down
9 changes: 6 additions & 3 deletions scripts/trainer_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train_step(self, x):
# Backward Pass (Generator)
self.opt_g.zero_grad()
self.accel.backward(outputs["loss"].mean())
self.accel.clip_grad_norm_(self.model.parameters(), 0.5)
self.accel.clip_grad_norm_(self.model.parameters(), 1e3)
self.opt_g.step()
self.scheduler.step()

Expand All @@ -101,7 +101,7 @@ def train_step(self, x):
# Backward Pass (Discriminator)
self.opt_d.zero_grad()
self.accel.backward(outputs["disc_loss"].mean())
self.accel.clip_grad_norm_(self.model_disc.parameters(), .05)
self.accel.clip_grad_norm_(self.model_disc.parameters(), 10.0)
self.opt_d.step()
else:
outputs["disc_loss"] = torch.zeros(x.size(0), device=x.device)
Expand Down Expand Up @@ -130,7 +130,10 @@ def train(self, ):
self.model, self.model_disc, self.opt_g, self.opt_d, self.scheduler = self.accel.prepare(g, d, opt_g, opt_d, scheduler)
self.loss_funcs["adv_loss"] = GANLoss(self.model_disc).to(self.accel.device)

if self.pretrain_ckp is not None: self.evaluate() # pre-eval epoch
if self.args.pretrain_ckp is not None and if self.accel.is_main_process:
self.evaluate() # pre-eval epoch
self.accel.wait_for_everyone()

while True:
for _, x in enumerate(self.train_dl):

Expand Down

0 comments on commit 12a78c0

Please sign in to comment.