Jeremy Bernstein   ·   Arash Vahdat   ·   Yisong Yue   ·   Ming‑Yu Liu
To get started with Fromage in your Pytorch code, copy the file fromage.py
into your project directory, then write:
from fromage import Fromage
optimizer = Fromage(net.parameters(), lr=0.01, p_bound=None)
An initial learning rate of 0.01 has worked well in all our experiments except model fine-tuning, where 0.001 worked well. Decaying the learning rate when the loss plateaus is a good idea.
On some benchmarks, Fromage heavily overfit the training set. We were able to control this behaviour by setting the p_bound regularisation flag. This constrains the norm of each layer's weights to lie within a factor of p_bound times its intial value.
We've written an academic paper that proposes an optimisation algorithm based on a new geometric characterisation of deep neural networks. The paper is called:
On the distance between two neural networks and the stability of learning.
You can also check out a blog post with some interactive demos of the main idea:
We're putting this code here so that you can test out our optimisation algorithm in your own applications, and also so that you can attempt to reproduce the experiments in our paper.
If something isn't clear or isn't working, let us know in the Issues section or contact bernstein@caltech.edu.
Here is the structure of this repository.
.
├── classify-cifar/ # CIFAR-10 classification experiments.
├── classify-imagenet/ # Imagenet classification experiments.
├── classify-mnist/ # MNIST classification experiments.
├── transformer-wikitext2/ # Transformer training experiments.
├── generate-cifar/ # CIFAR-10 class-conditional GAN experiments.
├── make-plots/ # Code to reproduce the figures in the paper.
├── LICENSE # The license on our algorithm.
├── README.md # The very page you're reading now.
└── fromage.py # Pytorch code for the Fromage optimiser.
- This research was supported by Caltech and NVIDIA.
- Our GAN implementation is based on a codebase by Jiahui Yu.
- Our Transformer code is from the Pytorch example.
- Our CIFAR-10 classification code is orginally by kuangliu.
- Our MNIST code was originally forked from the Pytorch example.
- See here and here for closely related work by Yang You and coauthors.
If you adore le fromage as much as we do, feel free to cite the paper:
@inproceedings{fromage,
title={On the distance between two neural networks and the stability of learning},
author={Jeremy Bernstein and Arash Vahdat and Yisong Yue and Ming-Yu Liu},
booktitle = {Neural Information Processing Systems},
year={2020}
}
We are making our algorithm available under a CC BY-NC-SA 4.0 license. The other code we have used obeys other license restrictions as indicated in the subfolders.