Skip to content

Commit

Permalink
add unit test for incremental equation edit distance with repair
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Dec 15, 2024
1 parent 31ee56c commit b529a58
Show file tree
Hide file tree
Showing 5 changed files with 469 additions and 109 deletions.
209 changes: 101 additions & 108 deletions src/ast/sls/sls_seq_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,67 +662,6 @@ namespace sls {
return d[n][m];
}

/**
* \brief edit distance with update calculation
*/
unsigned seq_plugin::edit_distance_with_updates(zstring const& a, bool_vector const& a_is_value, zstring const& b, bool_vector const& b_is_value) {
unsigned n = a.length();
unsigned m = b.length();
vector<unsigned_vector> d(n + 1); // edit distance
vector<unsigned_vector> u(n + 1); // edit distance with updates.
m_string_updates.reset();
for (unsigned i = 0; i <= n; ++i) {
d[i].resize(m + 1, 0);
u[i].resize(m + 1, 0);
}
for (unsigned i = 0; i <= n; ++i)
d[i][0] = i, u[i][0] = i;
for (unsigned j = 0; j <= m; ++j)
d[0][j] = j, u[0][j] = j;
for (unsigned j = 1; j <= m; ++j) {
for (unsigned i = 1; i <= n; ++i) {
if (a[i - 1] == b[j - 1]) {
d[i][j] = d[i - 1][j - 1];
u[i][j] = u[i - 1][j - 1];
}
else {
u[i][j] = 1 + std::min(u[i - 1][j], std::min(u[i][j - 1], u[i - 1][j - 1]));
d[i][j] = 1 + std::min(d[i - 1][j], std::min(d[i][j - 1], d[i - 1][j - 1]));

// TODO: take into account for a_is_value[i - 1] and b_is_value[j - 1]
// and whether index i-1, j-1 is at the boundary of an empty string variable.

if (d[i - 1][j] < u[i][j] && !a_is_value[i - 1]) {
m_string_updates.reset();
u[i][j] = d[i - 1][j];
}
if (d[i][j - 1] < u[i][j] && !b_is_value[i - 1]) {
m_string_updates.reset();
u[i][j] = d[i][j - 1];
}
if (d[i - 1][j - 1] < u[i][j] && (!a_is_value[i - 1] || !b_is_value[j - 1])) {
m_string_updates.reset();
u[i][j] = d[i - 1][j - 1];
}
if (d[i - 1][j] == u[i][j] && !a_is_value[i - 1]) {
add_string_update(side_t::left, op_t::del, i - 1, 0);
add_string_update(side_t::left, op_t::add, j - 1, i - 1);
}
if (d[i][j - 1] == u[i][j] && !b_is_value[j - 1]) {
add_string_update(side_t::right, op_t::del, j - 1, 0);
add_string_update(side_t::right, op_t::add, i - 1, j - 1);
}
if (d[i - 1][j - 1] == u[i][j] && !a_is_value[i - 1])
add_string_update(side_t::left, op_t::copy, j - 1, i - 1);

if (d[i - 1][j - 1] == u[i][j] && !b_is_value[j - 1])
add_string_update(side_t::right, op_t::copy, i - 1, j - 1);

}
}
}
return u[n][m];
}

void seq_plugin::add_edit_updates(ptr_vector<expr> const& w, zstring const& val, zstring const& val_other, uint_set const& chars) {
for (auto x : w) {
Expand Down Expand Up @@ -793,67 +732,124 @@ namespace sls {
#endif
}

void seq_plugin::init_string_instance(ptr_vector<expr> const& es, string_instance& a) {
bool prev_is_var = false;
for (auto x : es) {
auto const& val = strval0(x);
auto len = val.length();
bool is_val = is_value(x);
a.s += val;
if (!prev_is_var && !is_val && !a.next_is_var.empty())
a.next_is_var.back() = true;
for (unsigned i = 0; i < len; ++i) {
a.is_value.push_back(is_val);
a.prev_is_var.push_back(false);
a.next_is_var.push_back(false);
}
if (len > 0 && is_val && prev_is_var && !a.is_value.empty())
a.prev_is_var[a.prev_is_var.size() - len] = true;
prev_is_var = !is_val;
}
}


/**
* \brief edit distance with update calculation
*/
unsigned seq_plugin::edit_distance_with_updates(string_instance const& a, string_instance const& b) {
unsigned n = a.s.length();
unsigned m = b.s.length();
vector<unsigned_vector> d(n + 1); // edit distance
vector<unsigned_vector> u(n + 1); // edit distance with updates.
m_string_updates.reset();
for (unsigned i = 0; i <= n; ++i) {
d[i].resize(m + 1, 0);
u[i].resize(m + 1, 0);
}
for (unsigned i = 0; i <= n; ++i)
d[i][0] = i, u[i][0] = i;
for (unsigned j = 0; j <= m; ++j)
d[0][j] = j, u[0][j] = j;
for (unsigned j = 1; j <= m; ++j) {
for (unsigned i = 1; i <= n; ++i) {
if (a.s[i - 1] == b.s[j - 1]) {
d[i][j] = d[i - 1][j - 1];
u[i][j] = u[i - 1][j - 1];
}
else {
u[i][j] = 1 + std::min(u[i - 1][j], std::min(u[i][j - 1], u[i - 1][j - 1]));
d[i][j] = 1 + std::min(d[i - 1][j], std::min(d[i][j - 1], d[i - 1][j - 1]));

if (d[i - 1][j] < u[i][j] && a.can_add(i - 1)) {
m_string_updates.reset();
u[i][j] = d[i - 1][j];
}
if (d[i][j - 1] < u[i][j] && b.can_add(i - 1)) {
m_string_updates.reset();
u[i][j] = d[i][j - 1];
}
if (d[i - 1][j - 1] < u[i][j] && (a.can_add(i - 1) || b.can_add(j - 1))) {
m_string_updates.reset();
u[i][j] = d[i - 1][j - 1];
}
if (d[i - 1][j] == u[i][j] && a.can_add(i - 1))
add_string_update(side_t::left, op_t::add, j - 1, i - 1);

if (d[i][j - 1] == u[i][j] && b.can_add(j - 1))
add_string_update(side_t::right, op_t::add, i - 1, j - 1);

if (d[i - 1][j] == u[i][j] && !a.is_value[i - 1])
add_string_update(side_t::left, op_t::del, i - 1, 0);

if (d[i][j - 1] == u[i][j] && !b.is_value[j - 1])
add_string_update(side_t::right, op_t::del, j - 1, 0);

if (d[i - 1][j - 1] == u[i][j] && !a.is_value[i - 1])
add_string_update(side_t::left, op_t::copy, j - 1, i - 1);

if (d[i - 1][j - 1] == u[i][j] && !b.is_value[j - 1])
add_string_update(side_t::right, op_t::copy, i - 1, j - 1);
}
}
}
return u[n][m];
}


bool seq_plugin::repair_down_str_eq_edit_distance_incremental(app* eq) {
auto const& L = lhs(eq);
auto const& R = rhs(eq);
zstring a, b;
bool_vector a_is_value, b_is_value;
string_instance a, b;
init_string_instance(L, a);
init_string_instance(R, b);

for (auto x : L) {
auto const& val = strval0(x);
auto len = val.length();
auto is_val = is_value(x);
a += val;
for (unsigned i = 0; i < len; ++i)
a_is_value.push_back(is_val);
}

for (auto y : R) {
auto const& val = strval0(y);
auto len = val.length();
auto is_val = is_value(y);
b += val;
for (unsigned i = 0; i < len; ++i)
b_is_value.push_back(is_val);
}

if (a == b)
return update(eq->get_arg(0), a) && update(eq->get_arg(1), b);
if (a.s == b.s)
return update(eq->get_arg(0), a.s) && update(eq->get_arg(1), b.s);

unsigned diff = edit_distance_with_updates(a, b);

unsigned diff = edit_distance_with_updates(a, a_is_value, b, b_is_value);
if (a.length() == 0) {
m_str_updates.push_back({ eq->get_arg(1), zstring(), 1 });
m_str_updates.push_back({ eq->get_arg(0), zstring(b[0]), 1});
m_str_updates.push_back({ eq->get_arg(0), zstring(b[b.length() - 1]), 1});
}
if (b.length() == 0) {
m_str_updates.push_back({ eq->get_arg(0), zstring(), 1 });
m_str_updates.push_back({ eq->get_arg(1), zstring(a[0]), 1 });
m_str_updates.push_back({ eq->get_arg(1), zstring(a[a.length() - 1]), 1 });
}

verbose_stream() << "diff \"" << a << "\" \"" << b << "\" diff " << diff << " updates " << m_string_updates.size() << "\n";
verbose_stream() << "diff \"" << a.s << "\" \"" << b.s << "\" diff " << diff << " updates " << m_string_updates.size() << "\n";
#if 1
for (auto const& [side, op, i, j] : m_string_updates) {
switch (op) {
case op_t::del:
if (side == side_t::left)
verbose_stream() << "del " << a[i] << " @ " << i << " left\n";
verbose_stream() << "del " << a.s[i] << " @ " << i << " left\n";
else
verbose_stream() << "del " << b[i] << " @ " << i << " right\n";
verbose_stream() << "del " << b.s[i] << " @ " << i << " right\n";
break;
case op_t::add:
if (side == side_t::left)
verbose_stream() << "add " << b[i] << " @ " << j << " left\n";
verbose_stream() << "add " << b.s[i] << " @ " << j << " left\n";
else
verbose_stream() << "add " << a[i] << " @ " << j << " right\n";
verbose_stream() << "add " << a.s[i] << " @ " << j << " right\n";
break;
case op_t::copy:
if (side == side_t::left)
verbose_stream() << "copy " << b[i] << " @ " << j << " left\n";
verbose_stream() << "copy " << b.s[i] << " @ " << j << " left\n";
else
verbose_stream() << "copy " << a[i] << " @ " << j << " right\n";
verbose_stream() << "copy " << a.s[i] << " @ " << j << " right\n";
break;
}
}
Expand Down Expand Up @@ -905,13 +901,13 @@ namespace sls {
else if (op == op_t::del && side == side_t::right)
delete_char(R, i);
else if (op == op_t::add && side == side_t::left)
add_char(L, j, b[i]);
add_char(L, j, b.s[i]);
else if (op == op_t::add && side == side_t::right)
add_char(R, j, a[i]);
add_char(R, j, a.s[i]);
else if (op == op_t::copy && side == side_t::left)
copy_char(L, j, b[i]);
copy_char(L, j, b.s[i]);
else if (op == op_t::copy && side == side_t::right)
copy_char(R, j, a[i]);
copy_char(R, j, a.s[i]);
}
verbose_stream() << "num updates " << m_str_updates.size() << "\n";
bool r = apply_update();
Expand Down Expand Up @@ -939,9 +935,6 @@ namespace sls {
if (a == b)
return update(eq->get_arg(0), a) && update(eq->get_arg(1), b);




unsigned diff = edit_distance(a, b);

//verbose_stream() << "solve: " << diff << " " << a << " " << b << "\n";
Expand Down
13 changes: 12 additions & 1 deletion src/ast/sls/sls_seq_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,20 @@ namespace sls {
op_t op;
unsigned i, j;
};
struct string_instance {
zstring s;
bool_vector is_value;
bool_vector prev_is_var;
bool_vector next_is_var;

bool can_add(unsigned i) const {
return !is_value[i] || prev_is_var[i];
}
};
svector<string_update> m_string_updates;
void add_string_update(side_t side, op_t op, unsigned i, unsigned j) { m_string_updates.push_back({ side, op, i, j }); }
unsigned edit_distance_with_updates(zstring const& a, bool_vector const& a_is_value, zstring const& b, bool_vector const& b_is_value);
void init_string_instance(ptr_vector<expr> const& es, string_instance& a);
unsigned edit_distance_with_updates(string_instance const& a, string_instance const& b);
unsigned edit_distance(zstring const& a, zstring const& b);
void add_edit_updates(ptr_vector<expr> const& w, zstring const& val, zstring const& val_other, uint_set const& chars);

Expand Down
1 change: 1 addition & 0 deletions src/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ add_executable(test-z3
simplex.cpp
simplifier.cpp
sls_test.cpp
sls_seq_plugin.cpp
small_object_allocator.cpp
smt2print_parse.cpp
smt_context.cpp
Expand Down
1 change: 1 addition & 0 deletions src/test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,5 @@ int main(int argc, char ** argv) {
TST(euf_arith_plugin);
TST(sls_test);
TST(scoped_vector);
TST(sls_seq_plugin);
}
Loading

0 comments on commit b529a58

Please sign in to comment.