Skip to content

Commit

Permalink
fixed pfxt root nodes update bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Randy1005 committed Dec 7, 2023
1 parent e909fb9 commit 022f356
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 30 deletions.
41 changes: 29 additions & 12 deletions ink/ink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,10 +1249,6 @@ void Ink::_spur_incsfxt(
break;
}

if (_all_paths.size() >= K || _all_path_costs.size() >= K) {
break;
}


// expand search space
_spur(pfxt, *node);
Expand All @@ -1267,6 +1263,11 @@ void Ink::_spur_incsfxt(
else {
_all_path_costs.push_back(node->cost);
}


if (_all_paths.size() >= K || _all_path_costs.size() >= K) {
break;
}
}

auto end = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -1347,14 +1348,28 @@ void Ink::_spur_incremental(
}

void Ink::_mark_pfxt_nodes(Pfxt& pfxt) {
auto sfxt = *_global_sfxt;
std::queue<PfxtNode*> q;
for (auto s : pfxt.srcs) {
auto v = s->to;
// NOTE: we assume weight from S to primary inputs are 0
// but in real circuits this might change
auto w = 0.0f;
// update root pfxt node cost
if (v == *sfxt.successors[sfxt.S]) {
s->detour_cost = 0.0f;
s->cost = *sfxt.dist();
}
else {
s->detour_cost = *sfxt.dists[v] + w - *sfxt.dist();
s->cost = s->detour_cost + *sfxt.dist();
}

for (auto c : s->children) {
q.push(c);
}
}

auto sfxt = *_global_sfxt;
int spur_begin, to_spurs, total_spurs, updates_per_parent;
PfxtNode* curr_parent{nullptr};

Expand Down Expand Up @@ -1403,7 +1418,8 @@ void Ink::_mark_pfxt_nodes(Pfxt& pfxt) {
auto w = *eptr->weights[w_sel];
auto dc = *sfxt.dists[v] + w - *sfxt.dists[u];
node->detour_cost = dc + p->detour_cost;

node->cost = node->detour_cost + *sfxt.dist();

// finished updating a non-sfxt edge, decrement to_spurs
// if to_spurs == 0, spur_begin can advance to its successor
to_spurs--;
Expand Down Expand Up @@ -1542,12 +1558,7 @@ void Ink::_spur_incremental_v2(
valid = true;
}

if (valid &&
(_all_paths.size() >= K || _all_path_costs.size() >= K)) {
break;
}

if (node->removed) {
if (node->removed) {
continue;
}

Expand All @@ -1565,6 +1576,12 @@ void Ink::_spur_incremental_v2(
else {
_all_path_costs.push_back(node->cost);
}

if (valid &&
(_all_paths.size() >= K || _all_path_costs.size() >= K)) {
break;
}

}
end = std::chrono::steady_clock::now();
_rt_spur =
Expand Down
55 changes: 37 additions & 18 deletions main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ bool check_results(
const std::vector<float>& costs_inc,
const std::vector<float>& costs_full) {
for (size_t i = 0; i < costs_full.size(); i++) {
if (!equal(costs_inc[i], costs_full[i], 0.00001f)) {
if (!equal(costs_inc[i], costs_full[i], 0.001f)) {
std::cout << "mismatch at " << i << "-th path.\n";
std::cout << "cost_inc=" << costs_inc[i] << '\n';
std::cout << "cost_full=" << costs_full[i] << '\n';
Expand Down Expand Up @@ -47,7 +47,8 @@ int main(int argc, char* argv[]) {
auto es_inc = ink_inc.find_chain_edges();
std::vector<float> costs_full, costs_inc, cache_full, cache_inc;

std::ofstream ofs_prof("out-prof");
std::ofstream ofs_pts_full("out-path-traces-full");
std::ofstream ofs_pts_inc("out-path-traces-inc");
std::ofstream ofs_inc("out-inc");
std::ofstream ofs_full("out-full");
std::ofstream ofs_rt_full("out-runtime-distr-full");
Expand All @@ -57,13 +58,15 @@ int main(int argc, char* argv[]) {
auto paths_full = ink_full.report_incsfxt(num_paths, false, recover_path);
auto paths_inc = ink_inc.report_incsfxt(num_paths, true, recover_path);
if (recover_path) {
for (const auto& p : paths_full) {
costs_full.emplace_back(p.weight);
}
//for (const auto& p : paths_full) {
// costs_full.emplace_back(p.weight);
//}

for (const auto& p : paths_inc) {
costs_inc.emplace_back(p.weight);
}
//for (const auto& p : paths_inc) {
// costs_inc.emplace_back(p.weight);
//}
costs_full = ink_full.get_path_costs();
costs_inc = ink_inc.get_path_costs();
}
else {
costs_full = ink_full.get_path_costs();
Expand All @@ -76,6 +79,8 @@ int main(int argc, char* argv[]) {
std::exit(EXIT_FAILURE);
}

cache_full = std::move(costs_full);
cache_inc = std::move(costs_inc);
costs_full.clear();
costs_inc.clear();

Expand Down Expand Up @@ -128,21 +133,25 @@ int main(int argc, char* argv[]) {
paths_inc = ink_inc.report_incremental_v2(num_paths, true, recover_path);

if (recover_path) {
for (const auto& p : paths_full) {
costs_full.emplace_back(p.weight);
}
//for (const auto& p : paths_full) {
// costs_full.emplace_back(p.weight);
//}

for (const auto& p : paths_inc) {
costs_inc.emplace_back(p.weight);
}
//for (const auto& p : paths_inc) {
// costs_inc.emplace_back(p.weight);
//}
costs_full = ink_full.get_path_costs();
costs_inc = ink_inc.get_path_costs();
}
else {
costs_full = ink_full.get_path_costs();
costs_inc = ink_inc.get_path_costs();
}

// check results
bool check = check_results(costs_inc, costs_full);
cache_full = std::move(costs_full);
cache_inc = std::move(costs_inc);
bool check = check_results(cache_inc, cache_full);
if (!check) {
std::cout << "mismatch at iteration " << i << "!\n";
break;
Expand All @@ -151,8 +160,6 @@ int main(int argc, char* argv[]) {
ofs_rt_full << ink_full.pfxt_expansion_time * 1e-6 << '\n';
ofs_rt_inc << ink_inc.pfxt_expansion_time * 1e-6 << '\n';

cache_full = std::move(costs_full);
cache_inc = std::move(costs_inc);
costs_full.clear();
costs_inc.clear();
}
Expand All @@ -164,7 +171,19 @@ int main(int argc, char* argv[]) {
for (auto c : cache_inc) {
ofs_inc << c << '\n';
}


if (paths_full.size() != 0) {
for (size_t i = 0; i < paths_full.size(); i++) {
ofs_pts_full << "---- Path " << i << " ----\n";
paths_full[i].dump(ofs_pts_full);
}

for (size_t i = 0; i < paths_inc.size(); i++) {
ofs_pts_inc << "---- Path " << i << " ----\n";
paths_inc[i].dump(ofs_pts_inc);
}
}


return 0;
}
Expand Down
5 changes: 5 additions & 0 deletions unittests/inc_pfxt_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ TEST_CASE("Insert / Remove / Update Edges (two inc iterations)" * doctest::time
inc_full.report_incsfxt(10, false, false);
costs_full = inc_full.get_path_costs();
REQUIRE(costs_full.size() == 10);

for (auto c : costs_inc) {
std::cout << c << '\n';
}


for (size_t i = 0; i < costs_full.size(); i++) {
REQUIRE(float_equal(costs_inc[i], costs_full[i]));
Expand Down

0 comments on commit 022f356

Please sign in to comment.