From fd78994ddd33251bbb8a41a3d046352e94bbe3c6 Mon Sep 17 00:00:00 2001 From: Dillon Date: Mon, 2 Dec 2024 14:37:40 -0800 Subject: [PATCH 1/2] Catch overflow when applying rewrite rules --- builder/rewrite.h | 78 ++++++++++++++++++----------- builder/test/simplify/rule_tester.h | 16 ++++-- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/builder/rewrite.h b/builder/rewrite.h index d07cfc21..a4e4452d 100644 --- a/builder/rewrite.h +++ b/builder/rewrite.h @@ -42,7 +42,7 @@ template SLINKY_UNIQUE bool match(index_t p, expr_ref x, match_context&) { return is_constant(x, p); } -SLINKY_UNIQUE index_t substitute(index_t p, const match_context&) { return p; } +SLINKY_UNIQUE index_t substitute(index_t p, const match_context&, bool&) { return p; } template struct pattern_info { @@ -117,7 +117,7 @@ SLINKY_UNIQUE bool match(const pattern_wildcard& p, expr_ref x, match_context } template -SLINKY_UNIQUE expr_ref substitute(const pattern_wildcard& p, const match_context& ctx) { +SLINKY_UNIQUE expr_ref substitute(const pattern_wildcard&, const match_context& ctx, bool&) { return ctx.vars[N]; } @@ -150,7 +150,7 @@ SLINKY_UNIQUE bool match(const pattern_constant& p, expr_ref x, match_context } template -SLINKY_UNIQUE index_t substitute(const pattern_constant& p, const match_context& ctx) { +SLINKY_UNIQUE index_t substitute(const pattern_constant&, const match_context& ctx, bool&) { return ctx.constants[N]; } @@ -211,9 +211,24 @@ SLINKY_UNIQUE bool match( return match_binary(p, x.a.e, x.b.e, ctx); } +template +SLINKY_UNIQUE expr substitute_binary(expr a, expr b, bool&) { + return make_binary(std::move(a), std::move(b)); +} + +template +SLINKY_UNIQUE index_t substitute_binary(index_t a, index_t b, bool& overflowed) { + if (binary_overflows(a, b)) { + overflowed = true; + return 0; + } else { + return make_binary(a, b); + } +} + template -SLINKY_UNIQUE auto substitute(const pattern_binary& p, const match_context& ctx) { - return make_binary(substitute(p.a, ctx), substitute(p.b, ctx)); +SLINKY_UNIQUE auto substitute(const pattern_binary& p, const match_context& ctx, bool& overflowed) { + return substitute_binary(substitute(p.a, ctx, overflowed), substitute(p.b, ctx, overflowed), overflowed); } template @@ -278,8 +293,8 @@ template <> inline index_t make_unary(index_t a) { return a == 0 ? // clang-format on template -SLINKY_UNIQUE auto substitute(const pattern_unary& p, const match_context& ctx) { - return make_unary(substitute(p.a, ctx)); +SLINKY_UNIQUE auto substitute(const pattern_unary& p, const match_context& ctx, bool& overflowed) { + return make_unary(substitute(p.a, ctx, overflowed)); } template @@ -314,8 +329,9 @@ SLINKY_UNIQUE bool match(const pattern_select& p, } template -SLINKY_UNIQUE expr substitute(const pattern_select& p, const match_context& ctx) { - return select::make(substitute(p.c, ctx), substitute(p.t, ctx), substitute(p.f, ctx)); +SLINKY_UNIQUE expr substitute(const pattern_select& p, const match_context& ctx, bool& overflowed) { + return select::make( + substitute(p.c, ctx, overflowed), substitute(p.t, ctx, overflowed), substitute(p.f, ctx, overflowed)); } template @@ -359,14 +375,15 @@ SLINKY_UNIQUE bool match(const pattern_call& p, expr_ref x, match_conte return false; } -SLINKY_UNIQUE expr substitute(const pattern_call<>& p, const match_context& ctx) { return call::make(p.fn, {}); } +SLINKY_UNIQUE expr substitute(const pattern_call<>& p, const match_context& ctx, bool&) { return call::make(p.fn, {}); } template -SLINKY_UNIQUE expr substitute(const pattern_call& p, const match_context& ctx) { - return call::make(p.fn, {substitute(std::get<0>(p.args), ctx)}); +SLINKY_UNIQUE expr substitute(const pattern_call& p, const match_context& ctx, bool& overflowed) { + return call::make(p.fn, {substitute(std::get<0>(p.args), ctx, overflowed)}); } template -SLINKY_UNIQUE expr substitute(const pattern_call& p, const match_context& ctx) { - return call::make(p.fn, {substitute(std::get<0>(p.args), ctx), substitute(std::get<1>(p.args), ctx)}); +SLINKY_UNIQUE expr substitute(const pattern_call& p, const match_context& ctx, bool& overflowed) { + return call::make( + p.fn, {substitute(std::get<0>(p.args), ctx, overflowed), substitute(std::get<1>(p.args), ctx, overflowed)}); } SLINKY_UNIQUE std::ostream& operator<<(std::ostream& os, const pattern_call<>& p) { return os << p.fn << "()"; } @@ -389,8 +406,8 @@ class replacement_predicate { }; template -SLINKY_UNIQUE bool substitute(const replacement_predicate& r, const match_context& ctx) { - return r.fn(substitute(r.a, ctx)); +SLINKY_UNIQUE bool substitute(const replacement_predicate& r, const match_context& ctx, bool& overflowed) { + return r.fn(substitute(r.a, ctx, overflowed)); } template @@ -413,8 +430,8 @@ class replacement_eval { }; template -SLINKY_UNIQUE index_t substitute(const replacement_eval& r, const match_context& ctx) { - return substitute(r.a, ctx); +SLINKY_UNIQUE index_t substitute(const replacement_eval& r, const match_context& ctx, bool& overflowed) { + return substitute(r.a, ctx, overflowed); } template @@ -432,8 +449,8 @@ class replacement_boolean { }; template -SLINKY_UNIQUE auto substitute(const replacement_boolean& r, const match_context& ctx) { - return boolean(substitute(r.a, ctx)); +SLINKY_UNIQUE auto substitute(const replacement_boolean& r, const match_context& ctx, bool& overflowed) { + return boolean(substitute(r.a, ctx, overflowed)); } template @@ -563,7 +580,9 @@ SLINKY_UNIQUE bool match_any_variant(Pattern p, const Target& x, match_context& template SLINKY_UNIQUE bool match(match_context& ctx, Pattern p, const Target& x, Predicate pr) { - return match_any_variant(p, x, ctx) && substitute(pr, ctx); + if (!match_any_variant(p, x, ctx)) return false; + bool overflowed = false; + return substitute(pr, ctx, overflowed) && !overflowed; } template @@ -584,8 +603,9 @@ class base_rewriter { SLINKY_ALWAYS_INLINE bool find_replacement(const match_context& ctx, Replacement r) { static_assert(pattern_info::is_canonical); static_assert(!pattern_info::is_boolean || pattern_info::is_boolean); - result = substitute(r, ctx); - return true; + bool overflowed = false; + result = substitute(r, ctx, overflowed); + return !overflowed; } template @@ -594,13 +614,13 @@ class base_rewriter { static_assert(pattern_info::is_canonical); static_assert(!pattern_info::is_boolean || pattern_info::is_boolean); - if (substitute(pr, ctx)) { - result = substitute(r, ctx); - return true; - } else { - // Try the next replacement - return find_replacement(ctx, r_pr...); + bool overflowed = false; + if (substitute(pr, ctx, overflowed) && !overflowed) { + result = substitute(r, ctx, overflowed); + if (!overflowed) return true; } + // Try the next replacement + return find_replacement(ctx, r_pr...); } public: diff --git a/builder/test/simplify/rule_tester.h b/builder/test/simplify/rule_tester.h index d08d87aa..a3d98c30 100644 --- a/builder/test/simplify/rule_tester.h +++ b/builder/test/simplify/rule_tester.h @@ -81,8 +81,11 @@ class rule_tester { std::stringstream rule_str; rule_str << p << " -> " << r; - expr pattern = expr(substitute(p, m)); - expr replacement = expr(substitute(r, m)); + bool overflowed = false; + expr pattern = expr(substitute(p, m, overflowed)); + assert(!overflowed); + expr replacement = expr(substitute(r, m, overflowed)); + assert(!overflowed); // Make sure the expressions have the same value when evaluated. test_expr(pattern, replacement, rule_str.str()); @@ -101,9 +104,12 @@ class rule_tester { // expression that the rule applies to. for (int test = 0; test < 100000; ++test) { init_match_context(); - if (substitute(pr, m)) { - expr pattern = expr(substitute(p, m)); - expr replacement = expr(substitute(r, m)); + bool overflowed = false; + if (substitute(pr, m, overflowed) && !overflowed) { + expr pattern = expr(substitute(p, m, overflowed)); + assert(!overflowed); + expr replacement = expr(substitute(r, m, overflowed)); + assert(!overflowed); // Make sure the expressions have the same value when evaluated. test_expr(pattern, replacement, rule_str.str()); From 332b91e44926f8c5d26dd5149c057bcd21556f33 Mon Sep 17 00:00:00 2001 From: Dillon Date: Mon, 2 Dec 2024 14:41:27 -0800 Subject: [PATCH 2/2] Don't constant evaluate if it overflows --- runtime/evaluate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/evaluate.cc b/runtime/evaluate.cc index 7baeee5b..107d658b 100644 --- a/runtime/evaluate.cc +++ b/runtime/evaluate.cc @@ -775,7 +775,7 @@ class constant_evaluator : public expr_visitor { void visit_binary(const T* op) { std::optional a = eval(op->a); std::optional b = eval(op->b); - if (a && b) { + if (a && b && !binary_overflows(*a, *b)) { result = make_binary(*a, *b); } else { result = std::nullopt;