Skip to content

Commit

Permalink
Merge pull request #15 from titsuki/save-load-model
Browse files Browse the repository at this point in the history
Add .save/.load method for Model
  • Loading branch information
titsuki authored Aug 8, 2021
2 parents deb3b3d + 12a19cc commit 608ecd3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
14 changes: 14 additions & 0 deletions lib/Algorithm/XGBoost/Model.rakumod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ unit class Algorithm::XGBoost::Model:ver<0.0.3>:auth<cpan:TITSUKI> is repr('CPoi
my constant $library = %?RESOURCES<libraries/xgboost>.Str;
my sub XGBoosterGetNumFeature(Algorithm::XGBoost::Model, ulong is rw --> int32) is native($library) { * }
my sub XGBoosterPredict(Algorithm::XGBoost::Model, Algorithm::XGBoost::DMatrix, int32, uint32, int32, ulong is rw, Pointer[num32] is rw --> int32) is native($library) { * }
my sub XGBoosterLoadModel(Algorithm::XGBoost::Model, Str --> int32) is native($library) { * }
my sub XGBoosterSaveModel(Algorithm::XGBoost::Model, Str --> int32) is native($library) { * }

method new {!!!}

Expand All @@ -27,3 +29,15 @@ method predict(Algorithm::XGBoost::DMatrix $dmat, Int $option-mask = 0, Int $ntr
nativecast(CArray[$ret.of], $ret)[^$size]
}

method save(Str $fname) {
XGBoosterSaveModel(self, $fname);
}

my sub XGBoosterCreate(Algorithm::XGBoost::DMatrix is rw, ulong, Algorithm::XGBoost::Booster is rw --> int32) is native($library) { * }

method load(::?CLASS:U $this: Str $fname --> ::?CLASS) {
my $h = Pointer.new;
XGBoosterCreate(Pointer[void].new, 0, $h);
XGBoosterLoadModel($h, $fname);
nativecast(Algorithm::XGBoost::Model, $h);
}
16 changes: 16 additions & 0 deletions t/01-basic.rakutest
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,20 @@ subtest {
is from-json($json)<verbosity>, 2;
}, ".global-config should set/get verbosity=2";

subtest {
my @train[3;2] = [[0e0,0e0],[0e0,1e0],[1e0,0e0]];
my @y = [1e0, 0e0, 1e0];
my $dmat = Algorithm::XGBoost::DMatrix.from-matrix(@train, @y);
my $model = Algorithm::XGBoost.train($dmat, 10);
my @test[2;2] = [[0e0,0e0],[0e0,1e0]];
my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test);
my $expected = $model.predict($test);
my $path = "$*TMPDIR/.raku-xgboost.model";
$model.save($path);
my $model2 = Algorithm::XGBoost::Model.load($path);
my $actual = $model2.predict($test);
shell("rm $path");
is $actual, $expected;
}, "When a model is given, Then .save/.load should retain the model";

done-testing;

0 comments on commit 608ecd3

Please sign in to comment.