Skip to content

Commit

Permalink
Improve predicate operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Adda0 committed Nov 9, 2022
1 parent cdf761a commit 7db2382
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 36 deletions.
166 changes: 132 additions & 34 deletions src/smt/theory_str_noodler/inclusion_graph_node.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
/**
* @brief Create basic representation of inclusion graph node
* @brief Create basic representation of an inclusion graph node.
*
* The inclusion graph node is represented as a formula consisting of equations and ineqautions.
* Each equation or inequation consists of a left and right side of the equation which hold a vector of basic equations
* The inclusion graph node is represented as a predicate, represention an equation, inequation or another predicate
* such as contains, etc.
* Each equation or inequation consists of a left and right side of the equation which hold a vector of basic equation
* terms.
* Each term is of one of the following types:
* - Literal,
* - Variable, or
* - CallPredicate.
* - operation such as IndexOf, Length, etc.
*/

#ifndef Z3_INCLUSION_GRAPH_NODE_H
#define Z3_INCLUSION_GRAPH_NODE_H

Expand All @@ -19,30 +21,36 @@
#include <cassert>
#include <unordered_map>

namespace smt {
namespace noodler {
namespace smt::noodler {
enum struct PredicateType {
Equation,
Inequation,
Substring,
Contains,
Length,
IndexOf,
// TODO: Add additional predicate types.
};

enum struct BasicTermType {
Literal,
Variable
Variable,
Length,
Substring,
IndexOf,
// TODO: Add additional basic term types.
};

class BasicTerm {
public:
explicit BasicTerm(BasicTermType type): type(type) {}
BasicTerm(BasicTermType type, std::string_view name): type(type), name(name) {}

[[nodiscard]] BasicTermType get_type() const { return type; }
[[nodiscard]] bool is_variable() const { return type == BasicTermType::Variable; }
[[nodiscard]] bool is_literal() const { return type == BasicTermType::Literal; }
[[nodiscard]] bool is(BasicTermType term_type) const { return type == term_type; }

[[nodiscard]] std::string get_name() const { return name; }
void set_name(std::string_view new_name) { name = new_name; }

[[nodiscard]] bool equals(const BasicTerm& other) const {
return type == other.get_type() && name == other.get_name();
}
Expand All @@ -69,26 +77,50 @@ namespace noodler {
}
}

PredicateType get_type() { return type; }
[[nodiscard]] bool is_equation() const { return type == PredicateType::Equation; }
[[nodiscard]] bool is_inequation() const { return type == PredicateType::Inequation; }
[[nodiscard]] bool is_predicate() const { return !is_equation() && !is_inequation(); }
[[nodiscard]] bool is_eq_or_ineq() const { return is_equation() || is_inequation(); }
[[nodiscard]] bool is_predicate() const { return !is_eq_or_ineq(); }
[[nodiscard]] bool is(const PredicateType predicate_type) const { return predicate_type == this->type; }

PredicateType get_type() { return type; }
void set_type(PredicateType new_type) { type = new_type; }
std::vector<BasicTerm>& get_left_side() {
assert(is_eq_or_ineq());
return params[0];
}

std::vector<BasicTerm>& get_left() {
assert(is_equation() || is_inequation());
[[nodiscard]] const std::vector<BasicTerm>& get_left_side() const {
assert(is_eq_or_ineq());
return params[0];
}
std::vector<BasicTerm>& get_right() {
assert(is_equation() || is_inequation());

std::vector<BasicTerm>& get_right_side() {
assert(is_eq_or_ineq());
return params[1];
}

// TODO: Should we implement get_vars() and get_side_vars()?
[[nodiscard]] const std::vector<BasicTerm>& get_right_side() const {
assert(is_eq_or_ineq());
return params[1];
}

std::vector<BasicTerm>& get_side(const EquationSideType side) {
assert(is_equation() || is_inequation());
assert(is_eq_or_ineq());
switch (side) {
case EquationSideType::Left:
return params[0];
break;
case EquationSideType::Right:
return params[1];
break;
default:
throw std::runtime_error("unhandled equation side type");
break;
}
}

[[nodiscard]] const std::vector<BasicTerm>& get_side(const EquationSideType side) const {
assert(is_eq_or_ineq());
switch (side) {
case EquationSideType::Left:
return params[0];
Expand All @@ -102,15 +134,83 @@ namespace noodler {
}
}

bool multiple_occurrence_of_term_on_side(const EquationSideType side) {
assert(is_equation() || is_inequation());
/**
* Get unique variables on both sides of an (in)equation.
* @return Variables in the (in)equation.
*/
[[nodiscard]] std::vector<BasicTerm> get_vars() const {
assert(is_eq_or_ineq());
std::vector<BasicTerm> vars;
for (const auto& side: params) {
for (const auto &term: side) {
if (term.is_variable()) {
bool found{false};
for (const auto &var: vars) {
if (var == term) {
found = true;
break;
}
}
if (!found) { vars.push_back(term); }
}
}
}
return vars;
}

/**
* Get unique variables on a single @p side of an (in)equation.
* @param[in] side (In)Equation side to get variables from.
* @return Variables in the (in)equation on specified @p side.
*/
[[nodiscard]] std::vector<BasicTerm> get_side_vars(const EquationSideType side) const {
assert(is_eq_or_ineq());
std::vector<BasicTerm> vars;
std::vector<BasicTerm> side_terms;
switch (side) {
case EquationSideType::Left:
side_terms = get_left_side();
break;
case EquationSideType::Right:
side_terms = get_right_side();
break;
default:
throw std::runtime_error("unhandled equation side_terms type");
break;
}

for (const auto &term: side_terms) {
if (term.is_variable()) {
bool found{false};
for (const auto &var: vars) {
if (var == term) {
found = true;
break;
}
}
if (!found) { vars.push_back(term); }
}
}
return vars;
}

/**
* Decide whether the @p side contains multiple occurrences of a single variable (with a same name).
* @param side Side to check.
* @return True if there are multiple occurrences of a single variable. False otherwise.
*/
[[nodiscard]] bool mult_occurr_var_side(const EquationSideType side) const {
assert(is_eq_or_ineq());
const auto terms_begin{ get_side(side).cbegin() };
const auto terms_end{ get_side(side).cend() };
for (auto term_iter{ terms_begin }; term_iter < terms_end; ++term_iter) {
for (auto term_iter_following{ term_iter + 1}; term_iter_following < terms_end; ++term_iter_following) {
if (*term_iter == *term_iter_following) {
return true;
// TODO: How to handle calls of predicates?
if (term_iter->is_variable()) {
for (auto term_iter_following{ term_iter + 1}; term_iter_following < terms_end;
++term_iter_following) {
if (*term_iter == *term_iter_following) {
return true;
// TODO: How to handle calls of predicates?
}
}
}
}
Expand All @@ -127,10 +227,10 @@ namespace noodler {
throw std::runtime_error("unimplemented");
}

friend bool operator==(const Predicate& lhs, const Predicate& rhs) {
if (lhs.type == rhs.type) {
if (lhs.is_equation() || lhs.is_inequation()) {
return lhs.params[0] == rhs.params[0] && lhs.params[1] == rhs.params[1];
[[nodiscard]] bool equals(const Predicate& other) const {
if (type == other.type) {
if (is_eq_or_ineq()) {
return params[0] == other.params[0] && params[1] == other.params[1];
}
return true;
}
Expand All @@ -142,23 +242,21 @@ namespace noodler {
private:
PredicateType type;
std::vector<std::vector<BasicTerm>> params;

}; // Class Predicate.

using FormulaPredicates = std::vector<Predicate>;
bool operator==(const Predicate& lhs, const Predicate& rhs) { return lhs.equals(rhs); }

class Formula {
Formula(): predicates() {}

FormulaPredicates& get_predicates() { return predicates; }
std::vector<Predicate>& get_predicates() { return predicates; }

// TODO: Use std::move for both add functions?
void add_predicate(const Predicate& predicate) { predicates.push_back(predicate); }

private:
FormulaPredicates predicates;
std::vector<Predicate> predicates;
}; // Class Formula.
} // Namespace noodler.
} // Namespace smt.

#endif //Z3_INCLUSION_GRAPH_NODE_H
9 changes: 7 additions & 2 deletions src/test/noodler/inclusion-graph-node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ TEST_CASE( "Inclusion graph node", "[noodler]" ) {
auto term{ BasicTerm(BasicTermType::Variable, term_name) };
CHECK(term.get_name() == term_name);

auto& left{ predicate.get_left() };
auto& left{predicate.get_left_side() };
left.emplace_back(term);
left.emplace_back( BasicTermType::Literal, "lit" );
left.emplace_back(term);
CHECK(predicate.multiple_occurrence_of_term_on_side(Predicate::EquationSideType::Left));
CHECK(predicate.mult_occurr_var_side(Predicate::EquationSideType::Left));

CHECK(predicate.is_eq_or_ineq());
CHECK(predicate.get_side_vars(Predicate::EquationSideType::Left) == std::vector<BasicTerm>{ term });
CHECK(predicate.get_side_vars(Predicate::EquationSideType::Right).empty());
}

0 comments on commit 7db2382

Please sign in to comment.