Skip to content

Commit

Permalink
Added more python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
davisking committed Apr 27, 2013
1 parent affd197 commit e0c9bb6
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 22 deletions.
59 changes: 55 additions & 4 deletions tools/python/src/decision_funcions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ double predict (
else if (df.basis_vectors(0).size() != samp.size())
{
std::ostringstream sout;
sout << "Input vector should have " << df.basis_vectors(0).size() << " dimensions, not " << samp.size() << ".";
PyErr_SetString( PyExc_IndexError, sout.str().c_str() );
sout << "Input vector should have " << df.basis_vectors(0).size()
<< " dimensions, not " << samp.size() << ".";
PyErr_SetString( PyExc_ValueError, sout.str().c_str() );
boost::python::throw_error_already_set();
}
return df(samp);
Expand All @@ -43,12 +44,62 @@ void add_df (
.def_pickle(serialize_pickle<df_type>());
}

template <typename df_type>
typename df_type::sample_type get_weights(
const df_type& df
)
{
if (df.basis_vectors.size() == 0)
{
PyErr_SetString( PyExc_ValueError, "Decision function is empty." );
boost::python::throw_error_already_set();
}
df_type temp = simplify_linear_decision_function(df);
return temp.basis_vectors(0);
}

template <typename df_type>
typename df_type::scalar_type get_bias(
const df_type& df
)
{
if (df.basis_vectors.size() == 0)
{
PyErr_SetString( PyExc_ValueError, "Decision function is empty." );
boost::python::throw_error_already_set();
}
return df.b;
}

template <typename kernel_type>
void add_linear_df (
const std::string name
)
{
typedef decision_function<kernel_type> df_type;
class_<df_type>(name.c_str())
.def("predict", predict<df_type>)
.def("get_weights", get_weights<df_type>)
.def("get_bias", get_bias<df_type>)
.def_pickle(serialize_pickle<df_type>());
}

void bind_decision_functions()
{
add_df<linear_kernel<sample_type> >("_decision_function_linear");
add_df<sparse_linear_kernel<sparse_vect> >("_decision_function_sparse_linear");
add_linear_df<linear_kernel<sample_type> >("_decision_function_linear");
add_linear_df<sparse_linear_kernel<sparse_vect> >("_decision_function_sparse_linear");

add_df<histogram_intersection_kernel<sample_type> >("_decision_function_histogram_intersection");
add_df<sparse_histogram_intersection_kernel<sparse_vect> >("_decision_function_sparse_histogram_intersection");

add_df<polynomial_kernel<sample_type> >("_decision_function_polynomial");
add_df<sparse_polynomial_kernel<sparse_vect> >("_decision_function_sparse_polynomial");

add_df<radial_basis_kernel<sample_type> >("_decision_function_radial_basis");
add_df<sparse_radial_basis_kernel<sparse_vect> >("_decision_function_sparse_radial_basis");

add_df<sigmoid_kernel<sample_type> >("_decision_function_sigmoid");
add_df<sparse_sigmoid_kernel<sparse_vect> >("_decision_function_sparse_sigmoid");
}


Expand Down
17 changes: 17 additions & 0 deletions tools/python/src/dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ string sparse_vector__repr__ (const std::vector<std::pair<unsigned long,double>
return sout.str();
}

tuple get_training_data()
{
typedef matrix<double,0,1> sample_type;
std::vector<sample_type> samples;
std::vector<double> labels;

sample_type samp(3);
samp = 1,2,3;
samples.push_back(samp);
labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);

return make_tuple(samples, labels);
}

BOOST_PYTHON_MODULE(dlib)
{
Expand Down Expand Up @@ -138,6 +154,7 @@ BOOST_PYTHON_MODULE(dlib)
.def(vector_indexing_suite<std::vector<std::vector<pair_type> > >())
.def_pickle(serialize_pickle<std::vector<std::vector<pair_type> > >());

def("get_training_data",get_training_data);
/*
def("tomat",tomat);
def("add_to_map", add_to_map);
Expand Down
8 changes: 8 additions & 0 deletions tools/python/src/pyassert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


#define pyassert(_exp,_message) \
{if ( !(_exp) ) \
{ \
PyErr_SetString( PyExc_ValueError, _message ); \
boost::python::throw_error_already_set(); \
}}
135 changes: 117 additions & 18 deletions tools/python/src/svm_c_trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,140 @@
#include <dlib/matrix.h>
#include "serialize_pickle.h"
#include <dlib/svm.h>
#include "pyassert.h"

using namespace dlib;
using namespace std;
using namespace boost::python;

typedef matrix<double,0,1> sample_type;
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;

template <typename trainer_type>
typename trainer_type::trained_function_type train (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& samples,
const std::vector<double>& labels
)
{
pyassert(is_binary_classification_problem(samples,labels), "Invalid inputs");
return trainer.train(samples, labels);
}

template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
pyassert(eps > 0, "epsilon must be > 0");
trainer.set_epsilon(eps);
}

template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }


template <typename trainer_type>
void set_cache_size ( trainer_type& trainer, long cache_size)
{
pyassert(cache_size > 0, "cache size must be > 0");
trainer.set_cache_size(cache_size);
}

template <typename trainer_type>
long get_cache_size ( const trainer_type& trainer) { return trainer.get_cache_size(); }


template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
pyassert(C > 0, "C must be > 0");
trainer.set_c(C);
}

template <typename trainer_type>
void set_c_class1 ( trainer_type& trainer, double C)
{
pyassert(C > 0, "C must be > 0");
trainer.set_c_class1(C);
}

template <typename trainer_type>
void set_c_class2 ( trainer_type& trainer, double C)
{
pyassert(C > 0, "C must be > 0");
trainer.set_c_class2(C);
}

template <typename kernel_type>
void bind_kernel(
template <typename trainer_type>
double get_c_class1 ( const trainer_type& trainer) { return trainer.get_c_class1(); }
template <typename trainer_type>
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }

template <typename trainer_type>
class_<trainer_type> setup_trainer (
const std::string& name
)
{
return class_<trainer_type>(name.c_str())
.def("train", train<trainer_type>)
.def("set_c", set_c<trainer_type>)
.def("set_c_class1", set_c_class1<trainer_type>)
.def("set_c_class2", set_c_class2<trainer_type>)
.def("get_c_class1", get_c_class1<trainer_type>)
.def("get_c_class2", get_c_class2<trainer_type>)
.def("get_epsilon", get_epsilon<trainer_type>)
.def("set_epsilon", set_epsilon<trainer_type>)
.def("get_cache_size", get_cache_size<trainer_type>)
.def("set_cache_size", set_cache_size<trainer_type>);
}

void set_gamma (
svm_c_trainer<radial_basis_kernel<sample_type> >& trainer,
double gamma
)
{
pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(radial_basis_kernel<sample_type>(gamma));
}

double get_gamma (
const svm_c_trainer<radial_basis_kernel<sample_type> >& trainer
)
{
return trainer.get_kernel().gamma;
}

void set_gamma_sparse (
svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer,
double gamma
)
{
typedef svm_c_trainer<kernel_type> trainer;
class_<trainer>("svm_c_trainer")
.def("train", &trainer::template train<std::vector<sample_type>,std::vector<double> >);
pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(sparse_radial_basis_kernel<sparse_vect>(gamma));
}

double get_gamma_sparse (
const svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer
)
{
return trainer.get_kernel().gamma;
}


// ----------------------------------------------------------------------------------------

void bind_svm_c_trainer()
{
bind_kernel<linear_kernel<sample_type> >();

/*
class_<cv>("vector", init<>())
.def("set_size", &cv_set_size)
.def("__init__", make_constructor(&cv_from_object))
.def("__repr__", &cv__str__)
.def("__str__", &cv__str__)
.def("__len__", &cv__len__)
.def("__getitem__", &cv__getitem__)
.add_property("shape", &cv_get_matrix_size)
.def_pickle(serialize_pickle<cv>());
*/
setup_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >("svm_c_trainer_radial_basis")
.def("set_gamma", set_gamma)
.def("get_gamma", get_gamma);

setup_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >("svm_c_trainer_sparse_radial_basis")
.def("set_gamma", set_gamma_sparse)
.def("get_gamma", get_gamma_sparse);

setup_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >("svm_c_trainer_histogram_intersection");

setup_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >("svm_c_trainer_sparse_histogram_intersection");
}


0 comments on commit e0c9bb6

Please sign in to comment.