Skip to content

Commit

Permalink
Minor simplify improvements and fixes (#484)
Browse files Browse the repository at this point in the history
Fixes a few possible future bugs, and speeds up the simplifier on random
expressions by ~10%.
  • Loading branch information
dsharlet authored Nov 5, 2024
1 parent 854ae8c commit c49d090
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 45 deletions.
16 changes: 9 additions & 7 deletions builder/node_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,22 @@ class node_mutator : public expr_visitor, public stmt_visitor {
const stmt& mutated_stmt() const { return s_; }

virtual expr mutate(const expr& e) {
assert(!e_.defined());
if (e.defined()) {
e.accept(this);
return std::move(e_);
} else {
return expr();
switch (e.type()) {
case expr_node_type::variable: visit(static_cast<const variable*>(e.get())); break;
case expr_node_type::constant: visit(static_cast<const constant*>(e.get())); break;
default: e.accept(this);
}
}
return std::move(e_);
}
virtual stmt mutate(const stmt& s) {
assert(!s_.defined());
if (s.defined()) {
s.accept(this);
return std::move(s_);
} else {
return stmt();
}
return std::move(s_);
}

virtual interval_expr mutate(const interval_expr& x) {
Expand Down
75 changes: 40 additions & 35 deletions builder/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ expr strip_boolean(expr x) {
if (const not_equal* ne = x.as<not_equal>()) {
if (is_zero(ne->b)) {
return strip_boolean(ne->a);
} else if (is_zero(ne->a)) {
return strip_boolean(ne->b);
}
}
// This should be canonicalized to the RHS.
assert(!is_zero(ne->a));
}
return x;
}
Expand Down Expand Up @@ -258,23 +258,27 @@ class constant_adder : public node_mutator {
template <typename T>
void visit_min_max(const T* op) {
expr a = mutate(op->a);
expr b = mutate(op->b);
if (a.defined() && b.defined()) {
set_result(T::make(std::move(a), std::move(b)));
} else {
set_result(expr());
if (a.defined()) {
expr b = mutate(op->b);
if (b.defined()) {
set_result(T::make(std::move(a), std::move(b)));
return;
}
}
set_result(expr());
}
void visit(const class min* op) override { visit_min_max(op); }
void visit(const class max* op) override { visit_min_max(op); }
void visit(const class select* op) override {
expr t = mutate(op->true_value);
expr f = mutate(op->false_value);
if (t.defined() && f.defined()) {
set_result(select::make(op->condition, std::move(t), std::move(f)));
} else {
set_result(expr());
if (t.defined()) {
expr f = mutate(op->false_value);
if (f.defined()) {
set_result(select::make(op->condition, std::move(t), std::move(f)));
return;
}
}
set_result(expr());
}

void visit(const mul* op) override {
Expand Down Expand Up @@ -385,14 +389,12 @@ class simplifier : public node_mutator {
set_result(expr(e), std::move(info));
}
void set_result(stmt s) {
assert(!result_info.bounds.min.defined() && !result_info.bounds.max.defined());
result_info = {interval_expr(), alignment_type()};
node_mutator::set_result(std::move(s));
}
void set_result(const base_stmt_node* s) { set_result(stmt(s)); }
// Dummy for template code.
void set_result(stmt s, expr_info) { set_result(std::move(s)); }
void set_result(const base_stmt_node* s, expr_info) { set_result(stmt(s)); }
void set_result(stmt s, const expr_info&) { set_result(std::move(s)); }
void set_result(const base_stmt_node* s, const expr_info&) { set_result(stmt(s)); }

public:
simplifier() {}
Expand Down Expand Up @@ -436,7 +438,9 @@ class simplifier : public node_mutator {
// this.
expr mutate_boolean(const expr& e, expr_info* info) {
expr result = strip_boolean(mutate(e, info));
if (info) info->bounds = bounds_of(static_cast<const not_equal*>(nullptr), std::move(info->bounds), point(0));
if (info && !is_boolean(result)) {
info->bounds = bounds_of(static_cast<const not_equal*>(nullptr), std::move(info->bounds), point(0));
}
return result;
}

Expand Down Expand Up @@ -569,10 +573,11 @@ class simplifier : public node_mutator {
std::optional<index_t> ec = evaluate_constant(e);
if (ec) return *ec != 0;

// e is constant true if we know it has a bounds that don't include zero.
expr predicate = logical_or::make(less::make(0, constant_lower_bound(e)), less::make(constant_upper_bound(e), 0));
std::optional<index_t> result = evaluate_constant(predicate);
return result && *result != 0;
// e is constant true if we know it has bounds that don't include zero.
std::optional<index_t> a = evaluate_constant(constant_lower_bound(e));
if (a && *a > 0) return true;
std::optional<index_t> b = evaluate_constant(constant_upper_bound(e));
return b && *b < 0;
}

static bool prove_constant_false(const expr& e) {
Expand All @@ -582,9 +587,11 @@ class simplifier : public node_mutator {
if (ec) return *ec == 0;

// e is constant false if we know its bounds are [0, 0].
expr predicate = logical_or::make(constant_lower_bound(e), constant_upper_bound(e));
std::optional<index_t> result = evaluate_constant(predicate);
return result && *result == 0;
std::optional<index_t> a = evaluate_constant(constant_lower_bound(e));
if (!a) return false;
std::optional<index_t> b = evaluate_constant(constant_upper_bound(e));
if (!b) return false;
return *a == 0 && *b == 0;
}

std::optional<bool> attempt_to_prove(const expr& e) {
Expand Down Expand Up @@ -1308,7 +1315,8 @@ class simplifier : public node_mutator {
}

template <typename T>
bool buffer_changed(const T* op, const buffer_info& info) {
static bool buffer_changed(const T* op, const buffer_info& info) {
if (op->dims.size() != info.dims.size()) return true;
if (!info.elem_size.same_as(op->elem_size)) return true;
for (std::size_t d = 0; d < op->dims.size(); ++d) {
if (!info.dims[d].same_as(op->dims[d])) return true;
Expand Down Expand Up @@ -2120,16 +2128,15 @@ class constant_bound : public node_mutator {

template <typename T>
void visit_less(const T* op) {
expr a, b;
// This is a constant version of that found in bounds_of_less:
// - For a lower bound, we want to know if this can ever be false, so we want the upper bound of the lhs and the
// lower bound of the rhs.
// - For an upper bound, we want to know if this can ever be true, so we want the lower bound of the lhs and the
// upper bound of the rhs.
sign = -sign;
a = mutate(op->a);
expr a = mutate(op->a);
sign = -sign;
b = mutate(op->b);
expr b = mutate(op->b);

const index_t* ca = as_constant(a);
const index_t* cb = as_constant(b);
Expand All @@ -2149,15 +2156,14 @@ class constant_bound : public node_mutator {
// We can recursively mutate if:
// - We're looking for the upper bound of &&, because if either operand is definitely false, the result is false.
// - We're looking for the lower bound of ||, because if either operand is definitely true, the result is true.
// Whenever we mutate an expression implicitly converted to bool, we need to force it to have the value 0 or 1.
expr a = recurse ? mutate(boolean(op->a)) : op->a;
expr b = recurse ? mutate(boolean(op->b)) : op->b;
expr a = recurse ? mutate(op->a) : op->a;
expr b = recurse ? mutate(op->b) : op->b;

const index_t* ca = as_constant(a);
const index_t* cb = as_constant(b);

if (ca && cb) {
set_result(make_binary<T>(*ca, *cb));
set_result(make_binary<T>(*ca != 0, *cb != 0));
} else if (sign < 0) {
set_result(expr(0));
} else {
Expand All @@ -2169,8 +2175,7 @@ class constant_bound : public node_mutator {

void visit(const logical_not* op) override {
sign = -sign;
// Whenever we mutate an expression implicitly converted to bool, we need to force it to have the value 0 or 1.
expr a = mutate(boolean(op->a));
expr a = mutate(op->a);
sign = -sign;
const index_t* ca = as_constant(a);
if (ca) {
Expand Down
2 changes: 1 addition & 1 deletion builder/simplify_exprs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ expr simplify(const call* op, intrinsic fn, std::vector<expr> args) {
e = call::make(fn, std::move(args));
}

if (can_evaluate(fn) && constant) {
if (constant && can_evaluate(fn)) {
return evaluate(e);
}

Expand Down
2 changes: 1 addition & 1 deletion builder/substitute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class matcher : public expr_visitor, public stmt_visitor {
void visit(const base_expr_node* op) {
switch (op->type) {
case expr_node_type::variable: visit(reinterpret_cast<const variable*>(op)); return;
case expr_node_type::add: visit(reinterpret_cast<const add*>(op)); return;
case expr_node_type::constant: visit(reinterpret_cast<const constant*>(op)); return;
case expr_node_type::min: visit(reinterpret_cast<const class min*>(op)); return;
case expr_node_type::max: visit(reinterpret_cast<const class max*>(op)); return;
default: op->accept(this);
Expand Down
5 changes: 4 additions & 1 deletion runtime/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ namespace slinky {

bool can_evaluate(intrinsic fn) {
switch (fn) {
case intrinsic::abs: return true;
case intrinsic::abs:
case intrinsic::and_then:
case intrinsic::or_else:
return true;
default: return false;
}
}
Expand Down

0 comments on commit c49d090

Please sign in to comment.