Skip to content

Commit

Permalink
[egs] Fix model saving path in DeepClustering recipes(asteroid-team#398)
Browse files Browse the repository at this point in the history
* Fixes final checkpoint save for wsj0-mix DeepClustering 

Co-authored-by: Manuel Pariente <pariente.mnl@gmail.com>
ilyakava and mpariente authored Feb 6, 2021
1 parent 47e5237 commit 6ca575b
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion egs/kinect-wsj/DeepClustering/train.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def main(conf):
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(checkpoint.best_k_models, f, indent=0)
# Save last model for convenience
torch.save(system.model.state_dict(), os.path.join(exp_dir, "checkpoints/final.pth"))
torch.save(system.model.state_dict(), os.path.join(exp_dir, "final_model.pth"))


# TODO:Should ideally be inherited from wsj0-mix
2 changes: 1 addition & 1 deletion egs/wsj0-mix/DeepClustering/train.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def main(conf):
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
# Save last model for convenience
torch.save(system.model.state_dict(), os.path.join(exp_dir, "checkpoints/final.pth"))
torch.save(system.model.state_dict(), os.path.join(exp_dir, "final_model.pth"))


class ChimeraSystem(System):

0 comments on commit 6ca575b

Please sign in to comment.