forked from davisking/dlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added initial python bindings for dlib
- Loading branch information
Showing
6 changed files
with
415 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
CMAKE_MINIMUM_REQUIRED(VERSION 2.6) | ||
|
||
include(../../dlib/add_python_module) | ||
|
||
add_python_module(dlib | ||
src/dlib.cpp | ||
src/matrix.cpp | ||
src/vector.cpp | ||
src/svm_c_trainer.cpp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#include <boost/python.hpp> | ||
#include <dlib/matrix.h> | ||
#include <sstream> | ||
#include <string> | ||
#include <boost/python/suite/indexing/vector_indexing_suite.hpp> | ||
#include <boost/python/suite/indexing/map_indexing_suite.hpp> | ||
#include <boost/python/suite/indexing/indexing_suite.hpp> | ||
#include <boost/shared_ptr.hpp> | ||
|
||
#include <dlib/string.h> | ||
#include "serialize_pickle.h" | ||
|
||
using namespace std; | ||
using namespace dlib; | ||
using namespace boost::python; | ||
|
||
|
||
void bind_matrix(); | ||
void bind_vector(); | ||
void bind_svm_c_trainer(); | ||
|
||
BOOST_PYTHON_MODULE(dlib) | ||
{ | ||
bind_matrix(); | ||
bind_vector(); | ||
bind_svm_c_trainer(); | ||
|
||
class_<std::vector<double> >("array") | ||
.def(vector_indexing_suite<std::vector<double> >()) | ||
.def_pickle(serialize_pickle<std::vector<double> >()); | ||
|
||
class_<std::vector<matrix<double,0,1> > >("vectors") | ||
.def(vector_indexing_suite<std::vector<matrix<double,0,1> > >()) | ||
.def_pickle(serialize_pickle<std::vector<matrix<double,0,1> > >()); | ||
|
||
typedef pair<unsigned long,double> pair_type; | ||
class_<pair_type>("pair", "help message", init<>() ) | ||
.def(init<unsigned long,double>()) | ||
.def_readwrite("first",&pair_type::first, "THE FIRST, LOVE IT!") | ||
.def_readwrite("second",&pair_type::second) | ||
.def_pickle(serialize_pickle<pair_type>()); | ||
|
||
class_<std::vector<pair_type> >("sparse_vector") | ||
.def(vector_indexing_suite<std::vector<pair_type> >()) | ||
.def_pickle(serialize_pickle<std::vector<pair_type> >()); | ||
|
||
class_<std::vector<std::vector<pair_type> > >("sparse_vectors") | ||
.def(vector_indexing_suite<std::vector<std::vector<pair_type> > >()) | ||
.def_pickle(serialize_pickle<std::vector<std::vector<pair_type> > >()); | ||
|
||
/* | ||
def("tomat",tomat); | ||
def("add_to_map", add_to_map); | ||
def("getpair", getpair); | ||
def("getmatrix", getmatrix); | ||
def("yay", yay); | ||
def("sum", sum_mat); | ||
def("getmap", getmap); | ||
def("go", go); | ||
def("append_to_vector", append_to_vector); | ||
*/ | ||
|
||
|
||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
|
||
#include <boost/python.hpp> | ||
#include <boost/shared_ptr.hpp> | ||
#include <dlib/matrix.h> | ||
#include "serialize_pickle.h" | ||
|
||
|
||
using namespace dlib; | ||
using namespace std; | ||
using namespace boost::python; | ||
|
||
|
||
void matrix_set_size(matrix<double>& m, long nr, long nc) | ||
{ | ||
m.set_size(nr,nc); | ||
m = 0; | ||
} | ||
|
||
string matrix_double__str__(matrix<double>& c) | ||
{ | ||
ostringstream sout; | ||
sout << c; | ||
return sout.str(); | ||
} | ||
|
||
boost::shared_ptr<matrix<double> > make_matrix_from_size(long nr, long nc) | ||
{ | ||
boost::shared_ptr<matrix<double> > temp(new matrix<double>(nr,nc)); | ||
*temp = 0; | ||
return temp; | ||
} | ||
|
||
|
||
boost::shared_ptr<matrix<double> > from_object(object obj) | ||
{ | ||
tuple s = extract<tuple>(obj.attr("shape")); | ||
if (len(s) != 2) | ||
{ | ||
PyErr_SetString( PyExc_IndexError, "Input must be a matrix or some kind of 2D array." | ||
); | ||
boost::python::throw_error_already_set(); | ||
} | ||
|
||
const long nr = extract<long>(s[0]); | ||
const long nc = extract<long>(s[1]); | ||
boost::shared_ptr<matrix<double> > temp(new matrix<double>(nr,nc)); | ||
for ( long r = 0; r < nr; ++r) | ||
{ | ||
for (long c = 0; c < nc; ++c) | ||
{ | ||
(*temp)(r,c) = extract<double>(obj[make_tuple(r,c)]); | ||
} | ||
} | ||
return temp; | ||
} | ||
|
||
long matrix_double__len__(matrix<double>& c) | ||
{ | ||
return c.nr(); | ||
} | ||
|
||
|
||
struct mat_row | ||
{ | ||
mat_row() : data(0),size(0) {} | ||
mat_row(double* data_, long size_) : data(data_),size(size_) {} | ||
double* data; | ||
long size; | ||
}; | ||
|
||
void mat_row__setitem__(mat_row& c, long p, double val) | ||
{ | ||
if (p < 0) { | ||
p = c.size + p; // negative index | ||
} | ||
if (p > c.size-1) { | ||
PyErr_SetString( PyExc_IndexError, "3 index out of range" | ||
); | ||
boost::python::throw_error_already_set(); | ||
} | ||
c.data[p] = val; | ||
} | ||
|
||
|
||
string mat_row__str__(mat_row& c) | ||
{ | ||
ostringstream sout; | ||
sout << mat(c.data,1, c.size); | ||
return sout.str(); | ||
} | ||
|
||
long mat_row__len__(mat_row& m) | ||
{ | ||
return m.size; | ||
} | ||
|
||
double mat_row__getitem__(mat_row& m, long r) | ||
{ | ||
if (r < 0) { | ||
r = m.size + r; // negative index | ||
} | ||
if (r > m.size-1 || r < 0) { | ||
PyErr_SetString( PyExc_IndexError, "1 index out of range" | ||
); | ||
boost::python::throw_error_already_set(); | ||
} | ||
return m.data[r]; | ||
} | ||
|
||
mat_row matrix_double__getitem__(matrix<double>& m, long r) | ||
{ | ||
if (r < 0) { | ||
r = m.nr() + r; // negative index | ||
} | ||
if (r > m.nr()-1 || r < 0) { | ||
PyErr_SetString( PyExc_IndexError, (string("2 index out of range, got ") + cast_to_string(r)).c_str() | ||
); | ||
boost::python::throw_error_already_set(); | ||
} | ||
return mat_row(&m(r,0),m.nc()); | ||
} | ||
|
||
|
||
tuple get_matrix_size(matrix<double>& m) | ||
{ | ||
return make_tuple(m.nr(), m.nc()); | ||
} | ||
|
||
void bind_matrix() | ||
{ | ||
class_<mat_row>("_row") | ||
.def("__len__", &mat_row__len__) | ||
.def("__repr__", &mat_row__str__) | ||
.def("__str__", &mat_row__str__) | ||
.def("__setitem__", &mat_row__setitem__) | ||
.def("__getitem__", &mat_row__getitem__); | ||
|
||
class_<matrix<double> >("matrix", init<>()) | ||
.def("__init__", make_constructor(&make_matrix_from_size)) | ||
.def("set_size", &matrix_set_size) | ||
.def("__init__", make_constructor(&from_object)) | ||
.def("__repr__", &matrix_double__str__) | ||
.def("__str__", &matrix_double__str__) | ||
.def("__len__", &matrix_double__len__) | ||
.def("__getitem__", &matrix_double__getitem__, with_custodian_and_ward_postcall<0,1>()) | ||
.add_property("shape", &get_matrix_size) | ||
.def_pickle(serialize_pickle<matrix<double> >()); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#ifndef DLIB_SERIALIZE_PiCKLE_H__ | ||
#define DLIB_SERIALIZE_PiCKLE_H__ | ||
|
||
#include <dlib/serialize.h> | ||
#include <boost/python.hpp> | ||
#include <sstream> | ||
|
||
template <typename T> | ||
struct serialize_pickle : boost::python::pickle_suite | ||
{ | ||
static boost::python::tuple getstate( | ||
const T& item | ||
) | ||
{ | ||
using namespace dlib; | ||
std::ostringstream sout; | ||
serialize(item, sout); | ||
return boost::python::make_tuple(sout.str()); | ||
} | ||
|
||
static void setstate( | ||
T& item, | ||
boost::python::tuple state | ||
) | ||
{ | ||
using namespace dlib; | ||
using namespace boost::python; | ||
if (len(state) != 1) | ||
{ | ||
PyErr_SetObject(PyExc_ValueError, | ||
("expected 1-item tuple in call to __setstate__; got %s" | ||
% state).ptr() | ||
); | ||
throw_error_already_set(); | ||
} | ||
|
||
std::string& data = extract<std::string&>(state[0]); | ||
std::istringstream sin(data); | ||
deserialize(item, sin); | ||
} | ||
}; | ||
|
||
#endif // DLIB_SERIALIZE_PiCKLE_H__ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
|
||
#include <boost/python.hpp> | ||
#include <boost/shared_ptr.hpp> | ||
#include <dlib/matrix.h> | ||
#include "serialize_pickle.h" | ||
#include <dlib/svm.h> | ||
|
||
using namespace dlib; | ||
using namespace std; | ||
using namespace boost::python; | ||
|
||
typedef matrix<double,0,1> sample_type; | ||
|
||
|
||
|
||
template <typename kernel_type> | ||
void bind_kernel( | ||
) | ||
{ | ||
typedef svm_c_trainer<kernel_type> trainer; | ||
class_<trainer>("svm_c_trainer") | ||
.def("train", &trainer::template train<std::vector<sample_type>,std::vector<double> >); | ||
|
||
typedef decision_function<kernel_type> df; | ||
class_<df>("df") | ||
.def("predict", &df::operator()); | ||
} | ||
|
||
|
||
void bind_svm_c_trainer() | ||
{ | ||
bind_kernel<linear_kernel<sample_type> >(); | ||
|
||
/* | ||
class_<cv>("vector", init<>()) | ||
.def("set_size", &cv_set_size) | ||
.def("__init__", make_constructor(&cv_from_object)) | ||
.def("__repr__", &cv__str__) | ||
.def("__str__", &cv__str__) | ||
.def("__len__", &cv__len__) | ||
.def("__getitem__", &cv__getitem__) | ||
.add_property("shape", &cv_get_matrix_size) | ||
.def_pickle(serialize_pickle<cv>()); | ||
*/ | ||
} | ||
|
||
|
Oops, something went wrong.