-
Notifications
You must be signed in to change notification settings - Fork 945
/
updates.py
904 lines (732 loc) · 31.1 KB
/
updates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
"""
Functions to generate Theano update dictionaries for training.
The update functions implement different methods to control the learning
rate for use with stochastic gradient descent.
Update functions take a loss expression or a list of gradient expressions and
a list of parameters as input and return an ordered dictionary of updates:
.. autosummary::
:nosignatures:
sgd
momentum
nesterov_momentum
adagrad
rmsprop
adadelta
adam
adamax
amsgrad
Two functions can be used to further modify the updates to include momentum:
.. autosummary::
:nosignatures:
apply_momentum
apply_nesterov_momentum
Finally, we provide two helper functions to constrain the norm of tensors:
.. autosummary::
:nosignatures:
norm_constraint
total_norm_constraint
:func:`norm_constraint()` can be used to constrain the norm of parameters
(as an alternative to weight decay), or for a form of gradient clipping.
:func:`total_norm_constraint()` constrain the total norm of a list of tensors.
This is often used when training recurrent neural networks.
Examples
--------
Using :func:`nesterov_momentum` to define an update dictionary for a toy
example network:
>>> import lasagne
>>> import theano.tensor as T
>>> import theano
>>> from lasagne.nonlinearities import softmax
>>> from lasagne.layers import InputLayer, DenseLayer, get_output
>>> from lasagne.updates import nesterov_momentum
>>> l_in = InputLayer((100, 20))
>>> l1 = DenseLayer(l_in, num_units=3, nonlinearity=softmax)
>>> x = T.matrix('x') # shp: num_batch x num_features
>>> y = T.ivector('y') # shp: num_batch
>>> l_out = get_output(l1, x)
>>> params = lasagne.layers.get_all_params(l1)
>>> loss = T.mean(T.nnet.categorical_crossentropy(l_out, y))
>>> updates = nesterov_momentum(loss, params, learning_rate=1e-4, momentum=.9)
>>> train_fn = theano.function([x, y], updates=updates)
With :func:`apply_momentum` and :func:`apply_nesterov_momentum`, we can add
momentum to optimization schemes that do not usually support this:
>>> updates = lasagne.updates.rmsprop(loss, params, learning_rate=0.0001)
>>> updates = lasagne.updates.apply_momentum(updates, params, momentum=0.9)
All optimization schemes support symbolic variables for their hyperparameters,
such as shared variables. This allows to vary hyperparameters during training
without recompiling the training function. Note that the dtypes must match the
dtypes of the network parameters, which follow Theano's ``floatX`` setting.
In the following example, we use :func:`lasagne.utils.floatX` to ensure this:
>>> eta = theano.shared(lasagne.utils.floatX(0.001))
>>> updates = lasagne.updates.adam(loss, params, learning_rate=eta)
>>> train_fn = theano.function([x, y], updates=updates)
>>> # we can now modify the learning rate at any time during training:
>>> eta.set_value(lasagne.utils.floatX(eta.get_value() * 0.1))
"""
from collections import OrderedDict
import numpy as np
import theano
import theano.tensor as T
from . import utils
__all__ = [
"sgd",
"apply_momentum",
"momentum",
"apply_nesterov_momentum",
"nesterov_momentum",
"adagrad",
"rmsprop",
"adadelta",
"adam",
"adamax",
"amsgrad",
"norm_constraint",
"total_norm_constraint"
]
def get_or_compute_grads(loss_or_grads, params):
"""Helper function returning a list of gradients
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to return the gradients for
Returns
-------
list of expressions
If `loss_or_grads` is a list, it is assumed to be a list of
gradients and returned as is, unless it does not match the length
of `params`, in which case a `ValueError` is raised.
Otherwise, `loss_or_grads` is assumed to be a cost expression and
the function returns `theano.grad(loss_or_grads, params)`.
Raises
------
ValueError
If `loss_or_grads` is a list of a different length than `params`, or if
any element of `params` is not a shared variable (while we could still
compute its gradient, we can never update it and want to fail early).
"""
if any(not isinstance(p, theano.compile.SharedVariable) for p in params):
raise ValueError("params must contain shared variables only. If it "
"contains arbitrary parameter expressions, then "
"lasagne.utils.collect_shared_vars() may help you.")
if isinstance(loss_or_grads, list):
if not len(loss_or_grads) == len(params):
raise ValueError("Got %d gradient expressions for %d parameters" %
(len(loss_or_grads), len(params)))
return loss_or_grads
else:
return theano.grad(loss_or_grads, params)
def sgd(loss_or_grads, params, learning_rate):
"""Stochastic Gradient Descent (SGD) updates
Generates update expressions of the form:
* ``param := param - learning_rate * gradient``
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
"""
grads = get_or_compute_grads(loss_or_grads, params)
updates = OrderedDict()
for param, grad in zip(params, grads):
updates[param] = param - learning_rate * grad
return updates
def apply_momentum(updates, params=None, momentum=0.9):
"""Returns a modified update dictionary including momentum
Generates update expressions of the form:
* ``velocity := momentum * velocity + updates[param] - param``
* ``param := param + velocity``
Parameters
----------
updates : OrderedDict
A dictionary mapping parameters to update expressions
params : iterable of shared variables, optional
The variables to apply momentum to. If omitted, will apply
momentum to all `updates.keys()`.
momentum : float or symbolic scalar, optional
The amount of momentum to apply. Higher momentum results in
smoothing over more update steps. Defaults to 0.9.
Returns
-------
OrderedDict
A copy of `updates` with momentum updates for all `params`.
Notes
-----
Higher momentum also results in larger update steps. To counter that,
you can optionally scale your learning rate by `1 - momentum`.
See Also
--------
momentum : Shortcut applying momentum to SGD updates
"""
if params is None:
params = updates.keys()
updates = OrderedDict(updates)
for param in params:
value = param.get_value(borrow=True)
velocity = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
x = momentum * velocity + updates[param]
updates[velocity] = x - param
updates[param] = x
return updates
def momentum(loss_or_grads, params, learning_rate, momentum=0.9):
"""Stochastic Gradient Descent (SGD) updates with momentum
Generates update expressions of the form:
* ``velocity := momentum * velocity - learning_rate * gradient``
* ``param := param + velocity``
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
momentum : float or symbolic scalar, optional
The amount of momentum to apply. Higher momentum results in
smoothing over more update steps. Defaults to 0.9.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
Higher momentum also results in larger update steps. To counter that,
you can optionally scale your learning rate by `1 - momentum`.
See Also
--------
apply_momentum : Generic function applying momentum to updates
nesterov_momentum : Nesterov's variant of SGD with momentum
"""
updates = sgd(loss_or_grads, params, learning_rate)
return apply_momentum(updates, momentum=momentum)
def apply_nesterov_momentum(updates, params=None, momentum=0.9):
"""Returns a modified update dictionary including Nesterov momentum
Generates update expressions of the form:
* ``velocity := momentum * velocity + updates[param] - param``
* ``param := param + momentum * velocity + updates[param] - param``
Parameters
----------
updates : OrderedDict
A dictionary mapping parameters to update expressions
params : iterable of shared variables, optional
The variables to apply momentum to. If omitted, will apply
momentum to all `updates.keys()`.
momentum : float or symbolic scalar, optional
The amount of momentum to apply. Higher momentum results in
smoothing over more update steps. Defaults to 0.9.
Returns
-------
OrderedDict
A copy of `updates` with momentum updates for all `params`.
Notes
-----
Higher momentum also results in larger update steps. To counter that,
you can optionally scale your learning rate by `1 - momentum`.
The classic formulation of Nesterov momentum (or Nesterov accelerated
gradient) requires the gradient to be evaluated at the predicted next
position in parameter space. Here, we use the formulation described at
https://github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617,
which allows the gradient to be evaluated at the current parameters.
See Also
--------
nesterov_momentum : Shortcut applying Nesterov momentum to SGD updates
"""
if params is None:
params = updates.keys()
updates = OrderedDict(updates)
for param in params:
value = param.get_value(borrow=True)
velocity = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
x = momentum * velocity + updates[param] - param
updates[velocity] = x
updates[param] = momentum * x + updates[param]
return updates
def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9):
"""Stochastic Gradient Descent (SGD) updates with Nesterov momentum
Generates update expressions of the form:
* ``velocity := momentum * velocity - learning_rate * gradient``
* ``param := param + momentum * velocity - learning_rate * gradient``
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
momentum : float or symbolic scalar, optional
The amount of momentum to apply. Higher momentum results in
smoothing over more update steps. Defaults to 0.9.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
Higher momentum also results in larger update steps. To counter that,
you can optionally scale your learning rate by `1 - momentum`.
The classic formulation of Nesterov momentum (or Nesterov accelerated
gradient) requires the gradient to be evaluated at the predicted next
position in parameter space. Here, we use the formulation described at
https://github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617,
which allows the gradient to be evaluated at the current parameters.
See Also
--------
apply_nesterov_momentum : Function applying momentum to updates
"""
updates = sgd(loss_or_grads, params, learning_rate)
return apply_nesterov_momentum(updates, momentum=momentum)
def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
"""Adagrad updates
Scale learning rates by dividing with the square root of accumulated
squared gradients. See [1]_ for further description.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
epsilon : float or symbolic scalar
Small value added for numerical stability
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
Using step size eta Adagrad calculates the learning rate for feature i at
time step t as:
.. math:: \\eta_{t,i} = \\frac{\\eta}
{\\sqrt{\\sum^t_{t^\\prime} g^2_{t^\\prime,i}+\\epsilon}} g_{t,i}
as such the learning rate is monotonically decreasing.
Epsilon is not included in the typical formula, see [2]_.
References
----------
.. [1] Duchi, J., Hazan, E., & Singer, Y. (2011):
Adaptive subgradient methods for online learning and stochastic
optimization. JMLR, 12:2121-2159.
.. [2] Chris Dyer:
Notes on AdaGrad. http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf
"""
grads = get_or_compute_grads(loss_or_grads, params)
updates = OrderedDict()
for param, grad in zip(params, grads):
value = param.get_value(borrow=True)
accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
accu_new = accu + grad ** 2
updates[accu] = accu_new
updates[param] = param - (learning_rate * grad /
T.sqrt(accu_new + epsilon))
return updates
def rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
"""RMSProp updates
Scale learning rates by dividing with the moving average of the root mean
squared (RMS) gradients. See [1]_ for further description.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
rho : float or symbolic scalar
Gradient moving average decay factor
epsilon : float or symbolic scalar
Small value added for numerical stability
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
`rho` should be between 0 and 1. A value of `rho` close to 1 will decay the
moving average slowly and a value close to 0 will decay the moving average
fast.
Using the step size :math:`\\eta` and a decay factor :math:`\\rho` the
learning rate :math:`\\eta_t` is calculated as:
.. math::
r_t &= \\rho r_{t-1} + (1-\\rho)*g^2\\\\
\\eta_t &= \\frac{\\eta}{\\sqrt{r_t + \\epsilon}}
References
----------
.. [1] Tieleman, T. and Hinton, G. (2012):
Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
"""
grads = get_or_compute_grads(loss_or_grads, params)
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
for param, grad in zip(params, grads):
value = param.get_value(borrow=True)
accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
accu_new = rho * accu + (one - rho) * grad ** 2
updates[accu] = accu_new
updates[param] = param - (learning_rate * grad /
T.sqrt(accu_new + epsilon))
return updates
def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95, epsilon=1e-6):
""" Adadelta updates
Scale learning rates by the ratio of accumulated gradients to accumulated
updates, see [1]_ and notes for further description.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
The learning rate controlling the size of update steps
rho : float or symbolic scalar
Squared gradient moving average decay factor
epsilon : float or symbolic scalar
Small value added for numerical stability
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
rho should be between 0 and 1. A value of rho close to 1 will decay the
moving average slowly and a value close to 0 will decay the moving average
fast.
rho = 0.95 and epsilon=1e-6 are suggested in the paper and reported to
work for multiple datasets (MNIST, speech).
In the paper, no learning rate is considered (so learning_rate=1.0).
Probably best to keep it at this value.
epsilon is important for the very first update (so the numerator does
not become 0).
Using the step size eta and a decay factor rho the learning rate is
calculated as:
.. math::
r_t &= \\rho r_{t-1} + (1-\\rho)*g^2\\\\
\\eta_t &= \\eta \\frac{\\sqrt{s_{t-1} + \\epsilon}}
{\sqrt{r_t + \epsilon}}\\\\
s_t &= \\rho s_{t-1} + (1-\\rho)*(\\eta_t*g)^2
References
----------
.. [1] Zeiler, M. D. (2012):
ADADELTA: An Adaptive Learning Rate Method.
arXiv Preprint arXiv:1212.5701.
"""
grads = get_or_compute_grads(loss_or_grads, params)
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
for param, grad in zip(params, grads):
value = param.get_value(borrow=True)
# accu: accumulate gradient magnitudes
accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
# delta_accu: accumulate update magnitudes (recursively!)
delta_accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
# update accu (as in rmsprop)
accu_new = rho * accu + (one - rho) * grad ** 2
updates[accu] = accu_new
# compute parameter update, using the 'old' delta_accu
update = (grad * T.sqrt(delta_accu + epsilon) /
T.sqrt(accu_new + epsilon))
updates[param] = param - learning_rate * update
# update delta_accu (as accu, but accumulating updates)
delta_accu_new = rho * delta_accu + (one - rho) * update ** 2
updates[delta_accu] = delta_accu_new
return updates
def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
beta2=0.999, epsilon=1e-8):
"""Adam updates
Adam updates implemented as in [1]_.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
Learning rate
beta1 : float or symbolic scalar
Exponential decay rate for the first moment estimates.
beta2 : float or symbolic scalar
Exponential decay rate for the second moment estimates.
epsilon : float or symbolic scalar
Constant for numerical stability.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
Notes
-----
The paper [1]_ includes an additional hyperparameter lambda. This is only
needed to prove convergence of the algorithm and has no practical use
(personal communication with the authors), it is therefore omitted here.
References
----------
.. [1] Kingma, Diederik, and Jimmy Ba (2014):
Adam: A Method for Stochastic Optimization.
arXiv preprint arXiv:1412.6980.
"""
all_grads = get_or_compute_grads(loss_or_grads, params)
t_prev = theano.shared(utils.floatX(0.))
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
t = t_prev + 1
a_t = learning_rate*T.sqrt(one-beta2**t)/(one-beta1**t)
for param, g_t in zip(params, all_grads):
value = param.get_value(borrow=True)
m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
m_t = beta1*m_prev + (one-beta1)*g_t
v_t = beta2*v_prev + (one-beta2)*g_t**2
step = a_t*m_t/(T.sqrt(v_t) + epsilon)
updates[m_prev] = m_t
updates[v_prev] = v_t
updates[param] = param - step
updates[t_prev] = t
return updates
def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
beta2=0.999, epsilon=1e-8):
"""Adamax updates
Adamax updates implemented as in [1]_. This is a variant of of the Adam
algorithm based on the infinity norm.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
Learning rate
beta1 : float or symbolic scalar
Exponential decay rate for the first moment estimates.
beta2 : float or symbolic scalar
Exponential decay rate for the weighted infinity norm estimates.
epsilon : float or symbolic scalar
Constant for numerical stability.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
References
----------
.. [1] Kingma, Diederik, and Jimmy Ba (2014):
Adam: A Method for Stochastic Optimization.
arXiv preprint arXiv:1412.6980.
"""
all_grads = get_or_compute_grads(loss_or_grads, params)
t_prev = theano.shared(utils.floatX(0.))
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
t = t_prev + 1
a_t = learning_rate/(one-beta1**t)
for param, g_t in zip(params, all_grads):
value = param.get_value(borrow=True)
m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
u_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
m_t = beta1*m_prev + (one-beta1)*g_t
u_t = T.maximum(beta2*u_prev, abs(g_t))
step = a_t*m_t/(u_t + epsilon)
updates[m_prev] = m_t
updates[u_prev] = u_t
updates[param] = param - step
updates[t_prev] = t
return updates
def amsgrad(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
beta2=0.999, epsilon=1e-8):
"""AMSGrad updates
AMSGrad updates implemented as in [1]_.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
Learning rate
beta1 : float or symbolic scalar
Exponential decay rate for the first moment estimates.
beta2 : float or symbolic scalar
Exponential decay rate for the second moment estimates.
epsilon : float or symbolic scalar
Constant for numerical stability.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
References
----------
.. [1] https://openreview.net/forum?id=ryQu7f-RZ
"""
all_grads = get_or_compute_grads(loss_or_grads, params)
t_prev = theano.shared(utils.floatX(0.))
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
t = t_prev + 1
a_t = learning_rate*T.sqrt(one-beta2**t)/(one-beta1**t)
for param, g_t in zip(params, all_grads):
value = param.get_value(borrow=True)
m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_hat_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
m_t = beta1*m_prev + (one-beta1)*g_t
v_t = beta2*v_prev + (one-beta2)*g_t**2
v_hat_t = T.maximum(v_hat_prev, v_t)
step = a_t*m_t/(T.sqrt(v_hat_t) + epsilon)
updates[m_prev] = m_t
updates[v_prev] = v_t
updates[v_hat_prev] = v_hat_t
updates[param] = param - step
updates[t_prev] = t
return updates
def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7):
"""Max weight norm constraints and gradient clipping
This takes a TensorVariable and rescales it so that incoming weight
norms are below a specified constraint value. Vectors violating the
constraint are rescaled so that they are within the allowed range.
Parameters
----------
tensor_var : TensorVariable
Theano expression for update, gradient, or other quantity.
max_norm : scalar
This value sets the maximum allowed value of any norm in
`tensor_var`.
norm_axes : sequence (list or tuple)
The axes over which to compute the norm. This overrides the
default norm axes defined for the number of dimensions
in `tensor_var`. When this is not specified and `tensor_var` is a
matrix (2D), this is set to `(0,)`. If `tensor_var` is a 3D, 4D or
5D tensor, it is set to a tuple listing all axes but axis 0. The
former default is useful for working with dense layers, the latter
is useful for 1D, 2D and 3D convolutional layers.
(Optional)
epsilon : scalar, optional
Value used to prevent numerical instability when dividing by
very small or zero norms.
Returns
-------
TensorVariable
Input `tensor_var` with rescaling applied to weight vectors
that violate the specified constraints.
Examples
--------
>>> param = theano.shared(
... np.random.randn(100, 200).astype(theano.config.floatX))
>>> update = param + 100
>>> update = norm_constraint(update, 10)
>>> func = theano.function([], [], updates=[(param, update)])
>>> # Apply constrained update
>>> _ = func()
>>> from lasagne.utils import compute_norms
>>> norms = compute_norms(param.get_value())
>>> np.isclose(np.max(norms), 10)
True
Notes
-----
When `norm_axes` is not specified, the axes over which the norm is
computed depend on the dimensionality of the input variable. If it is
2D, it is assumed to come from a dense layer, and the norm is computed
over axis 0. If it is 3D, 4D or 5D, it is assumed to come from a
convolutional layer and the norm is computed over all trailing axes
beyond axis 0. For other uses, you should explicitly specify the axes
over which to compute the norm using `norm_axes`.
"""
ndim = tensor_var.ndim
if norm_axes is not None:
sum_over = tuple(norm_axes)
elif ndim == 2: # DenseLayer
sum_over = (0,)
elif ndim in [3, 4, 5]: # Conv{1,2,3}DLayer
sum_over = tuple(range(1, ndim))
else:
raise ValueError(
"Unsupported tensor dimensionality {}."
"Must specify `norm_axes`".format(ndim)
)
dtype = np.dtype(theano.config.floatX).type
norms = T.sqrt(T.sum(T.sqr(tensor_var), axis=sum_over, keepdims=True))
target_norms = T.clip(norms, 0, dtype(max_norm))
constrained_output = \
(tensor_var * (target_norms / (dtype(epsilon) + norms)))
return constrained_output
def total_norm_constraint(tensor_vars, max_norm, epsilon=1e-7,
return_norm=False):
"""Rescales a list of tensors based on their combined norm
If the combined norm of the input tensors exceeds the threshold then all
tensors are rescaled such that the combined norm is equal to the threshold.
Scaling the norms of the gradients is often used when training recurrent
neural networks [1]_.
Parameters
----------
tensor_vars : List of TensorVariables.
Tensors to be rescaled.
max_norm : float
Threshold value for total norm.
epsilon : scalar, optional
Value used to prevent numerical instability when dividing by
very small or zero norms.
return_norm : bool
If true the total norm is also returned.
Returns
-------
tensor_vars_scaled : list of TensorVariables
The scaled tensor variables.
norm : Theano scalar
The combined norms of the input variables prior to rescaling,
only returned if ``return_norms=True``.
Examples
--------
>>> from lasagne.layers import InputLayer, DenseLayer
>>> import lasagne
>>> from lasagne.updates import sgd, total_norm_constraint
>>> x = T.matrix()
>>> y = T.ivector()
>>> l_in = InputLayer((5, 10))
>>> l1 = DenseLayer(l_in, num_units=7, nonlinearity=T.nnet.softmax)
>>> output = lasagne.layers.get_output(l1, x)
>>> cost = T.mean(T.nnet.categorical_crossentropy(output, y))
>>> all_params = lasagne.layers.get_all_params(l1)
>>> all_grads = T.grad(cost, all_params)
>>> scaled_grads = total_norm_constraint(all_grads, 5)
>>> updates = sgd(scaled_grads, all_params, learning_rate=0.1)
Notes
-----
The total norm can be used to monitor training.
References
----------
.. [1] Sutskever, I., Vinyals, O., & Le, Q. V. (2014): Sequence to sequence
learning with neural networks. In Advances in Neural Information
Processing Systems (pp. 3104-3112).
"""
norm = T.sqrt(sum(T.sum(tensor**2) for tensor in tensor_vars))
dtype = np.dtype(theano.config.floatX).type
target_norm = T.clip(norm, 0, dtype(max_norm))
multiplier = target_norm / (dtype(epsilon) + norm)
tensor_vars_scaled = [step*multiplier for step in tensor_vars]
if return_norm:
return tensor_vars_scaled, norm
else:
return tensor_vars_scaled