diff --git a/rosette/solver/smt/enc-lit.rkt b/rosette/solver/smt/enc-lit.rkt new file mode 100644 index 00000000..fa444c58 --- /dev/null +++ b/rosette/solver/smt/enc-lit.rkt @@ -0,0 +1,35 @@ +#lang racket/base + +(provide current-literal-encoder + enc-real + enc-integer) + +(require racket/match + racket/format + (prefix-in $ "smtlib2.rkt") + (only-in "../../base/core/bitvector.rkt" + bitvector? bv bitvector-size + @bveq @bvslt @bvsle @bvult @bvule + @bvnot @bvor @bvand @bvxor @bvshl @bvlshr @bvashr + @bvneg @bvadd @bvmul @bvudiv @bvsdiv @bvurem @bvsrem @bvsmod + @concat @extract @zero-extend @sign-extend + @integer->bitvector @bitvector->integer @bitvector->natural)) + +(define (default-literal-encoder v) + (match v + [#t $true] + [#f $false] + [(? integer?) (enc-integer v)] + [(? real?) (enc-real v)] + [(bv lit t) ($bv lit (bitvector-size t))] + [_ (error 'enc "expected a boolean?, integer?, real?, or bitvector?, given ~a" v)])) + +(define current-literal-encoder + (make-parameter default-literal-encoder)) + +(define-syntax-rule (enc-real v) + (if (exact? v) ($/ (numerator v) (denominator v)) (string->symbol (~r v)))) + +(define-syntax-rule (enc-integer v) + (let ([v* (inexact->exact v)]) + (if (< v* 0) ($- (abs v*)) v*))) diff --git a/rosette/solver/smt/enc.rkt b/rosette/solver/smt/enc.rkt index 7fe55c12..bdcf72cd 100644 --- a/rosette/solver/smt/enc.rkt +++ b/rosette/solver/smt/enc.rkt @@ -1,7 +1,8 @@ #lang racket -(require "env.rkt" - (prefix-in $ "smtlib2.rkt") +(require "env.rkt" + "enc-lit.rkt" + (prefix-in $ "smtlib2.rkt") (only-in "../../base/core/term.rkt" expression expression? constant? term? get-type @app) (only-in "../../base/core/polymorphic.rkt" ite ite* =? guarded-test guarded-value) (only-in "../../base/core/distinct.rkt" @distinct?) @@ -74,19 +75,7 @@ [_ (error 'enc "cannot encode ~a to SMT" v)])) (define (enc-lit v env quantified) - (match v - [#t $true] - [#f $false] - [(? integer?) (enc-integer v)] - [(? real?) (enc-real v)] - [(bv lit t) ($bv lit (bitvector-size t))] - [_ (error 'enc "expected a boolean?, integer?, real?, or bitvector?, given ~a" v)])) - -(define-syntax-rule (enc-real v) - (if (exact? v) ($/ (numerator v) (denominator v)) (string->symbol (~r v)))) -(define-syntax-rule (enc-integer v) - (let ([v* (inexact->exact v)]) - (if (< v* 0) ($- (abs v*)) v*))) + ((current-literal-encoder) v)) (define-syntax define-encoder (syntax-rules () diff --git a/test/base/real.rkt b/test/base/real.rkt index 6df534e9..51837f2d 100644 --- a/test/base/real.rkt +++ b/test/base/real.rkt @@ -1,6 +1,7 @@ #lang racket -(require rackunit rackunit/text-ui "common.rkt" "solver.rkt" +(require rackunit rackunit/text-ui "common.rkt" "solver.rkt" + rosette/solver/smt/enc-lit rosette/solver/smt/z3 rosette/solver/solution rosette/lib/roseunit @@ -256,7 +257,19 @@ (parameterize ([solver (z3 #:options (hash ':pp.decimal 'true ':pp.decimal-precision 15))]) (check-= (abs ((solve (@= (@* xr xr) 2)) xr)) (abs (sqrt 2)) 1e-15) (solver-shutdown (solver)))) - + +(define (check-qf-nra) + (parameterize ([solver (z3 #:logic 'QF_NRA)] + [current-literal-encoder + (let ([encode (current-literal-encoder)]) + (λ (x) + (cond + [(real? x) (enc-real x)] + [else (encode x)])))]) + (define-symbolic x @real?) + (check-pred sat? (solve (@= -2.0 x))) + (solver-shutdown (solver)))) + (define (check-division-simplifications div x y z [epsilon 0]) (check-valid? (div 0 x) 0) (check-valid? (div x 1) x) @@ -563,6 +576,7 @@ (check-*-simplifications xr yr zr) (check-*-real-simplifications) (check-irrationals) + (check-qf-nra) (check-semantics @* xi yi zi) (check-semantics @* xr yr zr)))