Skip to content

Commit

Permalink
Merge pull request karpathy#37 from Kaixhin/master
Browse files Browse the repository at this point in the history
Add Adam trainer
  • Loading branch information
karpathy committed Jun 21, 2015
2 parents 9b8a20a + e14012e commit 7b71356
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
3 changes: 2 additions & 1 deletion demo/js/trainers.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ nets = [];\n\
trainer_defs = [];\n\
trainer_defs.push({learning_rate:LR, method: 'sgd', momentum: 0.0, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'sgd', momentum: 0.9, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'adam', eps: 1e-8, beta1: 0.9, beta2: 0.99, lambda: 1-1e-8, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'adagrad', eps: 1e-6, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'windowgrad', eps: 1e-6, ro: 0.95, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:1.0, method: 'adadelta', eps: 1e-6, ro:0.95, batch_size:BS, l2_decay:L2});\n\
trainer_defs.push({learning_rate:LR, method: 'nesterov', momentum: 0.9, batch_size:BS, l2_decay:L2});\n\
\n\
// names for all trainers above\n\
legend = ['sgd', 'sgd+momentum', 'adagrad', 'windowgrad', 'adadelta', 'nesterov'];\n\
legend = ['sgd', 'sgd+momentum', 'adam', 'adagrad', 'windowgrad', 'adadelta', 'nesterov'];\n\
"

// ------------------------
Expand Down
2 changes: 1 addition & 1 deletion demo/trainers.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<h2 style="text-align: center;"><a href="http://cs.stanford.edu/people/karpathy/convnetjs/">ConvNetJS</a> Trainer demo on MNIST</h2>
<h1>Description</h1>
<p>
This demo lets you evaluate multiple trainers against each other on MNIST. By default I've set up a little benchmark that puts SGD/SGD with momentum/Adagrad/Adadelta/Nesterov against each other. For reference math and explanations on these refer to Matthew Zeiler's <a href="http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf">Adadelta paper</a> (Windowgrad is Idea #1 in the paper). In my own experience, Adagrad/Adadelta are "safer" because they don't depend so strongly on setting of learning rates (with Adadelta being slightly better), but well-tuned SGD+Momentum almost always converges faster and at better final values.
This demo lets you evaluate multiple trainers against each other on MNIST. By default I've set up a little benchmark that puts SGD/SGD with momentum/Adam/Adagrad/Adadelta/Nesterov against each other. For reference math and explanations on these refer to Matthew Zeiler's <a href="http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf">Adadelta paper</a> (Windowgrad is Idea #1 in the paper). In my own experience, Adagrad/Adadelta are "safer" because they don't depend so strongly on setting of learning rates (with Adadelta being slightly better), but well-tuned SGD+Momentum almost always converges faster and at better final values.
</p>
<p>Report questions/bugs/suggestions to <a href="https://twitter.com/karpathy">@karpathy</a>.</p>

Expand Down
27 changes: 20 additions & 7 deletions src/convnet_trainers.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
this.l1_decay = typeof options.l1_decay !== 'undefined' ? options.l1_decay : 0.0;
this.l2_decay = typeof options.l2_decay !== 'undefined' ? options.l2_decay : 0.0;
this.batch_size = typeof options.batch_size !== 'undefined' ? options.batch_size : 1;
this.method = typeof options.method !== 'undefined' ? options.method : 'sgd'; // sgd/adagrad/adadelta/windowgrad/netsterov
this.method = typeof options.method !== 'undefined' ? options.method : 'sgd'; // sgd/adam/adagrad/adadelta/windowgrad/netsterov

this.momentum = typeof options.momentum !== 'undefined' ? options.momentum : 0.9;
this.ro = typeof options.ro !== 'undefined' ? options.ro : 0.95; // used in adadelta
this.eps = typeof options.eps !== 'undefined' ? options.eps : 1e-6; // used in adadelta
this.eps = typeof options.eps !== 'undefined' ? options.eps : 1e-8; // used in adam or adadelta
this.beta1 = typeof options.beta1 !== 'undefined' ? options.beta1 : 0.9; // used in adam
this.beta2 = typeof options.beta2 !== 'undefined' ? options.beta2 : 0.999; // used in adam
this.lambda = typeof options.lambda !== 'undefined' ? options.lambda : 1-1e-8; // used in adam

this.k = 0; // iteration counter
this.gsum = []; // last iteration gradients (used for momentum calculations)
this.xsum = []; // used in adadelta
this.xsum = []; // used in adam or adadelta
}

Trainer.prototype = {
Expand Down Expand Up @@ -47,10 +50,10 @@
// only vanilla sgd doesnt need either lists
// momentum needs gsum
// adagrad needs gsum
// adadelta needs gsum and xsum
// adam and adadelta needs gsum and xsum
for(var i=0;i<pglist.length;i++) {
this.gsum.push(global.zeros(pglist[i].params.length));
if(this.method === 'adadelta') {
if(this.method === 'adam' || this.method === 'adadelta') {
this.xsum.push(global.zeros(pglist[i].params.length));
} else {
this.xsum.push([]); // conserve memory
Expand Down Expand Up @@ -81,7 +84,18 @@

var gsumi = this.gsum[i];
var xsumi = this.xsum[i];
if(this.method === 'adagrad') {
if(this.method === 'adam') {
// adam update
var bt1 = this.beta1 * Math.pow(this.lambda, this.k-1); // decay first moment running average coefficient
gsumi[j] = gsumi[j] * bt1 + (1-bt1) * gij; // update biased first moment estimate
xsumi[j] = xsumi[j] * this.beta2 + (1-this.beta2) * gij * gij; // update biased second moment estimate
var denom = Math.sqrt(xsumi[j]) + this.eps;
var biasCorr1 = 1 - Math.pow(this.beta1, this.k); // correct bias
var biasCorr2 = 1 - Math.pow(this.beta2, this.k); // correct bias
var stepSize = this.learning_rate * Math.sqrt(biasCorr2) / biasCorr1;
var dx = stepSize * gsumi[j] / denom;
p[j] += dx;
} else if(this.method === 'adagrad') {
// adagrad update
gsumi[j] = gsumi[j] + gij * gij;
var dx = - this.learning_rate / Math.sqrt(gsumi[j] + this.eps) * gij;
Expand All @@ -94,7 +108,6 @@
var dx = - this.learning_rate / Math.sqrt(gsumi[j] + this.eps) * gij; // eps added for better conditioning
p[j] += dx;
} else if(this.method === 'adadelta') {
// assume adadelta if not sgd or adagrad
gsumi[j] = this.ro * gsumi[j] + (1-this.ro) * gij * gij;
var dx = - Math.sqrt((xsumi[j] + this.eps)/(gsumi[j] + this.eps)) * gij;
xsumi[j] = this.ro * xsumi[j] + (1-this.ro) * dx * dx; // yes, xsum lags behind gsum by 1.
Expand Down

0 comments on commit 7b71356

Please sign in to comment.