Skip to content

Train method fails in MultipleNeuronsOutputError.getTotalErrorSamples #40

Open
@joelself

Description

I have the following code (located at the end of the issue) to create, train and test a NN, but it fails in MultipleNeuronsOutputError here:

    for (OutputTargetTuple t : tuples) {
        if (!outputToTarget.get(t.outputPos).equals(t.targetPos)) {
        errorSamples++;
        }
    }

If I add outputToTarget.get(t.outputPos) != null && to the if statement it finishes successfully, but with zero samples and thus no error value.

I've checked to make sure the data is read in correctly and it seems to be fine. It trains just fine, the problem is that it fails on test.

Also switching to the GPU makes the training of a single epoch take forever. I've never actually seen it finish.

        Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);

        // create multi layer perceptron with one hidden layer and bias
        Environment.getInstance().setUseWeightsSharedMemory(false);
        Environment.getInstance().setUseDataSharedMemory(false);
        NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[]{40, 75, 75, 75, 10}, true);

        // create training and testing input providers
        FileReader reader;
        System.out.println("Try read data");
        List<float[][]> data = new ArrayList<float[][]>();
        try {
            reader = new FileReader("C:\\Users\\jself\\Data\\training_data.data2");
            data = GetDataFromFile(reader);
        } catch (FileNotFoundException e) {

        }
        System.out.println("Create input provider and trainer");
        SimpleInputProvider input = new SimpleInputProvider(data.get(0), data.get(1));
        // create backpropagation trainer for the network
        BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, input, input, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.1f, 0.7f, 0f, 0f, 0f, 1, 1, 1);

        // add logging
        bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName()));

        // early stopping
        //bpt.addEventListener(new EarlyStoppingListener(testingInput, 10, 0.1f));

        System.out.println("Start training");
        // train
        bpt.train();

        System.out.println("Start testing");
        // test
        bpt.test();

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions