Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Add RMSprop, AdaDelta and Adam solvers
Browse files Browse the repository at this point in the history
Disable solvers which aren't supported by the selected framework
  • Loading branch information
lukeyeager committed Feb 2, 2016
1 parent ea2a941 commit fa12f24
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 10 deletions.
9 changes: 8 additions & 1 deletion digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from .errors import BadNetworkError
from .framework import Framework
import digits
from digits.config import config_value
from digits.model.tasks import CaffeTrainTask
from digits.utils import subclass, override
from digits.utils import subclass, override, parse_version

@subclass
class CaffeFramework(Framework):
Expand All @@ -33,6 +34,12 @@ class CaffeFramework(Framework):
# whether this framework can shuffle data during training
CAN_SHUFFLE_DATA = False

if config_value('caffe_root')['version'] > parse_version('0.14.0-alpha'):
SUPPORTED_SOLVER_TYPES = ['SGD', 'NESTEROV', 'ADAGRAD',
'RMSPROP', 'ADADELTA', 'ADAM']
else:
SUPPORTED_SOLVER_TYPES = ['SGD', 'NESTEROV', 'ADAGRAD']

@override
def __init__(self):
super(CaffeFramework, self).__init__()
Expand Down
9 changes: 9 additions & 0 deletions digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def can_shuffle_data(self):
"""
return self.CAN_SHUFFLE_DATA

def supports_solver_type(self, solver_type):
"""
return whether framework supports this solver_type
"""
if not hasattr(self, 'SUPPORTED_SOLVER_TYPES'):
raise NotImplementedError
assert isinstance(self.SUPPORTED_SOLVER_TYPES, list)
return solver_type in self.SUPPORTED_SOLVER_TYPES

def validate_network(self, data):
"""
validate a network (must be implemented in child class)
Expand Down
2 changes: 2 additions & 0 deletions digits/frameworks/torch_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class TorchFramework(Framework):
# whether this framework can shuffle data during training
CAN_SHUFFLE_DATA = True

SUPPORTED_SOLVER_TYPES = ['SGD']

def __init__(self):
super(TorchFramework, self).__init__()
# id must be unique
Expand Down
27 changes: 19 additions & 8 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,26 @@ def validate_py_ext(form, field):

### Solver types

solver_type = utils.forms.SelectField('Solver type',
solver_type = utils.forms.SelectField(
'Solver type',
choices = [
('SGD', 'Stochastic gradient descent (SGD)'),
('ADAGRAD', 'Adaptive gradient (AdaGrad)'),
('NESTEROV', "Nesterov's accelerated gradient (NAG)"),
],
default = 'SGD',
tooltip = "What type of solver will be used??"
)
('SGD', 'Stochastic gradient descent (SGD)'),
('NESTEROV', "Nesterov's accelerated gradient (NAG)"),
('ADAGRAD', 'Adaptive gradient (AdaGrad)'),
('RMSPROP', 'RMSprop'),
('ADADELTA', 'AdaDelta'),
('ADAM', 'Adam'),
],
default = 'SGD',
tooltip = "What type of solver will be used?",
)

def validate_solver_type(form, field):
fw = frameworks.get_framework_by_id(form.framework)
if fw is not None:
if not fw.supports_solver_type(field.data):
raise validators.ValidationError(
'Solver type not supported by this framework')

### Learning rate

Expand Down
2 changes: 1 addition & 1 deletion digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def save_files_classification(self):
raise Exception('Unknown lr_policy: "%s"' % solver.lr_policy)

# go with the suggested defaults
if solver.solver_type != solver.ADAGRAD:
if solver.solver_type not in [solver.ADAGRAD, solver.RMSPROP]:
solver.momentum = 0.9
solver.weight_decay = 0.0005

Expand Down
16 changes: 16 additions & 0 deletions digits/templates/models/images/classification/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ <h4>Solver Options</h4>
</label>
</p>
<script>
{% for choice in form.solver_type.choices %}
{% for fw in frameworks %}
{% if fw.supports_solver_type(choice[0]) %}
$("select[name=solver_type] > option[value={{choice[0]}}").addClass("{{fw.get_id()}}");
{% endif %}
{% endfor %}
{% endfor %}

function showHideAdvancedLROptions() {
if ($("#show-advanced-lr-options").prop("checked")) {
$("#advanced-lr-options").show();
Expand Down Expand Up @@ -399,13 +407,21 @@ <h4>Solver Options</h4>
$("#shuffle-data").show();
else
$("#shuffle-data").hide();

$("select[name=solver_type] > option." + fwid).prop('disabled', false);
$("select[name=solver_type] > option").not("." + fwid).prop('disabled', true);
if (! $("select[name=solver_type] > option:selected").hasClass(fwid)) {
$("select[name=solver_type] > option:selected").prop("selected", false);
}

if (fwid == 'torch')
$("#torch-warning").show();
else
$("#torch-warning").hide();
$('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show');
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');
}
setFramework("{{form.framework.data}}");
</script>

<div id="torch-warning" class="alert alert-warning" style="display:none;">
Expand Down
15 changes: 15 additions & 0 deletions digits/templates/models/images/generic/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ <h4>Solver Options</h4>
</label>
</p>
<script>
{% for choice in form.solver_type.choices %}
{% for fw in frameworks %}
{% if fw.supports_solver_type(choice[0]) %}
$("select[name=solver_type] > option[value={{choice[0]}}").addClass("{{fw.get_id()}}");
{% endif %}
{% endfor %}
{% endfor %}

function showHideAdvancedLROptions() {
if ($("#show-advanced-lr-options").prop("checked")) {
$("#advanced-lr-options").show();
Expand Down Expand Up @@ -395,7 +403,14 @@ <h4>Solver Options</h4>
$("#shuffle-data").show();
else
$("#shuffle-data").hide();

$("select[name=solver_type] > option." + fwid).prop('disabled', false);
$("select[name=solver_type] > option").not("." + fwid).prop('disabled', true);
if (! $("select[name=solver_type] > option:selected").hasClass(fwid)) {
$("select[name=solver_type] > option:selected").prop("selected", false);
}
}
setFramework("{{form.framework.data}}");
</script>

<ul class="nav nav-tabs" id="network-tabs">
Expand Down

0 comments on commit fa12f24

Please sign in to comment.