Skip to content

Commit

Permalink
expose fast predict methods (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuttie authored Jun 27, 2024
1 parent 340ac45 commit 215aa34
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 3 deletions.
46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,48 @@ Note the following change in the LightGBM4 behavior:

* you need to set `objective=none metric=<eval metric>` parameters to signal that we're going to use custom objective. Otherwise the LightGBM will complain on incorrect objective.

### Low-latency predictions

Raw LGBM API exposes multiple low-level ways to make predictions with lower latency:
* Instead of `predictForMat`, you can use a single-row optimized `predictForMatSingleRow` method
* LGBM my default still uses paralellism for single-row predictions, which still affects final latency. Opt for including `threads=1` parameter for your prediction method calls.
* LightGBM4J also exposes a low-level `predictForMatSingleRowFast` method, which pre-allocates internal structures once, and reuses them on each next call.

#### Single-row prediction

```java
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
booster.updateOneIter();
booster.updateOneIter();
booster.updateOneIter();
for (int i = 0; i < 10; i++) {
double pred1 = booster.predictForMatSingleRow(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, PredictionType.C_API_PREDICT_NORMAL);
assertTrue(pred1 > 0);
double pred2 = booster.predictForMatSingleRow(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, PredictionType.C_API_PREDICT_NORMAL);
assertTrue(pred2 > 0);
}
dataset.close();
booster.close();
```

#### Single-row fast prediction

```java
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
booster.updateOneIter();
booster.updateOneIter();
booster.updateOneIter();
LGBMBooster.FastConfig config = booster.predictForMatSingleRowFastInit(PredictionType.C_API_PREDICT_NORMAL, C_API_DTYPE_FLOAT32,9, "");
double pred = booster.predictForMatSingleRowFast(config, new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, PredictionType.C_API_PREDICT_NORMAL);
assertTrue(Double.isFinite(pred));
config.close();
dataset.close();
booster.close();

```

## Supported platforms

This code is tested to work well with Linux (Ubuntu 20.04), Windows (Server 2019) and MacOS 10.15/11. Mac M1 is also supported.
Expand All @@ -280,6 +322,8 @@ Supported methods:
* [LGBM_BoosterLoadModelFromString](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterLoadModelFromString)
* [LGBM_BoosterPredictForMat](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMat)
* [LGBM_BoosterPredictForMatSingleRow](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow)
* [LGBM_BoosterPredictForMatSingleRowFast](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRowFast)
* [LGBM_BoosterPredictForMatSingleRowFastInit](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRowFastInit)
* [LGBM_BoosterSaveModel](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterSaveModel)
* [LGBM_BoosterSaveModelToString](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterSaveModelToString)
* [LGBM_BoosterUpdateOneIter](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterUpdateOneIter)
Expand Down Expand Up @@ -314,8 +358,6 @@ Not yet supported:
* [LGBM_BoosterPredictForCSRSingleRowFastInit](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForCSRSingleRowFastInit)
* [LGBM_BoosterPredictForFile](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForFile)
* [LGBM_BoosterPredictForMats](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMats)
* [LGBM_BoosterPredictForMatSingleRowFast](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRowFast)
* [LGBM_BoosterPredictForMatSingleRowFastInit](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRowFastInit)
* [LGBM_BoosterPredictSparseOutput](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictSparseOutput)
* [LGBM_BoosterRefit](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterRefit)
* [LGBM_BoosterResetParameter](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterResetParameter)
Expand Down
112 changes: 111 additions & 1 deletion src/main/java/io/github/metarank/lightgbm4j/LGBMBooster.java
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,117 @@ public double predictForMatSingleRow(float[] data, PredictionType predictionType
}
}

public FastConfig predictForMatSingleRowFastInit(PredictionType predictionType, int dtype, int ncols, String parameter) throws LGBMException {
if (!isClosed) {
SWIGTYPE_p_p_void out = voidpp_handle();

int result = LGBM_BoosterPredictForMatSingleRowFastInit(
voidpp_value(handle),
predictionType.getType(),
0,
iterations,
dtype,
ncols,
parameter,
out
);
if (result < 0) {
delete_voidpp(out);
throw new LGBMException(LGBM_GetLastError());
} else {
return new FastConfig(out);
}
} else {
throw new LGBMException("Booster was already closed");
}
}

public static class FastConfig implements AutoCloseable {
public SWIGTYPE_p_p_void handle;
public FastConfig(SWIGTYPE_p_p_void handle) {
this.handle = handle;
}

@Override
public void close() throws Exception {
delete_voidpp(handle);
}
}
public double predictForMatSingleRowFast(FastConfig config, float[] data, PredictionType predictionType) throws LGBMException {
if (!isClosed) {
SWIGTYPE_p_float dataBuffer = new_floatArray(data.length);
for (int i = 0; i < data.length; i++) {
floatArray_setitem(dataBuffer, i, data[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
long outBufferSize = outBufferSize(1, data.length, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outBufferSize);

int result = LGBM_BoosterPredictForMatSingleRowFast(
voidpp_value(config.handle),
float_to_voidp_ptr(dataBuffer),
outLength,
outBuffer
);
if (result < 0) {
delete_floatArray(dataBuffer);
delete_doubleArray(outBuffer);
delete_int64_tp(outLength);
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
delete_floatArray(dataBuffer);
delete_int64_tp(outLength);
delete_doubleArray(outBuffer);
return values[0];
}
} else {
throw new LGBMException("Booster was already closed");
}
}

public double predictForMatSingleRowFast(FastConfig config, double[] data, PredictionType predictionType) throws LGBMException {
if (!isClosed) {
SWIGTYPE_p_double dataBuffer = new_doubleArray(data.length);
for (int i = 0; i < data.length; i++) {
doubleArray_setitem(dataBuffer, i, data[i]);
}
SWIGTYPE_p_long_long outLength = new_int64_tp();
long outBufferSize = outBufferSize(1, data.length, predictionType);
SWIGTYPE_p_double outBuffer = new_doubleArray(outBufferSize);

int result = LGBM_BoosterPredictForMatSingleRowFast(
voidpp_value(config.handle),
double_to_voidp_ptr(dataBuffer),
outLength,
outBuffer
);
if (result < 0) {
delete_doubleArray(dataBuffer);
delete_doubleArray(outBuffer);
delete_int64_tp(outLength);
throw new LGBMException(LGBM_GetLastError());
} else {
long length = int64_tp_value(outLength);
double[] values = new double[(int) length];
for (int i = 0; i < length; i++) {
values[i] = doubleArray_getitem(outBuffer, i);
}
delete_doubleArray(dataBuffer);
delete_int64_tp(outLength);
delete_doubleArray(outBuffer);
return values[0];
}
} else {
throw new LGBMException("Booster was already closed");
}
}


private int importanceType(FeatureImportanceType tpe) {
int importanceType = C_API_FEATURE_IMPORTANCE_GAIN;
switch (tpe) {
Expand Down Expand Up @@ -828,7 +939,6 @@ public boolean updateOneIterCustom(float[] grad, float[] hess) throws LGBMExcept
} else {
throw new LGBMException("Booster was already closed");
}

}


Expand Down
30 changes: 30 additions & 0 deletions src/test/java/io/github/metarank/lightgbm4j/LGBMBoosterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.util.Random;

import static com.microsoft.ml.lightgbm.lightgbmlibConstants.C_API_DTYPE_FLOAT32;
import static com.microsoft.ml.lightgbm.lightgbmlibConstants.C_API_DTYPE_FLOAT64;
import static org.junit.jupiter.api.Assertions.*;

public class LGBMBoosterTest {
Expand Down Expand Up @@ -354,6 +356,34 @@ void testCreateByReference() throws LGBMException {
assertThrows(LGBMException.class, () -> booster.getPredict(0));
}

@Test void testPredictFastFloat() throws LGBMException, Exception {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
booster.updateOneIter();
booster.updateOneIter();
booster.updateOneIter();
LGBMBooster.FastConfig config = booster.predictForMatSingleRowFastInit(PredictionType.C_API_PREDICT_NORMAL, C_API_DTYPE_FLOAT32,9, "");
double pred = booster.predictForMatSingleRowFast(config, new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, PredictionType.C_API_PREDICT_NORMAL);
assertTrue(Double.isFinite(pred));
config.close();
dataset.close();
booster.close();
}

@Test void testPredictFastDouble() throws LGBMException, Exception {
LGBMDataset dataset = LGBMDataset.createFromFile("src/test/resources/cancer.csv", "header=true label=name:Classification", null);
LGBMBooster booster = LGBMBooster.create(dataset, "objective=binary label=name:Classification");
booster.updateOneIter();
booster.updateOneIter();
booster.updateOneIter();
LGBMBooster.FastConfig config = booster.predictForMatSingleRowFastInit(PredictionType.C_API_PREDICT_NORMAL, C_API_DTYPE_FLOAT64,9, "");
double pred = booster.predictForMatSingleRowFast(config, new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, PredictionType.C_API_PREDICT_NORMAL);
assertTrue(Double.isFinite(pred));
config.close();
dataset.close();
booster.close();
}

private float[] randomArray(int size) {
float[] result = new float[size];
Random rnd = new Random();
Expand Down

0 comments on commit 215aa34

Please sign in to comment.