Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-vasilev committed Apr 18, 2014
1 parent 8f2eef5 commit 48c00b5
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 300 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public DeepTrainerTrainingInputProvider(TrainingInputProvider inputProvider, DNN
}

@Override
public void after(TrainingInputData ti) {
public void afterBatch(TrainingInputData ti) {
if (dnn.getFirstNeuralNetwork() != currentNN) {
inputProvider.populateNext(inputDataBase);
calculatedLayers.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,84 @@
public interface TrainingInputProvider extends Serializable {

public int getInputSize();

public void reset();

float[] getNextInput();

float[] getNextTarget();

List<TensorFunction> getInputModifiers();
void after(TrainingInputData ti);
void before(TrainingInputData ti);

void afterBatch(TrainingInputData ti);

void beforeBatch(TrainingInputData ti);

void afterSample();

void beforeSample();

public default void populateNext(TrainingInputData ti) {
before(ti);
beforeBatch(ti);

// batch size
int batchSize = 0;
if (ti.getInput() != null && ti.getTarget() != null && ti.getInput().getDimensions()[ti.getInput().getDimensions().length - 1] != ti.getInput().getDimensions()[ti.getInput().getDimensions().length - 1]) {
throw new IllegalArgumentException("Input and target batch size don't match");
}

// input
if (ti.getInput() != null) {
int[] inputDims = ti.getInput().getDimensions();
int[][] limits = new int[2][inputDims.length];
IntStream.range(0, inputDims.length - 1).forEach(i -> limits[1][i] = inputDims[i] - 1);
IntStream.range(0, inputDims[inputDims.length - 1]).forEach(i -> {
limits[0][inputDims.length - 1] = limits[1][inputDims.length - 1] = i;
TensorIterator it = ti.getInput().iterator(limits);
float[] inputEl = getNextInput();
IntStream.range(0, inputEl.length).forEach(j -> ti.getInput().getElements()[it.next()] = inputEl[j]);
});
batchSize = ti.getInput().getDimensions()[ti.getInput().getDimensions().length - 1];
} else if (ti.getTarget() != null) {
batchSize = ti.getTarget().getDimensions()[ti.getTarget().getDimensions().length - 1];
}

int[] inputDims = null;
int[][] inputLimits = null;
int[] targetDims = null;
int[][] targetLimits = null;

if (getInputModifiers() != null) {
getInputModifiers().forEach(im -> im.value(ti.getInput()));
if (ti.getInput() != null) {
inputDims = ti.getInput().getDimensions();
inputLimits = new int[2][inputDims.length];
for (int i = 0; i < inputDims.length - 1; i++) {
inputLimits[1][i] = inputDims[i] - 1;
}
}

// target
if (ti.getTarget() != null) {
int[] targetDims = ti.getTarget().getDimensions();
int[][] limits = new int[2][targetDims.length];
IntStream.range(0, targetDims.length - 1).forEach(i -> limits[1][i] = targetDims[i] - 1);
IntStream.range(0, targetDims[targetDims.length - 1]).forEach(i -> {
limits[0][targetDims.length - 1] = limits[1][targetDims.length - 1] = i;
TensorIterator it = ti.getTarget().iterator(limits);
targetDims = ti.getTarget().getDimensions();
targetLimits = new int[2][targetDims.length];
for (int i = 0; i < inputDims.length - 1; i++) {
targetLimits[1][i] = targetDims[i] - 1;
}
}

// data population
for (int i = 0; i < batchSize; i++) {
beforeSample();

if (ti.getInput() != null) {
inputLimits[0][inputDims.length - 1] = inputLimits[1][inputDims.length - 1] = i;
TensorIterator inputIt = ti.getInput().iterator(inputLimits);
float[] inputEl = getNextInput();
IntStream.range(0, inputEl.length).forEach(j -> ti.getInput().getElements()[inputIt.next()] = inputEl[j]);
}

if (ti.getTarget() != null) {
targetLimits[0][targetDims.length - 1] = targetLimits[1][targetDims.length - 1] = i;
TensorIterator targetIt = ti.getTarget().iterator(targetLimits);
float[] targetEl = getNextTarget();
IntStream.range(0, targetEl.length).forEach(j -> ti.getTarget().getElements()[it.next()] = targetEl[j]);
});
IntStream.range(0, targetEl.length).forEach(j -> ti.getTarget().getElements()[targetIt.next()] = targetEl[j]);
}

afterSample();
}

if (ti.getInput() != null && getInputModifiers() != null) {
getInputModifiers().forEach(im -> im.value(ti.getInput()));
}

after(ti);
afterBatch(ti);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,20 @@ public void setInputConverter(InputConverter inputConverter) {
}

@Override
public void before(TrainingInputData ti) {
public void beforeBatch(TrainingInputData ti) {
}

@Override
public void after(TrainingInputData ti) {
if (ti.getInput() != null) {
currentInput += ti.getInput().getDimensions()[ti.getInput().getDimensions().length - 1];
} else if (ti.getTarget() != null) {
currentInput += ti.getTarget().getDimensions()[ti.getTarget().getDimensions().length - 1];
}
public void afterBatch(TrainingInputData ti) {
}

@Override
public void afterSample() {
currentInput++;
}

@Override
public void beforeSample() {
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,27 @@ public float[] getNextTarget() {
}

@Override
public void before(TrainingInputData ti) {
super.before(ti);
base.before(ti);
public void beforeBatch(TrainingInputData ti) {
super.beforeBatch(ti);
base.beforeBatch(ti);
}

@Override
public void after(TrainingInputData ti) {
super.after(ti);
base.after(ti);
public void afterBatch(TrainingInputData ti) {
super.afterBatch(ti);
base.afterBatch(ti);
}

@Override
public void beforeSample() {
super.beforeSample();
base.beforeSample();
}

@Override
public void afterSample() {
super.afterSample();
base.afterSample();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public void handleEvent(TrainingEvent event) {
vp = TensorFactory.tensorProvider(n, 1, Environment.getInstance().getUseDataSharedMemory());
}
if (vp.get(outputError) == null) {
vp.add(outputError, vp.get(n.getInputLayer()).getDimensions());
vp.add(outputError, vp.get(n.getOutputLayer()).getDimensions());
}
TrainingInputData input = new TrainingInputDataImpl(vp.get(n.getInputLayer()), vp.get(outputError));

Expand Down
Loading

0 comments on commit 48c00b5

Please sign in to comment.