Skip to content

Commit

Permalink
Make XGBoost.train enable to set parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
titsuki committed Aug 10, 2021
1 parent bcf78c0 commit 8ce596a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lib/Algorithm/XGBoost.rakumod
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ method version(--> Version) {

my sub XGBoosterCreate(Algorithm::XGBoost::DMatrix is rw, ulong, Algorithm::XGBoost::Booster is rw --> int32) is native($library) { * }
my sub XGBoosterUpdateOneIter(Algorithm::XGBoost::Booster, int32, Algorithm::XGBoost::DMatrix --> int32) is native($library) { * }
my sub XGBoosterSetParam(Algorithm::XGBoost::Booster, Str, Str --> int32) is native($library) { * }

method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration --> Algorithm::XGBoost::Model) {
method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration, %param? --> Algorithm::XGBoost::Model) {
my $h = Pointer.new;
XGBoosterCreate($dmat, 1, $h);
my $booster = nativecast(Algorithm::XGBoost::Booster, $h);
for %param {
XGBoosterSetParam($booster, .key.Str, .value.Str);
}

for ^$num-iteration -> $iter {
XGBoosterUpdateOneIter($booster, $iter, $dmat);
Expand Down
12 changes: 12 additions & 0 deletions t/01-basic.rakutest
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,16 @@ subtest {
is $actual, $expected;
}, "When a model is given, Then .save/.load should retain the model";

subtest {
Algorithm::XGBoost.global-config(q[{"verbosity": 4}]);
my @train[3;2] = [[0e0,0e0],[0e0,1e0],[1e0,0e0]];
my @y = [1e0, 0e0, 1e0];
my $dmat = Algorithm::XGBoost::DMatrix.from-matrix(@train, @y);
# TODO: Get param from the model
lives-ok {
my %param = (:booster("dart"));
my $model = Algorithm::XGBoost.train($dmat, 10, %param);
}
}, "When a %param is given, Then .train should use it";

done-testing;

0 comments on commit 8ce596a

Please sign in to comment.