Skip to content

Commit

Permalink
Example: Ensembler refactoring code, moving snippet to generate unsee…
Browse files Browse the repository at this point in the history
…n input data to main function
  • Loading branch information
neomatrix369 committed Feb 2, 2021
1 parent 862a0d9 commit 5f991b5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
package org.neomatrix369.ensembler;

import java.util.Arrays;
import deepnetts.examples.util.CsvFile;
import org.neomatrix369.ensembler.RegressionTribuoExample.*;
import org.neomatrix369.ensembler.RegressionDeepNettsExample.*;

public class EnsemblerMachine
{
private static String csvValidationFilename = "datasets/deepnetts-linear-regression-validation.csv";
private static int UNSEEN_DATA_COUNT = 100;

public static void main( String[] args ) throws Exception
{
System.out.println("~ Running Ensembler Machine");
System.out.println("CLI Params: " + Arrays.toString(args));

// plot predictions for some random data
double[][] data = new double[UNSEEN_DATA_COUNT][2];

for(int i=0; i<UNSEEN_DATA_COUNT; i++) {
data[i][0] = 0.5-Math.random();
data[i][1] = 0;
}

CsvFile.write(data, csvValidationFilename, "x,y");

boolean showOutput = true;
new RegressionDeepNettsExample(showOutput).run();
new RegressionTribuoExample(showOutput).run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class RegressionDeepNettsExample
private boolean showOutput;
private static String csvFilename = "datasets/linear-for-deepnetts.csv";
private static String csvValidationFilename = "datasets/deepnetts-linear-regression-validation.csv";
private static int UNSEEN_DATA_COUNT = 100;

public RegressionDeepNettsExample(boolean showOutput) {
this.showOutput = showOutput;
Expand Down Expand Up @@ -90,7 +91,7 @@ public void run() throws Exception


// plot predictions for some random data
double[][] data = new double[100][2];
double[][] data = CsvFile.read(csvValidationFilename, UNSEEN_DATA_COUNT, true);

for(int i=0; i<data.length; i++) {
data[i][0] = 0.5-Math.random();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,26 @@ public static void write(PrintWriter pw, double[][] data) throws FileNotFoundExc
}
}

public static double[][] read(String fileName, int lines) {
public static double[][] read(String fileName, int lines, boolean hasHeader) {
BufferedReader br = null;
try {
double[][] data = new double[lines][2];
br = new BufferedReader(new FileReader(fileName));

int startIndex = 0;
if (hasHeader) {
startIndex = 1;
}
for(int i=0; i<data.length; i++) {
String line = br.readLine();
String[] strVals = line.split(",");
data[i][0] = Double.parseDouble(strVals[0]);
data[i][1] = Double.parseDouble(strVals[1]);
if (i >= startIndex) {
String line = br.readLine();
String[] strVals = line.split(",");
data[startIndex][0] = Double.parseDouble(strVals[0]);
data[startIndex][1] = Double.parseDouble(strVals[1]);
startIndex = startIndex + 1;
} else {
br.readLine();
}
} br.close();
return data;
} catch (FileNotFoundException ex) {
Expand Down

0 comments on commit 5f991b5

Please sign in to comment.