Skip to content

Commit

Permalink
I needed to make a few minor changes to make this code work with the new
Browse files Browse the repository at this point in the history
version of dlib.
  • Loading branch information
davisking committed Sep 11, 2011
1 parent e3310d4 commit 07e4885
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tools/mltool/src/regression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,15 @@ krr_rbk_test (
gamma = gamma_range.get_next_value (gamma))
{
// LOO cross validation
double loo_error;
std::vector<double> loo_values;

if (parser.option("verbose")) {
trainer.set_search_lambdas(logspace(-9, 4, 100));
trainer.be_verbose();
}
trainer.set_kernel (kernel_type (gamma));
trainer.train (dense_samples, labels, loo_error);
trainer.train (dense_samples, labels, loo_values);
const double loo_error = mean_squared_error(loo_values, labels);
if (loo_error < best_loo) {
best_loo = loo_error;
best_gamma = gamma;
Expand Down Expand Up @@ -237,9 +238,12 @@ krr_lin_test (
krr_trainer<kernel_type> trainer;

// LOO cross validation
double loo_error;
trainer.train(dense_samples, labels, loo_error);
std::vector<double> loo_values;
trainer.train(dense_samples, labels, loo_values);
const double loo_error = mean_squared_error(loo_values, labels);
const double rs = r_squared(loo_values, labels);
std::cout << "mean squared LOO error: " << loo_error << std::endl;
std::cout << "R-Squared LOO: " << rs << std::endl;
}

// ----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -343,11 +347,11 @@ svr_test (
gamma = gamma_range.get_next_value (gamma))
{
cout << "test with svr-C: " << svr_c << " gamma: "<< gamma << flush;
double cv_error;
matrix<double,1,2> cv;
trainer.set_kernel (kernel_type (gamma));
cv_error = cross_validate_regression_trainer (trainer,
dense_samples, labels, 10);
cout << " 10-fold-MSE: "<< cv_error << endl;
cv = cross_validate_regression_trainer (trainer, dense_samples, labels, 10);
cout << " 10-fold-MSE: "<< cv(0) << endl;
cout << " 10-fold-R-Squared: "<< cv(1) << endl;
}
}
}
Expand Down

0 comments on commit 07e4885

Please sign in to comment.