Train method fails in MultipleNeuronsOutputError.getTotalErrorSamples #40
Open
Description
I have the following code (located at the end of the issue) to create, train and test a NN, but it fails in MultipleNeuronsOutputError here:
for (OutputTargetTuple t : tuples) {
if (!outputToTarget.get(t.outputPos).equals(t.targetPos)) {
errorSamples++;
}
}
If I add outputToTarget.get(t.outputPos) != null &&
to the if
statement it finishes successfully, but with zero samples and thus no error value.
I've checked to make sure the data is read in correctly and it seems to be fine. It trains just fine, the problem is that it fails on test.
Also switching to the GPU makes the training of a single epoch take forever. I've never actually seen it finish.
Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);
// create multi layer perceptron with one hidden layer and bias
Environment.getInstance().setUseWeightsSharedMemory(false);
Environment.getInstance().setUseDataSharedMemory(false);
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[]{40, 75, 75, 75, 10}, true);
// create training and testing input providers
FileReader reader;
System.out.println("Try read data");
List<float[][]> data = new ArrayList<float[][]>();
try {
reader = new FileReader("C:\\Users\\jself\\Data\\training_data.data2");
data = GetDataFromFile(reader);
} catch (FileNotFoundException e) {
}
System.out.println("Create input provider and trainer");
SimpleInputProvider input = new SimpleInputProvider(data.get(0), data.get(1));
// create backpropagation trainer for the network
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, input, input, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.1f, 0.7f, 0f, 0f, 0f, 1, 1, 1);
// add logging
bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName()));
// early stopping
//bpt.addEventListener(new EarlyStoppingListener(testingInput, 10, 0.1f));
System.out.println("Start training");
// train
bpt.train();
System.out.println("Start testing");
// test
bpt.test();
Metadata
Assignees
Labels
No labels