Skip to content

Commit

Permalink
more selective warning about dropout being stochastic
Browse files Browse the repository at this point in the history
  • Loading branch information
goodfeli committed Nov 14, 2014
1 parent 07ee218 commit 53bf27d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pylearn2/costs/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,25 @@ def get_data_specs(self, model):
raise NotImplementedError(str(type(self)) + " does not implement " +
"get_data_specs.")

def is_stochastic(self):
"""
Returns True if the cost is stochastic.
Stochastic costs are incompatible with some optimization algorithms
that make multiple updates per minibatch, such as algorithms that
use line searches. These optimizations should raise a TypeError if
given a stochastic Cost, or issue a warning if given a Cost whose
`is_stochastic` method raises NotImplementedError.
Returns
-------
is_stochastic : bool
Whether the cost is stochastic. For example, dropout is
stochastic.
"""

raise NotImplementedError()


class SumOfCosts(Cost):
"""
Expand Down
11 changes: 11 additions & 0 deletions pylearn2/training_algorithms/bgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
__email__ = "pylearn-dev@googlegroups"

import logging
import warnings

import numpy as np
from theano import config
from theano.compat.python2x import OrderedDict
Expand Down Expand Up @@ -129,6 +131,15 @@ def setup(self, model, dataset):
if self.cost is None:
self.cost = model.get_default_cost()

try:
if self.cost.is_stochastic():
raise TypeError("BGD is not compatible with stochastic "
"costs.")
except NotImplementedError:
warnings.warn("BGD is not compatible with stochastic costs "
"and cannot determine whether the current cost is "
"stochastic.")

if self.batch_size is None:
self.batch_size = model.force_batch_size
else:
Expand Down

0 comments on commit 53bf27d

Please sign in to comment.