Skip to content

Commit

Permalink
dropout experimentally added
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-vasilev committed Apr 19, 2014
1 parent 4aff5f8 commit 9b27b56
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ I'm using the [git-flow](https://github.com/nvie/gitflow) model. The most stable
* Convolutional networks with max pooling, average poolng and [stochastic pooling](http://techtalks.tv/talks/stochastic-pooling-for-regularization-of-deep-convolutional-neural-networks/58106/).

##Training algorithms
* Backpropagation - supports multilayer perceptrons and convolutional networks.
* Backpropagation - supports multilayer perceptrons, convolutional networks and [dropout](http://arxiv.org/pdf/1207.0580.pdf).
* Contrastive divergence and persistent contrastive divergence implemented using [these](http://www.iro.umontreal.ca/~lisa/publications2/index.php/publications/show/239) and [these](http://www.cs.toronto.edu/~hinton/absps/guideTR.pdf) guidelines.
* Greedy layer-wise training for deep networks - works for stacked autoencoders and DBNs, but supports any kind of training.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ public class AparapiNoise extends XORShiftKernel implements TensorFunction {
private final float corruptionLevel;
private final int startIndex;
private float[] inputOutput;
private final float corruptedValue;

public AparapiNoise(Tensor inputOutput, int maximumRange, float corruptionLevel) {
public AparapiNoise(Tensor inputOutput, int maximumRange, float corruptionLevel, float corruptedValue) {
super(maximumRange);
this.inputOutput = inputOutput.getElements();
this.startIndex = inputOutput.getStartIndex();
this.corruptionLevel = corruptionLevel;
this.corruptedValue = corruptedValue;
}

@Override
Expand All @@ -38,7 +40,7 @@ public void value(Tensor inputOutput) {
public void run() {
int id = getGlobalId();
if (random01() < corruptionLevel) {
inputOutput[startIndex + id] = 0;
inputOutput[startIndex + id] = corruptedValue;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public class ConnectionCalculatorFullyConnected implements ConnectionCalculator,
*/
protected List<TensorFunction> activationFunctions;

/**
* Dropout properties
*/
protected float dropoutRate;
protected TensorFunction dropoutFunction;

public ConnectionCalculatorFullyConnected() {
super();
inputFunctions = new HashSet<>();
Expand Down Expand Up @@ -82,6 +88,14 @@ public void calculate(List<Connections> connections, ValuesProvider valuesProvid
if (activationFunctions != null) {
activationFunctions.forEach(f -> f.value(TensorFactory.tensor(targetLayer, notBias, valuesProvider)));
}

if (dropoutRate > 0) {
if (dropoutFunction == null) {
dropoutFunction = createDropoutFunction(notBias, valuesProvider, targetLayer);
}

dropoutFunction.value(TensorFactory.tensor(targetLayer, notBias, valuesProvider));
}
}
}
}
Expand Down Expand Up @@ -125,6 +139,14 @@ public void removeActivationFunction(TensorFunction activationFunction) {
}
}

public float getDropoutRate() {
return dropoutRate;
}

public void setDropoutRate(float dropoutRate) {
this.dropoutRate = dropoutRate;
}

protected void calculateBias(Connections bias, ValuesProvider valuesProvider) {
if (bias != null) {
Tensor biasValue = TensorFactory.tensor(bias.getInputLayer(), bias, valuesProvider);
Expand Down Expand Up @@ -155,6 +177,11 @@ protected ConnectionCalculator createInputFunction(List<Connections> inputConnec
return new AparapiWeightedSum(inputConnections, valuesProvider, targetLayer);
}

protected TensorFunction createDropoutFunction(List<Connections> inputConnections, ValuesProvider valuesProvider, Layer targetLayer) {
Tensor t = TensorFactory.tensor(targetLayer, inputConnections, valuesProvider);
return new AparapiNoise(t, t.getSize(), dropoutRate, 0);
}

private ConnectionCalculator getConnectionCalculator(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
ConnectionCalculator result = inputFunctions.stream().filter(c -> {
return !(c instanceof AparapiWeightedSum) || ((AparapiWeightedSum) c).accept(connections, valuesProvider, targetLayer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ public class TrainerFactory {
* @param l1weightDecay
* @return
*/
public static BackPropagationTrainer<?> backPropagation(NeuralNetworkImpl nn, TrainingInputProvider trainingSet, TrainingInputProvider testingSet, OutputError error, NNRandomInitializer rand, float learningRate, float momentum, float l1weightDecay, float l2weightDecay, int trainingBatchSize, int testBatchSize, int epochs) {
BackPropagationTrainer<?> t = new BackPropagationTrainer<NeuralNetwork>(backpropProperties(nn, trainingSet, testingSet, error, rand, learningRate, momentum, l1weightDecay, l2weightDecay, trainingBatchSize, testBatchSize, epochs));

BackPropagationLayerCalculatorImpl bplc = bplc(nn, t.getProperties());
t.getProperties().setParameter(Constants.BACKPROPAGATION, bplc);
public static BackPropagationTrainer<?> backPropagation(NeuralNetworkImpl nn, TrainingInputProvider trainingSet, TrainingInputProvider testingSet, OutputError error, NNRandomInitializer rand, float learningRate, float momentum, float l1weightDecay, float l2weightDecay, float dropoutRate, int trainingBatchSize, int testBatchSize, int epochs) {
Properties p = backpropProperties(nn, trainingSet, testingSet, error, rand, learningRate, momentum, l1weightDecay, l2weightDecay, trainingBatchSize, testBatchSize, epochs);
p.setParameter(Constants.BACKPROPAGATION, bplc(nn, p));
p.setParameter(Constants.DROPOUT_RATE, dropoutRate);

return t;
return new BackPropagationTrainer<NeuralNetwork>(p);
}

private static BackPropagationLayerCalculatorImpl bplc(NeuralNetworkImpl nn, Properties p) {
Expand Down Expand Up @@ -174,14 +173,9 @@ private static BackPropagationLayerCalculatorImpl bplc(NeuralNetworkImpl nn, Pro
public static BackPropagationAutoencoder backPropagationAutoencoder(NeuralNetworkImpl nn, TrainingInputProvider trainingSet, TrainingInputProvider testingSet, OutputError error, NNRandomInitializer rand, float learningRate, float momentum, float l1weightDecay, float l2weightDecay, float inputCorruptionRate, int trainingBatchSize, int testBatchSize, int epochs) {
Properties p = backpropProperties(nn, trainingSet, testingSet, error, rand, learningRate, momentum, l1weightDecay, l2weightDecay, trainingBatchSize, testBatchSize, epochs);
p.setParameter(Constants.CORRUPTION_LEVEL, inputCorruptionRate);
p.setParameter(Constants.BACKPROPAGATION, bplc(nn, p));

BackPropagationAutoencoder t = new BackPropagationAutoencoder(p);

BackPropagationLayerCalculatorImpl bplc = bplc(nn, p);

t.getProperties().setParameter(Constants.BACKPROPAGATION, bplc);

return t;
return new BackPropagationAutoencoder(p);
}

protected static Properties backpropProperties(NeuralNetworkImpl nn, TrainingInputProvider trainingSet, TrainingInputProvider testingSet, OutputError error, NNRandomInitializer rand, float learningRate, float momentum, float l1weightDecay, float l2weightDecay, int trainingBatchSize, int testBatchSize, int epochs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public float[] getNextTarget() {
if (corruptionRate != null && corruptionRate > 0) {
if (noise == null) {
noiseTensor = TensorFactory.tensor(base.getNextInput().length);
noise = new AparapiNoise(noiseTensor, base.getNextInput().length, corruptionRate);
noise = new AparapiNoise(noiseTensor, base.getNextInput().length, corruptionRate, 0);
}

System.arraycopy(result, 0, noiseTensor.getElements(), 0, result.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@

import java.util.Set;

import com.github.neuralnetworks.architecture.FullyConnected;
import com.github.neuralnetworks.architecture.Layer;
import com.github.neuralnetworks.architecture.NeuralNetwork;
import com.github.neuralnetworks.calculation.LayerCalculatorImpl;
import com.github.neuralnetworks.calculation.memory.ValuesProvider;
import com.github.neuralnetworks.calculation.neuronfunctions.ConnectionCalculatorFullyConnected;
import com.github.neuralnetworks.events.TrainingEvent;
import com.github.neuralnetworks.events.TrainingEventListener;
import com.github.neuralnetworks.training.OneStepTrainer;
import com.github.neuralnetworks.training.TrainingInputData;
import com.github.neuralnetworks.training.TrainingInputDataImpl;
import com.github.neuralnetworks.training.events.TrainingFinishedEvent;
import com.github.neuralnetworks.util.Constants;
import com.github.neuralnetworks.util.Environment;
import com.github.neuralnetworks.util.Properties;
import com.github.neuralnetworks.util.TensorFactory;
import com.github.neuralnetworks.util.UniqueList;
import com.github.neuralnetworks.util.Util;

/**
* Base backpropagation one step trainer
Expand All @@ -21,7 +28,7 @@
* OutputErrorDerivative for calculating the derivative of the output error
* This allows for various implementations of these calculators to be used (for example via GPU or other)
*/
public class BackPropagationTrainer<N extends NeuralNetwork> extends OneStepTrainer<N> {
public class BackPropagationTrainer<N extends NeuralNetwork> extends OneStepTrainer<N> implements TrainingEventListener {

private static final long serialVersionUID = 1L;

Expand All @@ -31,9 +38,22 @@ public class BackPropagationTrainer<N extends NeuralNetwork> extends OneStepTrai

public BackPropagationTrainer(Properties properties) {
super(properties);
activations = TensorFactory.tensorProvider(getNeuralNetwork(), getTrainingBatchSize(), Environment.getInstance().getUseDataSharedMemory());
NeuralNetwork nn = getNeuralNetwork();
activations = TensorFactory.tensorProvider(nn, getTrainingBatchSize(), Environment.getInstance().getUseDataSharedMemory());
activations.add(getProperties().getParameter(Constants.OUTPUT_ERROR_DERIVATIVE), activations.get(getNeuralNetwork().getOutputLayer()).getDimensions());
backpropagation = TensorFactory.tensorProvider(getNeuralNetwork(), getTrainingBatchSize(), Environment.getInstance().getUseDataSharedMemory());
backpropagation = TensorFactory.tensorProvider(nn, getTrainingBatchSize(), Environment.getInstance().getUseDataSharedMemory());

float dropoutRate = properties.getParameter(Constants.DROPOUT_RATE);

if (dropoutRate > 0) {
LayerCalculatorImpl lc = (LayerCalculatorImpl) nn.getLayerCalculator();
nn.getConnections().stream().filter(c -> c instanceof FullyConnected && c.getInputLayer() != nn.getInputLayer() && !Util.isBias(c.getInputLayer())).forEach(c -> {
ConnectionCalculatorFullyConnected cc = (ConnectionCalculatorFullyConnected) lc.getConnectionCalculator(c.getOutputLayer());
cc.setDropoutRate(dropoutRate);
});

addEventListener(this);
}
}

/* (non-Javadoc)
Expand Down Expand Up @@ -67,6 +87,25 @@ protected TrainingInputData getInput() {
return input;
}

@Override
public void handleEvent(TrainingEvent event) {
if (event instanceof TrainingFinishedEvent) {
float dropoutRate = properties.getParameter(Constants.DROPOUT_RATE);

if (dropoutRate > 0) {
NeuralNetwork nn = getNeuralNetwork();

LayerCalculatorImpl lc = (LayerCalculatorImpl) nn.getLayerCalculator();
nn.getConnections().stream().filter(c -> c instanceof FullyConnected && c.getInputLayer() != nn.getInputLayer() && !Util.isBias(c.getInputLayer())).forEach(c -> {
ConnectionCalculatorFullyConnected cc = (ConnectionCalculatorFullyConnected) lc.getConnectionCalculator(c.getOutputLayer());
cc.setDropoutRate(0);
FullyConnected fc = (FullyConnected) c;
fc.getWeights().forEach(i -> fc.getWeights().getElements()[i] = fc.getWeights().getElements()[i] * (1 - dropoutRate));
});
}
}
}

public BackPropagationLayerCalculator getBPLayerCalculator() {
return getProperties().getParameter(Constants.BACKPROPAGATION);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ public class Constants {
public static final String TEST_BATCH_SIZE = "testBatchSize";
public static final String EPOCHS = "epochs";
public static final String CONNECTION_FACTORY = "connectionFactory";
public static final String DROPOUT_RATE = "dropoutRate";
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public void testCNNLayerCalculatorConstruction() {
assertTrue(lc.getConnectionCalculator(l) instanceof ConnectionCalculatorFullyConnected);

// backpropagation cc
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(nn, null, null, null, null, 0.01f, 0.5f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(nn, null, null, null, null, 0.01f, 0.5f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationLayerCalculatorImpl bplc = (BackPropagationLayerCalculatorImpl) bpt.getBPLayerCalculator();

l = nn.getInputLayer();
Expand Down Expand Up @@ -246,7 +246,7 @@ public void testCNNLayerCalculatorConstruction() {
l = l.getConnections().get(0).getOutputLayer();
assertTrue(lc.getConnectionCalculator(l) instanceof AparapiSigmoid);

bpt = TrainerFactory.backPropagation(nn, null, null, new MultipleNeuronsOutputError(), null, 0.02f, 0.5f, 0f, 0f, 1, 1, 1);
bpt = TrainerFactory.backPropagation(nn, null, null, new MultipleNeuronsOutputError(), null, 0.02f, 0.5f, 0f, 0f, 0f, 1, 1, 1);
bplc = (BackPropagationLayerCalculatorImpl) bpt.getBPLayerCalculator();

l = nn.getInputLayer();
Expand Down Expand Up @@ -531,7 +531,7 @@ public void testCNNBackpropagation() {
b.getWeights().getElements()[b.getWeights().getStartIndex()] = -3f;

SimpleInputProvider ts = new SimpleInputProvider(new float[][] { { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f } }, new float[][] { { 1, 1, 1, 1 } });
BackPropagationTrainer<?> t = TrainerFactory.backPropagation(nn, ts, null, null, null, 0.5f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> t = TrainerFactory.backPropagation(nn, ts, null, null, null, 0.5f, 0f, 0f, 0f, 0f, 1, 1, 1);
t.train();

it = c.getWeights().iterator();
Expand Down Expand Up @@ -565,7 +565,7 @@ public void testCNNBackpropagation2() {
cg2.set(0.3f, 0, 0);
cg2.set(0.9f, 0, 1);

BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(nn, new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), null, null, 1f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(nn, new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), null, null, 1f, 0f, 0f, 0f, 0f, 1, 1, 1);
bpt.train();

assertEquals(0.09916, cg1.get(0, 0), 0.001);
Expand Down Expand Up @@ -596,7 +596,7 @@ public void testCNNBackpropagation3() {
b.getWeights().getElements()[b.getWeights().getStartIndex()] = -3f;

SimpleInputProvider ts = new SimpleInputProvider(new float[][] { { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f } }, new float[][] { { 1, 1, 1, 1 }, { 1, 1, 1, 1 } });
BackPropagationTrainer<?> t = TrainerFactory.backPropagation(nn, ts, null, null, null, 0.5f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> t = TrainerFactory.backPropagation(nn, ts, null, null, null, 0.5f, 0f, 0f, 0f, 0f, 1, 1, 1);
t.train();

it = c.getWeights().iterator();
Expand Down
58 changes: 55 additions & 3 deletions nn-core/src/test/java/com/github/neuralnetworks/test/FFNNTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public void testSigmoidBP() {
cg2.set(0.3f, 0, 0);
cg2.set(0.9f, 0, 1);

BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), null, null, 1f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), new SimpleInputProvider(new float[][] { { 0.35f, 0.9f } }, new float[][] { { 0.5f } }), null, null, 1f, 0f, 0f, 0f, 0f, 1, 1, 1);
bpt.train();

assertEquals(0.09916, cg1.get(0, 0), 0.01);
Expand Down Expand Up @@ -338,7 +338,7 @@ public void testSigmoidBP2() {
Matrix cgb2 = cb2.getWeights();
cgb2.set(0.1f, 0, 0);

BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), null, null, 0.9f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), null, null, 0.9f, 0f, 0f, 0f, 0f, 1, 1, 1);
bpt.train();

assertEquals(0.192, cg1.get(0, 0), 0.001);
Expand All @@ -357,6 +357,58 @@ public void testSigmoidBP2() {
assertEquals(0.218, cgb2.get(0, 0), 0.001);
}

/**
* BP with dropout
*/
@Test
public void testSigmoidBPDropout() {
//Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);
Environment.getInstance().setUseWeightsSharedMemory(true);
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 3, 2, 1 }, true);

List<Connections> c = mlp.getConnections();
FullyConnected c1 = (FullyConnected) c.get(0);
Matrix cg1 = c1.getWeights();
cg1.set(0.2f, 0, 0);
cg1.set(0.4f, 0, 1);
cg1.set(-0.5f, 0, 2);
cg1.set(-0.3f, 1, 0);
cg1.set(0.1f, 1, 1);
cg1.set(0.2f, 1, 2);

FullyConnected cb1 = (FullyConnected) c.get(1);
Matrix cgb1 = cb1.getWeights();
cgb1.set(-0.4f, 0, 0);
cgb1.set(0.2f, 1, 0);

FullyConnected c2 = (FullyConnected) c.get(2);
Matrix cg2 = c2.getWeights();
cg2.set(-0.3f, 0, 0);
cg2.set(-0.2f, 0, 1);

FullyConnected cb2 = (FullyConnected) c.get(3);
Matrix cgb2 = cb2.getWeights();
cgb2.set(0.1f, 0, 0);

BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), new SimpleInputProvider(new float[][] { { 1, 0, 1 } }, new float[][] { { 1 } }), null, null, 0.9f, 0f, 0f, 0f, 0.01f, 1, 1, 1);
bpt.train();

assertEquals(0.192, cg1.get(0, 0), 0.001);
assertEquals(0.4, cg1.get(0, 1), 0.001);
assertEquals(-0.508, cg1.get(0, 2), 0.001);
assertEquals(-0.306, cg1.get(1, 0), 0.001);
assertEquals(0.1, cg1.get(1, 1), 0.001);
assertEquals(0.194, cg1.get(1, 2), 0.001);

assertEquals(-0.261 * 0.99, cg2.get(0, 0), 0.001);
assertEquals(-0.138 * 0.99, cg2.get(0, 1), 0.001);

assertEquals(-0.408, cgb1.get(0, 0), 0.001);
assertEquals(0.194, cgb1.get(1, 0), 0.001);

assertEquals(0.218, cgb2.get(0, 0), 0.001);
}

@Test
public void testSigmoidBP3() {
Environment.getInstance().setUseDataSharedMemory(true);
Expand Down Expand Up @@ -388,7 +440,7 @@ public void testSigmoidBP3() {
Matrix cgb2 = cb2.getWeights();
cgb2.set(0.1f, 0, 0);

BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 1, 0, 1 }, { 1, 1, 0 } }, new float[][] { { 1 }, { 1 } }), null, null, null, 0.9f, 0f, 0f, 0f, 1, 1, 1);
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, new SimpleInputProvider(new float[][] { { 1, 0, 1 }, { 1, 1, 0 } }, new float[][] { { 1 }, { 1 } }), null, null, null, 0.9f, 0f, 0f, 0f, 0f, 1, 1, 1);
bpt.train();

assertEquals(0.1849, cg1.get(0, 0), 0.0001);
Expand Down
Loading

0 comments on commit 9b27b56

Please sign in to comment.