Skip to content

Commit

Permalink
Reacquire GIL before constructing tuples containing numpy arrays.
Browse files Browse the repository at this point in the history
... as explicitly required by pybind11.  Compile with `#define
PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF` to observe GIL assertion
failures in absence of this patch.
  • Loading branch information
anntzer committed Oct 7, 2024
1 parent 0532247 commit cd5af66
Showing 1 changed file with 58 additions and 53 deletions.
111 changes: 58 additions & 53 deletions ext/_hmmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,37 +63,39 @@ std::tuple<double, py::array_t<double>, py::array_t<double>> forward_scaling(
auto scaling_ = py::array_t<double>{{ns}};
auto scaling = scaling_.mutable_unchecked<1>();
auto log_prob = 0.;
py::gil_scoped_release nogil;
std::fill_n(fwd.mutable_data(0, 0), fwd.size(), 0);
for (auto i = 0; i < nc; ++i) {
fwd(0, i) = startprob(i) * frameprob(0, i);
}
auto sum = std::accumulate(&fwd(0, 0), &fwd(0, nc), 0.);
if (sum < min_sum) {
throw std::range_error{"forward pass failed with underflow; "
"consider using implementation='log' instead"};
}
auto scale = scaling(0) = 1. / sum;
log_prob -= std::log(scale);
for (auto i = 0; i < nc; ++i) {
fwd(0, i) *= scale;
}
for (auto t = 1; t < ns; ++t) {
for (auto j = 0; j < nc; ++j) {
for (auto i = 0; i < nc; ++i) {
fwd(t, j) += fwd(t - 1, i) * transmat(i, j);
}
fwd(t, j) *= frameprob(t, j);
{
py::gil_scoped_release nogil;
std::fill_n(fwd.mutable_data(0, 0), fwd.size(), 0);
for (auto i = 0; i < nc; ++i) {
fwd(0, i) = startprob(i) * frameprob(0, i);
}
auto sum = std::accumulate(&fwd(t, 0), &fwd(t, nc), 0.);
auto sum = std::accumulate(&fwd(0, 0), &fwd(0, nc), 0.);
if (sum < min_sum) {
throw std::range_error{"forward pass failed with underflow; "
"consider using implementation='log' instead"};
}
auto scale = scaling(t) = 1. / sum;
auto scale = scaling(0) = 1. / sum;
log_prob -= std::log(scale);
for (auto j = 0; j < nc; ++j) {
fwd(t, j) *= scale;
for (auto i = 0; i < nc; ++i) {
fwd(0, i) *= scale;
}
for (auto t = 1; t < ns; ++t) {
for (auto j = 0; j < nc; ++j) {
for (auto i = 0; i < nc; ++i) {
fwd(t, j) += fwd(t - 1, i) * transmat(i, j);
}
fwd(t, j) *= frameprob(t, j);
}
auto sum = std::accumulate(&fwd(t, 0), &fwd(t, nc), 0.);
if (sum < min_sum) {
throw std::range_error{"forward pass failed with underflow; "
"consider using implementation='log' instead"};
}
auto scale = scaling(t) = 1. / sum;
log_prob -= std::log(scale);
for (auto j = 0; j < nc; ++j) {
fwd(t, j) *= scale;
}
}
}
return {log_prob, fwdlattice_, scaling_};
Expand All @@ -117,16 +119,18 @@ std::tuple<double, py::array_t<double>> forward_log(
auto buf = std::vector<double>(nc);
auto fwdlattice_ = py::array_t<double>{{ns, nc}};
auto fwd = fwdlattice_.mutable_unchecked<2>();
py::gil_scoped_release nogil;
for (auto i = 0; i < nc; ++i) {
fwd(0, i) = log_startprob(i) + log_frameprob(0, i);
}
for (auto t = 1; t < ns; ++t) {
for (auto j = 0; j < nc; ++j) {
for (auto i = 0; i < nc; ++i) {
buf[i] = fwd(t - 1, i) + log_transmat(i, j);
{
py::gil_scoped_release nogil;
for (auto i = 0; i < nc; ++i) {
fwd(0, i) = log_startprob(i) + log_frameprob(0, i);
}
for (auto t = 1; t < ns; ++t) {
for (auto j = 0; j < nc; ++j) {
for (auto i = 0; i < nc; ++i) {
buf[i] = fwd(t - 1, i) + log_transmat(i, j);
}
fwd(t, j) = logsumexp(buf.data(), nc) + log_frameprob(t, j);
}
fwd(t, j) = logsumexp(buf.data(), nc) + log_frameprob(t, j);
}
}
auto log_prob = logsumexp(&fwd(ns - 1, 0), nc);
Expand Down Expand Up @@ -290,30 +294,31 @@ std::tuple<double, py::array_t<ssize_t>> viterbi(
auto viterbi_lattice_ = py::array_t<double>{{ns, nc}};
auto state_sequence = state_sequence_.mutable_unchecked<1>();
auto viterbi_lattice = viterbi_lattice_.mutable_unchecked<2>();
py::gil_scoped_release nogil;
for (auto i = 0; i < nc; ++i) {
viterbi_lattice(0, i) = log_startprob(i) + log_frameprob(0, i);
}
for (auto t = 1; t < ns; ++t) {
{
py::gil_scoped_release nogil;
for (auto i = 0; i < nc; ++i) {
auto max = -std::numeric_limits<double>::infinity();
for (auto j = 0; j < nc; ++j) {
max = std::max(max, viterbi_lattice(t - 1, j) + log_transmat(j, i));
viterbi_lattice(0, i) = log_startprob(i) + log_frameprob(0, i);
}
for (auto t = 1; t < ns; ++t) {
for (auto i = 0; i < nc; ++i) {
auto max = -std::numeric_limits<double>::infinity();
for (auto j = 0; j < nc; ++j) {
max = std::max(max, viterbi_lattice(t - 1, j) + log_transmat(j, i));
}
viterbi_lattice(t, i) = max + log_frameprob(t, i);
}
viterbi_lattice(t, i) = max + log_frameprob(t, i);
}
}
auto row = &viterbi_lattice(ns - 1, 0);
auto prev = state_sequence(ns - 1) = std::max_element(row, row + nc) - row;
auto log_prob = row[prev];
for (auto t = ns - 2; t >= 0; --t) {
auto max = std::make_pair(-std::numeric_limits<double>::infinity(), 0);
for (auto i = 0; i < nc; ++i) {
max = std::max(max, {viterbi_lattice(t, i) + log_transmat(i, prev), i});
auto row = &viterbi_lattice(ns - 1, 0);
auto prev = state_sequence(ns - 1) = std::max_element(row, row + nc) - row;
for (auto t = ns - 2; t >= 0; --t) {
auto max = std::make_pair(-std::numeric_limits<double>::infinity(), 0);
for (auto i = 0; i < nc; ++i) {
max = std::max(max, {viterbi_lattice(t, i) + log_transmat(i, prev), i});
}
state_sequence(t) = prev = max.second;
}
state_sequence(t) = prev = max.second;
}
return {log_prob, state_sequence_};
return {viterbi_lattice(ns - 1, state_sequence(ns - 1)), state_sequence_};
}

PYBIND11_MODULE(_hmmc, m) {
Expand Down

0 comments on commit cd5af66

Please sign in to comment.