Skip to content

Commit

Permalink
Made set_prior() work with sparse vectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
davisking committed May 24, 2014
1 parent 05c0b37 commit 721597f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
35 changes: 25 additions & 10 deletions dlib/svm/svm_c_linear_trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace dlib
const in_scalar_vector_type& labels_,
const bool be_verbose_,
const scalar_type eps_,
const unsigned long max_iter
const unsigned long max_iter,
const unsigned long dims_
) :
samples(samples_),
labels(labels_),
Expand All @@ -53,7 +54,8 @@ namespace dlib
Cneg(C_neg/C),
be_verbose(be_verbose_),
eps(eps_),
max_iterations(max_iter)
max_iterations(max_iter),
dims(dims_)
{
dot_prods.resize(samples.size());
is_first_call = true;
Expand All @@ -69,7 +71,7 @@ namespace dlib
) const
{
// plus 1 for the bias term
return max_index_plus_one(samples) + 1;
return dims + 1;
}

virtual bool optimization_status (
Expand Down Expand Up @@ -300,6 +302,7 @@ namespace dlib
const bool be_verbose;
const scalar_type eps;
const unsigned long max_iterations;
const unsigned long dims;
};

// ----------------------------------------------------------------------------------------
Expand All @@ -317,11 +320,12 @@ namespace dlib
const in_scalar_vector_type& labels,
const bool be_verbose,
const scalar_type eps,
const unsigned long max_iterations
const unsigned long max_iterations,
const unsigned long dims
)
{
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(
C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations);
C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations, dims);
}

// ----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -478,7 +482,8 @@ namespace dlib
<< "\n\t this: " << this
);

prior = join_cols(prior_.basis_vectors(0), mat((scalar_type)prior_.b));
prior = sparse_to_dense(prior_.basis_vectors(0));
prior_b = prior_.b;
learn_nonnegative_weights = false;
last_weight_1 = false;
}
Expand Down Expand Up @@ -631,23 +636,32 @@ namespace dlib
if (is_matrix<sample_type>::value)
{
// make sure requires clause is not broken
DLIB_CASSERT(num_dims+1 == (unsigned long)prior.size(),
DLIB_CASSERT(num_dims == (unsigned long)prior.size(),
"\t decision_function svm_c_linear_trainer::train(x,y)"
<< "\n\t The dimension of the training vectors must match the dimension of\n"
<< "\n\t those used to create the prior."
<< "\n\t num_dims: " << num_dims
<< "\n\t prior.size(): " << prior.size()
);
}
const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
// In the case of sparse sample vectors, it is possible that the input
// vector dimensionality is larger than the prior vector dimensionality.
// We need to check for this case and pad prior with zeros if it is the
// case.
matrix<scalar_type,0,1> prior_temp = join_cols(join_cols(prior,
zeros_matrix<scalar_type>(dims-prior.size(),1)),
mat(prior_b));

svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, dims),
w,
prior);
prior_temp);
}
else
{
svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, num_dims),
w,
num_nonnegative,
force_weight_1_idx);
Expand Down Expand Up @@ -678,6 +692,7 @@ namespace dlib
bool learn_nonnegative_weights;
bool last_weight_1;
matrix<scalar_type,0,1> prior;
scalar_type prior_b;
};

// ----------------------------------------------------------------------------------------
Expand Down
26 changes: 13 additions & 13 deletions dlib/test/oca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,42 +66,42 @@ namespace
oca solver;

// test the version without a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 0);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

w_type prior = true_w;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior);
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

prior = 0,0,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior);
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

prior = -1,1,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior);
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w);
true_w = -1.0, 1.0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

prior = -0.2,0.2,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior);
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

prior = -10.2,-1,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior);
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w);
true_w = -10.2, -1.0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
Expand All @@ -110,7 +110,7 @@ namespace
print_spinner();

// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 9999);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 9999);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
Expand All @@ -126,7 +126,7 @@ namespace
print_spinner();

// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 2);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
Expand All @@ -136,7 +136,7 @@ namespace


// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 1);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
Expand All @@ -151,31 +151,31 @@ namespace
y.push_back(+1);


solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 0);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

print_spinner();

solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 1);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

print_spinner();

solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 2);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);

print_spinner();

solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 5);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 5);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
Expand Down
39 changes: 39 additions & 0 deletions dlib/test/svm_c_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,44 @@ namespace
DLIB_TEST(df.basis_vectors(0)(2) > 0);
}

void run_prior_sparse_test()
{
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;

svm_c_linear_trainer<kernel_type> trainer;

std::vector<sample_type> samples;
std::vector<double> labels;

sample_type samp;
samp[0] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();

trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(samples, labels);

trainer.set_prior(df);

samples.clear();
labels.clear();
samp[2] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();

df = trainer.train(samples, labels);

matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
dlog << LINFO << rs;
DLIB_TEST(rs(0) == 1);
DLIB_TEST(rs(1) == 1);

matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
dlog << LINFO << trans(w);
DLIB_TEST(w(0) > 0.1);
DLIB_TEST(w(1) < -0.1);
DLIB_TEST(w(2) > 0.1);
}

void get_simple_points (
std::vector<sample_type>& samples,
std::vector<double>& labels
Expand Down Expand Up @@ -255,6 +293,7 @@ namespace
test_dense();
test_sparse();
run_prior_test();
run_prior_sparse_test();

// test mixed sparse and dense dot products
{
Expand Down

0 comments on commit 721597f

Please sign in to comment.