diff --git a/.github/workflows/Windows.yml b/.github/workflows/Windows.yml new file mode 100644 index 00000000000..74eb04d15b5 --- /dev/null +++ b/.github/workflows/Windows.yml @@ -0,0 +1,64 @@ +name: Windows + +on: + push: + branches: [ master ] + +jobs: + build: + strategy: + matrix: + arch : [x86,x64,amd64_arm64] + include: + - arch : x86 + - arch : amd64_arm64 + - arch : x64 + cmd1 : 'julia -e "using Pkg; Pkg.add(PackageSpec(name=\"libcxxwrap_julia_jll\"))"' + cmd2 : 'julia -e "using libcxxwrap_julia_jll; print(dirname(libcxxwrap_julia_jll.libcxxwrap_julia_path))" > tmp.env' + cmd3 : 'set /P JlCxxDir=> $GITHUB_OUTPUT - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage-${{steps.date.outputs.date}} path: ${{github.workspace}}/coverage.html retention-days: 4 - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage-details-${{steps.date.outputs.date}} path: ${{env.COV_DETAILS_PATH}} diff --git a/.github/workflows/cross-build.yml b/.github/workflows/cross-build.yml index 83fae7a5d14..8745215d243 100644 --- a/.github/workflows/cross-build.yml +++ b/.github/workflows/cross-build.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install cross build tools run: apt update && apt install -y ninja-build cmake python3 g++-11-${{ matrix.arch }}-linux-gnu diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index d9baccd39a7..5bb7d2cad03 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -15,10 +15,10 @@ jobs: steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Log in to GitHub Docker registry - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ secrets.DOCKER_USERNAME }} @@ -29,7 +29,7 @@ jobs: # ------- - name: Extract metadata (tags, labels) for Bare Z3 Docker Image id: meta - uses: docker/metadata-action@v4 + uses: docker/metadata-action@v5 with: images: ghcr.io/z3prover/z3 flavor: | @@ -41,7 +41,7 @@ jobs: type=edge type=sha,prefix=ubuntu-20.04-bare-z3-sha- - name: Build and push Bare Z3 Docker Image - uses: docker/build-push-action@v4.0.0 + uses: docker/build-push-action@v5.1.0 with: context: . push: true diff --git a/.github/workflows/msvc-static-build.yml b/.github/workflows/msvc-static-build.yml index 2db2221610b..b329d5abc34 100644 --- a/.github/workflows/msvc-static-build.yml +++ b/.github/workflows/msvc-static-build.yml @@ -14,7 +14,7 @@ jobs: BUILD_TYPE: Release steps: - name: Checkout Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Build run: | diff --git a/.github/workflows/wasm-release.yml b/.github/workflows/wasm-release.yml index de15a242c1a..ed5dacec810 100644 --- a/.github/workflows/wasm-release.yml +++ b/.github/workflows/wasm-release.yml @@ -21,10 +21,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup node - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: "lts/*" registry-url: "https://registry.npmjs.org" @@ -36,7 +36,7 @@ jobs: cp ../../../LICENSE.txt . - name: Setup emscripten - uses: mymindstorm/setup-emsdk@v12 + uses: mymindstorm/setup-emsdk@v14 with: no-install: true version: ${{env.EM_VERSION}} diff --git a/.github/workflows/wasm.yml b/.github/workflows/wasm.yml index e8ac095e594..8e157f5a465 100644 --- a/.github/workflows/wasm.yml +++ b/.github/workflows/wasm.yml @@ -21,15 +21,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup node - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: "lts/*" - name: Setup emscripten - uses: mymindstorm/setup-emsdk@v12 + uses: mymindstorm/setup-emsdk@v14 with: no-install: true version: ${{env.EM_VERSION}} diff --git a/.github/workflows/wip.yml b/.github/workflows/wip.yml index ffea6225cec..5ed29a4571d 100644 --- a/.github/workflows/wip.yml +++ b/.github/workflows/wip.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Configure CMake run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ee39827ef7..f6877048ff8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,8 @@ # Enforce some CMake policies -cmake_minimum_required(VERSION 3.4) +cmake_minimum_required(VERSION 3.16) set(CMAKE_USER_MAKE_RULES_OVERRIDE_CXX "${CMAKE_CURRENT_SOURCE_DIR}/cmake/cxx_compiler_flags_overrides.cmake") -project(Z3 VERSION 4.12.2.0 LANGUAGES CXX) +project(Z3 VERSION 4.13.0.0 LANGUAGES CXX) ################################################################################ # Project version @@ -41,19 +41,22 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") ################################################################################ include(${PROJECT_SOURCE_DIR}/cmake/git_utils.cmake) macro(disable_git_describe) - message(WARNING "Disabling Z3_INCLUDE_GIT_DESCRIBE") - set(Z3_INCLUDE_GIT_DESCRIBE OFF CACHE BOOL "Include git describe output in version output" FORCE) + if(Z3_INCLUDE_GIT_DESCRIBE) + message(WARNING "Disabling Z3_INCLUDE_GIT_DESCRIBE") + set(Z3_INCLUDE_GIT_DESCRIBE OFF CACHE BOOL "Include git describe output in version output" FORCE) + endif() endmacro() macro(disable_git_hash) - message(WARNING "Disabling Z3_INCLUDE_GIT_HASH") - set(Z3_INCLUDE_GIT_HASH OFF CACHE BOOL "Include git hash in version output" FORCE) - unset(Z3GITHASH) # Used in configure_file() + if(Z3_INCLUDE_GIT_HASH) + message(WARNING "Disabling Z3_INCLUDE_GIT_HASH") + set(Z3_INCLUDE_GIT_HASH OFF CACHE BOOL "Include git hash in version output" FORCE) + endif() endmacro() option(Z3_INCLUDE_GIT_HASH "Include git hash in version output" ON) option(Z3_INCLUDE_GIT_DESCRIBE "Include git describe output in version output" ON) set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git") -if (EXISTS "${GIT_DIR}") +if ((Z3_INCLUDE_GIT_HASH OR Z3_INCLUDE_GIT_HASH) AND EXISTS "${GIT_DIR}") # Try to make CMake configure depend on the current git HEAD so that # a re-configure is triggered when the HEAD changes. add_git_dir_dependency("${GIT_DIR}" ADD_GIT_DEP_SUCCESS) @@ -63,13 +66,13 @@ if (EXISTS "${GIT_DIR}") if (NOT Z3GITHASH) message(WARNING "Failed to get Git hash") disable_git_hash() + else() + message(STATUS "Using Git hash in version output: ${Z3GITHASH}") + # This mimics the behaviour of the old build system. + set(Z3_FULL_VERSION_STR "${Z3_FULL_VERSION_STR} ${Z3GITHASH}") endif() - message(STATUS "Using Git hash in version output: ${Z3GITHASH}") - # This mimics the behaviour of the old build system. - set(Z3_FULL_VERSION_STR "${Z3_FULL_VERSION_STR} ${Z3GITHASH}") else() message(STATUS "Not using Git hash in version output") - unset(Z3GITHASH) # Used in configure_file() endif() if (Z3_INCLUDE_GIT_DESCRIBE) get_git_head_describe("${GIT_DIR}" Z3_GIT_DESCRIPTION) @@ -81,7 +84,7 @@ if (EXISTS "${GIT_DIR}") # This mimics the behaviour of the old build system. set(Z3_FULL_VERSION_STR "${Z3_FULL_VERSION_STR} ${Z3_GIT_DESCRIPTION}") else() - message(STATUS "Not including git descrption in version") + message(STATUS "Not including git description in version") endif() else() message(WARNING "Failed to add git dependency.") @@ -93,6 +96,9 @@ else() disable_git_describe() disable_git_hash() endif() +if(NOT Z3_INCLUDE_GIT_HASH) + unset(Z3GITHASH) # Used in configure_file() +endif() ################################################################################ # Useful CMake functions/Macros @@ -153,8 +159,8 @@ list(APPEND Z3_COMPONENT_CXX_DEFINES $<$:_EXTERNAL_RELEAS ################################################################################ # Find Python ################################################################################ -find_package(PythonInterp 3 REQUIRED) -message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") +find_package(Python3 REQUIRED COMPONENTS Interpreter) +message(STATUS "Python3_EXECUTABLE: ${Python3_EXECUTABLE}") ################################################################################ # Target architecture detection @@ -465,7 +471,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}") # generate files used for Z3's build. Changes to these files will trigger # a rebuild of all the generated files. ################################################################################ -# Note: ``update_api.py`` is deliberately not here because it not used +# Note: ``update_api.py`` is deliberately not here because it is not used # to generate every generated file. The targets that need it list it explicitly. set(Z3_GENERATED_FILE_EXTRA_DEPENDENCIES "${PROJECT_SOURCE_DIR}/scripts/mk_genfile_common.py" diff --git a/README-CMake.md b/README-CMake.md index 5845a52c372..9bfc2d9ef29 100644 --- a/README-CMake.md +++ b/README-CMake.md @@ -125,7 +125,7 @@ Note that this is `libz3` not `z3` (`libz3` refers to the library target from `s [Ninja](https://ninja-build.org/) is a simple build system that is built for speed. It can be significantly faster than "UNIX Makefile"s because it is not a recursive -build system and thus doesn't create a new process everytime it traverses into a directory. +build system and thus doesn't create a new process every time it traverses into a directory. Ninja is particularly appropriate if you want fast incremental building. Basic usage is as follows: @@ -236,7 +236,7 @@ more interactive and allow you to change various options. In both these tools the basic steps to follow are: 1. Configure. -2. Change any options you wish. Everytime you change a set of options +2. Change any options you wish. Every time you change a set of options You should configure again. This may cause new options to appear 3. Generate. @@ -277,7 +277,7 @@ The following useful options can be passed to CMake whilst configuring. * ``CMAKE_INSTALL_PYTHON_PKG_DIR`` - STRING. The path to install the z3 python bindings. This can be relative (to ``CMAKE_INSTALL_PREFIX``) or absolute. * ``CMAKE_INSTALL_Z3_CMAKE_PACKAGE_DIR`` - STRING. The path to install CMake package files (e.g. ``/usr/lib/cmake/z3``). * ``CMAKE_INSTALL_API_BINDINGS_DOC`` - STRING. The path to install documentation for API bindings. -* ``PYTHON_EXECUTABLE`` - STRING. The python executable to use during the build. +* ``Python3_EXECUTABLE`` - STRING. The python executable to use during the build. * ``Z3_ENABLE_TRACING_FOR_NON_DEBUG`` - BOOL. If set to ``TRUE`` enable tracing in non-debug builds, if set to ``FALSE`` disable tracing in non-debug builds. Note in debug builds tracing is always enabled. * ``Z3_BUILD_LIBZ3_SHARED`` - BOOL. If set to ``TRUE`` build libz3 as a shared library otherwise build as a static library. * ``Z3_ENABLE_EXAMPLE_TARGETS`` - BOOL. If set to ``TRUE`` add the build targets for building the API examples. @@ -303,7 +303,7 @@ The following useful options can be passed to CMake whilst configuring. * ``Z3_ENABLE_CFI`` - BOOL. If set to ``TRUE`` will enable Control Flow Integrity security checks. This is only supported by MSVC and Clang and will fail on other compilers. This requires Z3_LINK_TIME_OPTIMIZATION to also be enabled. * ``Z3_API_LOG_SYNC`` - BOOL. If set to ``TRUE`` will enable experimental API log sync feature. -* ``WARNINGS_AS_ERRORS`` - STRING. If set to ``TRUE`` compiler warnings will be treated as errors. If set to ``False`` compiler warnings will not be treated as errors. +* ``WARNINGS_AS_ERRORS`` - STRING. If set to ``ON`` compiler warnings will be treated as errors. If set to ``OFF`` compiler warnings will not be treated as errors. If set to ``SERIOUS_ONLY`` a subset of compiler warnings will be treated as errors. * ``Z3_C_EXAMPLES_FORCE_CXX_LINKER`` - BOOL. If set to ``TRUE`` the C API examples will request that the C++ linker is used rather than the C linker. * ``Z3_BUILD_EXECUTABLE`` - BOOL. If set to ``TRUE`` build the z3 executable. Defaults to ``TRUE`` unless z3 is being built as a submodule in which case it defaults to ``FALSE``. @@ -348,7 +348,7 @@ These notes are help developers and packagers of Z3. ### Install/Uninstall Install and uninstall targets are supported. Use ``CMAKE_INSTALL_PREFIX`` to -set the install prefix. If you also need need to control which directories are +set the install prefix. If you also need to control which directories are used for install set the documented ``CMAKE_INSTALL_*`` options. To install run diff --git a/README-Z3.md b/README-Z3.md index 5988a6c8d26..dfbaea599b7 100644 --- a/README-Z3.md +++ b/README-Z3.md @@ -1,6 +1,6 @@ # Z3 -Z3 is a theorem prover from Microsoft Research. +Z3 is a theorem prover from Microsoft Research. It is licensed under the [MIT license](LICENSE.txt). If you are not familiar with Z3, you can start [here](https://github.com/Z3Prover/z3/wiki#background). @@ -8,15 +8,15 @@ If you are not familiar with Z3, you can start [here](https://github.com/Z3Prove Pre-built binaries for stable and nightly releases are available from [here](https://github.com/Z3Prover/z3/releases). Z3 can be built using [Visual Studio][1], a [Makefile][2] or using [CMake][3]. It provides -[bindings for several programming languages][4]. +[bindings for several programming languages][4]. See the [release notes](RELEASE_NOTES.md) for notes on various stable releases of Z3. ## Build status -| Azure Pipelines | Code Coverage | Open Bugs | Android Build | WASM Build | -| --------------- | --------------|-----------|---------------|------------| -| [![Build Status](https://dev.azure.com/Z3Public/Z3/_apis/build/status/Z3Prover.z3?branchName=master)](https://dev.azure.com/Z3Public/Z3/_build/latest?definitionId=1&branchName=master) | [![CodeCoverage](https://github.com/Z3Prover/z3/actions/workflows/coverage.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/coverage.yml) | [![Open Issues](https://github.com/Z3Prover/z3/actions/workflows/wip.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/wip.yml) |[![Android Build](https://github.com/Z3Prover/z3/actions/workflows/android-build.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/android-build.yml) | [![WASM Build](https://github.com/Z3Prover/z3/actions/workflows/wasm.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/wasm.yml) | +| Azure Pipelines | Code Coverage | Open Bugs | Android Build | WASM Build | Windows Build | +| --------------- | --------------|-----------|---------------|------------|---------------| +| [![Build Status](https://dev.azure.com/Z3Public/Z3/_apis/build/status/Z3Prover.z3?branchName=master)](https://dev.azure.com/Z3Public/Z3/_build/latest?definitionId=1&branchName=master) | [![CodeCoverage](https://github.com/Z3Prover/z3/actions/workflows/coverage.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/coverage.yml) | [![Open Issues](https://github.com/Z3Prover/z3/actions/workflows/wip.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/wip.yml) |[![Android Build](https://github.com/Z3Prover/z3/actions/workflows/android-build.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/android-build.yml) | [![WASM Build](https://github.com/Z3Prover/z3/actions/workflows/wasm.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/wasm.yml) | [![Windows](https://github.com/Z3Prover/z3/actions/workflows/Windows.yml/badge.svg)](https://github.com/Z3Prover/z3/actions/workflows/Windows.yml) Docker image. @@ -46,7 +46,7 @@ cd build nmake ``` -Z3 uses C++17. The recommended version of Visual Studio is therefore VS2019. +Z3 uses C++17. The recommended version of Visual Studio is therefore VS2019. ## Building Z3 using make and GCC/Clang @@ -102,13 +102,13 @@ To clean Z3 you can delete the build directory and run the ``mk_make.py`` script ## Building Z3 using CMake Z3 has a build system using CMake. Read the [README-CMake.md](README-CMake.md) -file for details. It is recommended for most build tasks, +file for details. It is recommended for most build tasks, except for building OCaml bindings. ## Building Z3 using vcpkg vcpkg is a full platform package manager, you can easily install libzmq with vcpkg. - + Execute: ```bash @@ -120,8 +120,8 @@ git clone https://github.com/microsoft/vcpkg.git ## Dependencies Z3 itself has few dependencies. It uses C++ runtime libraries, including pthreads for multi-threading. -It is optionally possible to use GMP for multi-precision integers, but Z3 contains its own self-contained -multi-precision functionality. Python is required to build Z3. To build Java, .Net, OCaml, +It is optionally possible to use GMP for multi-precision integers, but Z3 contains its own self-contained +multi-precision functionality. Python is required to build Z3. To build Java, .Net, OCaml, Julia APIs requires installing relevant tool chains. ## Z3 bindings @@ -185,7 +185,7 @@ If you do need to install to a non standard prefix a better approach is to use a [Python virtual environment](https://virtualenv.readthedocs.org/en/latest/) and install Z3 there. Python packages also work for Python3. Under Windows, recall to build inside the Visual C++ native command build environment. -Note that the ``build/python/z3`` directory should be accessible from where python is used with Z3 +Note that the ``build/python/z3`` directory should be accessible from where python is used with Z3 and it depends on ``libz3.dll`` to be in the path. ```bash @@ -230,7 +230,10 @@ to Z3's C API. For more information, see [MachineArithmetic/README.md](https://g * [.NET API](https://z3prover.github.io/api/html/namespace_microsoft_1_1_z3.html) * [Java API](https://z3prover.github.io/api/html/namespacecom_1_1microsoft_1_1z3.html) * [Python API](https://z3prover.github.io/api/html/namespacez3py.html) (also available in [pydoc format](https://z3prover.github.io/api/html/z3.html)) +* [Rust](https://github.com/prove-rs/z3.rs) * C * OCaml * [Julia](https://github.com/ahumenberger/Z3.jl) * [Smalltalk](https://github.com/shingarov/MachineArithmetic/blob/pure-z3/MachineArithmetic/README.md) (supports Pharo and Smalltalk/X) + + diff --git a/README.md b/README.md index 1f754f03c17..c15fb12f876 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Z3-Noodler is an SMT solver for string constraints such as those that occur in symbolic execution and analysis of programs, reasoning about configuration files of cloud services and smart contracts, etc. -Z3-Noodler is based on the SMT solver [Z3 v4.12.2](https://github.com/Z3Prover/z3/releases/tag/z3-4.12.2), in which it replaces the solver for the theory of strings. +Z3-Noodler is based on the SMT solver [Z3 v4.13.0](https://github.com/Z3Prover/z3/releases/tag/z3-4.13.0), in which it replaces the solver for the theory of strings. The core of the string solver implements several decision procedures, but mainly it relies on the equation stabilization algorithm (see [Publications](#publications)). Z3-Noodler utilizes the automata library [Mata](https://github.com/VeriFIT/mata/) for efficient representation of automata and their processing. @@ -68,8 +68,9 @@ str.replace_re_all Furthermore, we do not support string variables as arguments of `str.to_re` and `re.range`. ## Publications -- Y. Chen, D. Chocholatý, V. Havlena, L. Holík, O. Lengál, and J. Síč. [Solving String Constraints with Lengths by Stabilization](https://doi.org/10.1145/3622872). In *Proc. of OOPSLA'23*, Cascais, Portugal, Volume 7, Issue OOPSLA2, pages 2112–2141, 2023. ACM. -- F. Blahoudek, Y. Chen, D. Chocholatý, V. Havlena, L. Holík, O. Lengál, and J. Síč. [Word Equations in Synergy with Regular Constraints](https://doi.org/10.1007/978-3-031-27481-7_23). In *Proc. of FM’23*, Lübeck, Germany, volume 14000 of LNCS, pages 403–423, 2023. Springer. +- Y. Chen, D. Chocholatý, V. Havlena, L. Holík, O. Lengál, and J. Síč. [Z3-Noodler: An Automata-based String Solver](https://doi.org/10.1007/978-3-031-57246-3_2). In *Proc. of TACAS'24*, volume 14570 of LNCS, pages 24-33, 2024. Springer. +- Y. Chen, D. Chocholatý, V. Havlena, L. Holík, O. Lengál, and J. Síč. [Solving String Constraints with Lengths by Stabilization](https://doi.org/10.1145/3622872). In *Proc. of OOPSLA'23*, Volume 7, Issue OOPSLA2, pages 2112–2141, 2023. ACM. +- F. Blahoudek, Y. Chen, D. Chocholatý, V. Havlena, L. Holík, O. Lengál, and J. Síč. [Word Equations in Synergy with Regular Constraints](https://doi.org/10.1007/978-3-031-27481-7_23). In *Proc. of FM’23*, volume 14000 of LNCS, pages 403–423, 2023. Springer. ## Z3-Noodler source files diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 68228df40e7..3830f566ffa 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,6 +10,55 @@ Version 4.next - native word level bit-vector solving. - introduction of simple induction lemmas to handle a limited repertoire of induction proofs. +Version 4.13.0 +============== +- add ARM64 wheels for Python, thanks to Steven Moy, smoy + +Version 4.12.6 +============== +- remove expensive rewrite that coalesces adjacent stores +- improved Java use of reference queues thanks to Thomas Haas #7131 +- fixes to conditional import of python library thanks to Cal Jacobson #7116 +- include universe for constants that get removed during pre-processing #7121 +- code improvements, Bruce Mitchener #7119 +- fix nested callback handling for user propagators +- include ARM64 binaries in distribution +- added Julia API, Thanks to Yisu Remy Yang #7108 + +Version 4.12.5 +============== +- Fixes to pypi setup and build for MacOS distributions +- A new theory solver "int-blast" enabled by using: + - sat.smt=true smt.bv.solver=2 + - It solves a few bit-vector problems not handled by bit-blasting, especially if the bit-widths are large. + - It is based on encoding bit-vector constraints to non-linear integer arithmetic. +- Optimizations to the arithmetic solver. Description: https://github.com/Z3Prover/doc/tree/master/arithmetic + +Version 4.12.4 +============== +- Re-release fixing a few issues with 4.12: + - Python dependency on importlib.resources vs importlib_resources break automatic pypi installations. Supposedly fixed by conditioning dependency on Python 3.9 where the feature is built-in. + - Missing release of arm64 for Ubuntu. + - Futile attempt to streamline adding readme.md file as part of Nuget distribution. Nuget.org now requires a readme file. I was able to integrate the readme with the cmake build, but the cross-platform repackage in scripts/mk_nuget_task.py does not ingest a similar readme file with the CI pipelines. + +Version 4.12.3 +============== +- Alpha support for polymorphism. + - SMTLIB3-ish, C, Python + It adds the new command `(declare-type-var A)` that declares a symbol (in this case `A`) globally as a polymorphic type variable. + The C API contains a new function `Z3_mk_type_variable` and a new enumeration case `Z3_TYPE_VAR` as a kind associated with sorts. + All occurrences of `A` are treated as type variables. A function declaration whose signature uses `A` is treated as a shorthand + for declarations of all functions that use instances of `A`. + Assertions that use type variables are shorthands for assertions covering all instantiations. +- Various (ongoing) performance fixes and improvements to smt.arith.solver=6 +- A working version of solver.proof.trim=true option. Proofs logs created when using sat.smt=true may be trimmed by running z3 + on the generated proof log using the option solver.proof.trim=true. +- Optimizations LIA and NIA (linear integer arithmetic and non-linear integer (and real) arithmetic reasoning). + smt.arith.solver=6 is the default for most use cases. It trails smt.arith.solver=2 in some scenarios and the gap has been either removed or reduced. + smt.arith.solver=6 is complete for integrations of non-linear real arithmetic and theories, smt.arith.solver=2 is not. +- qel: Light quantifier elimination based on term graphs (egraphs), and corresponding Model Based Projection for arrays and ADTs. Used by Spacer and QSAT. +- added real-closed fields features to C API, exposed more RCF over OCaml API +- fixes to FP Version 4.12.2 ============== @@ -59,7 +108,7 @@ Version 4.12.0 Clauses that are deduced by theories are marked by default by 'smt', and when more detailed information is available with proof hints or proof objects. - Instantations are considered useful to track so they + Instantiations are considered useful to track so they are logged using terms of the form (inst (not (forall (x) body)) body[t/x] (bind t)), where @@ -85,7 +134,7 @@ Version 4.12.0 checker cannot check. It is mainly a limitation of the arithmetic solver not pulling relevant information. Ensuring a tight coupling with proof hints and the validator - capabilites is open ended future work and good material for theses. + capabilities is open ended future work and good material for theses. - bit-vector inferences - are treated as trusted (there is no validation, it always blindly succeeds) - arrays, datatypes - there is no custom validation for @@ -155,13 +204,13 @@ Version 4.11.2 with SMT format that is extensible. The resulting format is a mild extension of SMTLIB with three extra commands assume, learn, del. They track input clauses, generated clauses and deleted clauses. They are optionally augmented by proof hints. Two proof hints are used in the current version: "rup" and "farkas". - "rup" is used whent the generated clause can be justified by reverse unit propagation. "farkas" is used when + "rup" is used when the generated clause can be justified by reverse unit propagation. "farkas" is used when the clause can be justified by a combination of Farkas cutting planes. There is a built-in proof checker for the format. Quantifier instantiations are also tracked as proof hints. - Other proof hints are to be added as the feature set is tested and developed. The fallback, buit-in, + Other proof hints are to be added as the feature set is tested and developed. The fallback, built-in, self-checker uses z3 to check that the generated clause is a consequence. Note that this is generally insufficient as generated clauses are in principle required to only be satisfiability preserving. - Proof checking and tranformation operations is overall open ended. + Proof checking and transformation operations is overall open ended. The log for the first commit introducing this change contains further information on the format. - fix to re-entrancy bug in user propagator (thanks to Clemens Eisenhofer). - handle _toExpr for quantified formulas in JS bindings @@ -635,7 +684,7 @@ xor88, parno, gario, Bauna, GManNickG, hanwentao, dinu09, fhowar, Cici, chinissa (assert F) (check-sat a) (check-sat) - If 'F' is unstatisfiable independently of the assumption 'a', and + If 'F' is unsatisfiable independently of the assumption 'a', and the inconsistency can be detected by just performing propagation, Then, version <= 4.3.1 may return unsat diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 382c2efc9e5..b3bc0f22652 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -43,6 +43,39 @@ jobs: - ${{if eq(variables['runRegressions'], 'True')}}: - template: scripts/test-regressions.yml +- job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" + pool: + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" + steps: + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - job: "Ubuntu20OCaml" displayName: "Ubuntu 20 with OCaml" pool: @@ -144,7 +177,7 @@ jobs: buildCmd: 'CC=clang CXX=clang++ cmake -DCMAKE_BUILD_TYPE=Release $(cmakeStdArgs)' runTests: 'True' debugClang: - setupCmd1: 'julia -e "using Pkg; Pkg.add(PackageSpec(name=\"libcxxwrap_julia_jll\", version=\"0.7.0\"))"' + setupCmd1: 'julia -e "using Pkg; Pkg.add(PackageSpec(name=\"libcxxwrap_julia_jll\"))"' setupCmd2: 'JlCxxDir=$(julia -e "using libcxxwrap_julia_jll; print(dirname(libcxxwrap_julia_jll.libcxxwrap_julia_path))")' buildCmd: 'CC=clang CXX=clang++ cmake -DJlCxx_DIR=$JlCxxDir/cmake/JlCxx $(cmakeJulia) $(cmakeStdArgs)' runTests: 'True' @@ -183,63 +216,6 @@ jobs: - template: scripts/test-regressions.yml -- job: "WindowsLatest" - displayName: "Windows" - pool: - vmImage: "windows-latest" - strategy: - matrix: - x86: - arch: 'x86' - setupCmd1: '' - setupCmd2: '' - setupCmd3: '' - bindings: '$(cmakePy)' - runTests: 'False' - x64: - arch: 'x64' - setupCmd1: 'julia -e "using Pkg; Pkg.add(PackageSpec(name=\"libcxxwrap_julia_jll\", version=\"0.7.0\"))"' - setupCmd2: 'julia -e "using libcxxwrap_julia_jll; print(dirname(libcxxwrap_julia_jll.libcxxwrap_julia_path))" > tmp.env' - setupCmd3: 'set /P JlCxxDir= fruit = ctx.mkEnumSort(name, ctx.mkSymbol("apple"), - ctx.mkSymbol("banana"), ctx.mkSymbol("orange")); + EnumSort fruit = ctx.mkEnumSort(name, ctx.mkSymbol("apple2"), + ctx.mkSymbol("banana2"), ctx.mkSymbol("orange2")); System.out.println((fruit.getConsts()[0])); System.out.println((fruit.getConsts()[1])); diff --git a/examples/ml/ml_example.ml b/examples/ml/ml_example.ml index 5b4e6e9ed1d..318c805e8dd 100644 --- a/examples/ml/ml_example.ml +++ b/examples/ml/ml_example.ml @@ -1,4 +1,4 @@ -(* +(* Copyright (C) 2012 Microsoft Corporation Author: CM Wintersteiger (cwinter) 2012-12-17 *) @@ -19,7 +19,6 @@ open Z3.Arithmetic.Integer open Z3.Arithmetic.Real open Z3.BitVector - exception TestFailedException of string (** @@ -31,14 +30,14 @@ let model_converter_test ( ctx : context ) = let yr = (Expr.mk_const ctx (Symbol.mk_string ctx "y") (Real.mk_sort ctx)) in let g4 = (mk_goal ctx true false false ) in (Goal.add g4 [ (mk_gt ctx xr (Real.mk_numeral_nd ctx 10 1)) ]) ; - (Goal.add g4 [ (mk_eq ctx + (Goal.add g4 [ (mk_eq ctx yr (Arithmetic.mk_add ctx [ xr; (Real.mk_numeral_nd ctx 1 1) ])) ]) ; (Goal.add g4 [ (mk_gt ctx yr (Real.mk_numeral_nd ctx 1 1)) ]) ; ( let ar = (Tactic.apply (mk_tactic ctx "simplify") g4 None) in - if ((get_num_subgoals ar) == 1 && - ((is_decided_sat (get_subgoal ar 0)) || + if ((get_num_subgoals ar) == 1 && + ((is_decided_sat (get_subgoal ar 0)) || (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") else @@ -46,8 +45,8 @@ let model_converter_test ( ctx : context ) = ); ( let ar = (Tactic.apply (and_then ctx (mk_tactic ctx ("simplify")) (mk_tactic ctx "solve-eqs") []) g4 None) in - if ((get_num_subgoals ar) == 1 && - ((is_decided_sat (get_subgoal ar 0)) || + if ((get_num_subgoals ar) == 1 && + ((is_decided_sat (get_subgoal ar 0)) || (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") else @@ -57,15 +56,15 @@ let model_converter_test ( ctx : context ) = let f e = (Solver.add solver [ e ]) in ignore (List.map f (get_formulas (get_subgoal ar 0))) ; let q = (check solver []) in - if q != SATISFIABLE then + if q != SATISFIABLE then raise (TestFailedException "") else - let m = (get_model solver) in - match m with + let m = (get_model solver) in + match m with | None -> raise (TestFailedException "") - | Some (m) -> + | Some (m) -> Printf.printf "Solver says: %s\n" (string_of_status q) ; - Printf.printf "Model: \n%s\n" (Model.to_string m) + Printf.printf "Model: \n%s\n" (Model.to_string m) ) (** @@ -79,7 +78,7 @@ let basic_tests ( ctx : context ) = let bs = (Boolean.mk_sort ctx) in let domain = [ bs; bs ] in let f = (FuncDecl.mk_func_decl ctx fname domain bs) in - let fapp = (mk_app ctx f + let fapp = (mk_app ctx f [ (Expr.mk_const ctx x bs); (Expr.mk_const ctx y bs) ]) in let fargs2 = [ (mk_fresh_const ctx "cp" bs) ] in let domain2 = [ bs ] in @@ -100,8 +99,8 @@ let basic_tests ( ctx : context ) = ); ( let ar = (Tactic.apply (mk_tactic ctx "simplify") g None) in - if ((get_num_subgoals ar) == 1 && - ((is_decided_sat (get_subgoal ar 0)) || + if ((get_num_subgoals ar) == 1 && + ((is_decided_sat (get_subgoal ar 0)) || (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") else @@ -109,28 +108,28 @@ let basic_tests ( ctx : context ) = ); ( let ar = (Tactic.apply (mk_tactic ctx "smt") g None) in - if ((get_num_subgoals ar) == 1 && + if ((get_num_subgoals ar) == 1 && (not (is_decided_sat (get_subgoal ar 0)))) then raise (TestFailedException "") else Printf.printf "Test passed.\n" ); - (Goal.add g [ (mk_eq ctx + (Goal.add g [ (mk_eq ctx (mk_numeral_int ctx 1 (BitVector.mk_sort ctx 32)) (mk_numeral_int ctx 2 (BitVector.mk_sort ctx 32))) ] ) ; ( let ar = (Tactic.apply (mk_tactic ctx "smt") g None) in - if ((get_num_subgoals ar) == 1 && + if ((get_num_subgoals ar) == 1 && (not (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") - else + else Printf.printf "Test passed.\n" ); ( let g2 = (mk_goal ctx true true false) in let ar = (Tactic.apply (mk_tactic ctx "smt") g2 None) in - if ((get_num_subgoals ar) == 1 && + if ((get_num_subgoals ar) == 1 && (not (is_decided_sat (get_subgoal ar 0)))) then raise (TestFailedException "") else @@ -140,10 +139,10 @@ let basic_tests ( ctx : context ) = let g2 = (mk_goal ctx true true false) in (Goal.add g2 [ (Boolean.mk_false ctx) ]) ; let ar = (Tactic.apply (mk_tactic ctx "smt") g2 None) in - if ((get_num_subgoals ar) == 1 && + if ((get_num_subgoals ar) == 1 && (not (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") - else + else Printf.printf "Test passed.\n" ); ( @@ -155,10 +154,10 @@ let basic_tests ( ctx : context ) = let constr = (mk_eq ctx xc yc) in (Goal.add g3 [ constr ] ) ; let ar = (Tactic.apply (mk_tactic ctx "smt") g3 None) in - if ((get_num_subgoals ar) == 1 && + if ((get_num_subgoals ar) == 1 && (not (is_decided_unsat (get_subgoal ar 0)))) then raise (TestFailedException "") - else + else Printf.printf "Test passed.\n" ) ; model_converter_test ctx ; @@ -169,12 +168,12 @@ let basic_tests ( ctx : context ) = Printf.printf "Numerator: %s Denominator: %s\n" (Real.numeral_to_string inum) (Real.numeral_to_string iden) ; if ((Real.numeral_to_string inum) <> "42" || (Real.numeral_to_string iden) <> "43") then raise (TestFailedException "") - else + else Printf.printf "Test passed.\n" ; if ((to_decimal_string rn 3) <> "0.976?") then raise (TestFailedException "") - else + else Printf.printf "Test passed.\n" ; if (to_decimal_string (Real.mk_numeral_s ctx "-1231231232/234234333") 5 <> "-5.25640?") then @@ -193,7 +192,7 @@ let basic_tests ( ctx : context ) = raise (TestFailedException "check") ) with Z3.Error(_) -> ( - Printf.printf "Exception caught, OK.\n" + Printf.printf "Exception caught, OK.\n" ) (** @@ -212,22 +211,22 @@ let quantifier_example1 ( ctx : context ) = let xs = [ (Integer.mk_const ctx (List.nth names 0)); (Integer.mk_const ctx (List.nth names 1)); (Integer.mk_const ctx (List.nth names 2)) ] in - - let body_vars = (Boolean.mk_and ctx - [ (mk_eq ctx - (Arithmetic.mk_add ctx [ (List.nth vars 0) ; (Integer.mk_numeral_i ctx 1)]) + + let body_vars = (Boolean.mk_and ctx + [ (mk_eq ctx + (Arithmetic.mk_add ctx [ (List.nth vars 0) ; (Integer.mk_numeral_i ctx 1)]) (Integer.mk_numeral_i ctx 2)) ; - (mk_eq ctx + (mk_eq ctx (Arithmetic.mk_add ctx [ (List.nth vars 1); (Integer.mk_numeral_i ctx 2)]) (Arithmetic.mk_add ctx [ (List.nth vars 2); (Integer.mk_numeral_i ctx 3)])) ]) in let body_const = (Boolean.mk_and ctx - [ (mk_eq ctx - (Arithmetic.mk_add ctx [ (List.nth xs 0); (Integer.mk_numeral_i ctx 1)]) + [ (mk_eq ctx + (Arithmetic.mk_add ctx [ (List.nth xs 0); (Integer.mk_numeral_i ctx 1)]) (Integer.mk_numeral_i ctx 2)) ; - (mk_eq ctx + (mk_eq ctx (Arithmetic.mk_add ctx [ (List.nth xs 1); (Integer.mk_numeral_i ctx 2)]) (Arithmetic.mk_add ctx [ (List.nth xs 2); (Integer.mk_numeral_i ctx 3)])) ]) in - + let x = (Quantifier.mk_forall ctx types names body_vars (Some 1) [] [] (Some (Symbol.mk_string ctx "Q1")) (Some (Symbol.mk_string ctx "skid1"))) in Printf.printf "Quantifier X: %s\n" (Quantifier.to_string x) ; let y = (Quantifier.mk_forall_const ctx xs body_const (Some 1) [] [] (Some (Symbol.mk_string ctx "Q2")) (Some (Symbol.mk_string ctx "skid2"))) in @@ -242,8 +241,8 @@ let quantifier_example1 ( ctx : context ) = open Z3.FloatingPoint -(** - A basic example of floating point arithmetic +(** + A basic example of floating point arithmetic **) let fpa_example ( ctx : context ) = Printf.printf "FPAExample\n" ; @@ -271,7 +270,7 @@ let fpa_example ( ctx : context ) = (Boolean.mk_not ctx (mk_is_nan ctx y)) ; (Boolean.mk_not ctx (mk_is_infinite ctx y)) ] in let args3 = [ c3 ; (Boolean.mk_and ctx and_args) ] in - let c4 = (Boolean.mk_and ctx args3) in + let c4 = (Boolean.mk_and ctx args3) in (Printf.printf "c4: %s\n" (Expr.to_string c4)) ; ( let solver = (mk_solver ctx None) in @@ -293,7 +292,7 @@ let fpa_example ( ctx : context ) = let c2 = (mk_to_fp_bv ctx (mk_numeral_string ctx "4619567317775286272" (BitVector.mk_sort ctx 64)) (mk_sort ctx 11 53)) in - let c3 = (mk_to_fp_int_real ctx + let c3 = (mk_to_fp_int_real ctx (RoundingMode.mk_rtz ctx) (mk_numeral_string ctx "2" (Integer.mk_sort ctx)) (mk_numeral_string ctx "1.75" (Real.mk_sort ctx)) @@ -304,18 +303,80 @@ let fpa_example ( ctx : context ) = let args3 = [ (mk_eq ctx c1 c2) ; (mk_eq ctx c1 c3) ; (mk_eq ctx c1 c4) ] in - let c5 = (Boolean.mk_and ctx args3) in + let c5 = (Boolean.mk_and ctx args3) in (Printf.printf "c5: %s\n" (Expr.to_string c5)) ; ( let solver = (mk_solver ctx None) in (Solver.add solver [ c5 ]) ; + Printf.printf "Memory in use before `check`: %Lu bytes\n" (Statistics.get_estimated_alloc_size()); if (check solver []) != SATISFIABLE then raise (TestFailedException "") else Printf.printf "Test passed.\n" - ) + ) + +(** + A basic example of RCF usage +**) +let rcf_example ( ctx : context ) = + Printf.printf "RCFExample\n" ; + let pi = RCF.mk_pi ctx in + let e = RCF.mk_e ctx in + let inf0 = RCF.mk_infinitesimal ctx in + let inf1 = RCF.mk_infinitesimal ctx in + let r = RCF.mk_rational ctx "42.001" in + let pi_div_e = RCF.div ctx pi e in + let pi_div_r = RCF.div ctx pi r in + (Printf.printf "e: %s, pi: %s, e==pi: %b, e < pi: %b\n" + (RCF.num_to_string ctx e true false) + (RCF.num_to_string ctx pi true false) + (RCF.eq ctx e pi) + (RCF.lt ctx e pi)) ; + Printf.printf "pi_div_e: %s.\n" (RCF.num_to_string ctx pi_div_e true false); + Printf.printf "pi_div_r: %s.\n" (RCF.num_to_string ctx pi_div_r true false); + Printf.printf "inf0: %s.\n" (RCF.num_to_string ctx inf0 true false); + Printf.printf "(RCF.is_rational ctx pi): %b.\n" (RCF.is_rational ctx pi); + Printf.printf "(RCF.is_algebraic ctx pi): %b.\n" (RCF.is_algebraic ctx pi); + Printf.printf "(RCF.is_transcendental ctx pi): %b.\n" (RCF.is_transcendental ctx pi); + Printf.printf "(RCF.is_rational ctx r): %b.\n" (RCF.is_rational ctx r); + Printf.printf "(RCF.is_algebraic ctx r): %b.\n" (RCF.is_algebraic ctx r); + Printf.printf "(RCF.is_transcendental ctx r): %b.\n" (RCF.is_transcendental ctx r); + Printf.printf "(RCF.is_infinitesimal ctx inf0): %b.\n" (RCF.is_infinitesimal ctx inf0); + Printf.printf "(RCF.extension_index ctx inf0): %d.\n" (RCF.extension_index ctx inf0); + Printf.printf "(RCF.extension_index ctx inf1): %d.\n" (RCF.extension_index ctx inf1); + let poly:RCF.rcf_num list = [ e; pi; inf0 ] in + let rs:RCF.root list = RCF.roots ctx poly in + let print_root (x:RCF.root) = + begin + Printf.printf "root: %s\n%!" (RCF.num_to_string ctx x.obj true false); + if RCF.is_algebraic ctx x.obj then ( + (match x.interval with + | Some ivl -> Printf.printf " interval: (%b, %b, %s, %b, %b, %s)\n" + ivl.lower_is_inf + ivl.lower_is_open + (RCF.num_to_string ctx ivl.lower true false) + ivl.upper_is_inf + ivl.upper_is_open + (RCF.num_to_string ctx ivl.upper true false); + | None -> ()); + Printf.printf " polynomial coefficients:"; + List.iter (fun c -> Printf.printf " %s" (RCF.num_to_string ctx c false false)) x.polynomial; + Printf.printf "\n"; + Printf.printf " sign conditions:"; + List.iter + (fun (poly, sign) -> + List.iter (fun p -> Printf.printf " %s" (RCF.num_to_string ctx p true false)) poly; + Printf.printf " %s" (if sign > 0 then "> 0" else if sign < 0 then "< 0" else "= 0")) + x.sign_conditions; + Printf.printf "\n") + end + in + List.iter print_root rs; + RCF.del_roots ctx rs; + RCF.del_list ctx [pi; e; inf0; inf1; r; pi_div_e; pi_div_r]; + Printf.printf "Test passed.\n" -let _ = +let _ = try ( if not (Log.open_ "z3.log") then raise (TestFailedException "Log couldn't be opened.") @@ -340,6 +401,7 @@ let _ = basic_tests ctx ; quantifier_example1 ctx ; fpa_example ctx ; + rcf_example ctx ; Printf.printf "Disposing...\n"; Gc.full_major () ); @@ -348,5 +410,6 @@ let _ = ) with Error(msg) -> ( Printf.printf "Z3 EXCEPTION: %s\n" msg ; exit 1 - ) + ) ;; + diff --git a/examples/python/bincover.py b/examples/python/bincover.py new file mode 100644 index 00000000000..d8a81c25a84 --- /dev/null +++ b/examples/python/bincover.py @@ -0,0 +1,382 @@ +from z3 import * +import math + +# Rudimentary bin cover solver using the UserPropagator feature. +# It supports the most basic propagation for bin covering. +# - each bin has a propositional variable set to true if the bin is covered +# - each item has a bit-vector recording the assigned bin +# It searches for a locally optimal solution. + +class Bin: + """ + Each bin carries values: + - min_bound - the lower bound required to be added to bin + - weight - the sum of weight of items currently added to bin + - slack - the difference between the maximal possible assignment and the assignments to other bin2bound. + - var - is propagated to true/false if the bin gets filled/cannot be filled. + """ + def __init__(self, min_bound, index): + assert min_bound > 0 + assert index >= 0 + self.index = index + self.min_bound = min_bound + self.weight = 0 + self.slack = 0 + self.added = [] + self.var = Bool(f"bin-{index}") + + def set_slack(self, slack): + self.slack = slack + + def set_fill(self, fill): + self.weight = fill + + def __repr__(self): + return f"{self.var}:bound-{self.min_bound}" + + +class Item: + def __init__(self, weight, index): + self.weight = weight + self.index = index + self.var = None + + def set_var(self, num_bits): + self.var = BitVec(f"binof-{self.index}", num_bits) + + def __repr__(self): + return f"binof-{self.index}:weight-{self.weight}" + +class BranchAndBound: + """Branch and Bound solver. + It keeps track of a current best score and a slack that tracks bins that are set unfilled. + It blocks branches that are worse than the current best score. + In Final check it blocks the current assignment. + """ + def __init__(self, user_propagator): + self.up = user_propagator + + def init(self, soft_literals): + self.value = 0 + self.best = 0 + self.slack = 0 + self.id2weight = {} + self.assigned_to_false = [] + for p, weight in soft_literals: + self.slack += weight + self.id2weight[p.get_id()] = weight + + def fixed(self, p, value): + weight = self.id2weight[p.get_id()] + if is_true(value): + old_value = self.value + self.up.trail += [lambda : self._undo_value(old_value)] + self.value += weight + elif self.best > self.slack - weight: + self.assigned_to_false += [ p ] + self.up.conflict(self.assigned_to_false) + self.assigned_to_false.pop(-1) + else: + old_slack = self.slack + self.up.trail += [lambda : self._undo_slack(old_slack)] + self.slack -= weight + self.assigned_to_false += [p] + + def final(self): + if self.value > self.best: + self.best = self.value + print("Number of bins filled", self.value) + for bin in self.up.bins: + print(bin.var, bin.added) + self.up.conflict(self.assigned_to_false) + + def _undo_value(self, old_value): + self.value = old_value + + def _undo_slack(self, old_slack): + self.slack = old_slack + self.assigned_to_false.pop(-1) + +class BinCoverSolver(UserPropagateBase): + """Represent a bin-covering problem by associating each bin with a variable + For each item i associate a bit-vector + - bin-of-i that carries the bin identifier where an item is assigned. + + """ + + def __init__(self, s=None, ctx=None): + UserPropagateBase.__init__(self, s, ctx) + self.bins = [] + self.items = [] + self.item2index = {} + self.trail = [] # Undo stack + self.lim = [] + self.solver = s + self.initialized = False + self.add_fixed(lambda x, v : self._fixed(x, v)) + self.branch_and_bound = None + + + # Initialize bit-vector variables for items. + # Register the bit-vector variables with the user propagator to get callbacks + # Ensure the bit-vector variables are assigned to a valid bin. + # Initialize the slack of each bin. + def init(self): + print(self.bins, len(self.bins)) + print(self.items) + assert not self.initialized + self.initialized = True + powerof2, num_bits = self._num_bits() + for item in self.items: + item.set_var(num_bits) + self.item2index[item.var.get_id()] = item.index + self.add(item.var) + if not powerof2: + bound = BitVecVal(len(self.bins), num_bits) + ineq = ULT(item.var, bound) + self.solver.add(ineq) + total_weight = sum(item.weight for item in self.items) + for bin in self.bins: + bin.slack = total_weight + + # + # Register optional branch and bound weighted solver. + # If it is registered, it + def init_branch_and_bound(self): + soft = [(bin.var, 1) for bin in self.bins] + self.branch_and_bound = BranchAndBound(self) + self.branch_and_bound.init(soft) + for bin in self.bins: + self.add(bin.var) + self.add_final(lambda : self.branch_and_bound.final()) + + def add_bin(self, min_bound): + assert not self.initialized + index = len(self.bins) + bin = Bin(min_bound, index) + self.bins += [bin] + return bin + + def add_item(self, weight): + assert not self.initialized + assert weight > 0 + index = len(self.items) + item = Item(weight, index) + self.items += [item] + return item + + def num_items(self): + return len(self.items) + + def num_bins(self): + return len(self.bins) + + def _num_bits(self): + log = math.log2(self.num_bins()) + if log.is_integer(): + return True, int(log) + else: + return False, int(log) + 1 + + def _set_slack(self, bin, slack_value): + bin.slack = slack_value + + def _set_fill(self, bin, fill_value): + bin.weight = fill_value + bin.added.pop() + + def _itemvar2item(self, v): + index = self.item2index[v.get_id()] + if index >= len(self.items): + return None + return self.items[index] + + def _value2bin(self, value): + assert isinstance(value, BitVecNumRef) + bin_index = value.as_long() + if bin_index >= len(self.bins): + return NOne + return self.bins[bin_index] + + def _add_item2bin(self, item, bin): + # print("add", item, "to", bin) + old_weight = bin.weight + bin.weight += item.weight + bin.added += [item] + self.trail += [lambda : self._set_fill(bin, old_weight)] + if old_weight < bin.min_bound and old_weight + item.weight >= bin.min_bound: + self._propagate_filled(bin) + + # This item can never go into bin + def _exclude_item2bin(self, item, bin): + # print("exclude", item, "from", bin) + # Check if bin has already been blocked + if bin.slack < bin.min_bound: + return + if bin.weight >= bin.min_bound: + return + old_slack = bin.slack + new_slack = old_slack - item.weight + bin.slack = new_slack + self.trail += [lambda : self._set_slack(bin, old_slack)] + # If the new slack does not permit the bin to be filled, propagate + if new_slack < bin.min_bound: + self._propagate_slack(bin) + + + # Callback from Z3 when an item gets fixed. + def _fixed(self, _item, value): + if self.branch_and_bound and is_bool(value): + self.branch_and_bound.fixed(_item, value) + return + item = self._itemvar2item(_item) + if item is None: + print("no item for ", _item) + return + bin = self._value2bin(value) + if bin is None: + print("no bin for ", value) + return + self._add_item2bin(item, bin) + for idx in range(len(self.bins)): + if idx == bin.index: + continue + other_bin = self.bins[idx] + self._exclude_item2bin(item, other_bin) + + def _propagate_filled(self, bin): + """Propagate that bin_index is filled justified by the set of + items that have been added + """ + justification = [i.var for i in bin.added] + self.propagate(bin.var, justification) + + def _propagate_slack(self, bin): + """Propagate that bin_index cannot be filled""" + justification = [] + for other_bin in self.bins: + if other_bin.index == bin.index: + continue + justification += other_bin.added + justification = [item.var for item in justification] + self.propagate(Not(bin.var), justification) + + def push(self): + self.lim += [len(self.trail)] + + def pop(self, n): + head = self.lim[len(self.lim) - n] + while len(self.trail) > head: + self.trail[-1]() + self.trail.pop(-1) + self.lim = self.lim[0:len(self.lim)-n] + +# Find a first maximally satisfying subset +class MaximalSatisfyingSubset: + def __init__(self, s): + self.s = s + self.model = None + + def tt(self, f): + return is_true(self.model.eval(f)) + + def get_mss(self, ps): + s = self.s + if sat != s.check(): + return [] + self.model = s.model() + mss = { q for q in ps if self.tt(q) } + return self._get_mss(mss, ps) + + def _get_mss(self, mss, ps): + ps = set(ps) - mss + backbones = set([]) + s = self.s + while len(ps) > 0: + p = ps.pop() + if sat == s.check(mss | backbones | { p }): + self.model = s.model() + mss = mss | { p } | { q for q in ps if self.tt(q) } + ps = ps - mss + else: + backbones = backbones | { Not(p) } + return mss + + +class OptimizeBinCoverSolver: + def __init__(self): + self.solver = Solver() + self.bin_solver = BinCoverSolver(self.solver) + self.mss_solver = MaximalSatisfyingSubset(self.solver) + + # + # Facilities to set up solver + # First add items and bins. + # Keep references to the returned objects. + # Then call init + # Then add any other custom constraints to the "solver" object. + # + def init(self): + self.bin_solver.init() + + def add_item(self, weight): + return self.bin_solver.add_item(weight) + + def add_bin(self, min_bound): + return self.bin_solver.add_bin(min_bound) + + def optimize(self): + self.init() + mss = self.mss_solver.get_mss([bin.var for bin in self.bin_solver.bins]) + print(self.mss_solver.model) + print("filled bins", mss) + print("bin contents") + for bin in self.bin_solver.bins: + print(bin, bin.added) + + +def example1(): + s = OptimizeBinCoverSolver() + i1 = s.add_item(2) + i2 = s.add_item(4) + i3 = s.add_item(5) + i4 = s.add_item(2) + b1 = s.add_bin(3) + b2 = s.add_bin(6) + b3 = s.add_bin(1) + s.optimize() + +#example1() + + +class BranchAndBoundCoverSolver: + def __init__(self): + self.solver = Solver() + self.bin_solver = BinCoverSolver(self.solver) + + def init(self): + self.bin_solver.init() + self.bin_solver.init_branch_and_bound() + + def add_item(self, weight): + return self.bin_solver.add_item(weight) + + def add_bin(self, min_bound): + return self.bin_solver.add_bin(min_bound) + + def optimize(self): + self.init() + self.solver.check() + +def example2(): + s = BranchAndBoundCoverSolver() + i1 = s.add_item(2) + i2 = s.add_item(4) + i3 = s.add_item(5) + i4 = s.add_item(2) + b1 = s.add_bin(3) + b2 = s.add_bin(6) + b3 = s.add_bin(1) + s.optimize() + +example2() diff --git a/examples/python/proofreplay.py b/examples/python/proofreplay.py index c8c9ff47eb4..5c82f43a96c 100644 --- a/examples/python/proofreplay.py +++ b/examples/python/proofreplay.py @@ -86,7 +86,7 @@ # The pair -inst 2 indicates that two quantifier instantiations were not self-validated # They were instead validated using a call to SMT solving. A log for an smt invocation # is exemplified in the next line. - # Note that the pair +inst 6 indicates that 6 quantifier instantations were validated + # Note that the pair +inst 6 indicates that 6 quantifier instantiations were validated # using a syntactic (cheap) check. Some quantifier instantiations based on quantifier elimination # are not simple substitutions and therefore a simple syntactic check does not suffice. set_param("solver.proof.check", True) diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index 4c22d5b852f..00000000000 --- a/package-lock.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "requires": true, - "lockfileVersion": 1, - "dependencies": { - "async-mutex": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/async-mutex/-/async-mutex-0.3.2.tgz", - "integrity": "sha512-HuTK7E7MT7jZEh1P9GtRW9+aTWiDWWi9InbZ5hjxrnRa39KS4BW04+xLBhYNS2aXhHUIKZSw3gj4Pn1pj+qGAA==", - "requires": { - "tslib": "^2.3.1" - } - }, - "tslib": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.4.0.tgz", - "integrity": "sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ==" - }, - "z3-solver": { - "version": "4.9.0", - "resolved": "https://registry.npmjs.org/z3-solver/-/z3-solver-4.9.0.tgz", - "integrity": "sha512-clSV0uyHsfrO84pSbHxoqvmd5HgSG4CoSJG2f8U65hBVylbV6p/0svctQWee9W2fWo0IsxHYRjxz2Z85GT0LAA==", - "requires": { - "async-mutex": "^0.3.2" - } - } - } -} diff --git a/scripts/build-win-signed-cmake.yml b/scripts/build-win-signed-cmake.yml new file mode 100644 index 00000000000..620379f0c4e --- /dev/null +++ b/scripts/build-win-signed-cmake.yml @@ -0,0 +1,99 @@ +parameters: + ReleaseVersion: '' + BuildArchitecture: '' + VCArchitecture: '' + +jobs: +- job: WindowsBuild${{parameters.BuildArchitecture}} + displayName: "Windows build (${{parameters.BuildArchitecture}})" + pool: + vmImage: "windows-latest" + steps: + - powershell: write-host $(System.TeamProjectId) + displayName: 'System.TeamProjectId' + - powershell: write-host $(System.DefinitionId) + displayName: 'System.DefinitionId' + - powershell: write-host $(Build.BuildId) + displayName: 'Build.BuildId' + - task: CmdLine@2 + displayName: Build + inputs: + script: + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{parameters.VCArchitecture}} & + python scripts\mk_win_dist_cmake.py + --${{parameters.BuildArchitecture}}-only + --assembly-version=${{parameters.ReleaseVersion}} + --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk + - task: CopyFiles@2 + displayName: 'Collect Symbols' + inputs: + sourceFolder: build-dist/${{parameters.BuildArchitecture}}/ + contents: '*.pdb' + targetFolder: '$(Build.ArtifactStagingDirectory)/symbols' + # Publish symbol archive to match nuget package + # Index your source code and publish symbols to a file share or Azure Artifacts symbol server + - task: PublishSymbols@2 + inputs: + symbolsFolder: '$(Build.ArtifactStagingDirectory)/symbols' + searchPattern: '**/*.pdb' + indexSources: false # Github sources not supported + publishSymbols: true + symbolServerType: TeamServices + detailedLog: true + - task: EsrpCodeSigning@2 + displayName: Sign + inputs: + ConnectedServiceName: 'z3-esrp-signing-2' + FolderPath: 'build-dist/${{parameters.BuildArchitecture}}/dist/z3-${{parameters.ReleaseVersion}}-${{parameters.BuildArchitecture}}-win/bin' + Pattern: 'Microsoft.Z3.dll,libz3.dll,libz3java.dll,z3.exe' + signConfigType: 'inlineSignParams' + inlineOperation: | + [ + { + "keyCode": "CP-230012", + "operationSetCode": "SigntoolSign", + "parameters": [ + { + "parameterName": "OpusName", + "parameterValue": "Microsoft" + }, + { + "parameterName": "OpusInfo", + "parameterValue": "http://www.microsoft.com" + }, + { + "parameterName": "PageHash", + "parameterValue": "/NPH" + }, + { + "parameterName": "FileDigest", + "parameterValue": "/fd sha256" + }, + { + "parameterName": "TimeStamp", + "parameterValue": "/tr \"http://rfc3161.gtm.corp.microsoft.com/TSS/HttpTspServer\" /td sha256" + } + ], + "toolName": "signtool.exe", + "toolVersion": "6.2.9304.0" + } + ] + SessionTimeout: '60' + MaxConcurrency: '50' + MaxRetryAttempts: '5' + - task: DeleteFiles@1 + displayName: Cleanup + inputs: + SourceFolder: 'build-dist/${{parameters.BuildArchitecture}}/dist/z3-${{parameters.ReleaseVersion}}-${{parameters.BuildArchitecture}}-win/bin' + Contents: 'CodeSignSummary*' + - task: ArchiveFiles@2 + displayName: Zip + inputs: + rootFolderOrFile: 'build-dist/${{parameters.BuildArchitecture}}/dist/z3-${{parameters.ReleaseVersion}}-${{parameters.BuildArchitecture}}-win' + includeRootFolder: true + archiveType: 'zip' + archiveFile: '$(Build.ArtifactStagingDirectory)/z3-${{parameters.ReleaseVersion}}-${{parameters.BuildArchitecture}}-win.zip' + - task: PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)/z3-${{parameters.ReleaseVersion}}-${{parameters.BuildArchitecture}}-win.zip' + artifactName: 'WindowsBuild-${{parameters.BuildArchitecture}}' \ No newline at end of file diff --git a/scripts/mk_genfile_common.py b/scripts/mk_genfile_common.py index bb6d884e6eb..3be314a53b8 100644 --- a/scripts/mk_genfile_common.py +++ b/scripts/mk_genfile_common.py @@ -139,7 +139,7 @@ def mk_z3consts_py_internal(api_files, output_dir): assert False, "Invalid %s, line: %s" % (api_file, linenum) else: assert mode == IN_ENUM - words = re.split('[^\-a-zA-Z0-9_]+', line) + words = re.split('[^-a-zA-Z0-9_]+', line) m = closebrace_pat.match(line) if m: name = words[1] @@ -227,7 +227,7 @@ def mk_z3consts_dotnet_internal(api_files, output_dir): assert False, "Invalid %s, line: %s" % (api_file, linenum) else: assert mode == IN_ENUM - words = re.split('[^\-a-zA-Z0-9_]+', line) + words = re.split('[^-a-zA-Z0-9_]+', line) m = closebrace_pat.match(line) if m: name = words[1] @@ -315,7 +315,7 @@ def mk_z3consts_java_internal(api_files, package_name, output_dir): assert False, "Invalid %s, line: %s" % (api_file, linenum) else: assert mode == IN_ENUM - words = re.split('[^\-a-zA-Z0-9_]+', line) + words = re.split('[^-a-zA-Z0-9_]+', line) m = closebrace_pat.match(line) if m: name = words[1] @@ -441,7 +441,7 @@ def mk_z3consts_ml_internal(api_files, output_dir): assert False, "Invalid %s, line: %s" % (api_file, linenum) else: assert mode == IN_ENUM - words = re.split('[^\-a-zA-Z0-9_]+', line) + words = re.split('[^-a-zA-Z0-9_]+', line) m = closebrace_pat.match(line) if m: name = words[1] @@ -574,7 +574,7 @@ def mk_def_file_internal(defname, dll_name, export_header_files): for line in api: m = pat1.match(line) if m: - words = re.split('\W+', line) + words = re.split(r'\W+', line) i = 0 for w in words: if w == 'Z3_API': @@ -618,9 +618,9 @@ def mk_gparams_register_modules_internal(h_files_full_path, path): fout = open(fullname, 'w') fout.write('// Automatically generated file.\n') fout.write('#include "util/gparams.h"\n') - reg_pat = re.compile('[ \t]*REG_PARAMS\(\'([^\']*)\'\)') - reg_mod_pat = re.compile('[ \t]*REG_MODULE_PARAMS\(\'([^\']*)\', *\'([^\']*)\'\)') - reg_mod_descr_pat = re.compile('[ \t]*REG_MODULE_DESCRIPTION\(\'([^\']*)\', *\'([^\']*)\'\)') + reg_pat = re.compile(r'[ \t]*REG_PARAMS\(\'([^\']*)\'\)') + reg_mod_pat = re.compile(r'[ \t]*REG_MODULE_PARAMS\(\'([^\']*)\', *\'([^\']*)\'\)') + reg_mod_descr_pat = re.compile(r'[ \t]*REG_MODULE_DESCRIPTION\(\'([^\']*)\', *\'([^\']*)\'\)') for h_file in sorted_headers_by_component(h_files_full_path): added_include = False with io.open(h_file, encoding='utf-8', mode='r') as fin: @@ -698,9 +698,9 @@ def ADD_SIMPLIFIER(name, descr, cmd): fout.write('#include "cmd_context/tactic_cmds.h"\n') fout.write('#include "cmd_context/simplifier_cmds.h"\n') fout.write('#include "cmd_context/cmd_context.h"\n') - tactic_pat = re.compile('[ \t]*ADD_TACTIC\(.*\)') - probe_pat = re.compile('[ \t]*ADD_PROBE\(.*\)') - simplifier_pat = re.compile('[ \t]*ADD_SIMPLIFIER\(.*\)') + tactic_pat = re.compile(r'[ \t]*ADD_TACTIC\(.*\)') + probe_pat = re.compile(r'[ \t]*ADD_PROBE\(.*\)') + simplifier_pat = re.compile(r'[ \t]*ADD_SIMPLIFIER\(.*\)') for h_file in sorted_headers_by_component(h_files_full_path): added_include = False try: @@ -780,10 +780,10 @@ def mk_mem_initializer_cpp_internal(h_files_full_path, path): fullname = os.path.join(path, 'mem_initializer.cpp') fout = open(fullname, 'w') fout.write('// Automatically generated file.\n') - initializer_pat = re.compile('[ \t]*ADD_INITIALIZER\(\'([^\']*)\'\)') + initializer_pat = re.compile(r'[ \t]*ADD_INITIALIZER\(\'([^\']*)\'\)') # ADD_INITIALIZER with priority - initializer_prio_pat = re.compile('[ \t]*ADD_INITIALIZER\(\'([^\']*)\',[ \t]*(-?[0-9]*)\)') - finalizer_pat = re.compile('[ \t]*ADD_FINALIZER\(\'([^\']*)\'\)') + initializer_prio_pat = re.compile(r'[ \t]*ADD_INITIALIZER\(\'([^\']*)\',[ \t]*(-?[0-9]*)\)') + finalizer_pat = re.compile(r'[ \t]*ADD_FINALIZER\(\'([^\']*)\'\)') for h_file in sorted_headers_by_component(h_files_full_path): added_include = False with io.open(h_file, encoding='utf-8', mode='r') as fin: @@ -952,7 +952,7 @@ def def_module_params(module_name, export, params, class_name=None, description= 'UINT_MAX' : UINT_MAX, 'max_memory_param' : max_memory_param, 'max_steps_param' : max_steps_param, - # Note that once this function is enterred that function + # Note that once this function is entered that function # executes with respect to the globals of this module and # not the globals defined here 'def_module_params' : def_module_params, diff --git a/scripts/mk_nuget_task.py b/scripts/mk_nuget_task.py index 073b6b99a74..9de592c086b 100644 --- a/scripts/mk_nuget_task.py +++ b/scripts/mk_nuget_task.py @@ -21,17 +21,20 @@ def mk_dir(d): if not os.path.exists(d): os.makedirs(d) -os_info = { 'ubuntu-latest' : ('so', 'linux-x64'), - 'ubuntu-18' : ('so', 'linux-x64'), - 'ubuntu-20' : ('so', 'linux-x64'), - 'glibc' : ('so', 'linux-x64'), - #'glibc-2.35' : ('so', 'linux-x64'), +os_info = { 'x64-ubuntu-latest' : ('so', 'linux-x64'), + 'x64-ubuntu-18' : ('so', 'linux-x64'), + 'x64-ubuntu-20' : ('so', 'linux-x64'), + 'x64-ubuntu-22' : ('so', 'linux-x64'), + 'x64-glibc-2.35' : ('so', 'linux-x64'), 'x64-win' : ('dll', 'win-x64'), 'x86-win' : ('dll', 'win-x86'), 'x64-osx' : ('dylib', 'osx-x64'), - 'arm64-osx' : ('dylib', 'osx-arm64'), 'debian' : ('so', 'linux-x64') } +# Nuget not supported for ARM +#'arm-glibc-2.35' : ('so', 'linux-arm64'), +#'arm64-osx' : ('dylib', 'osx-arm64'), + def classify_package(f, arch): @@ -76,8 +79,15 @@ def unpack(packages, symbols, arch): if symbols: files += ["Microsoft.Z3.pdb", "Microsoft.Z3.xml"] for b in files: - zip_ref.extract(f"{package_dir}/bin/{b}", f"{tmp}") - replace(f"{tmp}/{package_dir}/bin/{b}", f"out/lib/netstandard2.0/{b}") + file = f"{package_dir}/bin/{b}" + if os.path.exists(file): + zip_ref.extract(file, f"{tmp}") + replace(f"{tmp}/{package_dir}/bin/{b}", f"out/lib/netstandard2.0/{b}") + file = os.path.join(file,"netstandard2.0") + if os.path.exists(file): + zip_ref.extract(file, f"{tmp}") + replace(f"{tmp}/{package_dir}/bin/netstandard2.0/{b}", f"out/lib/netstandard2.0/{b}") + def mk_targets(source_root): mk_dir("out/build") @@ -86,6 +96,8 @@ def mk_targets(source_root): def mk_icon(source_root): mk_dir("out/content") shutil.copy(f"{source_root}/resources/icon.jpg", "out/content/icon.jpg") +# shutil.copy(f"{source_root}/src/api/dotnet/README.md", "out/content/README.md") + def create_nuget_spec(version, repo, branch, commit, symbols, arch): diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 44e436daf0a..2805cbaf13d 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -8,7 +8,7 @@ from mk_util import * def init_version(): - set_version(4, 12, 2, 0) # express a default build version or pick up ci build version + set_version(4, 13, 0, 0) # express a default build version or pick up ci build version # Z3 Project definition def init_project_def(): @@ -40,6 +40,7 @@ def init_project_def(): add_lib('model', ['macros']) add_lib('converters', ['model'], 'ast/converters') add_lib('simplifiers', ['euf', 'normal_forms', 'bit_blaster', 'converters', 'substitution'], 'ast/simplifiers') + add_lib('ast_sls', ['ast','normal_forms','converters'], 'ast/sls') add_lib('tactic', ['simplifiers']) add_lib('mbp', ['model', 'simplex'], 'qe/mbp') add_lib('qe_lite', ['tactic', 'mbp'], 'qe/lite') @@ -58,13 +59,13 @@ def init_project_def(): add_lib('proto_model', ['model', 'rewriter', 'smt_params'], 'smt/proto_model') add_lib('smt', ['bit_blaster', 'macros', 'normal_forms', 'cmd_context', 'proto_model', 'solver_assertions', 'substitution', 'grobner', 'simplex', 'proofs', 'pattern', 'parser_util', 'fpa', 'lp']) - add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') + add_lib('sat_smt', ['sat', 'ast_sls', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') add_lib('bv_tactics', ['tactic', 'bit_blaster', 'core_tactics'], 'tactic/bv') add_lib('fuzzing', ['ast'], 'test/fuzzing') add_lib('smt_tactic', ['smt'], 'smt/tactic') - add_lib('sls_tactic', ['tactic', 'normal_forms', 'core_tactics', 'bv_tactics'], 'tactic/sls') + add_lib('sls_tactic', ['tactic', 'normal_forms', 'core_tactics', 'bv_tactics', 'ast_sls'], 'tactic/sls') add_lib('qe', ['smt', 'mbp', 'qe_lite', 'nlsat', 'tactic', 'nlsat_tactic'], 'qe') add_lib('sat_solver', ['solver', 'core_tactics', 'aig_tactic', 'bv_tactics', 'arith_tactics', 'sat_tactic'], 'sat/sat_solver') add_lib('fd_solver', ['core_tactics', 'arith_tactics', 'sat_solver', 'smt'], 'tactic/fd_solver') diff --git a/scripts/mk_unix_dist.py b/scripts/mk_unix_dist.py index 3b1e7139143..d967e91098a 100644 --- a/scripts/mk_unix_dist.py +++ b/scripts/mk_unix_dist.py @@ -118,7 +118,9 @@ def check_build_dir(path): # Create a build directory using mk_make.py def mk_build_dir(path): + global LINUX_X64 if not check_build_dir(path) or FORCE_MK: + env = os.environ opts = [sys.executable, os.path.join('scripts', 'mk_make.py'), "-b", path, "--staticlib"] if DOTNET_CORE_ENABLED: opts.append('--dotnet') @@ -133,7 +135,17 @@ def mk_build_dir(path): opts.append('--python') if mk_util.IS_ARCH_ARM64: opts.append('--arm64=true') - if subprocess.call(opts) != 0: + if mk_util.IS_ARCH_ARM64 and LINUX_X64: + # we are machine x64 but build against arm64 + # so we have to do cross compiling + # the cross compiler is download from ARM GNU + # toolchain + myvar = { + "CC": "aarch64-none-linux-gnu-gcc", + "CXX": "aarch64-none-linux-gnu-g++" + } + env.update(myvar) + if subprocess.call(opts, env=env) != 0: raise MKException("Failed to generate build directory at '%s'" % path) # Create build directories @@ -159,12 +171,22 @@ def mk_z3(): return 1 def get_os_name(): + global LINUX_X64 if OS_NAME is not None: return OS_NAME import platform basic = os.uname()[0].lower() if basic == 'linux': - dist = platform.libc_ver() + if mk_util.IS_ARCH_ARM64 and LINUX_X64: + # handle cross compiling + # example: 'ldd (GNU) 2.34' + lines = subprocess.check_output(["ldd", "--version"]).decode('ascii') + first_line = lines.split("\n")[0] + ldd_version = first_line.split()[-1] + # coerce the format to platform.libc_ver() return type + dist = ('glibc', ldd_version) + else: + dist = platform.libc_ver() if len(dist) == 2 and len(dist[0]) > 0 and len(dist[1]) > 0: return '%s-%s' % (dist[0].lower(), dist[1].lower()) else: @@ -187,8 +209,14 @@ def get_os_name(): return basic def get_z3_name(): + import platform as platform_module + # Note that the platform name this function return + # has to work together with setup.py + # It's not the typical output from platform.machine() major, minor, build, revision = get_version() - if mk_util.IS_ARCH_ARM64: + if mk_util.IS_ARCH_ARM64 or platform_module.machine() == "aarch64": + # the second case handle native build on aarch64 + # TODO: we don't handle cross compile on host aarch64 to target x64 platform = "arm64" elif sys.maxsize >= 2**32: platform = "x64" diff --git a/scripts/mk_unix_dist_cmake.py b/scripts/mk_unix_dist_cmake.py new file mode 100644 index 00000000000..3a38ad65b64 --- /dev/null +++ b/scripts/mk_unix_dist_cmake.py @@ -0,0 +1,268 @@ +############################################ +# Copyright (c) 2013 Microsoft Corporation +# +# Scripts for automatically generating +# Linux/OSX/BSD distribution zip files. +# +# Author: Leonardo de Moura (leonardo) +############################################ + +import os +import subprocess +import zipfile +import re +import getopt +import sys +import shutil +from mk_exception import * +from fnmatch import fnmatch + +def getenv(name, default): + try: + return os.environ[name].strip(' "\'') + except: + return default + +BUILD_DIR = 'build-dist' +DIST_DIR = 'dist' +VERBOSE = True +FORCE_MK = False +ASSEMBLY_VERSION = None +DOTNET_CORE_ENABLED = True +DOTNET_KEY_FILE = None +JAVA_ENABLED = True +JULIA_ENABLED = False +GIT_HASH = False +PYTHON_ENABLED = True +ARM64 = False +MAKEJOBS = getenv("MAKEJOBS", "24") + +def set_verbose(flag): + global VERBOSE + VERBOSE = flag + +def is_verbose(): + return VERBOSE + +def mk_dir(d): + if not os.path.exists(d): + if is_verbose(): + print("Make directory", d) + os.makedirs(d) + +def get_z3_name(): + version = "4" + if ASSEMBLY_VERSION: + version = ASSEMBLY_VERSION + print("Assembly version:", version) + if GIT_HASH: + return 'z3-%s.%s' % (version, get_git_hash()) + else: + return 'z3-%s' % (version) + +def get_build_dir(): + return BUILD_DIR + +def get_build_dist(): + return os.path.join(get_build_dir(), DIST_DIR) + +def get_build_dist_path(): + return os.path.join(get_build_dist(), get_z3_name()) + +def set_build_dir(path): + global BUILD_DIR + BUILD_DIR = os.path.expanduser(os.path.normpath(path)) + mk_dir(BUILD_DIR) + +def display_help(): + print("mk_unix_dist_cmake.py: Z3 Unix distribution generator\n") + print("This script generates the zip files containing executables, shared objects, header files for Unix.") + print("It must be executed from the Z3 root directory.") + print("\nOptions:") + print(" -h, --help display this message.") + print(" -s, --silent do not print verbose messages.") + print(" -b , --build= subdirectory where Z3 will be built (default: build-dist).") + print(" -f, --force force script to regenerate Makefiles.") + print(" --version= release version.") + print(" --assembly-version assembly version for dll") + print(" --nodotnet do not include .NET bindings in the binary distribution files.") + print(" --dotnet-key= strongname sign the .NET assembly with the private key in .") + print(" --nojava do not include Java bindings in the binary distribution files.") + print(" --nopython do not include Python bindings in the binary distribution files.") + print(" --julia build Julia bindings.") + print(" --githash include git hash in the Zip file.") + print(" --arm64 build for ARM64 architecture.") + exit(0) + +# Parse configuration option for mk_make script +def parse_options(): + global FORCE_MK, JAVA_ENABLED, JULIA_ENABLED, GIT_HASH, DOTNET_CORE_ENABLED, DOTNET_KEY_FILE, ASSEMBLY_VERSION, PYTHON_ENABLED, ARM64 + path = BUILD_DIR + options, remainder = getopt.gnu_getopt(sys.argv[1:], 'b:hsf', ['build=', + 'help', + 'silent', + 'force', + 'nojava', + 'nodotnet', + 'dotnet-key=', + 'assembly-version=', + 'githash', + 'nopython', + 'julia', + 'arm64' + ]) + for opt, arg in options: + if opt in ('-b', '--build'): + if arg == 'src': + raise MKException('The src directory should not be used to host the Makefile') + path = arg + elif opt in ('-s', '--silent'): + set_verbose(False) + elif opt in ('-h', '--help'): + display_help() + elif opt in ('-f', '--force'): + FORCE_MK = True + elif opt == '--nodotnet': + DOTNET_CORE_ENABLED = False + elif opt == '--assembly-version': + ASSEMBLY_VERSION = arg + elif opt == '--nopython': + PYTHON_ENABLED = False + elif opt == '--dotnet-key': + DOTNET_KEY_FILE = arg + elif opt == '--nojava': + JAVA_ENABLED = False + elif opt == '--julia': + JULIA_ENABLED = True + elif opt == '--githash': + GIT_HASH = True + elif opt == '--arm64': + ARM64 = True + else: + raise MKException("Invalid command line option '%s'" % opt) + set_build_dir(path) + +def check_output(cmd): + out = subprocess.Popen(cmd, stdout=subprocess.PIPE).communicate()[0] + if out != None: + enc = sys.getdefaultencoding() + if enc != None: return out.decode(enc).rstrip('\r\n') + else: return out.rstrip('\r\n') + else: + return "" + +def get_git_hash(): + try: + branch = check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + r = check_output(['git', 'show-ref', '--abbrev=12', 'refs/heads/%s' % branch]) + except: + raise MKException("Failed to retrieve git hash") + ls = r.split(' ') + if len(ls) != 2: + raise MKException("Unexpected git output " + r) + return ls[0] + +# Create a build directory using CMake +def mk_build_dir(): + build_path = get_build_dir() + if not os.path.exists(build_path) or FORCE_MK: + mk_dir(build_path) + cmds = [] + cmds.append(f"cd {build_path}") + cmd = [] + cmd.append("cmake -S .") + if DOTNET_CORE_ENABLED: + cmd.append(' -DZ3_BUILD_DOTNET_BINDINGS=ON') + if JAVA_ENABLED: + cmd.append(' -DZ3_BUILD_JAVA_BINDINGS=ON') + cmd.append(' -DZ3_INSTALL_JAVA_BINDINGS=ON') + cmd.append(' -DZ3_JAVA_JAR_INSTALLDIR=java') + cmd.append(' -DZ3_JAVA_JNI_LIB_INSTALLDIR=bin/java') + if PYTHON_ENABLED: + cmd.append(' -DZ3_BUILD_PYTHON_BINDINGS=ON') + cmd.append(' -DZ3_INSTALL_PYTHON_BINDINGS=ON') + cmd.append(' -DCMAKE_INSTALL_PYTHON_PKG_DIR=bin/python') + if JULIA_ENABLED: + cmd.append(' -DZ3_BUILD_JULIA_BINDINGS=ON') + cmd.append(' -DZ3_INSTALL_JULIA_BINDINGS=ON') + if GIT_HASH: + git_hash = get_git_hash() + cmd.append(' -DGIT_HASH=' + git_hash) + cmd.append(' -DZ3_USE_LIB_GMP=OFF') + cmd.append(' -DZ3_BUILD_LIBZ3_SHARED=ON') + cmd.append(' -DCMAKE_BUILD_TYPE=RelWithDebInfo') + cmd.append(' -DCMAKE_INSTALL_PREFIX=' + get_build_dist_path()) + cmd.append(' -G "Ninja"') + cmd.append(' ..\n') + cmds.append("".join(cmd)) + print("CMAKE commands:", cmds) + sys.stdout.flush() + if exec_cmds(cmds) != 0: + raise MKException("failed to run commands") + +def exec_cmds(cmds): + cmd_file = 'z3_tmp.sh' + f = open(cmd_file, 'w') + for cmd in cmds: + f.write(cmd) + f.write('\n') + f.close() + res = 0 + try: + res = subprocess.call(['sh', cmd_file]) + except: + res = 1 + try: + os.remove(cmd_file) + except: + pass + return res + +def build_z3(): + if is_verbose(): + print("build z3") + build_dir = get_build_dir() + cmds = [] + cmds.append('cd %s' % build_dir) + cmds.append('ninja install') + if exec_cmds(cmds) != 0: + raise MKException("Failed to make z3") + +def mk_zip(): + build_dist = get_build_dist_path() + dist_name = get_z3_name() + old = os.getcwd() + try: + if is_verbose(): + print("dist path", build_dist) + mk_dir(build_dist) + zfname = '%s.zip' % dist_name + zipout = zipfile.ZipFile(zfname, 'w', zipfile.ZIP_DEFLATED) + os.chdir(get_build_dist()) + for root, dirs, files in os.walk("."): + for f in files: + if is_verbose(): + print("adding ", os.path.join(root, f)) + zipout.write(os.path.join(root, f)) + if is_verbose(): + print("Generated '%s'" % zfname) + except: + pass + os.chdir(old) + +def cp_license(): + if is_verbose(): + print("copy licence") + path = get_build_dist_path() + mk_dir(path) + shutil.copy("LICENSE.txt", path) + +# Entry point +def main(): + parse_options() + mk_build_dir() + build_z3() + cp_license() + mk_zip() + +main() diff --git a/scripts/mk_util.py b/scripts/mk_util.py index 805aea19df3..014b0e40f83 100644 --- a/scripts/mk_util.py +++ b/scripts/mk_util.py @@ -395,7 +395,7 @@ def check_java(): else: # Search for jni.h in the library directories... t = open('errout', 'r') - open_pat = re.compile("\[search path for class files: (.*)\]") + open_pat = re.compile(r"\[search path for class files: (.*)\]") cdirs = [] for line in t: m = open_pat.match(line) @@ -812,8 +812,8 @@ def parse_options(): def extract_c_includes(fname): result = {} # We look for well behaved #include directives - std_inc_pat = re.compile("[ \t]*#include[ \t]*\"(.*)\"[ \t]*") - system_inc_pat = re.compile("[ \t]*#include[ \t]*\<.*\>[ \t]*") + std_inc_pat = re.compile(r"[ \t]*#include[ \t]*\"(.*)\"[ \t]*") + system_inc_pat = re.compile(r"[ \t]*#include[ \t]*\<.*\>[ \t]*") # We should generate and error for any occurrence of #include that does not match the previous pattern. non_std_inc_pat = re.compile(".*#include.*") @@ -1720,7 +1720,7 @@ def mk_makefile(self, out): print("Version output to csproj:", version) - core_csproj_str = """ + core_csproj_str = r""" netstandard1.4 @@ -1736,6 +1736,7 @@ def mk_makefile(self, out): true Microsoft Microsoft + README.md false Z3 is a satisfiability modulo theories solver from Microsoft Research. Copyright Microsoft Corporation. All rights reserved. @@ -1745,9 +1746,10 @@ def mk_makefile(self, out): + -""" % (version, key, self.to_src_dir) +""" % (version, key, self.to_src_dir, self.to_src_dir) mk_dir(os.path.join(BUILD_DIR, 'dotnet')) csproj = os.path.join('dotnet', 'z3.csproj') @@ -2246,7 +2248,7 @@ def mk_makefile(self, out): else: platform = 'x86' - dotnet_proj_str = """ + dotnet_proj_str = r""" Exe netcoreapp2.0 @@ -2519,19 +2521,19 @@ def mk_config(): 'SLINK_FLAGS=/nologo /LDd\n' % static_opt) if VS_X64: config.write( - 'CXXFLAGS=/c %s /W3 /WX- /Od /Oy- /D _DEBUG /D Z3DEBUG /D _CONSOLE /D _TRACE /Gm- /RTC1 %s %s\n' % (CXXFLAGS, extra_opt, static_opt)) + 'CXXFLAGS=/c %s /Zi /W3 /WX- /Od /Oy- /D _DEBUG /D Z3DEBUG /D _CONSOLE /D _TRACE /Gm- /RTC1 %s %s\n' % (CXXFLAGS, extra_opt, static_opt)) config.write( - 'LINK_EXTRA_FLAGS=/link /DEBUG /MACHINE:X64 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' - 'SLINK_EXTRA_FLAGS=/link /DEBUG /MACHINE:X64 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (link_extra_opt, maybe_disable_dynamic_base, link_extra_opt)) + 'LINK_EXTRA_FLAGS=/link /PROFILE /DEBUG:full /MACHINE:X64 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' + 'SLINK_EXTRA_FLAGS=/link /PROFILE /DEBUG:full /MACHINE:X64 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (link_extra_opt, maybe_disable_dynamic_base, link_extra_opt)) elif VS_ARM: print("ARM on VS is unsupported") exit(1) else: config.write( - 'CXXFLAGS=/c %s /W3 /WX- /Od /Oy- /D _DEBUG /D Z3DEBUG /D _CONSOLE /D _TRACE /Gm- /RTC1 /arch:SSE2 %s %s\n' % (CXXFLAGS, extra_opt, static_opt)) + 'CXXFLAGS=/c %s /Zi /W3 /WX- /Od /Oy- /D _DEBUG /D Z3DEBUG /D _CONSOLE /D _TRACE /Gm- /RTC1 /arch:SSE2 %s %s\n' % (CXXFLAGS, extra_opt, static_opt)) config.write( - 'LINK_EXTRA_FLAGS=/link /DEBUG /MACHINE:X86 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' - 'SLINK_EXTRA_FLAGS=/link /DEBUG /MACHINE:X86 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (link_extra_opt, maybe_disable_dynamic_base, link_extra_opt)) + 'LINK_EXTRA_FLAGS=/link /PROFILE /DEBUG:full /MACHINE:X86 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' + 'SLINK_EXTRA_FLAGS=/link /PROFILE /DEBUG:full /MACHINE:X86 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (link_extra_opt, maybe_disable_dynamic_base, link_extra_opt)) else: # Windows Release mode LTCG=' /LTCG' if SLOW_OPTIMIZE else '' @@ -2544,19 +2546,19 @@ def mk_config(): extra_opt = '%s /D _TRACE ' % extra_opt if VS_X64: config.write( - 'CXXFLAGS=/c%s %s /W3 /WX- /O2 /D _EXTERNAL_RELEASE /D NDEBUG /D _LIB /D UNICODE /Gm- /GF /Gy /TP %s %s\n' % (GL, CXXFLAGS, extra_opt, static_opt)) + 'CXXFLAGS=/c%s %s /Zi /W3 /WX- /O2 /D _EXTERNAL_RELEASE /D NDEBUG /D _LIB /D UNICODE /Gm- /GF /Gy /TP %s %s\n' % (GL, CXXFLAGS, extra_opt, static_opt)) config.write( - 'LINK_EXTRA_FLAGS=/link%s /profile /MACHINE:X64 /SUBSYSTEM:CONSOLE /STACK:8388608 %s\n' - 'SLINK_EXTRA_FLAGS=/link%s /profile /MACHINE:X64 /SUBSYSTEM:WINDOWS /STACK:8388608 %s\n' % (LTCG, link_extra_opt, LTCG, link_extra_opt)) + 'LINK_EXTRA_FLAGS=/link%s /PROFILE /DEBUG:full /profile /MACHINE:X64 /SUBSYSTEM:CONSOLE /STACK:8388608 %s\n' + 'SLINK_EXTRA_FLAGS=/link%s /PROFILE /DEBUG:full /profile /MACHINE:X64 /SUBSYSTEM:WINDOWS /STACK:8388608 %s\n' % (LTCG, link_extra_opt, LTCG, link_extra_opt)) elif VS_ARM: print("ARM on VS is unsupported") exit(1) else: config.write( - 'CXXFLAGS=/c%s %s /WX- /O2 /Oy- /D _EXTERNAL_RELEASE /D NDEBUG /D _CONSOLE /D ASYNC_COMMANDS /Gm- /arch:SSE2 %s %s\n' % (GL, CXXFLAGS, extra_opt, static_opt)) + 'CXXFLAGS=/c%s %s /Zi /WX- /O2 /Oy- /D _EXTERNAL_RELEASE /D NDEBUG /D _CONSOLE /D ASYNC_COMMANDS /Gm- /arch:SSE2 %s %s\n' % (GL, CXXFLAGS, extra_opt, static_opt)) config.write( - 'LINK_EXTRA_FLAGS=/link%s /DEBUG /MACHINE:X86 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' - 'SLINK_EXTRA_FLAGS=/link%s /DEBUG /MACHINE:X86 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (LTCG, link_extra_opt, LTCG, maybe_disable_dynamic_base, link_extra_opt)) + 'LINK_EXTRA_FLAGS=/link%s /PROFILE /DEBUG:full /MACHINE:X86 /SUBSYSTEM:CONSOLE /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 /DYNAMICBASE /NXCOMPAT %s\n' + 'SLINK_EXTRA_FLAGS=/link%s /PROFILE /DEBUG:full /MACHINE:X86 /SUBSYSTEM:WINDOWS /INCREMENTAL:NO /STACK:8388608 /OPT:REF /OPT:ICF /TLBID:1 %s %s\n' % (LTCG, link_extra_opt, LTCG, maybe_disable_dynamic_base, link_extra_opt)) config.write('CFLAGS=$(CXXFLAGS)\n') @@ -2664,7 +2666,7 @@ def mk_config(): LDFLAGS = '%s -static-libgcc -static-libstdc++' % LDFLAGS if sysname == 'Linux' and machine.startswith('armv7') or machine.startswith('armv8'): CXXFLAGS = '%s -fpic' % CXXFLAGS - if IS_OSX and IS_ARCH_ARM64: + if IS_ARCH_ARM64 and IS_OSX: print("Setting arm64") CXXFLAGS = '%s -arch arm64' % CXXFLAGS LDFLAGS = '%s -arch arm64' % LDFLAGS @@ -2992,9 +2994,19 @@ def cp_z3py_to_build(): for f in files: if f.endswith('.pyc'): rmf(os.path.join(root, f)) + # We do not want a second copy of the compiled files in the system-wide cache, + # so we disable it temporarily. This is an issue with recent versions of MacOS + # where XCode's Python has a cache, but the build scripts don't have access to + # it (e.g. during OPAM package installation). + have_cache = hasattr(sys, 'pycache_prefix') and sys.pycache_prefix is not None + if have_cache: + pycache_prefix_before = sys.pycache_prefix + sys.pycache_prefix = None # Compile Z3Py files if compileall.compile_dir(z3py_src, force=1) != 1: raise MKException("failed to compile Z3Py sources") + if have_cache: + sys.pycache_prefix = pycache_prefix_before if is_verbose: print("Generated python bytecode") # Copy sources to build @@ -3162,7 +3174,7 @@ def mk_vs_proj_property_groups(f, name, target_ext, type): f.write(' Win32Proj\n') f.write(' %s\n' % get_platform_toolset_str()) f.write(' \n') - f.write(' \n') + f.write(' \n') f.write(' \n') f.write(' %s\n' % type) f.write(' Unicode\n') @@ -3173,24 +3185,24 @@ def mk_vs_proj_property_groups(f, name, target_ext, type): f.write(' Unicode\n') f.write(' false\n') f.write(' \n') - f.write(' \n') + f.write(' \n') f.write(' \n') f.write(' \n') - f.write(' \n') + f.write(' \n') f.write(' \n') f.write(' \n') - f.write(' $(SolutionDir)\$(ProjectName)\$(Configuration)\\n') + f.write(' $(SolutionDir)\\$(ProjectName)\\$(Configuration)\\\n') f.write(' %s\n' % name) f.write(' .%s\n' % target_ext) - f.write(' $(SolutionDir)\$(ProjectName)\$(Configuration)\\n') + f.write(' $(SolutionDir)\\$(ProjectName)\\$(Configuration)\\\n') f.write(' %s\n' % name) f.write(' .%s\n' % target_ext) f.write(' \n') f.write(' \n') - f.write(' $(ProjectName)\$(Configuration)\\n') + f.write(' $(ProjectName)\\$(Configuration)\\\n') f.write(' \n') f.write(' \n') - f.write(' $(ProjectName)\$(Configuration)\\n') + f.write(' $(ProjectName)\\$(Configuration)\\\n') f.write(' \n') @@ -3267,7 +3279,7 @@ def mk_vs_proj(name, components): mk_vs_proj_link_exe(f, name, debug=False) f.write(' \n') mk_vs_proj_dep_groups(f, name, components) - f.write(' \n') + f.write(' \n') f.write(' \n') f.write(' \n') f.write('\n') @@ -3308,7 +3320,7 @@ def mk_vs_proj_dll(name, components): mk_vs_proj_link_dll(f, name, debug=False) f.write(' \n') mk_vs_proj_dep_groups(f, name, components) - f.write(' \n') + f.write(' \n') f.write(' \n') f.write(' \n') f.write('\n') diff --git a/scripts/mk_win_dist_cmake.py b/scripts/mk_win_dist_cmake.py new file mode 100644 index 00000000000..f3c83cfb6d3 --- /dev/null +++ b/scripts/mk_win_dist_cmake.py @@ -0,0 +1,423 @@ +############################################ +# Copyright (c) 2012 Microsoft Corporation +# +# Scripts for automatically generating +# Window distribution zip files. +# +# Author: Leonardo de Moura (leonardo) +############################################ + +import os +import subprocess +import zipfile +import re +import getopt +import sys +import shutil +from mk_exception import * +from fnmatch import fnmatch + +def getenv(name, default): + try: + return os.environ[name].strip(' "\'') + except: + return default + +BUILD_DIR = 'build-dist' +DIST_DIR = 'dist' +BUILD_X64_DIR = os.path.join(BUILD_DIR, 'x64') +BUILD_X86_DIR = os.path.join(BUILD_DIR, 'x86') +BUILD_ARM64_DIR = os.path.join(BUILD_DIR, 'arm64') +VERBOSE = True +FORCE_MK = False +ASSEMBLY_VERSION = None +DOTNET_CORE_ENABLED = True +DOTNET_KEY_FILE = None +JAVA_ENABLED = True +ZIP_BUILD_OUTPUTS = False +GIT_HASH = False +PYTHON_ENABLED = True +X86ONLY = False +X64ONLY = False +ARM64ONLY = False +ARCHITECTURES = [] + +def set_verbose(flag): + global VERBOSE + VERBOSE = flag + +def is_verbose(): + return VERBOSE + +def mk_dir(d): + if not os.path.exists(d): + if is_verbose(): + print("Make directory", d) + os.makedirs(d) + +def get_z3_name(arch): + version = "4" + if ASSEMBLY_VERSION: + version = ASSEMBLY_VERSION + print("Assembly version:", version) + if GIT_HASH: + return 'z3-%s.%s-%s-win' % (version, get_git_hash(), arch) + else: + return 'z3-%s-%s-win' % (version, arch) + +def get_build_dir(arch): + return ARCHITECTURES[arch] + +def get_build_dist(arch): + return os.path.join(get_build_dir(arch), DIST_DIR) + +def get_build_dist_path(arch): + return os.path.join(get_build_dir(arch), DIST_DIR, get_z3_name(arch)) + +def get_bin_dist_path(arch): + return os.path.join(get_build_dist_path(arch), "bin") + +def get_lib_dist_path(arch): + return os.path.join(get_build_dist_path(arch), "lib") + +def get_java_dist_path(arch): + return os.path.join(get_build_dist_path(arch), "java") + +def get_dist_path(arch): + return os.path.join(DIST_DIR, arch) + +def set_build_dir(path): + global BUILD_DIR, BUILD_X86_DIR, BUILD_X64_DIR, BUILD_ARM64_DIR, ARCHITECTURES + BUILD_DIR = os.path.expanduser(os.path.normpath(path)) + BUILD_X86_DIR = os.path.join(path, 'x86') + BUILD_X64_DIR = os.path.join(path, 'x64') + BUILD_ARM64_DIR = os.path.join(path, 'arm64') # Set ARM64 build directory + ARCHITECTURES = {'x64': BUILD_X64_DIR, 'x86':BUILD_X86_DIR, 'arm64':BUILD_ARM64_DIR} + +def display_help(): + print("mk_win_dist.py: Z3 Windows distribution generator\n") + print("This script generates the zip files containing executables, dlls, header files for Windows.") + print("It must be executed from the Z3 root directory.") + print("\nOptions:") + print(" -h, --help display this message.") + print(" -s, --silent do not print verbose messages.") + print(" -b , --build= subdirectory where x86 and x64 Z3 versions will be built (default: build-dist).") + print(" -f, --force force script to regenerate Makefiles.") + print(" --version= release version.") + print(" --assembly-version assembly version for dll") + print(" --nodotnet do not include .NET bindings in the binary distribution files.") + print(" --dotnet-key= strongname sign the .NET assembly with the private key in .") + print(" --nojava do not include Java bindings in the binary distribution files.") + print(" --nopython do not include Python bindings in the binary distribution files.") + print(" --zip package build outputs in zip file.") + print(" --githash include git hash in the Zip file.") + print(" --x86-only x86 dist only.") + print(" --x64-only x64 dist only.") + print(" --arm64-only arm64 dist only.") + exit(0) + +# Parse configuration option for mk_make script +def parse_options(): + global FORCE_MK, JAVA_ENABLED, ZIP_BUILD_OUTPUTS, GIT_HASH, DOTNET_CORE_ENABLED, DOTNET_KEY_FILE, ASSEMBLY_VERSION, PYTHON_ENABLED, X86ONLY, X64ONLY, ARM64ONLY + path = BUILD_DIR + options, remainder = getopt.gnu_getopt(sys.argv[1:], 'b:hsf', ['build=', + 'help', + 'silent', + 'force', + 'nojava', + 'nodotnet', + 'dotnet-key=', + 'assembly-version=', + 'zip', + 'githash', + 'nopython', + 'x86-only', + 'x64-only', + 'arm64-only' + ]) + for opt, arg in options: + if opt in ('-b', '--build'): + if arg == 'src': + raise MKException('The src directory should not be used to host the Makefile') + path = arg + elif opt in ('-s', '--silent'): + set_verbose(False) + elif opt in ('-h', '--help'): + display_help() + elif opt in ('-f', '--force'): + FORCE_MK = True + elif opt == '--nodotnet': + DOTNET_CORE_ENABLED = False + elif opt == '--assembly-version': + ASSEMBLY_VERSION = arg + elif opt == '--nopython': + PYTHON_ENABLED = False + elif opt == '--dotnet-key': + DOTNET_KEY_FILE = arg + elif opt == '--nojava': + JAVA_ENABLED = False + elif opt == '--zip': + ZIP_BUILD_OUTPUTS = True + elif opt == '--githash': + GIT_HASH = True + elif opt == '--x86-only' and not X64ONLY: + X86ONLY = True + elif opt == '--arm64-only' and not X86ONLY and not X64ONLY: + ARM64ONLY = True + elif opt == '--x64-only' and not X86ONLY: + X64ONLY = True + else: + raise MKException("Invalid command line option '%s'" % opt) + set_build_dir(path) + +# Check whether build directory already exists or not +def check_build_dir(path): + return os.path.exists(path) and os.path.exists(os.path.join(path, 'Makefile')) + +def check_output(cmd): + out = subprocess.Popen(cmd, stdout=subprocess.PIPE).communicate()[0] + if out != None: + enc = sys.getdefaultencoding() + if enc != None: return out.decode(enc).rstrip('\r\n') + else: return out.rstrip('\r\n') + else: + return "" + +def get_git_hash(): + try: + branch = check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + r = check_output(['git', 'show-ref', '--abbrev=12', 'refs/heads/%s' % branch]) + except: + raise MKException("Failed to retrieve git hash") + ls = r.split(' ') + if len(ls) != 2: + raise MKException("Unexpected git output " + r) + return ls[0] + + + +# Create a build directory using mk_make.py +def mk_build_dir(arch): + build_path = get_build_dir(arch) + if not check_build_dir(build_path) or FORCE_MK: + mk_dir(build_path) + vsarch = arch + if arch == "arm64": + vsarch = "amd64_arm64" + + cmds = [] + cmds.append(f"cd {build_path}") + cmds.append('call "%VCINSTALLDIR%Auxiliary\\build\\vcvarsall.bat" ' + vsarch) + cmd = [] + cmd.append("cmake -S .") + if DOTNET_CORE_ENABLED: + cmd.append(' -DZ3_BUILD_DOTNET_BINDINGS=ON') +# cmd.append(' -DZ3_INSTALL_DOTNET_BINDINGS=ON') + if JAVA_ENABLED: + cmd.append(' -DZ3_BUILD_JAVA_BINDINGS=ON') + cmd.append(' -DZ3_INSTALL_JAVA_BINDINGS=ON') + cmd.append(' -DZ3_JAVA_JAR_INSTALLDIR=java') + cmd.append(' -DZ3_JAVA_JNI_LIB_INSTALLDIR=bin/java') + if PYTHON_ENABLED: + cmd.append(' -DZ3_BUILD_PYTHON_BINDINGS=ON') + cmd.append(' -DZ3_INSTALL_PYTHON_BINDINGS=ON') + cmd.append(' -DCMAKE_INSTALL_PYTHON_PKG_DIR=bin/python') + + if GIT_HASH: + git_hash = get_git_hash() + cmd.append(' -DGIT_HASH=' + git_hash) + cmd.append(' -DZ3_USE_LIB_GMP=OFF') + cmd.append(' -DZ3_BUILD_LIBZ3_SHARED=ON') + cmd.append(' -DCMAKE_BUILD_TYPE=RelWithDebInfo') + cmd.append(' -DCMAKE_INSTALL_PREFIX=' + os.path.join(DIST_DIR, get_z3_name(arch))) + cmd.append(' -G "Ninja"') + cmd.append(' ../..\n') + cmds.append("".join(cmd)) + print("CMAKE commands:", cmds) + sys.stdout.flush() + if exec_cmds(cmds) != 0: + raise MKException("failed to run commands") + + + +# Check if on Visual Studio command prompt +def check_vc_cmd_prompt(): + try: + DEVNULL = open(os.devnull, 'wb') + subprocess.call(['cl'], stdout=DEVNULL, stderr=DEVNULL) + except: + raise MKException("You must execute the mk_win_dist.py script on a Visual Studio Command Prompt") + +def exec_cmds(cmds): + cmd_file = 'z3_tmp.cmd' + f = open(cmd_file, 'w') + for cmd in cmds: + f.write(cmd) + f.write('\n') + f.close() + res = 0 + try: + res = subprocess.call(cmd_file, shell=True) + except: + res = 1 + try: + os.erase(cmd_file) + except: + pass + return res + + + +def build_z3(arch): + if is_verbose(): + print("build z3") + build_dir = get_build_dir(arch) + if arch == "arm64": + arch = "amd64_arm64" + cmds = [] + cmds.append('call "%VCINSTALLDIR%Auxiliary\\build\\vcvarsall.bat" ' + arch) + cmds.append('cd %s' % build_dir) + cmds.append('ninja install') + if exec_cmds(cmds) != 0: + raise MKException("Failed to make z3") + + + +def mk_zip(arch): + if not ZIP_BUILD_OUTPUTS: + return + build_dist = get_build_dist_path(arch) + dist_name = get_z3_name(arch) + dist_path = get_dist_path(arch) + build_dir = get_build_dir(arch) + old = os.getcwd() + try: + if is_verbose(): + print("dist path", dist_path) + mk_dir(dist_path) + zfname = os.path.join(dist_path, '%s.zip' % dist_name) + zipout = zipfile.ZipFile(zfname, 'w', zipfile.ZIP_DEFLATED) + os.chdir(get_build_dist(arch)) + for root, dirs, files in os.walk("."): + for f in files: + if is_verbose(): + print("adding ", os.path.join(root, f)) + zipout.write(os.path.join(root, f)) + if is_verbose(): + print("Generated '%s'" % zfname) + except: + pass + os.chdir(old) + + + +VS_RUNTIME_PATS = [re.compile(r'vcomp.*\.dll'), + re.compile(r'msvcp.*\.dll'), + re.compile(r'msvcr.*\.dll'), + re.compile(r'vcrun.*\.dll')] + +# Copy Visual Studio Runtime libraries +def cp_vs_runtime(arch): + platform = arch + vcdir = os.environ['VCINSTALLDIR'] + path = '%sredist' % vcdir + vs_runtime_files = [] + print("Walking %s" % path) + # Everything changes with every release of VS + # Prior versions of VS had DLLs under "redist\x64" + # There are now several variants of redistributables + # The naming convention defies my understanding so + # we use a "check_root" filter to find some hopefully suitable + # redistributable. + def check_root(root): + return platform in root and ("CRT" in root or "MP" in root) and "onecore" not in root and "debug" not in root + for root, dirs, files in os.walk(path): + for filename in files: + if fnmatch(filename, '*.dll') and check_root(root): + print("Checking %s %s" % (root, filename)) + for pat in VS_RUNTIME_PATS: + if pat.match(filename): + fname = os.path.join(root, filename) + if not os.path.isdir(fname): + vs_runtime_files.append(fname) + if not vs_runtime_files: + raise MKException("Did not find any runtime files to include") + bin_dist_path = get_bin_dist_path(arch) + for f in vs_runtime_files: + shutil.copy(f, bin_dist_path) + if is_verbose(): + print("Copied '%s' to '%s'" % (f, bin_dist_path)) + +def cp_license(arch): + if is_verbose(): + print("copy licence") + path = get_build_dist_path(arch) + mk_dir(path) + shutil.copy("LICENSE.txt", path) + +def cp_dotnet(arch): + if not DOTNET_CORE_ENABLED: + return + if is_verbose(): + print("copy dotnet") + build_dir = get_build_dir(arch) + dist_dir = get_bin_dist_path(arch) + shutil.copytree(os.path.join(build_dir, "Microsoft.Z3"), + dist_dir, + dirs_exist_ok=True) + +def cp_into_bin(arch): + if is_verbose(): + print("copy lib") + lib_dir = get_lib_dist_path(arch) + bin_dir = get_bin_dist_path(arch) + shutil.copyfile(os.path.join(lib_dir, "libz3.lib"), + os.path.join(bin_dir, "libz3.lib")) + shutil.rmtree(lib_dir) + if JAVA_ENABLED: + java_dir = get_java_dist_path(arch) + shutil.copytree(java_dir, + bin_dir, + dirs_exist_ok=True) + shutil.rmtree(java_dir) + +def cp_pdb(arch): + if is_verbose(): + print("copy pdb") + build_dir = get_build_dir(arch) + bin_path = get_bin_dist_path(arch) + mk_dir(bin_path) + for f in os.listdir(build_dir): + if f.endswith("libz3.pdb"): + shutil.copy(os.path.join(build_dir, f), bin_path) + +def build_for_arch(arch): + mk_build_dir(arch) + build_z3(arch) + cp_license(arch) + cp_pdb(arch) + cp_dotnet(arch) + cp_vs_runtime(arch) + cp_into_bin(arch) + mk_zip(arch) + +# Entry point +def main(): + if os.name != 'nt': + raise MKException("This script is for Windows only") + + parse_options() + check_vc_cmd_prompt() + + if X86ONLY: + build_for_arch("x86") + elif X64ONLY: + build_for_arch("x64") + elif ARM64ONLY: + build_for_arch("arm64") + else: + for arch in ARCHITECTURES: + build_for_arch(arch) + +main() + diff --git a/scripts/nightly.yaml b/scripts/nightly.yaml index 4aeb44b352d..51533dd8e2b 100644 --- a/scripts/nightly.yaml +++ b/scripts/nightly.yaml @@ -1,28 +1,44 @@ variables: Major: '4' - Minor: '12' - Patch: '2' + Minor: '13' + Patch: '0' + ReleaseVersion: $(Major).$(Minor).$(Patch) AssemblyVersion: $(Major).$(Minor).$(Patch).$(Build.BuildId) - NightlyVersion: $(AssemblyVersion)-$(Build.DefinitionName) + NightlyVersion: $(AssemblyVersion)-$(Build.buildId) stages: - stage: Build jobs: - - job: Mac - displayName: "Mac Build" + - job: MacBuild + displayName: "macOS Build" pool: vmImage: "macOS-11" steps: - - script: python scripts/mk_unix_dist.py --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk - script: git clone https://github.com/z3prover/z3test z3test - - script: python z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 - - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/. - - task: PublishPipelineArtifact@1 + displayName: 'Clone z3test' + - task: PythonScript@0 + displayName: Test inputs: - artifactName: 'Mac' + scriptSource: 'filepath' + scriptPath: z3test/scripts/test_benchmarks.py + arguments: build-dist/z3 z3test/regressions/smt2 + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'macOSBuild' targetPath: $(Build.ArtifactStagingDirectory) - - job: MacArm64 + - job: MacBuildArm64 displayName: "Mac ARM64 Build" pool: vmImage: "macOS-11" @@ -35,32 +51,76 @@ stages: artifactName: 'MacArm64' targetPath: $(Build.ArtifactStagingDirectory) - - job: Ubuntu20 - displayName: "Ubuntu20 build" + - job: UbuntuBuild + displayName: "Ubuntu build" + pool: + vmImage: "ubuntu-latest" + steps: + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk + - script: git clone https://github.com/z3prover/z3test z3test + displayName: 'Clone z3test' + - task: PythonScript@0 + displayName: Test + inputs: + scriptSource: 'filepath' + scriptPath: z3test/scripts/test_benchmarks.py + arguments: build-dist/z3 z3test/regressions/smt2 + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'UbuntuBuild' + targetPath: $(Build.ArtifactStagingDirectory) + + - job: UbuntuBuild20 + displayName: "Ubuntu build 20" pool: vmImage: "ubuntu-20.04" steps: - - script: python scripts/mk_unix_dist.py --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk - script: git clone https://github.com/z3prover/z3test z3test - - script: python z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 - - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/. + displayName: 'Clone z3test' + - task: PythonScript@0 + displayName: Test + inputs: + scriptSource: 'filepath' + scriptPath: z3test/scripts/test_benchmarks.py + arguments: build-dist/z3 z3test/regressions/smt2 + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) - task: PublishPipelineArtifact@0 inputs: - artifactName: 'Ubuntu-20.04' + artifactName: 'UbuntuBuild20' targetPath: $(Build.ArtifactStagingDirectory) - - job: Ubuntu - displayName: "Ubuntu build" + - job: UbuntuArm64 + displayName: "Ubuntu ARM64 build" pool: vmImage: "ubuntu-latest" steps: - - script: python scripts/mk_unix_dist.py --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk + - script: python scripts/mk_unix_dist.py --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk --arch=arm64 - script: git clone https://github.com/z3prover/z3test z3test - script: python z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/. - task: PublishPipelineArtifact@0 inputs: - artifactName: 'Ubuntu' + artifactName: 'UbuntuArm64' targetPath: $(Build.ArtifactStagingDirectory) - job: UbuntuDoc @@ -68,6 +128,7 @@ stages: pool: vmImage: "ubuntu-latest" steps: + - script: pip3 install importlib-resources - script: sudo apt-get install ocaml opam libgmp-dev - script: opam init -y - script: eval `opam config env`; opam install zarith ocamlfind -y @@ -86,8 +147,8 @@ stages: set -e eval `opam config env` cd doc - python mk_api_doc.py --mld --z3py-package-path=../build/python/z3 - python mk_params_doc.py + python3 mk_api_doc.py --mld --z3py-package-path=../build/python/z3 + python3 mk_params_doc.py mkdir api/html/ml ocamldoc -html -d api/html/ml -sort -hide Z3 -I $( ocamlfind query zarith ) -I ../build/api/ml ../build/api/ml/z3enums.mli ../build/api/ml/z3.mli cd .. @@ -97,123 +158,90 @@ stages: inputs: artifactName: 'UbuntuDoc' targetPath: $(Build.ArtifactStagingDirectory) - - - job: ManyLinuxBuild + + - job: LinuxBuilds + displayName: "ManyLinux build" variables: - python: "/opt/python/cp37-cp37m/bin/python" name: ManyLinux - displayName: "ManyLinux build" + python: "/opt/python/cp37-cp37m/bin/python" pool: vmImage: "ubuntu-latest" container: "quay.io/pypa/manylinux2014_x86_64:latest" steps: - - script: $(python) scripts/mk_unix_dist.py --nodotnet --nojava + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava + pythonInterpreter: $(python) - script: git clone https://github.com/z3prover/z3test z3test - - script: $(python) z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 - - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/ - - task: PublishPipelineArtifact@0 + displayName: 'Clone z3test' + - task: PythonScript@0 + displayName: Test inputs: - artifactName: '$(name)Build' - targetPath: $(Build.ArtifactStagingDirectory) - -# - job: MuslLinuxBuild -# condition: eq(0,1) -# variables: -# python: "/opt/python/cp310-cp310/bin/python" -# name: MuslLinux -# displayName: "MuslLinux build" -# pool: -# vmImage: "ubuntu-latest" -# container: "quay.io/pypa/musllinux_1_1_x86_64:latest" -# steps: -# - script: $(python) scripts/mk_unix_dist.py --nodotnet --nojava -# - script: git clone https://github.com/z3prover/z3test z3test -# - script: $(python) z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 -# - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/ -# - task: PublishPipelineArtifact@0 -# inputs: -# artifactName: '$(name)Build' -# targetPath: $(Build.ArtifactStagingDirectory) - - - job: Windows32 - displayName: "Windows 32-bit build" - pool: - vmImage: "windows-latest" - steps: - - task: CmdLine@2 - inputs: - script: - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x86 & - python scripts\mk_win_dist.py - --assembly-version=$(AssemblyVersion) - --x86-only - --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk - --zip + scriptSource: 'filepath' + scriptPath: z3test/scripts/test_benchmarks.py + arguments: build-dist/z3 z3test/regressions/smt2 + pythonInterpreter: $(python) - task: CopyFiles@2 inputs: sourceFolder: dist contents: '*.zip' targetFolder: $(Build.ArtifactStagingDirectory) - - task: PublishPipelineArtifact@1 + - task: PublishPipelineArtifact@0 inputs: + artifactName: 'ManyLinuxBuild' targetPath: $(Build.ArtifactStagingDirectory) - artifactName: 'Windows32' - - task: CopyFiles@2 - displayName: 'Collect Symbols' - inputs: - sourceFolder: dist - contents: '**/*.pdb' - targetFolder: '$(Build.ArtifactStagingDirectory)/symbols' - # Publish symbol archive to match nuget package - # Index your source code and publish symbols to a file share or Azure Artifacts symbol server - - task: PublishSymbols@2 - inputs: - symbolsFolder: '$(Build.ArtifactStagingDirectory)/symbols' - searchPattern: '**/*.pdb' - indexSources: false # Github not supported - publishSymbols: true - symbolServerType: TeamServices - detailedLog: true - - - job: Windows64 - displayName: "Windows 64-bit build" + + - job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" pool: - vmImage: "windows-latest" + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" steps: - - task: CmdLine@2 - inputs: - script: - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 & - python scripts\mk_win_dist.py - --assembly-version=$(AssemblyVersion) - --x64-only - --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk - --zip + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) - task: CopyFiles@2 inputs: sourceFolder: dist contents: '*.zip' targetFolder: $(Build.ArtifactStagingDirectory) - - task: PublishPipelineArtifact@1 + - task: PublishPipelineArtifact@0 inputs: + artifactName: 'ManyLinuxBuildArm64' targetPath: $(Build.ArtifactStagingDirectory) - artifactName: 'Windows64' - - task: CopyFiles@2 - displayName: 'Collect Symbols' - inputs: - sourceFolder: dist - contents: '**/*.pdb' - targetFolder: '$(Build.ArtifactStagingDirectory)/symbols' - # Publish symbol archive to match nuget package - # Index your source code and publish symbols to a file share or Azure Artifacts symbol server - - task: PublishSymbols@2 - inputs: - symbolsFolder: '$(Build.ArtifactStagingDirectory)/symbols' - searchPattern: '**/*.pdb' - indexSources: false # Github not supported - publishSymbols: true - symbolServerType: TeamServices - detailedLog: true + + - template: build-win-signed.yml + parameters: + ReleaseVersion: $(ReleaseVersion) + BuildArchitecture: 'x64' + + - template: build-win-signed.yml + parameters: + ReleaseVersion: $(ReleaseVersion) + BuildArchitecture: 'x86' + + - template: build-win-signed-cmake.yml + parameters: + ReleaseVersion: $(ReleaseVersion) + BuildArchitecture: 'arm64' + VCArchitecture: 'amd64_arm64' - stage: Package jobs: @@ -231,23 +259,28 @@ stages: - task: DownloadPipelineArtifact@2 displayName: 'Download Win64 Build' inputs: - artifact: 'Windows64' + artifact: 'WindowsBuild-x64' path: $(Agent.TempDirectory)\package - task: DownloadPipelineArtifact@2 - displayName: 'Download Ubuntu Build' + displayName: 'Download Ubuntu 20.04 Build' inputs: - artifact: 'Ubuntu' + artifact: 'UbuntuBuild20' path: $(Agent.TempDirectory)\package + - task: DownloadPipelineArtifact@2 + displayName: 'Download Ubuntu ARM64 Build' + inputs: + artifact: 'UbuntuArm64' + path: $(Agent.TempDirectory)\package - task: DownloadPipelineArtifact@2 displayName: 'Download macOS Build' inputs: - artifact: 'Mac' + artifact: 'macOsBuild' path: $(Agent.TempDirectory)\package - task: DownloadPipelineArtifact@2 displayName: 'Download macOS Arm64 Build' inputs: artifact: 'MacArm64' - path: $(Agent.TempDirectory)\package + path: $(Agent.TempDirectory)\package - task: NuGetToolInstaller@0 inputs: versionSpec: 5.x @@ -350,7 +383,7 @@ stages: - task: DownloadPipelineArtifact@2 displayName: 'Download Win32 Build' inputs: - artifact: 'Windows32' + artifact: 'WindowsBuild-x86' path: $(Agent.TempDirectory)\package - task: NuGetToolInstaller@0 inputs: @@ -449,24 +482,23 @@ stages: steps: - task: DownloadPipelineArtifact@2 inputs: - artifactName: 'Windows32' + artifactName: 'WindowsBuild-x86' targetPath: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 inputs: - artifactName: 'Windows64' + artifactName: 'WindowsBuild-x64' targetPath: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 inputs: artifactName: 'ManyLinuxBuild' targetPath: $(Agent.TempDirectory) -# - task: DownloadPipelineArtifact@2 -# displayName: 'Download MuslLinux Build' -# inputs: -# artifact: 'MuslLinuxBuild' -# path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 inputs: - artifactName: 'Mac' + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + inputs: + artifactName: 'macOsBuild' targetPath: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 inputs: @@ -474,14 +506,16 @@ stages: targetPath: $(Agent.TempDirectory) - script: cd $(Agent.TempDirectory); mkdir osx-x64-bin; cd osx-x64-bin; unzip ../*x64-osx*.zip - script: cd $(Agent.TempDirectory); mkdir osx-arm64-bin; cd osx-arm64-bin; unzip ../*arm64-osx*.zip - - script: cd $(Agent.TempDirectory); mkdir libc-bin; cd libc-bin; unzip ../*glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-x64-bin; cd libc-x64-bin; unzip ../*x64-glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-arm64-bin; cd libc-arm64-bin; unzip ../*arm64-glibc*.zip # - script: cd $(Agent.TempDirectory); mkdir musl-bin; cd musl-bin; unzip ../*-linux.zip - script: cd $(Agent.TempDirectory); mkdir win32-bin; cd win32-bin; unzip ../*x86-win*.zip - script: cd $(Agent.TempDirectory); mkdir win64-bin; cd win64-bin; unzip ../*x64-win*.zip - script: python3 -m pip install --user -U setuptools wheel - script: cd src/api/python; python3 setup.py sdist # take a look at this PREMIUM HACK I came up with to get around the fact that the azure variable syntax overloads the bash syntax for subshells - - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel # - script: cd src/api/python; echo $(Agent.TempDirectory)/musl-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win32-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel @@ -504,17 +538,22 @@ stages: - task: DownloadPipelineArtifact@2 displayName: "Download windows32" inputs: - artifactName: 'Windows32' + artifactName: 'WindowsBuild-x86' targetPath: tmp - task: DownloadPipelineArtifact@2 displayName: "Download windows64" inputs: - artifactName: 'Windows64' + artifactName: 'WindowsBuild-x64' + targetPath: tmp + - task: DownloadPipelineArtifact@2 + displayName: "Download windowsARM64" + inputs: + artifactName: 'WindowsBuild-arm64' targetPath: tmp - task: DownloadPipelineArtifact@2 displayName: "Download Mac" inputs: - artifactName: 'Mac' + artifactName: 'macOsBuild' targetPath: tmp - task: DownloadPipelineArtifact@2 displayName: "Download MacArm64" @@ -522,14 +561,14 @@ stages: artifactName: 'MacArm64' targetPath: tmp - task: DownloadPipelineArtifact@2 - displayName: "Download Ubuntu" + displayName: "Download Ubuntu Arm64" inputs: - artifactName: 'Ubuntu' + artifactName: 'UbuntuArm64' targetPath: tmp - task: DownloadPipelineArtifact@2 displayName: "Download Ubuntu-20.04" inputs: - artifactName: 'Ubuntu-20.04' + artifactName: 'UbuntuBuild20' targetPath: tmp - task: DownloadPipelineArtifact@2 displayName: "Download Doc" diff --git a/scripts/release.yml b/scripts/release.yml index 7d3ec1085a0..7fcd93b16da 100644 --- a/scripts/release.yml +++ b/scripts/release.yml @@ -6,7 +6,7 @@ trigger: none variables: - ReleaseVersion: '4.12.2' + ReleaseVersion: '4.13.0' stages: @@ -114,11 +114,26 @@ stages: artifactName: 'UbuntuBuild20' targetPath: $(Build.ArtifactStagingDirectory) + - job: UbuntuArm64 + displayName: "Ubuntu ARM64 build" + pool: + vmImage: "ubuntu-latest" + steps: + - script: python scripts/mk_unix_dist.py --dotnet-key=$(Build.SourcesDirectory)/resources/z3.snk --arch=arm64 + - script: git clone https://github.com/z3prover/z3test z3test + - script: python z3test/scripts/test_benchmarks.py build-dist/z3 z3test/regressions/smt2 + - script: cp dist/*.zip $(Build.ArtifactStagingDirectory)/. + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'UbuntuArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - job: UbuntuDoc displayName: "Ubuntu Doc build" pool: vmImage: "ubuntu-latest" steps: + - script: pip3 install importlib-resources - script: sudo apt-get install ocaml opam libgmp-dev - script: opam init -y - script: eval `opam config env`; opam install zarith ocamlfind -y @@ -137,8 +152,8 @@ stages: set -e eval `opam config env` cd doc - python mk_api_doc.py --mld --z3py-package-path=../build/python/z3 - python mk_params_doc.py + python3 mk_api_doc.py --mld --z3py-package-path=../build/python/z3 + python3 mk_params_doc.py mkdir api/html/ml ocamldoc -html -d api/html/ml -sort -hide Z3 -I $( ocamlfind query zarith ) -I ../build/api/ml ../build/api/ml/z3enums.mli ../build/api/ml/z3.mli cd .. @@ -184,6 +199,39 @@ stages: artifactName: 'ManyLinuxBuild' targetPath: $(Build.ArtifactStagingDirectory) + - job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" + pool: + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" + steps: + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - template: build-win-signed.yml parameters: ReleaseVersion: $(ReleaseVersion) @@ -194,6 +242,12 @@ stages: ReleaseVersion: $(ReleaseVersion) BuildArchitecture: 'x86' + - template: build-win-signed-cmake.yml + parameters: + ReleaseVersion: $(ReleaseVersion) + BuildArchitecture: 'arm64' + VCArchitecture: 'amd64_arm64' + # Creates Z3 packages in various formats - stage: Package @@ -225,6 +279,11 @@ stages: inputs: artifact: 'UbuntuBuild20' path: $(Agent.TempDirectory)\package + - task: DownloadPipelineArtifact@2 + displayName: 'Download Ubuntu ARM64 Build' + inputs: + artifact: 'UbuntuArm64' + path: $(Agent.TempDirectory)\package - task: DownloadPipelineArtifact@2 displayName: 'Download macOS Build' inputs: @@ -432,6 +491,11 @@ stages: inputs: artifact: 'ManyLinuxBuild' path: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + displayName: 'Download ManyLinux Arm64 Build' + inputs: + artifact: 'ManyLinuxBuildArm64' + path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 displayName: 'Download Win32 Build' inputs: @@ -444,7 +508,8 @@ stages: path: $(Agent.TempDirectory) - script: cd $(Agent.TempDirectory); mkdir osx-x64-bin; cd osx-x64-bin; unzip ../*x64-osx*.zip - script: cd $(Agent.TempDirectory); mkdir osx-arm64-bin; cd osx-arm64-bin; unzip ../*arm64-osx*.zip - - script: cd $(Agent.TempDirectory); mkdir libc-bin; cd libc-bin; unzip ../*glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-x64-bin; cd libc-x64-bin; unzip ../*x64-glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-arm64-bin; cd libc-arm64-bin; unzip ../*arm64-glibc*.zip - script: cd $(Agent.TempDirectory); mkdir win32-bin; cd win32-bin; unzip ../*x86-win*.zip - script: cd $(Agent.TempDirectory); mkdir win64-bin; cd win64-bin; unzip ../*x64-win*.zip - script: python3 -m pip install --user -U setuptools wheel @@ -452,7 +517,8 @@ stages: # take a look at this PREMIUM HACK I came up with to get around the fact that the azure variable syntax overloads the bash syntax for subshells - script: cd src/api/python; echo $(Agent.TempDirectory)/osx-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/osx-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win32-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - task: PublishPipelineArtifact@0 @@ -480,6 +546,11 @@ stages: inputs: artifact: 'UbuntuBuild' path: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + displayName: "Download Ubuntu Arm64" + inputs: + artifactName: 'UbuntuArm64' + path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 displayName: "Download Doc" inputs: @@ -501,14 +572,19 @@ stages: artifact: 'WindowsBuild-x86' path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 - displayName: "Download Python" + displayName: 'Download Win64 Build' inputs: - artifactName: 'PythonPackage' + artifact: 'WindowsBuild-x64' path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 - displayName: 'Download Win64 Build' + displayName: 'Download Arm64 Build' inputs: - artifact: 'WindowsBuild-x64' + artifact: 'WindowsBuild-arm64' + path: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + displayName: "Download Python" + inputs: + artifactName: 'PythonPackage' path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 displayName: 'Download NuGet64 Package' diff --git a/scripts/test-java-cmake.yml b/scripts/test-java-cmake.yml index 0f704534464..fa9f767b5c8 100644 --- a/scripts/test-java-cmake.yml +++ b/scripts/test-java-cmake.yml @@ -3,7 +3,7 @@ steps: cd build mkdir -p examples/java cp ../examples/java/JavaExample.java examples/java/ - javac examples/java/Javaexamplejava -classpath com.microsoft.z3.jar + javac examples/java/JavaExample.java -classpath com.microsoft.z3.jar export LD_LIBRARY_PATH=$(pwd):${LD_LIBRARY_PATH} java -cp .:examples/java:com.microsoft.z3.jar JavaExample cd .. diff --git a/scripts/update_api.py b/scripts/update_api.py index 4295b8961c1..79f144142c0 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -1,4 +1,4 @@ -# - !/usr/bin/env python +#!/usr/bin/env python ############################################ # Copyright (c) 2012 Microsoft Corporation # @@ -116,8 +116,8 @@ def def_Type(self, var, c_type, py_type): def def_Types(self, api_files): global Closures - pat1 = re.compile(" *def_Type\(\'(.*)\',[^\']*\'(.*)\',[^\']*\'(.*)\'\)[ \t]*") - pat2 = re.compile("Z3_DECLARE_CLOSURE\((.*),(.*), \((.*)\)\)") + pat1 = re.compile(r" *def_Type\(\'(.*)\',[^\']*\'(.*)\',[^\']*\'(.*)\'\)[ \t]*") + pat2 = re.compile(r"Z3_DECLARE_CLOSURE\((.*),(.*), \((.*)\)\)") for api_file in api_files: with open(api_file, 'r') as api: for line in api: @@ -426,9 +426,10 @@ def mk_dotnet(dotnet): dotnet.write(' {\n\n') for name, ret, sig in Closures: + sig = sig.replace("unsigned const*","uint[]") sig = sig.replace("void*","voidp").replace("unsigned","uint") sig = sig.replace("Z3_ast*","ref IntPtr").replace("uint*","ref uint").replace("Z3_lbool*","ref int") - ret = ret.replace("void*","voidp").replace("unsigned","uint") + ret = ret.replace("void*","voidp").replace("unsigned","uint") if "*" in sig or "*" in ret: continue dotnet.write(' [UnmanagedFunctionPointer(CallingConvention.Cdecl)]\n') @@ -633,7 +634,74 @@ def mk_java(java_src, java_dir, package_name): java_native.write(' }\n') java_native.write(' }\n') java_native.write(' }\n') - + java_native.write(""" + public static native long propagateInit(Object o, long ctx, long solver); + public static native void propagateRegisterCreated(Object o, long ctx, long solver); + public static native void propagateRegisterFixed(Object o, long ctx, long solver); + public static native void propagateRegisterEq(Object o, long ctx, long solver); + public static native void propagateRegisterDecide(Object o, long ctx, long solver); + public static native void propagateRegisterFinal(Object o, long ctx, long solver); + public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e); + public static native boolean propagateConsequence(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq); + public static native boolean propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, int phase); + public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo); + + public static abstract class UserPropagatorBase implements AutoCloseable { + protected long ctx; + protected long solver; + protected long javainfo; + + public UserPropagatorBase(long _ctx, long _solver) { + ctx = _ctx; + solver = _solver; + javainfo = propagateInit(this, ctx, solver); + } + + @Override + public void close() { + Native.propagateDestroy(this, ctx, solver, javainfo); + javainfo = 0; + solver = 0; + ctx = 0; + } + + protected final void registerCreated() { + Native.propagateRegisterCreated(this, ctx, solver); + } + + protected final void registerFixed() { + Native.propagateRegisterFixed(this, ctx, solver); + } + + protected final void registerEq() { + Native.propagateRegisterEq(this, ctx, solver); + } + + protected final void registerDecide() { + Native.propagateRegisterDecide(this, ctx, solver); + } + + protected final void registerFinal() { + Native.propagateRegisterFinal(this, ctx, solver); + } + + protected abstract void pushWrapper(); + + protected abstract void popWrapper(int number); + + protected abstract void finWrapper(); + + protected abstract void eqWrapper(long lx, long ly); + + protected abstract UserPropagatorBase freshWrapper(long lctx); + + protected abstract void createdWrapper(long le); + + protected abstract void fixedWrapper(long lvar, long lvalue); + + protected abstract void decideWrapper(long lvar, int bit, boolean is_pos); + } + """) java_native.write('\n') for name, result, params in _dotnet_decls: java_native.write(' protected static native %s INTERNAL%s(' % (type2java(result), java_method_name(name))) @@ -700,7 +768,7 @@ def mk_java(java_src, java_dir, package_name): java_wrapper.write(line) for name, result, params in _dotnet_decls: java_wrapper.write('DLL_VIS JNIEXPORT %s JNICALL Java_%s_Native_INTERNAL%s(JNIEnv * jenv, jclass cls' % (type2javaw(result), pkg_str, java_method_name(name))) - i = 0 + i = 0 for param in params: java_wrapper.write(', ') java_wrapper.write('%s a%d' % (param2javaw(param), i)) @@ -1761,17 +1829,28 @@ def write_core_py_preamble(core_py): core_py.write( """ # Automatically generated file +import atexit import sys, os +import contextlib import ctypes -import pkg_resources +if sys.version_info >= (3, 9): + import importlib.resources as importlib_resources +else: + import importlib_resources from .z3types import * from .z3consts import * +_file_manager = contextlib.ExitStack() +atexit.register(_file_manager.close) _ext = 'dll' if sys.platform in ('win32', 'cygwin') else 'dylib' if sys.platform == 'darwin' else 'so' _lib = None +_z3_lib_resource = importlib_resources.files('z3').joinpath('lib') +_z3_lib_resource_path = _file_manager.enter_context( + importlib_resources.as_file(_z3_lib_resource) +) _default_dirs = ['.', os.path.dirname(os.path.abspath(__file__)), - pkg_resources.resource_filename('z3', 'lib'), + _z3_lib_resource_path, os.path.join(sys.prefix, 'lib'), None] _all_dirs = [] @@ -1821,10 +1900,10 @@ def write_core_py_preamble(core_py): print(" - to the custom Z3_LIB_DIRS Python-builtin before importing the z3 module, e.g. via") if sys.version < '3': print(" import __builtin__") - print(" __builtin__.Z3_LIB_DIRS = [ '/path/to/libz3.%s' ] " % _ext) + print(" __builtin__.Z3_LIB_DIRS = [ '/path/to/z3/lib/dir' ] # directory containing libz3.%s" % _ext) else: print(" import builtins") - print(" builtins.Z3_LIB_DIRS = [ '/path/to/libz3.%s' ] " % _ext) + print(" builtins.Z3_LIB_DIRS = [ '/path/to/z3/lib/dir' ] # directory containing libz3.%s" % _ext) print(_failures) raise Z3Exception("libz3.%s not found." % _ext) @@ -1854,7 +1933,7 @@ def _to_pystr(s): _lib.Z3_set_error_handler.restype = None _lib.Z3_set_error_handler.argtypes = [ContextObj, _error_handler_type] -Z3_on_clause_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) +Z3_on_clause_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p) Z3_push_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p) Z3_pop_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint) Z3_fresh_eh = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) @@ -1864,7 +1943,7 @@ def _to_pystr(s): Z3_eq_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) Z3_created_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) -Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) +Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_int) _lib.Z3_solver_register_on_clause.restype = None _lib.Z3_solver_propagate_init.restype = None diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 27ff2b0d174..b0971f835b2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -54,6 +54,7 @@ add_subdirectory(ast/euf) add_subdirectory(ast/converters) add_subdirectory(ast/substitution) add_subdirectory(ast/simplifiers) +add_subdirectory(ast/sls) add_subdirectory(tactic) add_subdirectory(qe/mbp) add_subdirectory(qe/lite) @@ -221,7 +222,7 @@ if (MSVC) set(dll_module_exports_file "${CMAKE_CURRENT_BINARY_DIR}/api_dll.def") add_custom_command(OUTPUT "${dll_module_exports_file}" COMMAND - "${PYTHON_EXECUTABLE}" + "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/mk_def_file.py" "${dll_module_exports_file}" "libz3" diff --git a/src/ackermannization/ackermannize_bv_tactic.h b/src/ackermannization/ackermannize_bv_tactic.h index b99ae00e47f..1afac4498c5 100644 --- a/src/ackermannization/ackermannize_bv_tactic.h +++ b/src/ackermannization/ackermannize_bv_tactic.h @@ -22,7 +22,7 @@ A tactic for performing Ackermann reduction for bit-vector formulas ### Long Description The Ackermann reduction replaces uninterpreted functions $f(t_1), f(t_2)$ -by fresh variables $f_1, f_2$ and addes axioms $t_1 \simeq t_2 \implies f_1 \simeq f_2$. +by fresh variables $f_1, f_2$ and adds axioms $t_1 \simeq t_2 \implies f_1 \simeq f_2$. The reduction has the effect of eliminating uninterpreted functions. When the reduction produces a pure bit-vector benchmark, it allows Z3 to use a specialized SAT solver. diff --git a/src/ackermannization/ackr_helper.h b/src/ackermannization/ackr_helper.h index 9abe151e93b..5499e7d3a0b 100644 --- a/src/ackermannization/ackr_helper.h +++ b/src/ackermannization/ackr_helper.h @@ -18,6 +18,7 @@ #include "ast/bv_decl_plugin.h" #include "ast/array_decl_plugin.h" +#include "ast/ast_ll_pp.h" class ackr_helper { public: @@ -40,10 +41,8 @@ class ackr_helper { inline bool is_uninterp_fn(app const * a) const { if (is_uninterp(a)) return true; - else { - decl_plugin * p = m_bvutil.get_manager().get_plugin(a->get_family_id()); - return p->is_considered_uninterpreted(a->get_decl()); - } + decl_plugin * p = m_bvutil.get_manager().get_plugin(a->get_family_id()); + return p->is_considered_uninterpreted(a->get_decl()); } /** @@ -64,9 +63,8 @@ class ackr_helper { } } else { - for (expr* arg : *a) { + for (expr* arg : *a) non_select.mark(arg, true); - } } } @@ -112,7 +110,8 @@ class ackr_helper { } void insert(fun2terms_map& f2t, sel2terms_map& s2t, app* a) { - if (a->get_num_args() == 0) return; + if (a->get_num_args() == 0) + return; ast_manager& m = m_bvutil.get_manager(); app_set* ts = nullptr; bool is_const_args = true; @@ -129,21 +128,18 @@ class ackr_helper { ts = alloc(app_set); f2t.insert(fd, ts); } - is_const_args = m.is_value(a->get_arg(0)); + is_const_args = m.is_unique_value(a->get_arg(0)); } - else { + else return; - } - for (unsigned i = 1; is_const_args && i < a->get_num_args(); ++i) { - is_const_args &= m.is_value(a->get_arg(i)); - } - if (is_const_args) { + for (unsigned i = 1; is_const_args && i < a->get_num_args(); ++i) + is_const_args &= m.is_unique_value(a->get_arg(i)); + + if (is_const_args) ts->const_args.insert(a); - } - else { + else ts->var_args.insert(a); - } } private: diff --git a/src/ackermannization/ackr_info.h b/src/ackermannization/ackr_info.h index fd236106403..67c41bda22d 100644 --- a/src/ackermannization/ackr_info.h +++ b/src/ackermannization/ackr_info.h @@ -23,7 +23,7 @@ Revision History: /** \brief Information about how a formula is being converted into - a formula without uninterpreted function symbols via ackermannization. + a formula without uninterpreted function symbols via ackermannization. The intended use is that new terms are added via set_abstr. Once all terms are abstracted, call seal. diff --git a/src/ackermannization/lackr.cpp b/src/ackermannization/lackr.cpp index 887f1d51941..b02ef8c44a1 100644 --- a/src/ackermannization/lackr.cpp +++ b/src/ackermannization/lackr.cpp @@ -22,6 +22,8 @@ #include "ackermannization/ackr_info.h" #include "ast/for_each_expr.h" #include "ast/ast_util.h" +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" #include "model/model_smt2_pp.h" lackr::lackr(ast_manager& m, const params_ref& p, lackr_stats& st, @@ -142,10 +144,10 @@ bool lackr::ackr(app * const t1, app * const t2) { // Introduce the ackermann lemma for each pair of terms. // void lackr::eager_enc() { - TRACE("ackermannize", tout << "#funs: " << m_fun2terms.size() << "#sels: " << m_sel2terms.size() << std::endl;); - for (auto const& kv : m_fun2terms) { + TRACE("ackermannize", tout << "#funs: " << m_fun2terms.size() << " #sels: " << m_sel2terms.size() << std::endl;); + for (auto const& [k,v] : m_fun2terms) { checkpoint(); - ackr(kv.get_value()); + ackr(v); } for (auto const& kv : m_sel2terms) { checkpoint(); @@ -172,14 +174,13 @@ void lackr::ackr(app_set const* ts) { } void lackr::abstract_fun(fun2terms_map const& apps) { - for (auto const& kv : apps) { - func_decl* fd = kv.m_key; - for (app * t : kv.m_value->var_args) { + for (auto const& [fd, v] : apps) { + for (app * t : v->var_args) { app * fc = m.mk_fresh_const(fd->get_name(), t->get_sort()); SASSERT(t->get_decl() == fd); m_info->set_abstr(t, fc); } - for (app * t : kv.m_value->const_args) { + for (app * t : v->const_args) { app * fc = m.mk_fresh_const(fd->get_name(), t->get_sort()); SASSERT(t->get_decl() == fd); m_info->set_abstr(t, fc); diff --git a/src/api/CMakeLists.txt b/src/api/CMakeLists.txt index 5cc604f1539..38a09b63187 100644 --- a/src/api/CMakeLists.txt +++ b/src/api/CMakeLists.txt @@ -18,7 +18,7 @@ foreach (gen_file ${generated_files}) endforeach() add_custom_command(OUTPUT ${generated_files} - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/update_api.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--api_output_dir" diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index bc29826ff1a..424b361f3ea 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -29,6 +29,7 @@ Revision History: #include "ast/ast_ll_pp.h" #include "ast/ast_smt_pp.h" #include "ast/ast_smt2_pp.h" +#include "ast/polymorphism_util.h" #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/expr_safe_replace.h" @@ -88,6 +89,16 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_sort Z3_API Z3_mk_type_variable(Z3_context c, Z3_symbol name) { + Z3_TRY; + LOG_Z3_mk_type_variable(c, name); + RESET_ERROR_CODE(); + sort* ty = mk_c(c)->m().mk_type_var(to_symbol(name)); + mk_c(c)->save_ast_trail(ty); + RETURN_Z3(of_sort(ty)); + Z3_CATCH_RETURN(nullptr); + } + bool Z3_API Z3_is_eq_ast(Z3_context c, Z3_ast s1, Z3_ast s2) { RESET_ERROR_CODE(); return s1 == s2; @@ -135,8 +146,7 @@ extern "C" { ast_manager& m = mk_c(c)->m(); recfun::decl::plugin& p = mk_c(c)->recfun().get_plugin(); if (!p.has_def(d)) { - std::string msg = "function " + mk_pp(d, m) + " needs to be declared using rec_func_decl"; - SET_ERROR_CODE(Z3_INVALID_ARG, msg.c_str()); + SET_ERROR_CODE(Z3_INVALID_ARG, "function " + mk_pp(d, m) + " needs to be declared using rec_func_decl"); return; } expr_ref abs_body(m); @@ -157,8 +167,7 @@ extern "C" { return; } if (!pd.get_def()->get_cases().empty()) { - std::string msg = "function " + mk_pp(d, m) + " has already been given a definition"; - SET_ERROR_CODE(Z3_INVALID_ARG, msg.c_str()); + SET_ERROR_CODE(Z3_INVALID_ARG, "function " + mk_pp(d, m) + " has already been given a definition"); return; } @@ -180,7 +189,20 @@ extern "C" { arg_list.push_back(to_expr(args[i])); } func_decl* _d = reinterpret_cast(d); - app* a = mk_c(c)->m().mk_app(_d, num_args, arg_list.data()); + ast_manager& m = mk_c(c)->m(); + if (_d->is_polymorphic()) { + polymorphism::util u(m); + polymorphism::substitution sub(m); + ptr_buffer domain; + for (unsigned i = 0; i < num_args; ++i) { + if (!sub.match(_d->get_domain(i), arg_list[i]->get_sort())) + SET_ERROR_CODE(Z3_INVALID_ARG, "failed to match argument of polymorphic function"); + domain.push_back(arg_list[i]->get_sort()); + } + sort_ref range = sub(_d->get_range()); + _d = m.instantiate_polymorphic(_d, num_args, domain.data(), range); + } + app* a = m.mk_app(_d, num_args, arg_list.data()); mk_c(c)->save_ast_trail(a); check_sorts(c, a); RETURN_Z3(of_ast(a)); @@ -353,9 +375,7 @@ extern "C" { RESET_ERROR_CODE(); symbol _s = to_symbol(s); if (_s.is_numerical()) { - std::ostringstream buffer; - buffer << _s.get_num(); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::to_string(_s.get_num())); } else { return mk_c(c)->mk_external_string(_s.str()); @@ -728,6 +748,9 @@ extern "C" { else if (fid == mk_c(c)->get_char_fid() && k == CHAR_SORT) { return Z3_CHAR_SORT; } + else if (fid == poly_family_id) { + return Z3_TYPE_VAR; + } else { return Z3_UNKNOWN_SORT; } @@ -796,7 +819,7 @@ extern "C" { param_descrs descrs; th_rewriter::get_param_descrs(descrs); descrs.display(buffer); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(""); } @@ -1004,7 +1027,7 @@ extern "C" { default: UNREACHABLE(); } - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(nullptr); } @@ -1039,7 +1062,7 @@ extern "C" { pp.add_assumption(to_expr(assumptions[i])); } pp.display_smt2(buffer, to_expr(formula)); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(""); } diff --git a/src/api/api_ast_map.cpp b/src/api/api_ast_map.cpp index 5976d0e4181..523ba1f5976 100644 --- a/src/api/api_ast_map.cpp +++ b/src/api/api_ast_map.cpp @@ -160,8 +160,8 @@ extern "C" { for (; it != end; ++it) { buffer << "\n (" << mk_ismt2_pp(it->m_key, mng, 3) << "\n " << mk_ismt2_pp(it->m_value, mng, 3) << ")"; } - buffer << ")"; - return mk_c(c)->mk_external_string(buffer.str()); + buffer << ')'; + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(nullptr); } diff --git a/src/api/api_bv.cpp b/src/api/api_bv.cpp index bb426373004..3ea5ba91810 100644 --- a/src/api/api_bv.cpp +++ b/src/api/api_bv.cpp @@ -227,6 +227,9 @@ Z3_ast Z3_API NAME(Z3_context c, unsigned i, Z3_ast n) { \ Z3_ast Z3_API Z3_mk_bvadd_no_underflow(Z3_context c, Z3_ast t1, Z3_ast t2) { Z3_TRY; RESET_ERROR_CODE(); + // l1 := t1 t1 + t2 y Z3_ast zero = Z3_mk_int(c, 0, Z3_get_sort(c, t1)); Z3_inc_ref(c, zero); Z3_ast minus_t2 = Z3_mk_bvneg(c, t2); diff --git a/src/api/api_config_params.cpp b/src/api/api_config_params.cpp index 93d4e27e1bd..d463b1fb7a6 100644 --- a/src/api/api_config_params.cpp +++ b/src/api/api_config_params.cpp @@ -98,7 +98,10 @@ extern "C" { LOG_Z3_set_param_value(c, param_id, param_value); try { ast_context_params * p = reinterpret_cast(c); - p->set(param_id, param_value); + if (p->is_shell_only_parameter(param_id)) + warning_msg("parameter %s can only be set for the shell, not binary API", param_id); + else + p->set(param_id, param_value); } catch (z3_exception & ex) { // The error handler is only available for contexts @@ -111,7 +114,10 @@ extern "C" { Z3_TRY; LOG_Z3_update_param_value(c, param_id, param_value); RESET_ERROR_CODE(); - mk_c(c)->params().set(param_id, param_value); + if (mk_c(c)->params().is_shell_only_parameter(param_id)) + warning_msg("parameter %s can only be set for the shell, not binary API", param_id); + else + mk_c(c)->params().set(param_id, param_value); Z3_CATCH; } diff --git a/src/api/api_context.cpp b/src/api/api_context.cpp index 2b7a4ce43a5..344224dd386 100644 --- a/src/api/api_context.cpp +++ b/src/api/api_context.cpp @@ -78,6 +78,11 @@ namespace api { m().dec_ref(a); } + // flush_objects can only be called in the main thread. + // This ensures that the calls to m().dec_ref() and dealloc(o) + // only happens in the main thread. + // Calls to dec_ref are allowed in other threads when m_concurrent_dec_ref is + // set to true. void context::flush_objects() { #ifndef SINGLE_THREAD if (!m_concurrent_dec_ref) @@ -157,6 +162,9 @@ namespace api { flush_objects(); for (auto& kv : m_allocated_objects) { api::object* val = kv.m_value; +#ifdef SINGLE_THREAD +# define m_concurrent_dec_ref false +#endif DEBUG_CODE(if (!m_concurrent_dec_ref) warning_msg("Uncollected memory: %d: %s", kv.m_key, typeid(*val).name());); dealloc(val); } @@ -330,12 +338,12 @@ namespace api { std::ostringstream buffer; app * a = to_app(n); buffer << mk_pp(a->get_decl(), m()) << " applied to: "; - if (a->get_num_args() > 1) buffer << "\n"; + if (a->get_num_args() > 1) buffer << '\n'; for (unsigned i = 0; i < a->get_num_args(); ++i) { buffer << mk_bounded_pp(a->get_arg(i), m(), 3) << " of sort "; - buffer << mk_pp(a->get_arg(i)->get_sort(), m()) << "\n"; + buffer << mk_pp(a->get_arg(i)->get_sort(), m()) << '\n'; } - auto str = buffer.str(); + auto str = std::move(buffer).str(); warning_msg("%s", str.c_str()); break; } diff --git a/src/api/api_context.h b/src/api/api_context.h index a3f027dd576..8b049ce1691 100644 --- a/src/api/api_context.h +++ b/src/api/api_context.h @@ -231,7 +231,7 @@ namespace api { void handle_exception(z3_exception & ex); char const * get_exception_msg() const { return m_exception_msg.c_str(); } - // Interrupt the current interruptable object + // Interrupt the current interruptible object void interrupt(); void invoke_error_handler(Z3_error_code c); diff --git a/src/api/api_datatype.cpp b/src/api/api_datatype.cpp index 71d1de21210..1ef4ea62688 100644 --- a/src/api/api_datatype.cpp +++ b/src/api/api_datatype.cpp @@ -241,6 +241,20 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + unsigned Z3_API Z3_constructor_num_fields(Z3_context c, Z3_constructor constr) { + Z3_TRY; + LOG_Z3_constructor_num_fields(c, constr); + RESET_ERROR_CODE(); + mk_c(c)->reset_last_result(); + if (!constr) { + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + return 0; + } + constructor* c = reinterpret_cast(constr); + return c->m_field_names.size(); + Z3_CATCH_RETURN(0); + } + void Z3_API Z3_query_constructor(Z3_context c, Z3_constructor constr, diff --git a/src/api/api_fpa.cpp b/src/api/api_fpa.cpp index 2dda84af45a..3c350ed180d 100644 --- a/src/api/api_fpa.cpp +++ b/src/api/api_fpa.cpp @@ -742,7 +742,7 @@ extern "C" { fpa_util & fu = ctx->fpautil(); if (!ctx->bvutil().is_bv(to_expr(bv)) || !fu.is_float(to_sort(s))) { - SET_ERROR_CODE(Z3_INVALID_ARG, "bv sort the flaot sort expected"); + SET_ERROR_CODE(Z3_INVALID_ARG, "bv sort the float sort expected"); return nullptr; } expr * a = fu.mk_to_fp(to_sort(s), to_expr(bv)); @@ -1022,7 +1022,7 @@ extern "C" { if (mpfm.is_inf(val)) mpqm.set(q, 0); std::stringstream ss; mpqm.display_decimal(ss, q, sbits); - return mk_c(c)->mk_external_string(ss.str()); + return mk_c(c)->mk_external_string(std::move(ss).str()); Z3_CATCH_RETURN(""); } @@ -1100,7 +1100,7 @@ extern "C" { } std::stringstream ss; ss << exp; - return mk_c(c)->mk_external_string(ss.str()); + return mk_c(c)->mk_external_string(std::move(ss).str()); Z3_CATCH_RETURN(""); } diff --git a/src/api/api_goal.cpp b/src/api/api_goal.cpp index cfe0974df0a..cbb67f2a269 100644 --- a/src/api/api_goal.cpp +++ b/src/api/api_goal.cpp @@ -185,7 +185,7 @@ extern "C" { std::ostringstream buffer; to_goal_ref(g)->display(buffer); // Hack for removing the trailing '\n' - std::string result = buffer.str(); + std::string result = std::move(buffer).str(); SASSERT(result.size() > 0); result.resize(result.size()-1); return mk_c(c)->mk_external_string(std::move(result)); @@ -203,7 +203,7 @@ extern "C" { } to_goal_ref(g)->display_dimacs(buffer, include_names); // Hack for removing the trailing '\n' - std::string result = buffer.str(); + std::string result = std::move(buffer).str(); SASSERT(result.size() > 0); result.resize(result.size()-1); return mk_c(c)->mk_external_string(std::move(result)); diff --git a/src/api/api_model.cpp b/src/api/api_model.cpp index 3512b4b0502..e449cb0ea2b 100644 --- a/src/api/api_model.cpp +++ b/src/api/api_model.cpp @@ -432,14 +432,14 @@ extern "C" { if (mk_c(c)->get_print_mode() == Z3_PRINT_SMTLIB2_COMPLIANT) { model_smt2_pp(buffer, mk_c(c)->m(), *(to_model_ref(m)), 0); // Hack for removing the trailing '\n' - result = buffer.str(); + result = std::move(buffer).str(); if (!result.empty()) result.resize(result.size()-1); } else { model_params p; model_v2_pp(buffer, *(to_model_ref(m)), p.partial()); - result = buffer.str(); + result = std::move(buffer).str(); } return mk_c(c)->mk_external_string(std::move(result)); Z3_CATCH_RETURN(nullptr); diff --git a/src/api/api_numeral.cpp b/src/api/api_numeral.cpp index 7d8c00fbecf..b90a84bb791 100644 --- a/src/api/api_numeral.cpp +++ b/src/api/api_numeral.cpp @@ -29,13 +29,12 @@ bool is_numeral_sort(Z3_context c, Z3_sort ty) { if (!ty) return false; sort * _ty = to_sort(ty); family_id fid = _ty->get_family_id(); - if (fid != mk_c(c)->get_arith_fid() && - fid != mk_c(c)->get_bv_fid() && - fid != mk_c(c)->get_datalog_fid() && - fid != mk_c(c)->get_fpa_fid()) { - return false; - } - return true; + return + fid == mk_c(c)->get_arith_fid() || + fid == mk_c(c)->get_bv_fid() || + fid == mk_c(c)->get_datalog_fid() || + fid == mk_c(c)->get_fpa_fid(); + } static bool check_numeral_sort(Z3_context c, Z3_sort ty) { @@ -152,7 +151,7 @@ extern "C" { mk_c(c)->bvutil().is_numeral(e) || mk_c(c)->fpautil().is_numeral(e) || mk_c(c)->fpautil().is_rm_numeral(e) || - mk_c(c)->datalog_util().is_numeral_ext(e); + mk_c(c)->datalog_util().is_numeral(e); Z3_CATCH_RETURN(false); } @@ -189,8 +188,8 @@ extern "C" { bool ok = Z3_get_numeral_rational(c, a, r); if (ok && r.is_int() && !r.is_neg()) { std::stringstream strm; - r.display_bin(strm, r.get_num_bits()); - return mk_c(c)->mk_external_string(strm.str()); + strm << r.as_bin(r.get_num_bits()); + return mk_c(c)->mk_external_string(std::move(strm).str()); } else { SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); @@ -238,7 +237,7 @@ extern "C" { else if (mk_c(c)->fpautil().is_numeral(to_expr(a), tmp)) { std::ostringstream buffer; fu.fm().display_smt2(buffer, tmp, false); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); } else { SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); @@ -289,21 +288,21 @@ extern "C" { if (u.is_numeral(e, r) && !r.is_int()) { std::ostringstream buffer; r.display_decimal(buffer, precision); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); } if (u.is_irrational_algebraic_numeral(e)) { algebraic_numbers::anum const & n = u.to_irrational_algebraic_numeral(e); algebraic_numbers::manager & am = u.am(); std::ostringstream buffer; am.display_decimal(buffer, n, precision); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); } else if (mk_c(c)->fpautil().is_rm_numeral(to_expr(a), rm)) return Z3_get_numeral_string(c, a); else if (mk_c(c)->fpautil().is_numeral(to_expr(a), ftmp)) { std::ostringstream buffer; fu.fm().display_decimal(buffer, ftmp, 12); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); } else if (Z3_get_numeral_rational(c, a, r)) { return mk_c(c)->mk_external_string(r.to_string()); diff --git a/src/api/api_quant.cpp b/src/api/api_quant.cpp index 885cf659872..bb9efa9c30b 100644 --- a/src/api/api_quant.cpp +++ b/src/api/api_quant.cpp @@ -249,7 +249,10 @@ extern "C" { expr_abstract(mk_c(c)->m(), 0, num_bound, bound_asts.data(), pat, result); SASSERT(result.get()->get_kind() == AST_APP); pinned.push_back(result.get()); - SASSERT(mk_c(c)->m().is_pattern(result.get())); + if (!mk_c(c)->m().is_pattern(result.get())) { + SET_ERROR_CODE(Z3_INVALID_ARG, "invalid pattern"); + RETURN_Z3(nullptr); + } _patterns.push_back(of_pattern(result.get())); } svector _no_patterns; @@ -380,6 +383,36 @@ extern "C" { Z3_CATCH_RETURN(0); } + Z3_symbol Z3_API Z3_get_quantifier_skolem_id(Z3_context c, Z3_ast a) { + Z3_TRY; + LOG_Z3_get_quantifier_skolem_id(c, a); + RESET_ERROR_CODE(); + ast * _a = to_ast(a); + if (_a->get_kind() == AST_QUANTIFIER) { + return of_symbol(to_quantifier(_a)->get_skid()); + } + else { + SET_ERROR_CODE(Z3_SORT_ERROR, nullptr); + return of_symbol(symbol::null); + } + Z3_CATCH_RETURN(of_symbol(symbol::null)); + } + + Z3_symbol Z3_API Z3_get_quantifier_id(Z3_context c, Z3_ast a) { + Z3_TRY; + LOG_Z3_get_quantifier_skolem_id(c, a); + RESET_ERROR_CODE(); + ast * _a = to_ast(a); + if (_a->get_kind() == AST_QUANTIFIER) { + return of_symbol(to_quantifier(_a)->get_qid()); + } + else { + SET_ERROR_CODE(Z3_SORT_ERROR, nullptr); + return of_symbol(symbol::null); + } + Z3_CATCH_RETURN(of_symbol(symbol::null)); + } + unsigned Z3_API Z3_get_quantifier_num_patterns(Z3_context c, Z3_ast a) { Z3_TRY; LOG_Z3_get_quantifier_num_patterns(c, a); diff --git a/src/api/api_rcf.cpp b/src/api/api_rcf.cpp index 8a000be1ad2..ccc79bc46bd 100644 --- a/src/api/api_rcf.cpp +++ b/src/api/api_rcf.cpp @@ -17,7 +17,7 @@ Module Name: Leonardo de Moura (leonardo) 2012-01-05 Notes: - + --*/ #include "api/z3.h" #include "api/api_log_macros.h" @@ -32,12 +32,12 @@ static void reset_rcf_cancel(Z3_context c) { // no-op } -static Z3_rcf_num from_rcnumeral(rcnumeral a) { - return reinterpret_cast(a.data()); +static Z3_rcf_num from_rcnumeral(rcnumeral a) { + return reinterpret_cast(a.data()); } -static rcnumeral to_rcnumeral(Z3_rcf_num a) { - return rcnumeral::mk(a); +static rcnumeral to_rcnumeral(Z3_rcf_num a) { + return rcnumeral::mk(a); } extern "C" { @@ -179,7 +179,7 @@ extern "C" { RETURN_Z3(from_rcnumeral(r)); Z3_CATCH_RETURN(nullptr); } - + Z3_rcf_num Z3_API Z3_rcf_neg(Z3_context c, Z3_rcf_num a) { Z3_TRY; LOG_Z3_rcf_neg(c, a); @@ -302,4 +302,139 @@ extern "C" { Z3_CATCH; } + bool Z3_API Z3_rcf_is_rational(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_is_rational(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).is_rational(to_rcnumeral(a)); + Z3_CATCH_RETURN(false); + } + + bool Z3_API Z3_rcf_is_algebraic(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_is_algebraic(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).is_algebraic(to_rcnumeral(a)); + Z3_CATCH_RETURN(false); + } + + bool Z3_API Z3_rcf_is_infinitesimal(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_is_infinitesimal(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).is_infinitesimal(to_rcnumeral(a)); + Z3_CATCH_RETURN(false); + } + + bool Z3_API Z3_rcf_is_transcendental(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_is_transcendental(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).is_transcendental(to_rcnumeral(a)); + Z3_CATCH_RETURN(false); + } + + unsigned Z3_API Z3_rcf_extension_index(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_extension_index(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).extension_index(to_rcnumeral(a)); + Z3_CATCH_RETURN(false); + } + + Z3_symbol Z3_API Z3_rcf_transcendental_name(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_transcendental_name(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return of_symbol(rcfm(c).transcendental_name(to_rcnumeral(a))); + Z3_CATCH_RETURN(of_symbol(symbol::null)); + } + + Z3_symbol Z3_API Z3_rcf_infinitesimal_name(Z3_context c, Z3_rcf_num a) { + Z3_TRY; + LOG_Z3_rcf_infinitesimal_name(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return of_symbol(rcfm(c).infinitesimal_name(to_rcnumeral(a))); + Z3_CATCH_RETURN(of_symbol(symbol::null)); + } + + unsigned Z3_API Z3_rcf_num_coefficients(Z3_context c, Z3_rcf_num a) + { + Z3_TRY; + LOG_Z3_rcf_num_coefficients(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).num_coefficients(to_rcnumeral(a)); + Z3_CATCH_RETURN(0); + } + + Z3_rcf_num Z3_API Z3_rcf_coefficient(Z3_context c, Z3_rcf_num a, unsigned i) + { + Z3_TRY; + LOG_Z3_rcf_coefficient(c, a, i); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return from_rcnumeral(rcfm(c).get_coefficient(to_rcnumeral(a), i)); + Z3_CATCH_RETURN(nullptr); + } + + int Z3_API Z3_rcf_interval(Z3_context c, Z3_rcf_num a, int * lower_is_inf, int * lower_is_open, Z3_rcf_num * lower, int * upper_is_inf, int * upper_is_open, Z3_rcf_num * upper) { + Z3_TRY; + LOG_Z3_rcf_interval(c, a, lower_is_inf, lower_is_open, lower, upper_is_inf, upper_is_open, upper); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + rcnumeral r_lower, r_upper; + bool r = rcfm(c).get_interval(to_rcnumeral(a), *lower_is_inf, *lower_is_open, r_lower, *upper_is_inf, *upper_is_open, r_upper); + *lower = from_rcnumeral(r_lower); + *upper = from_rcnumeral(r_upper); + return r; + Z3_CATCH_RETURN(0); + } + + unsigned Z3_API Z3_rcf_num_sign_conditions(Z3_context c, Z3_rcf_num a) + { + Z3_TRY; + LOG_Z3_rcf_num_sign_conditions(c, a); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).num_sign_conditions(to_rcnumeral(a)); + Z3_CATCH_RETURN(0); + } + + int Z3_API Z3_rcf_sign_condition_sign(Z3_context c, Z3_rcf_num a, unsigned i) + { + Z3_TRY; + LOG_Z3_rcf_sign_condition_sign(c, a, i); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).get_sign_condition_sign(to_rcnumeral(a), i); + Z3_CATCH_RETURN(0); + } + + unsigned Z3_API Z3_rcf_num_sign_condition_coefficients(Z3_context c, Z3_rcf_num a, unsigned i) + { + Z3_TRY; + LOG_Z3_rcf_num_sign_condition_coefficients(c, a, i); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return rcfm(c).num_sign_condition_coefficients(to_rcnumeral(a), i); + Z3_CATCH_RETURN(0); + } + + Z3_rcf_num Z3_API Z3_rcf_sign_condition_coefficient(Z3_context c, Z3_rcf_num a, unsigned i, unsigned j) + { + Z3_TRY; + LOG_Z3_rcf_sign_condition_coefficient(c, a, i, j); + RESET_ERROR_CODE(); + reset_rcf_cancel(c); + return from_rcnumeral(rcfm(c).get_sign_condition_coefficient(to_rcnumeral(a), i, j)); + Z3_CATCH_RETURN(nullptr); + } }; diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 2c19d0d9eeb..ac100ee3079 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -170,8 +170,8 @@ extern "C" { if (g_is_threaded || g_thread_id != std::this_thread::get_id()) { g_is_threaded = true; std::ostringstream strm; - strm << smt2log << "-" << std::this_thread::get_id(); - smt2log = symbol(strm.str()); + strm << smt2log << '-' << std::this_thread::get_id(); + smt2log = symbol(std::move(strm).str()); } to_solver(s)->m_pp = alloc(solver2smt2_pp, mk_c(c)->m(), smt2log.str()); } @@ -208,7 +208,7 @@ extern "C" { if (!smt_logics::supported_logic(to_symbol(logic))) { std::ostringstream strm; strm << "logic '" << to_symbol(logic) << "' is not recognized"; - SET_ERROR_CODE(Z3_INVALID_ARG, strm.str()); + SET_ERROR_CODE(Z3_INVALID_ARG, std::move(strm).str()); RETURN_Z3(nullptr); } else { @@ -306,7 +306,7 @@ extern "C" { if (!parse_smt2_commands(*ctx.get(), is)) { ctx = nullptr; - SET_ERROR_CODE(Z3_PARSER_ERROR, errstrm.str()); + SET_ERROR_CODE(Z3_PARSER_ERROR, std::move(errstrm).str()); return; } @@ -333,7 +333,7 @@ extern "C" { std::stringstream err; sat::solver solver(to_solver_ref(s)->get_params(), m.limit()); if (!parse_dimacs(is, err, solver)) { - SET_ERROR_CODE(Z3_PARSER_ERROR, err.str()); + SET_ERROR_CODE(Z3_PARSER_ERROR, std::move(err).str()); return; } sat2goal s2g; @@ -400,7 +400,7 @@ extern "C" { if (!initialized) to_solver(s)->m_solver = nullptr; descrs.display(buffer); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(""); } @@ -799,7 +799,7 @@ extern "C" { init_solver(c, s); std::ostringstream buffer; to_solver_ref(s)->display(buffer); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(""); } @@ -810,7 +810,7 @@ extern "C" { init_solver(c, s); std::ostringstream buffer; to_solver_ref(s)->display_dimacs(buffer, include_names); - return mk_c(c)->mk_external_string(buffer.str()); + return mk_c(c)->mk_external_string(std::move(buffer).str()); Z3_CATCH_RETURN(""); } @@ -984,14 +984,14 @@ extern "C" { Z3_TRY; RESET_ERROR_CODE(); init_solver(c, s); - user_propagator::on_clause_eh_t _on_clause = [=](void* user_ctx, expr* proof, unsigned n, expr* const* _literals) { + user_propagator::on_clause_eh_t _on_clause = [=](void* user_ctx, expr* proof, unsigned nd, unsigned const* deps, unsigned n, expr* const* _literals) { Z3_ast_vector_ref * literals = alloc(Z3_ast_vector_ref, *mk_c(c), mk_c(c)->m()); mk_c(c)->save_object(literals); expr_ref pr(proof, mk_c(c)->m()); scoped_ast_vector _sc(literals); for (unsigned i = 0; i < n; ++i) literals->m_ast_vector.push_back(_literals[i]); - on_clause_eh(user_ctx, of_expr(pr.get()), of_ast_vector(literals)); + on_clause_eh(user_ctx, of_expr(pr.get()), nd, deps, of_ast_vector(literals)); }; to_solver_ref(s)->register_on_clause(user_context, _on_clause); auto& solver = *to_solver(s); @@ -1092,15 +1092,15 @@ extern "C" { Z3_CATCH; } - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); expr* const * _fixed_ids = (expr* const*) fixed_ids; expr* const * _eq_lhs = (expr*const*) eq_lhs; expr* const * _eq_rhs = (expr*const*) eq_rhs; - reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); - Z3_CATCH; + return reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); + Z3_CATCH_RETURN(false); } void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh) { @@ -1114,17 +1114,17 @@ extern "C" { void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr**, unsigned*, lbool*))decide_eh; + user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned, bool))decide_eh; to_solver_ref(s)->user_propagate_register_decide(c); Z3_CATCH; } - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { Z3_TRY; LOG_Z3_solver_next_split(c, cb, t, idx, phase); RESET_ERROR_CODE(); - reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); - Z3_CATCH; + return reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); + Z3_CATCH_RETURN(false); } Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 88b520147d3..cb1446a08c3 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -320,7 +320,7 @@ namespace z3 { /** \brief Create a recursive datatype over a single sort. \c name is the name of the recursive datatype - \c n - the numer of constructors of the datatype + \c n - the number of constructors of the datatype \c cs - the \c n constructors used to define the datatype References to the datatype can be created using \ref datatype_sort. @@ -368,7 +368,7 @@ namespace z3 { func_decl recfun(char const * name, sort const & d1, sort const & d2, sort const & range); /** - * \brief add function definition body to declaration decl. decl needs to be declared using context::. + * \brief add function definition body to declaration decl. decl needs to be declared using context::recfun. * @param decl * @param args * @param body @@ -4214,17 +4214,20 @@ namespace z3 { return expr(ctx(), r); } - typedef std::function on_clause_eh_t; + typedef std::function const& deps, expr_vector const& clause)> on_clause_eh_t; class on_clause { context& c; on_clause_eh_t m_on_clause; - static void _on_clause_eh(void* _ctx, Z3_ast _proof, Z3_ast_vector _literals) { + static void _on_clause_eh(void* _ctx, Z3_ast _proof, unsigned n, unsigned const* dep, Z3_ast_vector _literals) { on_clause* ctx = static_cast(_ctx); expr_vector lits(ctx->c, _literals); expr proof(ctx->c, _proof); - ctx->m_on_clause(proof, lits); + std::vector deps; + for (unsigned i = 0; i < n; ++i) + deps.push_back(dep[i]); + ctx->m_on_clause(proof, deps, lits); } public: on_clause(solver& s, on_clause_eh_t& on_clause_eh): c(s.ctx()) { @@ -4240,7 +4243,7 @@ namespace z3 { typedef std::function final_eh_t; typedef std::function eq_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; + typedef std::function decide_eh_t; final_eh_t m_final_eh; eq_eh_t m_eq_eh; @@ -4309,13 +4312,11 @@ namespace z3 { p->m_created_eh(e); } - static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { + static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - expr val(p->ctx(), *_val); - p->m_decide_eh(val, *bit, *is_pos); - // TBD: life time of val is within the scope of this callback. - *_val = val; + expr val(p->ctx(), _val); + p->m_decide_eh(val, bit, is_pos); } public: @@ -4435,7 +4436,7 @@ namespace z3 { } void register_decide() { - m_decide_eh = [this](expr& val, unsigned& bit, Z3_lbool& is_pos) { + m_decide_eh = [this](expr val, unsigned bit, bool is_pos) { decide(val, bit, is_pos); }; if (s) { @@ -4451,11 +4452,11 @@ namespace z3 { virtual void created(expr const& /*e*/) {} - virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {} + virtual void decide(expr const& /*val*/, unsigned /*bit*/, bool /*is_pos*/) {} - void next_split(expr const & e, unsigned idx, Z3_lbool phase) { + bool next_split(expr const& e, unsigned idx, Z3_lbool phase) { assert(cb); - Z3_solver_next_split(ctx(), cb, e, idx, phase); + return Z3_solver_next_split(ctx(), cb, e, idx, phase); } /** @@ -4498,14 +4499,14 @@ namespace z3 { Z3_solver_propagate_consequence(ctx(), cb, fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } - void propagate(expr_vector const& fixed, expr const& conseq) { + bool propagate(expr_vector const& fixed, expr const& conseq) { assert(cb); assert((Z3_context)conseq.ctx() == (Z3_context)ctx()); array _fixed(fixed); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); } - void propagate(expr_vector const& fixed, + bool propagate(expr_vector const& fixed, expr_vector const& lhs, expr_vector const& rhs, expr const& conseq) { assert(cb); @@ -4515,7 +4516,7 @@ namespace z3 { array _lhs(lhs); array _rhs(rhs); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } }; diff --git a/src/api/dotnet/CMakeLists.txt b/src/api/dotnet/CMakeLists.txt index fcd7b0d8586..d3cb87bc713 100644 --- a/src/api/dotnet/CMakeLists.txt +++ b/src/api/dotnet/CMakeLists.txt @@ -9,7 +9,7 @@ set(VER_TWEAK "${Z3_VERSION_TWEAK}") # Generate Native.cs set(Z3_DOTNET_NATIVE_FILE "${CMAKE_CURRENT_BINARY_DIR}/Native.cs") add_custom_command(OUTPUT "${Z3_DOTNET_NATIVE_FILE}" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/update_api.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--dotnet-output-dir" @@ -25,7 +25,7 @@ add_custom_command(OUTPUT "${Z3_DOTNET_NATIVE_FILE}" # Generate Enumerations.cs set(Z3_DOTNET_CONST_FILE "${CMAKE_CURRENT_BINARY_DIR}/Enumerations.cs") add_custom_command(OUTPUT "${Z3_DOTNET_CONST_FILE}" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/mk_consts_files.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--dotnet-output-dir" diff --git a/src/api/dotnet/Constructor.cs b/src/api/dotnet/Constructor.cs index f635d78e45b..323301bf901 100644 --- a/src/api/dotnet/Constructor.cs +++ b/src/api/dotnet/Constructor.cs @@ -91,7 +91,13 @@ public FuncDecl[] AccessorDecls /// ~Constructor() { - Native.Z3_del_constructor(Context.nCtx, NativeObject); + if (Context.nCtx != IntPtr.Zero) { + lock (Context) + { + if (Context.nCtx != IntPtr.Zero) + Native.Z3_del_constructor(Context.nCtx, NativeObject); + } + } } #region Internal diff --git a/src/api/dotnet/Context.cs b/src/api/dotnet/Context.cs index 6365852a6d9..629e964843f 100644 --- a/src/api/dotnet/Context.cs +++ b/src/api/dotnet/Context.cs @@ -3770,7 +3770,7 @@ public Simplifier MkSimplifier(string name) } /// - /// Create a simplifie that applies and + /// Create a simplifier that applies and /// then . /// public Simplifier AndThen(Simplifier t1, Simplifier t2, params Simplifier[] ts) diff --git a/src/api/dotnet/Microsoft.Z3.csproj.in b/src/api/dotnet/Microsoft.Z3.csproj.in index 85ab98b3856..ec136809d4a 100644 --- a/src/api/dotnet/Microsoft.Z3.csproj.in +++ b/src/api/dotnet/Microsoft.Z3.csproj.in @@ -7,6 +7,8 @@ Microsoft.Z3 Microsoft.Z3 + README.md + Z3 .NET Interface Z3 .NET Interface @@ -15,8 +17,8 @@ Z3 is a satisfiability modulo theories solver from Microsoft Research. .NET Interface to the Z3 Theorem Prover - Copyright (C) 2006-2019 Microsoft Corporation - Copyright (C) 2006-2019 Microsoft Corporation + Copyright (C) 2006- Microsoft Corporation + Copyright (C) 2006- Microsoft Corporation Microsoft Corporation Microsoft Corporation @@ -65,6 +67,11 @@ ${Z3_DOTNET_COMPILE_ITEMS} + + + + + diff --git a/src/api/dotnet/OnClause.cs b/src/api/dotnet/OnClause.cs index 68631892891..05defb312ab 100644 --- a/src/api/dotnet/OnClause.cs +++ b/src/api/dotnet/OnClause.cs @@ -30,6 +30,7 @@ namespace Microsoft.Z3 using Z3_context = System.IntPtr; using Z3_solver = System.IntPtr; using voidp = System.IntPtr; + using uintp = System.IntPtr; using Z3_ast = System.IntPtr; using Z3_ast_vector = System.IntPtr; @@ -60,7 +61,7 @@ public class OnClause : IDisposable Native.Z3_on_clause_eh on_clause_eh; - static void _on_clause(voidp ctx, Z3_ast _proof_hint, Z3_ast_vector _clause) + static void _on_clause(voidp ctx, Z3_ast _proof_hint, uint n, uint[] deps, Z3_ast_vector _clause) { var onc = (OnClause)GCHandle.FromIntPtr(ctx).Target; using var proof_hint = Expr.Create(onc.ctx, _proof_hint); diff --git a/src/api/dotnet/Optimize.cs b/src/api/dotnet/Optimize.cs index 0694faabe83..891ed4105e1 100644 --- a/src/api/dotnet/Optimize.cs +++ b/src/api/dotnet/Optimize.cs @@ -220,7 +220,7 @@ public Handle AssertSoft(BoolExpr constraint, uint weight, string group) /// /// Check satisfiability of asserted constraints. /// Produce a model that (when the objectives are bounded and - /// don't use strict inequalities) meets the objectives. + /// don't use strict inequalities) is optimal. /// /// public Status Check(params Expr[] assumptions) diff --git a/src/api/dotnet/README.md b/src/api/dotnet/README.md new file mode 100644 index 00000000000..fe614782f52 --- /dev/null +++ b/src/api/dotnet/README.md @@ -0,0 +1,3 @@ +# Z3 Nuget Package + +For more information see [the Z3 github page](https://github.com/z3prover/z3.git) diff --git a/src/api/dotnet/Symbol.cs b/src/api/dotnet/Symbol.cs index c0e1e3e7312..f6756d5f47c 100644 --- a/src/api/dotnet/Symbol.cs +++ b/src/api/dotnet/Symbol.cs @@ -97,7 +97,7 @@ public override bool Equals(object o) } /// - /// The Symbols's hash code. + /// The Symbol's hash code. /// /// A hash code public override int GetHashCode() diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs index b9cd4dc392d..b1a2d3df5cf 100644 --- a/src/api/dotnet/UserPropagator.cs +++ b/src/api/dotnet/UserPropagator.cs @@ -58,12 +58,12 @@ public class UserPropagator : IDisposable public delegate void CreatedEh(Expr term); /// - /// Delegate type for callback into solver's branching + /// Delegate type for callback into solver's branching. The values can be overridden by calling . + /// /// A bit-vector or Boolean used for branching /// If the term is a bit-vector, then an index into the bit-vector being branched on - /// Set phase to -1 (false) or 1 (true) to override solver's phase - /// - public delegate void DecideEh(ref Expr term, ref uint idx, ref int phase); + /// The tentative truth-value + public delegate void DecideEh(Expr term, uint idx, bool phase); // access managed objects through a static array. // thread safety is ignored for now. @@ -168,16 +168,11 @@ static void _created(voidp ctx, Z3_solver_callback cb, Z3_ast a) prop.Callback(() => prop.created_eh(t), cb); } - static void _decide(voidp ctx, Z3_solver_callback cb, ref Z3_ast a, ref uint idx, ref int phase) + static void _decide(voidp ctx, Z3_solver_callback cb, Z3_ast a, uint idx, bool phase) { var prop = (UserPropagator)GCHandle.FromIntPtr(ctx).Target; - var t = Expr.Create(prop.ctx, a); - var u = t; - prop.callback = cb; - prop.decide_eh(ref t, ref idx, ref phase); - prop.callback = IntPtr.Zero; - if (u != t) - a = t.NativeObject; + using var t = Expr.Create(prop.ctx, a); + prop.Callback(() => prop.decide_eh(t, idx, phase), cb); } /// @@ -257,11 +252,29 @@ public void Conflict(IEnumerable terms) /// /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// + /// + public bool Propagate(IEnumerable terms, Expr conseq) + { + return Propagate(terms, new EqualityPairs(), conseq); + } + + /// + /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// /// - public void Propagate(IEnumerable terms, Expr conseq) + public bool Propagate(IEnumerable terms, EqualityPairs equalities, Expr conseq) { var nTerms = Z3Object.ArrayToNative(terms.ToArray()); - Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, 0u, null, null, conseq.NativeObject); + var nLHS = Z3Object.ArrayToNative(equalities.LHS.ToArray()); + var nRHS = Z3Object.ArrayToNative(equalities.RHS.ToArray()); + return Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, (uint)equalities.Count, nLHS, nRHS, conseq.NativeObject) != 0; } @@ -352,10 +365,17 @@ public DecideEh Decide /// /// Set the next decision + /// A bit-vector or Boolean used for branching. Use to clear + /// If the term is a bit-vector, then an index into the bit-vector being branched on + /// The tentative truth-value (-1/false, 1/true, 0/let Z3 decide) /// - public void NextSplit(Expr e, uint idx, int phase) + /// + /// in case the value was successfully set; + /// if the next split could not be set + /// + public bool NextSplit(Expr e, uint idx, int phase) { - Native.Z3_solver_next_split(ctx.nCtx, this.callback, e.NativeObject, idx, phase); + return Native.Z3_solver_next_split(ctx.nCtx, this.callback, e?.NativeObject ?? IntPtr.Zero, idx, phase) != 0; } /// @@ -373,4 +393,72 @@ public void Register(Expr term) } } } + + /// + /// A list of equalities used as justifications for propagation + /// + public class EqualityPairs { + + readonly List lhsList = new List(); + readonly List rhsList = new List(); + + /// + /// The left hand sides of the equalities + /// + public Expr[] LHS => lhsList.ToArray(); + + /// + /// The right hand sides of the equalities + /// + public Expr[] RHS => rhsList.ToArray(); + + /// + /// The number of equalities + /// + public int Count => lhsList.Count; + + /// + /// Adds an equality to the list. The sorts of the arguments have to be the same. + /// The left hand side of the equality + /// The right hand side of the equality + /// + public void Add(Expr lhs, Expr rhs) { + lhsList.Add(lhs); + rhsList.Add(rhs); + } + + /// + /// Checks if two equality lists are equal. + /// The function does not take symmetries, shuffling, or duplicates into account. + /// + public override bool Equals(object obj) { + if (ReferenceEquals(this, obj)) + return true; + if (!(obj is EqualityPairs other)) + return false; + if (lhsList.Count != other.lhsList.Count) + return false; + for (int i = 0; i < lhsList.Count; i++) { + if (!lhsList[i].Equals(other.lhsList[i])) + return false; + } + return true; + } + + /// + /// Gets a hash code for the list of equalities + /// + public override int GetHashCode() { + int hash = lhsList.Count; + unchecked { + for (int i = 0; i < lhsList.Count; i++) { + hash ^= lhsList[i].GetHashCode(); + hash *= 17; + hash ^= rhsList[i].GetHashCode(); + hash *= 29; + } + return hash; + } + } + } } diff --git a/src/api/java/AST.java b/src/api/java/AST.java index c28c0cfcba7..99cdde948b1 100644 --- a/src/api/java/AST.java +++ b/src/api/java/AST.java @@ -19,6 +19,8 @@ import com.microsoft.z3.enumerations.Z3_ast_kind; +import java.lang.ref.ReferenceQueue; + /** * The abstract syntax tree (AST) class. **/ @@ -196,7 +198,7 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getASTDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ASTRef::new); } static AST create(Context ctx, long obj) @@ -217,4 +219,16 @@ static AST create(Context ctx, long obj) throw new Z3Exception("Unknown AST kind"); } } + + private static class ASTRef extends Z3ReferenceQueue.Reference { + + private ASTRef(AST referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.decRef(ctx.nCtx(), z3Obj); + } + } } diff --git a/src/api/java/ASTDecRefQueue.java b/src/api/java/ASTDecRefQueue.java deleted file mode 100644 index b0a6fa217e1..00000000000 --- a/src/api/java/ASTDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ASTDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ASTDecRefQueue extends IDecRefQueue -{ - public ASTDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.decRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/ASTMap.java b/src/api/java/ASTMap.java index 916811cec28..23a16a82881 100644 --- a/src/api/java/ASTMap.java +++ b/src/api/java/ASTMap.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Map from AST to AST **/ @@ -123,6 +125,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getASTMapDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ASTMapRef::new); + } + + private static class ASTMapRef extends Z3ReferenceQueue.Reference { + + private ASTMapRef(ASTMap referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.astMapDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ASTVector.java b/src/api/java/ASTVector.java index a6b436a99f4..8a5603fcb08 100644 --- a/src/api/java/ASTVector.java +++ b/src/api/java/ASTVector.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Vectors of ASTs. **/ @@ -101,16 +103,6 @@ public ASTVector(Context ctx) super(ctx, Native.mkAstVector(ctx.nCtx())); } - @Override - void incRef() { - Native.astVectorIncRef(getContext().nCtx(), getNativeObject()); - } - - @Override - void addToReferenceQueue() { - getContext().getASTVectorDRQ().storeReference(getContext(), this); - } - /** * Translates the AST vector into an AST[] * */ @@ -241,4 +233,26 @@ public RealExpr[] ToRealExprArray() res[i] = (RealExpr)Expr.create(getContext(), get(i).getNativeObject()); return res; } + + @Override + void incRef() { + Native.astVectorIncRef(getContext().nCtx(), getNativeObject()); + } + + @Override + void addToReferenceQueue() { + getContext().getReferenceQueue().storeReference(this, ASTVectorRef::new); + } + + private static class ASTVectorRef extends Z3ReferenceQueue.Reference { + + private ASTVectorRef(ASTVector referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.astVectorDecRef(ctx.nCtx(), z3Obj); + } + } } \ No newline at end of file diff --git a/src/api/java/ApplyResult.java b/src/api/java/ApplyResult.java index 6cfedd40480..b0e035c4fc7 100644 --- a/src/api/java/ApplyResult.java +++ b/src/api/java/ApplyResult.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * ApplyResult objects represent the result of an application of a tactic to a * goal. It contains the subgoals that were produced. @@ -66,6 +68,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getApplyResultDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ApplyResultRef::new); + } + + private static class ApplyResultRef extends Z3ReferenceQueue.Reference { + + private ApplyResultRef(ApplyResult referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.applyResultDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ApplyResultDecRefQueue.java b/src/api/java/ApplyResultDecRefQueue.java deleted file mode 100644 index e1a6607818b..00000000000 --- a/src/api/java/ApplyResultDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ApplyResultDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ApplyResultDecRefQueue extends IDecRefQueue -{ - public ApplyResultDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.applyResultDecRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/AstMapDecRefQueue.java b/src/api/java/AstMapDecRefQueue.java deleted file mode 100644 index 6c96970b7e7..00000000000 --- a/src/api/java/AstMapDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - AstMapDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ASTMapDecRefQueue extends IDecRefQueue { - public ASTMapDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.astMapDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/AstVectorDecRefQueue.java b/src/api/java/AstVectorDecRefQueue.java deleted file mode 100644 index e7ce3e33e65..00000000000 --- a/src/api/java/AstVectorDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - AstVectorDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ASTVectorDecRefQueue extends IDecRefQueue { - public ASTVectorDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.astVectorDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/CMakeLists.txt b/src/api/java/CMakeLists.txt index 4b13a25b1a4..c5221014f28 100644 --- a/src/api/java/CMakeLists.txt +++ b/src/api/java/CMakeLists.txt @@ -16,7 +16,7 @@ set(Z3_JAVA_PACKAGE_NAME "com.microsoft.z3") set(Z3_JAVA_NATIVE_JAVA "${CMAKE_CURRENT_BINARY_DIR}/Native.java") set(Z3_JAVA_NATIVE_CPP "${CMAKE_CURRENT_BINARY_DIR}/Native.cpp") add_custom_command(OUTPUT "${Z3_JAVA_NATIVE_JAVA}" "${Z3_JAVA_NATIVE_CPP}" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/update_api.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--java-input-dir" @@ -74,7 +74,7 @@ foreach (enum_file ${Z3_JAVA_ENUMERATION_PACKAGE_FILES}) ) endforeach() add_custom_command(OUTPUT ${Z3_JAVA_ENUMERATION_PACKAGE_FILES_FULL_PATH} - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/mk_consts_files.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--java-output-dir" @@ -91,17 +91,13 @@ add_custom_command(OUTPUT ${Z3_JAVA_ENUMERATION_PACKAGE_FILES_FULL_PATH} set(Z3_JAVA_JAR_SOURCE_FILES AlgebraicNum.java - ApplyResultDecRefQueue.java ApplyResult.java ArithExpr.java ArithSort.java ArrayExpr.java ArraySort.java - ASTDecRefQueue.java AST.java - AstMapDecRefQueue.java ASTMap.java - AstVectorDecRefQueue.java ASTVector.java BitVecExpr.java BitVecNum.java @@ -109,9 +105,7 @@ set(Z3_JAVA_JAR_SOURCE_FILES BoolExpr.java BoolSort.java CharSort.java - ConstructorDecRefQueue.java Constructor.java - ConstructorListDecRefQueue.java ConstructorList.java Context.java DatatypeExpr.java @@ -121,7 +115,6 @@ set(Z3_JAVA_JAR_SOURCE_FILES FiniteDomainExpr.java FiniteDomainNum.java FiniteDomainSort.java - FixedpointDecRefQueue.java Fixedpoint.java FPExpr.java FPNum.java @@ -130,13 +123,9 @@ set(Z3_JAVA_JAR_SOURCE_FILES FPRMSort.java FPSort.java FuncDecl.java - FuncInterpDecRefQueue.java - FuncInterpEntryDecRefQueue.java FuncInterp.java Global.java - GoalDecRefQueue.java Goal.java - IDecRefQueue.java IntExpr.java IntNum.java IntSort.java @@ -144,16 +133,11 @@ set(Z3_JAVA_JAR_SOURCE_FILES Lambda.java ListSort.java Log.java - ModelDecRefQueue.java Model.java - OptimizeDecRefQueue.java Optimize.java - ParamDescrsDecRefQueue.java ParamDescrs.java - ParamsDecRefQueue.java Params.java Pattern.java - ProbeDecRefQueue.java Probe.java Quantifier.java RatNum.java @@ -166,22 +150,20 @@ set(Z3_JAVA_JAR_SOURCE_FILES SeqSort.java SetSort.java Simplifier.java - SimplifierDecRefQueue.java - SolverDecRefQueue.java Solver.java Sort.java - StatisticsDecRefQueue.java Statistics.java Status.java StringSymbol.java Symbol.java - TacticDecRefQueue.java Tactic.java TupleSort.java UninterpretedSort.java + UserPropagatorBase.java Version.java Z3Exception.java Z3Object.java + Z3ReferenceQueue.java ) set(Z3_JAVA_JAR_SOURCE_FILES_FULL_PATH "") foreach (java_src_file ${Z3_JAVA_JAR_SOURCE_FILES}) @@ -203,11 +185,13 @@ add_custom_target(build_z3_java_bindings # Rule to build ``com.microsoft.z3.jar`` # TODO: Should we set ``CMAKE_JNI_TARGET`` to ``TRUE``? +# REMARK: removed VERSION to fix issue with using this to create installations. + add_jar(z3JavaJar SOURCES ${Z3_JAVA_JAR_SOURCE_FILES_FULL_PATH} OUTPUT_NAME ${Z3_JAVA_PACKAGE_NAME} OUTPUT_DIR "${PROJECT_BINARY_DIR}" - VERSION "${Z3_VERSION}" +# VERSION "${Z3_VERSION}" ) ############################################################################### @@ -218,21 +202,22 @@ if (Z3_INSTALL_JAVA_BINDINGS) # Provide cache variables for the install locations that the user can change. # This defaults to ``/usr/local/java`` which seems to be the location for ``.jar`` # files on Linux distributions - set(Z3_JAVA_JAR_INSTALLDIR - "${CMAKE_INSTALL_DATAROOTDIR}/java" - CACHE - PATH - "Directory to install Z3 Java jar file relative to install prefix" - ) - # FIXME: I don't think this the right installation location - set(Z3_JAVA_JNI_LIB_INSTALLDIR - "${CMAKE_INSTALL_LIBDIR}" - CACHE - PATH - "Directory to install Z3 Java JNI bridge library relative to install prefix" - ) + if (NOT Z3_JAVA_JAR_INSTALLDIR) + set(Z3_JAVA_JAR_INSTALLDIR + "${CMAKE_INSTALL_DATAROOTDIR}/java" + CACHE + PATH + "Directory to install Z3 Java jar file relative to install prefix" + ) + endif() + if (NOT Z3_JAVA_JNI_LIB_INSTALLDIR) + set(Z3_JAVA_JNI_LIB_INSTALLDIR + "${CMAKE_INSTALL_LIBDIR}" + CACHE + PATH + "Directory to install Z3 Java JNI bridge library relative to install prefix" + ) + endif() install(TARGETS z3java DESTINATION "${Z3_JAVA_JNI_LIB_INSTALLDIR}") - # Note: Don't use ``DESTINATION`` here as the version of ``UseJava.cmake`` shipped - # with CMake 2.8.12.2 handles that incorrectly. install_jar(z3JavaJar "${Z3_JAVA_JAR_INSTALLDIR}") endif() diff --git a/src/api/java/Constructor.java b/src/api/java/Constructor.java index 59565f56509..8e1766cb02a 100644 --- a/src/api/java/Constructor.java +++ b/src/api/java/Constructor.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Constructors are used for datatype sorts. **/ @@ -91,7 +93,7 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getConstructorDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ConstructorRef::new); } static Constructor of(Context ctx, Symbol name, Symbol recognizer, @@ -114,4 +116,16 @@ static Constructor of(Context ctx, Symbol name, Symbol recognizer, return new Constructor<>(ctx, n, nativeObj); } + + private static class ConstructorRef extends Z3ReferenceQueue.Reference> { + + private ConstructorRef(Constructor referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.delConstructor(ctx.nCtx(), z3Obj); + } + } } diff --git a/src/api/java/ConstructorDecRefQueue.java b/src/api/java/ConstructorDecRefQueue.java deleted file mode 100644 index a76b26bb73c..00000000000 --- a/src/api/java/ConstructorDecRefQueue.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.microsoft.z3; - -public class ConstructorDecRefQueue extends IDecRefQueue> { - public ConstructorDecRefQueue() { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.delConstructor(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/ConstructorList.java b/src/api/java/ConstructorList.java index d015c51c0ad..577c802f050 100644 --- a/src/api/java/ConstructorList.java +++ b/src/api/java/ConstructorList.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Lists of constructors **/ @@ -34,7 +36,7 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getConstructorListDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ConstructorListRef::new); } ConstructorList(Context ctx, Constructor[] constructors) @@ -43,4 +45,16 @@ void addToReferenceQueue() { constructors.length, Constructor.arrayToNative(constructors))); } + + private static class ConstructorListRef extends Z3ReferenceQueue.Reference> { + + private ConstructorListRef(ConstructorList referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.delConstructorList(ctx.nCtx(), z3Obj); + } + } } diff --git a/src/api/java/ConstructorListDecRefQueue.java b/src/api/java/ConstructorListDecRefQueue.java deleted file mode 100644 index 2f5dfcb3526..00000000000 --- a/src/api/java/ConstructorListDecRefQueue.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.microsoft.z3; - -public class ConstructorListDecRefQueue extends IDecRefQueue> { - public ConstructorListDecRefQueue() { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.delConstructorList(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Context.java b/src/api/java/Context.java index 7aaef4801ec..f3efa632a15 100644 --- a/src/api/java/Context.java +++ b/src/api/java/Context.java @@ -452,6 +452,21 @@ public final FuncDecl mkFuncDecl(Symbol name, Sort[] domain, return new FuncDecl<>(this, name, domain, range); } + public final FuncDecl mkPropagateFunction(Symbol name, Sort[] domain, R range) + { + checkContextMatch(name); + checkContextMatch(domain); + checkContextMatch(range); + long f = Native.solverPropagateDeclare( + this.nCtx(), + name.getNativeObject(), + AST.arrayLength(domain), + AST.arrayToNative(domain), + range.getNativeObject()); + return new FuncDecl<>(this, f); + } + + /** * Creates a new function declaration. **/ @@ -2018,11 +2033,11 @@ public SeqExpr mkString(String s) { StringBuilder buf = new StringBuilder(); for (int i = 0; i < s.length(); ++i) { - int code = s.codePointAt(i); - if (code <= 32 || 127 < code) - buf.append(String.format("\\u{%x}", code)); - else - buf.append(s.charAt(i)); + int code = s.codePointAt(i); + if (code <= 32 || 127 < code) + buf.append(String.format("\\u{%x}", code)); + else + buf.append(s.charAt(i)); } return (SeqExpr) Expr.create(this, Native.mkString(nCtx(), buf.toString())); } @@ -2288,13 +2303,13 @@ public final ReExpr mkIntersect(Expr>... t) public final ReExpr mkDiff(Expr> a, Expr> b) { checkContextMatch(a, b); - return (ReExpr) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject())); + return (ReExpr) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject())); } /** * Create the empty regular expression. - * Coresponds to re.none + * Corresponds to re.none */ public final ReExpr mkEmptyRe(ReSort s) { @@ -4304,119 +4319,9 @@ void checkContextMatch(Z3Object[] arr) checkContextMatch(a); } - private ASTDecRefQueue m_AST_DRQ = new ASTDecRefQueue(); - private ASTMapDecRefQueue m_ASTMap_DRQ = new ASTMapDecRefQueue(); - private ASTVectorDecRefQueue m_ASTVector_DRQ = new ASTVectorDecRefQueue(); - private ApplyResultDecRefQueue m_ApplyResult_DRQ = new ApplyResultDecRefQueue(); - private FuncInterpEntryDecRefQueue m_FuncEntry_DRQ = new FuncInterpEntryDecRefQueue(); - private FuncInterpDecRefQueue m_FuncInterp_DRQ = new FuncInterpDecRefQueue(); - private GoalDecRefQueue m_Goal_DRQ = new GoalDecRefQueue(); - private ModelDecRefQueue m_Model_DRQ = new ModelDecRefQueue(); - private ParamsDecRefQueue m_Params_DRQ = new ParamsDecRefQueue(); - private ParamDescrsDecRefQueue m_ParamDescrs_DRQ = new ParamDescrsDecRefQueue(); - private ProbeDecRefQueue m_Probe_DRQ = new ProbeDecRefQueue(); - private SolverDecRefQueue m_Solver_DRQ = new SolverDecRefQueue(); - private StatisticsDecRefQueue m_Statistics_DRQ = new StatisticsDecRefQueue(); - private TacticDecRefQueue m_Tactic_DRQ = new TacticDecRefQueue(); - private SimplifierDecRefQueue m_Simplifier_DRQ = new SimplifierDecRefQueue(); - private FixedpointDecRefQueue m_Fixedpoint_DRQ = new FixedpointDecRefQueue(); - private OptimizeDecRefQueue m_Optimize_DRQ = new OptimizeDecRefQueue(); - private ConstructorDecRefQueue m_Constructor_DRQ = new ConstructorDecRefQueue(); - private ConstructorListDecRefQueue m_ConstructorList_DRQ = - new ConstructorListDecRefQueue(); - - public IDecRefQueue> getConstructorDRQ() { - return m_Constructor_DRQ; - } - - public IDecRefQueue> getConstructorListDRQ() { - return m_ConstructorList_DRQ; - } - - public IDecRefQueue getASTDRQ() - { - return m_AST_DRQ; - } - - public IDecRefQueue getASTMapDRQ() - { - return m_ASTMap_DRQ; - } - - public IDecRefQueue getASTVectorDRQ() - { - return m_ASTVector_DRQ; - } - - public IDecRefQueue getApplyResultDRQ() - { - return m_ApplyResult_DRQ; - } - - public IDecRefQueue> getFuncEntryDRQ() - { - return m_FuncEntry_DRQ; - } - - public IDecRefQueue> getFuncInterpDRQ() - { - return m_FuncInterp_DRQ; - } + private Z3ReferenceQueue m_RefQueue = new Z3ReferenceQueue(this); - public IDecRefQueue getGoalDRQ() - { - return m_Goal_DRQ; - } - - public IDecRefQueue getModelDRQ() - { - return m_Model_DRQ; - } - - public IDecRefQueue getParamsDRQ() - { - return m_Params_DRQ; - } - - public IDecRefQueue getParamDescrsDRQ() - { - return m_ParamDescrs_DRQ; - } - - public IDecRefQueue getProbeDRQ() - { - return m_Probe_DRQ; - } - - public IDecRefQueue getSolverDRQ() - { - return m_Solver_DRQ; - } - - public IDecRefQueue getStatisticsDRQ() - { - return m_Statistics_DRQ; - } - - public IDecRefQueue getTacticDRQ() - { - return m_Tactic_DRQ; - } - - public IDecRefQueue getSimplifierDRQ() - { - return m_Simplifier_DRQ; - } - - public IDecRefQueue getFixedpointDRQ() - { - return m_Fixedpoint_DRQ; - } - - public IDecRefQueue getOptimizeDRQ() - { - return m_Optimize_DRQ; - } + Z3ReferenceQueue getReferenceQueue() { return m_RefQueue; } /** * Disposes of the context. @@ -4424,27 +4329,16 @@ public IDecRefQueue getOptimizeDRQ() @Override public void close() { - m_AST_DRQ.forceClear(this); - m_ASTMap_DRQ.forceClear(this); - m_ASTVector_DRQ.forceClear(this); - m_ApplyResult_DRQ.forceClear(this); - m_FuncEntry_DRQ.forceClear(this); - m_FuncInterp_DRQ.forceClear(this); - m_Goal_DRQ.forceClear(this); - m_Model_DRQ.forceClear(this); - m_Params_DRQ.forceClear(this); - m_Probe_DRQ.forceClear(this); - m_Solver_DRQ.forceClear(this); - m_Optimize_DRQ.forceClear(this); - m_Statistics_DRQ.forceClear(this); - m_Tactic_DRQ.forceClear(this); - m_Simplifier_DRQ.forceClear(this); - m_Fixedpoint_DRQ.forceClear(this); + if (m_ctx == 0) + return; + + m_RefQueue.forceClear(); m_boolSort = null; m_intSort = null; m_realSort = null; m_stringSort = null; + m_RefQueue = null; synchronized (creation_lock) { Native.delContext(m_ctx); diff --git a/src/api/java/Fixedpoint.java b/src/api/java/Fixedpoint.java index 96e1dd0cb21..c35f15e9667 100644 --- a/src/api/java/Fixedpoint.java +++ b/src/api/java/Fixedpoint.java @@ -19,6 +19,8 @@ import com.microsoft.z3.enumerations.Z3_lbool; +import java.lang.ref.ReferenceQueue; + /** * Object for managing fixedpoints **/ @@ -327,9 +329,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getFixedpointDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, FixedpointRef::new); } - @Override - void checkNativeObject(long obj) { } + private static class FixedpointRef extends Z3ReferenceQueue.Reference { + + private FixedpointRef(Fixedpoint referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.fixedpointDecRef(ctx.nCtx(), z3Obj); + } + } } diff --git a/src/api/java/FixedpointDecRefQueue.java b/src/api/java/FixedpointDecRefQueue.java deleted file mode 100644 index 69ed82092f5..00000000000 --- a/src/api/java/FixedpointDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - FixedpointDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class FixedpointDecRefQueue extends IDecRefQueue { - public FixedpointDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) - { - Native.fixedpointDecRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/FuncInterp.java b/src/api/java/FuncInterp.java index 64f96534b05..3d2affc4a5e 100644 --- a/src/api/java/FuncInterp.java +++ b/src/api/java/FuncInterp.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * A function interpretation is represented as a finite map and an 'else' value. * Each entry in the finite map represents the value of a function given a set @@ -93,7 +95,19 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getFuncEntryDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, FuncEntryRef::new); + } + + private static class FuncEntryRef extends Z3ReferenceQueue.Reference> { + + private FuncEntryRef(Entry referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.funcEntryDecRef(ctx.nCtx(), z3Obj); + } } } @@ -186,6 +200,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getFuncInterpDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, FuncInterpRef::new); + } + + private static class FuncInterpRef extends Z3ReferenceQueue.Reference> { + + private FuncInterpRef(FuncInterp referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.funcInterpDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/FuncInterpDecRefQueue.java b/src/api/java/FuncInterpDecRefQueue.java deleted file mode 100644 index 06a6f2af851..00000000000 --- a/src/api/java/FuncInterpDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - FuncInterpDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class FuncInterpDecRefQueue extends IDecRefQueue> -{ - public FuncInterpDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.funcInterpDecRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/FuncInterpEntryDecRefQueue.java b/src/api/java/FuncInterpEntryDecRefQueue.java deleted file mode 100644 index 77bb78f5b6c..00000000000 --- a/src/api/java/FuncInterpEntryDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - FuncInterpEntryDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class FuncInterpEntryDecRefQueue extends IDecRefQueue> { - public FuncInterpEntryDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.funcEntryDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Goal.java b/src/api/java/Goal.java index 3326f81fe02..e79ca7b7454 100644 --- a/src/api/java/Goal.java +++ b/src/api/java/Goal.java @@ -19,6 +19,8 @@ import com.microsoft.z3.enumerations.Z3_goal_prec; +import java.lang.ref.ReferenceQueue; + /** * A goal (aka problem). A goal is essentially a set of formulas, that can be * solved and/or transformed using tactics and solvers. @@ -262,6 +264,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getGoalDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, GoalRef::new); + } + + private static class GoalRef extends Z3ReferenceQueue.Reference { + + private GoalRef(Goal referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.goalDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/GoalDecRefQueue.java b/src/api/java/GoalDecRefQueue.java deleted file mode 100644 index 90bad1fb1c1..00000000000 --- a/src/api/java/GoalDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - GoalDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class GoalDecRefQueue extends IDecRefQueue { - public GoalDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.goalDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/IDecRefQueue.java b/src/api/java/IDecRefQueue.java deleted file mode 100644 index 4b515a3b6ca..00000000000 --- a/src/api/java/IDecRefQueue.java +++ /dev/null @@ -1,83 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - IDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -import java.lang.ref.PhantomReference; -import java.lang.ref.Reference; -import java.lang.ref.ReferenceQueue; -import java.util.IdentityHashMap; -import java.util.Map; - -/** - * A queue to handle management of native memory. - * - *

Mechanics: once an object is created, a metadata is stored for it in - * {@code referenceMap}, and a {@link PhantomReference} is created with a - * reference to {@code referenceQueue}. - * Once the object becomes strongly unreachable, the phantom reference gets - * added by JVM to the {@code referenceQueue}. - * After each object creation, we iterate through the available objects in - * {@code referenceQueue} and decrement references for them. - * - * @param Type of object stored in queue. - */ -public abstract class IDecRefQueue { - private final ReferenceQueue referenceQueue = new ReferenceQueue<>(); - private final Map, Long> referenceMap = - new IdentityHashMap<>(); - - protected IDecRefQueue() {} - - /** - * An implementation of this method should decrement the reference on a - * given native object. - * This function should always be called on the {@code ctx} thread. - * - * @param ctx Z3 context. - * @param obj Pointer to a Z3 object. - */ - protected abstract void decRef(Context ctx, long obj); - - public void storeReference(Context ctx, T obj) { - PhantomReference ref = new PhantomReference<>(obj, referenceQueue); - referenceMap.put(ref, obj.getNativeObject()); - clear(ctx); - } - - /** - * Clean all references currently in {@code referenceQueue}. - */ - protected void clear(Context ctx) - { - Reference ref; - while ((ref = referenceQueue.poll()) != null) { - long z3ast = referenceMap.remove(ref); - decRef(ctx, z3ast); - } - } - - /** - * Clean all references stored in {@code referenceMap}, - * regardless of whether they are in {@code referenceMap} or not. - */ - public void forceClear(Context ctx) { - for (long ref : referenceMap.values()) { - decRef(ctx, ref); - } - } -} diff --git a/src/api/java/Model.java b/src/api/java/Model.java index ffc4dd47f9d..cf9ab4b64e7 100644 --- a/src/api/java/Model.java +++ b/src/api/java/Model.java @@ -19,6 +19,8 @@ import com.microsoft.z3.enumerations.Z3_sort_kind; +import java.lang.ref.ReferenceQueue; + /** * A Model contains interpretations (assignments) of constants and functions. **/ @@ -296,6 +298,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getModelDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ModelRef::new); + } + + private static class ModelRef extends Z3ReferenceQueue.Reference { + + private ModelRef(Model referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.modelDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ModelDecRefQueue.java b/src/api/java/ModelDecRefQueue.java deleted file mode 100644 index f1b7c3fdd6e..00000000000 --- a/src/api/java/ModelDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ModelDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ModelDecRefQueue extends IDecRefQueue { - public ModelDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.modelDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/NativeStatic.txt b/src/api/java/NativeStatic.txt index 4693272d544..21d6ba07541 100644 --- a/src/api/java/NativeStatic.txt +++ b/src/api/java/NativeStatic.txt @@ -77,3 +77,159 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_setInternalErrorHand Z3_set_error_handler((Z3_context)a0, Z3JavaErrorHandler); } + +#include + +struct JavaInfo { + JNIEnv *jenv = nullptr; + jobject jobj = nullptr; + + jmethodID push = nullptr; + jmethodID pop = nullptr; + jmethodID fresh = nullptr; + jmethodID created = nullptr; + jmethodID fixed = nullptr; + jmethodID eq = nullptr; + jmethodID final = nullptr; + jmethodID decide = nullptr; + + Z3_solver_callback cb = nullptr; +}; + +struct ScopedCB { + JavaInfo *info; + ScopedCB(JavaInfo *_info, Z3_solver_callback cb): info(_info) { + info->cb = cb; + } + ~ScopedCB() { + info->cb = nullptr; + } +}; + +static void push_eh(void* _p, Z3_solver_callback cb) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->push); +} + +static void pop_eh(void* _p, Z3_solver_callback cb, unsigned int number) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->pop, number); +} + +static void* fresh_eh(void* _p, Z3_context new_context) { + JavaInfo *info = static_cast(_p); + return info->jenv->CallObjectMethod(info->jobj, info->fresh, (jlong)new_context); +} + +static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->created, (jlong)_e); +} + +static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->fixed, (jlong)_var, (jlong)_value); +} + +static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->eq, (jlong)_x, (jlong)_y); +} + +static void final_eh(void* _p, Z3_solver_callback cb) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->final); +} + +static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val, bit, is_pos); +} + +DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + JavaInfo *info = new JavaInfo; + + info->jenv = jenv; + info->jobj = jenv->NewGlobalRef(jobj); + jclass jcls = jenv->GetObjectClass(info->jobj); + info->push = jenv->GetMethodID(jcls, "pushWrapper", "()V"); + info->pop = jenv->GetMethodID(jcls, "popWrapper", "(I)V"); + info->fresh = jenv->GetMethodID(jcls, "freshWrapper", "(J)Lcom/microsoft/z3/Native$UserPropagatorBase;"); + info->created = jenv->GetMethodID(jcls, "createdWrapper", "(J)V"); + info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V"); + info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V"); + info->final = jenv->GetMethodID(jcls, "finWrapper", "()V"); + info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JIZ)V"); + + if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) { + assert(false); + } + + Z3_solver_propagate_init((Z3_context)ctx, (Z3_solver)solver, info, push_eh, pop_eh, fresh_eh); + + return (jlong)info; +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateDestroy(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo) { + JavaInfo *info = (JavaInfo*)javainfo; + info->jenv->DeleteGlobalRef(info->jobj); + delete info; +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterCreated(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_created((Z3_context)ctx, (Z3_solver)solver, created_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFinal(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_final((Z3_context)ctx, (Z3_solver)solver, final_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFixed(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_fixed((Z3_context)ctx, (Z3_solver)solver, fixed_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterEq(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_eq((Z3_context)ctx, (Z3_solver)solver, eq_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDecide(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh); +} + +DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateConsequence(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long num_fixed, jlongArray fixed, long num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) { + JavaInfo *info = (JavaInfo*)javainfo; + GETLONGAELEMS(Z3_ast, fixed, _fixed); + GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs); + GETLONGAELEMS(Z3_ast, eq_rhs, _eq_rhs); + bool retval = Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq); + RELEASELONGAELEMS(fixed, _fixed); + RELEASELONGAELEMS(eq_lhs, _eq_lhs); + RELEASELONGAELEMS(eq_rhs, _eq_rhs); + return (jboolean) retval; +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e) { + JavaInfo *info = (JavaInfo*)javainfo; + Z3_solver_callback cb = info->cb; + if (cb) + Z3_solver_propagate_register_cb((Z3_context)ctx, cb, (Z3_ast)e); + else if (solver) + Z3_solver_propagate_register((Z3_context)ctx, (Z3_solver)solver, (Z3_ast)e); + else { + assert(false); + } +} + + +DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e, long idx, int phase) { + JavaInfo *info = (JavaInfo*)javainfo; + Z3_solver_callback cb = info->cb; + return (jboolean) Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); +} diff --git a/src/api/java/Optimize.java b/src/api/java/Optimize.java index d72a28f0805..9679a96cd76 100644 --- a/src/api/java/Optimize.java +++ b/src/api/java/Optimize.java @@ -20,6 +20,8 @@ Nikolaj Bjorner (nbjorner) 2015-07-16 import com.microsoft.z3.enumerations.Z3_lbool; +import java.lang.ref.ReferenceQueue; + /** * Object for managing optimization context @@ -193,7 +195,7 @@ public Handle AssertSoft(Expr constraint, String weight, String gro /** * Check satisfiability of asserted constraints. * Produce a model that (when the objectives are bounded and - * don't use strict inequalities) meets the objectives. + * don't use strict inequalities) is optimal. **/ public Status Check(Expr... assumptions) { @@ -421,6 +423,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getOptimizeDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, OptimizeRef::new); + } + + private static class OptimizeRef extends Z3ReferenceQueue.Reference { + + private OptimizeRef(Optimize referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.optimizeDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/OptimizeDecRefQueue.java b/src/api/java/OptimizeDecRefQueue.java deleted file mode 100644 index 0acf20068c6..00000000000 --- a/src/api/java/OptimizeDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2015 Microsoft Corporation - -Module Name: - - OptimizeDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class OptimizeDecRefQueue extends IDecRefQueue { - public OptimizeDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.optimizeDecRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/ParamDescrs.java b/src/api/java/ParamDescrs.java index fdaf29647a2..0695f8fe01e 100644 --- a/src/api/java/ParamDescrs.java +++ b/src/api/java/ParamDescrs.java @@ -19,6 +19,8 @@ import com.microsoft.z3.enumerations.Z3_param_kind; +import java.lang.ref.ReferenceQueue; + /** * A ParamDescrs describes a set of parameters. **/ @@ -97,6 +99,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getParamDescrsDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ParamDescrsRef::new); + } + + private static class ParamDescrsRef extends Z3ReferenceQueue.Reference { + + private ParamDescrsRef(ParamDescrs referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.paramDescrsDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ParamDescrsDecRefQueue.java b/src/api/java/ParamDescrsDecRefQueue.java deleted file mode 100644 index ee3257db968..00000000000 --- a/src/api/java/ParamDescrsDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ParamDescrsDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ParamDescrsDecRefQueue extends IDecRefQueue { - public ParamDescrsDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) - { - Native.paramDescrsDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Params.java b/src/api/java/Params.java index a76dd3cab3e..1edfa67ba1d 100644 --- a/src/api/java/Params.java +++ b/src/api/java/Params.java @@ -18,6 +18,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * A ParameterSet represents a configuration in the form of Symbol/value pairs. **/ @@ -130,6 +132,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getParamsDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ParamsRef::new); + } + + private static class ParamsRef extends Z3ReferenceQueue.Reference { + + private ParamsRef(Params referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.paramsDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ParamsDecRefQueue.java b/src/api/java/ParamsDecRefQueue.java deleted file mode 100644 index 349713f67b0..00000000000 --- a/src/api/java/ParamsDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ParamDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ParamsDecRefQueue extends IDecRefQueue { - public ParamsDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.paramsDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Probe.java b/src/api/java/Probe.java index a36f3b64b02..cb4b134425a 100644 --- a/src/api/java/Probe.java +++ b/src/api/java/Probe.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Probes are used to inspect a goal (aka problem) and collect information that * may be used to decide which solver and/or preprocessing step will be used. @@ -56,6 +58,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getProbeDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, ProbeRef::new); + } + + private static class ProbeRef extends Z3ReferenceQueue.Reference { + + private ProbeRef(Probe referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.probeDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/ProbeDecRefQueue.java b/src/api/java/ProbeDecRefQueue.java deleted file mode 100644 index b25446c0cc4..00000000000 --- a/src/api/java/ProbeDecRefQueue.java +++ /dev/null @@ -1,32 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - ProbeDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class ProbeDecRefQueue extends IDecRefQueue -{ - public ProbeDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) - { - Native.probeDecRef(ctx.nCtx(), obj); - } -}; diff --git a/src/api/java/Quantifier.java b/src/api/java/Quantifier.java index efeac9bb573..c44534196b6 100644 --- a/src/api/java/Quantifier.java +++ b/src/api/java/Quantifier.java @@ -166,7 +166,7 @@ public Quantifier translate(Context ctx) * @param sorts Sorts of bound variables. * @param names Names of bound variables * @param body Body of quantifier - * @param weight Weight used to indicate priority for qunatifier instantiation + * @param weight Weight used to indicate priority for quantifier instantiation * @param patterns Nullable patterns * @param noPatterns Nullable noPatterns * @param quantifierID Nullable quantifierID diff --git a/src/api/java/Simplifier.java b/src/api/java/Simplifier.java index b3fc89ccf11..c89241a7df3 100644 --- a/src/api/java/Simplifier.java +++ b/src/api/java/Simplifier.java @@ -18,6 +18,8 @@ Christoph Wintersteiger (cwinter) 2012-03-21 package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + public class Simplifier extends Z3Object { /* * A string containing a description of parameters accepted by the simplifier. @@ -32,7 +34,7 @@ public String getHelp() * Retrieves parameter descriptions for Simplifiers. */ public ParamDescrs getParameterDescriptions() { - return new ParamDescrs(getContext(), Native.simplifierGetParamDescrs(getContext().nCtx(), getNativeObject())); + return new ParamDescrs(getContext(), Native.simplifierGetParamDescrs(getContext().nCtx(), getNativeObject())); } Simplifier(Context ctx, long obj) @@ -53,6 +55,18 @@ void incRef() @Override void addToReferenceQueue() { - getContext().getSimplifierDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, SimplifierRef::new); + } + + private static class SimplifierRef extends Z3ReferenceQueue.Reference { + + private SimplifierRef(Simplifier referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.simplifierDecRef(ctx.nCtx(), z3Obj); + } } -} \ No newline at end of file +} \ No newline at end of file diff --git a/src/api/java/SimplifierDecRefQueue.java b/src/api/java/SimplifierDecRefQueue.java deleted file mode 100644 index ba15dd5be38..00000000000 --- a/src/api/java/SimplifierDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - SimplifierDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class SimplifierDecRefQueue extends IDecRefQueue { - public SimplifierDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) - { - Native.simplifierDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Solver.java b/src/api/java/Solver.java index ce795d75833..b814a4db693 100644 --- a/src/api/java/Solver.java +++ b/src/api/java/Solver.java @@ -19,6 +19,8 @@ package com.microsoft.z3; import com.microsoft.z3.enumerations.Z3_lbool; + +import java.lang.ref.ReferenceQueue; import java.util.*; /** @@ -403,6 +405,18 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getSolverDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, SolverRef::new); + } + + private static class SolverRef extends Z3ReferenceQueue.Reference { + + private SolverRef(Solver referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.solverDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/SolverDecRefQueue.java b/src/api/java/SolverDecRefQueue.java deleted file mode 100644 index efa15d9398a..00000000000 --- a/src/api/java/SolverDecRefQueue.java +++ /dev/null @@ -1,27 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - SolverDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class SolverDecRefQueue extends IDecRefQueue { - public SolverDecRefQueue() { super(); } - - @Override - protected void decRef(Context ctx, long obj) { - Native.solverDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Statistics.java b/src/api/java/Statistics.java index d509424ed99..6d42e1af0ab 100644 --- a/src/api/java/Statistics.java +++ b/src/api/java/Statistics.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Objects of this class track statistical information about solvers. **/ @@ -25,7 +27,7 @@ public class Statistics extends Z3Object { * Statistical data is organized into pairs of [Key, Entry], where every * Entry is either a {@code DoubleEntry} or a {@code UIntEntry} **/ - public class Entry + public static class Entry { /** * The key of the entry. @@ -191,11 +193,23 @@ public Entry get(String key) @Override void incRef() { - getContext().getStatisticsDRQ().storeReference(getContext(), this); + Native.statsIncRef(getContext().nCtx(), getNativeObject()); } @Override void addToReferenceQueue() { - Native.statsIncRef(getContext().nCtx(), getNativeObject()); + getContext().getReferenceQueue().storeReference(this, StatisticsRef::new); + } + + private static class StatisticsRef extends Z3ReferenceQueue.Reference { + + private StatisticsRef(Statistics referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.statsDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/StatisticsDecRefQueue.java b/src/api/java/StatisticsDecRefQueue.java deleted file mode 100644 index ed698e4cadf..00000000000 --- a/src/api/java/StatisticsDecRefQueue.java +++ /dev/null @@ -1,30 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - StatisticsDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class StatisticsDecRefQueue extends IDecRefQueue { - public StatisticsDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) { - Native.statsDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/Tactic.java b/src/api/java/Tactic.java index 11d02ca732f..d70d8a4096e 100644 --- a/src/api/java/Tactic.java +++ b/src/api/java/Tactic.java @@ -17,6 +17,8 @@ package com.microsoft.z3; +import java.lang.ref.ReferenceQueue; + /** * Tactics are the basic building block for creating custom solvers for specific * problem domains. The complete list of tactics may be obtained using @@ -98,6 +100,19 @@ void incRef() { @Override void addToReferenceQueue() { - getContext().getTacticDRQ().storeReference(getContext(), this); + //getContext().getTacticDRQ().storeReference(getContext(), this); + getContext().getReferenceQueue().storeReference(this, TacticRef::new); + } + + private static class TacticRef extends Z3ReferenceQueue.Reference { + + private TacticRef(Tactic referent, ReferenceQueue q) { + super(referent, q); + } + + @Override + void decRef(Context ctx, long z3Obj) { + Native.tacticDecRef(ctx.nCtx(), z3Obj); + } } } diff --git a/src/api/java/TacticDecRefQueue.java b/src/api/java/TacticDecRefQueue.java deleted file mode 100644 index 8f151f25ce6..00000000000 --- a/src/api/java/TacticDecRefQueue.java +++ /dev/null @@ -1,31 +0,0 @@ -/** -Copyright (c) 2012-2014 Microsoft Corporation - -Module Name: - - TacticDecRefQueue.java - -Abstract: - -Author: - - @author Christoph Wintersteiger (cwinter) 2012-03-15 - -Notes: - -**/ - -package com.microsoft.z3; - -class TacticDecRefQueue extends IDecRefQueue { - public TacticDecRefQueue() - { - super(); - } - - @Override - protected void decRef(Context ctx, long obj) - { - Native.tacticDecRef(ctx.nCtx(), obj); - } -} diff --git a/src/api/java/UserPropagatorBase.java b/src/api/java/UserPropagatorBase.java new file mode 100644 index 00000000000..46a61400db2 --- /dev/null +++ b/src/api/java/UserPropagatorBase.java @@ -0,0 +1,108 @@ +package com.microsoft.z3; + +import com.microsoft.z3.Context; +import com.microsoft.z3.enumerations.Z3_lbool; + +public abstract class UserPropagatorBase extends Native.UserPropagatorBase { + private Context ctx; + private Solver solver; + + public UserPropagatorBase(Context _ctx, Solver _solver) { + super(_ctx.nCtx(), _solver.getNativeObject()); + ctx = _ctx; + solver = _solver; + } + + public final Context getCtx() { + return ctx; + } + + public final Solver getSolver() { + return solver; + } + + @Override + protected final void pushWrapper() { + push(); + } + + @Override + protected final void popWrapper(int number) { + pop(number); + } + + @Override + protected final void finWrapper() { + fin(); + } + + @Override + protected final void eqWrapper(long lx, long ly) { + Expr x = new Expr(ctx, lx); + Expr y = new Expr(ctx, ly); + eq(x, y); + } + + @Override + protected final UserPropagatorBase freshWrapper(long lctx) { + return fresh(new Context(lctx)); + } + + @Override + protected final void createdWrapper(long last) { + created(new Expr(ctx, last)); + } + + @Override + protected final void fixedWrapper(long lvar, long lvalue) { + Expr var = new Expr(ctx, lvar); + Expr value = new Expr(ctx, lvalue); + fixed(var, value); + } + + @Override + protected final void decideWrapper(long lvar, int bit, boolean is_pos) { + Expr var = new Expr(ctx, lvar); + decide(var, bit, is_pos); + } + + public abstract void push(); + + public abstract void pop(int number); + + public abstract UserPropagatorBase fresh(Context ctx); + + public void created(Expr ast) {} + + public void fixed(Expr var, Expr value) {} + + public void eq(Expr x, Expr y) {} + + public void decide(Expr var, int bit, boolean is_pos) {} + + public void fin() {} + + public final void add(Expr expr) { + Native.propagateAdd(this, ctx.nCtx(), solver.getNativeObject(), javainfo, expr.getNativeObject()); + } + + public final boolean conflict(Expr[] fixed) { + return conflict(fixed, new Expr[0], new Expr[0]); + } + + public final boolean conflict(Expr[] fixed, Expr[] lhs, Expr[] rhs) { + return consequence(fixed, lhs, rhs, ctx.mkBool(false)); + } + + public final boolean consequence(Expr[] fixed, Expr[] lhs, Expr[] rhs, Expr conseq) { + return Native.propagateConsequence( + this, ctx.nCtx(), solver.getNativeObject(), javainfo, + fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject()); + } + + public final boolean nextSplit(Expr e, long idx, Z3_lbool phase) { + return Native.propagateNextSplit( + this, ctx.nCtx(), solver.getNativeObject(), javainfo, + e.getNativeObject(), idx, phase.toInt()); + } +} diff --git a/src/api/java/Z3ReferenceQueue.java b/src/api/java/Z3ReferenceQueue.java new file mode 100644 index 00000000000..22435599fd9 --- /dev/null +++ b/src/api/java/Z3ReferenceQueue.java @@ -0,0 +1,144 @@ +/** +Copyright (c) 2012-2014 Microsoft Corporation + +Module Name: + + IDecRefQueue.java + +Abstract: + +Author: + + @author Christoph Wintersteiger (cwinter) 2012-03-15 + +Notes: + +**/ + +package com.microsoft.z3; + +import java.lang.ref.PhantomReference; +import java.lang.ref.ReferenceQueue; + +/** + * A queue to handle management of native memory. + * + *

Mechanics: When an object is created, a so-called {@link PhantomReference} + * is constructed that is associated with the created object and the reference queue {@code referenceQueue}. + * Once the object becomes strongly unreachable, the phantom reference gets + * added by JVM to the {@code referenceQueue}. + * After each object creation, we iterate through the available objects in + * {@code referenceQueue} and decrement references for them. + *

+ * In order for this to work, we need to (i) associate to each phantom reference the underlying + * native object (and its type) that it references and (ii) keep the phantom references themselves alive, so they are not + * garbage collected before the object they reference. + * We use a doubly-linked list of custom phantom references, subclasses of {@link Reference}, to achieve this. + * + */ +class Z3ReferenceQueue { + private final Context ctx; + private final ReferenceQueue referenceQueue = new ReferenceQueue<>(); + private final Reference referenceList = emptyList(); + + Z3ReferenceQueue(Context ctx) { + this.ctx = ctx; + } + + /** + * Create and store a new phantom reference. + */ + void storeReference(T z3Object, ReferenceConstructor refConstructor) { + referenceList.insert(refConstructor.construct(z3Object, referenceQueue)); + clear(); + } + + /** + * Clean all references currently in {@code referenceQueue}. + */ + private void clear() { + Reference ref; + while ((ref = (Reference)referenceQueue.poll()) != null) { + ref.cleanup(ctx); + } + } + + /** + * Clean all references stored in {@code referenceList}, + * regardless of whether they are in {@code referenceQueue} or not. + */ + @SuppressWarnings("StatementWithEmptyBody") + public void forceClear() { + // Decrement all reference counters + Reference cur = referenceList.next; + while (cur.next != null) { + cur.decRef(ctx, cur.nativePtr); + cur = cur.next; + } + + // Bulk-delete the reference list's entries + referenceList.next = cur; + cur.prev = referenceList; + + // Empty the reference queue so that there are no living phantom references anymore. + // This makes sure that all stored phantom references can be GC'd now. + while (referenceQueue.poll() != null) {} + } + + private static Reference emptyList() { + Reference head = new DummyReference(); + Reference tail = new DummyReference(); + head.next = tail; + tail.prev = head; + return head; + } + + // ================================================================================================================ + + @FunctionalInterface + interface ReferenceConstructor { + Reference construct(T reference, ReferenceQueue queue); + } + + abstract static class Reference extends PhantomReference { + + private Reference prev; + private Reference next; + private final long nativePtr; + + Reference(T referent, ReferenceQueue q) { + super(referent, q); + this.nativePtr = referent != null ? referent.getNativeObject() : 0; + } + + private void cleanup(Context ctx) { + decRef(ctx, nativePtr); + assert (prev != null && next != null); + prev.next = next; + next.prev = prev; + } + + private void insert(Reference ref) { + assert next != null; + ref.prev = this; + ref.next = this.next; + ref.next.prev = ref; + next = ref; + } + + abstract void decRef(Context ctx, long z3Obj); + } + + private static class DummyReference extends Reference { + + public DummyReference() { + super(null, null); + } + + @Override + void decRef(Context ctx, long z3Obj) { + // Should never be called. + assert false; + } + } +} diff --git a/src/api/js/README.md b/src/api/js/README.md index 8c446b91020..f53428bbdee 100644 --- a/src/api/js/README.md +++ b/src/api/js/README.md @@ -13,7 +13,7 @@ Then run `npm i` to install dependencies, `npm run build:ts` to build the TypeSc ### Build on your own -Consult the file [build-wasm.ts](https://github.com/Z3Prover/z3/blob/master/src/api/js/scripts/build-wasm.ts) for configurations used for building wasm. +Consult the file [build-wasm.ts](https://github.com/Z3Prover/z3/blob/master/src/api/js/scripts/build-wasm.ts) for configurations used for building wasm. ## Tests diff --git a/src/api/js/package-lock.json b/src/api/js/package-lock.json index a46cae95101..5f1d63bd0dd 100644 --- a/src/api/js/package-lock.json +++ b/src/api/js/package-lock.json @@ -1,7 +1,8 @@ { "name": "z3-solver", - "requires": true, + "version": "0.1.0", "lockfileVersion": 1, + "requires": true, "dependencies": { "@ampproject/remapping": { "version": "2.2.0", @@ -109,25 +110,6 @@ "integrity": "sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg==", "dev": true }, - "@babel/helper-function-name": { - "version": "7.19.0", - "resolved": "https://registry.npmjs.org/@babel/helper-function-name/-/helper-function-name-7.19.0.tgz", - "integrity": "sha512-WAwHBINyrpqywkUH0nTnNgI5ina5TFn85HKS0pbPDfxFfhyR/aNQEn4hGi1P1JyT//I0t4OgXUlofzWILRvS5w==", - "dev": true, - "requires": { - "@babel/template": "^7.18.10", - "@babel/types": "^7.19.0" - } - }, - "@babel/helper-hoist-variables": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz", - "integrity": "sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q==", - "dev": true, - "requires": { - "@babel/types": "^7.18.6" - } - }, "@babel/helper-module-imports": { "version": "7.18.6", "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.18.6.tgz", @@ -361,21 +343,141 @@ } }, "@babel/traverse": { - "version": "7.19.4", - "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.19.4.tgz", - "integrity": "sha512-w3K1i+V5u2aJUOXBFFC5pveFLmtq1s3qcdDNC2qRI6WPBQIDaKFqXxDEqDO/h1dQ3HjsZoZMyIy6jGLq0xtw+g==", - "dev": true, - "requires": { - "@babel/code-frame": "^7.18.6", - "@babel/generator": "^7.19.4", - "@babel/helper-environment-visitor": "^7.18.9", - "@babel/helper-function-name": "^7.19.0", - "@babel/helper-hoist-variables": "^7.18.6", - "@babel/helper-split-export-declaration": "^7.18.6", - "@babel/parser": "^7.19.4", - "@babel/types": "^7.19.4", + "version": "7.23.2", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.2.tgz", + "integrity": "sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==", + "dev": true, + "requires": { + "@babel/code-frame": "^7.22.13", + "@babel/generator": "^7.23.0", + "@babel/helper-environment-visitor": "^7.22.20", + "@babel/helper-function-name": "^7.23.0", + "@babel/helper-hoist-variables": "^7.22.5", + "@babel/helper-split-export-declaration": "^7.22.6", + "@babel/parser": "^7.23.0", + "@babel/types": "^7.23.0", "debug": "^4.1.0", "globals": "^11.1.0" + }, + "dependencies": { + "@babel/code-frame": { + "version": "7.22.13", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.22.13.tgz", + "integrity": "sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==", + "dev": true, + "requires": { + "@babel/highlight": "^7.22.13", + "chalk": "^2.4.2" + } + }, + "@babel/generator": { + "version": "7.23.0", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.23.0.tgz", + "integrity": "sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g==", + "dev": true, + "requires": { + "@babel/types": "^7.23.0", + "@jridgewell/gen-mapping": "^0.3.2", + "@jridgewell/trace-mapping": "^0.3.17", + "jsesc": "^2.5.1" + } + }, + "@babel/helper-environment-visitor": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz", + "integrity": "sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA==", + "dev": true + }, + "@babel/helper-function-name": { + "version": "7.23.0", + "resolved": "https://registry.npmjs.org/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz", + "integrity": "sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==", + "dev": true, + "requires": { + "@babel/template": "^7.22.15", + "@babel/types": "^7.23.0" + } + }, + "@babel/helper-hoist-variables": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz", + "integrity": "sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==", + "dev": true, + "requires": { + "@babel/types": "^7.22.5" + } + }, + "@babel/helper-split-export-declaration": { + "version": "7.22.6", + "resolved": "https://registry.npmjs.org/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz", + "integrity": "sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==", + "dev": true, + "requires": { + "@babel/types": "^7.22.5" + } + }, + "@babel/helper-string-parser": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz", + "integrity": "sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==", + "dev": true + }, + "@babel/helper-validator-identifier": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz", + "integrity": "sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==", + "dev": true + }, + "@babel/highlight": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.22.20.tgz", + "integrity": "sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==", + "dev": true, + "requires": { + "@babel/helper-validator-identifier": "^7.22.20", + "chalk": "^2.4.2", + "js-tokens": "^4.0.0" + } + }, + "@babel/parser": { + "version": "7.23.0", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.23.0.tgz", + "integrity": "sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw==", + "dev": true + }, + "@babel/template": { + "version": "7.22.15", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.22.15.tgz", + "integrity": "sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==", + "dev": true, + "requires": { + "@babel/code-frame": "^7.22.13", + "@babel/parser": "^7.22.15", + "@babel/types": "^7.22.15" + } + }, + "@babel/types": { + "version": "7.23.0", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.23.0.tgz", + "integrity": "sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg==", + "dev": true, + "requires": { + "@babel/helper-string-parser": "^7.22.5", + "@babel/helper-validator-identifier": "^7.22.20", + "to-fast-properties": "^2.0.0" + } + }, + "@jridgewell/gen-mapping": { + "version": "0.3.3", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", + "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "dev": true, + "requires": { + "@jridgewell/set-array": "^1.0.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.9" + } + } } }, "@babel/types": { diff --git a/src/api/js/scripts/build-wasm.ts b/src/api/js/scripts/build-wasm.ts index 0f51c84d3ee..5bdbdea3b47 100644 --- a/src/api/js/scripts/build-wasm.ts +++ b/src/api/js/scripts/build-wasm.ts @@ -69,7 +69,7 @@ const fns = JSON.stringify(exportedFuncs()); const methods = '["ccall","FS","allocate","UTF8ToString","intArrayFromString","ALLOC_NORMAL"]'; const libz3a = path.normalize('../../../build/libz3.a'); spawnSync( - `emcc build/async-fns.cc ${libz3a} --std=c++20 --pre-js src/low-level/async-wrapper.js -g2 -pthread -fexceptions -s WASM_BIGINT -s USE_PTHREADS=1 -s PTHREAD_POOL_SIZE=0 -s PTHREAD_POOL_SIZE_STRICT=0 -s MODULARIZE=1 -s 'EXPORT_NAME="initZ3"' -s EXPORTED_RUNTIME_METHODS=${methods} -s EXPORTED_FUNCTIONS=${fns} -s DISABLE_EXCEPTION_CATCHING=0 -s SAFE_HEAP=0 -s DEMANGLE_SUPPORT=1 -s TOTAL_MEMORY=1GB -I z3/src/api/ -o build/z3-built.js`, + `emcc build/async-fns.cc ${libz3a} --std=c++20 --pre-js src/low-level/async-wrapper.js -g2 -pthread -fexceptions -s WASM_BIGINT -s USE_PTHREADS=1 -s PTHREAD_POOL_SIZE=0 -s PTHREAD_POOL_SIZE_STRICT=0 -s MODULARIZE=1 -s 'EXPORT_NAME="initZ3"' -s EXPORTED_RUNTIME_METHODS=${methods} -s EXPORTED_FUNCTIONS=${fns} -s DISABLE_EXCEPTION_CATCHING=0 -s SAFE_HEAP=0 -s DEMANGLE_SUPPORT=1 -s TOTAL_MEMORY=1GB -s TOTAL_STACK=20MB -I z3/src/api/ -o build/z3-built.js`, ); fs.rmSync(ccWrapperPath); diff --git a/src/api/js/scripts/parse-api.ts b/src/api/js/scripts/parse-api.ts index a3aa81acdc3..151e2f7bd62 100644 --- a/src/api/js/scripts/parse-api.ts +++ b/src/api/js/scripts/parse-api.ts @@ -350,7 +350,7 @@ for (let fn of functions) { param.sizeIndex = defParams[idx].sizeIndex; if (!param.isArray && param.isPtr) { // not clear why some things are written as `int * x` and others `int x[]` - // but we can jsut cast + // but we can just cast param.isArray = true; param.isPtr = false; } diff --git a/src/api/julia/README.md b/src/api/julia/README.md index ec4b1e1091f..12e8461912f 100644 --- a/src/api/julia/README.md +++ b/src/api/julia/README.md @@ -14,7 +14,10 @@ make ## Julia part -The Z3 binaries are provided to [Z3.jl](https://github.com/ahumenberger/Z3.jl) via [z3_jll.jl](https://github.com/JuliaBinaryWrappers/z3_jll.jl). That is, in order to release a new Z3 version one has to update the corresponding [build script](https://github.com/JuliaPackaging/Yggdrasil/tree/master/Z/z3) which triggers a new version of z3_jll.jl. +The Z3 binaries are provided to [Z3.jl](https://github.com/ahumenberger/Z3.jl) via [z3_jll.jl](https://github.com/JuliaBinaryWrappers/z3_jll.jl). +That is, in order to propagate any C++ changes to the Julia side, one has to: +1. Release a new version of Z3. +2. Update the corresponding [build script](https://github.com/JuliaPackaging/Yggdrasil/tree/master/Z/z3) to use the new Z3 release. ### Using the compiled version of Z3 diff --git a/src/api/julia/z3jl.cpp b/src/api/julia/z3jl.cpp index 5aef2f41db0..336a739763d 100644 --- a/src/api/julia/z3jl.cpp +++ b/src/api/julia/z3jl.cpp @@ -303,6 +303,8 @@ JLCXX_MODULE define_julia_module(jlcxx::Module &m) m.method("xnor", &xnor); m.method("min", &min); m.method("max", &max); + m.method("exists", static_cast(&exists)); + m.method("forall", static_cast(&forall)); m.method("abs", static_cast(&abs)); m.method("sqrt", static_cast(&sqrt)); m.method("fma", static_cast(&fma)); diff --git a/src/api/ml/README.md b/src/api/ml/README.md index 7b9d5b4766e..041f37eb48c 100644 --- a/src/api/ml/README.md +++ b/src/api/ml/README.md @@ -223,7 +223,7 @@ correctly found by gcc. I specifically left the cygwin part of the code intact as I have no idea what the original author meant by this, neither do I use or -tested this patch in the cygwin or mingw environemt. I think that this +tested this patch in the cygwin or mingw environment. I think that this code is rather outdated and shouldn't really work. E.g., in the --staticlib mode adding z3linkdep (which is libz3-static.a) as an argument to `ocamlmklib` will yield the following broken archive diff --git a/src/api/ml/z3.ml b/src/api/ml/z3.ml index 98807bedd6f..0be6e57a0d1 100644 --- a/src/api/ml/z3.ml +++ b/src/api/ml/z3.ml @@ -8,7 +8,7 @@ open Z3enums exception Error of string -let _ = Callback.register_exception "Z3EXCEPTION" (Error "") +let () = Callback.register_exception "Z3EXCEPTION" (Error "") type context = Z3native.context @@ -26,31 +26,18 @@ struct let (major, minor, build, revision) = Z3native.get_version () let full_version : string = Z3native.get_full_version() - - let to_string = - string_of_int major ^ "." ^ - string_of_int minor ^ "." ^ - string_of_int build ^ "." ^ - string_of_int revision -end -let mk_list f n = - let rec mk_list' i accu = - if i >= n then - List.rev accu - else - mk_list' (i + 1) ((f i)::accu) - in - mk_list' 0 [] + let to_string = Printf.sprintf "%d.%d.%d.%d" major minor build revision +end let check_int32 v = v = Int32.to_int (Int32.of_int v) -let mk_int_expr ctx v ty = +let mk_int_expr ctx v ty = if not (check_int32 v) then Z3native.mk_numeral ctx (string_of_int v) ty else Z3native.mk_int ctx v ty - + let mk_context (settings:(string * string) list) = let cfg = Z3native.mk_config () in let f e = Z3native.set_param_value cfg (fst e) (snd e) in @@ -62,10 +49,13 @@ let mk_context (settings:(string * string) list) = Z3native.enable_concurrent_dec_ref res; res +let interrupt (ctx:context) = + Z3native.interrupt ctx + module Symbol = struct type symbol = Z3native.symbol - let gc = Z3native.context_of_symbol + let gc s = Z3native.context_of_symbol s let kind o = symbol_kind_of_int (Z3native.get_symbol_kind (gc o) o) let is_int_symbol o = kind o = INT_SYMBOL @@ -77,8 +67,8 @@ struct | INT_SYMBOL -> string_of_int (Z3native.get_symbol_int (gc o) o) | STRING_SYMBOL -> Z3native.get_symbol_string (gc o) o - let mk_int = Z3native.mk_int_symbol - let mk_string = Z3native.mk_string_symbol + let mk_int ctx = Z3native.mk_int_symbol ctx + let mk_string ctx s = Z3native.mk_string_symbol ctx s let mk_ints ctx names = List.map (mk_int ctx) names let mk_strings ctx names = List.map (mk_string ctx) names @@ -132,12 +122,12 @@ sig val translate : ast -> context -> ast end = struct type ast = Z3native.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a module ASTVector = struct type ast_vector = Z3native.ast_vector - let gc = Z3native.context_of_ast_vector + let gc v = Z3native.context_of_ast_vector v let mk_ast_vector = Z3native.mk_ast_vector let get_size (x:ast_vector) = Z3native.ast_vector_size (gc x) x @@ -150,12 +140,12 @@ end = struct let to_list (x:ast_vector) = let xs = get_size x in let f i = get x i in - mk_list f xs + List.init xs f let to_expr_list (x:ast_vector) = let xs = get_size x in let f i = get x i in - mk_list f xs + List.init xs f let to_string x = Z3native.ast_vector_to_string (gc x) x end @@ -163,7 +153,7 @@ end = struct module ASTMap = struct type ast_map = Z3native.ast_map - let gc = Z3native.context_of_ast_map + let gc m = Z3native.context_of_ast_map m let mk_ast_map = Z3native.mk_ast_map let contains (x:ast_map) (key:ast) = Z3native.ast_map_contains (gc x) x key @@ -228,7 +218,7 @@ sig val mk_uninterpreted_s : context -> string -> sort end = struct type sort = Z3native.sort - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let equal a b = (a = b) || (gc a = gc b && Z3native.is_eq_sort (gc a) a b) @@ -236,7 +226,7 @@ end = struct let get_sort_kind (x:sort) = sort_kind_of_int (Z3native.get_sort_kind (gc x) x) let get_name (x:sort) = Z3native.get_sort_name (gc x) x let to_string (x:sort) = Z3native.sort_to_string (gc x) x - let mk_uninterpreted = Z3native.mk_uninterpreted_sort + let mk_uninterpreted ctx s = Z3native.mk_uninterpreted_sort ctx s let mk_uninterpreted_s (ctx:context) (s:string) = mk_uninterpreted ctx (Symbol.mk_string ctx s) end @@ -287,7 +277,7 @@ sig val apply : func_decl -> Expr.expr list -> Expr.expr end = struct type func_decl = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a module Parameter = struct @@ -375,7 +365,7 @@ end = struct let get_domain (x:func_decl) = let n = get_domain_size x in let f i = Z3native.get_domain (gc x) x i in - mk_list f n + List.init n f let get_range (x:func_decl) = Z3native.get_range (gc x) x let get_decl_kind (x:func_decl) = decl_kind_of_int (Z3native.get_decl_kind (gc x) x) @@ -394,7 +384,7 @@ end = struct | PARAMETER_FUNC_DECL -> Parameter.P_Fdl (Z3native.get_decl_func_decl_parameter (gc x) x i) | PARAMETER_RATIONAL -> Parameter.P_Rat (Z3native.get_decl_rational_parameter (gc x) x i) in - mk_list f n + List.init n f let apply (x:func_decl) (args:Expr.expr list) = Expr.expr_of_func_app (gc x) x args end @@ -423,12 +413,12 @@ sig val set_print_mode : context -> Z3enums.ast_print_mode -> unit end = struct type params = Z3native.params - let gc = Z3native.context_of_params + let gc p = Z3native.context_of_params p module ParamDescrs = struct type param_descrs = Z3native.param_descrs - let gc = Z3native.context_of_param_descrs + let gc p = Z3native.context_of_param_descrs p let validate (x:param_descrs) (p:params) = Z3native.params_validate (gc x) p x let get_kind (x:param_descrs) (name:Symbol.symbol) = param_kind_of_int (Z3native.param_descrs_get_kind (gc x) x name) @@ -436,7 +426,7 @@ end = struct let get_names (x:param_descrs) = let n = Z3native.param_descrs_size (gc x) x in let f i = Z3native.param_descrs_get_name (gc x) x i in - mk_list f n + List.init n f let get_size (x:param_descrs) = Z3native.param_descrs_size (gc x) x let to_string (x:param_descrs) = Z3native.param_descrs_to_string (gc x) x @@ -488,7 +478,7 @@ sig val compare : expr -> expr -> int end = struct type expr = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let expr_of_ast a = let q = Z3enums.ast_kind_of_int (Z3native.get_ast_kind (gc a) a) in @@ -514,7 +504,7 @@ end = struct let get_args (x:expr) = let n = get_num_args x in let f i = Z3native.get_app_arg (gc x) x i in - mk_list f n + List.init n f let update (x:expr) (args:expr list) = if AST.is_app x && List.length args <> get_num_args x then @@ -564,11 +554,11 @@ open Expr module Boolean = struct - let mk_sort = Z3native.mk_bool_sort + let mk_sort ctx = Z3native.mk_bool_sort ctx let mk_const (ctx:context) (name:Symbol.symbol) = Expr.mk_const ctx name (mk_sort ctx) let mk_const_s (ctx:context) (name:string) = mk_const ctx (Symbol.mk_string ctx name) - let mk_true = Z3native.mk_true - let mk_false = Z3native.mk_false + let mk_true ctx = Z3native.mk_true ctx + let mk_false ctx = Z3native.mk_false ctx let mk_val (ctx:context) (value:bool) = if value then mk_true ctx else mk_false ctx let mk_not = Z3native.mk_not let mk_ite = Z3native.mk_ite @@ -606,7 +596,7 @@ end module Quantifier = struct type quantifier = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let expr_of_quantifier q = q @@ -620,14 +610,14 @@ struct module Pattern = struct type pattern = Z3native.pattern - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let get_num_terms x = Z3native.get_pattern_num_terms (gc x) x let get_terms x = let n = get_num_terms x in let f i = Z3native.get_pattern (gc x) x i in - mk_list f n + List.init n f let to_string x = Z3native.pattern_to_string (gc x) x end @@ -645,26 +635,26 @@ struct let get_patterns x = let n = get_num_patterns x in let f i = Z3native.get_quantifier_pattern_ast (gc x) x i in - mk_list f n + List.init n f let get_num_no_patterns x = Z3native.get_quantifier_num_no_patterns (gc x) x let get_no_patterns x = let n = get_num_patterns x in let f i = Z3native.get_quantifier_no_pattern_ast (gc x) x i in - mk_list f n + List.init n f let get_num_bound x = Z3native.get_quantifier_num_bound (gc x) x let get_bound_variable_names x = let n = get_num_bound x in let f i = Z3native.get_quantifier_bound_name (gc x) x i in - mk_list f n + List.init n f let get_bound_variable_sorts x = let n = get_num_bound x in let f i = Z3native.get_quantifier_bound_sort (gc x) x i in - mk_list f n + List.init n f let get_body x = Z3native.get_quantifier_body (gc x) x let mk_bound = Z3native.mk_bound @@ -721,7 +711,7 @@ struct let mk_exists = _internal_mk_quantifier ~universal:false let mk_exists_const = _internal_mk_quantifier_const ~universal:false let mk_lambda_const ctx bound body = Z3native.mk_lambda_const ctx (List.length bound) bound body - let mk_lambda ctx bound body = + let mk_lambda ctx bound body = let names = List.map (fun (x,_) -> x) bound in let sorts = List.map (fun (_,y) -> y) bound in Z3native.mk_lambda ctx (List.length bound) sorts names body @@ -743,7 +733,7 @@ end module Z3Array = struct - let mk_sort = Z3native.mk_array_sort + let mk_sort ctx domain range = Z3native.mk_array_sort ctx domain range let is_store x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_STORE let is_select x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_SELECT let is_constant_array x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_CONST_ARRAY @@ -803,7 +793,7 @@ end module FiniteDomain = struct - let mk_sort = Z3native.mk_finite_domain_sort + let mk_sort ctx s size = Z3native.mk_finite_domain_sort ctx s size let mk_sort_s ctx name size = mk_sort ctx (Symbol.mk_string ctx name) size let is_finite_domain (x:expr) = @@ -846,7 +836,7 @@ struct let get_column_sorts (x:Sort.sort) = let n = get_arity x in let f i = Z3native.get_relation_column (Sort.gc x) x i in - mk_list f n + List.init n f end @@ -855,15 +845,6 @@ struct module Constructor = struct type constructor = Z3native.constructor - - module FieldNumTable = Hashtbl.Make(struct - type t = AST.ast - let equal x y = AST.compare x y = 0 - let hash = AST.hash - end) - - let _field_nums = FieldNumTable.create 0 - let create (ctx:context) (name:Symbol.symbol) (recognizer:Symbol.symbol) (field_names:Symbol.symbol list) (sorts:Sort.sort option list) (sort_refs:int list) = let n = List.length field_names in if n <> List.length sorts then @@ -879,10 +860,9 @@ struct (let f x = match x with None -> Z3native.mk_null_ast ctx | Some s -> s in List.map f sorts) sort_refs in - FieldNumTable.add _field_nums no n; no - let get_num_fields (x:constructor) = FieldNumTable.find _field_nums x + let get_num_fields (x:constructor) = Z3native.constructor_num_fields (gc x) x let get_constructor_decl (x:constructor) = let (a, _, _) = (Z3native.query_constructor (gc x) x (get_num_fields x)) in @@ -917,10 +897,10 @@ struct let mk_sort_s (ctx:context) (name:string) (constructors:Constructor.constructor list) = mk_sort ctx (Symbol.mk_string ctx name) constructors - + let mk_sort_ref (ctx: context) (name:Symbol.symbol) = Z3native.mk_datatype_sort ctx name - + let mk_sort_ref_s (ctx: context) (name: string) = mk_sort_ref ctx (Symbol.mk_string ctx name) @@ -939,12 +919,12 @@ struct let get_constructors (x:Sort.sort) = let n = get_num_constructors x in let f i = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in - mk_list f n + List.init n f let get_recognizers (x:Sort.sort) = let n = (get_num_constructors x) in let f i = Z3native.get_datatype_sort_recognizer (Sort.gc x) x i in - mk_list f n + List.init n f let get_accessors (x:Sort.sort) = let n = (get_num_constructors x) in @@ -952,8 +932,8 @@ struct let fd = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in let ds = Z3native.get_domain_size (FuncDecl.gc fd) fd in let g j = Z3native.get_datatype_sort_constructor_accessor (Sort.gc x) x i j in - mk_list g ds) in - mk_list f n + List.init ds g) in + List.init n f end @@ -969,21 +949,21 @@ struct let get_const_decls (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in - mk_list f n + List.init n f let get_const_decl (x:Sort.sort) (inx:int) = Z3native.get_datatype_sort_constructor (Sort.gc x) x inx let get_consts (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Expr.mk_const_f (Sort.gc x) (get_const_decl x i) in - mk_list f n + List.init n f let get_const (x:Sort.sort) (inx:int) = Expr.mk_const_f (Sort.gc x) (get_const_decl x inx) let get_tester_decls (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Z3native.get_datatype_sort_recognizer (Sort.gc x) x i in - mk_list f n + List.init n f let get_tester_decl (x:Sort.sort) (inx:int) = Z3native.get_datatype_sort_recognizer (Sort.gc x) x inx end @@ -1017,8 +997,8 @@ struct let get_field_decls (x:Sort.sort) = let n = get_num_fields x in - let f i =Z3native.get_tuple_sort_field_decl (Sort.gc x) x i in - mk_list f n + let f i = Z3native.get_tuple_sort_field_decl (Sort.gc x) x i in + List.init n f end @@ -1050,7 +1030,7 @@ struct module Integer = struct - let mk_sort = Z3native.mk_int_sort + let mk_sort ctx = Z3native.mk_int_sort ctx let get_int x = match Z3native.get_numeral_int (Expr.gc x) x with @@ -1077,7 +1057,7 @@ struct module Real = struct - let mk_sort = Z3native.mk_real_sort + let mk_sort ctx = Z3native.mk_real_sort ctx let get_numerator x = Z3native.get_numerator (Expr.gc x) x let get_denominator x = Z3native.get_denominator (Expr.gc x) x @@ -1249,7 +1229,7 @@ end module Seq = struct let mk_seq_sort = Z3native.mk_seq_sort - let is_seq_sort = Z3native.is_seq_sort + let is_seq_sort = Z3native.is_seq_sort let mk_re_sort = Z3native.mk_re_sort let is_re_sort = Z3native.is_re_sort let mk_string_sort = Z3native.mk_string_sort @@ -1264,7 +1244,7 @@ struct let mk_seq_concat ctx args = Z3native.mk_seq_concat ctx (List.length args) args let mk_seq_prefix = Z3native.mk_seq_prefix let mk_seq_suffix = Z3native.mk_seq_suffix - let mk_seq_contains = Z3native.mk_seq_contains + let mk_seq_contains = Z3native.mk_seq_contains let mk_seq_extract = Z3native.mk_seq_extract let mk_seq_replace = Z3native.mk_seq_replace let mk_seq_at = Z3native.mk_seq_at @@ -1474,7 +1454,7 @@ end module Goal = struct type goal = Z3native.goal - let gc = Z3native.context_of_goal + let gc g = Z3native.context_of_goal g let get_precision (x:goal) = goal_prec_of_int (Z3native.goal_precision (gc x) x) let is_precise (x:goal) = (get_precision x) = GOAL_PRECISE @@ -1493,7 +1473,7 @@ struct let get_formulas (x:goal) = let n = get_size x in let f i = Z3native.goal_formula (gc x) x i in - mk_list f n + List.init n f let get_num_exprs (x:goal) = Z3native.goal_num_exprs (gc x) x let is_decided_sat (x:goal) = Z3native.goal_is_decided_sat (gc x) x @@ -1509,13 +1489,15 @@ struct in Z3native.apply_result_inc_ref (gc x) arn; let sg = Z3native.apply_result_get_num_subgoals (gc x) arn in - let res = if sg = 0 then - raise (Error "No subgoals") - else - Z3native.apply_result_get_subgoal (gc x) arn 0 in - Z3native.apply_result_dec_ref (gc x) arn; - Z3native.tactic_dec_ref (gc x) tn; - res + if sg = 0 then ( + Z3native.apply_result_dec_ref (gc x) arn; + Z3native.tactic_dec_ref (gc x) tn; + raise (Error "No subgoals")) + else + let res:goal = Z3native.apply_result_get_subgoal (gc x) arn 0 in + Z3native.apply_result_dec_ref (gc x) arn; + Z3native.tactic_dec_ref (gc x) tn; + res let mk_goal = Z3native.mk_goal @@ -1532,17 +1514,17 @@ end module Model = struct type model = Z3native.model - let gc = Z3native.context_of_model + let gc m = Z3native.context_of_model m module FuncInterp = struct type func_interp = Z3native.func_interp - let gc = Z3native.context_of_func_interp + let gc f = Z3native.context_of_func_interp f module FuncEntry = struct type func_entry = Z3native.func_entry - let gc = Z3native.context_of_func_entry + let gc f = Z3native.context_of_func_entry f let get_value (x:func_entry) = Z3native.func_entry_get_value (gc x) x let get_num_args (x:func_entry) = Z3native.func_entry_get_num_args (gc x) x @@ -1550,7 +1532,7 @@ struct let get_args (x:func_entry) = let n = get_num_args x in let f i = Z3native.func_entry_get_arg (gc x) x i in - mk_list f n + List.init n f let to_string (x:func_entry) = let a = get_args x in @@ -1563,7 +1545,7 @@ struct let get_entries (x:func_interp) = let n = get_num_entries x in let f i = Z3native.func_interp_get_entry (gc x) x i in - mk_list f n + List.init n f let get_else (x:func_interp) = Z3native.func_interp_get_else (gc x) x @@ -1619,21 +1601,24 @@ struct let get_const_decls (x:model) = let n = (get_num_consts x) in let f i = Z3native.model_get_const_decl (gc x) x i in - mk_list f n + List.init n f let get_num_funcs (x:model) = Z3native.model_get_num_funcs (gc x) x let get_func_decls (x:model) = let n = (get_num_funcs x) in let f i = Z3native.model_get_func_decl (gc x) x i in - mk_list f n + List.init n f let get_decls (x:model) = let n_funcs = get_num_funcs x in let n_consts = get_num_consts x in let f i = Z3native.model_get_func_decl (gc x) x i in let g i = Z3native.model_get_const_decl (gc x) x i in - (mk_list f n_funcs) @ (mk_list g n_consts) + List.init (n_funcs + n_consts) (fun i -> + if i < n_funcs then f i + else g i + ) let eval (x:model) (t:expr) (completion:bool) = match Z3native.model_eval (gc x) x t completion with @@ -1646,7 +1631,7 @@ struct let get_sorts (x:model) = let n = get_num_sorts x in let f i = Z3native.model_get_sort (gc x) x i in - mk_list f n + List.init n f let sort_universe (x:model) (s:Sort.sort) = let av = Z3native.model_get_sort_universe (gc x) x s in @@ -1661,12 +1646,12 @@ struct type probe = Z3native.probe let apply (x:probe) (g:Goal.goal) = Z3native.probe_apply (gc x) x g - let get_num_probes = Z3native.get_num_probes + let get_num_probes ctx = Z3native.get_num_probes ctx let get_probe_names (ctx:context) = let n = get_num_probes ctx in let f i = Z3native.get_probe_name ctx i in - mk_list f n + List.init n f let get_probe_description = Z3native.probe_get_descr let mk_probe = Z3native.mk_probe @@ -1685,19 +1670,19 @@ end module Tactic = struct type tactic = Z3native.tactic - let gc = Z3native.context_of_tactic + let gc t = Z3native.context_of_tactic t module ApplyResult = struct type apply_result = Z3native.apply_result - let gc = Z3native.context_of_apply_result + let gc a = Z3native.context_of_apply_result a let get_num_subgoals (x:apply_result) = Z3native.apply_result_get_num_subgoals (gc x) x let get_subgoals (x:apply_result) = let n = get_num_subgoals x in let f i = Z3native.apply_result_get_subgoal (gc x) x i in - mk_list f n + List.init n f let get_subgoal (x:apply_result) (i:int) = Z3native.apply_result_get_subgoal (gc x) x i let to_string (x:apply_result) = Z3native.apply_result_to_string (gc x) x @@ -1711,23 +1696,26 @@ struct | None -> Z3native.tactic_apply (gc x) x g | Some pn -> Z3native.tactic_apply_ex (gc x) x g pn - let get_num_tactics = Z3native.get_num_tactics + let get_num_tactics ctx = Z3native.get_num_tactics ctx let get_tactic_names (ctx:context) = let n = get_num_tactics ctx in let f i = Z3native.get_tactic_name ctx i in - mk_list f n + List.init n f let get_tactic_description = Z3native.tactic_get_descr let mk_tactic = Z3native.mk_tactic let and_then (ctx:context) (t1:tactic) (t2:tactic) (ts:tactic list) = - let f p c = (match p with - | None -> Some c - | Some(x) -> Some (Z3native.tactic_and_then ctx c x)) in - match (List.fold_left f None ts) with + let f p c = + match p with + | None -> Some c + | Some x -> Some (Z3native.tactic_and_then ctx c x) + in + match List.fold_left f None ts with | None -> Z3native.tactic_and_then ctx t1 t2 - | Some(x) -> let o = Z3native.tactic_and_then ctx t2 x in + | Some x -> + let o = Z3native.tactic_and_then ctx t2 x in Z3native.tactic_and_then ctx t1 o let or_else = Z3native.tactic_or_else @@ -1749,18 +1737,18 @@ end module Simplifier = struct type simplifier = Z3native.simplifier - let gc = Z3native.context_of_simplifier + let gc s = Z3native.context_of_simplifier s let get_help (x:simplifier) = Z3native.simplifier_get_help (gc x) x let get_param_descrs (x:simplifier) = Z3native.simplifier_get_param_descrs (gc x) x - let get_num_simplifiers = Z3native.get_num_simplifiers + let get_num_simplifiers ctx = Z3native.get_num_simplifiers ctx let get_simplifier_names (ctx:context) = let n = get_num_simplifiers ctx in let f i = Z3native.get_simplifier_name ctx i in - mk_list f n + List.init n f let get_simplifier_description = Z3native.simplifier_get_descr @@ -1783,7 +1771,7 @@ end module Statistics = struct type statistics = Z3native.stats - let gc = Z3native.context_of_stats + let gc s = Z3native.context_of_stats s module Entry = struct @@ -1827,16 +1815,19 @@ struct else Entry.create_sd k (Z3native.stats_get_double_value (gc x) x i) in - mk_list f n + List.init n f let get_keys (x:statistics) = let n = get_size x in let f i = Z3native.stats_get_key (gc x) x i in - mk_list f n + List.init n f let get (x:statistics) (key:string) = try Some(List.find (fun c -> Entry.get_key c = key) (get_entries x)) with | Not_found -> None + + let get_estimated_alloc_size = + Z3native.get_estimated_alloc_size end @@ -1844,7 +1835,7 @@ module Solver = struct type solver = Z3native.solver type status = UNSATISFIABLE | UNKNOWN | SATISFIABLE - let gc = Z3native.context_of_solver + let gc s = Z3native.context_of_solver s let string_of_status (s:status) = match s with | UNSATISFIABLE -> "unsatisfiable" @@ -1889,9 +1880,9 @@ struct | _ -> UNKNOWN let get_model x = - try + try let q = Z3native.solver_get_model (gc x) x in - if Z3native.is_null_model q then None else Some q + if Z3native.is_null_model q then None else Some q with | _ -> None let get_proof x = @@ -1916,13 +1907,16 @@ struct let add_simplifier = Z3native.solver_add_simplifier let translate x = Z3native.solver_translate (gc x) x let to_string x = Z3native.solver_to_string (gc x) x + + let interrupt (ctx:context) (s:solver) = + Z3native.solver_interrupt ctx s end module Fixedpoint = struct type fixedpoint = Z3native.fixedpoint - let gc = Z3native.context_of_fixedpoint + let gc f = Z3native.context_of_fixedpoint f let get_help x = Z3native.fixedpoint_get_help (gc x) x let set_parameters x = Z3native.fixedpoint_set_params (gc x) x @@ -2054,22 +2048,22 @@ struct formula let parse_smtlib2_string (ctx:context) (str:string) (sort_names:Symbol.symbol list) (sorts:Sort.sort list) (decl_names:Symbol.symbol list) (decls:func_decl list) = - let csn = List.length sort_names in let cs = List.length sorts in - let cdn = List.length decl_names in let cd = List.length decls in - if csn <> cs || cdn <> cd then + if List.compare_length_with sort_names cs <> 0 + || List.compare_length_with decl_names cd <> 0 + then raise (Error "Argument size mismatch") else Z3native.parse_smtlib2_string ctx str cs sort_names sorts cd decl_names decls let parse_smtlib2_file (ctx:context) (file_name:string) (sort_names:Symbol.symbol list) (sorts:Sort.sort list) (decl_names:Symbol.symbol list) (decls:func_decl list) = - let csn = List.length sort_names in let cs = List.length sorts in - let cdn = List.length decl_names in let cd = List.length decls in - if csn <> cs || cdn <> cd then + if List.compare_length_with sort_names cs <> 0 + || List.compare_length_with decl_names cd <> 0 + then raise (Error "Argument size mismatch") else Z3native.parse_smtlib2_file ctx file_name @@ -2077,6 +2071,130 @@ struct end +module RCF = +struct + type rcf_num = Z3native.rcf_num + + let del (ctx:context) (a:rcf_num) : unit = Z3native.rcf_del ctx a + let del_list (ctx:context) (ns:rcf_num list) = List.iter (fun a -> Z3native.rcf_del ctx a) ns + let mk_rational (ctx:context) (v:string) = Z3native.rcf_mk_rational ctx v + let mk_small_int (ctx:context) (v:int) = Z3native.rcf_mk_small_int ctx v + + let mk_pi (ctx:context) = Z3native.rcf_mk_pi ctx + let mk_e (ctx:context) = Z3native.rcf_mk_e ctx + let mk_infinitesimal (ctx:context) = Z3native.rcf_mk_infinitesimal ctx + + let mk_roots (ctx:context) (a:rcf_num list) = + let n, r = Z3native.rcf_mk_roots ctx (List.length a) a in + let _i, l = + (* keep only the first `n` elements of the list `r` *) + List.fold_left (fun (i, acc) x -> + if i = 0 then i, acc + else (i - 1, x :: acc) + ) (n, []) r + in + List.rev l + + let add (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_add ctx a b + let sub (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_sub ctx a b + let mul (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_mul ctx a b + let div (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_div ctx a b + + let neg (ctx:context) (a:rcf_num) = Z3native.rcf_neg ctx a + let inv (ctx:context) (a:rcf_num) = Z3native.rcf_neg ctx a + + let power (ctx:context) (a:rcf_num) (k:int) = Z3native.rcf_power ctx a k + + let lt (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_lt ctx a b + let gt (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_gt ctx a b + let le (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_le ctx a b + let ge (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_ge ctx a b + let eq (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_eq ctx a b + let neq (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_neq ctx a b + + let num_to_string (ctx:context) (a:rcf_num) (compact:bool) (html:bool) = Z3native.rcf_num_to_string ctx a compact html + let num_to_decimal_string (ctx:context) (a:rcf_num) (prec:int) = Z3native.rcf_num_to_decimal_string ctx a prec + let get_numerator_denominator (ctx:context) (a:rcf_num) = Z3native.rcf_get_numerator_denominator ctx a + + let is_rational (ctx:context) (a:rcf_num) = Z3native.rcf_is_rational ctx a + let is_algebraic (ctx:context) (a:rcf_num) = Z3native.rcf_is_algebraic ctx a + let is_infinitesimal (ctx:context) (a:rcf_num) = Z3native.rcf_is_infinitesimal ctx a + let is_transcendental (ctx:context) (a:rcf_num) = Z3native.rcf_is_transcendental ctx a + + let extension_index (ctx:context) (a:rcf_num) = Z3native.rcf_extension_index ctx a + let transcendental_name (ctx:context) (a:rcf_num) = Z3native.rcf_transcendental_name ctx a + let infinitesimal_name (ctx:context) (a:rcf_num) = Z3native.rcf_infinitesimal_name ctx a + + let num_coefficients (ctx:context) (a:rcf_num) = Z3native.rcf_num_coefficients ctx a + let get_coefficient (ctx:context) (a:rcf_num) (i:int) = Z3native.rcf_coefficient ctx a i + + let coefficients (ctx:context) (a:rcf_num) = + List.init (num_coefficients ctx a) (fun i -> Z3native.rcf_coefficient ctx a i) + + type interval = { + lower_is_inf : bool; + lower_is_open : bool; + lower : rcf_num; + upper_is_inf : bool; + upper_is_open : bool; + upper : rcf_num; + } + + let root_interval (ctx:context) (a:rcf_num) = + let ok, linf, lopen, l, uinf, uopen, u = Z3native.rcf_interval ctx a in + let i:interval = { + lower_is_inf = linf != 0; + lower_is_open = lopen != 0; + lower = l; + upper_is_inf = uinf != 0; + upper_is_open = uopen != 0; + upper = u } in + if ok != 0 then Some i else None + + let sign_condition_sign (ctx:context) (a:rcf_num) (i:int) = Z3native.rcf_sign_condition_sign ctx a i + + let sign_condition_coefficient (ctx:context) (a:rcf_num) (i:int) (j:int) = Z3native.rcf_sign_condition_coefficient ctx a i j + + let num_sign_condition_coefficients (ctx:context) (a:rcf_num) (i:int) = Z3native.rcf_num_sign_condition_coefficients ctx a i + + let sign_condition_coefficients (ctx:context) (a:rcf_num) (i:int) = + let n = Z3native.rcf_num_sign_condition_coefficients ctx a i in + List.init n (fun j -> Z3native.rcf_sign_condition_coefficient ctx a i j) + + let sign_conditions (ctx:context) (a:rcf_num) = + let n = Z3native.rcf_num_sign_conditions ctx a in + List.init n (fun i -> + (let nc = Z3native.rcf_num_sign_condition_coefficients ctx a i in + List.init nc (fun j -> Z3native.rcf_sign_condition_coefficient ctx a i j)), + Z3native.rcf_sign_condition_sign ctx a i) + + type root = { + obj : rcf_num; + polynomial : rcf_num list; + interval : interval option; + sign_conditions : (rcf_num list * int) list; + } + + let roots (ctx:context) (a:rcf_num list) = + let rs = mk_roots ctx a in + List.map + (fun r -> { + obj = r; + polynomial = coefficients ctx r; + interval = root_interval ctx r; + sign_conditions = sign_conditions ctx r}) + rs + + let del_root (ctx:context) (r:root) = + del ctx r.obj; + List.iter (fun n -> del ctx n) r.polynomial; + List.iter (fun (ns, _) -> del_list ctx ns) r.sign_conditions + + let del_roots (ctx:context) (rs:root list) = + List.iter (fun r -> del_root ctx r) rs +end + + let set_global_param = Z3native.global_param_set let get_global_param id = diff --git a/src/api/ml/z3.mli b/src/api/ml/z3.mli index 53e92b491e4..5320fc38ef8 100644 --- a/src/api/ml/z3.mli +++ b/src/api/ml/z3.mli @@ -48,6 +48,12 @@ type context *) val mk_context : (string * string) list -> context +(** Interrupt the execution of a Z3 procedure. + + This procedure can be used to interrupt: solvers, simplifiers and tactics. + Note: Tactic.interrupt is an alias for this. *) +val interrupt: context -> unit + (** Interaction logging for Z3 Interaction logs are used to record calls into the API into a text file. The text file can be replayed using z3. It has to be the same version of z3 @@ -1068,13 +1074,13 @@ sig if the corresponding sort reference is 0, then the value in sort_refs should be an index referring to one of the recursive datatypes that is declared. *) val mk_constructor_s : context -> string -> Symbol.symbol -> Symbol.symbol list -> Sort.sort option list -> int list -> Constructor.constructor - + (* Create a forward reference to a recursive datatype being declared. The forward reference can be used in a nested occurrence: the range of an array or as element sort of a sequence. The forward reference should only be used when used in an accessor for a recursive datatype that gets declared. *) val mk_sort_ref : context -> Symbol.symbol -> Sort.sort - + (* [mk_sort_ref_s ctx s] is [mk_sort_ref ctx (Symbol.mk_string ctx s)] *) val mk_sort_ref_s : context -> string -> Sort.sort @@ -1653,8 +1659,8 @@ sig - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0]. - If [t2] is zero, then the result is is not uniquely specified. - It can be set to any value that satisfies the constraints + If [t2] is zero, then the result is is not uniquely specified. + It can be set to any value that satisfies the constraints where signed division is used. The arguments must have the same bit-vector sort. *) val mk_sdiv : context -> Expr.expr -> Expr.expr -> Expr.expr @@ -1662,8 +1668,8 @@ sig (** Unsigned remainder. It is defined as [t1 - (t1 /u t2) * t2], where [/u] represents unsigned division. - If [t2] is zero, then the result is not uniquely specified. - It can be set to any value that satisfies the constraints + If [t2] is zero, then the result is not uniquely specified. + It can be set to any value that satisfies the constraints where unsigned remainder is used. The arguments must have the same bit-vector sort. *) val mk_urem : context -> Expr.expr -> Expr.expr -> Expr.expr @@ -1673,16 +1679,16 @@ sig It is defined as [t1 - (t1 /s t2) * t2], where [/s] represents signed division. The most significant bit (sign) of the result is equal to the most significant bit of \c t1. - If [t2] is zero, then the result is not uniquely specified. - It can be set to any value that satisfies the constraints + If [t2] is zero, then the result is not uniquely specified. + It can be set to any value that satisfies the constraints where signed remainder is used. The arguments must have the same bit-vector sort. *) val mk_srem : context -> Expr.expr -> Expr.expr -> Expr.expr (** Two's complement signed remainder (sign follows divisor). - If [t2] is zero, then the result is not uniquely specified. - It can be set to any value that satisfies the constraints + If [t2] is zero, then the result is not uniquely specified. + It can be set to any value that satisfies the constraints where two's complement signed remainder is used. The arguments must have the same bit-vector sort. *) val mk_smod : context -> Expr.expr -> Expr.expr -> Expr.expr @@ -1864,7 +1870,7 @@ sig end (** Sequences, Strings and Regular Expressions **) -module Seq : +module Seq : sig (** create a sequence sort *) val mk_seq_sort : context -> Sort.sort -> Sort.sort @@ -1872,9 +1878,9 @@ sig (** test if sort is a sequence sort *) val is_seq_sort : context -> Sort.sort -> bool - (** create regular expression sorts over sequences of the argument sort *) + (** create regular expression sorts over sequences of the argument sort *) val mk_re_sort : context -> Sort.sort -> Sort.sort - + (** test if sort is a regular expression sort *) val is_re_sort : context -> Sort.sort -> bool @@ -1885,7 +1891,7 @@ sig val mk_char_sort : context -> Sort.sort (** test if sort is a string sort (a sequence of 8-bit bit-vectors) *) - val is_string_sort : context -> Sort.sort -> bool + val is_string_sort : context -> Sort.sort -> bool (** test if sort is a char sort *) val is_char_sort : context -> Sort.sort -> bool @@ -1894,51 +1900,51 @@ sig val mk_string : context -> string -> Expr.expr (** test if expression is a string *) - val is_string : context -> Expr.expr -> bool + val is_string : context -> Expr.expr -> bool (** retrieve string from string Expr.expr *) - val get_string : context -> Expr.expr -> string + val get_string : context -> Expr.expr -> string (** the empty sequence over base sort *) - val mk_seq_empty : context -> Sort.sort -> Expr.expr + val mk_seq_empty : context -> Sort.sort -> Expr.expr (** a unit sequence *) - val mk_seq_unit : context -> Expr.expr -> Expr.expr + val mk_seq_unit : context -> Expr.expr -> Expr.expr (** sequence concatenation *) - val mk_seq_concat : context -> Expr.expr list -> Expr.expr + val mk_seq_concat : context -> Expr.expr list -> Expr.expr (** predicate if the first argument is a prefix of the second *) - val mk_seq_prefix : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_prefix : context -> Expr.expr -> Expr.expr -> Expr.expr (** predicate if the first argument is a suffix of the second *) - val mk_seq_suffix : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_suffix : context -> Expr.expr -> Expr.expr -> Expr.expr (** predicate if the first argument contains the second *) - val mk_seq_contains : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_contains : context -> Expr.expr -> Expr.expr -> Expr.expr (** extract sub-sequence starting at index given by second argument and of length provided by third argument *) - val mk_seq_extract : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_extract : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr (** replace first occurrence of second argument by third *) - val mk_seq_replace : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_replace : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr (** a unit sequence at index provided by second argument *) - val mk_seq_at : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_at : context -> Expr.expr -> Expr.expr -> Expr.expr (** length of a sequence *) - val mk_seq_length : context -> Expr.expr -> Expr.expr - - (** [mk_seq_nth ctx s index] retrieves from [s] the element at position [index]. + val mk_seq_length : context -> Expr.expr -> Expr.expr + + (** [mk_seq_nth ctx s index] retrieves from [s] the element at position [index]. The function is under-specified if the index is out of bounds. *) val mk_seq_nth : context -> Expr.expr -> Expr.expr -> Expr.expr (** index of the first occurrence of the second argument in the first *) - val mk_seq_index : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_index : context -> Expr.expr -> Expr.expr -> Expr.expr -> Expr.expr (** [mk_seq_last_index ctx s substr] occurence of [substr] in the sequence [s] *) val mk_seq_last_index : context -> Expr.expr -> Expr.expr -> Expr.expr - + (** retrieve integer expression encoded in string *) val mk_str_to_int : context -> Expr.expr -> Expr.expr @@ -1950,7 +1956,7 @@ sig val mk_str_lt : context -> Expr.expr -> Expr.expr -> Expr.expr (** convert an integer expression to a string *) - val mk_int_to_str : context -> Expr.expr -> Expr.expr + val mk_int_to_str : context -> Expr.expr -> Expr.expr (** [mk_string_to_code ctx s] convert a unit length string [s] to integer code *) val mk_string_to_code : context -> Expr.expr -> Expr.expr @@ -1965,43 +1971,43 @@ sig val mk_sbv_to_str : context -> Expr.expr -> Expr.expr (** create regular expression that accepts the argument sequence *) - val mk_seq_to_re : context -> Expr.expr -> Expr.expr + val mk_seq_to_re : context -> Expr.expr -> Expr.expr (** regular expression membership predicate *) - val mk_seq_in_re : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_seq_in_re : context -> Expr.expr -> Expr.expr -> Expr.expr (** regular expression plus *) - val mk_re_plus : context -> Expr.expr -> Expr.expr + val mk_re_plus : context -> Expr.expr -> Expr.expr (** regular expression star *) - val mk_re_star : context -> Expr.expr -> Expr.expr + val mk_re_star : context -> Expr.expr -> Expr.expr (** optional regular expression *) - val mk_re_option : context -> Expr.expr -> Expr.expr + val mk_re_option : context -> Expr.expr -> Expr.expr (** union of regular expressions *) - val mk_re_union : context -> Expr.expr list -> Expr.expr + val mk_re_union : context -> Expr.expr list -> Expr.expr (** concatenation of regular expressions *) - val mk_re_concat : context -> Expr.expr list -> Expr.expr - + val mk_re_concat : context -> Expr.expr list -> Expr.expr + (** regular expression for the range between two characters *) - val mk_re_range : context -> Expr.expr -> Expr.expr -> Expr.expr + val mk_re_range : context -> Expr.expr -> Expr.expr -> Expr.expr (** bounded loop regular expression *) - val mk_re_loop : context -> Expr.expr -> int -> int -> Expr.expr - + val mk_re_loop : context -> Expr.expr -> int -> int -> Expr.expr + (** intersection of regular expressions *) val mk_re_intersect : context -> Expr.expr list -> Expr.expr (** the regular expression complement *) - val mk_re_complement : context -> Expr.expr -> Expr.expr + val mk_re_complement : context -> Expr.expr -> Expr.expr (** the regular expression that accepts no sequences *) - val mk_re_empty : context -> Sort.sort -> Expr.expr + val mk_re_empty : context -> Sort.sort -> Expr.expr (** the regular expression that accepts all sequences *) - val mk_re_full : context -> Sort.sort -> Expr.expr + val mk_re_full : context -> Sort.sort -> Expr.expr (** [mk_char ctx i] converts an integer to a character *) val mk_char : context -> int -> Expr.expr @@ -2339,7 +2345,7 @@ sig (** Retrieves the sign of a floating-point literal. *) val get_numeral_sign : context -> Expr.expr -> bool * int - (** Return the sign of a floating-point numeral as a bit-vector expression. + (** Return the sign of a floating-point numeral as a bit-vector expression. Remark: NaN's do not have a bit-vector sign, so they are invalid arguments. *) val get_numeral_sign_bv : context -> Expr.expr -> Expr.expr @@ -2349,11 +2355,11 @@ sig (** Return the exponent value of a floating-point numeral as a signed integer *) val get_numeral_exponent_int : context -> Expr.expr -> bool -> bool * int64 - (** Return the exponent of a floating-point numeral as a bit-vector expression. + (** Return the exponent of a floating-point numeral as a bit-vector expression. Remark: NaN's do not have a bit-vector exponent, so they are invalid arguments. *) val get_numeral_exponent_bv : context -> Expr.expr -> bool -> Expr.expr - (** Return the significand value of a floating-point numeral as a bit-vector expression. + (** Return the significand value of a floating-point numeral as a bit-vector expression. Remark: NaN's do not have a bit-vector significand, so they are invalid arguments. *) val get_numeral_significand_bv : context -> Expr.expr -> Expr.expr @@ -2386,7 +2392,7 @@ sig (** Indicates whether a floating-point numeral is negative. *) val is_numeral_negative : context -> Expr.expr -> bool - + (** Conversion of a floating-point term into a bit-vector term in IEEE 754-2008 format. *) val mk_to_ieee_bv : context -> Expr.expr -> Expr.expr @@ -3218,6 +3224,9 @@ sig (** The value of a particular statistical counter. *) val get : statistics -> string -> Entry.statistics_entry option + + (** The estimated allocated memory in bytes. *) + val get_estimated_alloc_size : unit -> int64 end (** Solvers *) @@ -3260,7 +3269,7 @@ sig (** Assert multiple constraints (cs) into the solver, and track them (in the unsat) core using the Boolean constants in ps. - + This API is an alternative to {!check} with assumptions for extracting unsat cores. Both APIs can be used in the same solver. The unsat core will contain a combination of the Boolean variables provided using {!assert_and_track} and the Boolean literals @@ -3269,10 +3278,10 @@ sig (** Assert a constraint (c) into the solver, and track it (in the unsat) core using the Boolean constant p. - - This API is an alternative to {!check} with assumptions for extracting unsat cores. - Both APIs can be used in the same solver. The unsat core will contain a combination - of the Boolean variables provided using {!assert_and_track} and the Boolean literals + + This API is an alternative to {!check} with assumptions for extracting unsat cores. + Both APIs can be used in the same solver. The unsat core will contain a combination + of the Boolean variables provided using {!assert_and_track} and the Boolean literals provided using {!check} with assumptions. *) val assert_and_track : solver -> Expr.expr -> Expr.expr -> unit @@ -3342,6 +3351,15 @@ sig (** A string representation of the solver. *) val to_string : solver -> string + + (** Solver local interrupt. + + Normally you should use Z3_interrupt to cancel solvers because only + one solver is enabled concurrently per context. + However, per GitHub issue #1006, there are use cases where + it is more convenient to cancel a specific solver. Solvers + that are not selected for interrupts are left alone.*) + val interrupt: context -> solver -> unit end (** Fixedpoint solving *) @@ -3466,7 +3484,7 @@ sig (** Add minimization objective. *) val minimize : optimize -> Expr.expr -> handle - (** Checks whether the assertions in the context are satisfiable and solves objectives. *) + (** Check consistency and produce optimal values. *) val check : optimize -> Solver.status (** Retrieve model from satisfiable context *) @@ -3496,23 +3514,23 @@ sig val get_statistics : optimize -> Statistics.statistics (** Parse an SMT-LIB2 file with assertions, soft constraints and optimization - objectives. Add the parsed constraints and objectives to the optimization + objectives. Add the parsed constraints and objectives to the optimization context. *) val from_file : optimize -> string -> unit - (** Parse an SMT-LIB2 string with assertions, soft constraints and optimization - objectives. Add the parsed constraints and objectives to the optimization + (** Parse an SMT-LIB2 string with assertions, soft constraints and optimization + objectives. Add the parsed constraints and objectives to the optimization context. *) val from_string : optimize -> string -> unit - - (** Return the set of asserted formulas on the optimization context. *) + + (** Return the set of asserted formulas on the optimization context. *) val get_assertions : optimize -> Expr.expr list - (** Return objectives on the optimization context. If the objective function - is a max-sat objective it is returned as a Pseudo-Boolean (minimization) - sum of the form (+ (if f1 w1 0) (if f2 w2 0) ...). If the objective - function is entered as a maximization objective, then return the - corresponding minimization objective. In this way the resulting + (** Return objectives on the optimization context. If the objective function + is a max-sat objective it is returned as a Pseudo-Boolean (minimization) + sum of the form (+ (if f1 w1 0) (if f2 w2 0) ...). If the objective + function is entered as a maximization objective, then return the + corresponding minimization objective. In this way the resulting objective function is always returned as a minimization objective. *) val get_objectives : optimize -> Expr.expr list end @@ -3535,6 +3553,151 @@ sig val parse_smtlib2_file : context -> string -> Symbol.symbol list -> Sort.sort list -> Symbol.symbol list -> FuncDecl.func_decl list -> AST.ASTVector.ast_vector end +(** Real closed field *) +module RCF : +sig + type rcf_num + + (** Delete a RCF numeral created using the RCF API. *) + val del : context -> rcf_num -> unit + + (** Delete RCF numerals created using the RCF API. *) + val del_list : context -> rcf_num list -> unit + + (** Return a RCF rational using the given string. *) + val mk_rational : context -> string -> rcf_num + + (** Return a RCF small integer. *) + val mk_small_int : context -> int -> rcf_num + + (** Return Pi *) + val mk_pi : context -> rcf_num + + (** Return e (Euler's constant) *) + val mk_e : context -> rcf_num + + (** Return a new infinitesimal that is smaller than all elements in the Z3 field. *) + val mk_infinitesimal : context -> rcf_num + + (** Extract the roots of a polynomial. Precondition: The input polynomial is not the zero polynomial. *) + val mk_roots : context -> rcf_num list -> rcf_num list + + (** Addition *) + val add : context -> rcf_num -> rcf_num -> rcf_num + + (** Subtraction *) + val sub : context -> rcf_num -> rcf_num -> rcf_num + + (** Multiplication *) + val mul : context -> rcf_num -> rcf_num -> rcf_num + + (** Division *) + val div : context -> rcf_num -> rcf_num -> rcf_num + + (** Negation *) + val neg : context -> rcf_num -> rcf_num + + (** Multiplicative Inverse *) + val inv : context -> rcf_num -> rcf_num + + (** Power *) + val power : context -> rcf_num -> int -> rcf_num + + (** less-than *) + val lt : context -> rcf_num -> rcf_num -> bool + + (** greater-than *) + val gt : context -> rcf_num -> rcf_num -> bool + + (** less-than or equal *) + val le : context -> rcf_num -> rcf_num -> bool + + (** greater-than or equal *) + val ge : context -> rcf_num -> rcf_num -> bool + + (** equality *) + val eq : context -> rcf_num -> rcf_num -> bool + + (** not equal *) + val neq : context -> rcf_num -> rcf_num -> bool + + (** Convert the RCF numeral into a string. *) + val num_to_string : context -> rcf_num -> bool -> bool -> string + + (** Convert the RCF numeral into a string in decimal notation. *) + val num_to_decimal_string : context -> rcf_num -> int -> string + + (** Extract the "numerator" and "denominator" of the given RCF numeral. + We have that \ccode{a = n/d}, moreover \c n and \c d are not represented using rational functions. *) + val get_numerator_denominator : context -> rcf_num -> (rcf_num * rcf_num) + + (** Return \c true if \c a represents a rational number. *) + val is_rational : context -> rcf_num -> bool + + (** Return \c true if \c a represents an algebraic number. *) + val is_algebraic : context -> rcf_num -> bool + + (** Return \c true if \c a represents an infinitesimal. *) + val is_infinitesimal : context -> rcf_num -> bool + + (** Return \c true if \c a represents a transcendental number. *) + val is_transcendental : context -> rcf_num -> bool + + (** Return the index of a field extension. *) + val extension_index : context -> rcf_num -> int + + (** Return the name of a transcendental. *) + val transcendental_name : context -> rcf_num -> Symbol.symbol + + (** Return the name of an infinitesimal. *) + val infinitesimal_name : context -> rcf_num -> Symbol.symbol + + (** Return the number of coefficients in an algebraic number. *) + val num_coefficients : context -> rcf_num -> int + + (** Extract a coefficient from an algebraic number. *) + val get_coefficient : context -> rcf_num -> int -> rcf_num + + (** Extract the coefficients from an algebraic number. *) + val coefficients : context -> rcf_num -> rcf_num list + + (** Extract the sign of a sign condition from an algebraic number. *) + val sign_condition_sign : context -> rcf_num -> int -> int + + (** Return the size of a sign condition polynomial. *) + val num_sign_condition_coefficients : context -> rcf_num -> int -> int + + (** Extract a sign condition polynomial coefficient from an algebraic number. *) + val sign_condition_coefficient : context -> rcf_num -> int -> int -> rcf_num + + (** Extract sign conditions from an algebraic number. *) + val sign_conditions : context -> rcf_num -> (rcf_num list * int) list + + (** Extract the interval from an algebraic number. *) + type interval = { + lower_is_inf : bool; + lower_is_open : bool; + lower : rcf_num; + upper_is_inf : bool; + upper_is_open : bool; + upper : rcf_num; + } + + val root_interval : context -> rcf_num -> interval option + + type root = { + obj : rcf_num; + polynomial : rcf_num list; + interval : interval option; + sign_conditions : (rcf_num list * int) list; + } + + val roots : context -> rcf_num list -> root list + + val del_root : context -> root -> unit + + val del_roots : context -> root list -> unit +end (** Set a global (or module) parameter, which is shared by all Z3 contexts. diff --git a/src/api/ml/z3native.ml.pre b/src/api/ml/z3native.ml.pre index 1d75d5d1efa..fe4e8a194d6 100644 --- a/src/api/ml/z3native.ml.pre +++ b/src/api/ml/z3native.ml.pre @@ -4,36 +4,36 @@ open Z3enums (**/**) type ptr -and symbol = ptr -and config = ptr -and context = ptr -and ast = ptr -and app = ast -and sort = ast -and func_decl = ast -and pattern = ast -and model = ptr -and literals = ptr -and constructor = ptr -and constructor_list = ptr -and solver = ptr -and solver_callback = ptr -and goal = ptr -and tactic = ptr -and simplifier = ptr -and params = ptr -and parser_context = ptr -and probe = ptr -and stats = ptr -and ast_vector = ptr -and ast_map = ptr -and apply_result = ptr -and func_interp = ptr -and func_entry = ptr -and fixedpoint = ptr -and optimize = ptr -and param_descrs = ptr -and rcf_num = ptr +type symbol = ptr +type config = ptr +type context = ptr +type ast = ptr +type app = ast +type sort = ast +type func_decl = ast +type pattern = ast +type model = ptr +type literals = ptr +type constructor = ptr +type constructor_list = ptr +type solver = ptr +type solver_callback = ptr +type goal = ptr +type tactic = ptr +type simplifier = ptr +type params = ptr +type parser_context = ptr +type probe = ptr +type stats = ptr +type ast_vector = ptr +type ast_map = ptr +type apply_result = ptr +type func_interp = ptr +type func_entry = ptr +type fixedpoint = ptr +type optimize = ptr +type param_descrs = ptr +type rcf_num = ptr external set_internal_error_handler : ptr -> unit = "n_set_internal_error_handler" diff --git a/src/api/ml/z3native_stubs.c.pre b/src/api/ml/z3native_stubs.c.pre index 038b80725ee..c8afe90b909 100644 --- a/src/api/ml/z3native_stubs.c.pre +++ b/src/api/ml/z3native_stubs.c.pre @@ -2,6 +2,12 @@ #include #include +#ifndef __STDC_NO_ATOMICS__ +#include +#else +#define _Atomic(T) T +#endif + #ifdef __cplusplus extern "C" { #endif @@ -118,7 +124,7 @@ int compare_pointers(void* pt1, void* pt2) { blocks that get copied. */ typedef struct { Z3_context ctx; - unsigned long obj_count; + _Atomic(unsigned long) obj_count; } Z3_context_plus_data; /* A context is wrapped to an OCaml value by storing a pointer diff --git a/src/api/python/CMakeLists.txt b/src/api/python/CMakeLists.txt index f5e449ea84f..5da66dfe4b2 100644 --- a/src/api/python/CMakeLists.txt +++ b/src/api/python/CMakeLists.txt @@ -33,7 +33,7 @@ endforeach() # Generate z3core.py add_custom_command(OUTPUT "${z3py_bindings_build_dest}/z3/z3core.py" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/update_api.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--z3py-output-dir" @@ -49,7 +49,7 @@ list(APPEND build_z3_python_bindings_target_depends "${z3py_bindings_build_dest} # Generate z3consts.py add_custom_command(OUTPUT "${z3py_bindings_build_dest}/z3/z3consts.py" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/mk_consts_files.py" ${Z3_FULL_PATH_API_HEADER_FILES_TO_SCAN} "--z3py-output-dir" @@ -96,7 +96,7 @@ if (Z3_INSTALL_PYTHON_BINDINGS) if (NOT DEFINED CMAKE_INSTALL_PYTHON_PKG_DIR) message(STATUS "CMAKE_INSTALL_PYTHON_PKG_DIR not set. Trying to guess") execute_process( - COMMAND "${PYTHON_EXECUTABLE}" "-c" + COMMAND "${Python3_EXECUTABLE}" "-c" "import sysconfig; print(sysconfig.get_path('purelib'))" RESULT_VARIABLE exit_code OUTPUT_VARIABLE CMAKE_INSTALL_PYTHON_PKG_DIR diff --git a/src/api/python/setup.py b/src/api/python/setup.py index 81c472232a4..5faf5aad17a 100644 --- a/src/api/python/setup.py +++ b/src/api/python/setup.py @@ -18,7 +18,7 @@ build_env = dict(os.environ) build_env['PYTHON'] = sys.executable -build_env['CXXFLAGS'] = build_env.get('CXXFLAGS', '') + " -std=c++11" +build_env['CXXFLAGS'] = build_env.get('CXXFLAGS', '') + " -std=c++17" # determine where we're building and where sources are ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) @@ -292,11 +292,20 @@ def run(self): distos = RELEASE_METADATA[2] if distos in ('debian', 'ubuntu'): raise Exception( - "Linux binary distributions must be built on centos to conform to PEP 513 or alpine if targetting musl" + "Linux binary distributions must be built on centos to conform to PEP 513 or alpine if targeting musl" ) elif distos == 'glibc': if arch == 'x64': plat_name = 'manylinux2014_x86_64' + elif arch == 'arm64' or arch == 'aarch64': + # context on why are we match on arm64 + # but use aarch64 on the plat_name is + # due to a workaround current python + # legacy build doesn't support aarch64 + # so using the currently supported arm64 + # build and simply rename it to aarch64 + # see full context on #7148 + plat_name = 'manylinux2014_aarch64' else: plat_name = 'manylinux2014_i686' elif distos == 'linux' and os_id == 'alpine': @@ -313,6 +322,8 @@ def run(self): osver = RELEASE_METADATA[3] if osver.count('.') > 1: osver = '.'.join(osver.split('.')[:2]) + if osver.startswith("11"): + osver = "11_0" if arch == 'x64': plat_name ='macosx_%s_x86_64' % osver.replace('.', '_') elif arch == 'arm64': @@ -339,6 +350,7 @@ def run(self): license='MIT License', keywords=['z3', 'smt', 'sat', 'prover', 'theorem'], packages=['z3'], + install_requires = ["importlib-resources; python_version < '3.9'"], include_package_data=True, package_data={ 'z3': [os.path.join('lib', '*'), os.path.join('include', '*.h'), os.path.join('include', 'c++', '*.h')] diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 6b79dd1fe87..16db39afd60 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -683,6 +683,8 @@ def _to_sort_ref(s, ctx): return SeqSortRef(s, ctx) elif k == Z3_CHAR_SORT: return CharSortRef(s, ctx) + elif k == Z3_TYPE_VAR: + return TypeVarRef(s, ctx) return SortRef(s, ctx) @@ -708,6 +710,26 @@ def DeclareSort(name, ctx=None): ctx = _get_ctx(ctx) return SortRef(Z3_mk_uninterpreted_sort(ctx.ref(), to_symbol(name, ctx)), ctx) +class TypeVarRef(SortRef): + """Type variable reference""" + + def subsort(self, other): + return True + + def cast(self, val): + return val + + +def DeclareTypeVar(name, ctx=None): + """Create a new type variable named `name`. + + If `ctx=None`, then the new sort is declared in the global Z3Py context. + + """ + ctx = _get_ctx(ctx) + return TypeVarRef(Z3_mk_type_variable(ctx.ref(), to_symbol(name, ctx)), ctx) + + ######################################### # # Function Declarations @@ -1549,6 +1571,14 @@ class BoolRef(ExprRef): def sort(self): return BoolSortRef(Z3_get_sort(self.ctx_ref(), self.as_ast()), self.ctx) + def __add__(self, other): + if isinstance(other, BoolRef): + other = If(other, 1, 0) + return If(self, 1, 0) + other + + def __radd__(self, other): + return self + other + def __rmul__(self, other): return self * other @@ -1562,6 +1592,20 @@ def __mul__(self, other): if isinstance(other, BoolRef): other = If(other, 1, 0) return If(self, other, 0) + + def __and__(self, other): + return And(self, other) + + def __or__(self, other): + return Or(self, other) + + def __xor__(self, other): + return Xor(self, other) + + def __invert__(self): + return Not(self) + + def is_bool(a): @@ -2059,6 +2103,16 @@ def weight(self): """ return int(Z3_get_quantifier_weight(self.ctx_ref(), self.ast)) + def skolem_id(self): + """Return the skolem id of `self`. + """ + return _symbol2py(self.ctx, Z3_get_quantifier_skolem_id(self.ctx_ref(), self.ast)) + + def qid(self): + """Return the quantifier id of `self`. + """ + return _symbol2py(self.ctx, Z3_get_quantifier_id(self.ctx_ref(), self.ast)) + def num_patterns(self): """Return the number of patterns (i.e., quantifier instantiation hints) in `self`. @@ -5385,7 +5439,7 @@ def EnumSort(name, values, ctx=None): """ if z3_debug(): _z3_assert(isinstance(name, str), "Name must be a string") - _z3_assert(all([isinstance(v, str) for v in values]), "Eumeration sort values must be strings") + _z3_assert(all([isinstance(v, str) for v in values]), "Enumeration sort values must be strings") _z3_assert(len(values) > 0, "At least one value expected") ctx = _get_ctx(ctx) num = len(values) @@ -6909,6 +6963,13 @@ def __del__(self): if self.solver is not None and self.ctx.ref() is not None and Z3_solver_dec_ref is not None: Z3_solver_dec_ref(self.ctx.ref(), self.solver) + def __enter__(self): + self.push() + return self + + def __exit__(self, *exc_info): + self.pop() + def set(self, *args, **keys): """Set a configuration option. The method `help()` return a string containing all available options. @@ -7136,6 +7197,13 @@ def import_model_converter(self, other): """Import model converter from other into the current solver""" Z3_solver_import_model_converter(self.ctx.ref(), other.solver, self.solver) + def interrupt(self): + """Interrupt the execution of the solver object. + Remarks: This ensures that the interrupt applies only + to the given solver object and it applies only if it is running. + """ + Z3_solver_interrupt(self.ctx.ref(), self.solver) + def unsat_core(self): """Return a subset (as an AST vector) of the assumptions provided to the last check(). @@ -7986,7 +8054,7 @@ def pop(self): Z3_optimize_pop(self.ctx.ref(), self.optimize) def check(self, *assumptions): - """Check satisfiability while optimizing objective functions.""" + """Check consistency and produce optimal values.""" assumptions = _get_args(assumptions) num = len(assumptions) _assumptions = (Ast * num)() @@ -8916,7 +8984,7 @@ def substitute_funs(t, *m): m = m1 if z3_debug(): _z3_assert(is_expr(t), "Z3 expression expected") - _z3_assert(all([isinstance(p, tuple) and is_func_decl(p[0]) and is_expr(p[1]) for p in m]), "Z3 invalid substitution, funcion pairs expected.") + _z3_assert(all([isinstance(p, tuple) and is_func_decl(p[0]) and is_expr(p[1]) for p in m]), "Z3 invalid substitution, function pairs expected.") num = len(m) _from = (FuncDecl * num)() _to = (Ast * num)() @@ -9001,7 +9069,7 @@ def AtMost(*args): def AtLeast(*args): - """Create an at-most Pseudo-Boolean k constraint. + """Create an at-least Pseudo-Boolean k constraint. >>> a, b, c = Bools('a b c') >>> f = AtLeast(a, b, c, 2) @@ -10901,10 +10969,10 @@ def CharVal(ch, ctx=None): raise Z3Exception("character value should be an ordinal") return _to_expr_ref(Z3_mk_char(ctx.ref(), ch), ctx) -def CharFromBv(ch, ctx=None): - if not is_expr(ch): - raise Z3Expression("Bit-vector expression needed") - return _to_expr_ref(Z3_mk_char_from_bv(ch.ctx_ref(), ch.as_ast()), ch.ctx) +def CharFromBv(bv): + if not is_expr(bv): + raise Z3Exception("Bit-vector expression needed") + return _to_expr_ref(Z3_mk_char_from_bv(bv.ctx_ref(), bv.as_ast()), bv.ctx) def CharToBv(ch, ctx=None): ch = _coerce_char(ch, ctx) @@ -11279,6 +11347,8 @@ def Plus(re): >>> print(simplify(InRe("", re))) False """ + if z3_debug(): + _z3_assert(is_expr(re), "expression expected") return ReRef(Z3_mk_re_plus(re.ctx_ref(), re.as_ast()), re.ctx) @@ -11292,6 +11362,8 @@ def Option(re): >>> print(simplify(InRe("aa", re))) False """ + if z3_debug(): + _z3_assert(is_expr(re), "expression expected") return ReRef(Z3_mk_re_option(re.ctx_ref(), re.as_ast()), re.ctx) @@ -11310,6 +11382,8 @@ def Star(re): >>> print(simplify(InRe("", re))) True """ + if z3_debug(): + _z3_assert(is_expr(re), "expression expected") return ReRef(Z3_mk_re_star(re.ctx_ref(), re.as_ast()), re.ctx) @@ -11323,6 +11397,8 @@ def Loop(re, lo, hi=0): >>> print(simplify(InRe("", re))) False """ + if z3_debug(): + _z3_assert(is_expr(re), "expression expected") return ReRef(Z3_mk_re_loop(re.ctx_ref(), re.as_ast(), lo, hi), re.ctx) @@ -11336,11 +11412,17 @@ def Range(lo, hi, ctx=None): """ lo = _coerce_seq(lo, ctx) hi = _coerce_seq(hi, ctx) + if z3_debug(): + _z3_assert(is_expr(lo), "expression expected") + _z3_assert(is_expr(hi), "expression expected") return ReRef(Z3_mk_re_range(lo.ctx_ref(), lo.ast, hi.ast), lo.ctx) def Diff(a, b, ctx=None): """Create the difference regular expression """ + if z3_debug(): + _z3_assert(is_expr(a), "expression expected") + _z3_assert(is_expr(b), "expression expected") return ReRef(Z3_mk_re_diff(a.ctx_ref(), a.ast, b.ast), a.ctx) def AllChar(regex_sort, ctx=None): @@ -11395,11 +11477,12 @@ def to_AstVectorObj(ptr,): # for UserPropagator we use a global dictionary, which isn't great code. _my_hacky_class = None -def on_clause_eh(ctx, p, clause): +def on_clause_eh(ctx, p, n, dep, clause): onc = _my_hacky_class p = _to_expr_ref(to_Ast(p), onc.ctx) clause = AstVector(to_AstVectorObj(clause), onc.ctx) - onc.on_clause(p, clause) + deps = [dep[i] for i in range(n)] + onc.on_clause(p, deps, clause) _on_clause_eh = Z3_on_clause_eh(on_clause_eh) @@ -11487,52 +11570,54 @@ def user_prop_fresh(ctx, _new_ctx): def user_prop_fixed(ctx, cb, id, value): prop = _prop_closures.get(ctx) - prop.cb = cb + old_cb = prop.cb + prop.cb = cb id = _to_expr_ref(to_Ast(id), prop.ctx()) value = _to_expr_ref(to_Ast(value), prop.ctx()) prop.fixed(id, value) - prop.cb = None + prop.cb = old_cb def user_prop_created(ctx, cb, id): prop = _prop_closures.get(ctx) + old_cb = prop.cb prop.cb = cb id = _to_expr_ref(to_Ast(id), prop.ctx()) prop.created(id) - prop.cb = None + prop.cb = old_cb + def user_prop_final(ctx, cb): prop = _prop_closures.get(ctx) + old_cb = prop.cb prop.cb = cb prop.final() - prop.cb = None + prop.cb = old_cb def user_prop_eq(ctx, cb, x, y): prop = _prop_closures.get(ctx) + old_cb = prop.cb prop.cb = cb x = _to_expr_ref(to_Ast(x), prop.ctx()) y = _to_expr_ref(to_Ast(y), prop.ctx()) prop.eq(x, y) - prop.cb = None + prop.cb = old_cb def user_prop_diseq(ctx, cb, x, y): prop = _prop_closures.get(ctx) + old_cb = prop.cb prop.cb = cb x = _to_expr_ref(to_Ast(x), prop.ctx()) y = _to_expr_ref(to_Ast(y), prop.ctx()) prop.diseq(x, y) - prop.cb = None + prop.cb = old_cb -# TODO The decision callback is not fully implemented. -# It needs to handle the ast*, unsigned* idx, and Z3_lbool* -def user_prop_decide(ctx, cb, t_ref, idx_ref, phase_ref): +def user_prop_decide(ctx, cb, t, idx, phase): prop = _prop_closures.get(ctx) + old_cb = prop.cb prop.cb = cb t = _to_expr_ref(to_Ast(t_ref), prop.ctx()) - t, idx, phase = prop.decide(t, idx, phase) - t_ref = t - idx_ref = idx - phase_ref = phase - prop.cb = None + prop.decide(t, idx, phase) + prop.cb = old_cb _user_prop_push = Z3_push_eh(user_prop_push) @@ -11573,7 +11658,7 @@ class UserPropagateBase: # # Either solver is set or ctx is set. - # Propagators that are created throuh callbacks + # Propagators that are created through callbacks # to "fresh" inherit the context of that is supplied # as argument to the callback. # This context should not be deleted. It is owned by the solver. @@ -11678,7 +11763,7 @@ def add(self, e): # split on. A phase of true = 1/false = -1/undef = 0 = let solver decide is the last argument. # def next_split(self, t, idx, phase): - Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) + return Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) # # Propagation can only be invoked as during a fixed or final callback. @@ -11688,7 +11773,7 @@ def propagate(self, e, ids, eqs=[]): num_eqs = len(eqs) _lhs, _num_lhs = _to_ast_array([x for x, y in eqs]) _rhs, _num_rhs = _to_ast_array([y for x, y in eqs]) - Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( + return Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast) def conflict(self, deps = [], eqs = []): diff --git a/src/api/python/z3/z3printer.py b/src/api/python/z3/z3printer.py index 228f212d97b..2da5f89dace 100644 --- a/src/api/python/z3/z3printer.py +++ b/src/api/python/z3/z3printer.py @@ -99,6 +99,7 @@ def _z3_assert(cond, msg): Z3_OP_ARRAY_EXT: "Ext", Z3_OP_PB_AT_MOST: "AtMost", + Z3_OP_PB_AT_LEAST: "AtLeast", Z3_OP_PB_LE: "PbLe", Z3_OP_PB_GE: "PbGe", Z3_OP_PB_EQ: "PbEq", @@ -252,11 +253,11 @@ def _is_html_left_assoc(k): def _is_add(k): - return k == Z3_OP_ADD or k == Z3_OP_BADD + return k == Z3_OP_ADD or k == Z3_OP_BADD or k == Z3_OP_FPA_ADD def _is_sub(k): - return k == Z3_OP_SUB or k == Z3_OP_BSUB + return k == Z3_OP_SUB or k == Z3_OP_BSUB or k == Z3_OP_FPA_SUB if sys.version_info.major < 3: @@ -890,9 +891,21 @@ def pp_fp(self, a, d, xs): if self.is_infix(k) and n >= 3: rm = a.arg(0) if z3.is_fprm_value(rm) and z3.get_default_rounding_mode(a.ctx).eq(rm): - arg1 = to_format(self.pp_expr(a.arg(1), d + 1, xs)) - arg2 = to_format(self.pp_expr(a.arg(2), d + 1, xs)) + p = self.get_precedence(k) r = [] + x = a.arg(1) + y = a.arg(2) + arg1 = to_format(self.pp_expr(x, d + 1, xs)) + arg2 = to_format(self.pp_expr(y, d + 1, xs)) + if z3.is_app(x): + child_k = x.decl().kind() + if child_k != k and self.is_infix(child_k) and self.get_precedence(child_k) > p: + arg1 = self.add_paren(arg1) + if z3.is_app(y): + child_k = y.decl().kind() + if child_k != k and self.is_infix(child_k) and self.get_precedence(child_k) > p: + arg2 = self.add_paren(arg2) + r.append(arg1) r.append(to_format(" ")) r.append(to_format(op)) @@ -1099,6 +1112,10 @@ def pp_atmost(self, a, d, f, xs): k = Z3_get_decl_int_parameter(a.ctx_ref(), a.decl().ast, 0) return seq1(self.pp_name(a), [seq3([self.pp_expr(ch, d + 1, xs) for ch in a.children()]), to_format(k)]) + def pp_atleast(self, a, d, f, xs): + k = Z3_get_decl_int_parameter(a.ctx_ref(), a.decl().ast, 0) + return seq1(self.pp_name(a), [seq3([self.pp_expr(ch, d + 1, xs) for ch in a.children()]), to_format(k)]) + def pp_pbcmp(self, a, d, f, xs): chs = a.children() rchs = range(len(chs)) @@ -1151,6 +1168,8 @@ def pp_app(self, a, d, xs): return self.pp_K(a, d, xs) elif k == Z3_OP_PB_AT_MOST: return self.pp_atmost(a, d, f, xs) + elif k == Z3_OP_PB_AT_LEAST: + return self.pp_atleast(a, d, f, xs) elif k == Z3_OP_PB_LE: return self.pp_pbcmp(a, d, f, xs) elif k == Z3_OP_PB_GE: diff --git a/src/api/python/z3/z3util.py b/src/api/python/z3/z3util.py index 071e2b60e05..b60038f2f67 100644 --- a/src/api/python/z3/z3util.py +++ b/src/api/python/z3/z3util.py @@ -275,7 +275,7 @@ def _f(): def get_models(f, k): """ - Returns the first k models satisfiying f. + Returns the first k models satisfying f. If f is not satisfiable, returns False. If f cannot be solved, returns None If f is satisfiable, returns the first k models @@ -485,7 +485,7 @@ def model_str(m, as_str=True): x = 10, y = 3 EXAMPLES: - see doctest exampels from function prove() + see doctest examples from function prove() """ if z3_debug(): diff --git a/src/api/z3.h b/src/api/z3.h index 2ad00d835ee..178f8f6185f 100644 --- a/src/api/z3.h +++ b/src/api/z3.h @@ -20,7 +20,6 @@ Module Name: #pragma once -#include #include #include #include "z3_macros.h" diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 0582ffa3700..cbf9803dbbf 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -5,7 +5,6 @@ #pragma once DEFINE_TYPE(Z3_symbol); -DEFINE_TYPE(Z3_literals); DEFINE_TYPE(Z3_config); DEFINE_TYPE(Z3_context); DEFINE_TYPE(Z3_sort); @@ -151,6 +150,7 @@ typedef enum Z3_SEQ_SORT, Z3_RE_SORT, Z3_CHAR_SORT, + Z3_TYPE_VAR, Z3_UNKNOWN_SORT = 1000 } Z3_sort_kind; @@ -1363,7 +1363,7 @@ typedef enum { - Z3_NO_PARSER: Parser output is not available, that is, user didn't invoke #Z3_parse_smtlib2_string or #Z3_parse_smtlib2_file. - Z3_INVALID_PATTERN: Invalid pattern was used to build a quantifier. - Z3_MEMOUT_FAIL: A memory allocation failure was encountered. - - Z3_FILE_ACCESS_ERRROR: A file could not be accessed. + - Z3_FILE_ACCESS_ERROR: A file could not be accessed. - Z3_INVALID_USAGE: API call is invalid in the current state. - Z3_INTERNAL_FATAL: An error internal to Z3 occurred. - Z3_DEC_REF_ERROR: Trying to decrement the reference counter of an AST that was deleted or the reference counter was not initialized with #Z3_inc_ref. @@ -1397,7 +1397,6 @@ typedef enum def_Type('FUNC_DECL', 'Z3_func_decl', 'FuncDecl') def_Type('PATTERN', 'Z3_pattern', 'Pattern') def_Type('MODEL', 'Z3_model', 'ModelObj') - def_Type('LITERALS', 'Z3_literals', 'Literals') def_Type('CONSTRUCTOR', 'Z3_constructor', 'Constructor') def_Type('CONSTRUCTOR_LIST', 'Z3_constructor_list', 'ConstructorList') def_Type('SOLVER', 'Z3_solver', 'SolverObj') @@ -1435,8 +1434,8 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); -Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase)); -Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, Z3_ast_vector literals)); +Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase)); +Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, unsigned n, unsigned const* deps, Z3_ast_vector literals)); /** @@ -1883,6 +1882,17 @@ extern "C" { */ Z3_sort Z3_API Z3_mk_uninterpreted_sort(Z3_context c, Z3_symbol s); + /** + \brief Create a type variable. + + Functions using type variables can be applied to instantiations that match the signature + of the function. Assertions using type variables correspond to assertions over all possible + instantiations. + + def_API('Z3_mk_type_variable', SORT, (_in(CONTEXT), _in(SYMBOL))) + */ + Z3_sort Z3_API Z3_mk_type_variable(Z3_context c, Z3_symbol s); + /** \brief Create the Boolean type. @@ -2072,6 +2082,16 @@ extern "C" { unsigned sort_refs[] ); + /** + \brief Retrieve the number of fields of a constructor + + \param c logical context. + \param constr constructor. + + def_API('Z3_constructor_num_fields', UINT, (_in(CONTEXT), _in(CONSTRUCTOR))) + */ + unsigned Z3_API Z3_constructor_num_fields(Z3_context c, Z3_constructor constr); + /** \brief Reclaim memory allocated to constructor. @@ -2172,7 +2192,7 @@ extern "C" { \brief Query constructor for declared functions. \param c logical context. - \param constr constructor container. The container must have been passed in to a #Z3_mk_datatype call. + \param constr constructor container. The container must have been passed into a #Z3_mk_datatype call. \param num_fields number of accessor fields in the constructor. \param constructor constructor function declaration, allocated by user. \param tester constructor test function declaration, allocated by user. @@ -2317,7 +2337,7 @@ extern "C" { \param args constants that are used as arguments to the recursive function in the definition. \param body body of the recursive function - After declaring a recursive function or a collection of mutually recursive functions, use + After declaring a recursive function or a collection of mutually recursive functions, use this function to provide the definition for the recursive function. \sa Z3_mk_rec_func_decl @@ -3614,7 +3634,7 @@ extern "C" { /** \brief Retrieve the string constant stored in \c s. - Characters outside the basic printiable ASCII range are escaped. + Characters outside the basic printable ASCII range are escaped. \pre Z3_is_string(c, s) @@ -4897,7 +4917,7 @@ extern "C" { /** \brief Return a hash code for the given AST. The hash code is structural but two different AST objects can map to the same hash. - The result of \c Z3_get_ast_id returns an indentifier that is unique over the + The result of \c Z3_get_ast_id returns an identifier that is unique over the set of live AST objects. def_API('Z3_get_ast_hash', UINT, (_in(CONTEXT), _in(AST))) @@ -5186,6 +5206,24 @@ extern "C" { */ unsigned Z3_API Z3_get_quantifier_weight(Z3_context c, Z3_ast a); + /** + \brief Obtain skolem id of quantifier. + + \pre Z3_get_ast_kind(a) == Z3_QUANTIFIER_AST + + def_API('Z3_get_quantifier_skolem_id', SYMBOL, (_in(CONTEXT), _in(AST))) + */ + Z3_symbol Z3_API Z3_get_quantifier_skolem_id(Z3_context c, Z3_ast a); + + /** + \brief Obtain id of quantifier. + + \pre Z3_get_ast_kind(a) == Z3_QUANTIFIER_AST + + def_API('Z3_get_quantifier_id', SYMBOL, (_in(CONTEXT), _in(AST))) + */ + Z3_symbol Z3_API Z3_get_quantifier_id(Z3_context c, Z3_ast a); + /** \brief Return number of patterns used in quantifier. @@ -5346,7 +5384,7 @@ extern "C" { Z3_ast const to[]); /** - \brief Substitute funcions in \c from with new expressions in \c to. + \brief Substitute functions in \c from with new expressions in \c to. The expressions in \c to can have free variables. The free variable in \c to at index 0 refers to the first argument of \c from, the free variable at index 1 corresponds to the second argument. @@ -5584,14 +5622,14 @@ extern "C" { void Z3_API Z3_add_const_interp(Z3_context c, Z3_model m, Z3_func_decl f, Z3_ast a); /** - \brief Increment the reference counter of the given Z3_func_interp object. + \brief Increment the reference counter of the given \c Z3_func_interp object. def_API('Z3_func_interp_inc_ref', VOID, (_in(CONTEXT), _in(FUNC_INTERP))) */ void Z3_API Z3_func_interp_inc_ref(Z3_context c, Z3_func_interp f); /** - \brief Decrement the reference counter of the given Z3_func_interp object. + \brief Decrement the reference counter of the given \c Z3_func_interp object. def_API('Z3_func_interp_dec_ref', VOID, (_in(CONTEXT), _in(FUNC_INTERP))) */ @@ -7026,13 +7064,13 @@ extern "C" { Z3_on_clause_eh on_clause_eh); /** - \brief register a user-properator with the solver. + \brief register a user-propagator with the solver. \param c - context. \param s - solver object. \param user_context - a context used to maintain state for callbacks. \param push_eh - a callback invoked when scopes are pushed - \param pop_eh - a callback invoked when scopes are poped + \param pop_eh - a callback invoked when scopes are popped \param fresh_eh - a solver may spawn new solvers internally. This callback is used to produce a fresh user_context to be associated with fresh solvers. def_API('Z3_solver_propagate_init', VOID, (_in(CONTEXT), _in(SOLVER), _in(VOID_PTR), _fnptr(Z3_push_eh), _fnptr(Z3_pop_eh), _fnptr(Z3_fresh_eh))) @@ -7098,20 +7136,21 @@ extern "C" { /** \brief register a callback when the solver decides to split on a registered expression. - The callback may set the passed expression to another registered expression which will be selected instead. - In case the expression is a bitvector the bit to split on is determined by the bit argument and the - truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide. + The callback may change the arguments by providing other values by calling \ref Z3_solver_next_split def_API('Z3_solver_propagate_decide', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_decide_eh))) */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); /** - Sets the next expression to split on + Sets the next (registered) expression to split on. + The function returns false and ignores the given expression in case the expression is already assigned internally + (due to relevancy propagation, this assignments might not have been reported yet by the fixed callback). + In case the function is called in the decide callback, it overrides the currently selected variable and phase. - def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) + def_API('Z3_solver_next_split', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) */ - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); /** Create uninterpreted function declaration for the user propagator. @@ -7145,15 +7184,29 @@ extern "C" { void Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e); /** - \brief propagate a consequence based on fixed values. - This is a callback a client may invoke during the fixed_eh callback. - The callback adds a propagation consequence based on the fixed values of the - \c ids. + \brief propagate a consequence based on fixed values and equalities. + A client may invoke it during the \c propagate_fixed, \c propagate_eq, \c propagate_diseq, and \c propagate_final callbacks. + The callback adds a propagation consequence based on the fixed values passed \c ids and equalities \c eqs based on parameters \c lhs, \c rhs. - def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) + The solver might discard the propagation in case it is true in the current state. + The function returns false in this case; otw. the function returns true. + At least one propagation in the final callback has to return true in order to + prevent the solver from finishing. + + Assume the callback has the signature: \c propagate_consequence_eh(context, solver_cb, num_ids, ids, num_eqs, lhs, rhs, consequence). + \param c - context + \param solver_cb - solver callback + \param num_ids - number of fixed terms used as premise to propagation + \param ids - array of length \c num_ids containing terms that are fixed in the current scope + \param num_eqs - number of equalities used as premise to propagation + \param lhs - left side of equalities + \param rhs - right side of equalities + \param consequence - consequence to propagate. It is typically an atomic formula, but it can be an arbitrary formula. + + def_API('Z3_solver_propagate_consequence', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) */ - - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); + + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/api/z3_fpa.h b/src/api/z3_fpa.h index 4ab0d91775b..9c4b22153fe 100644 --- a/src/api/z3_fpa.h +++ b/src/api/z3_fpa.h @@ -32,6 +32,12 @@ extern "C" { \param c logical context + \sa Z3_mk_fpa_round_nearest_ties_to_away or Z3_mk_fpa_rna + \sa Z3_mk_fpa_round_nearest_ties_to_even or Z3_mk_fpa_rne + \sa Z3_mk_fpa_round_toward_negative or Z3_mk_fpa_rtn + \sa Z3_mk_fpa_round_toward_positive or Z3_mk_fpa_rtp + \sa Z3_mk_fpa_round_toward_zero or Z3_mk_fpa_rtz + def_API('Z3_mk_fpa_rounding_mode_sort', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_rounding_mode_sort(Z3_context c); @@ -39,8 +45,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the NearestTiesToEven rounding mode. + This is the same as #Z3_mk_fpa_rne. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_round_nearest_ties_to_away + \sa Z3_mk_fpa_round_toward_negative + \sa Z3_mk_fpa_round_toward_positive + \sa Z3_mk_fpa_round_toward_zero + def_API('Z3_mk_fpa_round_nearest_ties_to_even', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_round_nearest_ties_to_even(Z3_context c); @@ -48,8 +62,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the NearestTiesToEven rounding mode. + This is the same as #Z3_mk_fpa_round_nearest_ties_to_even. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_rna + \sa Z3_mk_fpa_rtn + \sa Z3_mk_fpa_rtp + \sa Z3_mk_fpa_rtz + def_API('Z3_mk_fpa_rne', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_rne(Z3_context c); @@ -57,8 +79,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the NearestTiesToAway rounding mode. + This is the same as #Z3_mk_fpa_rna. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_round_nearest_ties_to_even + \sa Z3_mk_fpa_round_toward_negative + \sa Z3_mk_fpa_round_toward_positive + \sa Z3_mk_fpa_round_toward_zero + def_API('Z3_mk_fpa_round_nearest_ties_to_away', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_round_nearest_ties_to_away(Z3_context c); @@ -66,8 +96,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the NearestTiesToAway rounding mode. + This is the same as #Z3_mk_fpa_round_nearest_ties_to_away. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_rne + \sa Z3_mk_fpa_rtn + \sa Z3_mk_fpa_rtp + \sa Z3_mk_fpa_rtz + def_API('Z3_mk_fpa_rna', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_rna(Z3_context c); @@ -75,8 +113,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardPositive rounding mode. + This is the same as #Z3_mk_fpa_rtp. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_round_nearest_ties_to_away + \sa Z3_mk_fpa_round_nearest_ties_to_even + \sa Z3_mk_fpa_round_toward_negative + \sa Z3_mk_fpa_round_toward_zero + def_API('Z3_mk_fpa_round_toward_positive', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_round_toward_positive(Z3_context c); @@ -84,8 +130,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardPositive rounding mode. + This is the same as #Z3_mk_fpa_round_toward_positive. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_rna + \sa Z3_mk_fpa_rne + \sa Z3_mk_fpa_rtn + \sa Z3_mk_fpa_rtz + def_API('Z3_mk_fpa_rtp', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_rtp(Z3_context c); @@ -93,8 +147,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardNegative rounding mode. + This is the same as #Z3_mk_fpa_rtn. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_round_nearest_ties_to_away + \sa Z3_mk_fpa_round_nearest_ties_to_even + \sa Z3_mk_fpa_round_toward_positive + \sa Z3_mk_fpa_round_toward_zero + def_API('Z3_mk_fpa_round_toward_negative', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_round_toward_negative(Z3_context c); @@ -102,8 +164,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardNegative rounding mode. + This is the same as #Z3_mk_fpa_round_toward_negative. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_rna + \sa Z3_mk_fpa_rne + \sa Z3_mk_fpa_rtp + \sa Z3_mk_fpa_rtz + def_API('Z3_mk_fpa_rtn', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_rtn(Z3_context c); @@ -111,8 +181,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardZero rounding mode. + This is the same as #Z3_mk_fpa_rtz. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_round_nearest_ties_to_away + \sa Z3_mk_fpa_round_nearest_ties_to_even + \sa Z3_mk_fpa_round_toward_negative + \sa Z3_mk_fpa_round_toward_positive + def_API('Z3_mk_fpa_round_toward_zero', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_round_toward_zero(Z3_context c); @@ -120,8 +198,16 @@ extern "C" { /** \brief Create a numeral of RoundingMode sort which represents the TowardZero rounding mode. + This is the same as #Z3_mk_fpa_round_toward_zero. + \param c logical context + \sa Z3_mk_fpa_rounding_mode_sort + \sa Z3_mk_fpa_rna + \sa Z3_mk_fpa_rne + \sa Z3_mk_fpa_rtn + \sa Z3_mk_fpa_rtp + def_API('Z3_mk_fpa_rtz', AST, (_in(CONTEXT),)) */ Z3_ast Z3_API Z3_mk_fpa_rtz(Z3_context c); @@ -135,6 +221,11 @@ extern "C" { \remark \c ebits must be larger than 1 and \c sbits must be larger than 2. + \sa Z3_mk_fpa_sort_half or Z3_mk_fpa_sort_16 + \sa Z3_mk_fpa_sort_single or Z3_mk_fpa_sort_32 + \sa Z3_mk_fpa_sort_double or Z3_mk_fpa_sort_64 + \sa Z3_mk_fpa_sort_quadruple or Z3_mk_fpa_sort_128 + def_API('Z3_mk_fpa_sort', SORT, (_in(CONTEXT), _in(UINT), _in(UINT))) */ Z3_sort Z3_API Z3_mk_fpa_sort(Z3_context c, unsigned ebits, unsigned sbits); @@ -142,8 +233,15 @@ extern "C" { /** \brief Create the half-precision (16-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_16. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_single + \sa Z3_mk_fpa_sort_double + \sa Z3_mk_fpa_sort_quadruple + def_API('Z3_mk_fpa_sort_half', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_half(Z3_context c); @@ -151,8 +249,15 @@ extern "C" { /** \brief Create the half-precision (16-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_half. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_32 + \sa Z3_mk_fpa_sort_64 + \sa Z3_mk_fpa_sort_128 + def_API('Z3_mk_fpa_sort_16', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_16(Z3_context c); @@ -160,8 +265,15 @@ extern "C" { /** \brief Create the single-precision (32-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_32. + \param c logical context. + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_half + \sa Z3_mk_fpa_sort_double + \sa Z3_mk_fpa_sort_quadruple + def_API('Z3_mk_fpa_sort_single', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_single(Z3_context c); @@ -169,8 +281,15 @@ extern "C" { /** \brief Create the single-precision (32-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_single. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_16 + \sa Z3_mk_fpa_sort_64 + \sa Z3_mk_fpa_sort_128 + def_API('Z3_mk_fpa_sort_32', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_32(Z3_context c); @@ -178,8 +297,15 @@ extern "C" { /** \brief Create the double-precision (64-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_64. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_half + \sa Z3_mk_fpa_sort_single + \sa Z3_mk_fpa_sort_quadruple + def_API('Z3_mk_fpa_sort_double', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_double(Z3_context c); @@ -187,8 +313,15 @@ extern "C" { /** \brief Create the double-precision (64-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_double. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_16 + \sa Z3_mk_fpa_sort_32 + \sa Z3_mk_fpa_sort_128 + def_API('Z3_mk_fpa_sort_64', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_64(Z3_context c); @@ -196,8 +329,15 @@ extern "C" { /** \brief Create the quadruple-precision (128-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_128. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_half + \sa Z3_mk_fpa_sort_single + \sa Z3_mk_fpa_sort_double + def_API('Z3_mk_fpa_sort_quadruple', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_quadruple(Z3_context c); @@ -205,8 +345,15 @@ extern "C" { /** \brief Create the quadruple-precision (128-bit) FloatingPoint sort. + This is the same as #Z3_mk_fpa_sort_quadruple. + \param c logical context + \sa Z3_mk_fpa_sort + \sa Z3_mk_fpa_sort_16 + \sa Z3_mk_fpa_sort_32 + \sa Z3_mk_fpa_sort_64 + def_API('Z3_mk_fpa_sort_128', SORT, (_in(CONTEXT),)) */ Z3_sort Z3_API Z3_mk_fpa_sort_128(Z3_context c); @@ -218,6 +365,7 @@ extern "C" { \param s target sort \sa Z3_mk_fpa_inf + \sa Z3_mk_fpa_is_nan \sa Z3_mk_fpa_zero def_API('Z3_mk_fpa_nan', AST, (_in(CONTEXT),_in(SORT))) @@ -233,6 +381,7 @@ extern "C" { When \c negative is \c true, -oo will be generated instead of +oo. + \sa Z3_mk_fpa_is_infinite \sa Z3_mk_fpa_nan \sa Z3_mk_fpa_zero @@ -250,6 +399,7 @@ extern "C" { When \c negative is \c true, -zero will be generated instead of +zero. \sa Z3_mk_fpa_inf + \sa Z3_mk_fpa_is_zero \sa Z3_mk_fpa_nan def_API('Z3_mk_fpa_zero', AST, (_in(CONTEXT),_in(SORT),_in(BOOL))) @@ -397,6 +547,10 @@ extern "C" { \param c logical context \param t term of FloatingPoint sort + \sa Z3_mk_fpa_is_negative + \sa Z3_mk_fpa_is_positive + \sa Z3_mk_fpa_neg + def_API('Z3_mk_fpa_abs', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_abs(Z3_context c, Z3_ast t); @@ -407,6 +561,10 @@ extern "C" { \param c logical context \param t term of FloatingPoint sort + \sa Z3_mk_fpa_abs + \sa Z3_mk_fpa_is_negative + \sa Z3_mk_fpa_is_positive + def_API('Z3_mk_fpa_neg', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_neg(Z3_context c, Z3_ast t); @@ -533,6 +691,8 @@ extern "C" { \c t1, \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_max + def_API('Z3_mk_fpa_min', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_min(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -546,6 +706,8 @@ extern "C" { \c t1, \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_min + def_API('Z3_mk_fpa_max', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_max(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -559,6 +721,11 @@ extern "C" { \c t1 and \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_eq + \sa Z3_mk_fpa_geq + \sa Z3_mk_fpa_gt + \sa Z3_mk_fpa_lt + def_API('Z3_mk_fpa_leq', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_leq(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -572,6 +739,11 @@ extern "C" { \c t1 and \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_eq + \sa Z3_mk_fpa_geq + \sa Z3_mk_fpa_gt + \sa Z3_mk_fpa_leq + def_API('Z3_mk_fpa_lt', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_lt(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -585,6 +757,11 @@ extern "C" { \c t1 and \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_eq + \sa Z3_mk_fpa_gt + \sa Z3_mk_fpa_leq + \sa Z3_mk_fpa_lt + def_API('Z3_mk_fpa_geq', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_geq(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -598,6 +775,11 @@ extern "C" { \c t1 and \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_eq + \sa Z3_mk_fpa_geq + \sa Z3_mk_fpa_leq + \sa Z3_mk_fpa_lt + def_API('Z3_mk_fpa_gt', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_gt(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -613,6 +795,11 @@ extern "C" { \c t1 and \c t2 must have the same FloatingPoint sort. + \sa Z3_mk_fpa_geq + \sa Z3_mk_fpa_gt + \sa Z3_mk_fpa_leq + \sa Z3_mk_fpa_lt + def_API('Z3_mk_fpa_eq', AST, (_in(CONTEXT),_in(AST),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_eq(Z3_context c, Z3_ast t1, Z3_ast t2); @@ -625,6 +812,11 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_is_infinite + \sa Z3_mk_fpa_is_nan + \sa Z3_mk_fpa_is_subnormal + \sa Z3_mk_fpa_is_zero + def_API('Z3_mk_fpa_is_normal', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_normal(Z3_context c, Z3_ast t); @@ -637,6 +829,11 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_is_infinite + \sa Z3_mk_fpa_is_nan + \sa Z3_mk_fpa_is_normal + \sa Z3_mk_fpa_is_zero + def_API('Z3_mk_fpa_is_subnormal', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_subnormal(Z3_context c, Z3_ast t); @@ -649,6 +846,12 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_is_infinite + \sa Z3_mk_fpa_is_nan + \sa Z3_mk_fpa_is_normal + \sa Z3_mk_fpa_is_subnormal + \sa Z3_mk_fpa_zero + def_API('Z3_mk_fpa_is_zero', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_zero(Z3_context c, Z3_ast t); @@ -661,6 +864,12 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_inf + \sa Z3_mk_fpa_is_nan + \sa Z3_mk_fpa_is_normal + \sa Z3_mk_fpa_is_subnormal + \sa Z3_mk_fpa_is_zero + def_API('Z3_mk_fpa_is_infinite', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_infinite(Z3_context c, Z3_ast t); @@ -673,6 +882,12 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_is_infinite + \sa Z3_mk_fpa_is_normal + \sa Z3_mk_fpa_is_subnormal + \sa Z3_mk_fpa_is_zero + \sa Z3_mk_fpa_nan + def_API('Z3_mk_fpa_is_nan', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_nan(Z3_context c, Z3_ast t); @@ -685,6 +900,10 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_abs + \sa Z3_mk_fpa_is_positive + \sa Z3_mk_fpa_neg + def_API('Z3_mk_fpa_is_negative', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_negative(Z3_context c, Z3_ast t); @@ -697,6 +916,10 @@ extern "C" { \c t must have FloatingPoint sort. + \sa Z3_mk_fpa_abs + \sa Z3_mk_fpa_is_negative + \sa Z3_mk_fpa_neg + def_API('Z3_mk_fpa_is_positive', AST, (_in(CONTEXT),_in(AST))) */ Z3_ast Z3_API Z3_mk_fpa_is_positive(Z3_context c, Z3_ast t); @@ -848,6 +1071,8 @@ extern "C" { \param c logical context \param s FloatingPoint sort + \sa Z3_fpa_get_sbits + def_API('Z3_fpa_get_ebits', UINT, (_in(CONTEXT),_in(SORT))) */ unsigned Z3_API Z3_fpa_get_ebits(Z3_context c, Z3_sort s); @@ -858,6 +1083,8 @@ extern "C" { \param c logical context \param s FloatingPoint sort + \sa Z3_fpa_get_ebits + def_API('Z3_fpa_get_sbits', UINT, (_in(CONTEXT),_in(SORT))) */ unsigned Z3_API Z3_fpa_get_sbits(Z3_context c, Z3_sort s); @@ -868,6 +1095,11 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_inf + \sa Z3_fpa_is_numeral_normal + \sa Z3_fpa_is_numeral_subnormal + \sa Z3_fpa_is_numeral_zero + def_API('Z3_fpa_is_numeral_nan', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_nan(Z3_context c, Z3_ast t); @@ -878,6 +1110,11 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_nan + \sa Z3_fpa_is_numeral_normal + \sa Z3_fpa_is_numeral_subnormal + \sa Z3_fpa_is_numeral_zero + def_API('Z3_fpa_is_numeral_inf', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_inf(Z3_context c, Z3_ast t); @@ -888,6 +1125,11 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_inf + \sa Z3_fpa_is_numeral_nan + \sa Z3_fpa_is_numeral_normal + \sa Z3_fpa_is_numeral_subnormal + def_API('Z3_fpa_is_numeral_zero', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_zero(Z3_context c, Z3_ast t); @@ -898,6 +1140,11 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_inf + \sa Z3_fpa_is_numeral_nan + \sa Z3_fpa_is_numeral_subnormal + \sa Z3_fpa_is_numeral_zero + def_API('Z3_fpa_is_numeral_normal', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_normal(Z3_context c, Z3_ast t); @@ -908,6 +1155,11 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_inf + \sa Z3_fpa_is_numeral_nan + \sa Z3_fpa_is_numeral_normal + \sa Z3_fpa_is_numeral_zero + def_API('Z3_fpa_is_numeral_subnormal', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_subnormal(Z3_context c, Z3_ast t); @@ -918,6 +1170,8 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_negative + def_API('Z3_fpa_is_numeral_positive', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_positive(Z3_context c, Z3_ast t); @@ -928,6 +1182,8 @@ extern "C" { \param c logical context \param t a floating-point numeral + \sa Z3_fpa_is_numeral_positive + def_API('Z3_fpa_is_numeral_negative', BOOL, (_in(CONTEXT), _in(AST))) */ bool Z3_API Z3_fpa_is_numeral_negative(Z3_context c, Z3_ast t); diff --git a/src/api/z3_rcf.h b/src/api/z3_rcf.h index 88c27db61f1..b3842f1b6c7 100644 --- a/src/api/z3_rcf.h +++ b/src/api/z3_rcf.h @@ -193,8 +193,124 @@ extern "C" { We have that \ccode{a = n/d}, moreover \c n and \c d are not represented using rational functions. def_API('Z3_rcf_get_numerator_denominator', VOID, (_in(CONTEXT), _in(RCF_NUM), _out(RCF_NUM), _out(RCF_NUM))) - */ - void Z3_API Z3_rcf_get_numerator_denominator(Z3_context c, Z3_rcf_num a, Z3_rcf_num * n, Z3_rcf_num * d); + */ + void Z3_API Z3_rcf_get_numerator_denominator(Z3_context c, Z3_rcf_num a, Z3_rcf_num * n, Z3_rcf_num * d); + + /** + \brief Return \c true if \c a represents a rational number. + + def_API('Z3_rcf_is_rational', BOOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + bool Z3_API Z3_rcf_is_rational(Z3_context c, Z3_rcf_num a); + + /** + \brief Return \c true if \c a represents an algebraic number. + + def_API('Z3_rcf_is_algebraic', BOOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + bool Z3_API Z3_rcf_is_algebraic(Z3_context c, Z3_rcf_num a); + + /** + \brief Return \c true if \c a represents an infinitesimal. + + def_API('Z3_rcf_is_infinitesimal', BOOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + bool Z3_API Z3_rcf_is_infinitesimal(Z3_context c, Z3_rcf_num a); + + /** + \brief Return \c true if \c a represents a transcendental number. + + def_API('Z3_rcf_is_transcendental', BOOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + bool Z3_API Z3_rcf_is_transcendental(Z3_context c, Z3_rcf_num a); + + /** + \brief Return the index of a field extension. + + def_API('Z3_rcf_extension_index', UINT, (_in(CONTEXT), _in(RCF_NUM))) + */ + unsigned Z3_API Z3_rcf_extension_index(Z3_context c, Z3_rcf_num a); + + /** + \brief Return the name of a transcendental. + + \pre Z3_rcf_is_transcendtal(ctx, a); + + def_API('Z3_rcf_transcendental_name', SYMBOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + Z3_symbol Z3_API Z3_rcf_transcendental_name(Z3_context c, Z3_rcf_num a); + + /** + \brief Return the name of an infinitesimal. + + \pre Z3_rcf_is_infinitesimal(ctx, a); + + def_API('Z3_rcf_infinitesimal_name', SYMBOL, (_in(CONTEXT), _in(RCF_NUM))) + */ + Z3_symbol Z3_API Z3_rcf_infinitesimal_name(Z3_context c, Z3_rcf_num a); + + /** + \brief Return the number of coefficients in an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_num_coefficients', UINT, (_in(CONTEXT), _in(RCF_NUM))) + */ + unsigned Z3_API Z3_rcf_num_coefficients(Z3_context c, Z3_rcf_num a); + + /** + \brief Extract a coefficient from an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_coefficient', RCF_NUM, (_in(CONTEXT), _in(RCF_NUM), _in(UINT))) + */ + Z3_rcf_num Z3_API Z3_rcf_coefficient(Z3_context c, Z3_rcf_num a, unsigned i); + + /** + \brief Extract an interval from an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_interval', INT, (_in(CONTEXT), _in(RCF_NUM), _out(INT), _out(INT), _out(RCF_NUM), _out(INT), _out(INT), _out(RCF_NUM))) + */ + int Z3_API Z3_rcf_interval(Z3_context c, Z3_rcf_num a, int * lower_is_inf, int * lower_is_open, Z3_rcf_num * lower, int * upper_is_inf, int * upper_is_open, Z3_rcf_num * upper); + + /** + \brief Return the number of sign conditions of an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_num_sign_conditions', UINT, (_in(CONTEXT), _in(RCF_NUM))) + */ + unsigned Z3_API Z3_rcf_num_sign_conditions(Z3_context c, Z3_rcf_num a); + + /** + \brief Extract the sign of a sign condition from an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_sign_condition_sign', INT, (_in(CONTEXT), _in(RCF_NUM), _in(UINT))) + */ + int Z3_API Z3_rcf_sign_condition_sign(Z3_context c, Z3_rcf_num a, unsigned i); + + /** + \brief Return the number of sign condition polynomial coefficients of an algebraic number. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_num_sign_condition_coefficients', UINT, (_in(CONTEXT), _in(RCF_NUM), _in(UINT))) + */ + unsigned Z3_API Z3_rcf_num_sign_condition_coefficients(Z3_context c, Z3_rcf_num a, unsigned i); + + /** + \brief Extract the j-th polynomial coefficient of the i-th sign condition. + + \pre Z3_rcf_is_algebraic(ctx, a); + + def_API('Z3_rcf_sign_condition_coefficient', RCF_NUM, (_in(CONTEXT), _in(RCF_NUM), _in(UINT), _in(UINT))) + */ + Z3_rcf_num Z3_API Z3_rcf_sign_condition_coefficient(Z3_context c, Z3_rcf_num a, unsigned i, unsigned j); /**@}*/ /**@}*/ diff --git a/src/api/z3_replayer.cpp b/src/api/z3_replayer.cpp index ad5bc952334..aebc2f9a52d 100644 --- a/src/api/z3_replayer.cpp +++ b/src/api/z3_replayer.cpp @@ -72,7 +72,7 @@ struct z3_replayer::imp { void check_arg(unsigned pos, value_kind k) const { if (pos >= m_args.size()) { - TRACE("z3_replayer", tout << "too few arguments " << m_args.size() << " expecting " << kind2string(k) << "\n";); + TRACE("z3_replayer", tout << pos << " too few arguments " << m_args.size() << " expecting " << kind2string(k) << "\n";); throw z3_replayer_exception("invalid argument reference"); } if (m_args[pos].m_kind != k) { diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 8dd8709648e..7a4a03a27a4 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -3,6 +3,7 @@ z3_add_component(ast act_cache.cpp arith_decl_plugin.cpp array_decl_plugin.cpp + array_peq.cpp ast.cpp ast_ll_pp.cpp ast_lt.cpp @@ -37,6 +38,8 @@ z3_add_component(ast num_occurs.cpp occurs.cpp pb_decl_plugin.cpp + polymorphism_inst.cpp + polymorphism_util.cpp pp.cpp quantifier_stat.cpp recfun_decl_plugin.cpp diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp index 4778caf8933..67a605869ca 100644 --- a/src/ast/arith_decl_plugin.cpp +++ b/src/ast/arith_decl_plugin.cpp @@ -370,7 +370,7 @@ inline func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, bool is_real) { if (is_real) { return m_manager->mk_func_decl(symbol("^0"), m_real_decl, m_real_decl, m_real_decl, func_decl_info(m_family_id, OP_POWER0)); } - return m_manager->mk_func_decl(symbol("^0"), m_int_decl, m_int_decl, m_int_decl, func_decl_info(m_family_id, OP_POWER0)); + return m_manager->mk_func_decl(symbol("^0"), m_int_decl, m_int_decl, m_real_decl, func_decl_info(m_family_id, OP_POWER0)); case OP_TO_REAL: return m_to_real_decl; case OP_TO_INT: return m_to_int_decl; case OP_IS_INT: return m_is_int_decl; @@ -508,6 +508,19 @@ static bool is_const_op(decl_kind k) { //k == OP_0_PW_0_REAL; } +symbol arith_decl_plugin::bv_symbol(decl_kind k) const { + switch (k) { + case OP_ARITH_BAND: return symbol("band"); + case OP_ARITH_SHL: return symbol("shl"); + case OP_ARITH_ASHR: return symbol("ashr"); + case OP_ARITH_LSHR: return symbol("lshr"); + default: + UNREACHABLE(); + } + return symbol(); +} + + func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { if (k == OP_NUM) @@ -523,6 +536,12 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } + if (k == OP_ARITH_BAND || k == OP_ARITH_SHL || k == OP_ARITH_ASHR || k == OP_ARITH_LSHR) { + if (arity != 2 || domain[0] != m_int_decl || domain[1] != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) + m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); + return m_manager->mk_func_decl(bv_symbol(k), 2, domain, m_int_decl, + func_decl_info(m_family_id, k, num_parameters, parameters)); + } if (m_manager->int_real_coercions() && use_coercion(k)) { return mk_func_decl(fix_kind(k, arity), has_real_arg(arity, domain, m_real_decl)); @@ -548,6 +567,14 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } + if (k == OP_ARITH_BAND || k == OP_ARITH_SHL || k == OP_ARITH_ASHR || k == OP_ARITH_LSHR) { + if (num_args != 2 || args[0]->get_sort() != m_int_decl || args[1]->get_sort() != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) + m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); + sort* domain[2] = { m_int_decl, m_int_decl }; + return m_manager->mk_func_decl(bv_symbol(k), 2, domain, m_int_decl, + func_decl_info(m_family_id, k, num_parameters, parameters)); + } + if (m_manager->int_real_coercions() && use_coercion(k)) { return mk_func_decl(fix_kind(k, num_args), has_real_arg(m_manager, num_args, args, m_real_decl)); } @@ -693,7 +720,16 @@ expr * arith_decl_plugin::get_some_value(sort * s) { return mk_numeral(rational(0), s == m_int_decl); } -bool arith_recognizers::is_numeral(expr const * n, rational & val, bool & is_int) const { +bool arith_util::is_numeral(expr const * n, rational & val, bool & is_int) const { + if (is_irrational_algebraic_numeral(n)) { + scoped_anum an(am()); + is_irrational_algebraic_numeral2(n, an); + if (am().is_rational(an)) { + am().to_rational(an, val); + is_int = val.is_int(); + return true; + } + } if (!is_app_of(n, arith_family_id, OP_NUM)) return false; func_decl * decl = to_app(n)->get_decl(); @@ -724,7 +760,7 @@ bool arith_recognizers::is_int_expr(expr const *e) const { if (is_to_real(e)) { // pass } - else if (is_numeral(e, r) && r.is_int()) { + else if (is_numeral(e) && is_int(e)) { // pass } else if (is_add(e) || is_mul(e)) { @@ -747,14 +783,14 @@ void arith_util::init_plugin() { m_plugin = static_cast(m_manager.get_plugin(arith_family_id)); } -bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) { +bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const { if (!is_app_of(n, arith_family_id, OP_IRRATIONAL_ALGEBRAIC_NUM)) return false; am().set(val, to_irrational_algebraic_numeral(n)); return true; } -algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) { +algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) const { SASSERT(is_irrational_algebraic_numeral(n)); return plugin().aw().to_anum(to_app(n)->get_decl()); } @@ -801,6 +837,29 @@ expr_ref arith_util::mk_add_simplify(unsigned sz, expr* const* args) { return result; } +bool arith_util::is_considered_partially_interpreted(func_decl* f, unsigned n, expr* const* args, func_decl_ref& f_out) { + if (is_decl_of(f, arith_family_id, OP_DIV) && n == 2 && !is_numeral(args[1])) { + f_out = mk_div0(); + return true; + } + if (is_decl_of(f, arith_family_id, OP_IDIV) && n == 2 && !is_numeral(args[1])) { + sort* rs[2] = { mk_int(), mk_int() }; + f_out = m_manager.mk_func_decl(arith_family_id, OP_IDIV0, 0, nullptr, 2, rs, mk_int()); + return true; + } + if (is_decl_of(f, arith_family_id, OP_MOD) && n == 2 && !is_numeral(args[1])) { + sort* rs[2] = { mk_int(), mk_int() }; + f_out = m_manager.mk_func_decl(arith_family_id, OP_MOD0, 0, nullptr, 2, rs, mk_int()); + return true; + } + if (is_decl_of(f, arith_family_id, OP_REM) && n == 2 && !is_numeral(args[1])) { + sort* rs[2] = { mk_int(), mk_int() }; + f_out = m_manager.mk_func_decl(arith_family_id, OP_MOD0, 0, nullptr, 2, rs, mk_int()); + return true; + } + return false; +} + bool arith_util::is_considered_uninterpreted(func_decl* f, unsigned n, expr* const* args, func_decl_ref& f_out) { rational r; if (is_decl_of(f, arith_family_id, OP_DIV) && n == 2 && is_numeral(args[1], r) && r.is_zero()) { @@ -834,7 +893,7 @@ bool arith_util::is_considered_uninterpreted(func_decl* f, unsigned n, expr* con func_decl* arith_util::mk_ipower0() { sort* s = mk_int(); sort* rs[2] = { s, s }; - return m_manager.mk_func_decl(arith_family_id, OP_POWER0, 0, nullptr, 2, rs, s); + return m_manager.mk_func_decl(arith_family_id, OP_POWER0, 0, nullptr, 2, rs, mk_real()); } func_decl* arith_util::mk_rpower0() { @@ -925,7 +984,8 @@ bool arith_util::is_extended_numeral(expr* term, rational& r) const { return true; } return false; - } while (false); + } + while (true); return false; } diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index 5dbf3e8cfc8..3f094d43f08 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -70,6 +70,11 @@ enum arith_op_kind { OP_ASINH, OP_ACOSH, OP_ATANH, + // Bit-vector functions + OP_ARITH_BAND, + OP_ARITH_SHL, + OP_ARITH_ASHR, + OP_ARITH_LSHR, // constants OP_PI, OP_E, @@ -148,6 +153,8 @@ class arith_decl_plugin : public decl_plugin { bool m_convert_int_numerals_to_real; + symbol bv_symbol(decl_kind k) const; + func_decl * mk_func_decl(decl_kind k, bool is_real); void set_manager(ast_manager * m, family_id id) override; decl_kind fix_kind(decl_kind k, unsigned arity); @@ -231,30 +238,22 @@ class arith_decl_plugin : public decl_plugin { executed in different threads. */ class arith_recognizers { + bool is_arith_op(expr const* n, decl_kind k, unsigned& sz, expr*& x, expr*& y) { + if (!is_app_of(n, arith_family_id, k)) + return false; + x = to_app(n)->get_arg(0); + y = to_app(n)->get_arg(1); + sz = to_app(n)->get_parameter(0).get_int(); + return true; + } public: family_id get_family_id() const { return arith_family_id; } bool is_arith_expr(expr const * n) const { return is_app(n) && to_app(n)->get_family_id() == arith_family_id; } - bool is_irrational_algebraic_numeral(expr const * n) const; - bool is_unsigned(expr const * n, unsigned& u) const { - rational val; - bool is_int = true; - return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); - } - bool is_numeral(expr const * n, rational & val, bool & is_int) const; - bool is_numeral(expr const * n, rational & val) const { bool is_int; return is_numeral(n, val, is_int); } - bool is_numeral(expr const * n) const { return is_app_of(n, arith_family_id, OP_NUM); } - bool is_zero(expr const * n) const { rational val; return is_numeral(n, val) && val.is_zero(); } - bool is_minus_one(expr * n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } - // return true if \c n is a term of the form (* -1 r) - bool is_times_minus_one(expr * n, expr * & r) const { - if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { - r = to_app(n)->get_arg(1); - return true; - } - return false; - } + bool is_irrational_algebraic_numeral(expr const* n) const; + + bool is_numeral(expr const* n) const { return is_app_of(n, arith_family_id, OP_NUM); } bool is_int_expr(expr const * e) const; bool is_le(expr const * n) const { return is_app_of(n, arith_family_id, OP_LE); } @@ -309,6 +308,15 @@ class arith_recognizers { bool is_int_real(sort const * s) const { return s->get_family_id() == arith_family_id; } bool is_int_real(expr const * n) const { return is_int_real(n->get_sort()); } + bool is_band(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_BAND); } + bool is_band(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_BAND, sz, x, y); } + bool is_shl(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_SHL); } + bool is_shl(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_SHL, sz, x, y); } + bool is_lshr(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_LSHR); } + bool is_lshr(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_LSHR, sz, x, y); } + bool is_ashr(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_ASHR); } + bool is_ashr(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_ASHR, sz, x, y); } + bool is_sin(expr const* n) const { return is_app_of(n, arith_family_id, OP_SIN); } bool is_cos(expr const* n) const { return is_app_of(n, arith_family_id, OP_COS); } bool is_tan(expr const* n) const { return is_app_of(n, arith_family_id, OP_TAN); } @@ -387,13 +395,33 @@ class arith_util : public arith_recognizers { return *m_plugin; } - algebraic_numbers::manager & am() { + algebraic_numbers::manager & am() const { return plugin().am(); } + // return true if \c n is a term of the form (* -1 r) + bool is_zero(expr const* n) const { rational val; return is_numeral(n, val) && val.is_zero(); } + bool is_one(expr const* n) const{ rational val; return is_numeral(n, val) && val.is_one(); } + bool is_minus_one(expr* n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } + bool is_times_minus_one(expr* n, expr*& r) const { + if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { + r = to_app(n)->get_arg(1); + return true; + } + return false; + } + bool is_unsigned(expr const* n, unsigned& u) const { + rational val; + bool is_int = true; + return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); + } + bool is_numeral(expr const* n) const { return arith_recognizers::is_numeral(n); } + bool is_numeral(expr const* n, rational& val, bool& is_int) const; + bool is_numeral(expr const* n, rational& val) const { bool is_int; return is_numeral(n, val, is_int); } + bool convert_int_numerals_to_real() const { return plugin().convert_int_numerals_to_real(); } - bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val); - algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n); + bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const; + algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n) const; sort * mk_int() { return m_manager.mk_sort(arith_family_id, INT_SORT); } sort * mk_real() { return m_manager.mk_sort(arith_family_id, REAL_SORT); } @@ -471,6 +499,11 @@ class arith_util : public arith_recognizers { app * mk_power(expr* arg1, expr* arg2) { return m_manager.mk_app(arith_family_id, OP_POWER, arg1, arg2); } app * mk_power0(expr* arg1, expr* arg2) { return m_manager.mk_app(arith_family_id, OP_POWER0, arg1, arg2); } + app* mk_band(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_BAND, 1, &p, 2, args); } + app* mk_shl(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_SHL, 1, &p, 2, args); } + app* mk_ashr(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_ASHR, 1, &p, 2, args); } + app* mk_lshr(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_LSHR, 1, &p, 2, args); } + app * mk_sin(expr * arg) { return m_manager.mk_app(arith_family_id, OP_SIN, arg); } app * mk_cos(expr * arg) { return m_manager.mk_app(arith_family_id, OP_COS, arg); } app * mk_tan(expr * arg) { return m_manager.mk_app(arith_family_id, OP_TAN, arg); } @@ -498,11 +531,11 @@ class arith_util : public arith_recognizers { if none of them are numerals, then the left-hand-side has a smaller id than the right hand side. */ app * mk_eq(expr * lhs, expr * rhs) { - if (is_numeral(lhs) || (!is_numeral(rhs) && lhs->get_id() > rhs->get_id())) + if (arith_recognizers::is_numeral(lhs) || (!arith_recognizers::is_numeral(rhs) && lhs->get_id() > rhs->get_id())) std::swap(lhs, rhs); if (lhs == rhs) return m_manager.mk_true(); - if (is_numeral(lhs) && is_numeral(rhs)) { + if (arith_recognizers::is_numeral(lhs) && arith_recognizers::is_numeral(rhs)) { SASSERT(lhs != rhs); return m_manager.mk_false(); } @@ -517,6 +550,8 @@ class arith_util : public arith_recognizers { bool is_considered_uninterpreted(func_decl* f, unsigned n, expr* const* args, func_decl_ref& f_out); + bool is_considered_partially_interpreted(func_decl* f, unsigned n, expr* const* args, func_decl_ref& f_out); + bool is_underspecified(expr* e) const; bool is_bounded(expr* e) const; diff --git a/src/ast/array_decl_plugin.cpp b/src/ast/array_decl_plugin.cpp index 6778bec7cb5..bd9d954c774 100644 --- a/src/ast/array_decl_plugin.cpp +++ b/src/ast/array_decl_plugin.cpp @@ -49,7 +49,7 @@ sort * array_decl_plugin::mk_sort(decl_kind k, unsigned num_parameters, paramete m_manager->raise_exception("invalid array sort definition, invalid number of parameters"); return nullptr; } - parameter params[2] = { parameters[0], parameter(m_manager->mk_bool_sort()) }; + parameter params[2] = { parameter(parameters[0]), parameter(m_manager->mk_bool_sort()) }; return mk_sort(ARRAY_SORT, 2, params); } SASSERT(k == ARRAY_SORT); @@ -633,6 +633,12 @@ bool array_decl_plugin::is_value(app * _e) const { } } +bool array_decl_plugin::is_unique_value(app* _e) const { + array_util u(*m_manager); + expr* e = _e; + return u.is_const(e, e) && m_manager->is_unique_value(e); +} + func_decl * array_recognizers::get_as_array_func_decl(expr * n) const { SASSERT(is_as_array(n)); diff --git a/src/ast/array_decl_plugin.h b/src/ast/array_decl_plugin.h index 79c6e682e65..9a57476d491 100644 --- a/src/ast/array_decl_plugin.h +++ b/src/ast/array_decl_plugin.h @@ -137,6 +137,8 @@ class array_decl_plugin : public decl_plugin { bool is_value(app * e) const override; + bool is_unique_value(app* e) const override; + }; class array_recognizers { @@ -184,6 +186,21 @@ class array_recognizers { bool is_store_ext(expr* e, expr_ref& a, expr_ref_vector& args, expr_ref& value); + + bool is_select1(expr* n) const { return is_select(n) && to_app(n)->get_num_args() == 2; } + + bool is_select1(expr* n, expr*& a, expr*& i) const { + return is_select1(n) && (a = to_app(n)->get_arg(0), i = to_app(n)->get_arg(1), true); + } + + bool is_store1(expr* n) const { return is_store(n) && to_app(n)->get_num_args() == 3; } + + bool is_store1(expr* n, expr*& a, expr*& i, expr*& v) const { + app* _n; + return is_store1(n) && (_n = to_app(n), a = _n->get_arg(0), i = _n->get_arg(1), v = _n->get_arg(2), true); + } + + MATCH_BINARY(is_subset); }; @@ -211,6 +228,11 @@ class array_util : public array_recognizers { return mk_store(args.size(), args.data()); } + app * mk_select(expr* a, expr* i) const { + expr* args[2] = { a, i }; + return mk_select(2, args); + } + app * mk_select(unsigned num_args, expr * const * args) const { return m_manager.mk_app(m_fid, OP_SELECT, 0, nullptr, num_args, args); } diff --git a/src/ast/array_peq.cpp b/src/ast/array_peq.cpp new file mode 100644 index 00000000000..9f4f1b10dee --- /dev/null +++ b/src/ast/array_peq.cpp @@ -0,0 +1,107 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + array_peq.cpp + +Abstract: + + Partial equality for arrays + +Author: + + Nikolaj Bjorner (nbjorner) 2015-06-13 + Hari Govind V K + +Revision History: + +--*/ +#include "ast/array_peq.h" + +#define PARTIAL_EQ "!partial_eq" +bool is_partial_eq(const func_decl *f) { + SASSERT(f); + return f->get_name() == PARTIAL_EQ; +} + +bool is_partial_eq(const app *a) { + SASSERT(a); + return is_partial_eq(a->get_decl()); +} + +app_ref mk_peq(expr *e0, expr *e1, vector const &indices, + ast_manager &m) { + peq p(e0, e1, indices, m); + return p.mk_peq(); +} + +app_ref peq::mk_eq(app_ref_vector &aux_consts, bool stores_on_rhs) { + if (!m_eq) { + expr_ref lhs(m_lhs, m), rhs(m_rhs, m); + if (!stores_on_rhs) { std::swap(lhs, rhs); } + // lhs = (...(store (store rhs i0 v0) i1 v1)...) + sort *val_sort = get_array_range(lhs->get_sort()); + for (expr_ref_vector const &diff : m_diff_indices) { + ptr_vector store_args; + store_args.push_back(rhs); + store_args.append(diff.size(), diff.data()); + app_ref val(m.mk_fresh_const("diff", val_sort), m); + store_args.push_back(val); + aux_consts.push_back(val); + rhs = m_arr_u.mk_store(store_args); + } + m_eq = m.mk_eq(lhs, rhs); + } + return m_eq; +} + +app_ref peq::mk_peq() { + if (!m_peq) { + ptr_vector args; + args.push_back(m_lhs); + args.push_back(m_rhs); + for (auto const &v : m_diff_indices) { + args.append(v.size(), v.data()); + } + m_peq = m.mk_app(m_decl, args.size(), args.data()); + } + return m_peq; +} + +peq::peq(expr *lhs, expr *rhs, vector const &diff_indices, + ast_manager &m) + : m(m), m_lhs(lhs, m), m_rhs(rhs, m), m_diff_indices(diff_indices), + m_decl(m), m_peq(m), m_eq(m), m_arr_u(m) { + SASSERT(m_arr_u.is_array(lhs)); + SASSERT(m_arr_u.is_array(rhs)); + SASSERT(lhs->get_sort() == rhs->get_sort()); + ptr_vector sorts; + sorts.push_back(m_lhs->get_sort()); + sorts.push_back(m_rhs->get_sort()); + + for (auto const &v : diff_indices) { + SASSERT(v.size() == get_array_arity(m_lhs->get_sort())); + for (expr *e : v) sorts.push_back(e->get_sort()); + } + m_decl = m.mk_func_decl(symbol(PARTIAL_EQ), sorts.size(), sorts.data(), + m.mk_bool_sort()); +} + +peq::peq(app *p, ast_manager &m) + : m(m), m_lhs(p->get_arg(0), m), m_rhs(p->get_arg(1), m), + m_decl(p->get_decl(), m), m_peq(p, m), m_eq(m), m_arr_u(m), + m_name(symbol(PARTIAL_EQ)) { + SASSERT(is_partial_eq(p)); + + SASSERT(m_arr_u.is_array(m_lhs)); + SASSERT(m_arr_u.is_array(m_rhs)); + SASSERT(m_lhs->get_sort() == m_rhs->get_sort()); + unsigned arity = get_array_arity(m_lhs->get_sort()); + for (unsigned i = 2; i < p->get_num_args(); i += arity) { + SASSERT(arity + i <= p->get_num_args()); + expr_ref_vector vec(m); + vec.append(arity, p->get_args() + i); + m_diff_indices.push_back(std::move(vec)); + } +} diff --git a/src/ast/array_peq.h b/src/ast/array_peq.h new file mode 100644 index 00000000000..9e71791c145 --- /dev/null +++ b/src/ast/array_peq.h @@ -0,0 +1,91 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + array_peq.h + +Abstract: + + Partial equality for arrays + +Author: + + Nikolaj Bjorner (nbjorner) 2015-06-13 + Hari Govind V K + +Revision History: + +--*/ +#pragma once + +#include "ast/array_decl_plugin.h" +#include "ast/ast.h" + +/** + * \brief utility class for partial equalities + * + * A partial equality (a ==I b), for two arrays a, b and a finite set of indices + * I holds iff (forall i :: i \not\in I => a[i] == b[i]). In other words, peq is + * a restricted form of the extensionality axiom + * + * Using this class, we denote (a =I b) as f(a,b,i0,i1,...), + * where f is an uninterpreted predicate with the name PARTIAL_EQ and + * I = {i0,i1,...} + */ + +class peq { + ast_manager &m; + expr_ref m_lhs; + expr_ref m_rhs; + vector m_diff_indices; + func_decl_ref m_decl; // the partial equality declaration + app_ref m_peq; // partial equality application + app_ref m_eq; // equivalent std equality using def. of partial eq + array_util m_arr_u; + symbol m_name; + + public: + peq(app *p, ast_manager &m); + + peq(expr *lhs, expr *rhs, vector const &diff_indices, + ast_manager &m); + + expr_ref lhs() { return m_lhs; } + + expr_ref rhs() { return m_rhs; } + + void get_diff_indices(vector &result) { + result.append(m_diff_indices); + } + + /** Convert peq into a peq expression */ + app_ref mk_peq(); + + /** Convert peq into an equality + + For peq of the form (a =I b) returns (a = b[i0 := v0, i1 := v1, ...]) + where i0, i1 \in I, and v0, v1 are fresh skolem constants + + Skolems are returned in aux_consts + + The left and right hand arguments are reversed when stores_on_rhs is + false + */ + app_ref mk_eq(app_ref_vector &aux_consts, bool stores_on_rhs = true); +}; + +/** + * mk (e0 ==indices e1) + * + * result has stores if either e0 or e1 or an index term has stores + */ +app_ref mk_peq(expr *e0, expr *e1, vector const &indices, + ast_manager &m); + +bool is_partial_eq(const func_decl *f); + +bool is_partial_eq(const app *a); + +inline bool is_peq(const func_decl *f) { return is_partial_eq(f); } +inline bool is_peq(const app *a) { return is_partial_eq(a); } diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 7f9542fe48a..e1e3efe99ee 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -48,21 +48,13 @@ parameter::~parameter() { } } -parameter& parameter::operator=(parameter const& other) { - if (this == &other) { - return *this; - } - - this->~parameter(); - m_val = other.m_val; - +parameter::parameter(parameter const& other) : m_val(other.m_val) { if (auto p = std::get_if(&m_val)) { m_val = alloc(rational, **p); } if (auto p = std::get_if(&m_val)) { m_val = alloc(zstring, **p); } - return *this; } void parameter::init_eh(ast_manager & m) { @@ -202,8 +194,7 @@ unsigned decl_info::hash() const { bool decl_info::operator==(decl_info const & info) const { return m_family_id == info.m_family_id && m_kind == info.m_kind && - m_parameters.size() == info.m_parameters.size() && - compare_arrays(m_parameters.begin(), info.m_parameters.begin(), m_parameters.size()); + m_parameters == info.m_parameters; } std::ostream & operator<<(std::ostream & out, decl_info const & info) { @@ -255,7 +246,8 @@ func_decl_info::func_decl_info(family_id family_id, decl_kind k, unsigned num_pa m_injective(false), m_idempotent(false), m_skolem(false), - m_lambda(false) { + m_lambda(false), + m_polymorphic(false) { } bool func_decl_info::operator==(func_decl_info const & info) const { @@ -283,6 +275,7 @@ std::ostream & operator<<(std::ostream & out, func_decl_info const & info) { if (info.is_idempotent()) out << " :idempotent "; if (info.is_skolem()) out << " :skolem "; if (info.is_lambda()) out << " :lambda "; + if (info.is_polymorphic()) out << " :polymorphic "; return out; } @@ -318,26 +311,6 @@ func_decl::func_decl(symbol const & name, unsigned arity, sort * const * domain, // // ----------------------------------- -static app_flags mk_const_flags() { - app_flags r; - r.m_depth = 1; - r.m_ground = true; - r.m_has_quantifiers = false; - r.m_has_labels = false; - return r; -} - -static app_flags mk_default_app_flags() { - app_flags r; - r.m_depth = 1; - r.m_ground = true; - r.m_has_quantifiers = false; - r.m_has_labels = false; - return r; -} - -app_flags app::g_constant_flags = mk_const_flags(); - app::app(func_decl * decl, unsigned num_args, expr * const * args): expr(AST_APP), m_decl(decl), @@ -447,7 +420,7 @@ bool compare_nodes(ast const * n1, ast const * n2) { if (to_sort(n1)->get_info() != nullptr && !(*to_sort(n1)->get_info() == *to_sort(n2)->get_info())) { return false; } - return to_sort(n1)->get_name() == to_sort(n2)->get_name(); + return to_sort(n1)->get_name() == to_sort(n2)->get_name(); case AST_FUNC_DECL: if ((to_func_decl(n1)->get_info() == nullptr) != (to_func_decl(n2)->get_info() == nullptr)) { return false; @@ -477,15 +450,15 @@ bool compare_nodes(ast const * n1, ast const * n2) { return q1->get_kind() == q2->get_kind() && q1->get_num_decls() == q2->get_num_decls() && + q1->get_expr() == q2->get_expr() && + q1->get_weight() == q2->get_weight() && + q1->get_num_patterns() == q2->get_num_patterns() && compare_arrays(q1->get_decl_sorts(), q2->get_decl_sorts(), q1->get_num_decls()) && compare_arrays(q1->get_decl_names(), q2->get_decl_names(), q1->get_num_decls()) && - q1->get_expr() == q2->get_expr() && - q1->get_weight() == q2->get_weight() && - q1->get_num_patterns() == q2->get_num_patterns() && ((q1->get_qid().is_numerical() && q2->get_qid().is_numerical()) || (q1->get_qid() == q2->get_qid())) && compare_arrays(q1->get_patterns(), @@ -540,22 +513,6 @@ inline unsigned ast_array_hash(T * const * array, unsigned size, unsigned init_v } } } -unsigned get_asts_hash(unsigned sz, ast * const* ns, unsigned init) { - return ast_array_hash(ns, sz, init); -} -unsigned get_apps_hash(unsigned sz, app * const* ns, unsigned init) { - return ast_array_hash(ns, sz, init); -} -unsigned get_exprs_hash(unsigned sz, expr * const* ns, unsigned init) { - return ast_array_hash(ns, sz, init); -} -unsigned get_sorts_hash(unsigned sz, sort * const* ns, unsigned init) { - return ast_array_hash(ns, sz, init); -} -unsigned get_decl_hash(unsigned sz, func_decl* const* ns, unsigned init) { - return ast_array_hash(ns, sz, init); -} - unsigned get_node_hash(ast const * n) { unsigned a, b, c; @@ -1309,10 +1266,7 @@ ast_manager::ast_manager(proof_gen_mode m, char const * trace_file, bool is_form m_expr_array_manager(*this, m_alloc), m_expr_dependency_manager(*this, m_alloc), m_expr_dependency_array_manager(*this, m_alloc), - m_proof_mode(m), - m_trace_stream(nullptr), - m_trace_stream_owner(false), - m_lambda_def(":lambda-def") { + m_proof_mode(m) { if (trace_file) { m_trace_stream = alloc(std::fstream, trace_file, std::ios_base::out); @@ -1333,9 +1287,7 @@ ast_manager::ast_manager(proof_gen_mode m, std::fstream * trace_stream, bool is_ m_expr_dependency_manager(*this, m_alloc), m_expr_dependency_array_manager(*this, m_alloc), m_proof_mode(m), - m_trace_stream(trace_stream), - m_trace_stream_owner(false), - m_lambda_def(":lambda-def") { + m_trace_stream(trace_stream) { if (!is_format_manager) m_format_manager = alloc(ast_manager, PGM_DISABLED, trace_stream, true); @@ -1350,9 +1302,7 @@ ast_manager::ast_manager(ast_manager const & src, bool disable_proofs): m_expr_dependency_manager(*this, m_alloc), m_expr_dependency_array_manager(*this, m_alloc), m_proof_mode(disable_proofs ? PGM_DISABLED : src.m_proof_mode), - m_trace_stream(src.m_trace_stream), - m_trace_stream_owner(false), - m_lambda_def(":lambda-def") { + m_trace_stream(src.m_trace_stream) { SASSERT(!src.is_format_manager()); m_format_manager = alloc(ast_manager, PGM_DISABLED, m_trace_stream, true); init(); @@ -1378,6 +1328,7 @@ void ast_manager::init() { ENSURE(model_value_family_id == mk_family_id("model-value")); ENSURE(user_sort_family_id == mk_family_id("user-sort")); ENSURE(arith_family_id == mk_family_id("arith")); + ENSURE(poly_family_id == mk_family_id("polymorphic")); basic_decl_plugin * plugin = alloc(basic_decl_plugin); register_plugin(basic_family_id, plugin); m_bool_sort = plugin->mk_bool_sort(); @@ -1783,8 +1734,7 @@ ast * ast_manager::register_node_core(ast * n) { inc_ref(t->get_decl()); unsigned num_args = t->get_num_args(); if (num_args > 0) { - app_flags * f = t->flags(); - *f = mk_default_app_flags(); + app_flags * f = &t->m_flags; SASSERT(t->is_ground()); SASSERT(!t->has_quantifiers()); SASSERT(!t->has_labels()); @@ -1797,13 +1747,13 @@ ast * ast_manager::register_node_core(ast * n) { unsigned arg_depth = 0; switch (arg->get_kind()) { case AST_APP: { - app_flags * arg_flags = to_app(arg)->flags(); - arg_depth = arg_flags->m_depth; - if (arg_flags->m_has_quantifiers) + app *app = to_app(arg); + arg_depth = app->get_depth(); + if (app->has_quantifiers()) f->m_has_quantifiers = true; - if (arg_flags->m_has_labels) + if (app->has_labels()) f->m_has_labels = true; - if (!arg_flags->m_ground) + if (!app->is_ground()) f->m_ground = false; break; } @@ -1879,6 +1829,8 @@ void ast_manager::delete_node(ast * n) { break; case AST_FUNC_DECL: { func_decl* f = to_func_decl(n); + if (f->is_polymorphic()) + m_poly_roots.erase(f); if (f->m_info != nullptr) { func_decl_info * info = f->get_info(); if (info->is_lambda()) { @@ -2019,6 +1971,7 @@ sort * ast_manager::mk_uninterpreted_sort(symbol const & name, unsigned num_para return plugin->mk_sort(kind, num_parameters, parameters); } + func_decl * ast_manager::mk_func_decl(symbol const & name, unsigned arity, sort * const * domain, sort * range, bool assoc, bool comm, bool inj) { func_decl_info info(null_family_id, null_decl_kind); @@ -2029,13 +1982,30 @@ func_decl * ast_manager::mk_func_decl(symbol const & name, unsigned arity, sort } func_decl * ast_manager::mk_func_decl(symbol const & name, unsigned arity, sort * const * domain, sort * range, func_decl_info * info) { - SASSERT(arity == 1 || info == 0 || !info->is_injective()); - SASSERT(arity == 2 || info == 0 || !info->is_associative()); - SASSERT(arity == 2 || info == 0 || !info->is_commutative()); + SASSERT(arity == 1 || !info || !info->is_injective()); + SASSERT(arity == 2 || !info || !info->is_associative()); + SASSERT(arity == 2 || !info || !info->is_commutative()); unsigned sz = func_decl::get_obj_size(arity); void * mem = allocate_node(sz); - func_decl * new_node = new (mem) func_decl(name, arity, domain, range, info); - return register_node(new_node); + + // determine if function is a polymorphic root object. + // instances of polymorphic functions are automatically tagged as polymorphic and + // inserted into the m_poly_roots table. + bool is_polymorphic_root = false; + func_decl_info info0; + if (has_type_var(arity, domain, range)) { + if (!info) + info = &info0; + if (!info->is_polymorphic()) { + info->set_polymorphic(true); + is_polymorphic_root = true; + } + } + func_decl* new_node = new (mem) func_decl(name, arity, domain, range, info); + new_node = register_node(new_node); + if (is_polymorphic_root) + m_poly_roots.insert(new_node, new_node); + return new_node; } void ast_manager::check_sort(func_decl const * decl, unsigned num_args, expr * const * args) const { @@ -2300,9 +2270,10 @@ func_decl * ast_manager::mk_fresh_func_decl(symbol const & prefix, symbol const func_decl_info info(null_family_id, null_decl_kind); info.m_skolem = skolem; SASSERT(skolem == info.is_skolem()); + func_decl_info* infop = skolem ? &info : nullptr; func_decl * d; if (prefix == symbol::null && suffix == symbol::null) { - d = mk_func_decl(symbol(m_fresh_id), arity, domain, range, &info); + d = mk_func_decl(symbol(m_fresh_id), arity, domain, range, infop); } else { string_buffer<64> buffer; @@ -2314,10 +2285,10 @@ func_decl * ast_manager::mk_fresh_func_decl(symbol const & prefix, symbol const if (suffix != symbol::null) buffer << suffix << "!"; buffer << m_fresh_id; - d = mk_func_decl(symbol(buffer.c_str()), arity, domain, range, &info); + d = mk_func_decl(symbol(buffer.c_str()), arity, domain, range, infop); } m_fresh_id++; - SASSERT(d->get_info()); + SASSERT(!skolem || d->get_info()); SASSERT(skolem == d->is_skolem()); return d; } @@ -2719,6 +2690,49 @@ bool ast_manager::is_fully_interp(sort * s) const { return false; } +// ----------------------------------------- +// Polymorphism +// ----------------------------------------- +sort * ast_manager::mk_type_var(symbol const& name) { + m_has_type_vars = true; + sort_info si(poly_family_id, 0); + return mk_sort(name, &si); +} + +bool ast_manager::has_type_var(sort* s) const { + if (is_type_var(s)) + return true; + for (parameter const& p : s->parameters()) + if (p.is_ast() && is_sort(p.get_ast()) && has_type_var(to_sort(p.get_ast()))) + return true; + return false; +} + +bool ast_manager::has_type_var(func_decl* f) const { + return has_type_var(f->get_arity(), f->get_domain(), f->get_range()); +} + +bool ast_manager::has_type_var(unsigned n, sort* const* domain, sort* range) const { + if (!has_type_vars()) + return false; + for (unsigned i = n; i-- > 0; ) + if (has_type_var(domain[i])) + return true; + return has_type_var(range); +} + +/** + * \brief create an instantiation of polymorphic function f. + */ + +func_decl* ast_manager::instantiate_polymorphic(func_decl* f, unsigned arity, sort * const* domain, sort * range) { + SASSERT(f->is_polymorphic()); + func_decl* g = mk_func_decl(f->get_name(), arity, domain, range, f->get_info()); + m_poly_roots.insert(g, f); + // SASSERT(g->is_polymorphic()); + return g; +} + // ----------------------------------- // // Proof generation @@ -2841,29 +2855,40 @@ proof * ast_manager::mk_transitivity(proof * p1, proof * p2) { SASSERT(has_fact(p2)); SASSERT(is_app(get_fact(p1))); SASSERT(is_app(get_fact(p2))); - SASSERT(to_app(get_fact(p1))->get_num_args() == 2); - SASSERT(to_app(get_fact(p2))->get_num_args() == 2); - CTRACE("mk_transitivity", to_app(get_fact(p1))->get_decl() != to_app(get_fact(p2))->get_decl(), - tout << mk_pp(get_fact(p1), *this) << "\n\n" << mk_pp(get_fact(p2), *this) << "\n"; - tout << mk_pp(to_app(get_fact(p1))->get_decl(), *this) << "\n"; - tout << mk_pp(to_app(get_fact(p2))->get_decl(), *this) << "\n";); - SASSERT(to_app(get_fact(p1))->get_decl() == to_app(get_fact(p2))->get_decl() || - ( (is_eq(get_fact(p1)) || is_oeq(get_fact(p1))) && - (is_eq(get_fact(p2)) || is_oeq(get_fact(p2))))); - CTRACE("mk_transitivity", to_app(get_fact(p1))->get_arg(1) != to_app(get_fact(p2))->get_arg(0), - tout << mk_pp(get_fact(p1), *this) << "\n\n" << mk_pp(get_fact(p2), *this) << "\n"; + app* fact1 = to_app(get_fact(p1)); + app* fact2 = to_app(get_fact(p2)); + SASSERT(fact1->get_num_args() == 2); + SASSERT(fact2->get_num_args() == 2); + CTRACE("mk_transitivity", fact1->get_decl() != fact2->get_decl(), + tout << mk_pp(fact1, *this) << "\n\n" << mk_pp(fact2, *this) << "\n"; + tout << mk_pp(fact1->get_decl(), *this) << "\n"; + tout << mk_pp(fact2->get_decl(), *this) << "\n";); + SASSERT(fact1->get_decl() == fact2->get_decl() || + ( (is_eq(fact1) || is_oeq(fact1)) && + (is_eq(fact2) || is_oeq(fact2)))); + CTRACE("mk_transitivity", fact1->get_arg(1) != fact2->get_arg(0), + tout << mk_pp(fact1, *this) << "\n\n" << mk_pp(fact2, *this) << "\n"; tout << p1->get_id() << ": " << mk_bounded_pp(p1, *this, 5) << "\n\n"; tout << p2->get_id() << ": " << mk_bounded_pp(p2, *this, 5) << "\n\n"; ); - SASSERT(to_app(get_fact(p1))->get_arg(1) == to_app(get_fact(p2))->get_arg(0)); if (is_reflexivity(p1)) return p2; if (is_reflexivity(p2)) return p1; + // local fixup to admit inline simplifications of not(not(e)) to e + expr* e; + if (is_not(fact1->get_arg(1), e) && is_not(e, e) && e == fact2->get_arg(0)) + p1 = mk_transitivity(p1, mk_rewrite(fact1->get_arg(1), fact2->get_arg(0))); + else if (is_not(fact2->get_arg(0), e) && is_not(e, e) && e == fact1->get_arg(1)) + p1 = mk_transitivity(p1, mk_rewrite(fact1->get_arg(1), fact2->get_arg(0))); + else { + SASSERT(fact1->get_arg(1) == fact2->get_arg(0)); + } // OEQ is compatible with EQ for transitivity. - func_decl* f = to_app(get_fact(p1))->get_decl(); - if (is_oeq(get_fact(p2))) f = to_app(get_fact(p2))->get_decl(); - return mk_app(basic_family_id, PR_TRANSITIVITY, p1, p2, mk_app(f, to_app(get_fact(p1))->get_arg(0), to_app(get_fact(p2))->get_arg(1))); + func_decl* f = fact1->get_decl(); + if (is_oeq(fact2)) + f = fact2->get_decl(); + return mk_app(basic_family_id, PR_TRANSITIVITY, p1, p2, mk_app(f, fact1->get_arg(0), fact2->get_arg(1))); } diff --git a/src/ast/ast.h b/src/ast/ast.h index e0ae7b92f21..05339002273 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -85,6 +85,7 @@ const family_id user_sort_family_id = 4; const family_id last_builtin_family_id = 4; const family_id arith_family_id = 5; +const family_id poly_family_id = 6; // ----------------------------------- // @@ -141,7 +142,7 @@ class parameter { explicit parameter(const char *s): m_val(symbol(s)) {} explicit parameter(const std::string &s): m_val(symbol(s)) {} explicit parameter(unsigned ext_id, bool): m_val(ext_id) {} - parameter(parameter const& other) { *this = other; } + explicit parameter(parameter const& other); parameter(parameter && other) noexcept : m_val(std::move(other.m_val)) { other.m_val = 0; @@ -149,7 +150,10 @@ class parameter { ~parameter(); - parameter& operator=(parameter const& other); + parameter& operator=(parameter && other) { + std::swap(other.m_val, m_val); + return *this; + } kind_t get_kind() const { return static_cast(m_val.index()); } bool is_int() const { return get_kind() == PARAM_INT; } @@ -399,6 +403,7 @@ struct func_decl_info : public decl_info { bool m_idempotent:1; bool m_skolem:1; bool m_lambda:1; + bool m_polymorphic:1; func_decl_info(family_id family_id = null_family_id, decl_kind k = null_decl_kind, unsigned num_parameters = 0, parameter const * parameters = nullptr); @@ -413,6 +418,7 @@ struct func_decl_info : public decl_info { bool is_idempotent() const { return m_idempotent; } bool is_skolem() const { return m_skolem; } bool is_lambda() const { return m_lambda; } + bool is_polymorphic() const { return m_polymorphic; } void set_associative(bool flag = true) { m_left_assoc = flag; m_right_assoc = flag; } void set_left_associative(bool flag = true) { m_left_assoc = flag; } @@ -425,6 +431,7 @@ struct func_decl_info : public decl_info { void set_idempotent(bool flag = true) { m_idempotent = flag; } void set_skolem(bool flag = true) { m_skolem = flag; } void set_lambda(bool flag = true) { m_lambda = flag; } + void set_polymorphic(bool flag = true) { m_polymorphic = flag; } bool operator==(func_decl_info const & info) const; @@ -622,6 +629,7 @@ class sort : public decl { sort_size const & get_num_elements() const { return get_info()->get_num_elements(); } void set_num_elements(sort_size const& s) { get_info()->set_num_elements(s); } unsigned get_size() const { return get_obj_size(); } + bool is_type_var() const { return get_family_id() == poly_family_id; } }; // ----------------------------------- @@ -653,6 +661,7 @@ class func_decl : public decl { bool is_skolem() const { return get_info() != nullptr && get_info()->is_skolem(); } bool is_lambda() const { return get_info() != nullptr && get_info()->is_lambda(); } bool is_idempotent() const { return get_info() != nullptr && get_info()->is_idempotent(); } + bool is_polymorphic() const { return get_info() != nullptr && get_info()->is_polymorphic(); } unsigned get_arity() const { return m_arity; } sort * get_domain(unsigned idx) const { SASSERT(idx < get_arity()); return m_domain[idx]; } sort * const * get_domain() const { return m_domain; } @@ -698,6 +707,7 @@ struct app_flags { unsigned m_ground:1; // application does not have free variables or nested quantifiers. unsigned m_has_quantifiers:1; // application has nested quantifiers. unsigned m_has_labels:1; // application has nested labels. + app_flags() : m_depth(1), m_ground(1), m_has_quantifiers(0), m_has_labels(0) {} }; class app : public expr { @@ -705,19 +715,15 @@ class app : public expr { func_decl * m_decl; unsigned m_num_args; + app_flags m_flags; expr * m_args[0]; - static app_flags g_constant_flags; - - // remark: store term depth in the end of the app. the depth is only stored if the num_args > 0 static unsigned get_obj_size(unsigned num_args) { - return num_args == 0 ? sizeof(app) : sizeof(app) + num_args * sizeof(expr *) + sizeof(app_flags); + return sizeof(app) + num_args * sizeof(expr *); } friend class tmp_app; - app_flags * flags() const { return m_num_args == 0 ? &g_constant_flags : reinterpret_cast(const_cast(m_args + m_num_args)); } - app(func_decl * decl, unsigned num_args, expr * const * args); public: func_decl * get_decl() const { return m_decl; } @@ -738,10 +744,10 @@ class app : public expr { expr * const * end() const { return m_args + m_num_args; } sort * _get_sort() const { return get_decl()->get_range(); } - unsigned get_depth() const { return flags()->m_depth; } - bool is_ground() const { return flags()->m_ground; } - bool has_quantifiers() const { return flags()->m_has_quantifiers; } - bool has_labels() const { return flags()->m_has_labels; } + unsigned get_depth() const { return m_flags.m_depth; } + bool is_ground() const { return m_flags.m_ground; } + bool has_quantifiers() const { return m_flags.m_has_quantifiers; } + bool has_labels() const { return m_flags.m_has_labels; } }; // ----------------------------------- @@ -964,11 +970,6 @@ inline quantifier const * to_quantifier(ast const * n) { SASSERT(is_quantifier(n unsigned get_node_hash(ast const * n); bool compare_nodes(ast const * n1, ast const * n2); unsigned get_node_size(ast const * n); -unsigned get_asts_hash(unsigned sz, ast * const* ns, unsigned init); -unsigned get_apps_hash(unsigned sz, app * const* ns, unsigned init); -unsigned get_exprs_hash(unsigned sz, expr * const* ns, unsigned init); -unsigned get_sorts_hash(unsigned sz, sort * const* ns, unsigned init); -unsigned get_decl_hash(unsigned sz, func_decl* const* ns, unsigned init); // This is the internal comparison functor for hash-consing AST nodes. struct ast_eq_proc { @@ -1101,7 +1102,7 @@ class decl_plugin { // Event handlers for deleting/translating PARAM_EXTERNAL virtual void del(parameter const & p) {} - virtual parameter translate(parameter const & p, decl_plugin & target) { UNREACHABLE(); return p; } + virtual parameter translate(parameter const & p, decl_plugin & target) { UNREACHABLE(); return {}; } virtual bool is_considered_uninterpreted(func_decl * f) { return false; } }; @@ -1509,13 +1510,15 @@ class ast_manager { unsigned m_fresh_id; bool m_debug_ref_count; u_map m_debug_free_indices; - std::fstream* m_trace_stream; - bool m_trace_stream_owner; + std::fstream* m_trace_stream = nullptr; + bool m_trace_stream_owner = false; + bool m_has_type_vars = false; #ifdef Z3DEBUG bool slow_not_contains(ast const * n); #endif ast_manager * m_format_manager; // hack for isolating format objects in a different manager. - symbol m_lambda_def; + symbol m_lambda_def = symbol(":lambda-def"); + obj_map m_poly_roots; void init(); @@ -1709,6 +1712,8 @@ class ast_manager { sort * mk_uninterpreted_sort(symbol const & name) { return mk_uninterpreted_sort(name, 0, nullptr); } + sort * mk_type_var(symbol const& name); + sort * mk_sort(symbol const & name, sort_info const & info) { if (info.get_family_id() == null_family_id) { return mk_uninterpreted_sort(name); @@ -1730,12 +1735,27 @@ class ast_manager { bool is_uninterp(sort const * s) const { return s->get_family_id() == null_family_id || s->get_family_id() == user_sort_family_id; } + bool is_type_var(sort const* s) const { return s->get_family_id() == poly_family_id; } + + bool has_type_vars() const { return m_has_type_vars; } + + func_decl* poly_root(func_decl* f) const { SASSERT(f->is_polymorphic()); return m_poly_roots[f]; } + + + func_decl* instantiate_polymorphic(func_decl* f, unsigned arity, sort * const* domain, sort * range); + /** \brief A sort is "fully" interpreted if it is interpreted, and doesn't depend on other uninterpreted sorts. */ bool is_fully_interp(sort * s) const; + bool has_type_var(sort* s) const; + + bool has_type_var(func_decl* f) const; + + bool has_type_var(unsigned n, sort* const* domain, sort* range) const; + func_decl * mk_func_decl(family_id fid, decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range = nullptr); diff --git a/src/ast/ast_lt.cpp b/src/ast/ast_lt.cpp index 3537dc71e55..869c7bff8db 100644 --- a/src/ast/ast_lt.cpp +++ b/src/ast/ast_lt.cpp @@ -68,8 +68,8 @@ bool lt(ast * n1, ast * n2) { num = to_sort(n1)->get_num_parameters(); SASSERT(num > 0); for (unsigned i = 0; i < num; i++) { - parameter p1 = to_sort(n1)->get_parameter(i); - parameter p2 = to_sort(n2)->get_parameter(i); + const parameter &p1 = to_sort(n1)->get_parameter(i); + const parameter &p2 = to_sort(n2)->get_parameter(i); check_parameter(p1, p2); } UNREACHABLE(); @@ -80,8 +80,8 @@ bool lt(ast * n1, ast * n2) { check_value(to_func_decl(n1)->get_num_parameters(), to_func_decl(n2)->get_num_parameters()); num = to_func_decl(n1)->get_num_parameters(); for (unsigned i = 0; i < num; i++) { - parameter p1 = to_func_decl(n1)->get_parameter(i); - parameter p2 = to_func_decl(n2)->get_parameter(i); + const parameter &p1 = to_func_decl(n1)->get_parameter(i); + const parameter &p2 = to_func_decl(n2)->get_parameter(i); check_parameter(p1, p2); } num = to_func_decl(n1)->get_arity(); diff --git a/src/ast/ast_pp.h b/src/ast/ast_pp.h index 7ccb8ec155b..1f20ce300db 100644 --- a/src/ast/ast_pp.h +++ b/src/ast/ast_pp.h @@ -58,13 +58,13 @@ inline std::ostream& operator<<(std::ostream & out, mk_pp_vec const & pp) { inline std::string operator+(char const* s, mk_pp const& pp) { std::ostringstream strm; strm << s << pp; - return strm.str(); + return std::move(strm).str(); } inline std::string operator+(std::string const& s, mk_pp const& pp) { std::ostringstream strm; strm << s << pp; - return strm.str(); + return std::move(strm).str(); } inline std::string& operator+=(std::string& s, mk_pp const& pp) { diff --git a/src/ast/ast_translation.cpp b/src/ast/ast_translation.cpp index 781593b38bc..e2369a35ad7 100644 --- a/src/ast/ast_translation.cpp +++ b/src/ast/ast_translation.cpp @@ -65,6 +65,13 @@ void ast_translation::collect_decl_extra_children(decl * d) { } void ast_translation::push_frame(ast * n) { + // ensure poly roots are pushed first. + if (m_from_manager.has_type_vars() && n->get_kind() == AST_FUNC_DECL && to_func_decl(n)->is_polymorphic()) { + func_decl* g = m_from_manager.poly_root(to_func_decl(n)); + if (n != g && m_cache.contains(g)) { + m_frame_stack.push_back(frame(n, 0, m_extra_children_stack.size(), m_result_stack.size())); + } + } m_frame_stack.push_back(frame(n, 0, m_extra_children_stack.size(), m_result_stack.size())); switch (n->get_kind()) { case AST_SORT: @@ -153,6 +160,10 @@ void ast_translation::mk_func_decl(func_decl * f, frame & fr) { new_domain, new_range); } + else if (f->is_polymorphic() && m_from_manager.poly_root(f) != f) { + func_decl* fr = to_func_decl(m_cache[m_from_manager.poly_root(f)]); + new_f = m_to_manager.instantiate_polymorphic(fr, f->get_arity(), new_domain, new_range); + } else { buffer ps; copy_params(f, fr.m_rpos, ps); diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index 4d38327a5e7..5dd9f60806f 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -454,9 +454,8 @@ func_decl * bv_decl_plugin::mk_num_decl(unsigned num_parameters, parameter const // This cannot be enforced now, since some Z3 modules try to generate these invalid numerals. // After SMT-COMP, I should find all offending modules. // For now, I will just simplify the numeral here. - rational v = parameters[0].get_rational(); - parameter p0(mod2k(v, bv_size)); - parameter ps[2] = { std::move(p0), parameters[1] }; + const rational &v = parameters[0].get_rational(); + parameter ps[2] = { parameter(mod2k(v, bv_size)), parameter(parameters[1]) }; sort * bv = get_bv_sort(bv_size); return m_manager->mk_const_decl(m_bv_sym, bv, func_decl_info(m_family_id, OP_BV_NUM, num_parameters, ps)); } @@ -651,7 +650,7 @@ func_decl * bv_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, p for (unsigned i = 0; i < num_args; ++i) { if (args[i]->get_sort() != r->get_domain(i)) { std::ostringstream buffer; - buffer << "Argument " << mk_pp(args[i], m) << " at position " << i << " has sort " << mk_pp(args[i]->get_sort(), m) << " it does does not match declaration " << mk_pp(r, m); + buffer << "Argument " << mk_pp(args[i], m) << " at position " << i << " has sort " << mk_pp(args[i]->get_sort(), m) << " it does not match declaration " << mk_pp(r, m); m.raise_exception(buffer.str()); return nullptr; } @@ -913,13 +912,9 @@ app * bv_util::mk_numeral(rational const & val, unsigned bv_size) const { if (m_plugin->log_constant_meaning_prelude(r)) { if (bv_size % 4 == 0) { - m_manager.trace_stream() << "#x"; - val.display_hex(m_manager.trace_stream(), bv_size); - m_manager.trace_stream() << "\n"; + m_manager.trace_stream() << "#x" << val.as_hex(bv_size) << "\n"; } else { - m_manager.trace_stream() << "#b"; - val.display_bin(m_manager.trace_stream(), bv_size); - m_manager.trace_stream() << "\n"; + m_manager.trace_stream() << "#b" << val.as_bin(bv_size) << "\n"; } } @@ -942,3 +937,18 @@ app * bv_util::mk_bv2int(expr* e) { parameter p(s); return m_manager.mk_app(get_fid(), OP_BV2INT, 1, &p, 1, &e); } + +app* bv_util::mk_int2bv(unsigned sz, expr* e) { + parameter p(sz); + return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); +} + +app* bv_util::mk_bv_rotate_left(expr* arg, unsigned n) { + parameter p(n); + return m_manager.mk_app(get_fid(), OP_ROTATE_LEFT, 1, &p, 1, &arg); +} + +app* bv_util::mk_bv_rotate_right(expr* arg, unsigned n) { + parameter p(n); + return m_manager.mk_app(get_fid(), OP_ROTATE_RIGHT, 1, &p, 1, &arg); +} \ No newline at end of file diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 51faca7edab..137dc754f59 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -96,7 +96,7 @@ enum bv_op_kind { OP_BUMUL_OVFL, // unsigned multiplication overflow predicate (negation of OP_BUMUL_NO_OVFL) OP_BSMUL_OVFL, // signed multiplication over/underflow predicate - OP_BSDIV_OVFL, // signed division overflow perdicate + OP_BSDIV_OVFL, // signed division overflow predicate OP_BNEG_OVFL, // negation overflow predicate @@ -386,9 +386,31 @@ class bv_recognizers { bool is_bv_shl(expr const * e) const { return is_app_of(e, get_fid(), OP_BSHL); } bool is_sign_ext(expr const * e) const { return is_app_of(e, get_fid(), OP_SIGN_EXT); } bool is_bv_umul_no_ovfl(expr const* e) const { return is_app_of(e, get_fid(), OP_BUMUL_NO_OVFL); } + bool is_redand(expr const* e) const { return is_app_of(e, get_fid(), OP_BREDAND); } + bool is_redor(expr const* e) const { return is_app_of(e, get_fid(), OP_BREDOR); } + bool is_comp(expr const* e) const { return is_app_of(e, get_fid(), OP_BCOMP); } + bool is_rotate_left(expr const* e) const { return is_app_of(e, get_fid(), OP_ROTATE_LEFT); } + bool is_rotate_right(expr const* e) const { return is_app_of(e, get_fid(), OP_ROTATE_RIGHT); } + bool is_ext_rotate_left(expr const* e) const { return is_app_of(e, get_fid(), OP_EXT_ROTATE_LEFT); } + bool is_ext_rotate_right(expr const* e) const { return is_app_of(e, get_fid(), OP_EXT_ROTATE_RIGHT); } + + bool is_rotate_left(expr const* e, unsigned& n, expr*& x) const { + return is_rotate_left(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } + bool is_rotate_right(expr const* e, unsigned& n, expr*& x) const { + return is_rotate_right(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } + bool is_int2bv(expr const* e, unsigned& n, expr*& x) const { + return is_int2bv(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } MATCH_UNARY(is_bv_not); + MATCH_UNARY(is_redand); + MATCH_UNARY(is_redor); + MATCH_BINARY(is_ext_rotate_left); + MATCH_BINARY(is_ext_rotate_right); + MATCH_BINARY(is_comp); MATCH_BINARY(is_bv_add); MATCH_BINARY(is_bv_sub); MATCH_BINARY(is_bv_mul); @@ -411,12 +433,23 @@ class bv_recognizers { MATCH_BINARY(is_bv_sdiv); MATCH_BINARY(is_bv_udiv); MATCH_BINARY(is_bv_smod); + MATCH_BINARY(is_bv_and); + MATCH_BINARY(is_bv_or); + MATCH_BINARY(is_bv_xor); + MATCH_BINARY(is_bv_nand); + MATCH_BINARY(is_bv_nor); + MATCH_BINARY(is_bv_uremi); MATCH_BINARY(is_bv_sremi); MATCH_BINARY(is_bv_sdivi); MATCH_BINARY(is_bv_udivi); MATCH_BINARY(is_bv_smodi); + MATCH_BINARY(is_bv_urem0); + MATCH_BINARY(is_bv_srem0); + MATCH_BINARY(is_bv_sdiv0); + MATCH_BINARY(is_bv_udiv0); + MATCH_BINARY(is_bv_smod0); MATCH_UNARY(is_bit2bool); MATCH_UNARY(is_int2bv); bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const; @@ -516,17 +549,23 @@ class bv_util : public bv_recognizers { app * mk_bv_lshr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BLSHR, arg1, arg2); } app * mk_bv2int(expr* e); + app * mk_int2bv(unsigned sz, expr* e); + + app* mk_bv_rotate_left(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_LEFT, arg1, arg2); } + app* mk_bv_rotate_right(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_RIGHT, arg1, arg2); } + app* mk_bv_rotate_left(expr* arg, unsigned n); + app* mk_bv_rotate_right(expr* arg, unsigned n); // TODO: all these binary ops commute (right?) but it'd be more logical to swap `n` & `m` in the `return` - app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, n, m); } - app * mk_bvsmul_no_udfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_UDFL, n, m); } - app * mk_bvumul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_NO_OVFL, n, m); } - app * mk_bvsmul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_OVFL, n, m); } - app * mk_bvumul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_OVFL, n, m); } + app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, m, n); } + app * mk_bvsmul_no_udfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_UDFL, m, n); } + app * mk_bvumul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_NO_OVFL, m, n); } + app * mk_bvsmul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_OVFL, m, n); } + app * mk_bvumul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_OVFL, m, n); } app * mk_bvsdiv_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSDIV_OVFL, m, n); } app * mk_bvneg_ovfl(expr* m) { return m_manager.mk_app(get_fid(), OP_BNEG_OVFL, m); } - app * mk_bvuadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUADD_OVFL, n, m); } - app * mk_bvsadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSADD_OVFL, n, m); } + app * mk_bvuadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUADD_OVFL, m, n); } + app * mk_bvsadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSADD_OVFL, m, n); } app * mk_bvusub_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUSUB_OVFL, m, n); } app * mk_bvssub_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSSUB_OVFL, m, n); } diff --git a/src/ast/char_decl_plugin.cpp b/src/ast/char_decl_plugin.cpp index 029312ba326..6a94db12148 100644 --- a/src/ast/char_decl_plugin.cpp +++ b/src/ast/char_decl_plugin.cpp @@ -7,7 +7,7 @@ Module Name: Abstract: - char_plugin for unicode suppport + char_plugin for unicode support Author: diff --git a/src/ast/char_decl_plugin.h b/src/ast/char_decl_plugin.h index 3d934ffe403..686b7105fed 100644 --- a/src/ast/char_decl_plugin.h +++ b/src/ast/char_decl_plugin.h @@ -7,7 +7,7 @@ Module Name: Abstract: - char_plugin for unicode suppport + char_plugin for unicode support Author: diff --git a/src/ast/converters/expr_inverter.cpp b/src/ast/converters/expr_inverter.cpp index 0ee3e130d30..a06d125a5e2 100644 --- a/src/ast/converters/expr_inverter.cpp +++ b/src/ast/converters/expr_inverter.cpp @@ -78,7 +78,7 @@ class basic_expr_inverter : public iexpr_inverter { * * x = t -> fresh * x := if(fresh, t, diff(t)) - * where diff is a diagnonalization function available in domains of size > 1. + * where diff is a diagonalization function available in domains of size > 1. * */ @@ -400,6 +400,7 @@ class bv_expr_inverter : public iexpr_inverter { } bool process_concat(func_decl* f, unsigned num, expr* const* args, expr_ref& r) { +// return false; if (num == 0) return false; if (!uncnstr(num, args)) @@ -807,7 +808,7 @@ bool iexpr_inverter::uncnstr(unsigned num, expr * const * args) const { /** \brief Create a fresh variable for abstracting (f args[0] ... args[num-1]) - Return true if it a new variable was created, and false if the variable already existed for this + Return true if a new variable was created, and false if the variable already existed for this application. Store the variable in v */ void iexpr_inverter::mk_fresh_uncnstr_var_for(sort * s, expr_ref & v) { diff --git a/src/ast/converters/generic_model_converter.cpp b/src/ast/converters/generic_model_converter.cpp index 1e81f913149..c50d86cae6d 100644 --- a/src/ast/converters/generic_model_converter.cpp +++ b/src/ast/converters/generic_model_converter.cpp @@ -43,6 +43,7 @@ void generic_model_converter::operator()(model_ref & md) { expr_ref val(m); unsigned arity; bool reset_ev = false; + obj_map> uninterpreted; for (unsigned i = m_entries.size(); i-- > 0; ) { entry const& e = m_entries[i]; switch (e.m_instruction) { @@ -63,6 +64,13 @@ void generic_model_converter::operator()(model_ref & md) { reset_ev = old_val != nullptr; md->register_decl(e.m_f, val); } + // corner case when uninterpreted constants are eliminated + sort* s = e.m_f->get_range(); + if (m.is_uninterp(s) && !md->has_uninterpreted_sort(s)) { + uninterpreted.insert_if_not_there(s, {}); + if (!uninterpreted[s].contains(val)) + uninterpreted[s].push_back(val); + } } else { func_interp * old_val = md->get_func_interp(e.m_f); @@ -84,6 +92,9 @@ void generic_model_converter::operator()(model_ref & md) { break; } } + for (auto const& [s, u] : uninterpreted) { + md->register_usort(s, u.size(), u.data()); + } TRACE("model_converter", tout << "after generic_model_converter\n"; model_v2_pp(tout, *md);); } diff --git a/src/ast/converters/generic_model_converter.h b/src/ast/converters/generic_model_converter.h index 8a1c6234774..0bc6b21b407 100644 --- a/src/ast/converters/generic_model_converter.h +++ b/src/ast/converters/generic_model_converter.h @@ -68,6 +68,8 @@ class generic_model_converter : public model_converter { void get_units(obj_map& units) override; vector const& entries() const { return m_entries; } + + void reset() { m_entries.reset(); } }; typedef ref generic_model_converter_ref; diff --git a/src/ast/converters/model_converter.cpp b/src/ast/converters/model_converter.cpp index 716970cba2d..d053394cac9 100644 --- a/src/ast/converters/model_converter.cpp +++ b/src/ast/converters/model_converter.cpp @@ -24,7 +24,8 @@ Module Name: * Add or overwrite value in model. */ void model_converter::display_add(std::ostream& out, smt2_pp_environment& env, ast_manager& m, func_decl* f, expr* e) { - VERIFY(e); + if (!e) + return; VERIFY(f->get_range() == e->get_sort()); ast_smt2_pp_rev(out, f, e, env, params_ref(), 0, "model-add") << "\n"; } diff --git a/src/ast/datatype_decl_plugin.cpp b/src/ast/datatype_decl_plugin.cpp index d0214d44f79..a4ddbdaed13 100644 --- a/src/ast/datatype_decl_plugin.cpp +++ b/src/ast/datatype_decl_plugin.cpp @@ -275,7 +275,7 @@ namespace datatype { } parameter const & name = parameters[0]; if (!name.is_symbol()) { - TRACE("datatype", tout << "expected symol parameter at position " << 0 << " got: " << name << "\n";); + TRACE("datatype", tout << "expected symbol parameter at position " << 0 << " got: " << name << "\n";); throw invalid_datatype(); } for (unsigned i = 1; i < num_parameters; ++i) { diff --git a/src/ast/datatype_decl_plugin.h b/src/ast/datatype_decl_plugin.h index 7229636cb40..341f3669b50 100644 --- a/src/ast/datatype_decl_plugin.h +++ b/src/ast/datatype_decl_plugin.h @@ -52,7 +52,7 @@ namespace datatype { class accessor { symbol m_name; sort_ref m_range; - unsigned m_index; // reference to recursive data-type may only get resolved after all mutually recursive data-types are procssed. + unsigned m_index; // reference to recursive data-type may only get resolved after all mutually recursive data-types are processed. constructor* m_constructor{ nullptr }; public: accessor(ast_manager& m, symbol const& n, sort* range): diff --git a/src/ast/euf/CMakeLists.txt b/src/ast/euf/CMakeLists.txt index 8d3fa2e74db..aa71e7fba08 100644 --- a/src/ast/euf/CMakeLists.txt +++ b/src/ast/euf/CMakeLists.txt @@ -1,8 +1,14 @@ z3_add_component(euf SOURCES + euf_ac_plugin.cpp + euf_arith_plugin.cpp + euf_bv_plugin.cpp + euf_egraph.cpp euf_enode.cpp euf_etable.cpp - euf_egraph.cpp + euf_justification.cpp + euf_plugin.cpp + euf_specrel_plugin.cpp COMPONENT_DEPENDENCIES ast util diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp new file mode 100644 index 00000000000..174ef363b21 --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -0,0 +1,1058 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.cpp + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +Completion modulo AC + + E set of eqs + pick critical pair xy = z by j1 xu = v by j2 in E + Add new equation zu = xyu = vy by j1, j2 + + + Notes: + - Some equalities come from shared terms, some do not. + + - V2 can use multiplicities of elements to handle larger domains. + - e.g. 3x + 100000y + +More notes: + + Justifications for new equations are joined (requires extension to egraph/justification) + + Process new merges so use list is updated + Justifications for processed merges are recorded + + Updated equations are recorded for restoration on backtracking + + Keep track of foreign / shared occurrences of AC functions. + - use register_shared to accumulate shared occurrences. + + Shared occurrences are rewritten modulo completion. + When equal to a different shared occurrence, propagate equality. + + - Elimination of redundant rules. + -> forward and backward subsumption + - apply forward subsumption when simplifying equality using processed + - apply backward subsumption when simplifying processed and to_simplify + + Rewrite rules are reoriented after a merge of enodes. + It simulates creating a critical pair: + n -> n' + n + k = j + k + after merge + n' + k = j + k, could be that n' + k < j + k < n + k in term ordering because n' < j, m < n + +TODOs: + +- Efficiency of handling shared terms. + - The shared terms hash table is not incremental. + It could be made incremental by updating it on every merge similar to how the egraph handles it. +- V2 using multiplicities instead of repeated values in monomials. +- Squash trail updates when equations or monomials are modified within the same epoch. + - by an epoch counter that can be updated by the egraph class whenever there is a push/pop. + - store the epoch as a tick on equations and possibly when updating monomials on equations. + +--*/ + +#include "ast/euf/euf_ac_plugin.h" +#include "ast/euf/euf_egraph.h" +#include "ast/ast_pp.h" + +namespace euf { + + ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op) : + plugin(g), m_fid(fid), m_op(op), + m_dep_manager(get_region()), + m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) + { + g.set_th_propagates_diseqs(m_fid); + } + + ac_plugin::ac_plugin(egraph& g, func_decl* f) : + plugin(g), m_fid(f->get_family_id()), m_decl(f), + m_dep_manager(get_region()), + m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) + { + if (m_fid != null_family_id) + g.set_th_propagates_diseqs(m_fid); + } + + void ac_plugin::register_node(enode* n) { + if (is_op(n)) + return; + for (auto arg : enode_args(n)) + if (is_op(arg)) + register_shared(arg); // TODO optimization to avoid registering shared terms twice + } + + void ac_plugin::register_shared(enode* n) { + if (m_shared_nodes.get(n->get_id(), false)) + return; + auto m = to_monomial(n); + auto const& ns = monomial(m); + for (auto arg : ns) { + arg->shared.push_back(m); + m_node_trail.push_back(arg); + push_undo(is_add_shared_index); + } + m_shared_nodes.setx(n->get_id(), true, false); + sort(monomial(m)); + m_shared_todo.insert(m_shared.size()); + m_shared.push_back({ n, m, justification::axiom(get_id()) }); + push_undo(is_register_shared); + } + + void ac_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case is_add_eq: { + m_eqs.pop_back(); + break; + } + case is_add_node: { + auto* n = m_node_trail.back(); + m_node_trail.pop_back(); + m_nodes[n->n->get_id()] = nullptr; + n->~node(); + break; + } + case is_add_monomial: { + m_monomials.pop_back(); + break; + } + case is_merge_node: { + auto [other, old_shared, old_eqs] = m_merge_trail.back(); + auto* root = other->root; + std::swap(other->next, root->next); + root->shared.shrink(old_shared); + root->eqs.shrink(old_eqs); + m_merge_trail.pop_back(); + ++m_tick; + break; + } + case is_update_eq: { + auto const& [idx, eq] = m_update_eq_trail.back(); + m_eqs[idx] = eq; + m_update_eq_trail.pop_back(); + break; + } + case is_add_shared_index: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->shared.pop_back(); + break; + } + case is_add_eq_index: { + auto n = m_node_trail.back(); + m_node_trail.pop_back(); + n->eqs.pop_back(); + break; + } + case is_register_shared: { + auto s = m_shared.back(); + m_shared_nodes[s.n->get_id()] = false; + m_shared.pop_back(); + break; + } + case is_update_shared: { + auto [id, s] = m_update_shared_trail.back(); + m_shared[id] = s; + m_update_shared_trail.pop_back(); + break; + } + default: + UNREACHABLE(); + } + } + + std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector const& m) const { + for (auto n : m) { + if (n->n->num_args() == 0) + out << mk_pp(n->n->get_expr(), g.get_manager()) << " "; + else + out << g.bpp(n->n) << " "; + } + return out; + } + + std::ostream& ac_plugin::display_equation(std::ostream& out, eq const& e) const { + display_status(out, e.status) << " "; + display_monomial(out, monomial(e.l)); + out << "== "; + display_monomial(out, monomial(e.r)); + return out; + } + + std::ostream& ac_plugin::display_status(std::ostream& out, eq_status s) const { + switch (s) { + case eq_status::is_dead: out << "d"; break; + case eq_status::processed: out << "p"; break; + case eq_status::to_simplify: out << "s"; break; + } + return out; + } + + std::ostream& ac_plugin::display(std::ostream& out) const { + unsigned i = 0; + for (auto const& eq : m_eqs) { + out << i << ": " << eq.l << " == " << eq.r << ": "; + display_equation(out, eq); + out << "\n"; + ++i; + } + i = 0; + for (auto m : m_monomials) { + out << i << ": "; + display_monomial(out, m); + out << "\n"; + ++i; + } + for (auto n : m_nodes) { + if (!n) + continue; + if (n->eqs.empty() && n->shared.empty()) + continue; + out << g.bpp(n->n) << " r: " << n->root_id() << " "; + if (!n->eqs.empty()) { + out << "eqs "; + for (auto l : n->eqs) + out << l << " "; + } + if (!n->shared.empty()) { + out << "shared "; + for (auto s : n->shared) + out << s << " "; + } + out << "\n"; + } + return out; + } + + void ac_plugin::merge_eh(enode* l, enode* r) { + if (l == r) + return; + auto j = justification::equality(l, r); + if (!is_op(l) && !is_op(r)) + merge(mk_node(l), mk_node(r), j); + else + init_equation(eq(to_monomial(l), to_monomial(r), j)); + } + + void ac_plugin::diseq_eh(enode* eq) { + SASSERT(g.get_manager().is_eq(eq->get_expr())); + enode* a = eq->get_arg(0), * b = eq->get_arg(1); + a = a->get_closest_th_node(m_fid); + b = b->get_closest_th_node(m_fid); + SASSERT(a && b); + register_shared(a); + register_shared(b); + } + + void ac_plugin::init_equation(eq const& e) { + m_eqs.push_back(e); + auto& eq = m_eqs.back(); + if (orient_equation(eq)) { + + unsigned eq_id = m_eqs.size() - 1; + + for (auto n : monomial(eq.l)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.r)) { + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq_id); + n->root->n->mark1(); + push_undo(is_add_eq_index); + m_node_trail.push_back(n->root); + } + } + + for (auto n : monomial(eq.l)) + n->root->n->unmark1(); + + for (auto n : monomial(eq.r)) + n->root->n->unmark1(); + + m_to_simplify_todo.insert(eq_id); + } + else + m_eqs.pop_back(); + } + + bool ac_plugin::orient_equation(eq& e) { + auto& ml = monomial(e.l); + auto& mr = monomial(e.r); + if (ml.size() > mr.size()) + return true; + if (ml.size() < mr.size()) { + std::swap(e.l, e.r); + return true; + } + else { + sort(ml); + sort(mr); + for (unsigned i = ml.size(); i-- > 0;) { + if (ml[i]->root_id() == mr[i]->root_id()) + continue; + if (ml[i]->root_id() < mr[i]->root_id()) + std::swap(e.l, e.r); + return true; + } + return false; + } + } + + void ac_plugin::sort(monomial_t& m) { + std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); }); + } + + bool ac_plugin::is_sorted(monomial_t const& m) const { + if (m.m_bloom.m_tick == m_tick) + return true; + for (unsigned i = m.size(); i-- > 1; ) + if (m[i - 1]->root_id() > m[i]->root_id()) + return false; + return true; + } + + uint64_t ac_plugin::filter(monomial_t& m) { + auto& bloom = m.m_bloom; + if (bloom.m_tick == m_tick) + return bloom.m_filter; + bloom.m_filter = 0; + for (auto n : m) + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + if (!is_sorted(m)) + sort(m); + bloom.m_tick = m_tick; + return bloom.m_filter; + } + + bool ac_plugin::can_be_subset(monomial_t& subset, monomial_t& superset) { + if (subset.size() > superset.size()) + return false; + auto f1 = filter(subset); + auto f2 = filter(superset); + return (f1 | f2) == f2; + } + + bool ac_plugin::can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& bloom) { + if (subset.size() > m.size()) + return false; + if (bloom.m_tick != m_tick) { + bloom.m_filter = 0; + for (auto n : m) + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_tick = m_tick; + } + auto f2 = bloom.m_filter; + return (filter(subset) | f2) == f2; + } + + void ac_plugin::merge(node* root, node* other, justification j) { + for (auto n : equiv(other)) + n->root = root; + m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() }); + for (auto eq_id : other->eqs) + set_status(eq_id, eq_status::to_simplify); + for (auto m : other->shared) + m_shared_todo.insert(m); + root->shared.append(other->shared); + root->eqs.append(other->eqs); + std::swap(root->next, other->next); + push_undo(is_merge_node); + ++m_tick; + } + + void ac_plugin::push_undo(undo_kind k) { + m_undo.push_back(k); + push_plugin_undo(get_id()); + m_undo_notify(); // tell main plugin to dispatch undo to this module. + } + + unsigned ac_plugin::to_monomial(enode* n) { + enode_vector& ns = m_todo; + ns.reset(); + ptr_vector m; + ns.push_back(n); + for (unsigned i = 0; i < ns.size(); ++i) { + n = ns[i]; + if (is_op(n)) + ns.append(n->num_args(), n->args()); + else + m.push_back(mk_node(n)); + } + return to_monomial(n, m); + } + + unsigned ac_plugin::to_monomial(enode* e, ptr_vector const& ms) { + unsigned id = m_monomials.size(); + m_monomials.push_back({ ms, bloom() }); + push_undo(is_add_monomial); + return id; + } + + ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) { + auto* mem = r.allocate(sizeof(node)); + node* res = new (mem) node(); + res->n = n; + res->root = res; + res->next = res; + return res; + } + + ac_plugin::node* ac_plugin::mk_node(enode* n) { + unsigned id = n->get_id(); + if (m_nodes.size() > id && m_nodes[id]) + return m_nodes[id]; + auto* r = node::mk(get_region(), n); + push_undo(is_add_node); + m_nodes.setx(id, r, nullptr); + m_node_trail.push_back(r); + return r; + } + + void ac_plugin::propagate() { + while (true) { + loop_start: + unsigned eq_id = pick_next_eq(); + if (eq_id == UINT_MAX) + break; + + TRACE("plugin", tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n"); + + // simplify eq using processed + for (auto other_eq : backward_iterator(eq_id)) + TRACE("plugin", tout << "backward iterator " << eq_id << " vs " << other_eq << " " << is_processed(other_eq) << "\n"); + for (auto other_eq : backward_iterator(eq_id)) + if (is_processed(other_eq) && backward_simplify(eq_id, other_eq)) + goto loop_start; + + set_status(eq_id, eq_status::processed); + + // simplify processed using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_processed(other_eq)) + forward_simplify(eq_id, other_eq); + + // superpose, create new equations + for (auto other_eq : superpose_iterator(eq_id)) + if (is_processed(other_eq)) + superpose(eq_id, other_eq); + + // simplify to_simplify using eq + for (auto other_eq : forward_iterator(eq_id)) + if (is_to_simplify(other_eq)) + forward_simplify(eq_id, other_eq); + } + propagate_shared(); + + CTRACE("plugin", !m_shared.empty() || !m_eqs.empty(), display(tout)); + } + + unsigned ac_plugin::pick_next_eq() { + while (!m_to_simplify_todo.empty()) { + unsigned id = *m_to_simplify_todo.begin(); + if (id < m_eqs.size() && is_to_simplify(id)) + return id; + m_to_simplify_todo.remove(id); + } + return UINT_MAX; + } + + // reorient equations when the status of equations are set to to_simplify. + void ac_plugin::set_status(unsigned id, eq_status s) { + auto& eq = m_eqs[id]; + if (eq.status == eq_status::is_dead) + return; + if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r))) + s = eq_status::is_dead; + + if (eq.status != s) { + m_update_eq_trail.push_back({ id, eq }); + eq.status = s; + push_undo(is_update_eq); + } + switch (s) { + case eq_status::processed: + case eq_status::is_dead: + m_to_simplify_todo.remove(id); + break; + case eq_status::to_simplify: + m_to_simplify_todo.insert(id); + orient_equation(eq); + break; + } + } + + // + // superpose iterator enumerates all equations where lhs of eq have element in common. + // + unsigned_vector const& ac_plugin::superpose_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r).m_nodes); + init_ref_counts(monomial(eq.l), m_src_l_counts); + init_overlap_iterator(eq_id, monomial(eq.l)); + return m_eq_occurs; + } + + // + // backward iterator allows simplification of eq + // The rhs of eq is a super-set of lhs of other eq. + // + unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) { + auto const& eq = m_eqs[eq_id]; + init_ref_counts(monomial(eq.r), m_dst_r_counts); + init_ref_counts(monomial(eq.l), m_dst_l_counts); + m_dst_r.reset(); + m_dst_r.append(monomial(eq.r).m_nodes); + init_subset_iterator(eq_id, monomial(eq.r)); + return m_eq_occurs; + } + + void ac_plugin::init_overlap_iterator(unsigned eq_id, monomial_t const& m) { + m_eq_occurs.reset(); + for (auto n : m) + m_eq_occurs.append(n->root->eqs); + compress_eq_occurs(eq_id); + } + + // + // add all but one of the use lists. Identify the largest use list and skip it. + // The rationale is that [a, b] is a subset of [a, b, c, d, e] if + // it has at least two elements (otherwise it would not apply as a rewrite over AC). + // then one of the two elements has to be in the set of [a, b, c, d, e] \ { x } + // where x is an arbitrary value from a, b, c, d, e. Not a two-element watch list, but still. + // + void ac_plugin::init_subset_iterator(unsigned eq_id, monomial_t const& m) { + unsigned max_use = 0; + node* max_n = nullptr; + bool has_two = false; + for (auto n : m) + if (n->root->eqs.size() >= max_use) + has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size(); + m_eq_occurs.reset(); + if (has_two) { + for (auto n : m) + if (n->root != max_n) + m_eq_occurs.append(n->root->eqs); + } + else { + for (auto n : m) { + m_eq_occurs.append(n->root->eqs); + break; + } + } + compress_eq_occurs(eq_id); + } + + // prune m_eq_occurs to single occurrences + void ac_plugin::compress_eq_occurs(unsigned eq_id) { + unsigned j = 0; + m_eq_seen.reserve(m_eqs.size() + 1, false); + for (unsigned i = 0; i < m_eq_occurs.size(); ++i) { + unsigned id = m_eq_occurs[i]; + if (m_eq_seen[id]) + continue; + if (id == eq_id) + continue; + m_eq_occurs[j++] = id; + m_eq_seen[id] = true; + } + m_eq_occurs.shrink(j); + for (auto id : m_eq_occurs) + m_eq_seen[id] = false; + } + + // + // forward iterator simplifies other eqs where their rhs is a superset of lhs of eq + // + unsigned_vector const& ac_plugin::forward_iterator(unsigned eq_id) { + auto& eq = m_eqs[eq_id]; + m_src_r.reset(); + m_src_r.append(monomial(eq.r).m_nodes); + init_ref_counts(monomial(eq.l), m_src_l_counts); + init_ref_counts(monomial(eq.r), m_src_r_counts); + unsigned min_r = UINT_MAX; + node* min_n = nullptr; + for (auto n : monomial(eq.l)) + if (n->root->eqs.size() < min_r) + min_n = n, min_r = n->root->eqs.size(); + // found node that occurs in fewest eqs + VERIFY(min_n); + return min_n->eqs; + } + + void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) const { + init_ref_counts(monomial.m_nodes, counts); + } + + void ac_plugin::init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const { + counts.reset(); + for (auto n : monomial) + counts.inc(n->root_id(), 1); + } + + bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const { + return is_correct_ref_count(m.m_nodes, counts); + } + + bool ac_plugin::is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const { + ref_counts check; + init_ref_counts(m, check); + return + all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) && + all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); + } + + void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) { + + if (src_eq == dst_eq) + return; + + // check that left src.l is a subset of dst.r + // dst = A -> BC + // src = B -> D + // post(dst) := A -> CD + auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized + auto& dst = m_eqs[dst_eq]; + + TRACE("plugin", tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n"); + + + if (forward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "forward subsumed\n"); + set_status(dst_eq, eq_status::is_dead); + return; + } + + if (!can_be_subset(monomial(src.l), monomial(dst.r))) + return; + + + m_dst_r_counts.reset(); + + unsigned src_l_size = monomial(src.l).size(); + unsigned src_r_size = m_src_r.size(); + + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + // subtract src.l from dst.r if src.l is a subset of dst.r + // dst_rhs := dst_rhs - src_lhs + src_rhs + // := src_rhs + (dst_rhs - src_lhs) + // := src_rhs + elements from dst_rhs that are in excess of src_lhs + unsigned num_overlap = 0; + for (auto n : monomial(dst.r)) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_r_counts[id]; + unsigned src_count = m_src_l_counts[id]; + if (dst_count > src_count) { + m_src_r.push_back(n); + m_dst_r_counts.dec(id, 1); + } + else if (dst_count < src_count) { + m_src_r.shrink(src_r_size); + return; + } + else + ++num_overlap; + } + // The dst.r has to be a superset of src.l, otherwise simplification does not apply + if (num_overlap != src_l_size) { + m_src_r.shrink(src_r_size); + return; + } + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m_src_r, j); + auto new_r = to_monomial(m_src_r); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = j; + push_undo(is_update_eq); + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); + TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + } + + bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { + if (src_eq == dst_eq) + return false; + + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; // pre-computed dst_r_counts, dst_l_counts + // + // dst_ids, dst_count contain rhs of dst_eq + // + TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); + + if (backward_subsumes(src_eq, dst_eq)) { + TRACE("plugin", tout << "backward subsumed\n"); + set_status(dst_eq, eq_status::is_dead); + return true; + } + // check that src.l is a subset of dst.r + if (!can_be_subset(monomial(src.l), monomial(dst.r))) + return false; + if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) { + TRACE("plugin", tout << "not subset\n"); + return false; + } + + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + + ptr_vector m(m_dst_r); + init_ref_counts(monomial(src.l), m_src_l_counts); + + rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m); + auto j = justify_rewrite(src_eq, dst_eq); + reduce(m, j); + auto new_r = to_monomial(m); + index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r)); + m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); + m_eqs[dst_eq].r = new_r; + m_eqs[dst_eq].j = j; + TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); + push_undo(is_update_eq); + return true; + } + + // dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l) + // dst_r_counts is pre-computed for monomial(dst.r). + // is dst_eq subsumed by src_eq? + bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + if (!can_be_subset(monomial(src.l), monomial(dst.l))) + return false; + if (!can_be_subset(monomial(src.r), monomial(dst.r))) + return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + return false; + if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) + return false; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + // add difference betwen dst.l and src.l to both src.l, src.r + for (auto n : monomial(dst.l)) { + unsigned id = n->root_id(); + SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; + if (diff > 0) { + m_src_l_counts.inc(id, diff); + m_src_r_counts.inc(id, diff); + } + } + // now dst.r and src.r should align and have the same elements. + // since src.r is a subset of dst.r we iterate over dst.r + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + } + + // src_l_counts, src_r_counts are initialized for src.l, src.r + bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) { + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); + SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); + if (!can_be_subset(monomial(src.l), monomial(dst.l))) + return false; + if (!can_be_subset(monomial(src.r), monomial(dst.r))) + return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l))) + return false; + if (!is_superset(m_src_r_counts, m_dst_r_counts, monomial(dst.r))) + return false; + SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); + SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; + if (diff == 0) + continue; + m_dst_l_counts.dec(id, diff); + if (m_dst_r_counts[id] < diff) + return false; + m_dst_r_counts.dec(id, diff); + } + + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + } + + void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector& dst) { + // pre-condition: is-subset is invoked so that src_l is initialized. + // pre-condition: dst_count is also initialized. + // remove from dst elements that are in src_l + // add elements from src_r + SASSERT(is_correct_ref_count(dst, dst_counts)); + SASSERT(&src_r.m_nodes != &dst); + unsigned sz = dst.size(), j = 0; + for (unsigned i = 0; i < sz; ++i) { + auto* n = dst[i]; + unsigned id = n->root_id(); + unsigned dst_count = dst_counts[id]; + unsigned src_count = src_l[id]; + SASSERT(dst_count > 0); + if (src_count == 0) + dst[j++] = n; + else if (src_count < dst_count) { + dst[j++] = n; + dst_counts.dec(id, 1); + } + } + dst.shrink(j); + dst.append(src_r.m_nodes); + } + + // rewrite monomial to normal form. + bool ac_plugin::reduce(ptr_vector& m, justification& j) { + bool change = false; + do { + init_loop: + if (m.size() == 1) + return change; + bloom b; + init_ref_counts(m, m_m_counts); + for (auto n : m) { + for (auto eq : n->root->eqs) { + if (!is_processed(eq)) + continue; + auto& src = m_eqs[eq]; + + if (!can_be_subset(monomial(src.l), m, b)) + continue; + if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l))) + continue; + TRACE("plugin", display_equation(tout << "reduce ", src) << "\n"); + SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts)); + rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m); + j = join(j, eq); + change = true; + goto init_loop; + } + } + } + while (false); + return change; + } + + // check that src is a subset of dst, where dst_counts are precomputed + bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) { + SASSERT(&dst_counts != &src_counts); + init_ref_counts(src, src_counts); + return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); + } + + // check that dst is a superset of src, where src_counts are precomputed + bool ac_plugin::is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst) { + SASSERT(&dst_counts != &src_counts); + init_ref_counts(dst, dst_counts); + return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; }); + } + + void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) { + for (auto n : old_r) + n->root->n->mark1(); + for (auto n : new_r) + if (!n->root->n->is_marked1()) { + n->root->eqs.push_back(eq); + m_node_trail.push_back(n->root); + n->root->n->mark1(); + push_undo(is_add_eq_index); + } + for (auto n : old_r) + n->root->n->unmark1(); + for (auto n : new_r) + n->root->n->unmark1(); + } + + + void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { + if (src_eq == dst_eq) + return; + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + + TRACE("plugin", tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";); + // AB -> C, AD -> E => BE ~ CD + // m_src_ids, m_src_counts contains information about src (call it AD -> E) + m_dst_l_counts.reset(); + + m_dst_r.reset(); + m_dst_r.append(monomial(dst.r).m_nodes); + unsigned src_r_size = m_src_r.size(); + unsigned dst_r_size = m_dst_r.size(); + SASSERT(src_r_size == monomial(src.r).size()); + // dst_r contains C + // src_r contains E + + // compute BE, initialize dst_ids, dst_counts + bool overlap = false; + for (auto n : monomial(dst.l)) { + unsigned id = n->root_id(); + m_dst_l_counts.inc(id, 1); + if (m_src_l_counts[id] < m_dst_l_counts[id]) + m_src_r.push_back(n); + overlap |= m_src_l_counts[id] > 0; + } + + if (!overlap) { + m_src_r.shrink(src_r_size); + return; + } + + // compute CD + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + if (m_dst_l_counts[id] > 0) + m_dst_l_counts.dec(id, 1); + else + m_dst_r.push_back(n); + } + + if (are_equal(m_src_r, m_dst_r)) { + m_src_r.shrink(src_r_size); + return; + } + + TRACE("plugin", tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";); + + justification j = justify_rewrite(src_eq, dst_eq); + reduce(m_dst_r, j); + reduce(m_src_r, j); + if (m_src_r.size() == 1 && m_dst_r.size() == 1) + push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); + else + init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); + + m_src_r.reset(); + m_src_r.append(monomial(src.r).m_nodes); + } + + bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { + return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes); + } + + bool ac_plugin::are_equal(ptr_vector const& a, ptr_vector const& b) { + if (a.size() != b.size()) + return false; + m_eq_counts.reset(); + for (auto n : a) + m_eq_counts.inc(n->root_id(), 1); + + for (auto n : b) { + unsigned id = n->root_id(); + if (m_eq_counts[id] == 0) + return false; + m_eq_counts.dec(id, 1); + } + return true; + } + + // + // simple version based on propagating all shared + // todo: version touching only newly processed shared, and maintaining incremental data-structures. + // - hash-tables for shared monomials similar to the ones used for euf_table. + // the tables have to be updated (and re-sorted) whenever a child changes root. + // + + void ac_plugin::propagate_shared() { + if (m_shared_todo.empty()) + return; + while (!m_shared_todo.empty()) { + auto idx = *m_shared_todo.begin(); + m_shared_todo.remove(idx); + if (idx < m_shared.size()) + simplify_shared(idx, m_shared[idx]); + } + m_monomial_table.reset(); + for (auto const& s1 : m_shared) { + shared s2; + TRACE("plugin", tout << "shared " << m_pp(*this, monomial(s1.m)) << "\n"); + if (!m_monomial_table.find(s1.m, s2)) + m_monomial_table.insert(s1.m, s1); + else if (s2.n->get_root() != s1.n->get_root()) { + TRACE("plugin", tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n"); + push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); + } + } + } + + void ac_plugin::simplify_shared(unsigned idx, shared s) { + auto j = s.j; + auto old_m = s.m; + ptr_vector m1(monomial(old_m).m_nodes); + TRACE("plugin", tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n"); + if (!reduce(m1, j)) + return; + + auto new_m = to_monomial(m1); + // update shared occurrences for members of the new monomial that are not already in the old monomial. + for (auto n : monomial(old_m)) + n->root->n->mark1(); + for (auto n : m1) + if (!n->root->n->is_marked1()) { + n->root->shared.push_back(idx); + m_shared_todo.insert(idx); + m_node_trail.push_back(n->root); + push_undo(is_add_shared_index); + } + for (auto n : monomial(old_m)) + n->root->n->unmark1(); + m_update_shared_trail.push_back({ idx, s }); + push_undo(is_update_shared); + m_shared[idx].m = new_m; + m_shared[idx].j = j; + } + + justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) { + auto* j = m_dep_manager.mk_join(justify_equation(eq1), justify_equation(eq2)); + return justification::dependent(j); + } + + justification::dependency* ac_plugin::justify_equation(unsigned eq) { + auto const& e = m_eqs[eq]; + auto* j = m_dep_manager.mk_leaf(e.j); + j = justify_monomial(j, monomial(e.l)); + j = justify_monomial(j, monomial(e.r)); + return j; + } + + justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, monomial_t const& m) { + for (auto n : m) + if (n->root->n != n->n) + j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); + return j; + } + + justification ac_plugin::join(justification j, unsigned eq) { + return justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(j), justify_equation(eq))); + } + +} diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h new file mode 100644 index 00000000000..7da346dbafa --- /dev/null +++ b/src/ast/euf/euf_ac_plugin.h @@ -0,0 +1,309 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_ac_plugin.h + +Abstract: + + plugin structure for ac functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +ex: +xyz -> xy, then xyzz -> xy by repeated rewriting + +monomials = [0 |-> xyz, 1 |-> xy, 2 |-> xyzz] +parents(x) = [0, 1, 2] +parents(z) = [0, 1] +for p in parents(xyzz): + p != xyzz + p' := simplify_using(xyzz, p) + if p != p': + repeat reduction using p := p' + +--*/ + +#pragma once + +#include +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class ac_plugin : public plugin { + + // enode structure for AC equivalences + struct node { + enode* n; // associated enode + node* root; // path compressed root + node* next; // next in equivalence class + justification j; // justification for equality + node* target = nullptr; // justified next + unsigned_vector shared; // shared occurrences + unsigned_vector eqs; // equality occurrences + + unsigned root_id() const { return root->n->get_id(); } + ~node() {} + static node* mk(region& r, enode* n); + }; + + class equiv { + node& n; + public: + class iterator { + node* m_first; + node* m_last; + public: + iterator(node* n, node* m) : m_first(n), m_last(m) {} + node* operator*() { return m_first; } + iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; } + iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } + bool operator==(iterator const& other) const { return m_last == other.m_last && m_first == other.m_first; } + bool operator!=(iterator const& other) const { return !(*this == other); } + }; + equiv(node& _n) :n(_n) {} + equiv(node* _n) :n(*_n) {} + iterator begin() const { return iterator(&n, nullptr); } + iterator end() const { return iterator(&n, &n); } + }; + + struct bloom { + uint64_t m_tick = 0; + uint64_t m_filter = 0; + }; + + enum eq_status { + processed, to_simplify, is_dead + }; + + // represent equalities added by merge_eh and by superposition + struct eq { + eq(unsigned l, unsigned r, justification j): + l(l), r(r), j(j) {} + unsigned l, r; // refer to monomials + eq_status status = to_simplify; + justification j; // justification for equality + }; + + // represent shared enodes that use the AC symbol. + struct shared { + enode* n; // original shared enode + unsigned m; // monomial index + justification j; // justification for current simplification of monomial + }; + + struct monomial_t { + ptr_vector m_nodes; + bloom m_bloom; + node* operator[](unsigned i) const { return m_nodes[i]; } + unsigned size() const { return m_nodes.size(); } + void set(ptr_vector const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; } + node* const* begin() const { return m_nodes.begin(); } + node* const* end() const { return m_nodes.end(); } + node* * begin() { return m_nodes.begin(); } + node* * end() { return m_nodes.end(); } + }; + + + struct monomial_hash { + ac_plugin& p; + monomial_hash(ac_plugin& p) :p(p) {} + unsigned operator()(unsigned i) const { + unsigned h = 0; + auto& m = p.monomial(i); + if (!p.is_sorted(m)) + p.sort(m); + for (auto* n : m) + h = combine_hash(h, n->root_id()); + return h; + } + }; + + struct monomial_eq { + ac_plugin& p; + monomial_eq(ac_plugin& p) :p(p) {} + bool operator()(unsigned i, unsigned j) const { + auto const& m1 = p.monomial(i); + auto const& m2 = p.monomial(j); + if (m1.size() != m2.size()) return false; + for (unsigned k = 0; k < m1.size(); ++k) + if (m1[k]->root_id() != m2[k]->root_id()) + return false; + return true; + } + }; + + theory_id m_fid = 0; + unsigned m_op = null_decl_kind; + func_decl* m_decl = nullptr; + vector m_eqs; + ptr_vector m_nodes; + bool_vector m_shared_nodes; + vector m_monomials; + svector m_shared; + justification::dependency_manager m_dep_manager; + tracked_uint_set m_to_simplify_todo; + tracked_uint_set m_shared_todo; + uint64_t m_tick = 1; + + + + monomial_hash m_hash; + monomial_eq m_eq; + map m_monomial_table; + + + // backtrackable state + enum undo_kind { + is_add_eq, + is_add_monomial, + is_add_node, + is_merge_node, + is_update_eq, + is_add_shared_index, + is_add_eq_index, + is_register_shared, + is_update_shared + }; + svector m_undo; + ptr_vector m_node_trail; + + svector> m_update_shared_trail; + svector> m_merge_trail; + svector> m_update_eq_trail; + + + + node* mk_node(enode* n); + void merge(node* r1, node* r2, justification j); + + bool is_op(enode* n) const { auto d = n->get_decl(); return d && (d == m_decl || (m_fid == d->get_family_id() && m_op == d->get_decl_kind())); } + + std::function m_undo_notify; + void push_undo(undo_kind k); + enode_vector m_todo; + unsigned to_monomial(enode* n); + unsigned to_monomial(enode* n, ptr_vector const& ms); + unsigned to_monomial(ptr_vector const& ms) { return to_monomial(nullptr, ms); } + monomial_t const& monomial(unsigned i) const { return m_monomials[i]; } + monomial_t& monomial(unsigned i) { return m_monomials[i]; } + void sort(monomial_t& monomial); + bool is_sorted(monomial_t const& monomial) const; + uint64_t filter(monomial_t& m); + bool can_be_subset(monomial_t& subset, monomial_t& superset); + bool can_be_subset(monomial_t& subset, ptr_vector const& m, bloom& b); + bool are_equal(ptr_vector const& a, ptr_vector const& b); + bool are_equal(monomial_t& a, monomial_t& b); + bool backward_subsumes(unsigned src_eq, unsigned dst_eq); + bool forward_subsumes(unsigned src_eq, unsigned dst_eq); + + void init_equation(eq const& e); + bool orient_equation(eq& e); + void set_status(unsigned eq_id, eq_status s); + unsigned pick_next_eq(); + + void forward_simplify(unsigned eq_id, unsigned using_eq); + bool backward_simplify(unsigned eq_id, unsigned using_eq); + void superpose(unsigned src_eq, unsigned dst_eq); + + ptr_vector m_src_r, m_src_l, m_dst_r, m_dst_l; + + struct ref_counts { + unsigned_vector ids; + unsigned_vector counts; + void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); } + unsigned operator[](unsigned idx) const { return counts.get(idx, 0); } + void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; } + void dec(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] -= amount; } + unsigned const* begin() const { return ids.begin(); } + unsigned const* end() const { return ids.end(); } + }; + ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts, m_m_counts; + unsigned_vector m_eq_occurs; + bool_vector m_eq_seen; + + unsigned_vector const& forward_iterator(unsigned eq); + unsigned_vector const& superpose_iterator(unsigned eq); + unsigned_vector const& backward_iterator(unsigned eq); + void init_ref_counts(monomial_t const& monomial, ref_counts& counts) const; + void init_ref_counts(ptr_vector const& monomial, ref_counts& counts) const; + void init_overlap_iterator(unsigned eq, monomial_t const& m); + void init_subset_iterator(unsigned eq, monomial_t const& m); + void compress_eq_occurs(unsigned eq_id); + // check that src is a subset of dst, where dst_counts are precomputed + bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src); + + // check that dst is a superset of dst, where src_counts are precomputed + bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst); + void rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_r_counts, ptr_vector& dst_r); + bool reduce(ptr_vector& m, justification& j); + void index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r); + + bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } + bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } + bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; } + + justification justify_rewrite(unsigned eq1, unsigned eq2); + justification::dependency* justify_equation(unsigned eq); + justification::dependency* justify_monomial(justification::dependency* d, monomial_t const& m); + justification join(justification j1, unsigned eq); + + bool is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const; + bool is_correct_ref_count(ptr_vector const& m, ref_counts const& counts) const; + + void register_shared(enode* n); + void propagate_shared(); + void simplify_shared(unsigned idx, shared s); + + std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); } + std::ostream& display_monomial(std::ostream& out, ptr_vector const& m) const; + std::ostream& display_equation(std::ostream& out, eq const& e) const; + std::ostream& display_status(std::ostream& out, eq_status s) const; + + + public: + + ac_plugin(egraph& g, unsigned fid, unsigned op); + + ac_plugin(egraph& g, func_decl* f); + + ~ac_plugin() override {} + + theory_id get_id() const override { return m_fid; } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override; + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + void set_undo(std::function u) { m_undo_notify = u; } + + struct eq_pp { + ac_plugin& p; eq const& e; + eq_pp(ac_plugin& p, eq const& e) : p(p), e(e) {}; + eq_pp(ac_plugin& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {} + std::ostream& display(std::ostream& out) const { return p.display_equation(out, e); } + }; + + struct m_pp { + ac_plugin& p; ptr_vector const& m; + m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {} + m_pp(ac_plugin& p, ptr_vector const& m) : p(p), m(m) {} + std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); } + }; + }; + + inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp const& d) { return d.display(out); } + inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp const& d) { return d.display(out); } +} diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp new file mode 100644 index 00000000000..268eff38d81 --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -0,0 +1,71 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.cpp + +Abstract: + + plugin structure for arithmetic + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_arith_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + arith_plugin::arith_plugin(egraph& g) : + plugin(g), + a(g.get_manager()), + m_add(g, get_id(), OP_ADD), + m_mul(g, get_id(), OP_MUL) { + std::function uadd = [&]() { m_undo.push_back(undo_t::undo_add); }; + m_add.set_undo(uadd); + std::function umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; + m_mul.set_undo(umul); + } + + void arith_plugin::register_node(enode* n) { + // no-op + } + + void arith_plugin::merge_eh(enode* n1, enode* n2) { + m_add.merge_eh(n1, n2); + m_mul.merge_eh(n1, n2); + } + + void arith_plugin::propagate() { + m_add.propagate(); + m_mul.propagate(); + } + + void arith_plugin::undo() { + auto k = m_undo.back(); + m_undo.pop_back(); + switch (k) { + case undo_t::undo_add: + m_add.undo(); + break; + case undo_t::undo_mul: + m_mul.undo(); + break; + default: + UNREACHABLE(); + } + } + + std::ostream& arith_plugin::display(std::ostream& out) const { + out << "add\n"; + m_add.display(out); + out << "mul\n"; + m_mul.display(out); + return out; + } +} diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h new file mode 100644 index 00000000000..4c2a88d3603 --- /dev/null +++ b/src/ast/euf/euf_arith_plugin.h @@ -0,0 +1,53 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_arith_plugin.h + +Abstract: + + plugin structure for arithmetic +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include "ast/arith_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class egraph; + + class arith_plugin : public plugin { + enum undo_t { undo_add, undo_mul }; + arith_util a; + svector m_undo; + ac_plugin m_add, m_mul; + + public: + arith_plugin(egraph& g); + + ~arith_plugin() override {} + + theory_id get_id() const override { return a.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override {} + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + }; +} diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp new file mode 100644 index 00000000000..4b0cd6cfc0f --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -0,0 +1,581 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.cpp + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + +Objective: + +satisfies extract/concat axioms. + - concat(n{I],n[J]) = n[IJ] for I, J consecutive. + - concat(v1, v2) = 2^width(v1)*v2 + v1 + - concat(n[width(n)-1:0]) = n + - concat(a, b)[I] = concat(a[I1], b[I2]) + - concat(a, concat(b, c)) = concat(concat(a, b), c) + +E-graph: + +The E-graph contains node definitions of the form + + n := f(n1,n2,..) + +and congruences: + + n ~ n' means root(n) = root(n') + +Saturated state: + + 1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals + + 2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ] + + 3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1 + + 4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I]) + + 5. n' := n[I] => n ~ n[width(n)-1:0] + + 6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c) + - handled by rewriter pre-processing for inputs + - terms created internally are not equated modulo associativity + + 7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2] + - handled by rewriter pre-processing + +Example: + x == (x1 x2) x3 + y == y1 (y2 y3) + x1 == y1, x2 == y2, x3 == y3 + => + x = y + + by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3) + by (2) => x[I23] = (x2 x3) + by (2) => x[I123] = (x1 (x2 x3)) + by (5) => x = x[I123] + +The formal properties of saturation have to be established. + +- Saturation does not complete with respect to associativity. +Instead the claim is along the lines that the resulting E-graph can be used as a canonizer. +If given a set of equations E that are saturated, and terms t1, t2 that are +both simplified with respect to left-associativity of concatentation, and t1, t2 belong to the E-graph, +then t1 = t2 iff t1 ~ t2 in the E-graph. + +TODO: Is saturation for (7) overkill for the purpose of canonization? + +--*/ + + +#include "ast/ast_pp.h" +#include "ast/euf/euf_bv_plugin.h" +#include "ast/euf/euf_egraph.h" + +namespace euf { + + bv_plugin::bv_plugin(egraph& g): + plugin(g), + bv(g.get_manager()) + {} + + enode* bv_plugin::mk_value_concat(enode* hi, enode* lo) { + auto v1 = get_value(hi); + auto v2 = get_value(lo); + auto v3 = v2 + v1 * rational::power_of_two(width(lo)); + return mk_value(v3, width(lo) + width(hi)); + } + + enode* bv_plugin::mk_value(rational const& v, unsigned sz) { + auto e = bv.mk_numeral(v, sz); + auto n = mk(e, 0, nullptr); + if (m_ensure_th_var) + m_ensure_th_var(n); + return n; + } + + void bv_plugin::propagate_merge(enode* x, enode* y) { + if (!bv.is_bv(x->get_expr())) + return; + + TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n"); + SASSERT(!m_internal); + flet _internal(m_internal, true); + + propagate_values(x); + + // ensure slices align + if (has_sub(x) || has_sub(y)) { + enode_vector& xs = m_xs, & ys = m_ys; + xs.reset(); + ys.reset(); + xs.push_back(x); + ys.push_back(y); + merge(xs, ys, justification::equality(x, y)); + } + + // ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + for (auto* n : enode_class(x)) + propagate_extract(n); + } + + void bv_plugin::register_node(enode* n) { + m_queue.push_back(n); + m_trail.push_back(new (get_region()) push_back_vector(m_queue)); + push_plugin_undo(bv.get_family_id()); + } + + void bv_plugin::merge_eh(enode* n1, enode* n2) { + m_queue.push_back(enode_pair(n1, n2)); + m_trail.push_back(new (get_region()) push_back_vector(m_queue)); + push_plugin_undo(bv.get_family_id()); + } + + void bv_plugin::propagate() { + if (m_qhead == m_queue.size()) + return; + m_trail.push_back(new (get_region()) value_trail(m_qhead)); + push_plugin_undo(bv.get_family_id()); + for (; m_qhead < m_queue.size(); ++m_qhead) { + if (std::holds_alternative(m_queue[m_qhead])) { + auto n = *std::get_if(&m_queue[m_qhead]); + propagate_register_node(n); + } + else { + auto [a, b] = *std::get_if(&m_queue[m_qhead]); + propagate_merge(a, b); + } + } + } + + // enforce concat(v1, v2) = v1*2^|v2| + v2 + void bv_plugin::propagate_values(enode* x) { + if (!is_value(x)) + return; + + auto val_x = get_value(x); + enode* a, * b; + unsigned lo, hi; + for (enode* p : enode_parents(x)) { + if (is_concat(p, a, b) && is_value(a) && is_value(b)) + push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b)); + + if (is_extract(p, lo, hi)) { + auto val_p = mod2k(machine_div2k(val_x, lo), hi - lo + 1); + auto ix = x->get_interpreted(); + auto ex = mk(bv.mk_extract(hi, lo, ix->get_expr()), 1, &ix); + push_merge(ex, mk_value(val_p, width(p))); + } + } + + for (enode* sib : enode_class(x)) { + if (is_concat(sib, a, b)) { + auto val_a = machine_div2k(val_x, width(b)); + auto val_b = mod2k(val_x, width(b)); + push_merge(mk_concat(mk_value(val_a, width(a)), mk_value(val_b, width(b))), x->get_interpreted()); + } + } + } + + // + // p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ] + // + // n is of form arg[I] + // p is of form concat(n, b) or concat(a, n) + // b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p + // a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p + // + + void bv_plugin::propagate_extract(enode* n) { + unsigned lo1, hi1, lo2, hi2; + enode* a, * b; + if (!is_extract(n, lo1, hi1)) + return; + + enode* arg = n->get_arg(0); + enode* arg_r = arg->get_root(); + enode* n_r = n->get_root(); + + m_ensure_concat.reset(); + auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) { + // verbose_stream() << lo << " " << mid << " " << hi << "\n"; + TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n"); + unsigned lo_, hi_; + for (enode* p1 : enode_parents(n)) + if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r) + return; + // add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications + push_merge(mk_concat(mk_extract(arg, mid + 1, hi), mk_extract(arg, lo, mid)), mk_extract(arg, lo, hi)); + }; + + auto propagate_above = [&](enode* b) { + TRACE("bv", tout << "propagate-above " << g.bpp(b) << "\n"); + for (enode* sib : enode_class(b)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2) + m_ensure_concat.push_back({lo1, hi1, hi2}); + }; + + auto propagate_below = [&](enode* a) { + TRACE("bv", tout << "propagate-below " << g.bpp(a) << "\n"); + for (enode* sib : enode_class(a)) + if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1) + m_ensure_concat.push_back({lo2, hi2, hi1}); + }; + + for (enode* p : enode_parents(n)) { + if (is_concat(p, a, b)) { + if (a->get_root() == n_r) + propagate_below(b); + if (b->get_root() == n_r) + propagate_above(a); + } + } + + for (auto [lo, mid, hi] : m_ensure_concat) + ensure_concat(lo, mid, hi); + + } + + class bv_plugin::undo_split : public trail { + bv_plugin& p; + enode* n; + public: + undo_split(bv_plugin& p, enode* n): p(p), n(n) {} + void undo() override { + auto& i = p.info(n); + i.value = nullptr; + i.lo = nullptr; + i.hi = nullptr; + i.cut = null_cut; + } + }; + + void bv_plugin::push_undo_split(enode* n) { + m_trail.push_back(new (get_region()) undo_split(*this, n)); + push_plugin_undo(bv.get_family_id()); + } + + void bv_plugin::undo() { + m_trail.back()->undo(); + m_trail.pop_back(); + } + + + void bv_plugin::propagate_register_node(enode* n) { + TRACE("bv", tout << "register " << g.bpp(n) << "\n"); + enode* a, * b; + unsigned lo, hi; + if (is_concat(n, a, b)) { + auto& i = info(n); + i.value = n; + i.hi = a; + i.lo = b; + i.cut = width(b); + push_undo_split(n); + } + else if (is_concat(n) && n->num_args() != 2) { + SASSERT(n->num_args() != 0); + auto last = n->get_arg(n->num_args() - 1); + for (unsigned i = n->num_args() - 1; i-- > 0;) + last = mk_concat(n->get_arg(i), last); + push_merge(last, n); + } + else if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) { + enode* arg = n->get_arg(0); + unsigned w = width(arg); + if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; })) + push_merge(mk_extract(arg, 0, w - 1), arg); + ensure_slice(arg, lo, hi); + } + TRACE("bv", tout << "done register " << g.bpp(n) << "\n"); + } + + // + // Ensure that there are slices at boundaries of n[hi:lo] + // + void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) { + enode* r = n; + unsigned lb = 0, ub = width(n) - 1; + while (true) { + TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n"); + SASSERT(lb <= lo && hi <= ub); + SASSERT(ub - lb + 1 == width(r)); + if (lb == lo && ub == hi) + return; + slice_info const& i = info(r); + + if (!i.lo) { + if (lo > lb) { + split(r, lo - lb); + if (hi < ub) // or split(info(r).hi, ...) + ensure_slice(n, lo, hi); + } + else if (hi < ub) + split(r, ub - hi); + break; + } + auto cut = i.cut; + if (cut + lb <= lo) { + lb += cut; + r = i.hi; + continue; + } + if (cut + lb > hi) { + ub = cut + lb - 1; + r = i.lo; + continue; + } + SASSERT(lo < cut + lb && cut + lb <= hi); + ensure_slice(n, lo, cut + lb - 1); + ensure_slice(n, cut + lb, hi); + break; + } + } + + enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) { + SASSERT(lo <= hi && width(n) > hi - lo); + unsigned lo1, hi1; + while (is_extract(n, lo1, hi1)) { + lo += lo1; + hi += lo1; + n = n->get_arg(0); + } + if (n->interpreted()) { + auto v = get_value(n); + if (lo > 0) + v = div(v, rational::power_of_two(lo)); + if (hi + 1 != width(n)) + v = mod(v, rational::power_of_two(hi + 1)); + return mk_value(v, hi - lo + 1); + } + return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n); + } + + enode* bv_plugin::mk_concat(enode* hi, enode* lo) { + enode* args[2] = { hi, lo }; + return mk(bv.mk_concat(hi->get_expr(), lo->get_expr()), 2, args); + } + + void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + auto x = xs.back(); + auto y = ys.back(); + TRACE("bv", tout << "merge " << g.bpp(x) << " " << g.bpp(y) << "\n"); + if (unfold_sub(x, xs)) + continue; + else if (unfold_sub(y, ys)) + continue; + else if (unfold_width(x, xs, y, ys)) + continue; + else if (unfold_width(y, ys, x, xs)) + continue; + else if (x->get_root() != y->get_root()) + push_merge(x, y, dep); + xs.pop_back(); + ys.pop_back(); + } + SASSERT(ys.empty()); + } + + bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) { + if (!has_sub(x)) + return false; + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) { + if (width(x) <= width(y)) + return false; + split(x, width(y)); + xs.pop_back(); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + return true; + } + + void bv_plugin::split(enode* n, unsigned cut) { + TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n"); + unsigned w = width(n); + SASSERT(!info(n).hi); + SASSERT(0 < cut && cut < w); + enode* hi = mk_extract(n, cut, w - 1); + enode* lo = mk_extract(n, 0, cut - 1); + auto& i = info(n); + i.value = n; + i.hi = hi; + i.lo = lo; + i.cut = cut; + push_undo_split(n); + push_merge(mk_concat(hi, lo), n); + } + + void bv_plugin::sub_slices(enode* n, std::function& consumer) { + m_todo.push_back({ n, 0 }); + unsigned lo, hi; + expr* e; + + for (unsigned i = 0; i < m_todo.size(); ++i) { + auto [n, offset] = m_todo[i]; + m_offsets.reserve(n->get_root_id() + 1); + auto& offsets = m_offsets[n->get_root_id()]; + if (offsets.contains(offset)) + continue; + offsets.push_back(offset); + if (!consumer(n, offset)) + continue; + for (auto sib : euf::enode_class(n)) { + if (bv.is_concat(sib->get_expr())) { + unsigned delta = 0; + for (unsigned j = sib->num_args(); j-- > 0; ) { + auto arg = sib->get_arg(j); + m_todo.push_back({ arg, offset + delta }); + delta += width(arg); + } + } + for (auto p : euf::enode_parents(sib)) { + if (bv.is_extract(p->get_expr(), lo, hi, e)) { + SASSERT(g.find(e)->get_root() == n->get_root()); + m_todo.push_back({ p, offset + lo }); + } + } + } + + } + clear_offsets(); + } + + void bv_plugin::super_slices(enode* n, std::function& consumer) { + m_todo.push_back({ n, 0 }); + unsigned lo, hi; + expr* e; + + for (unsigned i = 0; i < m_todo.size(); ++i) { + auto [n, offset] = m_todo[i]; + m_offsets.reserve(n->get_root_id() + 1); + auto& offsets = m_offsets[n->get_root_id()]; + if (offsets.contains(offset)) + continue; + offsets.push_back(offset); + if (!consumer(n, offset)) + continue; + for (auto sib : euf::enode_class(n)) { + if (bv.is_extract(sib->get_expr(), lo, hi, e)) { + auto child = g.find(e); + m_todo.push_back({ child, offset + lo }); + } + for (auto p : euf::enode_parents(sib)) { + if (bv.is_concat(p->get_expr())) { + unsigned delta = 0; + for (unsigned j = p->num_args(); j-- > 0; ) { + auto arg = p->get_arg(j); + if (arg->get_root() == n->get_root()) + m_todo.push_back({ p, offset + delta }); + delta += width(arg); + } + } + } + } + } + clear_offsets(); + } + + // + // Explain that a is a subslice of b at offset + // or that b is a subslice of a at offset + // + void bv_plugin::explain_slice(enode* a, unsigned offset, enode* b, std::function& consumer) { + if (width(a) < width(b)) + std::swap(a, b); + SASSERT(width(a) >= width(b)); + svector> just; + m_jtodo.push_back({ a, 0, UINT_MAX }); + unsigned lo, hi; + expr* e; + + for (unsigned i = 0; i < m_jtodo.size(); ++i) { + auto [n, offs, j] = m_jtodo[i]; + m_offsets.reserve(n->get_root_id() + 1); + auto& offsets = m_offsets[n->get_root_id()]; + if (offsets.contains(offs)) + continue; + offsets.push_back(offs); + if (n->get_root() == b->get_root() && offs == offset) { + if (n != b) + consumer(n, b); + while (j != UINT_MAX) { + auto [x, y, j2] = just[j]; + if (x != y) + consumer(x, y); + j = j2; + } + for (auto const& [n, offset, j] : m_jtodo) { + m_offsets.reserve(n->get_root_id() + 1); + m_offsets[n->get_root_id()].reset(); + } + DEBUG_CODE( + for (auto const& off : m_offsets) { + VERIFY(off.empty()); + }); + m_jtodo.reset(); + return; + } + for (auto sib : euf::enode_class(n)) { + if (bv.is_concat(sib->get_expr())) { + unsigned delta = 0; + unsigned j2 = just.size(); + just.push_back({ n, sib, j }); + for (unsigned j = sib->num_args(); j-- > 0; ) { + auto arg = sib->get_arg(j); + m_jtodo.push_back({ arg, offs + delta, j2 }); + delta += width(arg); + } + } + for (auto p : euf::enode_parents(sib)) { + if (bv.is_extract(p->get_expr(), lo, hi, e)) { + SASSERT(g.find(e)->get_root() == n->get_root()); + unsigned j2 = just.size(); + just.push_back({ g.find(e), n, j }); + m_jtodo.push_back({ p, offs + lo, j2 }); + } + } + } + + } + IF_VERBOSE(0, + g.display(verbose_stream()); + verbose_stream() << g.bpp(a) << " offset " << offset << " " << g.bpp(b) << "\n"; + for (auto const& [n, offset, j] : m_jtodo) + verbose_stream() << g.bpp(n) << " offset " << offset << " " << g.bpp(n->get_root()) << "\n"; + ); + UNREACHABLE(); + } + + void bv_plugin::clear_offsets() { + for (auto const& [n, offset] : m_todo) { + m_offsets.reserve(n->get_root_id() + 1); + m_offsets[n->get_root_id()].reset(); + } + m_todo.reset(); + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + out << "bv\n"; + for (auto const& i : m_info) + if (i.lo) + out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n"; + return out; + } +} diff --git a/src/ast/euf/euf_bv_plugin.h b/src/ast/euf/euf_bv_plugin.h new file mode 100644 index 00000000000..ec2c0b448f4 --- /dev/null +++ b/src/ast/euf/euf_bv_plugin.h @@ -0,0 +1,123 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_bv_plugin.h + +Abstract: + + plugin structure for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + Jakob Rath 2023-11-08 + + +--*/ + +#pragma once + +#include "util/trail.h" +#include "ast/bv_decl_plugin.h" +#include "ast/euf/euf_plugin.h" + +namespace euf { + + class egraph; + + class bv_plugin : public plugin { + static constexpr unsigned null_cut = std::numeric_limits::max(); + + struct slice_info { + unsigned cut = null_cut; // = bv.get_bv_size(lo) + enode* hi = nullptr; // + enode* lo = nullptr; // + enode* value = nullptr; + void reset() { *this = slice_info(); } + }; + using slice_info_vector = svector; + + bv_util bv; + slice_info_vector m_info; // indexed by enode::get_id() + enode_vector m_xs, m_ys; + + std::function m_ensure_th_var; + + bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); } + bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && n->num_args() == 2 && (a = n->get_arg(0), b = n->get_arg(1), true); } + bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); } + bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); } + unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); } + + enode* mk_extract(enode* n, unsigned lo, unsigned hi); + enode* mk_concat(enode* hi, enode* lo); + enode* mk_value_concat(enode* hi, enode* lo); + enode* mk_value(rational const& v, unsigned sz); + unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); } + bool is_value(enode* n) { return n->get_root()->interpreted(); } + rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; } + slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; } + bool has_sub(enode* n) { return !!info(n).lo; } + enode* sub_lo(enode* n) { return info(n).lo; } + enode* sub_hi(enode* n) { return info(n).hi; } + + bool m_internal = false; + void ensure_slice(enode* n, unsigned lo, unsigned hi); + + + void split(enode* n, unsigned cut); + + bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys); + bool unfold_sub(enode* x, enode_vector& xs); + void merge(enode_vector& xs, enode_vector& ys, justification j); + svector> m_ensure_concat; + void propagate_extract(enode* n); + void propagate_values(enode* n); + + vector m_offsets; + svector> m_todo; + svector> m_jtodo; + void clear_offsets(); + + + ptr_vector m_trail; + + class undo_split; + void push_undo_split(enode* n); + + vector> m_queue; + unsigned m_qhead = 0; + void propagate_register_node(enode* n); + void propagate_merge(enode* a, enode* b); + + public: + bv_plugin(egraph& g); + + ~bv_plugin() override {} + + theory_id get_id() const override { return bv.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override {} + + void propagate() override; + + void undo() override; + + void set_ensure_th_var(std::function& f) { m_ensure_th_var = f; } + + std::ostream& display(std::ostream& out) const override; + + void sub_slices(enode* n, std::function& consumer); + + void super_slices(enode* n, std::function& consumer); + + void explain_slice(enode* a, unsigned offset, enode* b, std::function& consumer); + + }; +} diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index ac60a98ba96..f19a146c5c6 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -36,8 +36,10 @@ namespace euf { } m_expr2enode.setx(f->get_id(), n, nullptr); push_node(n); - for (unsigned i = 0; i < num_args; ++i) - set_cgc_enabled(args[i], true); + for (unsigned i = 0; i < num_args; ++i) { + set_cgc_enabled(args[i], true); + args[i]->get_root()->set_is_shared(l_undef); + } return n; } @@ -67,7 +69,7 @@ namespace euf { } enode_bool_pair egraph::insert_table(enode* p) { - TRACE("euf", tout << bpp(p) << "\n"); + TRACE("euf_verbose", tout << "insert_table " << bpp(p) << "\n"); //SASSERT(!m_table.contains_ptr(p)); auto rc = m_table.insert(p); p->m_cg = rc.first; @@ -79,9 +81,14 @@ namespace euf { } void egraph::reinsert_equality(enode* p) { - SASSERT(p->is_equality()); + SASSERT(p->is_equality()); if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) - add_literal(p, nullptr); + queue_literal(p, nullptr); + } + + void egraph::queue_literal(enode* p, enode* ante) { + if (m_on_propagate_literal) + m_to_merge.push_back(to_merge(p, ante)); } void egraph::force_push() { @@ -100,6 +107,8 @@ namespace euf { void egraph::update_children(enode* n) { for (enode* child : enode_args(n)) child->get_root()->add_parent(n); + for (enode* child : enode_args(n)) + SASSERT(child->get_root()->m_parents.back() == n); m_updates.push_back(update_record(n, update_record::update_children())); } @@ -123,8 +132,8 @@ namespace euf { if (n2 == n) update_children(n); else - merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); - + push_merge(n, n2, comm); + return n; } @@ -139,42 +148,55 @@ namespace euf { memory::deallocate(m_tmp_node); } + void egraph::add_plugin(plugin* p) { + m_plugins.reserve(p->get_id() + 1); + m_plugins.set(p->get_id(), p); + } + + void egraph::propagate_plugins() { + for (auto* p : m_plugins) + if (p) + p->propagate(); + } + void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) { - TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";); + TRACE("euf", tout << "eq: " << v1 << " == " << v2 << " - " << bpp(c) << " == " << bpp(r) << "\n";); m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_updates.push_back(update_record(update_record::new_th_eq())); ++m_stats.m_num_th_eqs; + auto* p = get_plugin(id); + if (p) + p->merge_eh(c, r); } - void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq) { + void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) { if (!th_propagates_diseqs(id)) return; TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";); - m_new_th_eqs.push_back(th_eq(id, v1, v2, eq)); + m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr())); m_updates.push_back(update_record(update_record::new_th_eq())); + auto* p = get_plugin(id); + if (p) + p->diseq_eh(eq); ++m_stats.m_num_th_diseqs; } void egraph::add_literal(enode* n, enode* ante) { + TRACE("euf", tout << "propagate " << bpp(n) << " " << bpp(ante) << "\n"); if (!m_on_propagate_literal) return; if (!ante) ++m_stats.m_num_eqs; else ++m_stats.m_num_lits; if (!ante) m_on_propagate_literal(n, ante); else if (m.is_true(ante->get_expr()) || m.is_false(ante->get_expr())) { - for (enode* k : enode_class(n)) { - if (k != ante) { - //verbose_stream() << "eq: " << k->value() << " " <value() << "\n"; - m_on_propagate_literal(k, ante); - } - } + for (enode* k : enode_class(n)) + if (k != ante) + m_on_propagate_literal(k, ante); } else { for (enode* k : enode_class(n)) { - if (k->value() != ante->value()) { - //verbose_stream() << "eq: " << k->value() << " " <value() << "\n"; - m_on_propagate_literal(k, ante); - } + if (k->value() != ante->value()) + m_on_propagate_literal(k, ante); } } } @@ -200,7 +222,7 @@ namespace euf { return; theory_var v1 = arg1->get_closest_th_var(id); theory_var v2 = arg2->get_closest_th_var(id); - add_th_diseq(id, v1, v2, n->get_expr()); + add_th_diseq(id, v1, v2, n); return; } for (auto const& p : euf::enode_th_vars(r1)) { @@ -208,8 +230,8 @@ namespace euf { continue; for (auto const& q : euf::enode_th_vars(r2)) if (p.get_id() == q.get_id()) - add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n->get_expr()); - } + add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n); + } } @@ -228,7 +250,7 @@ namespace euf { n = n->get_root(); theory_var v2 = n->get_closest_th_var(id); if (v2 != null_theory_var) - add_th_diseq(id, v1, v2, p->get_expr()); + add_th_diseq(id, v1, v2, p); } } } @@ -247,6 +269,10 @@ namespace euf { theory_var w = n->get_th_var(id); enode* r = n->get_root(); + auto* p = get_plugin(id); + if (p) + p->register_node(n); + if (w == null_theory_var) { n->add_th_var(v, id, m_region); m_updates.push_back(update_record(n, id, update_record::add_th_var())); @@ -418,10 +444,15 @@ namespace euf { break; case update_record::tag_t::is_update_children: for (unsigned i = 0; i < p.r1->num_args(); ++i) { + CTRACE("euf", (p.r1->m_args[i]->get_root()->m_parents.back() != p.r1), + display(tout << bpp(p.r1->m_args[i]) << " " << bpp(p.r1->m_args[i]->get_root()) << " ");); SASSERT(p.r1->m_args[i]->get_root()->m_parents.back() == p.r1); p.r1->m_args[i]->get_root()->m_parents.pop_back(); } break; + case update_record::tag_t::is_plugin_undo: + m_plugins[p.m_th_id]->undo(); + break; default: UNREACHABLE(); break; @@ -476,6 +507,7 @@ namespace euf { c->m_root = r2; std::swap(r1->m_next, r2->m_next); r2->inc_class_size(r1->class_size()); + r2->set_is_shared(l_undef); merge_th_eq(r1, r2); reinsert_parents(r1, r2); if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) @@ -490,7 +522,8 @@ namespace euf { } void egraph::remove_parents(enode* r) { - TRACE("euf", tout << bpp(r) << "\n"); + TRACE("euf_verbose", tout << bpp(r) << "\n"); + SASSERT(all_of(enode_parents(r), [&](enode* p) { return !p->is_marked1(); })); for (enode* p : enode_parents(r)) { if (p->is_marked1()) continue; @@ -500,7 +533,7 @@ namespace euf { SASSERT(m_table.contains_ptr(p)); p->mark1(); erase_from_table(p); - CTRACE("euf", m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout)); + CTRACE("euf_verbose", m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout)); SASSERT(!m_table.contains_ptr(p)); } else if (p->is_equality()) @@ -513,11 +546,11 @@ namespace euf { if (!p->is_marked1()) continue; p->unmark1(); - TRACE("euf", tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); + TRACE("euf_verbose", tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); if (p->cgc_enabled()) { auto [p_other, comm] = insert_table(p); SASSERT(m_table.contains_ptr(p) == (p_other == p)); - TRACE("euf", tout << "other " << bpp(p_other) << "\n";); + CTRACE("euf_verbose", p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n"); if (p_other != p) m_to_merge.push_back(to_merge(p_other, p, comm)); else @@ -551,13 +584,14 @@ namespace euf { void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) { enode* r2 = r1->get_root(); - TRACE("euf", tout << "undo-eq old-root: " << bpp(r1) << " current-root " << bpp(r2) << " node: " << bpp(n1) << "\n";); + TRACE("euf_verbose", tout << "undo-eq old-root: " << bpp(r1) << " current-root " << bpp(r2) << " node: " << bpp(n1) << "\n";); r2->dec_class_size(r1->class_size()); + r2->set_is_shared(l_undef); std::swap(r1->m_next, r2->m_next); auto begin = r2->begin_parents() + r2_num_parents, end = r2->end_parents(); for (auto it = begin; it != end; ++it) { enode* p = *it; - TRACE("euf", tout << "erase " << bpp(p) << "\n";); + TRACE("euf_verbose", tout << "erase " << bpp(p) << "\n";); SASSERT(!p->cgc_enabled() || m_table.contains_ptr(p)); SASSERT(!p->cgc_enabled() || p->is_cgr()); if (p->cgc_enabled()) @@ -576,11 +610,27 @@ namespace euf { bool egraph::propagate() { - SASSERT(m_num_scopes == 0 || m_to_merge.empty()); force_push(); - for (unsigned i = 0; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) { - auto const& w = m_to_merge[i]; - merge(w.a, w.b, justification::congruence(w.commutativity, m_congruence_timestamp++)); + unsigned i = 0; + bool change = true; + while (change) { + change = false; + propagate_plugins(); + for (; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) { + auto const& w = m_to_merge[i]; + switch (w.t) { + case to_merge_plain: + case to_merge_comm: + merge(w.a, w.b, justification::congruence(w.commutativity(), m_congruence_timestamp++)); + break; + case to_justified: + merge(w.a, w.b, w.j); + break; + case to_add_literal: + add_literal(w.a, w.b); + break; + } + } } m_to_merge.reset(); return @@ -596,7 +646,7 @@ namespace euf { m_updates.push_back(update_record(false, update_record::inconsistent())); m_n1 = n1; m_n2 = n2; - TRACE("euf", tout << "conflict " << bpp(n1) << " " << bpp(n2) << " " << j << "\n"); + TRACE("euf", tout << "conflict " << bpp(n1) << " " << bpp(n2) << " " << j << " " << n1->get_root()->value() << " " << n2->get_root()->value() << "\n"); m_justification = j; } @@ -623,7 +673,7 @@ namespace euf { SASSERT(n1->get_root()->reaches(n1)); SASSERT(n1->m_target); n1->m_target = nullptr; - n1->m_justification = justification::axiom(); + n1->m_justification = justification::axiom(null_theory_id); n1->get_root()->reverse_justification(); // --------------- // n1 -> ... -> r1 @@ -748,6 +798,16 @@ namespace euf { justifications.push_back(j.ext()); else if (j.is_congruence()) push_congruence(a, b, j.is_commutative()); + else if (j.is_dependent()) { + vector js; + for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js)) + explain_eq(justifications, cc, a, b, j2); + } + else if (j.is_equality()) + explain_eq(justifications, cc, j.lhs(), j.rhs()); + else if (j.is_axiom() && j.get_theory_id() != null_theory_id) { + IF_VERBOSE(20, verbose_stream() << "TODO add theory axiom to justification\n"); + } if (cc && j.is_congruence()) cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative())); } @@ -867,6 +927,9 @@ namespace euf { max_args = std::max(max_args, n->num_args()); for (enode* n : m_nodes) display(out, max_args, n); + for (auto* p : m_plugins) + if (p) + p->display(out); return out; } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index c1b9b784953..8822b07e793 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -19,7 +19,7 @@ Module Name: - data structures form the (legacy) SMT solver. - it still uses eager path compression. - NB. The worklist is in reality inheritied from the legacy SMT solver. + NB. The worklist is in reality inherited from the legacy SMT solver. It is claimed to have the same effect as delayed congruence table reconstruction from egg. Similar to the legacy solver, parents are partially deduplicated. @@ -29,8 +29,10 @@ Module Name: #include "util/statistics.h" #include "util/trail.h" #include "util/lbool.h" +#include "util/scoped_ptr_vector.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" +#include "ast/euf/euf_plugin.h" #include "ast/ast_ll_pp.h" #include @@ -82,12 +84,19 @@ namespace euf { class egraph { + friend class plugin; + typedef ptr_vector trail_stack; + enum to_merge_t { to_merge_plain, to_merge_comm, to_justified, to_add_literal }; struct to_merge { enode* a, * b; - bool commutativity; - to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {} + to_merge_t t; + justification j; + bool commutativity() const { return t == to_merge_comm; } + to_merge(enode* a, enode* b, bool c) : a(a), b(b), t(c ? to_merge_comm : to_merge_plain) {} + to_merge(enode* a, enode* b, justification j): a(a), b(b), t(to_justified), j(j) {} + to_merge(enode* p, enode* ante): a(p), b(ante), t(to_add_literal) {} }; struct stats { @@ -113,10 +122,12 @@ namespace euf { struct lbl_set {}; struct update_children {}; struct set_relevant {}; + struct plugin_undo {}; enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, is_add_th_var, is_replace_th_var, is_new_th_eq, is_lbl_hash, is_new_th_eq_qhead, - is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant }; + is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant, + is_plugin_undo }; tag_t tag; enode* r1; enode* n1; @@ -159,11 +170,14 @@ namespace euf { tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_record(enode* n, set_relevant) : tag(tag_t::is_set_relevant), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} + update_record(unsigned th_id, plugin_undo) : + tag(tag_t::is_plugin_undo), r1(nullptr), n1(nullptr), m_th_id(th_id) {} }; ast_manager& m; svector m_to_merge; etable m_table; region m_region; + scoped_ptr_vector m_plugins; svector m_updates; unsigned_vector m_scopes; enode_vector m_expr2enode; @@ -202,11 +216,18 @@ namespace euf { } void push_node(enode* n) { m_updates.push_back(update_record(n)); } + // plugin related methods + void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); } + void push_merge(enode* a, enode* b, justification j) { SASSERT(a->get_sort() == b->get_sort()); m_to_merge.push_back({ a, b, j }); } + void push_merge(enode* a, enode* b, bool comm) { m_to_merge.push_back({ a, b, comm }); } + void propagate_plugins(); + void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); void add_th_diseqs(theory_id id, theory_var v1, enode* r); bool th_propagates_diseqs(theory_id id) const; void add_literal(enode* n, enode* ante); + void queue_literal(enode* n, enode* ante); void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); void undo_add_th_var(enode* n, theory_id id); enode* mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args); @@ -241,11 +262,15 @@ namespace euf { public: egraph(ast_manager& m); ~egraph(); + + void add_plugin(plugin* p); + plugin* get_plugin(family_id fid) const { return m_plugins.get(fid, nullptr); } + enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } enode* find(expr* f, unsigned n, enode* const* args); enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); enode_vector const& enodes_of(func_decl* f); - void push() { if (!m_to_merge.empty()) propagate(); ++m_num_scopes; } + void push() { if (can_propagate()) propagate(); ++m_num_scopes; } void pop(unsigned num_scopes); /** @@ -265,6 +290,7 @@ namespace euf { of new equalities. */ bool propagate(); + bool can_propagate() const { return !m_to_merge.empty(); } bool inconsistent() const { return m_inconsistent; } /** @@ -282,7 +308,7 @@ namespace euf { where \c n is an enode and \c is_eq indicates whether the enode is an equality consequence. */ - void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq); + void add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq); bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index 08df9f4939a..2149059b483 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -93,6 +93,17 @@ namespace euf { return null_theory_var; } + enode* enode::get_closest_th_node(theory_id id) { + enode* n = this; + while (n) { + theory_var v = n->get_th_var(id); + if (v != null_theory_var) + return n; + n = n->m_target; + } + return nullptr; + } + bool enode::acyclic() const { enode const* n = this; enode const* p = this; @@ -124,7 +135,7 @@ namespace euf { enode* prev = this; justification js = m_justification; prev->m_target = nullptr; - prev->m_justification = justification::axiom(); + prev->m_justification = justification::axiom(null_theory_id); while (curr != nullptr) { enode* new_curr = curr->m_target; justification new_js = curr->m_justification; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index d9ae45074e3..50a7ce479e3 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -36,10 +36,6 @@ namespace euf { typedef std::pair enode_bool_pair; typedef svector enode_bool_pair_vector; typedef id_var_list<> th_var_list; - typedef int theory_var; - typedef int theory_id; - const theory_var null_theory_var = -1; - const theory_id null_theory_id = -1; class enode { expr* m_expr = nullptr; @@ -52,6 +48,7 @@ namespace euf { bool m_merge_tf_enabled = false; bool m_is_equality = false; // Does the expression represent an equality bool m_is_relevant = false; + lbool m_is_shared = l_undef; lbool m_value = l_undef; // Assignment by SAT solver for Boolean node sat::bool_var m_bool_var = sat::null_bool_var; // SAT solver variable associated with Boolean node unsigned m_class_size = 1; // Size of the equivalence class if the enode is the root. @@ -96,6 +93,7 @@ namespace euf { for (unsigned i = 0; i < num_args; ++i) { SASSERT(to_app(f)->get_arg(i) == args[i]->get_expr()); n->m_args[i] = args[i]; + n->m_args[i]->get_root()->set_is_shared(l_undef); } return n; } @@ -181,6 +179,9 @@ namespace euf { void unmark3() { m_mark3 = false; } bool is_marked3() { return m_mark3; } + lbool is_shared() const { return m_is_shared; } + void set_is_shared(lbool s) { m_is_shared = s; } + template void mark1_targets() { enode* n = this; while (n) { @@ -202,6 +203,7 @@ namespace euf { enode* get_root() const { return m_root; } expr* get_expr() const { return m_expr; } sort* get_sort() const { return m_expr->get_sort(); } + enode* get_interpreted() const { return get_root(); } app* get_app() const { return to_app(m_expr); } func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; } unsigned get_expr_id() const { return m_expr->get_id(); } @@ -211,6 +213,10 @@ namespace euf { bool children_are_roots() const; enode* get_next() const { return m_next; } + enode* get_target() const { return m_target; } + justification get_justification() const { return m_justification; } + justification get_lit_justification() const { return m_lit_justification; } + bool has_lbl_hash() const { return m_lbl_hash >= 0; } unsigned char get_lbl_hash() const { SASSERT(m_lbl_hash >= 0 && static_cast(m_lbl_hash) < approx_set_traits::capacity); @@ -224,6 +230,7 @@ namespace euf { theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); } theory_var get_closest_th_var(theory_id id) const; + enode* get_closest_th_node(theory_id id); bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; } bool has_th_vars() const { return !m_th_vars.empty(); } bool has_one_th_var() const { return !m_th_vars.empty() && !m_th_vars.get_next();} diff --git a/src/ast/euf/euf_etable.cpp b/src/ast/euf/euf_etable.cpp index e007297ef54..1fc8aa0e03f 100644 --- a/src/ast/euf/euf_etable.cpp +++ b/src/ast/euf/euf_etable.cpp @@ -237,6 +237,8 @@ namespace euf { UNTAG(table*, t)->erase(n); break; } + CTRACE("euf", contains_ptr(n), display(tout)); + SASSERT(!contains_ptr(n)); } bool etable::contains(enode* n) const { diff --git a/src/ast/euf/euf_justification.cpp b/src/ast/euf/euf_justification.cpp new file mode 100644 index 00000000000..22b52ea848a --- /dev/null +++ b/src/ast/euf/euf_justification.cpp @@ -0,0 +1,54 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_justification.cpp + +Abstract: + + justification structure for euf + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-23 + +--*/ + + +#include "ast/euf/euf_justification.h" +#include "ast/euf/euf_enode.h" + +namespace euf { + + + std::ostream& justification::display(std::ostream& out, std::function const& ext) const { + switch (m_kind) { + case kind_t::external_t: + if (ext) + ext(out, m_external); + else + out << "external"; + return out; + case kind_t::axiom_t: + return out << "axiom"; + case kind_t::congruence_t: + return out << "congruence"; + case kind_t::dependent_t: { + vector js; + out << "dependent"; + for (auto const& j : dependency_manager::s_linearize(m_dependency, js)) + j.display(out << " ", ext); + return out; + } + case kind_t::equality_t: + return out << "equality #" << m_n1->get_id() << " == #" << m_n2->get_id(); + + default: + UNREACHABLE(); + return out; + } + return out; + } + +} diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index 57b532e3bf1..8a5a58ad979 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -16,25 +16,46 @@ Module Name: Notes: - congruence closure justifications are given a timestamp so it is easy to sort them. - See the longer descriptoin in euf_proof_checker.cpp + See the longer description in euf_proof_checker.cpp --*/ #pragma once +#include "util/dependency.h" + namespace euf { + class enode; + + typedef int theory_var; + typedef int theory_id; + const theory_var null_theory_var = -1; + const theory_id null_theory_id = -1; + class justification { + public: + typedef stacked_dependency_manager dependency_manager; + typedef stacked_dependency_manager::dependency dependency; + private: enum class kind_t { axiom_t, congruence_t, - external_t + external_t, + dependent_t, + equality_t }; kind_t m_kind; - bool m_comm; + union { + int m_theory_id; + bool m_comm; + enode* m_n1; + }; union { void* m_external; uint64_t m_timestamp; + dependency* m_dependency; + enode* m_n2; }; justification(bool comm, uint64_t ts): @@ -49,21 +70,49 @@ namespace euf { m_external(ext) {} + justification(dependency* dep, int): + m_kind(kind_t::dependent_t), + m_comm(false), + m_dependency(dep) + {} + + justification(enode* n1, enode* n2): + m_kind(kind_t::equality_t), + m_n1(n1), + m_n2(n2) + {} + + justification(int theory_id): + m_kind(kind_t::axiom_t), + m_theory_id(theory_id), + m_external(nullptr) + {} + public: + justification(): m_kind(kind_t::axiom_t), - m_comm(false), + m_theory_id(null_theory_id), m_external(nullptr) {} - static justification axiom() { return justification(); } + static justification axiom(int theory_id) { return justification(theory_id); } static justification congruence(bool c, uint64_t ts) { return justification(c, ts); } static justification external(void* ext) { return justification(ext); } + static justification dependent(dependency* d) { return justification(d, 1); } + static justification equality(enode* a, enode* b) { return justification(a, b); } + bool is_axiom() const { return m_kind == kind_t::axiom_t; } bool is_external() const { return m_kind == kind_t::external_t; } bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_commutative() const { return m_comm; } + bool is_dependent() const { return m_kind == kind_t::dependent_t; } + bool is_equality() const { return m_kind == kind_t::equality_t; } + dependency* get_dependency() const { SASSERT(is_dependent()); return m_dependency; } + enode* lhs() const { SASSERT(is_equality()); return m_n1; } + enode* rhs() const { SASSERT(is_equality()); return m_n2; } uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; } + theory_id get_theory_id() const { SASSERT(is_axiom()); return m_theory_id; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } @@ -72,33 +121,20 @@ namespace euf { case kind_t::external_t: return external(copy_justification(m_external)); case kind_t::axiom_t: - return axiom(); + return axiom(m_theory_id); case kind_t::congruence_t: return congruence(m_comm, m_timestamp); + case kind_t::dependent_t: + NOT_IMPLEMENTED_YET(); + return dependent(m_dependency); default: UNREACHABLE(); - return axiom(); + return axiom(-1); } } - std::ostream& display(std::ostream& out, std::function const& ext) const { - switch (m_kind) { - case kind_t::external_t: - if (ext) - ext(out, m_external); - else - out << "external"; - return out; - case kind_t::axiom_t: - return out << "axiom"; - case kind_t::congruence_t: - return out << "congruence"; - default: - UNREACHABLE(); - return out; - } - return out; - } + std::ostream& display(std::ostream& out, std::function const& ext) const; + }; inline std::ostream& operator<<(std::ostream& out, justification const& j) { diff --git a/src/ast/euf/euf_plugin.cpp b/src/ast/euf/euf_plugin.cpp new file mode 100644 index 00000000000..f95107bc3cf --- /dev/null +++ b/src/ast/euf/euf_plugin.cpp @@ -0,0 +1,48 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.cpp + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#include "ast/euf/euf_egraph.h" + +namespace euf { + + void plugin::push_plugin_undo(unsigned th_id) { + g.push_plugin_undo(th_id); + } + + void plugin::push_merge(enode* a, enode* b, justification j) { + TRACE("euf", tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << " " << j << "\n"); + g.push_merge(a, b, j); + } + + void plugin::push_merge(enode* a, enode* b) { + TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); + g.push_merge(a, b, justification::axiom(get_id())); + } + + enode* plugin::mk(expr* e, unsigned n, enode* const* args) { + enode* r = g.find(e); + if (!r) + r = g.mk(e, 0, n, args); + return r; + } + + region& plugin::get_region() { + return g.m_region; + } +} diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h new file mode 100644 index 00000000000..8dbd4d7e7c1 --- /dev/null +++ b/src/ast/euf/euf_plugin.h @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_plugin.h + +Abstract: + + plugin structure for euf + + Plugins allow adding equality saturation for theories. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-08 + +--*/ + +#pragma once + +#include "ast/euf/euf_enode.h" +#include "ast/euf/euf_justification.h" + +namespace euf { + + + class plugin { + protected: + egraph& g; + void push_plugin_undo(unsigned th_id); + void push_merge(enode* a, enode* b, justification j); + void push_merge(enode* a, enode* b); + enode* mk(expr* e, unsigned n, enode* const* args); + region& get_region(); + public: + plugin(egraph& g): + g(g) + {} + + virtual ~plugin() {} + + virtual theory_id get_id() const = 0; + + virtual void register_node(enode* n) = 0; + + virtual void merge_eh(enode* n1, enode* n2) = 0; + + virtual void diseq_eh(enode* eq) {}; + + virtual void propagate() = 0; + + virtual void undo() = 0; + + virtual std::ostream& display(std::ostream& out) const = 0; + + }; +} diff --git a/src/ast/euf/euf_specrel_plugin.cpp b/src/ast/euf/euf_specrel_plugin.cpp new file mode 100644 index 00000000000..3220a24e601 --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.cpp @@ -0,0 +1,71 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.cpp + +Abstract: + + plugin structure for specrel + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#include "ast/euf/euf_specrel_plugin.h" +#include "ast/euf/euf_egraph.h" +#include + +namespace euf { + + specrel_plugin::specrel_plugin(egraph& g) : + plugin(g), + sp(g.get_manager()) { + } + + void specrel_plugin::register_node(enode* n) { + func_decl* f = n->get_decl(); + if (!f) + return; + if (!sp.is_ac(f)) + return; + ac_plugin* p = nullptr; + if (!m_decl2plugin.find(f, p)) { + p = alloc(ac_plugin, g, f); + m_decl2plugin.insert(f, p); + m_plugins.push_back(p); + std::function undo_op = [&]() { m_undo.push_back(p); }; + p->set_undo(undo_op); + } + } + + void specrel_plugin::merge_eh(enode* n1, enode* n2) { + for (auto * p : m_plugins) + p->merge_eh(n1, n2); + } + + void specrel_plugin::diseq_eh(enode* eq) { + for (auto* p : m_plugins) + p->diseq_eh(eq); + } + + void specrel_plugin::propagate() { + for (auto * p : m_plugins) + p->propagate(); + } + + void specrel_plugin::undo() { + auto p = m_undo.back(); + m_undo.pop_back(); + p->undo(); + } + + std::ostream& specrel_plugin::display(std::ostream& out) const { + for (auto * p : m_plugins) + p->display(out); + return out; + } +} \ No newline at end of file diff --git a/src/ast/euf/euf_specrel_plugin.h b/src/ast/euf/euf_specrel_plugin.h new file mode 100644 index 00000000000..ae93bd2a5eb --- /dev/null +++ b/src/ast/euf/euf_specrel_plugin.h @@ -0,0 +1,56 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + euf_specrel_plugin.h + +Abstract: + + plugin structure for specrel functions + +Author: + + Nikolaj Bjorner (nbjorner) 2023-11-11 + +--*/ + +#pragma once + +#include +#include "util/scoped_ptr_vector.h" +#include "ast/special_relations_decl_plugin.h" +#include "ast/euf/euf_plugin.h" +#include "ast/euf/euf_ac_plugin.h" + +namespace euf { + + class specrel_plugin : public plugin { + scoped_ptr_vector m_plugins; + ptr_vector m_undo; + obj_map m_decl2plugin; + special_relations_util sp; + + public: + + specrel_plugin(egraph& g); + + ~specrel_plugin() override {} + + theory_id get_id() const override { return sp.get_family_id(); } + + void register_node(enode* n) override; + + void merge_eh(enode* n1, enode* n2) override; + + void diseq_eh(enode* eq) override; + + void undo() override; + + void propagate() override; + + std::ostream& display(std::ostream& out) const override; + + }; + +} diff --git a/src/ast/for_each_expr.cpp b/src/ast/for_each_expr.cpp index 832c1d0bcc0..8feb217dbb3 100644 --- a/src/ast/for_each_expr.cpp +++ b/src/ast/for_each_expr.cpp @@ -109,9 +109,9 @@ bool has_skolem_functions(expr * n) { subterms::subterms(expr_ref_vector const& es, bool include_bound, ptr_vector* esp, expr_mark* vp): m_include_bound(include_bound), m_es(es), m_esp(esp), m_vp(vp) {} subterms::subterms(expr_ref const& e, bool include_bound, ptr_vector* esp, expr_mark* vp) : m_include_bound(include_bound), m_es(e.m()), m_esp(esp), m_vp(vp) { if (e) m_es.push_back(e); } -subterms::iterator subterms::begin() { return iterator(* this, m_esp, m_vp, true); } -subterms::iterator subterms::end() { return iterator(*this, nullptr, nullptr, false); } -subterms::iterator::iterator(subterms& f, ptr_vector* esp, expr_mark* vp, bool start): m_include_bound(f.m_include_bound), m_esp(esp), m_visitedp(vp) { +subterms::iterator subterms::begin() const { return iterator(* this, m_esp, m_vp, true); } +subterms::iterator subterms::end() const { return iterator(*this, nullptr, nullptr, false); } +subterms::iterator::iterator(subterms const& f, ptr_vector* esp, expr_mark* vp, bool start): m_include_bound(f.m_include_bound), m_esp(esp), m_visitedp(vp) { if (!esp) m_esp = &m_es; else diff --git a/src/ast/for_each_expr.h b/src/ast/for_each_expr.h index 0ba0dc9926d..77b01e93928 100644 --- a/src/ast/for_each_expr.h +++ b/src/ast/for_each_expr.h @@ -186,7 +186,7 @@ class subterms { expr_mark m_visited; expr_mark* m_visitedp = nullptr; public: - iterator(subterms& f, ptr_vector* esp, expr_mark* vp, bool start); + iterator(subterms const& f, ptr_vector* esp, expr_mark* vp, bool start); expr* operator*(); iterator operator++(int); iterator& operator++(); @@ -198,8 +198,8 @@ class subterms { static subterms ground(expr_ref const& e, ptr_vector* esp = nullptr, expr_mark* vp = nullptr) { return subterms(e, false, esp, vp); } static subterms all(expr_ref_vector const& e, ptr_vector* esp = nullptr, expr_mark* vp = nullptr) { return subterms(e, true, esp, vp); } static subterms ground(expr_ref_vector const& e, ptr_vector* esp = nullptr, expr_mark* vp = nullptr) { return subterms(e, false, esp, vp); } - iterator begin(); - iterator end(); + iterator begin() const; + iterator end() const; }; class subterms_postorder { diff --git a/src/ast/fpa/bv2fpa_converter.cpp b/src/ast/fpa/bv2fpa_converter.cpp index 00e9d71c320..a3008851bc8 100644 --- a/src/ast/fpa/bv2fpa_converter.cpp +++ b/src/ast/fpa/bv2fpa_converter.cpp @@ -319,12 +319,12 @@ func_interp * bv2fpa_converter::convert_func_interp(model_core * mc, func_decl * if (m_fpa_util.is_to_sbv(f) || m_fpa_util.is_to_ubv(f)) { auto k = m_fpa_util.is_to_sbv(f) ? OP_FPA_TO_SBV_I : OP_FPA_TO_UBV_I; - parameter param = f->get_parameter(0); + const parameter ¶m = f->get_parameter(0); func_decl_ref to_bv_i(m.mk_func_decl(fid, k, 1, ¶m, dom.size(), dom.data()), m); expr_ref else_value(m.mk_app(to_bv_i, dom.size(), dom.data()), m); result->set_else(else_value); } - else if (m_fpa_util.is_to_real(f)) { + else if (m_fpa_util.is_to_real(f)) { SASSERT(dom.size() == 1); func_decl_ref to_real_i(m.mk_func_decl(fid, OP_FPA_TO_REAL_I, 0, nullptr, dom.size(), dom.data()), m); expr_ref else_value(m.mk_app(to_real_i, dom.size(), dom.data()), m); @@ -508,7 +508,7 @@ void bv2fpa_converter::convert_uf2bvuf(model_core * mc, model_core * target_mode } } - TRACE("bv2fpa", tout << "Target model: " << *target_model; ); + TRACE("bv2fpa", tout << "Target model: " << *target_model << std::endl; ); } void bv2fpa_converter::display(std::ostream & out) { diff --git a/src/ast/fpa/fpa2bv_converter.cpp b/src/ast/fpa/fpa2bv_converter.cpp index ab13e751a30..25a0e77ad1a 100644 --- a/src/ast/fpa/fpa2bv_converter.cpp +++ b/src/ast/fpa/fpa2bv_converter.cpp @@ -147,36 +147,11 @@ void fpa2bv_converter::mk_distinct(func_decl * f, unsigned num, expr * const * a void fpa2bv_converter::mk_numeral(func_decl * f, unsigned num, expr * const * args, expr_ref & result) { SASSERT(num == 0); - sort* s = f->get_range(); - if (f->get_num_parameters() == 1) { - SASSERT(f->get_parameter(0).is_external()); - unsigned p_id = f->get_parameter(0).get_ext_id(); - mpf const& v = m_plugin->get_value(p_id); - mk_numeral(s, v, result); - return; - } - scoped_mpf v(m_mpf_manager); - unsigned ebits = m_util.get_ebits(s), sbits = m_util.get_sbits(s); - switch (f->get_decl_kind()) { - case OP_FPA_PLUS_INF: - m_util.fm().mk_pinf(ebits, sbits, v); - break; - case OP_FPA_MINUS_INF: - m_util.fm().mk_ninf(ebits, sbits, v); - break; - case OP_FPA_NAN: - m_util.fm().mk_nan(ebits, sbits, v); - break; - case OP_FPA_PLUS_ZERO: - m_util.fm().mk_pzero(ebits, sbits, v); - break; - case OP_FPA_MINUS_ZERO: - m_util.fm().mk_nzero(ebits, sbits, v); - break; - default: - UNREACHABLE(); - } - mk_numeral(s, v, result); + scoped_mpf v(m_mpf_manager); + expr_ref a(m); + a = m.mk_app(f, num, args); + m_util.is_numeral(a, v); + mk_numeral(f->get_range(), v, result); } void fpa2bv_converter::mk_numeral(sort * s, mpf const & v, expr_ref & result) { @@ -941,8 +916,8 @@ void fpa2bv_converter::mk_div(sort * s, expr_ref & rm, expr_ref & x, expr_ref & dbg_decouple("fpa2bv_div_y_is_pos", y_is_pos); dbg_decouple("fpa2bv_div_y_is_inf", y_is_inf); - expr_ref c1(m), c2(m), c3(m), c4(m), c5(m), c6(m), c7(m); - expr_ref v1(m), v2(m), v3(m), v4(m), v5(m), v6(m), v7(m), v8(m); + expr_ref c1(m), c2(m), c3(m), c4(m), c5(m), c6(m), c7(m), c8(m); + expr_ref v1(m), v2(m), v3(m), v4(m), v5(m), v6(m), v7(m), v8(m), v9(m); // (x is NaN) || (y is NaN) -> NaN m_simp.mk_or(x_is_nan, y_is_nan, c1); @@ -998,6 +973,9 @@ void fpa2bv_converter::mk_div(sort * s, expr_ref & rm, expr_ref & x, expr_ref & a_sig_ext = m_bv_util.mk_concat(a_sig, m_bv_util.mk_numeral(0, sbits + extra_bits)); b_sig_ext = m_bv_util.mk_zero_extend(sbits + extra_bits, b_sig); + dbg_decouple("fpa2bv_div_a_sig_ext", a_sig_ext); + dbg_decouple("fpa2bv_div_b_sig_ext", b_sig_ext); + expr_ref a_exp_ext(m), b_exp_ext(m); a_exp_ext = m_bv_util.mk_sign_extend(2, a_exp); b_exp_ext = m_bv_util.mk_sign_extend(2, b_exp); @@ -1017,14 +995,21 @@ void fpa2bv_converter::mk_div(sort * s, expr_ref & rm, expr_ref & x, expr_ref & expr_ref quotient(m); // b_sig_ext can't be 0 here, so it's safe to use OP_BUDIV_I quotient = m.mk_app(m_bv_util.get_fid(), OP_BUDIV_I, a_sig_ext, b_sig_ext); - dbg_decouple("fpa2bv_div_quotient", quotient); SASSERT(m_bv_util.get_bv_size(quotient) == (sbits + sbits + extra_bits)); - expr_ref sticky(m); + expr_ref sticky(m), upper(m), upper_reduced(m), too_large(m); sticky = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, m_bv_util.mk_extract(extra_bits-2, 0, quotient)); res_sig = m_bv_util.mk_concat(m_bv_util.mk_extract(extra_bits+sbits+1, extra_bits-1, quotient), sticky); + upper = m_bv_util.mk_extract(sbits + sbits + extra_bits-1, extra_bits+sbits+2, quotient); + upper_reduced = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, upper.get()); + too_large = m.mk_eq(upper_reduced, m_bv_util.mk_numeral(1, 1)); + c8 = too_large; + mk_ite(signs_xor, ninf, pinf, v8); + dbg_decouple("fpa2bv_div_res_sig_p4", res_sig); + dbg_decouple("fpa2bv_div_upper", upper); + dbg_decouple("fpa2bv_div_too_large", too_large); SASSERT(m_bv_util.get_bv_size(res_sig) == (sbits + 4)); @@ -1042,10 +1027,14 @@ void fpa2bv_converter::mk_div(sort * s, expr_ref & rm, expr_ref & x, expr_ref & m_simp.mk_ite(shift_cond, res_sig, res_sig_shifted, res_sig); m_simp.mk_ite(shift_cond, res_exp, res_exp_shifted, res_exp); - round(s, rm, res_sgn, res_sig, res_exp, v8); + dbg_decouple("fpa2bv_div_res_sig", res_sig); + dbg_decouple("fpa2bv_div_res_exp", res_exp); + + round(s, rm, res_sgn, res_sig, res_exp, v9); // And finally, we tie them together. - mk_ite(c7, v7, v8, result); + mk_ite(c8, v8, v9, result); + mk_ite(c7, v7, result, result); mk_ite(c6, v6, result, result); mk_ite(c5, v5, result, result); mk_ite(c4, v4, result, result); @@ -2200,7 +2189,13 @@ void fpa2bv_converter::mk_round_to_integral(sort * s, expr_ref & rm, expr_ref & renorm_delta = m.mk_ite(m_bv_util.mk_ule(zero_e2, sig_lz_capped), sig_lz_capped, zero_e2); SASSERT(m_bv_util.get_bv_size(renorm_delta) == ebits + 2); res_exp = m_bv_util.mk_bv_sub(res_exp, renorm_delta); - res_sig = m_bv_util.mk_bv_shl(res_sig, m_bv_util.mk_zero_extend(sbits-ebits-2, renorm_delta)); + if (sbits >= ebits+2) + res_sig = m_bv_util.mk_bv_shl(res_sig, m_bv_util.mk_zero_extend(sbits-ebits-2, renorm_delta)); + else { + // should not overflow because renorm_delta is logarithmic to the size of the leading zero bits + res_sig = m_bv_util.mk_bv_shl(m_bv_util.mk_zero_extend(ebits+2-sbits, res_sig), renorm_delta); + res_sig = m_bv_util.mk_extract(sbits-1, 0, res_sig); + } dbg_decouple("fpa2bv_r2i_renorm_delta", renorm_delta); dbg_decouple("fpa2bv_r2i_sig_lz", sig_lz); dbg_decouple("fpa2bv_r2i_sig_lz_capped", sig_lz_capped); @@ -2809,8 +2804,46 @@ void fpa2bv_converter::mk_to_fp_real(func_decl * f, sort * s, expr * rm, expr * expr * e = m.mk_eq(m_util.mk_to_real(result), x); m_extra_assertions.push_back(e); - // x = 0 -> result = 0+ - m_extra_assertions.push_back(m.mk_implies(m.mk_eq(x, zero), m.mk_eq(result, m_util.mk_pzero(result->get_sort())))); + + expr_ref r_is_nan(m); + mk_is_nan(result, r_is_nan); + m_extra_assertions.push_back(m.mk_not(r_is_nan)); + + rational min_real, max_real; + const mpz& max_exp_z = m_mpf_manager.m_powers2.m1(ebits-1); + SASSERT(m_mpz_manager.is_uint(max_exp_z)); + unsigned max_exp = m_mpz_manager.get_uint(max_exp_z); + rational max_sig = m_mpf_manager.m_powers2.m1(sbits) / m_mpf_manager.m_powers2(sbits-1); + max_real = max_sig * rational(m_mpf_manager.m_powers2(max_exp)); + TRACE("fpa2bv_to_real", tout << "max exp: " << max_exp << " max real: " << max_real << std::endl;); + + expr_ref r_is_pinf(m), r_is_ninf(m); + mk_is_pinf(result, r_is_pinf); + mk_is_ninf(result, r_is_ninf); + + expr_ref e_max_real(m), e_max_real_neg(m); + e_max_real = m_arith_util.mk_numeral(max_real, false); + e_max_real_neg = m_arith_util.mk_numeral(-max_real, false); + + expr_ref rm_nta(m), rm_nte(m), rm_tp(m), rm_tn(m), rm_tz(m); + mk_is_rm(bv_rm, BV_RM_TIES_TO_AWAY, rm_nta); + mk_is_rm(bv_rm, BV_RM_TIES_TO_EVEN, rm_nte); + mk_is_rm(bv_rm, BV_RM_TO_POSITIVE, rm_tp); + mk_is_rm(bv_rm, BV_RM_TO_NEGATIVE, rm_tn); + mk_is_rm(bv_rm, BV_RM_TO_ZERO, rm_tz); + + expr_ref implies_gt_max_real(m), implies_lt_min_real(m); + implies_gt_max_real = m.mk_implies(r_is_pinf, m.mk_and(rm_tp, m_arith_util.mk_gt(x, e_max_real))); + implies_lt_min_real = m.mk_implies(r_is_ninf, m.mk_and(rm_tn, m_arith_util.mk_lt(x, e_max_real_neg))); + + m_extra_assertions.push_back(implies_gt_max_real); + m_extra_assertions.push_back(implies_lt_min_real); + + // x = 0 -> result = +0/-0 + expr_ref pzero(m), nzero(m); + mk_pzero(result->get_sort(), pzero); + mk_nzero(result->get_sort(), nzero); + m_extra_assertions.push_back(m.mk_implies(m.mk_eq(x, zero), m.mk_or(m.mk_eq(result, pzero), m.mk_eq(result, nzero)))); } SASSERT(is_well_sorted(m, result)); @@ -2854,19 +2887,13 @@ void fpa2bv_converter::mk_to_fp_real_int(func_decl * f, unsigned num, expr * con m_mpf_manager.set(tn, ebits, sbits, MPF_ROUND_TOWARD_NEGATIVE, e.to_mpq().numerator(), q.to_mpq()); m_mpf_manager.set(tz, ebits, sbits, MPF_ROUND_TOWARD_ZERO, e.to_mpq().numerator(), q.to_mpq()); - app_ref a_nte(m), a_nta(m), a_tp(m), a_tn(m), a_tz(m); - a_nte = m_plugin->mk_numeral(nte); - a_nta = m_plugin->mk_numeral(nta); - a_tp = m_plugin->mk_numeral(tp); - a_tn = m_plugin->mk_numeral(tn); - a_tz = m_plugin->mk_numeral(tz); - expr_ref bv_nte(m), bv_nta(m), bv_tp(m), bv_tn(m), bv_tz(m); - mk_numeral(a_nte->get_decl(), 0, nullptr, bv_nte); - mk_numeral(a_nta->get_decl(), 0, nullptr, bv_nta); - mk_numeral(a_tp->get_decl(), 0, nullptr, bv_tp); - mk_numeral(a_tn->get_decl(), 0, nullptr, bv_tn); - mk_numeral(a_tz->get_decl(), 0, nullptr, bv_tz); + sort *s = f->get_range(); + mk_numeral(s, nte, bv_nte); + mk_numeral(s, nta, bv_nta); + mk_numeral(s, tp, bv_tp); + mk_numeral(s, tn, bv_tn); + mk_numeral(s, tz, bv_tz); expr_ref c1(m), c2(m), c3(m), c4(m); c1 = m.mk_eq(bv_rm, m_bv_util.mk_numeral(BV_RM_TO_POSITIVE, 3)); @@ -3003,27 +3030,34 @@ void fpa2bv_converter::mk_to_fp_signed(func_decl * f, unsigned num, expr * const unsigned bv_sz = m_bv_util.get_bv_size(x); SASSERT(m_bv_util.get_bv_size(rm) == 3); + expr_ref rm_is_to_neg(m); + mk_is_rm(rm, BV_RM_TO_NEGATIVE, rm_is_to_neg); + expr_ref bv1_1(m), bv0_sz(m); bv1_1 = m_bv_util.mk_numeral(1, 1); bv0_sz = m_bv_util.mk_numeral(0, bv_sz); - expr_ref is_zero(m), pzero(m); + expr_ref is_zero(m), pzero(m), nzero(m); is_zero = m.mk_eq(x, bv0_sz); mk_pzero(f, pzero); + mk_nzero(f, nzero); - // Special case: x == 0 -> p/n zero + // Special case: x == 0 -> +zero expr_ref c1(m), v1(m); c1 = is_zero; - v1 = pzero; + v1 = pzero; // No -zero (IEEE754) // Special case: x != 0 - expr_ref is_neg_bit(m), exp_too_large(m), sig_4(m), exp_2(m); + expr_ref sign_bit(m), exp_too_large(m), sig_4(m), exp_2(m), rest(m); expr_ref is_neg(m), x_abs(m), neg_x(m); - is_neg_bit = m_bv_util.mk_extract(bv_sz - 1, bv_sz - 1, x); - is_neg = m.mk_eq(is_neg_bit, bv1_1); - neg_x = m_bv_util.mk_bv_neg(x); // overflow problem? + sign_bit = m_bv_util.mk_extract(bv_sz - 1, bv_sz - 1, x); + rest = m_bv_util.mk_extract(bv_sz - 2, 0, x); + dbg_decouple("fpa2bv_to_fp_signed_rest", rest); + is_neg = m.mk_eq(sign_bit, bv1_1); + neg_x = m_bv_util.mk_bv_neg(x); // overflow ok, x_abs is now unsigned. x_abs = m.mk_ite(is_neg, neg_x, x); dbg_decouple("fpa2bv_to_fp_signed_is_neg", is_neg); + dbg_decouple("fpa2bv_to_fp_signed_x_abs", x_abs); // x_abs has an extra bit in the front. // x_abs is [bv_sz-1, bv_sz-2] . [bv_sz-3 ... 0] * 2^(bv_sz-2) // bv_sz-2 is the "1.0" bit for the rounder. @@ -3075,7 +3109,7 @@ void fpa2bv_converter::mk_to_fp_signed(func_decl * f, unsigned num, expr * const TRACE("fpa2bv_to_fp_signed", tout << "exp worst case sz: " << exp_worst_case_sz << std::endl;); - if (exp_sz < exp_worst_case_sz) { + if (exp_sz <= exp_worst_case_sz) { // exp_sz < exp_worst_case_sz and exp >= 0. // Take the maximum legal exponent; this // allows us to keep the most precision. @@ -3093,7 +3127,7 @@ void fpa2bv_converter::mk_to_fp_signed(func_decl * f, unsigned num, expr * const dbg_decouple("fpa2bv_to_fp_signed_exp_too_large", exp_too_large); expr_ref sgn(m), sig(m), exp(m); - sgn = is_neg_bit; + sgn = sign_bit; sig = sig_4; exp = exp_2; @@ -3132,6 +3166,9 @@ void fpa2bv_converter::mk_to_fp_unsigned(func_decl * f, unsigned num, expr * con rm = to_app(args[0])->get_arg(0); x = args[1]; + expr_ref rm_is_to_neg(m); + mk_is_rm(rm, BV_RM_TO_NEGATIVE, rm_is_to_neg); + dbg_decouple("fpa2bv_to_fp_unsigned_x", x); unsigned ebits = m_util.get_ebits(f->get_range()); @@ -3143,14 +3180,15 @@ void fpa2bv_converter::mk_to_fp_unsigned(func_decl * f, unsigned num, expr * con bv0_1 = m_bv_util.mk_numeral(0, 1); bv0_sz = m_bv_util.mk_numeral(0, bv_sz); - expr_ref is_zero(m), pzero(m); + expr_ref is_zero(m), pzero(m), nzero(m); is_zero = m.mk_eq(x, bv0_sz); mk_pzero(f, pzero); + mk_nzero(f, nzero); - // Special case: x == 0 -> p/n zero + // Special case: x == 0 -> +zero expr_ref c1(m), v1(m); c1 = is_zero; - v1 = pzero; + v1 = pzero; // No -zero (IEEE754) // Special case: x != 0 expr_ref exp_too_large(m), sig_4(m), exp_2(m); @@ -3194,7 +3232,7 @@ void fpa2bv_converter::mk_to_fp_unsigned(func_decl * f, unsigned num, expr * con unsigned exp_sz = ebits + 2; // (+2 for rounder) exp_2 = m_bv_util.mk_extract(exp_sz - 1, 0, s_exp); // the remaining bits are 0 if ebits is large enough. - exp_too_large = m.mk_false(); // This is always in range. + exp_too_large = m.mk_false(); // The exponent is at most bv_sz, i.e., we need ld(bv_sz)+1 ebits. // exp < bv_sz (+sign bit which is [0]) diff --git a/src/ast/fpa/fpa2bv_converter.h b/src/ast/fpa/fpa2bv_converter.h index 19315129a0f..e237c0dcded 100644 --- a/src/ast/fpa/fpa2bv_converter.h +++ b/src/ast/fpa/fpa2bv_converter.h @@ -208,6 +208,7 @@ class fpa2bv_converter { private: void mk_nan(sort * s, expr_ref & result); + void mk_nzero(sort * s, expr_ref & result); void mk_pzero(sort * s, expr_ref & result); void mk_zero(sort * s, expr_ref & sgn, expr_ref & result); diff --git a/src/ast/fpa_decl_plugin.cpp b/src/ast/fpa_decl_plugin.cpp index 78d300ca79c..76e44278db1 100644 --- a/src/ast/fpa_decl_plugin.cpp +++ b/src/ast/fpa_decl_plugin.cpp @@ -208,8 +208,7 @@ sort * fpa_decl_plugin::mk_float_sort(unsigned ebits, unsigned sbits) { if (ebits > 63) m_manager->raise_exception("maximum number of exponent bits is 63"); - parameter p1(ebits), p2(sbits); - parameter ps[2] = { p1, p2 }; + parameter ps[2] = { parameter(ebits), parameter(sbits) }; sort_size sz; sz = sort_size::mk_very_big(); // TODO: refine return m_manager->mk_sort(symbol("FloatingPoint"), sort_info(m_family_id, FLOATING_POINT_SORT, sz, 2, ps)); diff --git a/src/ast/macros/macro_finder.cpp b/src/ast/macros/macro_finder.cpp index e7452ee9c0b..bc63aae8e10 100644 --- a/src/ast/macros/macro_finder.cpp +++ b/src/ast/macros/macro_finder.cpp @@ -65,7 +65,7 @@ bool macro_finder::is_arith_macro(expr * n, proof * pr, bool deps_valid, expr_de // functions introduced within macros are Skolem functions // To avoid unsound expansion of these as macros (because they // appear in model conversions and are therefore not fully - // replacable) we prevent these from being treated as macro functions. + // replaceable) we prevent these from being treated as macro functions. if (m_macro_manager.contains(f) || f->is_skolem()) return false; diff --git a/src/ast/macros/macro_manager.h b/src/ast/macros/macro_manager.h index a3c1a8d97e9..758e3c1a7ee 100644 --- a/src/ast/macros/macro_manager.h +++ b/src/ast/macros/macro_manager.h @@ -32,7 +32,7 @@ Revision History: where T[X] does not contain f. This class is responsible for storing macros and expanding them. - It has support for backtracking and tagging declarations in an expression as forbidded for being macros. + It has support for backtracking and tagging declarations in an expression as forbidden for being macros. */ class macro_manager { ast_manager & m; diff --git a/src/ast/normal_forms/defined_names.cpp b/src/ast/normal_forms/defined_names.cpp index ad5f8348693..c931c6fad51 100644 --- a/src/ast/normal_forms/defined_names.cpp +++ b/src/ast/normal_forms/defined_names.cpp @@ -207,7 +207,7 @@ void defined_names::impl::mk_definition(expr * e, app * n, sort_ref_buffer & var // the instantiation rules for store(a, i, v) are: // store(a, i, v)[j] = if i = j then v else a[j] with patterns {a[j], store(a, i, v)} { store(a, i, v)[j] } // The first pattern is not included. - // TBD use a model-based scheme for exracting instantiations instead of + // TBD use a model-based scheme for extracting instantiations instead of // using multi-patterns. // diff --git a/src/ast/normal_forms/elim_term_ite.cpp b/src/ast/normal_forms/elim_term_ite.cpp index 3376e9dda0e..077f66d1f31 100644 --- a/src/ast/normal_forms/elim_term_ite.cpp +++ b/src/ast/normal_forms/elim_term_ite.cpp @@ -18,6 +18,7 @@ Revision History: --*/ #include "ast/normal_forms/elim_term_ite.h" #include "ast/ast_smt2_pp.h" +#include "ast/rewriter/rewriter_def.h" br_status elim_term_ite_cfg::reduce_app(func_decl* f, unsigned n, expr * const* args, expr_ref& result, proof_ref& result_pr) { if (!m.is_term_ite(f)) { @@ -38,3 +39,4 @@ br_status elim_term_ite_cfg::reduce_app(func_decl* f, unsigned n, expr * const* return BR_DONE; } +template class rewriter_tpl; diff --git a/src/ast/normal_forms/nnf.cpp b/src/ast/normal_forms/nnf.cpp index b04445d16d8..1f0ce6781ff 100644 --- a/src/ast/normal_forms/nnf.cpp +++ b/src/ast/normal_forms/nnf.cpp @@ -28,6 +28,7 @@ Module Name: #include "ast/rewriter/var_subst.h" #include "ast/normal_forms/name_exprs.h" #include "ast/ast_smt2_pp.h" +#include "ast/ast_pp.h" #include /** @@ -149,7 +150,7 @@ class skolemizer { p = nullptr; if (m_proofs_enabled) { if (q->get_kind() == forall_k) - p = m.mk_skolemization(mk_not(m, q), m.mk_not(r)); + p = m.mk_skolemization(mk_not(m, q), mk_not(m, r)); else p = m.mk_skolemization(q, r); } diff --git a/src/ast/occurs.cpp b/src/ast/occurs.cpp index 2bcd9839655..4e0008373be 100644 --- a/src/ast/occurs.cpp +++ b/src/ast/occurs.cpp @@ -19,6 +19,7 @@ Revision History: #include "ast/occurs.h" #include "ast/for_each_expr.h" +#include "ast/for_each_ast.h" // ----------------------------------- // @@ -49,6 +50,15 @@ namespace { void operator()(quantifier const * n) { } }; + + struct sort_proc { + sort* m_s; + sort_proc(sort* s) :m_s(s) {} + void operator()(sort const* s2) { if (m_s == s2) throw found(); } + void operator()(ast*) {} + }; + + } // Return true if n1 occurs in n2 @@ -74,6 +84,17 @@ bool occurs(func_decl * d, expr * n) { return false; } +bool occurs(sort* s1, sort* s2) { + sort_proc p(s1); + try { + for_each_ast(p, s2, true); + } + catch (const found&) { + return true; + } + return false; +} + void mark_occurs(ptr_vector& to_check, expr* v, expr_mark& occ) { expr_fast_mark2 visited; occ.mark(v, true); @@ -116,4 +137,4 @@ void mark_occurs(ptr_vector& to_check, expr* v, expr_mark& occ) { to_check.pop_back(); } } -} \ No newline at end of file +} diff --git a/src/ast/occurs.h b/src/ast/occurs.h index 7475a292c3a..f2f42aaeefb 100644 --- a/src/ast/occurs.h +++ b/src/ast/occurs.h @@ -31,6 +31,11 @@ bool occurs(expr * n1, expr * n2); */ bool occurs(func_decl * d, expr * n); +/** +* \brief Return true if s1 occurs in s2 +*/ +bool occurs(sort* s1, sort* s2); + /** * \brief Mark sub-expressions of to_check by whether v occurs in these. */ diff --git a/src/ast/pattern/CMakeLists.txt b/src/ast/pattern/CMakeLists.txt index 7393b711065..88dfd198bed 100644 --- a/src/ast/pattern/CMakeLists.txt +++ b/src/ast/pattern/CMakeLists.txt @@ -7,7 +7,7 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/database.h") endif() add_custom_command(OUTPUT "database.h" - COMMAND "${PYTHON_EXECUTABLE}" + COMMAND "${Python3_EXECUTABLE}" "${PROJECT_SOURCE_DIR}/scripts/mk_pat_db.py" "${CMAKE_CURRENT_SOURCE_DIR}/database.smt2" "${CMAKE_CURRENT_BINARY_DIR}/database.h" diff --git a/src/ast/pattern/pattern_inference.cpp b/src/ast/pattern/pattern_inference.cpp index 2fd2b4c82b4..7c7576c84e3 100644 --- a/src/ast/pattern/pattern_inference.cpp +++ b/src/ast/pattern/pattern_inference.cpp @@ -109,6 +109,7 @@ pattern_inference_cfg::pattern_inference_cfg(ast_manager & m, pattern_inference_ m_le(), m_nested_arith_only(true), m_block_loop_patterns(params.m_pi_block_loop_patterns), + m_decompose_patterns(params.m_pi_decompose_patterns), m_candidates(m), m_pattern_weight_lt(m_candidates_info), m_collect(m, *this), @@ -407,6 +408,9 @@ bool pattern_inference_cfg::pattern_weight_lt::operator()(expr * n1, expr * n2) app* pattern_inference_cfg::mk_pattern(app* candidate) { + if (!m_decompose_patterns) + return m.mk_pattern(candidate); + auto has_var_arg = [&](expr* e) { if (!is_app(e)) return false; @@ -620,9 +624,11 @@ bool pattern_inference_cfg::reduce_quantifier( proof_ref & result_pr) { TRACE("pattern_inference", tout << "processing:\n" << mk_pp(q, m) << "\n";); - if (!is_forall(q)) { + if (!m_params.m_pi_enabled) + return false; + + if (!is_forall(q)) return false; - } int weight = q->get_weight(); @@ -649,9 +655,8 @@ bool pattern_inference_cfg::reduce_quantifier( } } - if (q->get_num_patterns() > 0) { + if (q->get_num_patterns() > 0) return false; - } if (m_params.m_pi_nopat_weight >= 0) weight = m_params.m_pi_nopat_weight; diff --git a/src/ast/pattern/pattern_inference.h b/src/ast/pattern/pattern_inference.h index da905dca412..8d179ba33e5 100644 --- a/src/ast/pattern/pattern_inference.h +++ b/src/ast/pattern/pattern_inference.h @@ -20,6 +20,7 @@ Revision History: #include "ast/ast.h" #include "ast/rewriter/rewriter.h" +#include "ast/rewriter/rewriter_def.h" #include "params/pattern_inference_params.h" #include "util/vector.h" #include "util/uint_set.h" @@ -69,6 +70,7 @@ class pattern_inference_cfg : public default_rewriter_cfg { expr * const * m_no_patterns; bool m_nested_arith_only; bool m_block_loop_patterns; + bool m_decompose_patterns; struct info { uint_set m_free_vars; diff --git a/src/ast/polymorphism_inst.cpp b/src/ast/polymorphism_inst.cpp new file mode 100644 index 00000000000..aa9b1e5feb7 --- /dev/null +++ b/src/ast/polymorphism_inst.cpp @@ -0,0 +1,142 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_inst.cpp + +Abstract: + + Utilities for instantiating polymorphic assertions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + + +--*/ +#include "ast/polymorphism_inst.h" +#include "ast/ast_pp.h" + +namespace polymorphism { + + void inst::add(expr* e) { + if (!m.has_type_vars()) + return; + + if (m_from_instantiation.contains(e)) + return; + + instances inst; + u.collect_poly_instances(e, inst.m_poly_fns); + if (inst.m_poly_fns.empty()) + return; + if (m_instances.contains(e)) + return; + + add_instantiations(e, inst.m_poly_fns); + + if (!u.has_type_vars(e)) + return; + + // insert e into the occurs list for polymorphic roots + ast_mark seen; + for (auto* f : inst.m_poly_fns) { + f = m.poly_root(f); + if (seen.is_marked(f)) + continue; + seen.mark(f, true); + if (!m_occurs.contains(f)) { + m_occurs.insert(f, ptr_vector()); + t.push(insert_map(m_occurs, f)); + } + auto& es = m_occurs.find(f); + es.push_back(e); + t.push(remove_back(m_occurs, f)); + } + m_assertions.push_back(e); + t.push(push_back_vector(m_assertions)); + u.collect_type_vars(e, inst.m_tvs); + inst.m_subst = alloc(substitutions); + inst.m_subst->insert(alloc(substitution, m)); + m_instances.insert(e, inst); + t.push(new_obj_trail(inst.m_subst)); + t.push(insert_map(m_instances, e)); + } + + void inst::collect_instantiations(expr* e) { + ptr_vector instances; + u.collect_poly_instances(e, instances); + add_instantiations(e, instances); + } + + void inst::add_instantiations(expr* e, ptr_vector const& instances) { + for (auto* f : instances) { + if (m_in_decl_queue.is_marked(f)) + continue; + m_in_decl_queue.mark(f, true); + m_decl_queue.push_back(f); + t.push(add_decl_queue(*this)); + } + } + + void inst::instantiate(vector& instances) { + unsigned num_decls = m_decl_queue.size(); + if (m_assertions_qhead < m_assertions.size()) { + t.push(value_trail(m_assertions_qhead)); + for (; m_assertions_qhead < m_assertions.size(); ++m_assertions_qhead) { + expr* e = m_assertions.get(m_assertions_qhead); + for (unsigned i = 0; i < num_decls; ++i) + instantiate(m_decl_queue.get(i), e, instances); + } + } + if (m_decl_qhead < num_decls) { + t.push(value_trail(m_decl_qhead)); + for (; m_decl_qhead < num_decls; ++m_decl_qhead) { + func_decl* p = m_decl_queue.get(m_decl_qhead); + func_decl* r = m.poly_root(p); + if (!m_occurs.contains(r)) + continue; + for (expr* e : m_occurs[r]) + instantiate(p, e, instances); + } + } + } + + void inst::instantiate(func_decl* f1, expr* e, vector& instances) { + auto const& [tv, fns, substs] = m_instances[e]; + + for (auto* f2 : fns) { + substitution sub1(m), new_sub(m); + if (!u.unify(f1, f2, sub1)) + continue; + if (substs->contains(&sub1)) + continue; + substitutions new_substs; + for (auto* sub2 : *substs) { + if (!u.unify(sub1, *sub2, new_sub)) + continue; + if (substs->contains(&new_sub)) + continue; + if (new_substs.contains(&new_sub)) + continue; + expr_ref e_inst = new_sub(e); + if (!m_from_instantiation.contains(e_inst)) { + collect_instantiations(e_inst); + auto* new_sub1 = alloc(substitution, new_sub); + instances.push_back(instantiation(e, e_inst, new_sub1)); + new_substs.insert(new_sub1); + m_from_instantiation.insert(e_inst); + m.inc_ref(e_inst); + t.push(insert_ref_map(m, m_from_instantiation, e_inst)); + } + } + for (auto* sub2 : new_substs) { + SASSERT(!substs->contains(sub2)); + substs->insert(sub2); + t.push(new_obj_trail(sub2)); + t.push(insert_map(*substs, sub2)); + } + } + } +} diff --git a/src/ast/polymorphism_inst.h b/src/ast/polymorphism_inst.h new file mode 100644 index 00000000000..1d171b31438 --- /dev/null +++ b/src/ast/polymorphism_inst.h @@ -0,0 +1,91 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_inst.h + +Abstract: + + Utilities for instantiating polymorphic assertions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + + +--*/ +#pragma once + +#include "util/trail.h" +#include "ast/ast.h" +#include "ast/polymorphism_util.h" + +namespace polymorphism { + + struct instantiation { + expr* orig; + expr_ref inst; + substitution* sub; + instantiation(expr* orig, expr_ref& inst, substitution* s): + orig(orig), inst(inst), sub(s) {} + }; + + class inst { + ast_manager& m; + trail_stack& t; + util u; + + struct instances { + ptr_vector m_tvs; + ptr_vector m_poly_fns; + substitutions* m_subst = nullptr; + }; + + func_decl_ref_vector m_poly_roots; + obj_map> m_occurs; + obj_map m_instances; + func_decl_ref_vector m_decl_queue; + unsigned m_decl_qhead = 0; + ast_mark m_in_decl_queue; + expr_ref_vector m_assertions; + unsigned m_assertions_qhead = 0; + obj_hashtable m_from_instantiation; + + struct add_decl_queue : public trail { + inst& i; + add_decl_queue(inst& i): i(i) {} + void undo() override { + i.m_in_decl_queue.mark(i.m_decl_queue.back(), false); + i.m_decl_queue.pop_back(); + }; + }; + + struct remove_back : public trail { + obj_map>& occ; + func_decl* f; + remove_back(obj_map>& occ, func_decl* f): + occ(occ), f(f) {} + void undo() override { + occ.find(f).pop_back(); + } + }; + + void instantiate(func_decl* p, expr* e, vector& instances); + + void collect_instantiations(expr* e); + + void add_instantiations(expr* e, ptr_vector const& insts); + + public: + inst(ast_manager& m, trail_stack& t): + m(m), t(t), u(m), m_poly_roots(m), m_decl_queue(m), m_assertions(m) {} + + void add(expr* e); + + void instantiate(vector& instances); + + bool pending() const { return m_decl_qhead < m_decl_queue.size() || m_assertions_qhead < m_assertions.size(); } + + }; +} diff --git a/src/ast/polymorphism_util.cpp b/src/ast/polymorphism_util.cpp new file mode 100644 index 00000000000..3431dd03430 --- /dev/null +++ b/src/ast/polymorphism_util.cpp @@ -0,0 +1,353 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_util.cpp + +Abstract: + + Utilities for supporting polymorphic type signatures. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + +--*/ + +#include "ast/polymorphism_util.h" +#include "ast/for_each_ast.h" +#include "ast/occurs.h" +#include "ast/ast_pp.h" + +namespace polymorphism { + + sort_ref_vector substitution::operator()(sort_ref_vector const& s) { + sort_ref_vector r(m); + for (auto* srt : s) + r.push_back((*this)(srt)); + return r; + } + + sort_ref substitution::operator()(sort* s) { + if (!m.has_type_var(s)) + return sort_ref(s, m); + if (s->is_type_var()) { + if (m_sub.find(s, s)) + return (*this)(s); + return sort_ref(s, m); + } + unsigned n = s->get_num_parameters(); + vector ps; + for (unsigned i = 0; i < n; ++i) { + auto &p = s->get_parameter(i); + if (p.is_ast() && is_sort(p.get_ast())) { + sort_ref s = (*this)(to_sort(p.get_ast())); + ps.push_back(parameter(s.get())); + } + else + ps.push_back(p); + } + sort_info si(s->get_family_id(), s->get_decl_kind(), n, ps.data(), s->private_parameters()); + return sort_ref(m.mk_sort(s->get_name(), si), m); + } + + expr_ref substitution::operator()(expr* e) { + ptr_vector todo; + expr_ref_vector result(m); + todo.push_back(e); + auto in_cache = [&](expr* a) { + return result.size() > a->get_id() && result.get(a->get_id()); + }; + ptr_buffer args; + sort_ref_buffer domain(m); + while (!todo.empty()) { + expr* a = todo.back(); + if (in_cache(a)) { + todo.pop_back(); + continue; + } + if (is_var(a)) { + if (m.has_type_var(a->get_sort())) + result.setx(a->get_id(), m.mk_var(to_var(a)->get_idx(), (*this)(a->get_sort()))); + else + result.setx(a->get_id(), a); + todo.pop_back(); + } + else if (is_quantifier(a)) { + quantifier* q = to_quantifier(a); + bool pending = false; + if (!in_cache(q->get_expr())) { + todo.push_back(q->get_expr()); + pending = true; + } + ptr_buffer patterns, no_patterns; + unsigned np = q->get_num_patterns(); + for (unsigned i = 0; i < np; ++i) { + if (!in_cache(q->get_pattern(i))) { + todo.push_back(q->get_pattern(i)); + pending = true; + } + else + patterns.push_back(result.get(q->get_pattern(i)->get_id())); + } + np = q->get_num_no_patterns(); + for (unsigned i = 0; i < np; ++i) { + if (!in_cache(q->get_no_pattern(i))) { + todo.push_back(q->get_no_pattern(i)); + pending = true; + } + else + no_patterns.push_back(result.get(q->get_no_pattern(i)->get_id())); + } + if (pending) + continue; + todo.pop_back(); + domain.reset(); + for (unsigned i = 0; i < q->get_num_decls(); ++i) + domain.push_back((*this)(q->get_decl_sort(i))); + quantifier* q2 = + m.mk_quantifier(q->get_kind(), q->get_num_decls(), domain.data(), q->get_decl_names(), result.get(q->get_expr()->get_id()), + q->get_weight(), + q->get_qid(), q->get_skid(), + q->get_num_patterns(), patterns.data(), q->get_num_no_patterns(), no_patterns.data() + ); + result.setx(q->get_id(), q2); + } + else if (is_app(a)) { + args.reset(); + unsigned n = todo.size(); + for (expr* arg : *to_app(a)) { + if (!in_cache(arg)) + todo.push_back(arg); + else + args.push_back(result.get(arg->get_id())); + } + if (n < todo.size()) + continue; + func_decl* f = to_app(a)->get_decl(); + if (f->is_polymorphic()) { + domain.reset(); + for (unsigned i = 0; i < f->get_arity(); ++i) + domain.push_back((*this)(f->get_domain(i))); + sort_ref range = (*this)(f->get_range()); + f = m.instantiate_polymorphic(f, f->get_arity(), domain.data(), range); + } + result.setx(a->get_id(), m.mk_app(f, args)); + todo.pop_back(); + } + } + return expr_ref(result.get(e->get_id()), m); + } + + bool substitution::unify(sort* s1, sort* s2) { + if (s1 == s2) + return true; + if (s1->is_type_var() && m_sub.find(s1, s1)) + return unify(s1, s2); + if (s2->is_type_var() && m_sub.find(s2, s2)) + return unify(s1, s2); + if (s2->is_type_var() && !s1->is_type_var()) + std::swap(s1, s2); + if (s1->is_type_var()) { + auto s22 = (*this)(s2); + if (occurs(s1, s22)) + return false; + m_trail.push_back(s22); + m_trail.push_back(s1); + m_sub.insert(s1, s22); + return true; + } + if (s1->get_family_id() != s2->get_family_id()) + return false; + if (s1->get_decl_kind() != s2->get_decl_kind()) + return false; + if (s1->get_name() != s2->get_name()) + return false; + if (s1->get_num_parameters() != s2->get_num_parameters()) + return false; + for (unsigned i = s1->get_num_parameters(); i-- > 0;) { + auto &p1 = s1->get_parameter(i); + auto &p2 = s2->get_parameter(i); + if (p1.is_ast() && is_sort(p1.get_ast())) { + if (!p2.is_ast()) + return false; + if (!is_sort(p2.get_ast())) + return false; + if (!unify(to_sort(p1.get_ast()), to_sort(p2.get_ast()))) + return false; + continue; + } + if (p1 != p2) + return false; + } + return true; + } + + bool substitution::match(sort* s1, sort* s2) { + if (s1 == s2) + return true; + if (s1->is_type_var() && m_sub.find(s1, s1)) + return match(s1, s2); + if (s1->is_type_var()) { + m_trail.push_back(s2); + m_trail.push_back(s1); + m_sub.insert(s1, s2); + return true; + } + if (s1->get_family_id() != s2->get_family_id()) + return false; + if (s1->get_decl_kind() != s2->get_decl_kind()) + return false; + if (s1->get_name() != s2->get_name()) + return false; + if (s1->get_num_parameters() != s2->get_num_parameters()) + return false; + for (unsigned i = s1->get_num_parameters(); i-- > 0;) { + auto &p1 = s1->get_parameter(i); + auto &p2 = s2->get_parameter(i); + if (p1.is_ast() && is_sort(p1.get_ast())) { + if (!p2.is_ast()) + return false; + if (!is_sort(p2.get_ast())) + return false; + if (!match(to_sort(p1.get_ast()), to_sort(p2.get_ast()))) + return false; + continue; + } + if (p1 != p2) + return false; + } + return true; + } + + // util + bool util::unify(sort* s1, sort* s2, substitution& sub) { + return sub.unify(s1, s2); + } + + bool util::unify(func_decl* f1, func_decl* f2, substitution& sub) { + if (f1 == f2) + return true; + if (!f1->is_polymorphic() || !f2->is_polymorphic()) + return false; + if (m.poly_root(f1) != m.poly_root(f2)) + return false; + for (unsigned i = f1->get_arity(); i-- > 0; ) + if (!sub.unify(fresh(f1->get_domain(i)), f2->get_domain(i))) + return false; + return sub.unify(fresh(f1->get_range()), f2->get_range()); + } + + bool util::unify(substitution const& s1, substitution const& s2, + substitution& sub) { + sort* v2; + for (auto const& [k, v] : s1) + sub.insert(k, v); + for (auto const& [k, v] : s2) { + if (sub.find(k, v2)) { + if (!sub.unify(sub(v), v2)) + return false; + } + else + sub.insert(k, sub(v)); + } + return true; + } + + bool util::match(substitution& sub, sort* s1, sort* s_ground) { + return sub.match(s1, s_ground); + } + + /** + * Create fresh variables, but with caching. + * So "fresh" variables are not truly fresh globally. + * This can block some unifications and therefore block some instantiations of + * polymorphic assertions. A different caching scheme could be created to + * ensure that fresh variables are introduced at the right time, or use other + * tricks such as creating variable/offset pairs to distinguish name spaces without + * incurring costs. + */ + sort_ref util::fresh(sort* s) { + sort* s1; + if (m_fresh.find(s, s1)) + return sort_ref(s1, m); + + if (m.is_type_var(s)) { + s1 = m.mk_type_var(symbol("fresh!" + std::to_string(m_counter))); + m_trail.push_back(s1); + m_trail.push_back(s); + m_fresh.insert(s, s1); + return sort_ref(s1, m); + } + vector params; + for (unsigned i = 0; i < s->get_num_parameters(); ++i) { + const parameter &p = s->get_parameter(i); + if (p.is_ast() && is_sort(p.get_ast())) { + sort_ref fs = fresh(to_sort(p.get_ast())); + params.push_back(parameter(fs.get())); + } + else + params.push_back(p); + } + sort_info info(s->get_family_id(), s->get_decl_kind(), params.size(), params.data(), s->private_parameters()); + s1 = m.mk_sort(s->get_name(), info); + m_trail.push_back(s1); + m_trail.push_back(s); + m_fresh.insert(s, s1); + return sort_ref(s1, m); + } + + sort_ref_vector util::fresh(unsigned n, sort* const* s) { + sort_ref_vector r(m); + for (unsigned i = 0; i < n; ++i) + r.push_back(fresh(s[i])); + return r; + } + + void util::collect_poly_instances(expr* e, ptr_vector& instances) { + struct proc { + ast_manager& m; + ptr_vector& instances; + proc(ast_manager& m, ptr_vector& instances) : m(m), instances(instances) {} + void operator()(func_decl* f) { + if (f->is_polymorphic() && !m.is_eq(f) && !is_decl_of(f, pattern_family_id, OP_PATTERN)) + instances.push_back(f); + } + void operator()(ast* a) {} + }; + proc proc(m, instances); + for_each_ast(proc, e, false); + } + + bool util::has_type_vars(expr* e) { + struct proc { + ast_manager& m; + bool found = false; + proc(ast_manager& m) : m(m) {} + void operator()(sort* f) { + if (m.has_type_var(f)) + found = true; + } + void operator()(ast* a) {} + }; + proc proc(m); + for_each_ast(proc, e, false); + return proc.found; + } + + void util::collect_type_vars(expr* e, ptr_vector& tvs) { + struct proc { + ast_manager& m; + ptr_vector& tvs; + proc(ast_manager& m, ptr_vector& tvs) : m(m), tvs(tvs) {} + void operator()(sort* s) { + if (m.is_type_var(s)) + tvs.push_back(s); + } + void operator()(ast* a) {} + }; + proc proc(m, tvs); + for_each_ast(proc, e, true); + } +} diff --git a/src/ast/polymorphism_util.h b/src/ast/polymorphism_util.h new file mode 100644 index 00000000000..3023d033874 --- /dev/null +++ b/src/ast/polymorphism_util.h @@ -0,0 +1,112 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_util.h + +Abstract: + + Utilities for supporting polymorphic type signatures. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "util/hashtable.h" + +namespace polymorphism { + + class substitution { + ast_manager& m; + obj_map m_sub; + sort_ref_vector m_trail; + public: + substitution(ast_manager& m): m(m), m_trail(m) {} + + sort_ref_vector operator()(sort_ref_vector const& s); + + sort_ref operator()(sort* s); + + expr_ref operator()(expr* e); + + bool unify(sort* s1, sort* s2); + + bool match(sort* s1, sort* s_ground); + + obj_map::iterator begin() const { return m_sub.begin(); } + obj_map::iterator end() const { return m_sub.end(); } + + void insert(sort* v, sort* t) { m_trail.push_back(v).push_back(t); m_sub.insert(v, t); } + + bool find(sort* v, sort*& t) const { return m_sub.find(v, t); } + + unsigned size() const { return m_sub.size(); } + + /** + * weak equality: strong equality considers applying substitutions recursively in range + * because substitutions may be in triangular form. + */ + struct eq { + bool operator()(substitution const* s1, substitution const* s2) const { + if (s1->size() != s2->size()) + return false; + sort* v2; + for (auto const& [k, v] : *s1) { + if (!s2->find(k, v2)) + return false; + if (v != v2) + return false; + } + return true; + } + }; + + struct hash { + unsigned operator()(substitution const* s) const { + unsigned hash = 0xfabc1234 + s->size(); + for (auto const& [k, v] : *s) + hash ^= k->hash() + 2 * v->hash(); + return hash; + } + }; + }; + + typedef hashtable substitutions; + + class util { + ast_manager& m; + sort_ref_vector m_trail; + obj_map m_fresh; + unsigned m_counter = 0; + + sort_ref fresh(sort* s); + + sort_ref_vector fresh(unsigned n, sort* const* s); + + public: + util(ast_manager& m): m(m), m_trail(m) {} + + bool unify(sort* s1, sort* s2, substitution& sub); + + bool unify(func_decl* f1, func_decl* f2, substitution& sub); + + bool unify(substitution const& s1, substitution const& s2, + substitution& sub); + + bool match(substitution& sub, sort* s1, sort* s_ground); + + // collect instantiations of polymorphic functions + void collect_poly_instances(expr* e, ptr_vector& instances); + + // test if expression contains polymorphic variable. + bool has_type_vars(expr* e); + + void collect_type_vars(expr* e, ptr_vector& tvs); + + }; +} diff --git a/src/ast/proofs/proof_utils.cpp b/src/ast/proofs/proof_utils.cpp index bcc1aed27a8..a3508969bf8 100644 --- a/src/ast/proofs/proof_utils.cpp +++ b/src/ast/proofs/proof_utils.cpp @@ -260,7 +260,7 @@ class reduce_hypotheses { { cls.push_back(cls_fact->get_arg(i)); } } else { cls.push_back(cls_fact); } - // construct new resovent + // construct new resolvent ptr_buffer new_fact_cls; bool found; // XXX quadratic @@ -604,7 +604,7 @@ class reduce_hypotheses0 { // -- otherwise, the fact has not changed. nothing to simplify SASSERT(m.get_fact(tmp) == m.get_fact(m.get_parent(p, i))); parents.push_back(tmp); - // remember that we have this derivation while we have not poped the trail + // remember that we have this derivation while we have not popped the trail // but only if the proof is closed (i.e., a real unit) if (is_closed(tmp) && !m_units.contains(m.get_fact(tmp))) { m_units.insert(m.get_fact(tmp), tmp); diff --git a/src/ast/recfun_decl_plugin.cpp b/src/ast/recfun_decl_plugin.cpp index 7a3e9521da6..495f3cb2086 100644 --- a/src/ast/recfun_decl_plugin.cpp +++ b/src/ast/recfun_decl_plugin.cpp @@ -602,15 +602,6 @@ namespace recfun { m_args.append(n->get_num_args(), n->get_args()); } - case_expansion::case_expansion(case_expansion const & from) - : m_lhs(from.m_lhs), - m_def(from.m_def), - m_args(from.m_args) {} - case_expansion::case_expansion(case_expansion && from) - : m_lhs(from.m_lhs), - m_def(from.m_def), - m_args(std::move(from.m_args)) {} - std::ostream& case_expansion::display(std::ostream & out) const { return out << "case_exp(" << m_lhs << ")"; } diff --git a/src/ast/recfun_decl_plugin.h b/src/ast/recfun_decl_plugin.h index 442bdadbd0b..3a1dc8e84c1 100644 --- a/src/ast/recfun_decl_plugin.h +++ b/src/ast/recfun_decl_plugin.h @@ -301,8 +301,6 @@ namespace recfun { recfun::def * m_def; expr_ref_vector m_args; case_expansion(recfun::util& u, app * n); - case_expansion(case_expansion const & from); - case_expansion(case_expansion && from); std::ostream& display(std::ostream& out) const; }; @@ -323,10 +321,6 @@ namespace recfun { } body_expansion(app_ref & pred, recfun::case_def const & d, expr_ref_vector & args) : m_pred(pred), m_cdef(&d), m_args(args) {} - body_expansion(body_expansion const & from): - m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(from.m_args) {} - body_expansion(body_expansion && from) noexcept : - m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(std::move(from.m_args)) {} std::ostream& display(std::ostream& out) const; }; diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index df803b0f186..7f351ecb652 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -23,7 +23,6 @@ z3_add_component(rewriter factor_rewriter.cpp fpa_rewriter.cpp func_decl_replace.cpp - hoist_rewriter.cpp inj_axiom.cpp label_rewriter.cpp macro_replacer.cpp diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index ed36562caf0..f21a5c4be68 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -91,6 +91,10 @@ br_status arith_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * c case OP_SINH: SASSERT(num_args == 1); st = mk_sinh_core(args[0], result); break; case OP_COSH: SASSERT(num_args == 1); st = mk_cosh_core(args[0], result); break; case OP_TANH: SASSERT(num_args == 1); st = mk_tanh_core(args[0], result); break; + case OP_ARITH_BAND: SASSERT(num_args == 2); st = mk_band_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_SHL: SASSERT(num_args == 2); st = mk_shl_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_ASHR: SASSERT(num_args == 2); st = mk_ashr_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_LSHR: SASSERT(num_args == 2); st = mk_lshr_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; default: st = BR_FAILED; break; } CTRACE("arith_rewriter", st != BR_FAILED, tout << st << ": " << mk_pp(f, m); @@ -551,25 +555,10 @@ br_status arith_rewriter::mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kin } if (m_anum_simp) { - if (is_numeral(arg1, a1) && m_util.is_irrational_algebraic_numeral(arg2)) { - anum_manager & am = m_util.am(); - scoped_anum v1(am); - am.set(v1, a1.to_mpq()); - anum const & v2 = m_util.to_irrational_algebraic_numeral(arg2); - ANUM_LE_GE_EQ(); - } - if (m_util.is_irrational_algebraic_numeral(arg1) && is_numeral(arg2, a2)) { - anum_manager & am = m_util.am(); - anum const & v1 = m_util.to_irrational_algebraic_numeral(arg1); - scoped_anum v2(am); - am.set(v2, a2.to_mpq()); - ANUM_LE_GE_EQ(); - } - if (m_util.is_irrational_algebraic_numeral(arg1) && m_util.is_irrational_algebraic_numeral(arg2)) { - anum_manager & am = m_util.am(); - anum const & v1 = m_util.to_irrational_algebraic_numeral(arg1); - anum const & v2 = m_util.to_irrational_algebraic_numeral(arg2); - ANUM_LE_GE_EQ(); + auto& am = m_util.am(); + scoped_anum v1(am), v2(am); + if (is_algebraic_numeral(arg1, v1) && is_algebraic_numeral(arg2, v2)) { + ANUM_LE_GE_EQ(); } } br_status st1 = is_separated(arg1, arg2, kind, result); @@ -669,6 +658,7 @@ br_status arith_rewriter::mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kin return BR_FAILED; } + br_status arith_rewriter::mk_le_core(expr * arg1, expr * arg2, expr_ref & result) { return mk_le_ge_eq_core(arg1, arg2, LE, result); } @@ -744,18 +734,26 @@ bool arith_rewriter::mk_eq_mod(expr* arg1, expr* arg2, expr_ref& result) { if (g == 1) { expr_ref nb(m_util.mk_numeral(b, true), m); result = m.mk_eq(m_util.mk_mod(u, y), - m_util.mk_mod(m_util.mk_mul(nb, arg2), y)); + m_util.mk_mod(m_util.mk_mul(nb, arg2), y)); return true; } } return false; } -expr_ref arith_rewriter::neg_monomial(expr* e) const { +expr_ref arith_rewriter::neg_monomial(expr* e) { expr_ref_vector args(m); rational a1; if (m_util.is_numeral(e, a1)) args.push_back(m_util.mk_numeral(-a1, e->get_sort())); + else if (m_util.is_irrational_algebraic_numeral(e)) { + auto& n = m_util.to_irrational_algebraic_numeral(e); + auto& am = m_util.am(); + scoped_anum new_n(am); + am.set(new_n, n); + am.neg(new_n); + args.push_back(m_util.mk_numeral(am, new_n, m_util.is_int(e))); + } else if (is_app(e) && m_util.is_mul(e)) { if (is_numeral(to_app(e)->get_arg(0), a1)) { if (!a1.is_minus_one()) { @@ -780,7 +778,7 @@ expr_ref arith_rewriter::neg_monomial(expr* e) const { } } -bool arith_rewriter::is_neg_poly(expr* t, expr_ref& neg) const { +bool arith_rewriter::is_neg_poly(expr* t, expr_ref& neg) { rational r; if (m_util.is_mul(t) && is_numeral(to_app(t)->get_arg(0), r) && r.is_neg()) { neg = neg_monomial(t); @@ -824,6 +822,36 @@ bool arith_rewriter::is_anum_simp_target(unsigned num_args, expr * const * args) return false; } +bool arith_rewriter::is_algebraic_numeral(expr* n, scoped_anum& a) { + auto& am = m_util.am(); + expr* x, *y; + rational r; + if (m_util.is_mul(n, x, y)) { + scoped_anum ax(am), ay(am); + if (is_algebraic_numeral(x, ax) && is_algebraic_numeral(y, ay)) { + am.mul(ax, ay, a); + return true; + } + } + else if (m_util.is_add(n, x, y)) { + scoped_anum ax(am), ay(am); + if (is_algebraic_numeral(x, ax) && is_algebraic_numeral(y, ay)) { + am.add(ax, ay, a); + return true; + } + } + else if (m_util.is_numeral(n, r)) { + am.set(a, r.to_mpq()); + return true; + } + else if (m_util.is_irrational_algebraic_numeral(n)) { + am.set(a, m_util.to_irrational_algebraic_numeral(n)); + return true; + } + return false; +} + + br_status arith_rewriter::mk_add_core(unsigned num_args, expr * const * args, expr_ref & result) { if (is_anum_simp_target(num_args, args)) { expr_ref_buffer new_args(m); @@ -1091,7 +1119,7 @@ br_status arith_rewriter::mk_idiv_core(expr * arg1, expr * arg2, expr_ref & resu return BR_REWRITE3; } } - if (divides(arg1, arg2, result)) { + if (get_divides(arg1, arg2, result)) { expr_ref zero(m_util.mk_int(0), m); result = m.mk_ite(m.mk_eq(zero, arg2), m_util.mk_idiv(arg1, zero), result); return BR_REWRITE_FULL; @@ -1109,7 +1137,7 @@ br_status arith_rewriter::mk_idiv_core(expr * arg1, expr * arg2, expr_ref & resu // // implement div ab ac = floor( ab / ac) = floor (b / c) = div b c // -bool arith_rewriter::divides(expr* num, expr* den, expr_ref& result) { +bool arith_rewriter::get_divides(expr* num, expr* den, expr_ref& result) { expr_fast_mark1 mark; rational num_r(1), den_r(1); expr* num_e = nullptr, *den_e = nullptr; @@ -1121,7 +1149,7 @@ bool arith_rewriter::divides(expr* num, expr* den, expr_ref& result) { if (m_util.is_numeral(arg, num_r)) num_e = arg; } for (expr* arg : args2) { - // dont remove divisor on (div (* -1 x) (* -1 y)) because rewriting would diverge. + // don't remove divisor on (div (* -1 x) (* -1 y)) because rewriting would diverge. if (mark.is_marked(arg) && (!m_util.is_numeral(arg, num_r) || !num_r.is_minus_one())) { result = remove_divisor(arg, num, den); return true; @@ -1201,22 +1229,25 @@ static rational symmod(rational const& a, rational const& b) { if (2*r > b) r -= b; return r; } - + br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & result) { set_curr_sort(arg1->get_sort()); numeral v1, v2; bool is_int; - if (m_util.is_numeral(arg1, v1, is_int) && m_util.is_numeral(arg2, v2, is_int) && !v2.is_zero()) { + bool is_num1 = m_util.is_numeral(arg1, v1, is_int); + bool is_num2 = m_util.is_numeral(arg2, v2, is_int); + + if (is_num1 && is_num2 && !v2.is_zero()) { result = m_util.mk_numeral(mod(v1, v2), is_int); return BR_DONE; } - if (m_util.is_numeral(arg2, v2, is_int) && is_int && (v2.is_one() || v2.is_minus_one())) { + if (is_num2 && is_int && (v2.is_one() || v2.is_minus_one())) { result = m_util.mk_numeral(numeral(0), true); return BR_DONE; } - if (arg1 == arg2 && !m_util.is_numeral(arg2)) { + if (arg1 == arg2 && !is_num2) { expr_ref zero(m_util.mk_int(0), m); result = m.mk_ite(m.mk_eq(arg2, zero), m_util.mk_mod(zero, zero), zero); return BR_DONE; @@ -1224,13 +1255,13 @@ br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & resul // mod is idempotent on non-zero modulus. expr* t1, *t2; - if (m_util.is_mod(arg1, t1, t2) && t2 == arg2 && m_util.is_numeral(arg2, v2, is_int) && is_int && !v2.is_zero()) { + if (m_util.is_mod(arg1, t1, t2) && t2 == arg2 && is_num2 && is_int && !v2.is_zero()) { result = arg1; return BR_DONE; } // propagate mod inside only if there is something to reduce. - if (m_util.is_numeral(arg2, v2, is_int) && is_int && v2.is_pos() && (is_add(arg1) || is_mul(arg1))) { + if (is_num2 && is_int && v2.is_pos() && (is_add(arg1) || is_mul(arg1))) { TRACE("mod_bug", tout << "mk_mod:\n" << mk_ismt2_pp(arg1, m) << "\n" << mk_ismt2_pp(arg2, m) << "\n";); expr_ref_buffer args(m); bool change = false; @@ -1252,17 +1283,43 @@ br_status arith_rewriter::mk_mod_core(expr * arg1, expr * arg2, expr_ref & resul args.push_back(arg); } } - if (!change) { - return BR_FAILED; // did not find any target for applying simplification + if (change) { + result = m_util.mk_mod(m.mk_app(to_app(arg1)->get_decl(), args.size(), args.data()), arg2); + TRACE("mod_bug", tout << "mk_mod result: " << mk_ismt2_pp(result, m) << "\n";); + return BR_REWRITE3; } - result = m_util.mk_mod(m.mk_app(to_app(arg1)->get_decl(), args.size(), args.data()), arg2); - TRACE("mod_bug", tout << "mk_mod result: " << mk_ismt2_pp(result, m) << "\n";); - return BR_REWRITE3; + } + + expr* x, *y; + if (is_num2 && v2.is_pos() && m_util.is_mul(arg1, x, y) && m_util.is_numeral(x, v1, is_int) && v1 > 0 && divides(v1, v2)) { + result = m_util.mk_mul(m_util.mk_int(v1), m_util.mk_mod(y, m_util.mk_int(v2/v1))); + return BR_REWRITE1; } return BR_FAILED; } +bool arith_rewriter::get_range(expr* e, rational& lo, rational& hi) { + expr* x, *y; + rational r; + if (m_util.is_idiv(e, x, y) && m_util.is_numeral(y, r) && get_range(x, lo, hi) && 0 <= lo && r > 0) { + lo = div(lo, r); + hi = div(hi, r); + return true; + } + if (m_util.is_mod(e, x, y) && m_util.is_numeral(y, r) && r > 0) { + lo = 0; + hi = r - 1; + return true; + } + if (m_util.is_numeral(e, r)) { + lo = hi = r; + return true; + } + return false; +} + + br_status arith_rewriter::mk_rem_core(expr * arg1, expr * arg2, expr_ref & result) { set_curr_sort(arg1->get_sort()); numeral v1, v2; @@ -1325,6 +1382,134 @@ app* arith_rewriter_core::mk_power(expr* x, rational const& r, sort* s) { return y; } +br_status arith_rewriter::mk_shl_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && is_num_y) { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_int(mod(x * rational::power_of_two(y.get_unsigned()), N)); + return BR_DONE; + } + if (is_num_y) { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_mod(m_util.mk_mul(arg1, m_util.mk_int(rational::power_of_two(y.get_unsigned()))), m_util.mk_int(N)); + return BR_REWRITE1; + } + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + return BR_FAILED; +} +br_status arith_rewriter::mk_ashr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_x && is_num_y) { + bool signx = x >= N/2; + rational d = div(x, rational::power_of_two(y.get_unsigned())); + SASSERT(y >= 0); + if (signx) { + if (y >= sz) + result = m_util.mk_int(N-1); + else + result = m_util.mk_int(d); + } + else { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_int(mod(d - rational::power_of_two(sz - y.get_unsigned()), N)); + } + return BR_DONE; + } + return BR_FAILED; +} + +br_status arith_rewriter::mk_lshr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_y && y == 0) { + result = arg1; + return BR_DONE; + } + if (is_num_x && is_num_y) { + if (y >= sz) + result = m_util.mk_int(0); + else { + rational d = div(x, rational::power_of_two(y.get_unsigned())); + result = m_util.mk_int(d); + } + return BR_DONE; + } + return BR_FAILED; +} + +br_status arith_rewriter::mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x.is_zero()) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_y && y.is_zero()) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_x && is_num_y) { + rational r(0); + for (unsigned i = 0; i < sz; ++i) + if (x.get_bit(i) && y.get_bit(i)) + r += rational::power_of_two(i); + result = m_util.mk_int(r); + return BR_DONE; + } + if (is_num_x && (x + 1).is_power_of_two()) { + result = m_util.mk_mod(arg2, m_util.mk_int(x + 1)); + return BR_REWRITE1; + } + if (is_num_y && (y + 1).is_power_of_two()) { + result = m_util.mk_mod(arg1, m_util.mk_int(y + 1)); + return BR_REWRITE1; + } + return BR_FAILED; +} + br_status arith_rewriter::mk_power_core(expr * arg1, expr * arg2, expr_ref & result) { numeral x, y; bool is_num_x = m_util.is_numeral(arg1, x); @@ -1569,21 +1754,48 @@ br_status arith_rewriter::mk_to_real_core(expr * arg, expr_ref & result) { } br_status arith_rewriter::mk_is_int(expr * arg, expr_ref & result) { - numeral a; - if (m_util.is_numeral(arg, a)) { - result = a.is_int() ? m.mk_true() : m.mk_false(); + numeral n; + + if (m_util.is_numeral(arg, n)) { + result = n.is_int() ? m.mk_true() : m.mk_false(); return BR_DONE; } - else if (m_util.is_to_real(arg)) { + + if (m_util.is_to_real(arg)) { result = m.mk_true(); return BR_DONE; } - else { - result = m.mk_eq(m.mk_app(get_fid(), OP_TO_REAL, - m.mk_app(get_fid(), OP_TO_INT, arg)), - arg); - return BR_REWRITE3; + + ptr_buffer todo; + todo.push_back(arg); + expr_fast_mark1 mark; + for (unsigned i = 0; i < todo.size(); ++i) { + expr* e = todo[i]; + if (mark.is_marked(e)) + continue; + mark.mark(e, true); + if (m_util.is_to_real(e)) + continue; + if (m_util.is_numeral(e, n)) { + if (n.is_int()) + continue; + goto bail; + } + if (m_util.is_mul(e) || m_util.is_add(e) || m_util.is_sub(e) || m_util.is_uminus(e)) { + for (expr* a : *to_app(e)) + todo.push_back(a); + continue; + } + goto bail; } + result = m.mk_true(); + return BR_DONE; + + bail: + result = m.mk_eq(m.mk_app(get_fid(), OP_TO_REAL, + m.mk_app(get_fid(), OP_TO_INT, arg)), + arg); + return BR_REWRITE3; } br_status arith_rewriter::mk_abs_core(expr * arg, expr_ref & result) { @@ -1592,7 +1804,7 @@ br_status arith_rewriter::mk_abs_core(expr * arg, expr_ref & result) { } -// Return true if t is of the form c*Pi where c is a numeral. +// Return true if t is of the form c*Pi where c is a numeral. // Store c into k bool arith_rewriter::is_pi_multiple(expr * t, rational & k) { if (m_util.is_pi(t)) { @@ -1603,7 +1815,7 @@ bool arith_rewriter::is_pi_multiple(expr * t, rational & k) { return m_util.is_mul(t, a, b) && m_util.is_pi(b) && m_util.is_numeral(a, k); } -// Return true if t is of the form (+ s c*Pi) where c is a numeral. +// Return true if t is of the form (+ s c*Pi) where c is a numeral. // Store c into k, and c*Pi into m. bool arith_rewriter::is_pi_offset(expr * t, rational & k, expr * & m) { if (m_util.is_add(t)) { @@ -1916,7 +2128,7 @@ br_status arith_rewriter::mk_tan_core(expr * arg, expr_ref & result) { br_status arith_rewriter::mk_asin_core(expr * arg, expr_ref & result) { // Remark: we assume that ForAll x : asin(-x) == asin(x). // Mathematica uses this as an axiom. Although asin is an underspecified function for x < -1 or x > 1. - // Actually, in Mathematica, asin(x) is a total function that returns a complex number fo x < -1 or x > 1. + // Actually, in Mathematica, asin(x) is a total function that returns a complex number for x < -1 or x > 1. rational k; if (is_numeral(arg, k)) { if (k.is_zero()) { diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index 3cd9d6165a4..01fea0ac7f0 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -21,6 +21,7 @@ Module Name: #include "ast/rewriter/poly_rewriter.h" #include "ast/arith_decl_plugin.h" #include "ast/seq_decl_plugin.h" +#include "math/polynomial/algebraic_numbers.h" class arith_rewriter_core { protected: @@ -62,6 +63,7 @@ class arith_rewriter : public poly_rewriter { bool m_eq2ineq; unsigned m_max_degree; + bool get_range(expr* e, rational& lo, rational& hi); void get_coeffs_gcd(expr * t, numeral & g, bool & first, unsigned & num_consts); enum const_treatment { CT_FLOOR, CT_CEIL, CT_FALSE }; bool div_polynomial(expr * t, numeral const & g, const_treatment ct, expr_ref & result); @@ -80,6 +82,7 @@ class arith_rewriter : public poly_rewriter { void updt_local_params(params_ref const & p); bool is_anum_simp_target(unsigned num_args, expr * const * args); + bool is_algebraic_numeral(expr* n, scoped_anum& a); br_status mk_div_irrat_rat(expr * arg1, expr * arg2, expr_ref & result); br_status mk_div_rat_irrat(expr * arg1, expr * arg2, expr_ref & result); @@ -97,11 +100,11 @@ class arith_rewriter : public poly_rewriter { bool is_2_pi_integer_offset(expr * t, expr * & m); bool is_pi_integer(expr * t); bool is_pi_integer_offset(expr * t, expr * & m); - bool is_neg_poly(expr* e, expr_ref& neg) const; - expr_ref neg_monomial(expr * e) const; + bool is_neg_poly(expr* e, expr_ref& neg); + expr_ref neg_monomial(expr * e); expr * mk_sin_value(rational const & k); app * mk_sqrt(rational const & k); - bool divides(expr* d, expr* n, expr_ref& result); + bool get_divides(expr* d, expr* n, expr_ref& result); expr_ref remove_divisor(expr* arg, expr* num, expr* den); void flat_mul(expr* e, ptr_buffer& args); void remove_divisor(expr* d, ptr_buffer& args); @@ -157,6 +160,10 @@ class arith_rewriter : public poly_rewriter { br_status mk_mod_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_rem_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_power_core(expr* arg1, expr* arg2, expr_ref & result); + br_status mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_shl_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_lshr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_ashr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); void mk_div(expr * arg1, expr * arg2, expr_ref & result) { if (mk_div_core(arg1, arg2, result) == BR_FAILED) result = m.mk_app(get_fid(), OP_DIV, arg1, arg2); diff --git a/src/ast/rewriter/array_rewriter.cpp b/src/ast/rewriter/array_rewriter.cpp index e580eb82d0d..bd67a940e9d 100644 --- a/src/ast/rewriter/array_rewriter.cpp +++ b/src/ast/rewriter/array_rewriter.cpp @@ -24,6 +24,7 @@ Module Name: #include "ast/rewriter/var_subst.h" #include "params/array_rewriter_params.hpp" #include "util/util.h" +#include "ast/array_peq.h" void array_rewriter::updt_params(params_ref const & _p) { array_rewriter_params p(_p); @@ -40,8 +41,48 @@ void array_rewriter::get_param_descrs(param_descrs & r) { } br_status array_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { + br_status st = BR_FAILED; + + // BEGIN: rewrite rules for PEQs + if (is_partial_eq(f)) { + SASSERT(num_args >= 2); + expr *e0, *e1; + e0 = args[0]; + e1 = args[1]; + + expr_ref a(m()), val(m()); + expr_ref_vector vindex(m()); + + if (e0 == e1) { + // t peq t --> true + result = m().mk_true(); + st = BR_DONE; + } + else if (m_util.is_store_ext(e0, a, vindex, val)) { + if (num_args == 2 && a == e1) { + // (a[i := x] peq_{\emptyset} a) ---> a[i] == x + mk_select(vindex.size(), vindex.data(), result); + result = m().mk_eq(result, val); + st = BR_REWRITE_FULL; + } + else if (a == e1 && vindex.size() == num_args + 2) { + // a [i: = x] peq_{i} a -- > true + bool all_eq = true; + for (unsigned i = 0, sz = vindex.size(); all_eq && i < sz; + ++i) { + all_eq &= vindex.get(i) == args[2+i]; + } + if (all_eq) { + result = m().mk_true(); + st = BR_DONE; + } + } + } + return st; + } + // END: rewrite rules for PEQs + SASSERT(f->get_family_id() == get_fid()); - br_status st; switch (f->get_decl_kind()) { case OP_SELECT: st = mk_select_core(num_args, args, result); @@ -165,7 +206,9 @@ br_status array_rewriter::mk_store_core(unsigned num_args, expr * const * args, bool array_rewriter::squash_store(unsigned n, expr* const* args, expr_ref& result) { ptr_buffer parents, sargs; expr* a = args[0]; - while (m_util.is_store(a)) { + unsigned rounds = 0; + while (m_util.is_store(a) && rounds < 10) { + ++rounds; lbool r = compare_args(n - 2, args + 1, to_app(a)->get_args() + 1); switch (r) { case l_undef: diff --git a/src/ast/rewriter/bit_blaster/bit_blaster_rewriter.cpp b/src/ast/rewriter/bit_blaster/bit_blaster_rewriter.cpp index 801ebf4b594..bb8b067e808 100644 --- a/src/ast/rewriter/bit_blaster/bit_blaster_rewriter.cpp +++ b/src/ast/rewriter/bit_blaster/bit_blaster_rewriter.cpp @@ -21,6 +21,7 @@ Module Name: #include "ast/rewriter/bit_blaster/bit_blaster_tpl_def.h" #include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/bool_rewriter.h" +#include "ast/rewriter/th_rewriter.h" #include "util/ref_util.h" #include "ast/ast_smt2_pp.h" @@ -549,10 +550,19 @@ MK_PARAMETRIC_UNARY_REDUCE(reduce_sign_extend, mk_sign_extend); case OP_INT2BV: case OP_BV2INT: return BR_FAILED; - default: + default: TRACE("bit_blaster", tout << "non-supported operator: " << f->get_name() << "\n"; for (unsigned i = 0; i < num; i++) tout << mk_ismt2_pp(args[i], m()) << std::endl;); + { + expr_ref r(m().mk_app(f, num, args), m()); + result = r; + th_rewriter rw(m()); + rw(result); + if (!is_app(result) || to_app(result)->get_decl() != f) + return BR_REWRITE_FULL; + } throw_unsupported(f); + } } @@ -605,6 +615,8 @@ MK_PARAMETRIC_UNARY_REDUCE(reduce_sign_extend, mk_sign_extend); if (m_blast_quant) { if (m_bindings.empty()) return false; + if (!butil().is_bv(t)) + return false; unsigned shift = m_shifts.back(); if (t->get_idx() >= m_bindings.size()) { if (shift == 0) diff --git a/src/ast/rewriter/bool_rewriter.cpp b/src/ast/rewriter/bool_rewriter.cpp index 95c0950d87e..2176ff1297a 100644 --- a/src/ast/rewriter/bool_rewriter.cpp +++ b/src/ast/rewriter/bool_rewriter.cpp @@ -26,6 +26,7 @@ Module Name: void bool_rewriter::updt_params(params_ref const & _p) { bool_rewriter_params p(_p); m_flat_and_or = p.flat_and_or(); + m_sort_disjunctions = p.sort_disjunctions(); m_elim_and = p.elim_and(); m_elim_ite = p.elim_ite(); m_local_ctx = p.local_ctx(); @@ -33,7 +34,6 @@ void bool_rewriter::updt_params(params_ref const & _p) { m_blast_distinct = p.blast_distinct(); m_blast_distinct_threshold = p.blast_distinct_threshold(); m_ite_extra_rules = p.ite_extra_rules(); - m_hoist.set_elim_and(m_elim_and); } void bool_rewriter::get_param_descrs(param_descrs & r) { @@ -183,7 +183,7 @@ br_status bool_rewriter::mk_flat_and_core(unsigned num_args, expr * const * args } br_status bool_rewriter::mk_nflat_or_core(unsigned num_args, expr * const * args, expr_ref & result) { - bool s = false; + bool s = false; // whether we have canceled some disjuncts or found some out or order ptr_buffer buffer; expr_fast_mark1 neg_lits; expr_fast_mark2 pos_lits; @@ -269,31 +269,11 @@ br_status bool_rewriter::mk_nflat_or_core(unsigned num_args, expr * const * args return BR_DONE; } -#if 1 - br_status st; - expr_ref r(m()); - st = m_hoist.mk_or(buffer.size(), buffer.data(), r); - if (st != BR_FAILED) { - m_counts1.reserve(m().get_num_asts() + 1); - m_counts2.reserve(m().get_num_asts() + 1); - get_num_internal_exprs(m_counts1, m_todo1, r); - for (unsigned i = 0; i < num_args; ++i) - get_num_internal_exprs(m_counts2, m_todo2, args[i]); - unsigned count1 = count_internal_nodes(m_counts1, m_todo1); - unsigned count2 = count_internal_nodes(m_counts2, m_todo2); - if (count1 > count2) - st = BR_FAILED; - } - if (st != BR_FAILED) - result = r; - if (st == BR_DONE) - return BR_REWRITE1; - if (st != BR_FAILED) - return st; -#endif if (s) { - ast_lt lt; - std::sort(buffer.begin(), buffer.end(), lt); + if (m_sort_disjunctions) { + ast_lt lt; + std::sort(buffer.begin(), buffer.end(), lt); + } result = m().mk_or(sz, buffer.data()); return BR_DONE; } @@ -329,7 +309,7 @@ br_status bool_rewriter::mk_flat_or_core(unsigned num_args, expr * const * args, } } if (mk_nflat_or_core(flat_args.size(), flat_args.data(), result) == BR_FAILED) { - if (!ordered) { + if (m_sort_disjunctions && !ordered) { ast_lt lt; std::sort(flat_args.begin(), flat_args.end(), lt); } @@ -565,16 +545,9 @@ bool bool_rewriter::local_ctx_simp(unsigned num_args, expr * const * args, expr_ bool simp = false; bool modified = false; bool forward = true; - unsigned rounds = 0; expr* narg; while (true) { - rounds++; -#if 0 - if (rounds > 10) - verbose_stream() << "rounds: " << rounds << "\n"; -#endif - #define PROCESS_ARG() \ { \ @@ -662,12 +635,19 @@ br_status bool_rewriter::try_ite_value(app * ite, app * val, expr_ref & result) SASSERT(m().is_value(val)); if (m().are_distinct(val, e)) { - mk_eq(t, val, result); + if (get_depth(t) < 500) + mk_eq(t, val, result); + else + result = m().mk_eq(t, val); + result = m().mk_and(result, cond); return BR_REWRITE2; } if (m().are_distinct(val, t)) { - mk_eq(e, val, result); + if (get_depth(e) < 500) + mk_eq(e, val, result); + else + result = m().mk_eq(e, val); result = m().mk_and(result, m().mk_not(cond)); return BR_REWRITE2; } @@ -712,6 +692,22 @@ app* bool_rewriter::mk_eq(expr* lhs, expr* rhs) { return m().mk_eq(lhs, rhs); } +bool bool_rewriter::try_ite_eq(expr* lhs, expr* rhs, expr_ref& r) { + expr* c, *t, *e; + if (!m().is_ite(lhs, c, t, e)) + return false; + if (m().are_equal(t, rhs) && m().are_distinct(e, rhs)) { + r = c; + return true; + } + if (m().are_equal(e, rhs) && m().are_distinct(t, rhs)) { + r = m().mk_not(c); + return true; + } + return false; +} + + br_status bool_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { if (m().are_equal(lhs, rhs)) { result = m().mk_true(); @@ -726,6 +722,12 @@ br_status bool_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { br_status r = BR_FAILED; + if (try_ite_eq(lhs, rhs, result)) + return BR_REWRITE1; + + if (try_ite_eq(rhs, lhs, result)) + return BR_REWRITE1; + if (m_ite_extra_rules) { if (m().is_ite(lhs) && m().is_value(rhs)) { r = try_ite_value(to_app(lhs), to_app(rhs), result); diff --git a/src/ast/rewriter/bool_rewriter.h b/src/ast/rewriter/bool_rewriter.h index 0693e94ba85..421811ed4d1 100644 --- a/src/ast/rewriter/bool_rewriter.h +++ b/src/ast/rewriter/bool_rewriter.h @@ -20,7 +20,6 @@ Module Name: #include "ast/ast.h" #include "ast/rewriter/rewriter.h" -#include "ast/rewriter/hoist_rewriter.h" #include "util/params.h" /** @@ -51,8 +50,8 @@ Module Name: */ class bool_rewriter { ast_manager & m_manager; - hoist_rewriter m_hoist; bool m_flat_and_or = false; + bool m_sort_disjunctions = true; bool m_local_ctx = false; bool m_elim_and = false; bool m_blast_distinct = false; @@ -72,6 +71,8 @@ class bool_rewriter { void mk_and_as_or(unsigned num_args, expr * const * args, expr_ref & result); + bool try_ite_eq(expr* lhs, expr* rhs, expr_ref& r); + expr * mk_or_app(unsigned num_args, expr * const * args); bool simp_nested_not_or(unsigned num_args, expr * const * args, expr_fast_mark1 & neg_lits, expr_fast_mark2 & pos_lits, expr_ref & result); expr * simp_arg(expr * arg, expr_fast_mark1 & neg_lits, expr_fast_mark2 & pos_lits, bool & modified); @@ -83,7 +84,7 @@ class bool_rewriter { void push_new_arg(expr* arg, expr_ref_vector& new_args, expr_fast_mark1& neg_lits, expr_fast_mark2& pos_lits); public: - bool_rewriter(ast_manager & m, params_ref const & p = params_ref()):m_manager(m), m_hoist(m), m_local_ctx_cost(0) { + bool_rewriter(ast_manager & m, params_ref const & p = params_ref()):m_manager(m), m_local_ctx_cost(0) { updt_params(p); } ast_manager & m() const { return m_manager; } diff --git a/src/ast/rewriter/bv_rewriter.cpp b/src/ast/rewriter/bv_rewriter.cpp index 751608e1290..db87cd0084f 100644 --- a/src/ast/rewriter/bv_rewriter.cpp +++ b/src/ast/rewriter/bv_rewriter.cpp @@ -20,8 +20,9 @@ Module Name: #include "ast/rewriter/bv_rewriter.h" #include "ast/rewriter/poly_rewriter_def.h" #include "ast/rewriter/bool_rewriter.h" -#include "ast/ast_smt2_pp.h" #include "ast/ast_lt.h" +#include "ast/ast_pp.h" + void bv_rewriter::updt_local_params(params_ref const & _p) { @@ -54,45 +55,58 @@ void bv_rewriter::get_param_descrs(param_descrs & r) { br_status bv_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { SASSERT(f->get_family_id() == get_fid()); + br_status st = BR_FAILED; switch(f->get_decl_kind()) { case OP_BIT0: SASSERT(num_args == 0); result = mk_zero(1); return BR_DONE; case OP_BIT1: SASSERT(num_args == 0); result = mk_one(1); return BR_DONE; case OP_ULEQ: SASSERT(num_args == 2); - return mk_ule(args[0], args[1], result); + st = mk_ule(args[0], args[1], result); + break; case OP_UGEQ: SASSERT(num_args == 2); - return mk_uge(args[0], args[1], result); + st = mk_uge(args[0], args[1], result); + break; case OP_ULT: SASSERT(num_args == 2); - return mk_ult(args[0], args[1], result); + st = mk_ult(args[0], args[1], result); + break; case OP_UGT: SASSERT(num_args == 2); - return mk_ult(args[1], args[0], result); + st = mk_ult(args[1], args[0], result); + break; case OP_SLEQ: SASSERT(num_args == 2); - return mk_sle(args[0], args[1], result); + st = mk_sle(args[0], args[1], result); + break; case OP_SGEQ: SASSERT(num_args == 2); - return mk_sge(args[0], args[1], result); + st = mk_sge(args[0], args[1], result); + break; case OP_SLT: SASSERT(num_args == 2); - return mk_slt(args[0], args[1], result); + st = mk_slt(args[0], args[1], result); + break; case OP_SGT: SASSERT(num_args == 2); - return mk_slt(args[1], args[0], result); + st = mk_slt(args[1], args[0], result); + break; case OP_BADD: SASSERT(num_args > 0); - return mk_bv_add(num_args, args, result); + st = mk_bv_add(num_args, args, result); + break; case OP_BMUL: SASSERT(num_args > 0); - return mk_bv_mul(num_args, args, result); + st = mk_bv_mul(num_args, args, result); + break; case OP_BSUB: SASSERT(num_args > 0); - return mk_sub(num_args, args, result); + st = mk_sub(num_args, args, result); + break; case OP_BNEG: SASSERT(num_args == 1); - return mk_uminus(args[0], result); + st = mk_uminus(args[0], result); + break; case OP_BNEG_OVFL: SASSERT(num_args == 1); return mk_bvneg_overflow(args[0], result); @@ -216,10 +230,17 @@ br_status bv_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * cons case OP_BUSUB_OVFL: return mk_bvusub_underflow(num_args, args, result); case OP_BSSUB_OVFL: - return mk_bvssub_overflow(num_args, args, result); + return mk_bvssub_under_overflow(num_args, args, result); default: return BR_FAILED; } + + CTRACE("bv", st != BR_FAILED, tout << mk_pp(f, m) << "\n"; + for (unsigned i = 0; i < num_args; ++i) + tout << " " << mk_bounded_pp(args[i], m) << "\n"; + tout << mk_bounded_pp(result, m, 3) << "\n"); + + return st; } br_status bv_rewriter::mk_ule(expr * a, expr * b, expr_ref & result) { @@ -541,7 +562,7 @@ br_status bv_rewriter::mk_leq_core(bool is_signed, expr * a, expr * b, expr_ref const br_status cst = rw_leq_concats(is_signed, a, b, result); if (cst != BR_FAILED) { TRACE("le_extra", tout << (is_signed ? "bv_sle\n" : "bv_ule\n") - << mk_ismt2_pp(a, m, 2) << "\n" << mk_ismt2_pp(b, m, 2) << "\n--->\n"<< mk_ismt2_pp(result, m, 2) << "\n";); + << mk_pp(a, m, 2) << "\n" << mk_pp(b, m, 2) << "\n--->\n"<< mk_pp(result, m, 2) << "\n";); return cst; } } @@ -550,7 +571,7 @@ br_status bv_rewriter::mk_leq_core(bool is_signed, expr * a, expr * b, expr_ref const br_status cst = rw_leq_overflow(is_signed, a, b, result); if (cst != BR_FAILED) { TRACE("le_extra", tout << (is_signed ? "bv_sle\n" : "bv_ule\n") - << mk_ismt2_pp(a, m, 2) << "\n" << mk_ismt2_pp(b, m, 2) << "\n--->\n"<< mk_ismt2_pp(result, m, 2) << "\n";); + << mk_pp(a, m, 2) << "\n" << mk_pp(b, m, 2) << "\n--->\n"<< mk_pp(result, m, 2) << "\n";); return cst; } } @@ -802,8 +823,8 @@ br_status bv_rewriter::mk_extract(unsigned high, unsigned low, expr * arg, expr_ const unsigned ep_rm = propagate_extract(high, arg, ep_res); if (ep_rm != 0) { result = m_mk_extract(high, low, ep_res); - TRACE("extract_prop", tout << mk_ismt2_pp(arg, m) << "\n[" << high <<"," << low << "]\n" << ep_rm << "---->\n" - << mk_ismt2_pp(result.get(), m) << "\n";); + TRACE("extract_prop", tout << mk_pp(arg, m) << "\n[" << high <<"," << low << "]\n" << ep_rm << "---->\n" + << mk_pp(result.get(), m) << "\n";); return BR_REWRITE2; } } @@ -1132,7 +1153,7 @@ br_status bv_rewriter::mk_bv_udiv_core(expr * arg1, expr * arg2, bool hi_div0, e m_util.mk_bv_udiv0(arg1), m_util.mk_bv_udiv_i(arg1, arg2)); - TRACE("bv_udiv", tout << mk_ismt2_pp(arg1, m) << "\n" << mk_ismt2_pp(arg2, m) << "\n---->\n" << mk_ismt2_pp(result, m) << "\n";); + TRACE("bv_udiv", tout << mk_pp(arg1, m) << "\n" << mk_pp(arg2, m) << "\n---->\n" << mk_pp(result, m) << "\n";); return BR_REWRITE2; } @@ -1792,8 +1813,8 @@ br_status bv_rewriter::mk_bv_or(unsigned num, expr * const * args, expr_ref & re std::reverse(exs.begin(), exs.end()); result = m_util.mk_concat(exs.size(), exs.data()); TRACE("mask_bug", - tout << "(assert (distinct (bvor (_ bv" << old_v1 << " " << sz << ")\n" << mk_ismt2_pp(t, m) << ")\n"; - tout << mk_ismt2_pp(result, m) << "))\n";); + tout << "(assert (distinct (bvor (_ bv" << old_v1 << " " << sz << ")\n" << mk_pp(t, m) << ")\n"; + tout << mk_pp(result, m) << "))\n";); return BR_REWRITE2; } @@ -2463,7 +2484,7 @@ br_status bv_rewriter::mk_blast_eq_value(expr * lhs, expr * rhs, expr_ref & resu unsigned sz = get_bv_size(lhs); if (sz == 1) return BR_FAILED; - TRACE("blast_eq_value", tout << "sz: " << sz << "\n" << mk_ismt2_pp(lhs, m) << "\n";); + TRACE("blast_eq_value", tout << "sz: " << sz << "\n" << mk_pp(lhs, m) << "\n";); if (is_numeral(lhs)) std::swap(lhs, rhs); @@ -2573,7 +2594,6 @@ void bv_rewriter::mk_t1_add_t2_eq_c(expr * t1, expr * t2, expr * c, expr_ref & r result = m.mk_eq(t1, m_util.mk_bv_sub(c, t2)); } -#include "ast/ast_pp.h" bool bv_rewriter::isolate_term(expr* lhs, expr* rhs, expr_ref& result) { if (!m_util.is_numeral(lhs) || !is_add(rhs)) { @@ -2730,13 +2750,13 @@ br_status bv_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { st = mk_mul_eq(lhs, rhs, result); if (st != BR_FAILED) { - TRACE("mk_mul_eq", tout << mk_ismt2_pp(lhs, m) << "\n=\n" << mk_ismt2_pp(rhs, m) << "\n----->\n" << mk_ismt2_pp(result,m) << "\n";); + TRACE("mk_mul_eq", tout << mk_pp(lhs, m) << "\n=\n" << mk_pp(rhs, m) << "\n----->\n" << mk_pp(result,m) << "\n";); return st; } st = mk_mul_eq(rhs, lhs, result); if (st != BR_FAILED) { - TRACE("mk_mul_eq", tout << mk_ismt2_pp(lhs, m) << "\n=\n" << mk_ismt2_pp(rhs, m) << "\n----->\n" << mk_ismt2_pp(result,m) << "\n";); + TRACE("mk_mul_eq", tout << mk_pp(lhs, m) << "\n=\n" << mk_pp(rhs, m) << "\n----->\n" << mk_pp(result,m) << "\n";); return st; } @@ -2851,8 +2871,8 @@ bool bv_rewriter::is_eq_bit(expr * t, expr * & x, unsigned & val) { br_status bv_rewriter::mk_ite_core(expr * c, expr * t, expr * e, expr_ref & result) { - TRACE("bv_ite", tout << "mk_ite_core:\n" << mk_ismt2_pp(c, m) << "?\n" - << mk_ismt2_pp(t, m) << "\n:" << mk_ismt2_pp(e, m) << "\n";); + TRACE("bv_ite", tout << "mk_ite_core:\n" << mk_pp(c, m) << "?\n" + << mk_pp(t, m) << "\n:" << mk_pp(e, m) << "\n";); if (m.are_equal(t, e)) { result = e; return BR_REWRITE1; @@ -3015,8 +3035,8 @@ br_status bv_rewriter::mk_bvumul_no_overflow(unsigned num, expr * const * args, br_status bv_rewriter::mk_bvneg_overflow(expr * const arg, expr_ref & result) { unsigned int sz = get_bv_size(arg); - auto maxUnsigned = mk_numeral(rational::power_of_two(sz)-1, sz); - result = m.mk_eq(arg, maxUnsigned); + auto minSigned = mk_numeral(rational::power_of_two(sz - 1), sz); // 0b1000...0 + result = m.mk_eq(arg, minSigned); return BR_REWRITE3; } @@ -3081,24 +3101,34 @@ br_status bv_rewriter::mk_bvusub_underflow(unsigned num, expr * const * args, ex return status; } -br_status bv_rewriter::mk_bvssub_overflow(unsigned num, expr * const * args, expr_ref & result) { +// +// no_overflow := if t2 = min_int then t1 no_underflow(t1 + -t2) +// over_underflow := 0 =s 0 || t2 != min_int & under_overflow+(t1 + -t2) +// := if t2 == min_int then t1 >=s 0 else under_overflow+(t1 + -t2) +// because when 0 { br_status mk_bvsadd_over_underflow(unsigned num, expr * const * args, expr_ref & result); br_status mk_bvusub_underflow(unsigned num, expr * const * args, expr_ref & result); - br_status mk_bvssub_overflow(unsigned num, expr * const * args, expr_ref & result); + // br_status mk_bvssub_overflow(unsigned num, expr * const * args, expr_ref & result); + br_status mk_bvssub_under_overflow(unsigned num, expr * const * args, expr_ref & result); bool is_minus_one_times_t(expr * arg); void mk_t1_add_t2_eq_c(expr * t1, expr * t2, expr * c, expr_ref & result); diff --git a/src/ast/rewriter/datatype_rewriter.cpp b/src/ast/rewriter/datatype_rewriter.cpp index ba0155e975e..001b697e423 100644 --- a/src/ast/rewriter/datatype_rewriter.cpp +++ b/src/ast/rewriter/datatype_rewriter.cpp @@ -21,7 +21,8 @@ Module Name: br_status datatype_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { SASSERT(f->get_family_id() == get_fid()); switch(f->get_decl_kind()) { - case OP_DT_CONSTRUCTOR: return BR_FAILED; + case OP_DT_CONSTRUCTOR: + return BR_FAILED; case OP_DT_RECOGNISER: SASSERT(num_args == 1); result = m_util.mk_is(m_util.get_recognizer_constructor(f), args[0]); diff --git a/src/ast/rewriter/der.cpp b/src/ast/rewriter/der.cpp index 0e28cf6f6fa..1e2a19d725e 100644 --- a/src/ast/rewriter/der.cpp +++ b/src/ast/rewriter/der.cpp @@ -176,9 +176,9 @@ void der::reduce1(quantifier * q, expr_ref & r, proof_ref & pr) { var * v = nullptr; expr_ref t(m); - if (is_forall(q) && is_var_diseq(e, num_decls, v, t) && !occurs(v, t)) + if (is_forall(q) && is_var_diseq(e, num_decls, v, t) && !has_quantifiers(t) && !occurs(v, t)) r = m.mk_false(); - else if (is_exists(q) && is_var_eq(e, num_decls, v, t) && !occurs(v, t)) + else if (is_exists(q) && is_var_eq(e, num_decls, v, t) && !has_quantifiers(t) && !occurs(v, t)) r = m.mk_true(); else { expr_ref_vector literals(m); diff --git a/src/ast/rewriter/elim_bounds.cpp b/src/ast/rewriter/elim_bounds.cpp index de23537ad02..b6b98e0cb25 100644 --- a/src/ast/rewriter/elim_bounds.cpp +++ b/src/ast/rewriter/elim_bounds.cpp @@ -17,11 +17,10 @@ Revision History: --*/ -#ifndef ELIM_BOUNDS_H_ -#define ELIM_BOUNDS_H_ #include "ast/used_vars.h" #include "util/obj_hashtable.h" +#include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/elim_bounds.h" #include "ast/ast_pp.h" @@ -200,4 +199,3 @@ bool elim_bounds_cfg::reduce_quantifier(quantifier * q, return true; } -#endif /* ELIM_BOUNDS_H_ */ diff --git a/src/ast/rewriter/hoist_rewriter.cpp b/src/ast/rewriter/hoist_rewriter.cpp deleted file mode 100644 index 72a764bfa37..00000000000 --- a/src/ast/rewriter/hoist_rewriter.cpp +++ /dev/null @@ -1,248 +0,0 @@ -/*++ -Copyright (c) 2019 Microsoft Corporation - -Module Name: - - hoist_rewriter.cpp - -Abstract: - - Hoist predicates over disjunctions - -Author: - - Nikolaj Bjorner (nbjorner) 2019-2-4 - ---*/ - - -#include "ast/rewriter/hoist_rewriter.h" -#include "ast/rewriter/bool_rewriter.h" -#include "ast/ast_util.h" -#include "ast/ast_pp.h" -#include "ast/ast_ll_pp.h" - -hoist_rewriter::hoist_rewriter(ast_manager & m, params_ref const & p): - m(m), m_args1(m), m_args2(m), m_refs(m), m_subst(m) { - updt_params(p); -} - -expr_ref hoist_rewriter::mk_and(expr_ref_vector const& args) { - if (m_elim_and) { - expr_ref_vector negs(m); - for (expr* a : args) - if (m.is_false(a)) - return expr_ref(m.mk_false(), m); - else if (m.is_true(a)) - continue; - else - negs.push_back(::mk_not(m, a)); - return ::mk_not(mk_or(negs)); - } - else - return ::mk_and(args); -} - -expr_ref hoist_rewriter::mk_or(expr_ref_vector const& args) { - return ::mk_or(args); -} - -br_status hoist_rewriter::mk_or(unsigned num_args, expr * const * es, expr_ref & result) { - if (num_args < 2) - return BR_FAILED; - - for (unsigned i = 0; i < num_args; ++i) - if (!is_and(es[i], nullptr)) - return BR_FAILED; - - bool turn = false; - m_preds1.reset(); - m_preds2.reset(); - m_uf1.reset(); - m_uf2.reset(); - m_expr2var.reset(); - m_var2expr.reset(); - basic_union_find* uf[2] = { &m_uf1, &m_uf2 }; - obj_hashtable* preds[2] = { &m_preds1, &m_preds2 }; - expr_ref_vector* args[2] = { &m_args1, &m_args2 }; - VERIFY(is_and(es[0], args[turn])); - expr* e1, *e2; - for (expr* e : *(args[turn])) { - if (m.is_eq(e, e1, e2)) - (*uf)[turn].merge(mk_var(e1), mk_var(e2)); - else - (*preds)[turn].insert(e); - } - unsigned round = 0; - for (unsigned j = 1; j < num_args; ++j) { - ++round; - m_es.reset(); - m_mark.reset(); - - bool last = turn; - turn = !turn; - (*preds)[turn].reset(); - reset(m_uf0); - VERIFY(is_and(es[j], args[turn])); - - for (expr* e : *args[turn]) { - if (m.is_eq(e, e1, e2)) { - m_es.push_back(e1); - m_uf0.merge(mk_var(e1), mk_var(e2)); - } - else if ((*preds)[last].contains(e)) - (*preds)[turn].insert(e); - } - - if ((*preds)[turn].empty() && m_es.empty()) - return BR_FAILED; - - m_eqs.reset(); - for (expr* e : m_es) { - if (m_mark.is_marked(e)) - continue; - unsigned u = mk_var(e); - unsigned v = u; - m_roots.reset(); - do { - m_mark.mark(e); - unsigned r = (*uf)[last].find(v); - if (m_roots.find(r, e2)) - m_eqs.push_back({e, e2}); - else - m_roots.insert(r, e); - v = m_uf0.next(v); - e = mk_expr(v); - } - while (u != v); - } - reset((*uf)[turn]); - for (auto const& [e1, e2] : m_eqs) - (*uf)[turn].merge(mk_var(e1), mk_var(e2)); - if ((*preds)[turn].empty() && m_eqs.empty()) - return BR_FAILED; - } - if (m_eqs.empty()) { - result = hoist_predicates((*preds)[turn], num_args, es); - return BR_DONE; - } - // p & eqs & (or fmls) - expr_ref_vector fmls(m); - m_subst.reset(); - for (expr * p : (*preds)[turn]) { - expr* q = nullptr; - if (m.is_not(p, q)) - m_subst.insert(q, m.mk_false()); - else - m_subst.insert(p, m.mk_true()); - fmls.push_back(p); - } - for (auto& p : m_eqs) { - if (m.is_value(p.first)) - std::swap(p.first, p.second); - m_subst.insert(p.first, p.second); - fmls.push_back(m.mk_eq(p.first, p.second)); - } - expr_ref ors(::mk_or(m, num_args, es), m); - m_subst(ors); - fmls.push_back(ors); - result = mk_and(fmls); - TRACE("hoist", tout << ors << " => " << result << "\n";); - return BR_DONE; -} - -unsigned hoist_rewriter::mk_var(expr* e) { - unsigned v = 0; - if (m_expr2var.find(e, v)) - return v; - m_uf1.mk_var(); - v = m_uf2.mk_var(); - SASSERT(v == m_var2expr.size()); - m_expr2var.insert(e, v); - m_var2expr.push_back(e); - return v; -} - -expr_ref hoist_rewriter::hoist_predicates(obj_hashtable const& preds, unsigned num_args, expr* const* es) { - expr_ref_vector args(m), args1(m), fmls(m); - for (unsigned i = 0; i < num_args; ++i) { - VERIFY(is_and(es[i], &args1)); - fmls.reset(); - for (expr* e : args1) - if (!preds.contains(e)) - fmls.push_back(e); - args.push_back(mk_and(fmls)); - } - fmls.reset(); - fmls.push_back(mk_or(args)); - for (auto* p : preds) - fmls.push_back(p); - return mk_and(fmls); -} - - -br_status hoist_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { - switch (f->get_decl_kind()) { - case OP_OR: - return mk_or(num_args, args, result); - default: - return BR_FAILED; - } -} - -bool hoist_rewriter::is_and(expr * e, expr_ref_vector* args) { -#if 0 - if (!args) - return m.is_and(e) || (m.is_not(e, e) && m.is_or(e)); - expr_fast_mark1 visited; - args->reset(); - args->push_back(e); - m_refs.reset(); - for (unsigned i = 0; i < args->size(); ++i) { - e = args->get(i); - if (visited.is_marked(e)) - goto drop; - m_refs.push_back(e); - visited.mark(e, true); - if (m.is_and(e)) - args->append(to_app(e)->get_num_args(), to_app(e)->get_args()); - else if (m.is_not(e, e) && m.is_or(e)) - for (expr* arg : *to_app(e)) - args->push_back(::mk_not(m, arg)); - else - continue; - drop: - (*args)[i] = args->back(); - args->pop_back(); - --i; - } - return args->size() > 1; -#else - if (m.is_and(e)) { - if (args) { - args->reset(); - args->append(to_app(e)->get_num_args(), to_app(e)->get_args()); - } - return true; - } - if (m.is_not(e, e) && m.is_or(e)) { - if (args) { - args->reset(); - for (expr* arg : *to_app(e)) - args->push_back(::mk_not(m, arg)); - TRACE("hoist", tout << args << " " << * args << "\n"); - } - return true; - } -#endif - return false; -} - - -void hoist_rewriter::reset(basic_union_find& uf) { - uf.reset(); - for (expr* e : m_var2expr) { - (void)e; - uf.mk_var(); - } -} diff --git a/src/ast/rewriter/hoist_rewriter.h b/src/ast/rewriter/hoist_rewriter.h deleted file mode 100644 index b64325584a5..00000000000 --- a/src/ast/rewriter/hoist_rewriter.h +++ /dev/null @@ -1,87 +0,0 @@ -/*++ -Copyright (c) 2019 Microsoft Corporation - -Module Name: - - hoist_rewriter.h - -Abstract: - - Hoist predicates over disjunctions - -Author: - - Nikolaj Bjorner (nbjorner) 2019-2-4 - -Notes: - ---*/ -#pragma once - -#include "ast/ast.h" -#include "ast/rewriter/rewriter.h" -#include "ast/rewriter/expr_safe_replace.h" -#include "util/params.h" -#include "util/union_find.h" -#include "util/obj_hashtable.h" - -class bool_rewriter; - -class hoist_rewriter { - ast_manager & m; - expr_ref_vector m_args1, m_args2, m_refs; - obj_hashtable m_preds1, m_preds2; - basic_union_find m_uf1, m_uf2, m_uf0; - ptr_vector m_es; - svector> m_eqs; - u_map m_roots; - expr_safe_replace m_subst; - obj_map m_expr2var; - ptr_vector m_var2expr; - expr_mark m_mark; - bool m_elim_and = false; - - bool is_and(expr* e, expr_ref_vector* args); - expr_ref mk_and(expr_ref_vector const& args); - expr_ref mk_or(expr_ref_vector const& args); - - bool is_var(expr* e) { return m_expr2var.contains(e); } - expr* mk_expr(unsigned v) { return m_var2expr[v]; } - unsigned mk_var(expr* e); - - void reset(basic_union_find& uf); - - expr_ref hoist_predicates(obj_hashtable const& p, unsigned num_args, expr* const* args); - - -public: - hoist_rewriter(ast_manager & m, params_ref const & p = params_ref()); - family_id get_fid() const { return m.get_basic_family_id(); } - bool is_eq(expr * t) const { return m.is_eq(t); } - void updt_params(params_ref const & p) {} - static void get_param_descrs(param_descrs & r) {} - br_status mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result); - br_status mk_or(unsigned num_args, expr * const * args, expr_ref & result); - void set_elim_and(bool b) { m_elim_and = b; } -}; - -struct hoist_rewriter_cfg : public default_rewriter_cfg { - hoist_rewriter m_r; - bool rewrite_patterns() const { return false; } - br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { - result_pr = nullptr; - if (f->get_family_id() != m_r.get_fid()) - return BR_FAILED; - return m_r.mk_app_core(f, num, args, result); - } - hoist_rewriter_cfg(ast_manager & m, params_ref const & p):m_r(m, p) {} -}; - -class hoist_rewriter_star : public rewriter_tpl { - hoist_rewriter_cfg m_cfg; -public: - hoist_rewriter_star(ast_manager & m, params_ref const & p = params_ref()): - rewriter_tpl(m, false, m_cfg), - m_cfg(m, p) {} -}; - diff --git a/src/ast/rewriter/pb_rewriter.cpp b/src/ast/rewriter/pb_rewriter.cpp index 2f06710ffaa..4f1fe3fd1a5 100644 --- a/src/ast/rewriter/pb_rewriter.cpp +++ b/src/ast/rewriter/pb_rewriter.cpp @@ -157,9 +157,7 @@ expr_ref pb_rewriter::mk_validate_rewrite(app_ref& e1, app_ref& e2) { continue; } - std::ostringstream strm; - strm << 'x' << i; - name = symbol(strm.str()); + name = symbol('x' + std::to_string(i)); trail.push_back(m.mk_const(name, a.mk_int())); expr* x = trail.back(); m.is_not(e,e); @@ -188,9 +186,7 @@ void pb_rewriter::validate_rewrite(func_decl* f, unsigned sz, expr*const* args, } void pb_rewriter::dump_pb_rewrite(expr* fml) { - std::ostringstream strm; - strm << "pb_rewrite_" << (s_lemma++) << ".smt2"; - std::ofstream out(strm.str()); + std::ofstream out("pb_rewrite_" + std::to_string(s_lemma++) + ".smt2"); ast_smt_pp pp(m()); pp.display_smt2(out, fml); out.close(); diff --git a/src/ast/rewriter/push_app_ite.cpp b/src/ast/rewriter/push_app_ite.cpp index 0282aa60268..b993c29a5b2 100644 --- a/src/ast/rewriter/push_app_ite.cpp +++ b/src/ast/rewriter/push_app_ite.cpp @@ -18,6 +18,7 @@ Revision History: --*/ #include "ast/rewriter/push_app_ite.h" +#include "ast/rewriter/rewriter_def.h" #include "ast/ast_pp.h" diff --git a/src/ast/rewriter/rewriter.cpp b/src/ast/rewriter/rewriter.cpp index 3b25b94095d..51c4764df4d 100644 --- a/src/ast/rewriter/rewriter.cpp +++ b/src/ast/rewriter/rewriter.cpp @@ -17,6 +17,8 @@ Module Name: --*/ #include "ast/rewriter/rewriter_def.h" +#include "ast/rewriter/push_app_ite.h" +#include "ast/rewriter/elim_bounds.h" #include "ast/ast_ll_pp.h" #include "ast/ast_pp.h" #include "ast/ast_smt2_pp.h" @@ -417,3 +419,6 @@ void inv_var_shifter::process_var(var * v) { } template class rewriter_tpl; +template class rewriter_tpl; +template class rewriter_tpl; +template class rewriter_tpl; diff --git a/src/ast/rewriter/rewriter_types.h b/src/ast/rewriter/rewriter_types.h index ff1755b9135..937bcdc6bcb 100644 --- a/src/ast/rewriter/rewriter_types.h +++ b/src/ast/rewriter/rewriter_types.h @@ -33,6 +33,18 @@ enum br_status { BR_FAILED // no builtin rewrite is available }; +inline std::ostream& operator<<(std::ostream& out, br_status st) { + switch (st) { + case BR_REWRITE1: return out << "rewrite1"; + case BR_REWRITE2: return out << "rewrite2"; + case BR_REWRITE3: return out << "rewrite3"; + case BR_REWRITE_FULL: return out << "rewrite_full"; + case BR_DONE: return out << "done"; + case BR_FAILED: return out << "failed"; + default: return out << "unknown"; + } +} + #define RW_UNBOUNDED_DEPTH 3 inline br_status unsigned2br_status(unsigned u) { br_status r = u >= RW_UNBOUNDED_DEPTH ? BR_REWRITE_FULL : static_cast(u); diff --git a/src/ast/rewriter/seq_axioms.cpp b/src/ast/rewriter/seq_axioms.cpp index 4d7da4d7f99..b0d3add0be0 100644 --- a/src/ast/rewriter/seq_axioms.cpp +++ b/src/ast/rewriter/seq_axioms.cpp @@ -538,6 +538,7 @@ namespace seq { expr_ref t_eq_empty = mk_eq_empty(t); expr_ref xsy = mk_concat(x, s, y); + // add_clause(~mk_eq(t, s), i_eq_0); add_clause(cnt, i_eq_m1); add_clause(~t_eq_empty, s_eq_empty, i_eq_m1); add_clause(~s_eq_empty, mk_eq(i, mk_len(t))); diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 69c3584d42f..004838d62db 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -829,6 +829,7 @@ br_status seq_rewriter::mk_seq_concat(expr* a, expr* b, expr_ref& result) { br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { zstring b; + rational r; m_es.reset(); str().get_concat(a, m_es); unsigned len = 0; @@ -868,6 +869,17 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { result = str().mk_length(z); return BR_REWRITE1; } + // not suitable for noodler: adds length check which makes it harder + // // len(extract(x, 0, z)) = min(z, len(x)) + // if (str().is_extract(a, x, y, z) && + // m_autil.is_numeral(y, r) && r.is_zero() && + // m_autil.is_numeral(z, r) && r >= 0) { + // expr* len_x = str().mk_length(x); + // result = m().mk_ite(m_autil.mk_le(len_x, z), len_x, z); + // // expr* zero = m_autil.mk_int(0); + // // result = m().mk_ite(m_autil.mk_le(z, zero), zero, result); + // return BR_REWRITE_FULL; + // } #if 0 expr* s = nullptr, *offset = nullptr, *length = nullptr; if (str().is_extract(a, s, offset, length)) { @@ -1213,6 +1225,11 @@ br_status seq_rewriter::mk_seq_extract(expr* a, expr* b, expr* c, expr_ref& resu constantPos &= pos.is_unsigned(); constantLen &= len.is_unsigned(); + if (constantPos && constantLen && len == 1) { + result = str().mk_at(a, b); + return BR_REWRITE1; + } + if (constantPos && constantLen && constantBase) { unsigned _pos = pos.get_unsigned(); unsigned _len = len.get_unsigned(); @@ -1249,6 +1266,15 @@ br_status seq_rewriter::mk_seq_extract(expr* a, expr* b, expr* c, expr_ref& resu result = str().mk_substr(a1, m_autil.mk_add(b1, b), m_autil.mk_sub(c1, b)); return BR_REWRITE3; } + rational r1, r2; + if (str().is_extract(a, a1, b1, c1) && + m_autil.is_numeral(b1, r1) && r1.is_unsigned() && + m_autil.is_numeral(c1, r2) && r2.is_unsigned() && + constantPos && constantLen && + r1 == 0 && r2 >= pos + len) { + result = str().mk_substr(a1, b, c); + return BR_REWRITE1; + } if (str().is_extract(a, a1, b1, c1) && is_prefix(a1, b1, c1) && is_prefix(a, b, c)) { @@ -1547,9 +1573,17 @@ bool seq_rewriter::reduce_by_char(expr_ref& r, expr* ch, unsigned depth) { */ br_status seq_rewriter::mk_seq_at(expr* a, expr* b, expr_ref& result) { zstring c; - rational r; + rational r, offset_r, len_r; + expr* offset, *a1, *len; expr_ref_vector lens(m()); sort* sort_a = a->get_sort(); + if (str().is_extract(a, a1, offset, len) && + m_autil.is_numeral(offset, offset_r) && offset_r.is_zero() && + m_autil.is_numeral(len, len_r) && m_autil.is_numeral(b, r) && + r < len_r) { + result = str().mk_at(a1, b); + return BR_REWRITE1; + } if (!get_lengths(b, lens, r)) { return BR_FAILED; } @@ -1724,6 +1758,10 @@ br_status seq_rewriter::mk_seq_last_index(expr* a, expr* b, expr_ref& result) { result = m_autil.mk_numeral(rational(idx), true); return BR_DONE; } + if (a == b) { + result = m_autil.mk_int(0); + return BR_DONE; + } return BR_FAILED; } @@ -3456,7 +3494,7 @@ expr_ref seq_rewriter::mk_antimirov_deriv_union(expr* d1, expr* d2) { // // restrict(d, false) = [] // -// it is already assumed that the restriction takes place witin a branch +// it is already assumed that the restriction takes place within a branch // so the condition is not added explicitly but propagated down in order to eliminate // infeasible cases expr_ref seq_rewriter::mk_antimirov_deriv_restrict(expr* e, expr* d, expr* cond) { @@ -3695,7 +3733,7 @@ expr_ref seq_rewriter::mk_regex_concat(expr* r, expr* s) { result = re().mk_plus(re().mk_full_char(ele_sort)); else if (re().is_concat(r, r1, r2)) // create the resulting concatenation in right-associative form except for the following case - // TODO: maintain the followig invariant for A ++ B{m,n} + C + // TODO: maintain the following invariant for A ++ B{m,n} + C // concat(concat(A, B{m,n}), C) (if A != () and C != ()) // concat(B{m,n}, C) (if A == () and C != ()) // where A, B, C are regexes @@ -3703,11 +3741,11 @@ expr_ref seq_rewriter::mk_regex_concat(expr* r, expr* s) { // In other words, do not make A ++ B{m,n} into right-assoc form, but keep B{m,n} at the top // This will help to identify this situation in the merge routine: // concat(concat(A, B{0,m}), C) | concat(concat(A, B{0,n}), C) - // simplies to + // simplifies to // concat(concat(A, B{0,max(m,n)}), C) // analogously: // concat(concat(A, B{0,m}), C) & concat(concat(A, B{0,n}), C) - // simplies to + // simplifies to // concat(concat(A, B{0,min(m,n)}), C) result = mk_regex_concat(r1, mk_regex_concat(r2, s)); else { @@ -3828,12 +3866,12 @@ bool seq_rewriter::pred_implies(expr* a, expr* b) { Utility function to decide if two BDDs (nested if-then-else terms) have exactly the same structure and conditions. */ -bool seq_rewriter::ite_bdds_compatabile(expr* a, expr* b) { +bool seq_rewriter::ite_bdds_compatible(expr* a, expr* b) { expr* ca = nullptr, *a1 = nullptr, *a2 = nullptr; expr* cb = nullptr, *b1 = nullptr, *b2 = nullptr; if (m().is_ite(a, ca, a1, a2) && m().is_ite(b, cb, b1, b2)) { - return (ca == cb) && ite_bdds_compatabile(a1, b1) - && ite_bdds_compatabile(a2, b2); + return (ca == cb) && ite_bdds_compatible(a1, b1) + && ite_bdds_compatible(a2, b2); } else if (m().is_ite(a) || m().is_ite(b)) { return false; @@ -3893,7 +3931,7 @@ expr_ref seq_rewriter::mk_der_op_rec(decl_kind k, expr* a, expr* b) { // sophisticated: in an antimirov union of n terms, we really // want to check if any pair of them is compatible. else if (m().is_ite(a) && m().is_ite(b) && - !ite_bdds_compatabile(a, b)) { + !ite_bdds_compatible(a, b)) { k = _OP_RE_ANTIMIROV_UNION; } #endif @@ -4247,7 +4285,7 @@ expr_ref seq_rewriter::mk_derivative_rec(expr* ele, expr* r) { } else if (re().is_reverse(r, r1)) { if (re().is_to_re(r1, r2)) { - // First try to exctract hd and tl such that r = hd ++ tl and |tl|=1 + // First try to extract hd and tl such that r = hd ++ tl and |tl|=1 expr_ref hd(m()), tl(m()); if (get_head_tail_reversed(r2, hd, tl)) { // Use mk_der_cond to normalize @@ -5050,12 +5088,14 @@ br_status seq_rewriter::mk_re_star(expr* a, expr_ref& result) { * (re.range c_1 c_n) */ br_status seq_rewriter::mk_re_range(expr* lo, expr* hi, expr_ref& result) { - zstring s; + zstring slo, shi; unsigned len = 0; bool is_empty = false; - if (str().is_string(lo, s) && s.length() != 1) + if (str().is_string(lo, slo) && slo.length() != 1) is_empty = true; - if (str().is_string(hi, s) && s.length() != 1) + if (str().is_string(hi, shi) && shi.length() != 1) + is_empty = true; + if (slo.length() == 1 && shi.length() == 1 && slo[0] > shi[0]) is_empty = true; len = min_length(lo).second; if (len > 1) @@ -5272,7 +5312,17 @@ br_status seq_rewriter::reduce_re_is_empty(expr* r, expr_ref& result) { else if (re().is_range(r, r1, r2) && str().is_string(r1, s1) && str().is_string(r2, s2) && s1.length() == 1 && s2.length() == 1) { - result = m().mk_bool_val(s1[0] <= s2[0]); + result = m().mk_bool_val(s1[0] > s2[0]); + return BR_DONE; + } + else if (re().is_range(r, r1, r2) && + str().is_string(r1, s1) && s1.length() != 1) { + result = m().mk_true(); + return BR_DONE; + } + else if (re().is_range(r, r1, r2) && + str().is_string(r2, s2) && s2.length() != 1) { + result = m().mk_true(); return BR_DONE; } else if ((re().is_loop(r, r1, lo) || @@ -5333,6 +5383,7 @@ br_status seq_rewriter::mk_le_core(expr * l, expr * r, expr_ref & result) { } br_status seq_rewriter::mk_eq_core(expr * l, expr * r, expr_ref & result) { + TRACE("seq", tout << mk_pp(l, m()) << " == " << mk_pp(r, m()) << "\n"); expr_ref_vector res(m()); expr_ref_pair_vector new_eqs(m()); if (m_util.is_re(l)) { diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 92a6a17faf6..af4756576ae 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -201,7 +201,7 @@ class seq_rewriter { expr_ref mk_der_compl(expr* a); expr_ref mk_der_cond(expr* cond, expr* ele, sort* seq_sort); expr_ref mk_der_antimirov_union(expr* r1, expr* r2); - bool ite_bdds_compatabile(expr* a, expr* b); + bool ite_bdds_compatible(expr* a, expr* b); /* if r has the form deriv(en..deriv(e1,to_re(s))..) returns 's = [e1..en]' else returns '() in r'*/ expr_ref is_nullable_symbolic_regex(expr* r, sort* seq_sort); #ifdef Z3DEBUG diff --git a/src/ast/rewriter/seq_skolem.h b/src/ast/rewriter/seq_skolem.h index 4b828abf644..4e327f0fa43 100644 --- a/src/ast/rewriter/seq_skolem.h +++ b/src/ast/rewriter/seq_skolem.h @@ -8,7 +8,7 @@ Module Name: Abstract: Skolem function support for sequences. - Skolem functions are auxiliary funcions useful for axiomatizing sequence + Skolem functions are auxiliary functions useful for axiomatizing sequence operations. Author: diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 9278ae5aed4..ee8b79be53e 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -39,6 +39,7 @@ Module Name: #include "ast/ast_util.h" #include "ast/well_sorted.h" #include "ast/for_each_expr.h" +#include "ast/array_peq.h" namespace { struct th_rewriter_cfg : public default_rewriter_cfg { @@ -74,6 +75,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { bool m_push_ite_bv = true; bool m_ignore_patterns_on_ground_qbody = true; bool m_rewrite_patterns = true; + bool m_enable_der = true; + bool m_nested_der = false; ast_manager & m() const { return m_b_rw.m(); } @@ -89,6 +92,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { m_push_ite_bv = p.push_ite_bv(); m_ignore_patterns_on_ground_qbody = p.ignore_patterns_on_ground_qbody(); m_rewrite_patterns = p.rewrite_patterns(); + m_enable_der = p.enable_der(); + m_nested_der = _p.get_bool("nested_der", false); } void updt_params(params_ref const & p) { @@ -642,6 +647,10 @@ struct th_rewriter_cfg : public default_rewriter_cfg { else st = pull_ite(result); } + if (st == BR_FAILED && f->get_family_id() == null_family_id && is_partial_eq(f)) { + st = m_ar_rw.mk_app_core(f, num, args, result); + } + CTRACE("th_rewriter_step", st != BR_FAILED, tout << f->get_name() << "\n"; for (unsigned i = 0; i < num; i++) tout << mk_ismt2_pp(args[i], m()) << "\n"; @@ -827,16 +836,20 @@ struct th_rewriter_cfg : public default_rewriter_cfg { expr_ref r(m()); bool der_change = false; - if (is_quantifier(result) && to_quantifier(result)->get_num_patterns() == 0) { + if (m_enable_der && is_quantifier(result) && to_quantifier(result)->get_num_patterns() == 0) { m_der(to_quantifier(result), r, p2); der_change = result.get() != r.get(); if (m().proofs_enabled() && der_change) - result_pr = m().mk_transitivity(result_pr, p2); + result_pr = m().mk_transitivity(result_pr, p2); + result = r; } - if (der_change) { + if (der_change && !m_nested_der) { th_rewriter rw(m()); + params_ref p; + p.set_bool("nested_der", true); + rw.updt_params(p); rw(result, r, p2); if (m().proofs_enabled() && result.get() != r.get()) result_pr = m().mk_transitivity(result_pr, p2); diff --git a/src/ast/simplifiers/bound_simplifier.cpp b/src/ast/simplifiers/bound_simplifier.cpp index f5c986425a6..1a5d4c10107 100644 --- a/src/ast/simplifiers/bound_simplifier.cpp +++ b/src/ast/simplifiers/bound_simplifier.cpp @@ -58,9 +58,9 @@ struct bound_simplifier::rw : public rewriter_tpl { br_status bound_simplifier::reduce_app(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result, proof_ref& pr) { rational N, hi, lo; if (a.is_mod(f) && num_args == 2 && a.is_numeral(args[1], N)) { - expr* x = args[0]; auto& im = m_interval; scoped_dep_interval i(im); + expr* x = args[0]; get_bounds(x, i); if (im.upper_is_inf(i) || im.lower_is_inf(i)) return BR_FAILED; @@ -83,7 +83,55 @@ br_status bound_simplifier::reduce_app(func_decl* f, unsigned num_args, expr* co } IF_VERBOSE(2, verbose_stream() << "potentially missed simplification: " << mk_pp(x, m) << " " << lo << " " << hi << " not reduced\n"); } - return BR_FAILED; + + expr_ref_buffer new_args(m); + expr_ref new_arg(m); + bool change = false; + for (unsigned i = 0; i < num_args; ++i) { + expr* arg = args[i]; + change = reduce_arg(arg, new_arg) || change; + new_args.push_back(new_arg); + } + if (!change) + return BR_FAILED; + + result = m.mk_app(f, num_args, new_args.data()); + + return BR_DONE; +} + +bool bound_simplifier::reduce_arg(expr* arg, expr_ref& result) { + result = arg; + expr* x, *y; + rational N, lo, hi; + bool strict; + if ((a.is_le(arg, x, y) && a.is_numeral(y, N)) || + (a.is_ge(arg, y, x) && a.is_numeral(y, N))) { + + if (has_upper(x, hi, strict) && !strict && N >= hi) { + result = m.mk_true(); + return true; + } + if (has_lower(x, lo, strict) && !strict && N < lo) { + result = m.mk_false(); + return true; + } + return false; + } + + if ((a.is_le(arg, y, x) && a.is_numeral(y, N)) || + (a.is_ge(arg, x, y) && a.is_numeral(y, N))) { + if (has_lower(x, lo, strict) && !strict && N <= lo) { + result = m.mk_true(); + return true; + } + if (has_upper(x, hi, strict) && !strict && N > hi) { + result = m.mk_false(); + return true; + } + return false; + } + return false; } void bound_simplifier::reduce() { diff --git a/src/ast/simplifiers/bound_simplifier.h b/src/ast/simplifiers/bound_simplifier.h index 7950f418b84..9bd4b19088c 100644 --- a/src/ast/simplifiers/bound_simplifier.h +++ b/src/ast/simplifiers/bound_simplifier.h @@ -37,6 +37,7 @@ class bound_simplifier : public dependent_expr_simplifier { unsynch_mpq_manager nm; small_object_allocator m_alloc; bound_propagator bp; + u_dependency_manager m_dep_manager; dep_intervals m_interval; ptr_vector m_var2expr; unsigned_vector m_expr2var; @@ -77,8 +78,12 @@ class bound_simplifier : public dependent_expr_simplifier { return v; } + bool reduce_arg(expr* arg, expr_ref& result); + br_status reduce_app(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result, proof_ref& pr); + + void assert_lower(expr* x, rational const& n, bool strict); void assert_upper(expr* x, rational const& n, bool strict); @@ -101,7 +106,7 @@ class bound_simplifier : public dependent_expr_simplifier { a(m), m_rewriter(m), bp(nm, m_alloc, p), - m_interval(m.limit()), + m_interval(m_dep_manager, m.limit()), m_trail(m), m_num_buffer(nm) { updt_params(p); diff --git a/src/ast/simplifiers/dependent_expr_state.h b/src/ast/simplifiers/dependent_expr_state.h index b4fe4e9d45d..e187f19c61b 100644 --- a/src/ast/simplifiers/dependent_expr_state.h +++ b/src/ast/simplifiers/dependent_expr_state.h @@ -74,6 +74,8 @@ class dependent_expr_state { virtual bool inconsistent() = 0; virtual model_reconstruction_trail& model_trail() = 0; virtual void flatten_suffix() {} + virtual bool updated() = 0; + virtual void reset_updated() = 0; trail_stack m_trail; void push() { @@ -103,12 +105,15 @@ class dependent_expr_state { class default_dependent_expr_state : public dependent_expr_state { public: default_dependent_expr_state(ast_manager& m): dependent_expr_state(m) {} - virtual unsigned qtail() const { return 0; } - virtual dependent_expr const& operator[](unsigned i) { throw default_exception("unexpected access"); } - virtual void update(unsigned i, dependent_expr const& j) { throw default_exception("unexpected update"); } - virtual void add(dependent_expr const& j) { throw default_exception("unexpected addition"); } - virtual bool inconsistent() { return false; } - virtual model_reconstruction_trail& model_trail() { throw default_exception("unexpected access to model reconstruction"); } + unsigned qtail() const override { return 0; } + dependent_expr const& operator[](unsigned i) override { throw default_exception("unexpected access"); } + void update(unsigned i, dependent_expr const& j) override { throw default_exception("unexpected update"); } + void add(dependent_expr const& j) override { throw default_exception("unexpected addition"); } + bool inconsistent() override { return false; } + model_reconstruction_trail& model_trail() override { throw default_exception("unexpected access to model reconstruction"); } + bool updated() override { return false; } + void reset_updated() override {} + }; inline std::ostream& operator<<(std::ostream& out, dependent_expr_state& st) { @@ -147,7 +152,7 @@ class dependent_expr_simplifier { index_set indices() { return index_set(*this); } proof* mp(proof* a, proof* b) { return (a && b) ? m.mk_modus_ponens(a, b) : nullptr; } - + proof* tr(proof* a, proof* b) { return m.mk_transitivity(a, b); } public: dependent_expr_simplifier(ast_manager& m, dependent_expr_state& s) : m(m), m_fmls(s), m_trail(s.m_trail) {} virtual ~dependent_expr_simplifier() {} diff --git a/src/ast/simplifiers/elim_unconstrained.cpp b/src/ast/simplifiers/elim_unconstrained.cpp index 231858897c3..818800d9947 100644 --- a/src/ast/simplifiers/elim_unconstrained.cpp +++ b/src/ast/simplifiers/elim_unconstrained.cpp @@ -66,7 +66,6 @@ bool elim_unconstrained::is_var_lt(int v1, int v2) const { } void elim_unconstrained::eliminate() { - while (!m_heap.empty()) { expr_ref r(m); int v = m_heap.erase_min(); @@ -86,7 +85,12 @@ void elim_unconstrained::eliminate() { n.m_refcount = 0; continue; } + if (m_heap.contains(root(e))) { + IF_VERBOSE(11, verbose_stream() << "already in heap " << mk_bounded_pp(e, m) << "\n"); + continue; + } app* t = to_app(e); + TRACE("elim_unconstrained", tout << "eliminating " << mk_pp(t, m) << "\n";); unsigned sz = m_args.size(); for (expr* arg : *to_app(t)) m_args.push_back(reconstruct_term(get_node(arg))); @@ -99,14 +103,17 @@ void elim_unconstrained::eliminate() { proof * pr = m.mk_apply_def(s, r, pr1); m_trail.push_back(pr); } + expr_ref rr(m.mk_app(t->get_decl(), t->get_num_args(), m_args.data() + sz), m); n.m_refcount = 0; m_args.shrink(sz); if (!inverted) { IF_VERBOSE(11, verbose_stream() << "not inverted " << mk_bounded_pp(e, m) << "\n"); continue; } + + IF_VERBOSE(11, verbose_stream() << "replace " << mk_pp(t, m) << " / " << rr << " -> " << r << "\n"); - TRACE("elim_unconstrained", tout << mk_pp(t, m) << " -> " << r << "\n"); + TRACE("elim_unconstrained", tout << mk_pp(t, m) << " / " << rr << " -> " << r << "\n"); SASSERT(r->get_sort() == t->get_sort()); m_stats.m_num_eliminated++; m_trail.push_back(r); @@ -119,7 +126,8 @@ void elim_unconstrained::eliminate() { get_node(e).m_term = r; get_node(e).m_proof = pr; get_node(e).m_refcount++; - IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(e, m) << "\n"); + get_node(e).m_dirty = false; + IF_VERBOSE(11, verbose_stream() << "set " << &get_node(e) << " " << root(e) << " " << mk_bounded_pp(e, m) << " := " << mk_bounded_pp(r, m) << "\n"); SASSERT(!m_heap.contains(root(e))); if (is_uninterp_const(r)) m_heap.insert(root(e)); @@ -263,12 +271,18 @@ void elim_unconstrained::gc(expr* t) { while (!todo.empty()) { t = todo.back(); todo.pop_back(); + node& n = get_node(t); if (n.m_refcount == 0) continue; + if (n.m_term && !is_node(n.m_term)) + continue; + dec_ref(t); if (n.m_refcount != 0) continue; + if (n.m_term) + t = n.m_term; if (is_app(t)) { for (expr* arg : *to_app(t)) todo.push_back(arg); @@ -283,13 +297,22 @@ expr_ref elim_unconstrained::reconstruct_term(node& n0) { expr* t = n0.m_term; if (!n0.m_dirty) return expr_ref(t, m); + if (!is_node(t)) + return expr_ref(t, m); ptr_vector todo; todo.push_back(t); while (!todo.empty()) { t = todo.back(); + if (!is_node(t)) { + UNREACHABLE(); + } node& n = get_node(t); unsigned sz0 = todo.size(); - if (is_app(t)) { + if (is_app(t)) { + if (n.m_term != t) { + todo.pop_back(); + continue; + } for (expr* arg : *to_app(t)) if (get_node(arg).m_dirty || !get_node(arg).m_term) todo.push_back(arg); @@ -300,7 +323,6 @@ expr_ref elim_unconstrained::reconstruct_term(node& n0) { for (expr* arg : *to_app(t)) m_args.push_back(get_node(arg).m_term); n.m_term = m.mk_app(to_app(t)->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz); - m_args.shrink(sz); } else if (is_quantifier(t)) { @@ -418,6 +440,6 @@ void elim_unconstrained::reduce() { vector old_fmls; assert_normalized(old_fmls); update_model_trail(*mc, old_fmls); + mc->reset(); } - } diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 3ede7024ee3..280b5e6bfd8 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -307,7 +307,7 @@ namespace euf { } }; SASSERT(e); - if (num_scopes() > 0) + if (num_scopes() > 0 && m_canonical.size() > n->get_id()) m_trail.push(vtrail(m_canonical, n->get_id())); m_canonical.setx(n->get_id(), e); m_epochs.setx(n->get_id(), m_epoch, 0); diff --git a/src/ast/simplifiers/model_reconstruction_trail.cpp b/src/ast/simplifiers/model_reconstruction_trail.cpp index 95f73fd7a76..47ebea52526 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.cpp +++ b/src/ast/simplifiers/model_reconstruction_trail.cpp @@ -182,11 +182,11 @@ std::ostream& model_reconstruction_trail::display(std::ostream& out) const { out << "hide " << t->m_decl->get_name() << "\n"; else if (t->is_def()) { for (auto const& [f, def, dep] : t->m_defs) - out << f->get_name() << " <- " << mk_pp(def, m) << "\n"; + out << "def: " << f->get_name() << " <- " << mk_pp(def, m) << "\n"; } else { for (auto const& [v, def] : t->m_subst->sub()) - out << mk_pp(v, m) << " <- " << mk_pp(def, m) << "\n"; + out << "sub: " << mk_pp(v, m) << " -> " << mk_pp(def, m) << "\n"; } for (auto const& d : t->m_removed) out << "rm: " << d << "\n"; diff --git a/src/ast/simplifiers/model_reconstruction_trail.h b/src/ast/simplifiers/model_reconstruction_trail.h index 79a38401ab9..c2d8b00019a 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.h +++ b/src/ast/simplifiers/model_reconstruction_trail.h @@ -90,7 +90,7 @@ class model_reconstruction_trail { struct undo_model_var : public trail { model_reconstruction_trail& s; undo_model_var(model_reconstruction_trail& s) : s(s) {} - virtual void undo() { + void undo() override { s.m_model_vars.mark(s.m_model_vars_trail.back(), false); s.m_model_vars_trail.pop_back(); } diff --git a/src/ast/simplifiers/solve_context_eqs.cpp b/src/ast/simplifiers/solve_context_eqs.cpp index b56802caf8d..c58786f0ecc 100644 --- a/src/ast/simplifiers/solve_context_eqs.cpp +++ b/src/ast/simplifiers/solve_context_eqs.cpp @@ -243,8 +243,8 @@ namespace euf { void solve_context_eqs::collect_nested_equalities(dependent_expr const& df, expr_mark& visited, dep_eq_vector& eqs) { - svector> todo; - todo.push_back({ false, 0, df.fml()}); + svector> todo; + todo.push_back({ false, 0, df.fml(), 0}); // even depth is conjunctive context, odd is disjunctive // when alternating between conjunctive and disjunctive context, increment depth. @@ -255,37 +255,85 @@ namespace euf { return (0 == depth % 2) ? depth : depth + 1; }; - while (!todo.empty()) { - auto [s, depth, f] = todo.back(); - todo.pop_back(); + for (unsigned i = 0; i < todo.size(); ++i) { + auto [s, depth, f, p] = todo[i]; if (visited.is_marked(f)) continue; visited.mark(f, true); if (s && m.is_and(f)) { for (auto* arg : *to_app(f)) - todo.push_back({ s, inc_or(depth), arg }); + todo.push_back({ s, inc_or(depth), arg, i }); } else if (!s && m.is_or(f)) { for (auto* arg : *to_app(f)) - todo.push_back({ s, inc_or(depth), arg }); + todo.push_back({ s, inc_or(depth), arg, i }); } if (!s && m.is_and(f)) { for (auto* arg : *to_app(f)) - todo.push_back({ s, inc_and(depth), arg }); + todo.push_back({ s, inc_and(depth), arg, i }); } else if (s && m.is_or(f)) { for (auto* arg : *to_app(f)) - todo.push_back({ s, inc_and(depth), arg }); + todo.push_back({ s, inc_and(depth), arg, i }); } else if (m.is_not(f, f)) - todo.push_back({ !s, depth, f }); + todo.push_back({ !s, depth, f, i }); else if (!s && 1 <= depth) { + unsigned sz = eqs.size(); for (extract_eq* ex : m_solve_eqs.m_extract_plugins) { ex->set_allow_booleans(false); ex->get_eqs(dependent_expr(m, f, nullptr, df.dep()), eqs); ex->set_allow_booleans(true); } + // prune eqs for solutions that are not safe in df.fml() + for (; sz < eqs.size(); ++sz) { + if (!is_safe_var(eqs[sz].var, i, df.fml(), todo)) { + eqs[sz] = eqs.back(); + --sz; + eqs.pop_back(); + } + } } } } + + bool solve_context_eqs::is_safe_var(expr* x, unsigned i, expr* f, svector> const& todo) { + m_contains_v.reset(); + m_todo.push_back(f); + mark_occurs(m_todo, x, m_contains_v); + SASSERT(m_todo.empty()); + + auto is_parent = [&](unsigned p, unsigned i) { + while (p != i && i != 0) { + auto [_s,_depth, _f, _p] = todo[i]; + i = _p; + } + return p == i; + }; + + // retrieve oldest parent of i within the same alternation of and + unsigned pi = i; + auto [_s, _depth, _f, _p] = todo[i]; + while (pi != 0) { + auto [s, depth, f, p] = todo[pi]; + if (depth != _depth) + break; + pi = p; + } + + // determine if j and j have common conjunctive parent + // for every j in todo. + for (unsigned j = 0; j < todo.size(); ++j) { + auto [s, depth, f, p] = todo[j]; + if (i == j || !m_contains_v.is_marked(f)) + continue; + if (is_parent(j, i)) // j is a parent if i + continue; + if (is_parent(pi, j)) // pi is a parent of j + continue; + return false; + } + return true; + } + } diff --git a/src/ast/simplifiers/solve_context_eqs.h b/src/ast/simplifiers/solve_context_eqs.h index 8332d3a73c8..a11a1043b0d 100644 --- a/src/ast/simplifiers/solve_context_eqs.h +++ b/src/ast/simplifiers/solve_context_eqs.h @@ -45,7 +45,9 @@ namespace euf { bool contains_conjunctively(expr* f, bool sign, expr* e, signed_expressions& conjuncts); bool is_conjunction(bool sign, expr* f) const; - void collect_nested_equalities(dependent_expr const& f, expr_mark& visited, dep_eq_vector& eqs); + void collect_nested_equalities(dependent_expr const& f, expr_mark& visited, dep_eq_vector& eqs); + + bool is_safe_var(expr* x, unsigned i, expr* f, svector> const& todo); public: diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index fbc6fbb02db..c97cab7c9ed 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -212,12 +212,13 @@ namespace euf { auto [f, p, d] = m_fmls[i](); auto [new_f, new_dep] = rp->replace_with_dep(f); proof_ref new_pr(m); - m_rewriter(new_f, new_f, new_pr); - if (new_f == f) + expr_ref tmp(m); + m_rewriter(new_f, tmp, new_pr); + if (tmp == f) continue; new_dep = m.mk_join(d, new_dep); old_fmls.push_back(m_fmls[i]); - m_fmls.update(i, dependent_expr(m, new_f, mp(p, new_pr), new_dep)); + m_fmls.update(i, dependent_expr(m, tmp, mp(p, new_pr), new_dep)); } } diff --git a/src/ast/simplifiers/then_simplifier.h b/src/ast/simplifiers/then_simplifier.h index 6ee8b94129a..e5a7ca104f6 100644 --- a/src/ast/simplifiers/then_simplifier.h +++ b/src/ast/simplifiers/then_simplifier.h @@ -51,6 +51,10 @@ class then_simplifier : public dependent_expr_simplifier { } }; +protected: + + bool m_bail_on_no_change = false; + public: then_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& fmls): @@ -72,9 +76,17 @@ class then_simplifier : public dependent_expr_simplifier { break; s->reset_statistics(); collect_stats _cs(*s); - s->reduce(); - m_fmls.flatten_suffix(); + m_fmls.reset_updated(); + try { + s->reduce(); + m_fmls.flatten_suffix(); + } + catch (rewriter_exception &) { + break; + } TRACE("simplifier", tout << s->name() << "\n" << m_fmls); + if (m_bail_on_no_change && !m_fmls.updated()) + break; } } @@ -108,3 +120,14 @@ class then_simplifier : public dependent_expr_simplifier { s->pop(n); } }; + +class if_change_simplifier : public then_simplifier { +public: + if_change_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& fmls): + then_simplifier(m, p, fmls) { + m_bail_on_no_change = true; + } + + char const* name() const override { return "if-change-then"; } + +}; diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt new file mode 100644 index 00000000000..24eaec4dcac --- /dev/null +++ b/src/ast/sls/CMakeLists.txt @@ -0,0 +1,14 @@ +z3_add_component(ast_sls + SOURCES + bvsls_opt_engine.cpp + bv_sls.cpp + bv_sls_eval.cpp + bv_sls_fixed.cpp + bv_sls_terms.cpp + sls_engine.cpp + sls_valuation.cpp + COMPONENT_DEPENDENCIES + ast + converters + normal_forms +) diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp new file mode 100644 index 00000000000..f80a362ba32 --- /dev/null +++ b/src/ast/sls/bv_sls.cpp @@ -0,0 +1,295 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/sls/bv_sls.h" +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "params/sls_params.hpp" + +namespace bv { + + sls::sls(ast_manager& m): + m(m), + bv(m), + m_terms(m), + m_eval(m) + {} + + void sls::init() { + m_terms.init(); + } + + void sls::init_eval(std::function& eval) { + m_eval.init_eval(m_terms.assertions(), eval); + m_eval.tighten_range(m_terms.assertions()); + init_repair(); + } + + void sls::init_repair() { + m_repair_down = UINT_MAX; + m_repair_up.reset(); + m_repair_roots.reset(); + for (auto* e : m_terms.assertions()) { + if (!m_eval.bval0(e)) { + m_eval.set(e, true); + m_repair_roots.insert(e->get_id()); + } + } + for (auto* t : m_terms.terms()) { + if (t && !re_eval_is_correct(t)) + m_repair_roots.insert(t->get_id()); + } + } + + void sls::init_repair_goal(app* t) { + if (m.is_bool(t)) + m_eval.set(t, m_eval.bval1(t)); + else if (bv.is_bv(t)) { + auto& v = m_eval.wval(t); + v.bits().copy_to(v.nw, v.eval); + } + } + + void sls::reinit_eval() { + std::function eval = [&](expr* e, unsigned i) { + auto should_keep = [&]() { + return m_rand() % 100 <= 92; + }; + if (m.is_bool(e)) { + if (m_eval.is_fixed0(e) || should_keep()) + return m_eval.bval0(e); + } + else if (bv.is_bv(e)) { + auto& w = m_eval.wval(e); + if (w.fixed.get(i) || should_keep()) + return w.get_bit(i); + } + return m_rand() % 2 == 0; + }; + m_eval.init_eval(m_terms.assertions(), eval); + init_repair(); + } + + std::pair sls::next_to_repair() { + app* e = nullptr; + if (m_repair_down != UINT_MAX) { + e = m_terms.term(m_repair_down); + m_repair_down = UINT_MAX; + return { true, e }; + } + + if (!m_repair_up.empty()) { + unsigned index = m_repair_up.elem_at(m_rand(m_repair_up.size())); + m_repair_up.remove(index); + e = m_terms.term(index); + return { false, e }; + } + + while (!m_repair_roots.empty()) { + unsigned index = m_repair_roots.elem_at(m_rand(m_repair_roots.size())); + e = m_terms.term(index); + if (m_terms.is_assertion(e) && !m_eval.bval1(e)) { + SASSERT(m_eval.bval0(e)); + return { true, e }; + } + if (!re_eval_is_correct(e)) { + init_repair_goal(e); + return { true, e }; + } + m_repair_roots.remove(index); + } + + return { false, nullptr }; + } + + lbool sls::search() { + // init and init_eval were invoked + unsigned n = 0; + for (; n++ < m_config.m_max_repairs && m.inc(); ) { + auto [down, e] = next_to_repair(); + if (!e) + return l_true; + + + trace_repair(down, e); + + ++m_stats.m_moves; + + if (down) + try_repair_down(e); + else + try_repair_up(e); + } + return l_undef; + } + + + lbool sls::operator()() { + lbool res = l_undef; + m_stats.reset(); + m_stats.m_restarts = 0; + do { + res = search(); + if (res != l_undef) + break; + trace(); + reinit_eval(); + } + while (m.inc() && m_stats.m_restarts++ < m_config.m_max_restarts); + + return res; + } + + void sls::try_repair_down(app* e) { + + unsigned n = e->get_num_args(); + if (n == 0) { + if (m.is_bool(e)) + m_eval.set(e, m_eval.bval1(e)); + else + VERIFY(m_eval.wval(e).commit_eval()); + + for (auto p : m_terms.parents(e)) + m_repair_up.insert(p->get_id()); + return; + } + + unsigned s = m_rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (m_eval.try_repair(e, j)) { + set_repair_down(e->get_arg(j)); + return; + } + } + // search a new root / random walk to repair + } + + void sls::try_repair_up(app* e) { + + if (m_terms.is_assertion(e) || !m_eval.repair_up(e)) + m_repair_roots.insert(e->get_id()); + else { + if (!eval_is_correct(e)) { + verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + } + SASSERT(eval_is_correct(e)); + for (auto p : m_terms.parents(e)) + m_repair_up.insert(p->get_id()); + } + } + + bool sls::eval_is_correct(app* e) { + if (!m_eval.can_eval1(e)) + return false; + if (m.is_bool(e)) + return m_eval.bval0(e) == m_eval.bval1(e); + if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + return v.eval == v.bits(); + } + UNREACHABLE(); + return false; + } + + + bool sls::re_eval_is_correct(app* e) { + if (!m_eval.can_eval1(e)) + return false; + if (m.is_bool(e)) + return m_eval.bval0(e) == m_eval.bval1(e); + if (bv.is_bv(e)) { + auto const& v = m_eval.eval(e); + return v.eval == v.bits(); + } + UNREACHABLE(); + return false; + } + + model_ref sls::get_model() { + model_ref mdl = alloc(model, m); + auto& terms = m_eval.sort_assertions(m_terms.assertions()); + for (expr* e : terms) { + if (!re_eval_is_correct(to_app(e))) { + verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + verbose_stream() << v << "\n" << v.eval << "\n"; + } + } + if (!is_uninterp_const(e)) + continue; + + auto f = to_app(e)->get_decl(); + if (m.is_bool(e)) + mdl->register_decl(f, m.mk_bool_val(m_eval.bval0(e))); + else if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + rational n = v.get_value(); + mdl->register_decl(f, bv.mk_numeral(n, v.bw)); + } + } + terms.reset(); + return mdl; + } + + std::ostream& sls::display(std::ostream& out) { + auto& terms = m_eval.sort_assertions(m_terms.assertions()); + for (expr* e : terms) { + out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; + if (m_eval.is_fixed0(e)) + out << "f "; + if (m_repair_up.contains(e->get_id())) + out << "u "; + if (m_repair_roots.contains(e->get_id())) + out << "r "; + if (bv.is_bv(e)) + out << m_eval.wval(e); + else if (m.is_bool(e)) + out << (m_eval.bval0(e)?"T":"F"); + out << "\n"; + } + terms.reset(); + return out; + } + + void sls::updt_params(params_ref const& _p) { + sls_params p(_p); + m_config.m_max_restarts = p.max_restarts(); + m_rand.set_seed(p.random_seed()); + } + + void sls::trace_repair(bool down, expr* e) { + IF_VERBOSE(20, + verbose_stream() << (down ? "d #" : "u #") + << e->get_id() << ": " + << mk_bounded_pp(e, m, 1) << " "; + if (bv.is_bv(e)) verbose_stream() << m_eval.wval(e) << " " << (m_eval.is_fixed0(e) ? "fixed " : " "); + if (m.is_bool(e)) verbose_stream() << m_eval.bval0(e) << " "; + verbose_stream() << "\n"); + } + + void sls::trace() { + IF_VERBOSE(2, verbose_stream() + << "(bvsls :restarts " << m_stats.m_restarts + << " :repair-up " << m_repair_up.size() + << " :repair-roots " << m_repair_roots.size() << ")\n"); + } +} diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h new file mode 100644 index 00000000000..bbcd59aea95 --- /dev/null +++ b/src/ast/sls/bv_sls.h @@ -0,0 +1,110 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_valuation.h" +#include "ast/sls/bv_sls_terms.h" +#include "ast/sls/bv_sls_eval.h" +#include "ast/bv_decl_plugin.h" +#include "model/model.h" + +namespace bv { + + + class sls { + + struct config { + unsigned m_max_restarts = 1000; + unsigned m_max_repairs = 1000; + }; + + ast_manager& m; + bv_util bv; + sls_terms m_terms; + sls_eval m_eval; + sls_stats m_stats; + indexed_uint_set m_repair_up, m_repair_roots; + unsigned m_repair_down = UINT_MAX; + ptr_vector m_todo; + random_gen m_rand; + config m_config; + + std::pair next_to_repair(); + + bool eval_is_correct(app* e); + bool re_eval_is_correct(app* e); + void init_repair_goal(app* e); + void try_repair_down(app* e); + void try_repair_up(app* e); + void set_repair_down(expr* e) { m_repair_down = e->get_id(); } + + lbool search(); + void reinit_eval(); + void init_repair(); + void trace(); + void trace_repair(bool down, expr* e); + + public: + sls(ast_manager& m); + + /** + * Add constraints + */ + void assert_expr(expr* e) { m_terms.assert_expr(e); } + + /* + * Invoke init after all expressions are asserted. + * No other expressions can be asserted after init. + */ + void init(); + + /** + * Invoke init_eval to initialize, or re-initialize, values of + * uninterpreted constants. + */ + void init_eval(std::function& eval); + + /** + * Run (bounded) local search to find feasible assignments. + */ + lbool operator()(); + + void updt_params(params_ref const& p); + void collect_statistics(statistics & st) const { m_stats.collect_statistics(st); } + void reset_statistics() { m_stats.reset(); } + + sls_stats const& get_stats() const { return m_stats; } + + std::ostream& display(std::ostream& out); + + /** + * Retrieve valuation + */ + sls_valuation const& wval(expr* e) const { return m_eval.wval(e); } + + model_ref get_model(); + + void cancel() { m.limit().cancel(); } + }; +} diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp new file mode 100644 index 00000000000..4b7bf954653 --- /dev/null +++ b/src/ast/sls/bv_sls_eval.cpp @@ -0,0 +1,1735 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_eval.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls.h" + +namespace bv { + + sls_eval::sls_eval(ast_manager& m): + m(m), + bv(m), + m_fix(*this) + {} + + void sls_eval::init_eval(expr_ref_vector const& es, std::function const& eval) { + sort_assertions(es); + for (expr* e : m_todo) { + if (!is_app(e)) + continue; + app* a = to_app(e); + if (bv.is_bv(e)) + add_bit_vector(a); + if (a->get_family_id() == basic_family_id) + init_eval_basic(a); + else if (a->get_family_id() == bv.get_family_id()) + init_eval_bv(a); + else if (is_uninterp(e)) { + if (bv.is_bv(e)) { + auto& v = wval(e); + for (unsigned i = 0; i < v.bw; ++i) + m_tmp.set(i, eval(e, i)); + v.set_repair(random_bool(), m_tmp); + } + else if (m.is_bool(e)) + m_eval.setx(e->get_id(), eval(e, 0), false); + } + else { + TRACE("sls", tout << "Unhandled expression " << mk_pp(e, m) << "\n"); + } + } + m_todo.reset(); + } + + /** + * Sort all sub-expressions by depth, smallest first. + */ + ptr_vector& sls_eval::sort_assertions(expr_ref_vector const& es) { + expr_fast_mark1 mark; + for (expr* e : es) { + if (!mark.is_marked(e)) { + mark.mark(e); + m_todo.push_back(e); + } + } + for (unsigned i = 0; i < m_todo.size(); ++i) { + auto e = m_todo[i]; + if (!is_app(e)) + continue; + for (expr* arg : *to_app(e)) { + if (!mark.is_marked(arg)) { + mark.mark(arg); + m_todo.push_back(arg); + } + } + } + std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + return m_todo; + } + + bool sls_eval::add_bit_vector(app* e) { + m_values.reserve(e->get_id() + 1); + if (m_values.get(e->get_id())) + return false; + auto v = alloc_valuation(e); + m_values.set(e->get_id(), v); + if (bv.is_sign_ext(e)) { + unsigned p = e->get_parameter(0).get_int(); + v->set_signed(p); + } + return true; + } + + sls_valuation* sls_eval::alloc_valuation(app* e) { + auto bit_width = bv.get_bv_size(e); + auto* r = alloc(sls_valuation, bit_width); + while (m_tmp.size() < 2 * r->nw) { + m_tmp.push_back(0); + m_tmp2.push_back(0); + m_tmp3.push_back(0); + m_tmp4.push_back(0); + m_zero.push_back(0); + m_one.push_back(0); + m_a.push_back(0); + m_b.push_back(0); + m_nexta.push_back(0); + m_nextb.push_back(0); + m_aux.push_back(0); + m_minus_one.push_back(~0); + m_one[0] = 1; + } + return r; + } + + void sls_eval::init_eval_basic(app* e) { + auto id = e->get_id(); + if (m.is_bool(e)) + m_eval.setx(id, bval1(e), false); + else if (m.is_ite(e)) { + SASSERT(bv.is_bv(e->get_arg(1))); + auto& val = wval(e); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + if (bval0(e->get_arg(0))) + val.set(val_th.bits()); + else + val.set(val_el.bits()); + } + else { + UNREACHABLE(); + } + } + + void sls_eval::init_eval_bv(app* e) { + if (bv.is_bv(e)) + eval(e).commit_eval(); + else if (m.is_bool(e)) + m_eval.setx(e->get_id(), bval1_bv(e), false); + } + + bool sls_eval::bval1_basic(app* e) const { + SASSERT(m.is_bool(e)); + SASSERT(e->get_family_id() == basic_family_id); + + auto id = e->get_id(); + switch (e->get_decl_kind()) { + case OP_TRUE: + return true; + case OP_FALSE: + return false; + case OP_AND: + return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_OR: + return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_NOT: + return !bval0(e->get_arg(0)); + case OP_XOR: { + bool r = false; + for (auto* arg : *to_app(e)) + r ^= bval0(arg); + return r; + } + case OP_IMPLIES: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + return !bval0(a) || bval0(b); + } + case OP_ITE: { + auto c = bval0(e->get_arg(0)); + return bval0(c ? e->get_arg(1) : e->get_arg(2)); + } + case OP_EQ: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + if (m.is_bool(a)) + return bval0(a) == bval0(b); + else if (bv.is_bv(a)) { + auto const& va = wval(a); + auto const& vb = wval(b); + return va.eq(vb); + } + return m.are_equal(a, b); + } + case OP_DISTINCT: + default: + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + UNREACHABLE(); + break; + } + UNREACHABLE(); + return false; + } + + bool sls_eval::can_eval1(app* e) const { + expr* x, * y, * z; + if (m.is_eq(e, x, y)) + return m.is_bool(x) || bv.is_bv(x); + if (m.is_ite(e, x, y, z)) + return m.is_bool(y) || bv.is_bv(y); + if (e->get_family_id() == bv.get_fid()) { + switch (e->get_decl_kind()) { + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + return false; + default: + return true; + } + } + if (e->get_family_id() == basic_family_id) + return true; + if (is_uninterp_const(e)) + return m.is_bool(e) || bv.is_bv(e); + return false; + } + + bool sls_eval::bval1_bv(app* e) const { + SASSERT(m.is_bool(e)); + SASSERT(e->get_family_id() == bv.get_fid()); + + auto ucompare = [&](std::function const& f) { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + return f(mpn.compare(a.bits().data(), a.nw, b.bits().data(), b.nw)); + }; + + // x x + 2^{bw-1} const& f) { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + add_p2_1(a, m_tmp); + add_p2_1(b, m_tmp2); + return f(mpn.compare(m_tmp.data(), a.nw, m_tmp2.data(), b.nw)); + }; + + auto umul_overflow = [&]() { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + return a.set_mul(m_tmp2, a.bits(), b.bits()); + }; + + switch (e->get_decl_kind()) { + case OP_ULEQ: + return ucompare([](int i) { return i <= 0; }); + case OP_ULT: + return ucompare([](int i) { return i < 0; }); + case OP_UGT: + return ucompare([](int i) { return i > 0; }); + case OP_UGEQ: + return ucompare([](int i) { return i >= 0; }); + case OP_SLEQ: + return scompare([](int i) { return i <= 0; }); + case OP_SLT: + return scompare([](int i) { return i < 0; }); + case OP_SGT: + return scompare([](int i) { return i > 0; }); + case OP_SGEQ: + return scompare([](int i) { return i >= 0; }); + case OP_BIT2BOOL: { + expr* child; + unsigned idx; + VERIFY(bv.is_bit2bool(e, child, idx)); + auto& a = wval(child); + return a.get_bit(idx); + } + case OP_BUMUL_NO_OVFL: + return !umul_overflow(); + case OP_BUMUL_OVFL: + return umul_overflow(); + case OP_BUADD_OVFL: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + return a.set_add(m_tmp, a.bits(), b.bits()); + } + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + NOT_IMPLEMENTED_YET(); + break; + default: + UNREACHABLE(); + break; + } + return false; + } + + bool sls_eval::bval1(app* e) const { + if (e->get_family_id() == basic_family_id) + return bval1_basic(e); + if (e->get_family_id() == bv.get_fid()) + return bval1_bv(e); + SASSERT(is_uninterp_const(e)); + return bval0(e); + } + + sls_valuation& sls_eval::eval(app* e) const { + auto& val = *m_values[e->get_id()]; + eval(e, val); + return val; + } + + void sls_eval::eval(app* e, sls_valuation& val) const { + SASSERT(bv.is_bv(e)); + if (m.is_ite(e)) { + SASSERT(bv.is_bv(e->get_arg(1))); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + if (bval0(e->get_arg(0))) + val.set(val_th.bits()); + else + val.set(val_el.bits()); + return; + } + if (e->get_family_id() == null_family_id) { + val.set(wval(e).bits()); + return; + } + auto set_sdiv = [&]() { + // d = udiv(abs(x), abs(y)) + // y = 0, x >= 0 -> -1 + // y = 0, x < 0 -> 1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + bool sign_a = a.sign(); + bool sign_b = b.sign(); + if (b.is_zero()) { + if (sign_a) + val.set(m_one); + else + val.set(m_minus_one); + } + else if (a.is_zero()) + val.set(m_zero); + else { + if (sign_a) + a.set_sub(m_tmp, m_zero, a.bits()); + else + a.get(m_tmp); + + if (sign_b) + b.set_sub(m_tmp2, m_zero, b.bits()); + else + b.get(m_tmp2); + + set_div(m_tmp, m_tmp2, a.bw, m_tmp3, m_tmp4); + if (sign_a == sign_b) + val.set(m_tmp3); + else + val.set_sub(val.eval, m_zero, m_tmp3); + } + }; + + auto mk_rotate_left = [&](unsigned n) { + auto& a = wval(e->get_arg(0)); + VERIFY(try_repair_rotate_left(a.bits(), val, a.bw - n)); + }; + + SASSERT(e->get_family_id() == bv.get_fid()); + switch (e->get_decl_kind()) { + case OP_BV_NUM: { + rational n; + VERIFY(bv.is_numeral(e, n)); + val.set_value(m_tmp, n); + val.set(m_tmp); + break; + } + case OP_BAND: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] & b.bits()[i]; + break; + } + case OP_BOR: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] | b.bits()[i]; + break; + } + case OP_BXOR: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] ^ b.bits()[i]; + break; + } + case OP_BNAND: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = ~(a.bits()[i] & b.bits()[i]); + break; + } + case OP_BADD: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_add(val.eval, a.bits(), b.bits()); + break; + } + case OP_BSUB: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_sub(val.eval, a.bits(), b.bits()); + break; + } + case OP_BMUL: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_mul(m_tmp2, a.bits(), b.bits()); + val.set(m_tmp2); + break; + } + case OP_CONCAT: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < b.bw; ++i) + val.eval.set(i, b.get_bit(i)); + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i + b.bw, a.get_bit(i)); + break; + } + case OP_EXTRACT: { + expr* child; + unsigned lo, hi; + VERIFY(bv.is_extract(e, lo, hi, child)); + auto const& a = wval(child); + SASSERT(lo <= hi && hi + 1 <= a.bw && hi - lo + 1 == val.bw); + for (unsigned i = lo; i <= hi; ++i) + val.eval.set(i - lo, a.get_bit(i)); + break; + } + case OP_BNOT: { + auto& a = wval(e->get_arg(0)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = ~a.bits()[i]; + break; + } + case OP_BNEG: { + auto& a = wval(e->get_arg(0)); + val.set_sub(val.eval, m_zero, a.bits()); + break; + } + case OP_BIT0: + val.eval.set(0, false); + break; + case OP_BIT1: + val.eval.set(0, true); + break; + case OP_BSHL: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) + val.set_zero(); + else { + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i, i >= sh && a.get_bit(i - sh)); + } + break; + } + case OP_BLSHR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) + val.set_zero(); + else { + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i, i + sh < a.bw && a.get_bit(i + sh)); + } + break; + } + case OP_BASHR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + auto sign = a.sign(); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = sign ? ~0 : 0; + val.set(m_tmp); + } + else { + a.set_zero(m_tmp); + for (unsigned i = 0; i < a.bw; ++i) + m_tmp.set(i, i + sh < a.bw && a.get_bit(i + sh)); + if (sign) + val.set_range(m_tmp, a.bw - sh, a.bw, true); + val.set(m_tmp); + } + break; + } + case OP_SIGN_EXT: { + auto& a = wval(e->get_arg(0)); + a.get(m_tmp); + bool sign = a.sign(); + val.set_range(m_tmp, a.bw, val.bw, sign); + val.set(m_tmp); + break; + } + case OP_ZERO_EXT: { + auto& a = wval(e->get_arg(0)); + a.get(m_tmp); + val.set_range(m_tmp, a.bw, val.bw, false); + val.set(m_tmp); + break; + } + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + + if (b.is_zero()) + val.set(a.bits()); + else { + set_div(a.bits(), b.bits(), b.bw, m_tmp, m_tmp2); + val.set(m_tmp2); + } + break; + } + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: { + // u = mod(abs(x),abs(y)) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(a.bits()); + else { + if (a.sign()) + a.set_sub(m_tmp3, m_zero, a.bits()); + else + a.set(m_tmp3, a.bits()); + if (b.sign()) + b.set_sub(m_tmp4, m_zero, b.bits()); + else + a.set(m_tmp4, b.bits()); + set_div(m_tmp3, m_tmp4, a.bw, m_tmp, m_tmp2); + if (val.is_zero(m_tmp2)) + val.set(m_tmp2); + else if (a.sign() && b.sign()) + val.set_sub(val.eval, m_zero, m_tmp2); + else if (a.sign()) + val.set_sub(val.eval, b.bits(), m_tmp2); + else if (b.sign()) + val.set_add(val.eval, b.bits(), m_tmp2); + else + val.set(m_tmp2); + } + break; + } + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: { + // x div 0 = -1 + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(m_minus_one); + else { + set_div(a.bits(), b.bits(), a.bw, m_tmp, m_tmp2); + val.set(m_tmp); + } + break; + } + + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: { + set_sdiv(); + break; + } + case OP_BSREM: + case OP_BSREM0: + case OP_BSREM_I: { + // b = 0 -> a + // else a - sdiv(a, b) * b + // + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(a.bits()); + else { + set_sdiv(); + val.set_mul(m_tmp, val.eval, b.bits()); + val.set_sub(val.eval, a.bits(), m_tmp); + } + break; + } + case OP_ROTATE_LEFT: { + unsigned n = e->get_parameter(0).get_int() % val.bw; + mk_rotate_left(n); + break; + } + case OP_ROTATE_RIGHT: { + unsigned n = e->get_parameter(0).get_int() % val.bw; + mk_rotate_left(val.bw - n); + break; + } + case OP_EXT_ROTATE_LEFT: { + auto& b = wval(e->get_arg(1)); + rational n = b.get_value(); + n = mod(n, rational(val.bw)); + SASSERT(n.is_unsigned()); + mk_rotate_left(n.get_unsigned()); + break; + } + case OP_EXT_ROTATE_RIGHT: { + auto& b = wval(e->get_arg(1)); + rational n = b.get_value(); + n = mod(n, rational(val.bw)); + SASSERT(n.is_unsigned()); + mk_rotate_left(val.bw - n.get_unsigned()); + break; + } + case OP_BCOMP: { + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + if (a.bits() == b.bits()) + val.set(val.eval, 1); + else + val.set(val.eval, 0); + break; + } + case OP_BREDAND: + case OP_BREDOR: + case OP_BXNOR: + case OP_INT2BV: + + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + break; + case OP_BIT2BOOL: + case OP_BV2INT: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BUADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + case OP_BUMUL_NO_OVFL: + case OP_BUMUL_OVFL: + case OP_ULEQ: + case OP_UGEQ: + case OP_UGT: + case OP_ULT: + case OP_SLEQ: + case OP_SGEQ: + case OP_SGT: + case OP_SLT: + UNREACHABLE(); + break; + default: + UNREACHABLE(); + break; + } + val.clear_overflow_bits(val.eval); + } + + digit_t sls_eval::random_bits() { + return sls_valuation::random_bits(m_rand); + } + + bool sls_eval::try_repair(app* e, unsigned i) { + if (is_fixed0(e->get_arg(i))) + return false; + else if (e->get_family_id() == basic_family_id) + return try_repair_basic(e, i); + if (e->get_family_id() == bv.get_family_id()) + return try_repair_bv(e, i); + return false; + } + + bool sls_eval::try_repair_basic(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_AND: + return try_repair_and_or(e, i); + case OP_OR: + return try_repair_and_or(e, i); + case OP_NOT: + return try_repair_not(e); + case OP_FALSE: + return false; + case OP_TRUE: + return false; + case OP_EQ: + return try_repair_eq(e, i); + case OP_IMPLIES: + return try_repair_implies(e, i); + case OP_XOR: + return try_repair_xor(e, i); + case OP_ITE: + return try_repair_ite(e, i); + default: + UNREACHABLE(); + return false; + } + } + + bool sls_eval::try_repair_bv(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_BAND: + return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BOR: + return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BXOR: + return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BADD: + return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BSUB: + return try_repair_sub(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BMUL: + return try_repair_mul(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BNOT: + return try_repair_bnot(eval_value(e), wval(e, i)); + case OP_BNEG: + return try_repair_bneg(eval_value(e), wval(e, i)); + case OP_BIT0: + return false; + case OP_BIT1: + return false; + case OP_BV2INT: + return false; + case OP_INT2BV: + return false; + case OP_ULEQ: + if (i == 0) + return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_uge(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_UGEQ: + if (i == 0) + return try_repair_uge(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_UGT: + if (i == 0) + return try_repair_ule(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_uge(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_ULT: + if (i == 0) + return try_repair_uge(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_ule(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SLEQ: + if (i == 0) + return try_repair_sle(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sge(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SGEQ: + if (i == 0) + return try_repair_sge(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sle(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SGT: + if (i == 0) + return try_repair_sle(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sge(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SLT: + if (i == 0) + return try_repair_sge(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sle(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_BASHR: + return try_repair_ashr(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BLSHR: + return try_repair_lshr(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BSHL: + return try_repair_shl(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BIT2BOOL: { + unsigned idx; + expr* arg; + VERIFY(bv.is_bit2bool(e, arg, idx)); + return try_repair_bit2bool(wval(e, 0), idx); + } + + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: + return try_repair_udiv(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: + return try_repair_urem(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_ROTATE_LEFT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), e->get_parameter(0).get_int()); + case OP_ROTATE_RIGHT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e).bw - e->get_parameter(0).get_int()); + case OP_EXT_ROTATE_LEFT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_EXT_ROTATE_RIGHT: + return try_repair_rotate_right(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_ZERO_EXT: + return try_repair_zero_ext(eval_value(e), wval(e, 0)); + case OP_SIGN_EXT: + return try_repair_sign_ext(eval_value(e), wval(e, 0)); + case OP_CONCAT: + return try_repair_concat(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_EXTRACT: { + unsigned hi, lo; + expr* arg; + VERIFY(bv.is_extract(e, lo, hi, arg)); + return try_repair_extract(eval_value(e), wval(arg), lo); + } + case OP_BUMUL_NO_OVFL: + return try_repair_umul_ovfl(!bval0(e), wval(e, 0), wval(e, 1), i); + case OP_BUMUL_OVFL: + return try_repair_umul_ovfl(bval0(e), wval(e, 0), wval(e, 1), i); + case OP_BCOMP: + return try_repair_comp(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BUADD_OVFL: + + case OP_BNAND: + case OP_BREDAND: + case OP_BREDOR: + case OP_BXNOR: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + verbose_stream() << mk_pp(e, m) << "\n"; + return false; + case OP_BSREM: + case OP_BSREM_I: + case OP_BSREM0: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + // these are currently compiled to udiv and urem. + UNREACHABLE(); + return false; + default: + return false; + } + } + + bool sls_eval::try_repair_and_or(app* e, unsigned i) { + auto b = bval0(e); + auto child = e->get_arg(i); + if (b == bval0(child)) + return false; + m_eval[child->get_id()] = b; + return true; + } + + bool sls_eval::try_repair_not(app* e) { + auto child = e->get_arg(0); + m_eval[child->get_id()] = !bval0(e); + return true; + } + + bool sls_eval::try_repair_eq(app* e, unsigned i) { + auto child = e->get_arg(i); + auto is_true = bval0(e); + if (m.is_bool(child)) { + SASSERT(!is_fixed0(child)); + auto bv = bval0(e->get_arg(1 - i)); + m_eval[child->get_id()] = is_true == bv; + return true; + } + else if (bv.is_bv(child)) { + auto & a = wval(e->get_arg(i)); + auto & b = wval(e->get_arg(1 - i)); + return try_repair_eq(is_true, a, b); + } + return false; + } + + bool sls_eval::try_repair_eq(bool is_true, bvval& a, bvval const& b) { + if (is_true) { + if (m_rand() % 20 != 0) + if (a.try_set(b.bits())) + return true; + + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + else { + bool try_above = m_rand() % 2 == 0; + if (try_above) { + a.set_add(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_tmp2, m_rand)) + return true; + } + a.set_sub(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_most(m_tmp, m_tmp2, m_rand)) + return true; + if (!try_above) { + a.set_add(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_tmp2, m_rand)) + return true; + } + return false; + } + } + + bool sls_eval::try_repair_xor(app* e, unsigned i) { + bool ev = bval0(e); + bool bv = bval0(e->get_arg(1 - i)); + auto child = e->get_arg(i); + m_eval[child->get_id()] = ev != bv; + return true; + } + + bool sls_eval::try_repair_ite(app* e, unsigned i) { + auto child = e->get_arg(i); + bool c = bval0(e->get_arg(0)); + if (i == 0) { + m_eval[child->get_id()] = !c; + return true; + } + if (c != (i == 1)) + return false; + if (m.is_bool(e)) { + m_eval[child->get_id()] = bval0(e); + return true; + } + if (bv.is_bv(e)) + return wval(child).try_set(wval(e).bits()); + return false; + } + + bool sls_eval::try_repair_implies(app* e, unsigned i) { + auto child = e->get_arg(i); + bool ev = bval0(e); + bool av = bval0(child); + bool bv = bval0(e->get_arg(1 - i)); + if (i == 0) { + if (ev == (!av || bv)) + return false; + } + else if (ev != (!bv || av)) + return false; + m_eval[child->get_id()] = ev; + return true; + } + + // + // e = a & b + // e[i] = 1 -> a[i] = 1 + // e[i] = 0 & b[i] = 1 -> a[i] = 0 + // e[i] = 0 & b[i] = 0 -> a[i] = random + // a := e[i] | (~b[i] & a[i]) + + bool sls_eval::try_repair_band(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~a.fixed[i] & (e[i] | (~b.bits()[i] & random_bits())); + return a.set_repair(random_bool(), m_tmp); + } + + // + // e = a | b + // set a[i] to 1 where b[i] = 0, e[i] = 1 + // set a[i] to 0 where e[i] = 0, a[i] = 1 + // + bool sls_eval::try_repair_bor(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = e[i] & (~b.bits()[i] | random_bits()); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = e[i] ^ b.bits()[i]; + a.clear_overflow_bits(m_tmp); + return a.set_repair(random_bool(), m_tmp); + } + + + // + // first try to set a := e - b + // If this fails, set a to a random value + // + bool sls_eval::try_repair_add(bvect const& e, bvval& a, bvval const& b) { + if (m_rand() % 20 != 0) { + a.set_sub(m_tmp, e, b.bits()); + if (a.try_set(m_tmp)) + return true; + } + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) { + if (m_rand() % 20 != 0) { + if (i == 0) + // e = a - b -> a := e + b + a.set_add(m_tmp, e, b.bits()); + else + // b := a - e + b.set_sub(m_tmp, a.bits(), e); + if (a.try_set(m_tmp)) + return true; + } + // fall back to a random value + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + /** + * e = a*b, then a = e * b^-1 + * 8*e = a*(2b), then a = 4e*b^-1 + */ + bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvval const& b) { + unsigned parity_e = b.parity(e); + unsigned parity_b = b.parity(b.bits()); + + if (b.is_zero(e)) { + a.get_variant(m_tmp, m_rand); + for (unsigned i = 0; i < b.bw - parity_b; ++i) + m_tmp.set(i, false); + return a.set_repair(random_bool(), m_tmp); + } + + if (b.is_zero()) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + if (m_rand() % 20 == 0) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + +#if 0 + verbose_stream() << "solve for " << e << "\n"; + + rational r = e.get_value(e.nw); + rational root; + verbose_stream() << r.is_int_perfect_square(root) << "\n"; +#endif + + + auto& x = m_tmp; + auto& y = m_tmp2; + auto& quot = m_tmp3; + auto& rem = m_tmp4; + auto& ta = m_a; + auto& tb = m_b; + auto& nexta = m_nexta; + auto& nextb = m_nextb; + auto& aux = m_aux; + auto bw = b.bw; + + + // x*ta + y*tb = x + + b.get(y); + if (parity_b > 0) { + b.shift_right(y, parity_b); +#if 0 + for (unsigned i = parity_b; i < b.bw; ++i) + y.set(i, m_rand() % 2 == 0); +#endif + } + + y[a.nw] = 0; + x[a.nw] = 0; + + + a.set_bw((a.nw + 1)* 8 * sizeof(digit_t)); + y.set_bw(a.bw); // enable comparisons + a.set_zero(x); + x.set(bw, true); // x = 2 ^ b.bw + + a.set_one(ta); + a.set_zero(tb); + a.set_zero(nexta); + a.set_one(nextb); + + rem.reserve(2 * a.nw); + SASSERT(y <= x); + while (y > m_zero) { + SASSERT(y <= x); + set_div(x, y, a.bw, quot, rem); // quot, rem := quot_rem(x, y) + SASSERT(y >= rem); + a.set(x, y); // x := y + a.set(y, rem); // y := rem + a.set(aux, nexta); // aux := nexta + a.set_mul(rem, quot, nexta, false); + a.set_sub(nexta, ta, rem); // nexta := ta - quot*nexta + a.set(ta, aux); // ta := aux + a.set(aux, nextb); // aux := nextb + a.set_mul(rem, quot, nextb, false); + a.set_sub(nextb, tb, rem); // nextb := tb - quot*nextb + a.set(tb, aux); // tb := aux + } + + a.set_bw(bw); + y.set_bw(0); + // x*a + y*b = 1 + + tb.set_bw(0); +#if Z3DEBUG + b.get(y); + if (parity_b > 0) + b.shift_right(y, parity_b); + a.set_mul(m_tmp, tb, y); + SASSERT(a.is_one(m_tmp)); +#endif + e.copy_to(b.nw, m_tmp2); + if (parity_e > 0 && parity_b > 0) + b.shift_right(m_tmp2, std::min(parity_b, parity_e)); + a.set_mul(m_tmp, tb, m_tmp2); + if (a.set_repair(random_bool(), m_tmp)) + return true; + + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_bnot(bvect const& e, bvval& a) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + bool sls_eval::try_repair_bneg(bvect const& e, bvval& a) { + a.set_sub(m_tmp, m_zero, e); + return a.try_set(m_tmp); + } + + + // a <=s b <-> a + p2 <=u b + p2 + // + // NB: p2 = -p2 + // + // to solve x for x >s b: + // infeasible if b + 1 = p2 + // solve for x >=s b + 1 + // + bool sls_eval::try_repair_sle(bool e, bvval& a, bvval const& b) { + auto& p2 = m_b; + b.set_zero(p2); + p2.set(b.bw - 1, true); + p2.set_bw(b.bw); + bool r = false; + if (e) + r = try_repair_sle(a, b.bits(), p2); + else { + auto& b1 = m_nexta; + a.set_add(b1, b.bits(), m_one); + b1.set_bw(b.bw); + if (p2 == b1) + r = false; + else + r = try_repair_sge(a, b1, p2); + b1.set_bw(0); + } + p2.set_bw(0); + return r; + } + + // to solve x for x = p2 if c >= p2 (b < p2) + // or + // x := random p2 <= x <= b if c < p2 (b >= p2) + // + bool sls_eval::try_repair_sle(bvval& a, bvect const& b, bvect const& p2) { + bool r = false; + if (b < p2) { + bool coin = m_rand() % 2 == 0; + if (coin) + r = a.set_random_at_least(p2, m_tmp3, m_rand); + if (!r) + r = a.set_random_at_most(b, m_tmp3, m_rand); + if (!coin && !r) + r = a.set_random_at_least(p2, m_tmp3, m_rand); + } + else + r = a.set_random_in_range(p2, b, m_tmp3, m_rand); + return r; + } + + // solve for x >=s b + // + // d := b + p2 + // + // x := random b <= x < p2 if d >= p2 (b < p2) + // or + // x := random b <= x or x < p2 if d < p2 + // + + bool sls_eval::try_repair_sge(bvval& a, bvect const& b, bvect const& p2) { + auto& p2_1 = m_tmp4; + a.set_sub(p2_1, p2, m_one); + p2_1.set_bw(a.bw); + bool r = false; + if (p2 < b) + // random b <= x < p2 + r = a.set_random_in_range(b, p2_1, m_tmp3, m_rand); + else { + // random b <= x or x < p2 + bool coin = m_rand() % 2 == 0; + if (coin) + r = a.set_random_at_most(p2_1, m_tmp3, m_rand); + if (!r) + r = a.set_random_at_least(b, m_tmp3, m_rand); + if (!r && !coin) + r = a.set_random_at_most(p2_1, m_tmp3, m_rand); + } + p2_1.set_bw(0); + return r; + } + + void sls_eval::add_p2_1(bvval const& a, bvect& t) const { + m_zero.set(a.bw - 1, true); + a.set_add(t, a.bits(), m_zero); + m_zero.set(a.bw - 1, false); + a.clear_overflow_bits(t); + } + + bool sls_eval::try_repair_ule(bool e, bvval& a, bvval const& b) { + if (e) { + // a <= t + return a.set_random_at_most(b.bits(), m_tmp, m_rand); + } + else { + // a > t + a.set_add(m_tmp, b.bits(), m_one); + if (a.is_zero(m_tmp)) + return false; + return a.set_random_at_least(m_tmp, m_tmp2, m_rand); + } + } + + bool sls_eval::try_repair_uge(bool e, bvval& a, bvval const& b) { + if (e) { + // a >= t + return a.set_random_at_least(b.bits(), m_tmp, m_rand); + } + else { + // a < t + if (b.is_zero()) + return false; + a.set_sub(m_tmp, b.bits(), m_one); + return a.set_random_at_most(m_tmp, m_tmp2, m_rand); + } + } + + bool sls_eval::try_repair_bit2bool(bvval& a, unsigned idx) { + return a.try_set_bit(idx, !a.get_bit(idx)); + } + + bool sls_eval::try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + unsigned sh = b.to_nat(b.bw); + if (sh == 0) + return a.try_set(e); + else if (sh >= b.bw) + return false; + else { + // + // e = a << sh + // set bw - sh low order bits to bw - sh high-order of e. + // a[bw - sh - 1: 0] = e[bw - 1: sh] + // a[bw - 1: bw - sh] = unchanged + // + for (unsigned i = 0; i < a.bw - sh; ++i) + m_tmp.set(i, e.get(sh + i)); + for (unsigned i = a.bw - sh; i < a.bw; ++i) + m_tmp.set(i, a.get_bit(i)); + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + } + else { + // NB. blind sub-range of possible values for b + SASSERT(i == 1); + unsigned sh = m_rand(a.bw + 1); + b.set(m_tmp, sh); + return b.try_set(m_tmp); + } + return false; + } + + bool sls_eval::try_repair_ashr(bvect const& e, bvval & a, bvval& b, unsigned i) { + if (i == 0) { + unsigned sh = b.to_nat(b.bw); + if (sh == 0) + return a.try_set(e); + else if (sh >= b.bw) { + if (e.get(a.bw - 1)) + return a.try_set_bit(a.bw - 1, true); + else + return a.try_set_bit(a.bw - 1, false); + } + else { + // e = a >> sh + // a[bw-1:sh] = e[bw-sh-1:0] + // a[sh-1:0] = a[sh-1:0] + // ignore sign + for (unsigned i = sh; i < a.bw; ++i) + m_tmp.set(i, e.get(i - sh)); + for (unsigned i = 0; i < sh; ++i) + m_tmp.set(i, a.get_bit(i)); + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + } + else { + // NB. blind sub-range of possible values for b + SASSERT(i == 1); + unsigned sh = m_rand(a.bw + 1); + b.set(m_tmp, sh); + return b.try_set(m_tmp); + } + } + + bool sls_eval::try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i) { + return try_repair_ashr(e, a, b, i); + } + + bool sls_eval::try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i) { + SASSERT(e[0] == 0 || e[0] == 1); + SASSERT(e.bw == 1); + return try_repair_eq(e[0] == 1, i == 0 ? a : b, i == 0 ? b : a); + } + + // e = a udiv b + // e = 0 => a != ones + // b = 0 => e = -1 // nothing to repair on a + // e != -1 => max(a) >=u e + + bool sls_eval::try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + if (a.is_zero(e) && a.is_ones(a.fixed) && a.is_ones()) + return false; + if (b.is_zero()) + return false; + if (!a.is_ones(e)) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~a.fixed[i] | a.bits()[i]; + a.clear_overflow_bits(m_tmp); + if (e > m_tmp) + return false; + } + // e = 1 => a := b + if (a.is_one(e)) { + a.set(m_tmp, b.bits()); + return a.set_repair(false, m_tmp); + } + // b * e + r = a + if (mul_overflow_on_fixed(b, e)) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + b.get_variant(m_tmp2, m_rand); + while (b.bits() < m_tmp2) + m_tmp2.set(b.msb(m_tmp2), false); + while (a.set_add(m_tmp3, m_tmp, m_tmp2)) + m_tmp2.set(b.msb(m_tmp2), false); + a.clear_overflow_bits(m_tmp3); + return a.set_repair(true, m_tmp3); + } + else { + if (a.is_one(e) && a.is_zero()) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + return b.set_repair(true, m_tmp); + } + if (a.is_one(e)) { + a.set(m_tmp, a.bits()); + return b.set_repair(true, m_tmp); + } + + // e * b + r = a + // b = (a - r) udiv e + // random version of r: + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + // ensure r <= m + while (a.bits() < m_tmp) + m_tmp.set(a.msb(m_tmp), false); + a.set_sub(m_tmp2, a.bits(), m_tmp); + set_div(m_tmp2, e, a.bw, m_tmp3, m_tmp4); + return b.set_repair(random_bool(), m_tmp4); + } + } + + // table III in Niemetz et al + // x urem s = t <=> + // ~(-s) >=u t + // ((s = 0 or t = ones) => mcb(x, t)) + // ((s != 0 and t != ones) => exists y . (mcb(x, s*y + t) and ~mulo(s, y) and ~addo(s*y, t)) + // s urem x = t <=> + // (s = t => x can be >u t) + // (s != t => exists y . (mcb(x, y) and y >u t and (s - t) mod y = 0) + + + bool sls_eval::try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i) { + + if (i == 0) { + if (b.is_zero()) { + a.set(m_tmp, e); + return a.set_repair(random_bool(), m_tmp); + } + // a urem b = e: b*y + e = a + // ~Ovfl*(b, y) + // ~Ovfl+(b*y, e) + // choose y at random + // lower y as long as y*b overflows with fixed bits in b + + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + while (mul_overflow_on_fixed(b, m_tmp)) { + auto i = b.msb(m_tmp); + m_tmp.set(i, false); + } + while (true) { + a.set_mul(m_tmp2, m_tmp, b.bits()); + if (!a.set_add(m_tmp3, m_tmp2, e)) + break; + auto i = b.msb(m_tmp); + m_tmp.set(i, false); + } + return a.set_repair(random_bool(), m_tmp3); + } + else { + // a urem b = e: b*y + e = a + // b*y = a - e + // b = (a - e) div y + // ~Ovfl*(b, y) + // ~Ovfl+(b*y, e) + // choose y at random + // lower y as long as y*b overflows with fixed bits in b + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.set_sub(m_tmp2, a.bits(), e); + set_div(m_tmp2, m_tmp, a.bw, m_tmp3, m_tmp4); + a.clear_overflow_bits(m_tmp3); + return b.set_repair(random_bool(), m_tmp3); + } + } + + bool sls_eval::add_overflow_on_fixed(bvval const& a, bvect const& t) { + a.set(m_tmp3, m_zero); + for (unsigned i = 0; i < a.nw; ++i) + m_tmp3[i] = a.fixed[i] & a.bits()[i]; + return a.set_add(m_tmp4, t, m_tmp3); + } + + bool sls_eval::mul_overflow_on_fixed(bvval const& a, bvect const& t) { + a.set(m_tmp3, m_zero); + for (unsigned i = 0; i < a.nw; ++i) + m_tmp3[i] = a.fixed[i] & a.bits()[i]; + return a.set_mul(m_tmp4, m_tmp3, t); + } + + bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const { + // a := rotate_right(e, n) + n = (a.bw - n) % a.bw; + for (unsigned i = a.bw - n; i < a.bw; ++i) + m_tmp.set(i + n - a.bw, e.get(i)); + for (unsigned i = 0; i < a.bw - n; ++i) + m_tmp.set(i + n, e.get(i)); + return a.set_repair(true, m_tmp); + } + + bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + rational n = b.get_value(); + n = mod(n, rational(b.bw)); + return try_repair_rotate_left(e, a, n.get_unsigned()); + } + else { + SASSERT(i == 1); + unsigned sh = m_rand(b.bw); + b.set(m_tmp, sh); + return b.set_repair(random_bool(), m_tmp); + } + } + + bool sls_eval::try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + rational n = b.get_value(); + n = mod(b.bw - n, rational(b.bw)); + return try_repair_rotate_left(e, a, n.get_unsigned()); + } + else { + SASSERT(i == 1); + unsigned sh = m_rand(b.bw); + b.set(m_tmp, sh); + return b.set_repair(random_bool(), m_tmp); + } + } + + bool sls_eval::try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i) { + if (e) { + // maximize + if (i == 0) { + a.max_feasible(m_tmp); + return a.set_repair(false, m_tmp); + } + else { + b.max_feasible(m_tmp); + return b.set_repair(false, m_tmp); + } + } + else { + // minimize + if (i == 0) { + a.min_feasible(m_tmp); + return a.set_repair(true, m_tmp); + } + else { + b.min_feasible(m_tmp); + return b.set_repair(true, m_tmp); + } + } + } + + // + // prefix of e must be 1s or 0 and match bit position of last bit in a. + // set a to suffix of e, matching signs. + // + bool sls_eval::try_repair_sign_ext(bvect const& e, bvval& a) { + for (unsigned i = a.bw; i < e.bw; ++i) + if (e.get(i) != e.get(a.bw - 1)) + return false; + + for (unsigned i = 0; i < e.nw; ++i) + m_tmp[i] = e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + // + // prefix of e must be 0s. + // + bool sls_eval::try_repair_zero_ext(bvect const& e, bvval& a) { + for (unsigned i = a.bw; i < e.bw; ++i) + if (e.get(i)) + return false; + + for (unsigned i = 0; i < e.nw; ++i) + m_tmp[i] = e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + bool sls_eval::try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned idx) { + bool r = false; + if (idx == 0) { + for (unsigned i = 0; i < a.bw; ++i) + m_tmp.set(i, e.get(i + b.bw)); + a.clear_overflow_bits(m_tmp); + r = a.try_set(m_tmp); + } + else { + for (unsigned i = 0; i < b.bw; ++i) + m_tmp.set(i, e.get(i)); + b.clear_overflow_bits(m_tmp); + r = b.try_set(m_tmp); + } + //verbose_stream() << e << " := " << a << " " << b << "\n"; + return r; + } + + // + // e = a[hi:lo], where hi = e.bw + lo - 1 + // for the randomized assignment, + // set a outside of [hi:lo] to random values with preference to 0 or 1 bits + // + bool sls_eval::try_repair_extract(bvect const& e, bvval& a, unsigned lo) { + if (m_rand() % m_config.m_prob_randomize_extract <= 100) { + a.get_variant(m_tmp, m_rand); + if (0 == (m_rand() % 2)) { + auto bit = 0 == (m_rand() % 2); + if (!a.try_set_range(m_tmp, 0, lo, bit)) + a.try_set_range(m_tmp, 0, lo, !bit); + } + if (0 == (m_rand() % 2)) { + auto bit = 0 == (m_rand() % 2); + if (!a.try_set_range(m_tmp, lo + e.bw, a.bw, bit)) + a.try_set_range(m_tmp, lo + e.bw, a.bw, !bit); + } + } + else + a.get(m_tmp); + for (unsigned i = 0; i < e.bw; ++i) + m_tmp.set(i + lo, e.get(i)); + if (a.try_set(m_tmp)) + return true; + a.get_variant(m_tmp, m_rand); + bool res = a.set_repair(random_bool(), m_tmp); + // verbose_stream() << "try set " << res << " " << m_tmp[0] << " " << a << "\n"; + return res; + } + + void sls_eval::set_div(bvect const& a, bvect const& b, unsigned bw, + bvect& quot, bvect& rem) const { + unsigned nw = (bw + 8 * sizeof(digit_t) - 1) / (8 * sizeof(digit_t)); + unsigned bnw = nw; + while (bnw > 1 && b[bnw - 1] == 0) + --bnw; + if (b[bnw-1] == 0) { + for (unsigned i = 0; i < nw; ++i) { + quot[i] = ~0; + rem[i] = 0; + } + quot[nw - 1] = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + } + else { + for (unsigned i = 0; i < nw; ++i) + rem[i] = quot[i] = 0; + mpn.div(a.data(), nw, b.data(), bnw, quot.data(), rem.data()); + } + } + + bool sls_eval::repair_up(expr* e) { + if (!is_app(e)) + return false; + if (m.is_bool(e)) { + auto b = bval1(to_app(e)); + if (is_fixed0(e)) + return b == bval0(e); + m_eval[e->get_id()] = b; + return true; + } + if (bv.is_bv(e)) { + auto& v = eval(to_app(e)); + // verbose_stream() << "committing: " << v << "\n"; + for (unsigned i = 0; i < v.nw; ++i) + if (0 != (v.fixed[i] & (v.bits()[i] ^ v.eval[i]))) { + v.bits().copy_to(v.nw, v.eval); + return false; + } + if (v.commit_eval()) + return true; + v.bits().copy_to(v.nw, v.eval); + return false; + } + return false; + } + + sls_valuation& sls_eval::wval(expr* e) const { + // if (!m_values[e->get_id()]) verbose_stream() << mk_bounded_pp(e, m) << "\n"; + return *m_values[e->get_id()]; + } + + std::ostream& sls_eval::display(std::ostream& out, expr_ref_vector const& es) { + auto& terms = sort_assertions(es); + for (expr* e : terms) { + out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; + if (is_fixed0(e)) + out << "f "; + if (bv.is_bv(e)) + out << wval(e); + else if (m.is_bool(e)) + out << (bval0(e) ? "T" : "F"); + out << "\n"; + } + terms.reset(); + return out; + } +} diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h new file mode 100644 index 00000000000..5422d5b7c02 --- /dev/null +++ b/src/ast/sls/bv_sls_eval.h @@ -0,0 +1,178 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/sls/sls_valuation.h" +#include "ast/sls/bv_sls_fixed.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_fixed; + + class sls_eval { + struct config { + unsigned m_prob_randomize_extract = 50; + }; + + friend class sls_fixed; + friend class sls_test; + ast_manager& m; + bv_util bv; + sls_fixed m_fix; + mutable mpn_manager mpn; + ptr_vector m_todo; + random_gen m_rand; + config m_config; + + + + scoped_ptr_vector m_values; // expr-id -> bv valuation + bool_vector m_eval; // expr-id -> boolean valuation + bool_vector m_fixed; // expr-id -> is Boolean fixed + + mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one; + bvect m_a, m_b, m_nextb, m_nexta, m_aux; + + using bvval = sls_valuation; + + + void init_eval_basic(app* e); + void init_eval_bv(app* e); + + /** + * Register e as a bit-vector. + * Return true if not already registered, false if already registered. + */ + bool add_bit_vector(app* e); + sls_valuation* alloc_valuation(app* e); + + bool bval1_basic(app* e) const; + bool bval1_bv(app* e) const; + + /** + * Repair operations + */ + bool try_repair_basic(app* e, unsigned i); + bool try_repair_bv(app * e, unsigned i); + bool try_repair_and_or(app* e, unsigned i); + bool try_repair_not(app* e); + bool try_repair_eq(app* e, unsigned i); + bool try_repair_xor(app* e, unsigned i); + bool try_repair_ite(app* e, unsigned i); + bool try_repair_implies(app* e, unsigned i); + bool try_repair_band(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_add(bvect const& e, bvval& a, bvval const& b); + bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_mul(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bnot(bvect const& e, bvval& a); + bool try_repair_bneg(bvect const& e, bvval& a); + bool try_repair_ule(bool e, bvval& a, bvval const& b); + bool try_repair_uge(bool e, bvval& a, bvval const& b); + bool try_repair_sle(bool e, bvval& a, bvval const& b); + bool try_repair_sge(bool e, bvval& a, bvval const& b); + bool try_repair_sge(bvval& a, bvect const& b, bvect const& p2); + bool try_repair_sle(bvval& a, bvect const& b, bvect const& p2); + bool try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_ashr(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_bit2bool(bvval& a, unsigned idx); + bool try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const; + bool try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_ule(bool e, bvval& a, bvect const& t); + bool try_repair_uge(bool e, bvval& a, bvect const& t); + bool try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i); + bool try_repair_zero_ext(bvect const& e, bvval& a); + bool try_repair_sign_ext(bvect const& e, bvval& a); + bool try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_extract(bvect const& e, bvval& a, unsigned lo); + bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_eq(bool is_true, bvval& a, bvval const& b); + void add_p2_1(bvval const& a, bvect& t) const; + + bool add_overflow_on_fixed(bvval const& a, bvect const& t); + bool mul_overflow_on_fixed(bvval const& a, bvect const& t); + void set_div(bvect const& a, bvect const& b, unsigned nw, + bvect& quot, bvect& rem) const; + + digit_t random_bits(); + bool random_bool() { return m_rand() % 2 == 0; } + + sls_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); } + + void eval(app* e, sls_valuation& val) const; + + bvect const& eval_value(app* e) const { return wval(e).eval; } + + public: + sls_eval(ast_manager& m); + + void init_eval(expr_ref_vector const& es, std::function const& eval); + + void tighten_range(expr_ref_vector const& es) { m_fix.init(es); } + + ptr_vector& sort_assertions(expr_ref_vector const& es); + + /** + * Retrieve evaluation based on cache. + * bval - Boolean values + * wval - Word (bit-vector) values + */ + + bool bval0(expr* e) const { return m_eval[e->get_id()]; } + + sls_valuation& wval(expr* e) const; + + bool is_fixed0(expr* e) const { return m_fixed.get(e->get_id(), false); } + + /** + * Retrieve evaluation based on immediate children. + */ + bool bval1(app* e) const; + bool can_eval1(app* e) const; + + sls_valuation& eval(app* e) const; + + /** + * Override evaluaton. + */ + + void set(expr* e, bool b) { + m_eval[e->get_id()] = b; + } + + /* + * Try to invert value of child to repair value assignment of parent. + */ + + bool try_repair(app* e, unsigned i); + + /* + * Propagate repair up to parent + */ + bool repair_up(expr* e); + + + std::ostream& display(std::ostream& out, expr_ref_vector const& es); + }; +} diff --git a/src/ast/sls/bv_sls_fixed.cpp b/src/ast/sls/bv_sls_fixed.cpp new file mode 100644 index 00000000000..91ce8e0e2a5 --- /dev/null +++ b/src/ast/sls/bv_sls_fixed.cpp @@ -0,0 +1,423 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_fixed.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls_fixed.h" +#include "ast/sls/bv_sls_eval.h" + +namespace bv { + + sls_fixed::sls_fixed(sls_eval& ev): + ev(ev), + m(ev.m), + bv(ev.bv) + {} + + void sls_fixed::init(expr_ref_vector const& es) { + ev.sort_assertions(es); + for (expr* e : ev.m_todo) { + if (!is_app(e)) + continue; + app* a = to_app(e); + ev.m_fixed.setx(a->get_id(), is_fixed1(a), false); + if (a->get_family_id() == basic_family_id) + init_fixed_basic(a); + else if (a->get_family_id() == bv.get_family_id()) + init_fixed_bv(a); + else + ; + } + ev.m_todo.reset(); + init_ranges(es); + } + + + void sls_fixed::init_ranges(expr_ref_vector const& es) { + for (expr* e : es) { + bool sign = m.is_not(e, e); + if (is_app(e)) + init_range(to_app(e), sign); + } + } + + // s <=s t <=> s + K <= t + K, K = 2^{bw-1} + + void sls_fixed::init_range(app* e, bool sign) { + expr* s, * t, * x, * y; + rational a, b; + unsigned idx; + auto N = [&](expr* s) { + auto b = bv.get_bv_size(s); + return b > 0 ? rational::power_of_two(b - 1) : rational(0); + }; + if (bv.is_ule(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a, y, b, sign); + } + else if (bv.is_ult(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b, x, a, !sign); + } + else if (bv.is_uge(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b, x, a, sign); + } + else if (bv.is_ugt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a, y, b, !sign); + } + else if (bv.is_sle(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a + N(s), y, b + N(s), sign); + } + else if (bv.is_slt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b + N(s), x, a + N(s), !sign); + } + else if (bv.is_sge(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b + N(s), x, a + N(s), sign); + } + else if (bv.is_sgt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a + N(s), y, b + N(s), !sign); + } + else if (!sign && m.is_eq(e, s, t)) { + if (bv.is_numeral(s, a)) + // t - a <= 0 + init_range(t, -a, nullptr, rational(0), false); + else if (bv.is_numeral(t, a)) + init_range(s, -a, nullptr, rational(0), false); + } + else if (bv.is_bit2bool(e, s, idx)) { + auto& val = wval(s); + val.try_set_bit(idx, !sign); + val.fixed.set(idx, true); + val.tighten_range(); + } + } + + // + // x + a <= b <=> x in [-a, b - a + 1[ b != -1 + // a <= x + b <=> x in [a - b, -b[ a != 0 + // x + a <= x + b <=> x in [-a, -b[ a != b + // + // x + a < b <=> ! (b <= x + a) <=> x not in [-b, a - b + 1[ <=> x in [a - b + 1, -b [ b != 0 + // a < x + b <=> ! (x + b <= a) <=> x not in [-a, b - a [ <=> x in [b - a, -a [ a != -1 + // x + a < x + b <=> ! (x + b <= x + a) <=> x in [-a, -b [ a != b + // + void sls_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) { + if (!x && !y) + return; + if (!x) { + // a <= y + b + if (a == 0) + return; + auto& v = wval(y); + if (!sign) + v.add_range(a - b, -b); + else + v.add_range(-b, a - b); + } + else if (!y) { + + if (mod(b + 1, rational::power_of_two(bv.get_bv_size(x))) == 0) + return; + auto& v = wval(x); + if (!sign) + v.add_range(-a, b - a + 1); + else + v.add_range(b - a + 1, -a); + } + else if (x == y) { + if (a == b) + return; + auto& v = wval(x); + if (!sign) + v.add_range(-a, -b); + else + v.add_range(-b, -a); + } + + } + + void sls_fixed::get_offset(expr* e, expr*& x, rational& offset) { + expr* s, * t; + x = e; + offset = 0; + if (bv.is_bv_add(e, s, t)) { + if (bv.is_numeral(s, offset)) + x = t; + else if (bv.is_numeral(t, offset)) + x = s; + } + else if (bv.is_numeral(e, offset)) + x = nullptr; + } + + sls_valuation& sls_fixed::wval(expr* e) { + return ev.wval(e); + } + + void sls_fixed::init_fixed_basic(app* e) { + if (bv.is_bv(e) && m.is_ite(e)) { + auto& val = wval(e); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + for (unsigned i = 0; i < val.nw; ++i) + val.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i)); + } + } + + void sls_fixed::init_fixed_bv(app* e) { + if (bv.is_bv(e)) + set_fixed_bw(e); + } + + bool sls_fixed::is_fixed1(app* e) const { + if (is_uninterp(e)) + return false; + if (e->get_family_id() == basic_family_id) + return is_fixed1_basic(e); + return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); + } + + bool sls_fixed::is_fixed1_basic(app* e) const { + switch (e->get_decl_kind()) { + case OP_TRUE: + case OP_FALSE: + return true; + case OP_AND: + return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && !ev.bval0(e); }); + case OP_OR: + return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && ev.bval0(e); }); + default: + return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); + } + } + + void sls_fixed::set_fixed_bw(app* e) { + SASSERT(bv.is_bv(e)); + SASSERT(e->get_family_id() == bv.get_fid()); + auto& v = ev.wval(e); + if (all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) { + for (unsigned i = 0; i < v.bw; ++i) + v.fixed.set(i, true); + ev.m_fixed.setx(e->get_id(), true, false); + return; + } + switch (e->get_decl_kind()) { + case OP_BAND: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i)); + break; + } + case OP_BOR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i)); + break; + } + case OP_BXOR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = a.fixed[i] & b.fixed[i]; + break; + } + case OP_BNOT: { + auto& a = wval(e->get_arg(0)); + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = a.fixed[i]; + break; + } + case OP_BADD: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + rational r; + if (bv.is_numeral(e->get_arg(0), r) && b.has_range()) + v.add_range(r + b.lo(), r + b.hi()); + else if (bv.is_numeral(e->get_arg(1), r) && a.has_range()) + v.add_range(r + a.lo(), r + a.hi()); + bool pfixed = true; + for (unsigned i = 0; i < v.bw; ++i) { + if (pfixed && a.fixed.get(i) && b.fixed.get(i)) + v.fixed.set(i, true); + else if (!pfixed && a.fixed.get(i) && b.fixed.get(i) && + !a.get_bit(i) && !b.get_bit(i)) { + pfixed = true; + v.fixed.set(i, false); + } + else { + pfixed = false; + v.fixed.set(i, false); + } + } + + break; + } + case OP_BMUL: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0; + // i'th bit depends on bits j + k = i + // if the first j, resp k bits are 0, the bits j + k are 0 + for (; j < v.bw; ++j) + if (!a.fixed.get(j)) + break; + for (; k < v.bw; ++k) + if (!b.fixed.get(k)) + break; + for (; zj < v.bw; ++zj) + if (!a.fixed.get(zj) || a.get_bit(zj)) + break; + for (; zk < v.bw; ++zk) + if (!b.fixed.get(zk) || b.get_bit(zk)) + break; + for (; hzj < v.bw; ++hzj) + if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1)) + break; + for (; hzk < v.bw; ++hzk) + if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1)) + break; + + + if (j > 0 && k > 0) { + for (unsigned i = 0; i < std::min(k, j); ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + // lower zj + jk bits are 0 + if (zk > 0 || zj > 0) { + for (unsigned i = 0; i < zk + zj; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + // upper bits are 0, if enough high order bits of a, b are 0. + // TODO - buggy + if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) { + hzj = v.bw - hzj; + hzk = v.bw - hzk; + for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + break; + } + case OP_CONCAT: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < b.bw; ++i) + v.fixed.set(i, b.fixed.get(i)); + for (unsigned i = 0; i < a.bw; ++i) + v.fixed.set(i + b.bw, a.fixed.get(i)); + break; + } + case OP_EXTRACT: { + expr* child; + unsigned lo, hi; + VERIFY(bv.is_extract(e, lo, hi, child)); + auto& a = wval(child); + for (unsigned i = lo; i <= hi; ++i) + v.fixed.set(i - lo, a.fixed.get(i)); + break; + } + case OP_BNEG: { + auto& a = wval(e->get_arg(0)); + bool pfixed = true; + for (unsigned i = 0; i < v.bw; ++i) { + if (pfixed && a.fixed.get(i)) + v.fixed.set(i, true); + else { + pfixed = false; + v.fixed.set(i, false); + } + } + break; + } + case OP_BSHL: { + // determine range of b. + // if b = 0, then inherit fixed from a + // if b >= v.bw then make e fixed to 0 + // if 0 < b < v.bw is known, then inherit shift of fixed values of a + // if 0 < b < v.bw but not known, then inherit run lengths of equal bits of a + // that are fixed. + break; + } + + case OP_BASHR: + case OP_BLSHR: + case OP_INT2BV: + case OP_BCOMP: + case OP_BNAND: + case OP_BREDAND: + case OP_BREDOR: + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: + case OP_BXNOR: + // NOT_IMPLEMENTED_YET(); + break; + case OP_BV_NUM: + case OP_BIT0: + case OP_BIT1: + case OP_BV2INT: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BUADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + case OP_BUMUL_NO_OVFL: + case OP_BUMUL_OVFL: + case OP_BIT2BOOL: + case OP_ULEQ: + case OP_UGEQ: + case OP_UGT: + case OP_ULT: + case OP_SLEQ: + case OP_SGEQ: + case OP_SGT: + case OP_SLT: + UNREACHABLE(); + break; + } + } +} diff --git a/src/ast/sls/bv_sls_fixed.h b/src/ast/sls/bv_sls_fixed.h new file mode 100644 index 00000000000..14970c20c69 --- /dev/null +++ b/src/ast/sls/bv_sls_fixed.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_fixed.h + +Abstract: + + Initialize fixed information. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/sls/sls_valuation.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_eval; + + class sls_fixed { + sls_eval& ev; + ast_manager& m; + bv_util& bv; + + void init_ranges(expr_ref_vector const& es); + void init_range(app* e, bool sign); + void init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign); + void get_offset(expr* e, expr*& x, rational& offset); + + void init_fixed_basic(app* e); + void init_fixed_bv(app* e); + + bool is_fixed1(app* e) const; + bool is_fixed1_basic(app* e) const; + void set_fixed_bw(app* e); + + sls_valuation& wval(expr* e); + + public: + sls_fixed(sls_eval& ev); + + void init(expr_ref_vector const& es); + + }; +} diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp new file mode 100644 index 00000000000..8702c3c480c --- /dev/null +++ b/src/ast/sls/bv_sls_terms.cpp @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls.h" + +namespace bv { + + sls_terms::sls_terms(ast_manager& m): + m(m), + bv(m), + m_assertions(m), + m_pinned(m), + m_translated(m), + m_terms(m){} + + + void sls_terms::assert_expr(expr* e) { + m_assertions.push_back(ensure_binary(e)); + } + + expr* sls_terms::ensure_binary(expr* e) { + expr* top = e; + m_pinned.push_back(e); + m_todo.push_back(e); + expr_fast_mark1 mark; + for (unsigned i = 0; i < m_todo.size(); ++i) { + expr* e = m_todo[i]; + if (!is_app(e)) + continue; + if (m_translated.get(e->get_id(), nullptr)) + continue; + if (mark.is_marked(e)) + continue; + mark.mark(e); + for (auto arg : *to_app(e)) + m_todo.push_back(arg); + } + std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : m_todo) + ensure_binary_core(e); + m_todo.reset(); + return m_translated.get(top->get_id()); + } + + void sls_terms::ensure_binary_core(expr* e) { + if (m_translated.get(e->get_id(), nullptr)) + return; + + app* a = to_app(e); + auto arg = [&](unsigned i) { + return m_translated.get(a->get_arg(i)->get_id()); + }; + unsigned num_args = a->get_num_args(); + expr_ref r(m); +#define FOLD_OP(oper) \ + r = arg(0); \ + for (unsigned i = 1; i < num_args; ++i)\ + r = oper(r, arg(i)); \ + + if (m.is_and(e)) { + FOLD_OP(m.mk_and); + } + else if (m.is_or(e)) { + FOLD_OP(m.mk_or); + } + else if (m.is_xor(e)) { + FOLD_OP(m.mk_xor); + } + else if (bv.is_bv_and(e)) { + FOLD_OP(bv.mk_bv_and); + } + else if (bv.is_bv_or(e)) { + FOLD_OP(bv.mk_bv_or); + } + else if (bv.is_bv_xor(e)) { + FOLD_OP(bv.mk_bv_xor); + } + else if (bv.is_bv_add(e)) { + FOLD_OP(bv.mk_bv_add); + } + else if (bv.is_bv_mul(e)) { + FOLD_OP(bv.mk_bv_mul); + } + else if (bv.is_concat(e)) { + FOLD_OP(bv.mk_concat); + } + else if (m.is_distinct(e)) { + expr_ref_vector es(m); + for (unsigned i = 0; i < num_args; ++i) + for (unsigned j = i + 1; j < num_args; ++j) + es.push_back(m.mk_not(m.mk_eq(arg(i), arg(j)))); + r = m.mk_and(es); + } + else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) { + r = mk_sdiv(arg(0), arg(1)); + } + else if (bv.is_bv_smod(e) || bv.is_bv_smod0(e) || bv.is_bv_smodi(e)) { + r = mk_smod(arg(0), arg(1)); + } + else if (bv.is_bv_srem(e) || bv.is_bv_srem0(e) || bv.is_bv_sremi(e)) { + r = mk_srem(arg(0), arg(1)); + } + else { + for (unsigned i = 0; i < num_args; ++i) + m_args.push_back(arg(i)); + r = m.mk_app(a->get_decl(), num_args, m_args.data()); + m_args.reset(); + } + m_translated.setx(e->get_id(), r); + } + + expr* sls_terms::mk_sdiv(expr* x, expr* y) { + // d = udiv(abs(x), abs(y)) + // y = 0, x >= 0 -> -1 + // y = 0, x < 0 -> 1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + unsigned sz = bv.get_bv_size(x); + rational N = rational::power_of_two(sz); + expr_ref z(bv.mk_zero(sz), m); + expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x); + expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y); + expr* absx = m.mk_ite(signx, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), x), x); + expr* absy = m.mk_ite(signy, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), y), y); + expr* d = bv.mk_bv_udiv(absx, absy); + expr* r = m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d)); + r = m.mk_ite(m.mk_eq(z, y), + m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)), + m.mk_ite(m.mk_eq(x, z), z, r)); + return r; + } + + expr* sls_terms::mk_smod(expr* x, expr* y) { + // u := umod(abs(x), abs(y)) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + unsigned sz = bv.get_bv_size(x); + expr_ref z(bv.mk_zero(sz), m); + expr_ref abs_x(m.mk_ite(bv.mk_sle(z, x), x, bv.mk_bv_neg(x)), m); + expr_ref abs_y(m.mk_ite(bv.mk_sle(z, y), y, bv.mk_bv_neg(y)), m); + expr_ref u(bv.mk_bv_urem(abs_x, abs_y), m); + return + m.mk_ite(m.mk_eq(u, z), z, + m.mk_ite(m.mk_eq(y, z), x, + m.mk_ite(m.mk_and(bv.mk_sle(z, x), bv.mk_sle(z, x)), u, + m.mk_ite(bv.mk_sle(z, x), bv.mk_bv_add(y, u), + m.mk_ite(bv.mk_sle(z, y), bv.mk_bv_sub(y, u), bv.mk_bv_neg(u)))))); + + } + + expr* sls_terms::mk_srem(expr* x, expr* y) { + // y = 0 -> x + // else x - sdiv(x, y) * y + return + m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))), + x, bv.mk_bv_sub(x, bv.mk_bv_mul(y, mk_sdiv(x, y)))); + } + + + void sls_terms::init() { + // populate terms + expr_fast_mark1 mark; + for (expr* e : m_assertions) + m_todo.push_back(e); + while (!m_todo.empty()) { + expr* e = m_todo.back(); + m_todo.pop_back(); + if (mark.is_marked(e) || !is_app(e)) + continue; + mark.mark(e); + m_terms.setx(e->get_id(), to_app(e)); + for (expr* arg : *to_app(e)) + m_todo.push_back(arg); + } + // populate parents + m_parents.reserve(m_terms.size()); + for (expr* e : m_terms) { + if (!e || !is_app(e)) + continue; + for (expr* arg : *to_app(e)) + m_parents[arg->get_id()].push_back(e); + } + for (auto a : m_assertions) + m_assertion_set.insert(a->get_id()); + } + +} diff --git a/src/ast/sls/bv_sls_terms.h b/src/ast/sls/bv_sls_terms.h new file mode 100644 index 00000000000..3baffc78e8a --- /dev/null +++ b/src/ast/sls/bv_sls_terms.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_terms.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_valuation.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_terms { + ast_manager& m; + bv_util bv; + ptr_vector m_todo, m_args; + expr_ref_vector m_assertions, m_pinned, m_translated; + app_ref_vector m_terms; + vector> m_parents; + tracked_uint_set m_assertion_set; + + expr* ensure_binary(expr* e); + void ensure_binary_core(expr* e); + + expr* mk_sdiv(expr* x, expr* y); + expr* mk_smod(expr* x, expr* y); + expr* mk_srem(expr* x, expr* y); + + public: + sls_terms(ast_manager& m); + + /** + * Add constraints + */ + void assert_expr(expr* e); + + /** + * Initialize structures: assertions, parents, terms + */ + void init(); + + /** + * Accessors. + */ + + ptr_vector const& parents(expr* e) const { return m_parents[e->get_id()]; } + + expr_ref_vector const& assertions() const { return m_assertions; } + + app* term(unsigned id) const { return m_terms.get(id); } + + app_ref_vector const& terms() const { return m_terms; } + + bool is_assertion(expr* e) const { return m_assertion_set.contains(e->get_id()); } + + }; +} diff --git a/src/tactic/sls/bvsls_opt_engine.cpp b/src/ast/sls/bvsls_opt_engine.cpp similarity index 98% rename from src/tactic/sls/bvsls_opt_engine.cpp rename to src/ast/sls/bvsls_opt_engine.cpp index 37454ca7213..7dc71cd8cdf 100644 --- a/src/tactic/sls/bvsls_opt_engine.cpp +++ b/src/ast/sls/bvsls_opt_engine.cpp @@ -17,7 +17,7 @@ Module Name: --*/ #include "ast/normal_forms/nnf.h" -#include "tactic/sls/bvsls_opt_engine.h" +#include "ast/sls/bvsls_opt_engine.h" bvsls_opt_engine::bvsls_opt_engine(ast_manager & m, params_ref const & p) : sls_engine(m, p), @@ -68,7 +68,8 @@ bvsls_opt_engine::optimization_result bvsls_opt_engine::optimize( if (is_sat != l_true) { do { - checkpoint(); + if (!m_manager.inc()) + return res; IF_VERBOSE(1, verbose_stream() << "Satisfying... restarts left:" << (m_max_restarts - m_stats.m_restarts) << std::endl;); is_sat = search(); @@ -136,7 +137,8 @@ expr_ref bvsls_opt_engine::maximize() while (m_mpz_manager.lt(score, max_score) && check_restart(m_stats.m_moves)) { - checkpoint(); + if (!m_manager.inc()) + goto bailout; m_stats.m_moves++; m_mpz_manager.set(old_score, score); new_const = (unsigned)-1; diff --git a/src/tactic/sls/bvsls_opt_engine.h b/src/ast/sls/bvsls_opt_engine.h similarity index 98% rename from src/tactic/sls/bvsls_opt_engine.h rename to src/ast/sls/bvsls_opt_engine.h index c6b3857afe9..435fa3af483 100644 --- a/src/tactic/sls/bvsls_opt_engine.h +++ b/src/ast/sls/bvsls_opt_engine.h @@ -18,7 +18,7 @@ Module Name: --*/ #pragma once -#include "tactic/sls/sls_engine.h" +#include "ast/sls/sls_engine.h" class bvsls_opt_engine : public sls_engine { sls_tracker & m_hard_tracker; diff --git a/src/tactic/sls/sls_engine.cpp b/src/ast/sls/sls_engine.cpp similarity index 90% rename from src/tactic/sls/sls_engine.cpp rename to src/ast/sls/sls_engine.cpp index 58676edfb39..249c771ed32 100644 --- a/src/tactic/sls/sls_engine.cpp +++ b/src/ast/sls/sls_engine.cpp @@ -23,11 +23,10 @@ Module Name: #include "ast/ast_pp.h" #include "ast/rewriter/var_subst.h" #include "model/model_pp.h" -#include "tactic/tactic.h" #include "util/luby.h" -#include "tactic/sls/sls_params.hpp" -#include "tactic/sls/sls_engine.h" +#include "params/sls_params.hpp" +#include "ast/sls/sls_engine.h" sls_engine::sls_engine(ast_manager & m, params_ref const & p) : @@ -52,7 +51,6 @@ sls_engine::~sls_engine() { void sls_engine::updt_params(params_ref const & _p) { sls_params p(_p); - m_produce_models = _p.get_bool("model", false); m_max_restarts = p.max_restarts(); m_tracker.set_random_seed(p.random_seed()); m_walksat = p.walksat(); @@ -78,28 +76,13 @@ void sls_engine::updt_params(params_ref const & _p) { NOT_IMPLEMENTED_YET(); } -void sls_engine::collect_statistics(statistics& st) const { - double seconds = m_stats.m_stopwatch.get_current_seconds(); - st.update("sls restarts", m_stats.m_restarts); - st.update("sls full evals", m_stats.m_full_evals); - st.update("sls incr evals", m_stats.m_incr_evals); - st.update("sls incr evals/sec", m_stats.m_incr_evals / seconds); - st.update("sls FLIP moves", m_stats.m_flips); - st.update("sls INC moves", m_stats.m_incs); - st.update("sls DEC moves", m_stats.m_decs); - st.update("sls INV moves", m_stats.m_invs); - st.update("sls moves", m_stats.m_moves); - st.update("sls moves/sec", m_stats.m_moves / seconds); -} -void sls_engine::checkpoint() { - tactic::checkpoint(m_manager); -} bool sls_engine::full_eval(model & mdl) { model::scoped_model_completion _scm(mdl, true); for (expr* a : m_assertions) { - checkpoint(); + if (!m_manager.inc()) + return false; if (!mdl.is_true(a)) { TRACE("sls", tout << "Evaluation: false\n";); return false; @@ -423,7 +406,8 @@ lbool sls_engine::search() { unsigned sz = m_assertions.size(); while (check_restart(m_stats.m_moves)) { - checkpoint(); + if (!m_manager.inc()) + return l_undef; m_stats.m_moves++; // Andreas: Every base restart interval ... @@ -523,38 +507,6 @@ lbool sls_engine::search() { return res; } -void sls_engine::operator()(goal_ref const & g, model_converter_ref & mc) { - if (g->inconsistent()) { - mc = nullptr; - return; - } - - m_produce_models = g->models_enabled(); - - for (unsigned i = 0; i < g->size(); i++) - assert_expr(g->form(i)); - - lbool res = operator()(); - - if (res == l_true) { - report_tactic_progress("Number of flips:", m_stats.m_moves); - for (unsigned i = 0; i < g->size(); i++) - if (!m_mpz_manager.is_one(m_tracker.get_value(g->form(i)))) - { - verbose_stream() << "Terminated before all assertions were SAT!" << std::endl; - NOT_IMPLEMENTED_YET(); - } - - if (m_produce_models) { - model_ref mdl = m_tracker.get_model(); - mc = model2model_converter(mdl.get()); - TRACE("sls_model", mc->display(tout);); - } - g->reset(); - } - else - mc = nullptr; -} lbool sls_engine::operator()() { m_tracker.initialize(m_assertions); @@ -565,9 +517,10 @@ lbool sls_engine::operator()() { lbool res = l_undef; do { - checkpoint(); + if (!m_manager.inc()) + return l_undef; - report_tactic_progress("Searching... restarts left:", m_max_restarts - m_stats.m_restarts); + // report_tactic_progress("Searching... restarts left:", m_max_restarts - m_stats.m_restarts); res = search(); if (res == l_undef) diff --git a/src/tactic/sls/sls_engine.h b/src/ast/sls/sls_engine.h similarity index 72% rename from src/tactic/sls/sls_engine.h rename to src/ast/sls/sls_engine.h index 5f290c626bb..614534f1a86 100644 --- a/src/tactic/sls/sls_engine.h +++ b/src/ast/sls/sls_engine.h @@ -21,48 +21,19 @@ Module Name: #include "util/stopwatch.h" #include "util/lbool.h" #include "ast/converters/model_converter.h" -#include "tactic/goal.h" -#include "tactic/sls/sls_tracker.h" -#include "tactic/sls/sls_evaluator.h" -#include "util/statistics.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_tracker.h" +#include "ast/sls/sls_evaluator.h" class sls_engine { -public: - class stats { - public: - unsigned m_restarts; - stopwatch m_stopwatch; - unsigned m_full_evals; - unsigned m_incr_evals; - unsigned m_moves, m_flips, m_incs, m_decs, m_invs; - - stats() : - m_restarts(0), - m_full_evals(0), - m_incr_evals(0), - m_moves(0), - m_flips(0), - m_incs(0), - m_decs(0), - m_invs(0) { - m_stopwatch.reset(); - m_stopwatch.start(); - } - void reset() { - m_full_evals = m_flips = m_incr_evals = 0; - m_stopwatch.reset(); - m_stopwatch.start(); - } - }; protected: ast_manager & m_manager; - stats m_stats; + bv::sls_stats m_stats; unsynch_mpz_manager m_mpz_manager; powers m_powers; mpz m_zero, m_one, m_two; - bool m_produce_models; bv_util m_bv_util; sls_tracker m_tracker; sls_evaluator m_evaluator; @@ -96,8 +67,8 @@ class sls_engine { void assert_expr(expr * e) { m_assertions.push_back(e); } - // stats const & get_stats(void) { return m_stats; } - void collect_statistics(statistics & st) const; + bv::sls_stats const & get_stats(void) { return m_stats; } + void collect_statistics(statistics & st) const { m_stats.collect_statistics(st); } void reset_statistics() { m_stats.reset(); } bool full_eval(model & mdl); @@ -111,10 +82,14 @@ class sls_engine { lbool search(); lbool operator()(); - void operator()(goal_ref const & g, model_converter_ref & mc); + + mpz & get_value(expr * n) { return m_tracker.get_value(n); } + + model_ref get_model() { return m_tracker.get_model(); } + + unsynch_mpz_manager& get_mpz_manager() { return m_mpz_manager; } protected: - void checkpoint(); bool what_if(func_decl * fd, const unsigned & fd_inx, const mpz & temp, double & best_score, unsigned & best_const, mpz & best_value); @@ -135,5 +110,7 @@ class sls_engine { //double get_restart_armin(unsigned cnt_restarts); unsigned check_restart(unsigned curr_value); + + }; diff --git a/src/tactic/sls/sls_evaluator.h b/src/ast/sls/sls_evaluator.h similarity index 99% rename from src/tactic/sls/sls_evaluator.h rename to src/ast/sls/sls_evaluator.h index d386ece1592..2ee03c928f4 100644 --- a/src/tactic/sls/sls_evaluator.h +++ b/src/ast/sls/sls_evaluator.h @@ -21,8 +21,8 @@ Module Name: #include "model/model_evaluator.h" -#include "tactic/sls/sls_powers.h" -#include "tactic/sls/sls_tracker.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_tracker.h" class sls_evaluator { ast_manager & m_manager; diff --git a/src/tactic/sls/sls_powers.h b/src/ast/sls/sls_powers.h similarity index 97% rename from src/tactic/sls/sls_powers.h rename to src/ast/sls/sls_powers.h index 9616c43abcd..80ccbe04f42 100644 --- a/src/tactic/sls/sls_powers.h +++ b/src/ast/sls/sls_powers.h @@ -20,6 +20,7 @@ Module Name: #pragma once #include "util/mpz.h" +#include "util/map.h" class powers : public u_map { unsynch_mpz_manager & m; diff --git a/src/ast/sls/sls_stats.h b/src/ast/sls/sls_stats.h new file mode 100644 index 00000000000..9468e9c8d37 --- /dev/null +++ b/src/ast/sls/sls_stats.h @@ -0,0 +1,51 @@ +#pragma once +#include "util/statistics.h" +#include "util/stopwatch.h" + + +namespace bv { + class sls_stats { + public: + unsigned m_restarts; + stopwatch m_stopwatch; + unsigned m_full_evals; + unsigned m_incr_evals; + unsigned m_moves, m_flips, m_incs, m_decs, m_invs; + + sls_stats() : + m_restarts(0), + m_full_evals(0), + m_incr_evals(0), + m_moves(0), + m_flips(0), + m_incs(0), + m_decs(0), + m_invs(0) { + m_stopwatch.reset(); + m_stopwatch.start(); + } + void reset() { + m_full_evals = m_flips = m_incr_evals = 0; + m_stopwatch.reset(); + m_stopwatch.start(); + } + + void collect_statistics(statistics& st) const { + double seconds = m_stopwatch.get_current_seconds(); + st.update("sls restarts", m_restarts); + st.update("sls full evals", m_full_evals); + st.update("sls incr evals", m_incr_evals); + if (seconds > 0 && m_incr_evals > 0) + st.update("sls incr evals/sec", m_incr_evals / seconds); + if (seconds > 0 && m_moves > 0) + st.update("sls moves/sec", m_moves / seconds); + st.update("sls FLIP moves", m_flips); + st.update("sls INC moves", m_incs); + st.update("sls DEC moves", m_decs); + st.update("sls INV moves", m_invs); + st.update("sls moves", m_moves); + + } + + }; +} diff --git a/src/tactic/sls/sls_tracker.h b/src/ast/sls/sls_tracker.h similarity index 99% rename from src/tactic/sls/sls_tracker.h rename to src/ast/sls/sls_tracker.h index 951153a5c18..67723828faa 100644 --- a/src/tactic/sls/sls_tracker.h +++ b/src/ast/sls/sls_tracker.h @@ -25,8 +25,8 @@ Module Name: #include "ast/bv_decl_plugin.h" #include "model/model.h" -#include "tactic/sls/sls_params.hpp" -#include "tactic/sls/sls_powers.h" +#include "params/sls_params.hpp" +#include "ast/sls/sls_powers.h" class sls_tracker { ast_manager & m_manager; diff --git a/src/ast/sls/sls_valuation.cpp b/src/ast/sls/sls_valuation.cpp new file mode 100644 index 00000000000..3160e5cf5d0 --- /dev/null +++ b/src/ast/sls/sls_valuation.cpp @@ -0,0 +1,653 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_valuation.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/sls/sls_valuation.h" + +namespace bv { + + void bvect::set_bw(unsigned bw) { + this->bw = bw; + nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t)); + mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + if (mask == 0) + mask = ~(digit_t)0; + reserve(nw + 1); + } + + bool operator==(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return 0 == mpn_manager().compare(a.data(), a.nw, b.data(), a.nw); + } + + bool operator<(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) < 0; + } + + bool operator>(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) > 0; + } + + bool operator<=(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) <= 0; + } + + bool operator>=(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) >= 0; + } + + std::ostream& operator<<(std::ostream& out, bvect const& v) { + out << std::hex; + bool nz = false; + for (unsigned i = v.nw; i-- > 0;) { + auto w = v[i]; + if (i + 1 == v.nw) + w &= v.mask; + if (nz) + out << std::setw(8) << std::setfill('0') << w; + else if (w != 0) + out << w, nz = true; + } + if (!nz) + out << "0"; + out << std::dec; + return out; + } + + rational bvect::get_value(unsigned nw) const { + rational p(1), r(0); + for (unsigned i = 0; i < nw; ++i) { + r += p * rational((*this)[i]); + p *= rational::power_of_two(8 * sizeof(digit_t)); + } + return r; + } + + sls_valuation::sls_valuation(unsigned bw) { + set_bw(bw); + m_lo.set_bw(bw); + m_hi.set_bw(bw); + m_bits.set_bw(bw); + fixed.set_bw(bw); + eval.set_bw(bw); + // have lo, hi bits, fixed point to memory allocated within this of size num_bytes each allocated + for (unsigned i = 0; i < nw; ++i) + m_lo[i] = 0, m_hi[i] = 0, m_bits[i] = 0, fixed[i] = 0, eval[i] = 0; + fixed[nw - 1] = ~mask; + } + + void sls_valuation::set_bw(unsigned b) { + bw = b; + nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t)); + mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + if (mask == 0) + mask = ~(digit_t)0; + } + + bool sls_valuation::commit_eval() { + for (unsigned i = 0; i < nw; ++i) + if (0 != (fixed[i] & (m_bits[i] ^ eval[i]))) + return false; + if (!in_range(eval)) + return false; + for (unsigned i = 0; i < nw; ++i) + m_bits[i] = eval[i]; + SASSERT(well_formed()); + return true; + } + + bool sls_valuation::in_range(bvect const& bits) const { + mpn_manager m; + auto c = m.compare(m_lo.data(), nw, m_hi.data(), nw); + SASSERT(!has_overflow(bits)); + // full range + + if (c == 0) + return true; + // lo < hi: then lo <= bits & bits < hi + if (c < 0) + return + m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 && + m.compare(bits.data(), nw, m_hi.data(), nw) < 0; + // hi < lo: bits < hi or lo <= bits + return + m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 || + m.compare(bits.data(), nw, m_hi.data(), nw) < 0; + } + + // + // largest dst <= src and dst is feasible + // set dst := src & (~fixed | bits) + // + // increment dst if dst < src by setting bits below msb(src & ~dst) to 1 + // + // if dst < lo < hi: + // return false + // if lo < hi <= dst: + // set dst := hi - 1 + // if hi <= dst < lo + // set dst := hi - 1 + // + + bool sls_valuation::get_at_most(bvect const& src, bvect& dst) const { + SASSERT(!has_overflow(src)); + for (unsigned i = 0; i < nw; ++i) + dst[i] = src[i] & (~fixed[i] | m_bits[i]); + + // + // If dst < src, then find the most significant + // bit where src[idx] = 1, dst[idx] = 0 + // set dst[j] = bits_j | ~fixed_j for j < idx + // + for (unsigned i = nw; i-- > 0; ) { + if (0 != (~dst[i] & src[i])) { + auto idx = log2(~dst[i] & src[i]); + auto mask = (1 << idx) - 1; + dst[i] = (~fixed[i] & mask) | dst[i]; + for (unsigned j = i; j-- > 0; ) + dst[j] = (~fixed[j] | m_bits[j]); + break; + } + } + SASSERT(!has_overflow(dst)); + return round_down(dst); + } + + // + // smallest dst >= src and dst is feasible with respect to this. + // set dst := (src & ~fixed) | (fixed & bits) + // + // decrement dst if dst > src by setting bits below msb to 0 unless fixed + // + // if lo < hi <= dst + // return false + // if dst < lo < hi: + // set dst := lo + // if hi <= dst < lo + // set dst := lo + // + bool sls_valuation::get_at_least(bvect const& src, bvect& dst) const { + SASSERT(!has_overflow(src)); + for (unsigned i = 0; i < nw; ++i) + dst[i] = (~fixed[i] & src[i]) | (fixed[i] & m_bits[i]); + + // + // If dst > src, then find the most significant + // bit where src[idx] = 0, dst[idx] = 1 + // set dst[j] = dst[j] & fixed_j for j < idx + // + for (unsigned i = nw; i-- > 0; ) { + if (0 != (dst[i] & ~src[i])) { + auto idx = log2(dst[i] & ~src[i]); + auto mask = (1 << idx); + dst[i] = dst[i] & (fixed[i] | mask); + for (unsigned j = i; j-- > 0; ) + dst[j] = dst[j] & fixed[j]; + break; + } + } + SASSERT(!has_overflow(dst)); + return round_up(dst); + } + + bool sls_valuation::round_up(bvect& dst) const { + if (m_lo < m_hi) { + if (m_hi <= dst) + return false; + if (m_lo > dst) + set(dst, m_lo); + } + else if (m_hi <= dst && m_lo > dst) + set(dst, m_lo); + SASSERT(!has_overflow(dst)); + return true; + } + + bool sls_valuation::round_down(bvect& dst) const { + if (m_lo < m_hi) { + if (m_lo > dst) + return false; + if (m_hi <= dst) { + set(dst, m_hi); + sub1(dst); + } + } + else if (m_hi <= dst && m_lo > dst) { + set(dst, m_hi); + sub1(dst); + } + SASSERT(well_formed()); + return true; + } + + bool sls_valuation::set_random_at_most(bvect const& src, bvect& tmp, random_gen& r) { + if (!get_at_most(src, tmp)) + return false; + if (is_zero(tmp) || (0 == r() % 2)) + return try_set(tmp); + + set_random_below(tmp, r); + // random value below tmp + + if (m_lo == m_hi || is_zero(m_lo) || m_lo <= tmp) + return try_set(tmp); + + // for simplicity, bail out if we were not lucky + return get_at_most(src, tmp) && try_set(tmp); + } + + bool sls_valuation::set_random_at_least(bvect const& src, bvect& tmp, random_gen& r) { + if (!get_at_least(src, tmp)) + return false; + if (is_ones(tmp) || (0 == r() % 2)) + return try_set(tmp); + + // random value at least tmp + set_random_above(tmp, r); + + if (m_lo == m_hi || is_zero(m_hi) || m_hi > tmp) + return try_set(tmp); + + // for simplicity, bail out if we were not lucky + return get_at_least(src, tmp) && try_set(tmp); + } + + bool sls_valuation::set_random_in_range(bvect const& lo, bvect const& hi, bvect& tmp, random_gen& r) { + if (0 == r() % 2) { + if (!get_at_least(lo, tmp)) + return false; + SASSERT(in_range(tmp)); + if (hi < tmp) + return false; + + if (is_ones(tmp) || (0 == r() % 2)) + return try_set(tmp); + set_random_above(tmp, r); + round_down(tmp, [&](bvect const& t) { return hi >= t && in_range(t); }); + if (in_range(tmp) && lo <= tmp && hi >= tmp) + return try_set(tmp); + return get_at_least(lo, tmp) && hi >= tmp && try_set(tmp); + } + else { + if (!get_at_most(hi, tmp)) + return false; + SASSERT(in_range(tmp)); + if (lo > tmp) + return false; + if (is_zero(tmp) || (0 == r() % 2)) + return try_set(tmp); + set_random_below(tmp, r); + round_up(tmp, [&](bvect const& t) { return lo <= t && in_range(t); }); + if (in_range(tmp) && lo <= tmp && hi >= tmp) + return try_set(tmp); + return get_at_most(hi, tmp) && lo <= tmp && try_set(tmp); + } + } + + void sls_valuation::round_down(bvect& dst, std::function const& is_feasible) { + for (unsigned i = bw; !is_feasible(dst) && i-- > 0; ) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + repair_sign_bits(dst); + } + + void sls_valuation::round_up(bvect& dst, std::function const& is_feasible) { + for (unsigned i = 0; !is_feasible(dst) && i < bw; ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + repair_sign_bits(dst); + } + + void sls_valuation::set_random_above(bvect& dst, random_gen& r) { + for (unsigned i = 0; i < nw; ++i) + dst[i] = dst[i] | (random_bits(r) & ~fixed[i]); + repair_sign_bits(dst); + } + + void sls_valuation::set_random_below(bvect& dst, random_gen& r) { + if (is_zero(dst)) + return; + unsigned n = 0, idx = UINT_MAX; + for (unsigned i = 0; i < bw; ++i) + if (dst.get(i) && !fixed.get(i) && (r() % ++n) == 0) + idx = i; + + if (idx == UINT_MAX) + return; + dst.set(idx, false); + for (unsigned i = 0; i < idx; ++i) + if (!fixed.get(i)) + dst.set(i, r() % 2 == 0); + repair_sign_bits(dst); + } + + bool sls_valuation::set_repair(bool try_down, bvect& dst) { + for (unsigned i = 0; i < nw; ++i) + dst[i] = (~fixed[i] & dst[i]) | (fixed[i] & m_bits[i]); + + repair_sign_bits(dst); + if (in_range(dst)) { + set(eval, dst); + return true; + } + bool repaired = false; + dst.set_bw(bw); + if (m_lo < m_hi) { + for (unsigned i = bw; m_hi <= dst && !in_range(dst) && i-- > 0; ) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + for (unsigned i = 0; i < bw && dst < m_lo && !in_range(dst); ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + } + else { + for (unsigned i = 0; !in_range(dst) && i < bw; ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + for (unsigned i = bw; !in_range(dst) && i-- > 0;) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + } + repair_sign_bits(dst); + if (in_range(dst)) { + set(eval, dst); + repaired = true; + } + dst.set_bw(0); + return repaired; + } + + void sls_valuation::min_feasible(bvect& out) const { + if (m_lo < m_hi) + m_lo.copy_to(nw, out); + else { + for (unsigned i = 0; i < nw; ++i) + out[i] = fixed[i] & m_bits[i]; + } + repair_sign_bits(out); + SASSERT(!has_overflow(out)); + } + + void sls_valuation::max_feasible(bvect& out) const { + if (m_lo < m_hi) { + m_hi.copy_to(nw, out); + sub1(out); + } + else { + for (unsigned i = 0; i < nw; ++i) + out[i] = ~fixed[i] | m_bits[i]; + } + repair_sign_bits(out); + SASSERT(!has_overflow(out)); + } + + unsigned sls_valuation::msb(bvect const& src) const { + SASSERT(!has_overflow(src)); + for (unsigned i = nw; i-- > 0; ) + if (src[i] != 0) + return i * 8 * sizeof(digit_t) + log2(src[i]); + return bw; + } + + void sls_valuation::set_value(bvect& bits, rational const& n) { + for (unsigned i = 0; i < bw; ++i) + bits.set(i, n.get_bit(i)); + clear_overflow_bits(bits); + } + + void sls_valuation::get(bvect& dst) const { + m_bits.copy_to(nw, dst); + } + + digit_t sls_valuation::random_bits(random_gen& rand) { + digit_t r = 0; + for (digit_t i = 0; i < sizeof(digit_t); ++i) + r ^= rand() << (8 * i); + return r; + } + + void sls_valuation::get_variant(bvect& dst, random_gen& r) const { + for (unsigned i = 0; i < nw; ++i) + dst[i] = (random_bits(r) & ~fixed[i]) | (fixed[i] & m_bits[i]); + repair_sign_bits(dst); + clear_overflow_bits(dst); + } + + void sls_valuation::repair_sign_bits(bvect& dst) const { + if (m_signed_prefix == 0) + return; + bool sign = dst.get(bw - 1); + for (unsigned i = bw; i-- >= bw - m_signed_prefix; ) { + if (dst.get(i) != sign) { + if (fixed.get(i)) { + for (unsigned i = bw; i-- >= bw - m_signed_prefix; ) + if (!fixed.get(i)) + dst.set(i, !sign); + return; + } + else + dst.set(i, sign); + } + } + } + + // + // new_bits != bits => ~fixed + // 0 = (new_bits ^ bits) & fixed + // also check that new_bits are in range + // + bool sls_valuation::can_set(bvect const& new_bits) const { + SASSERT(!has_overflow(new_bits)); + for (unsigned i = 0; i < nw; ++i) + if (0 != ((new_bits[i] ^ m_bits[i]) & fixed[i])) + return false; + return in_range(new_bits); + } + + unsigned sls_valuation::to_nat(unsigned max_n) { + bvect const& d = m_bits; + SASSERT(!has_overflow(d)); + SASSERT(max_n < UINT_MAX / 2); + unsigned p = 1; + unsigned value = 0; + for (unsigned i = 0; i < bw; ++i) { + if (p >= max_n) { + for (unsigned j = i; j < bw; ++j) + if (d.get(j)) + return max_n; + return value; + } + if (d.get(i)) + value += p; + p <<= 1; + } + return value; + } + + void sls_valuation::shift_right(bvect& out, unsigned shift) const { + SASSERT(shift < bw); + for (unsigned i = 0; i < bw; ++i) + out.set(i, i + shift < bw ? m_bits.get(i + shift) : false); + SASSERT(well_formed()); + } + + void sls_valuation::add_range(rational l, rational h) { + + l = mod(l, rational::power_of_two(bw)); + h = mod(h, rational::power_of_two(bw)); + if (h == l) + return; + + //verbose_stream() << "[" << l << ", " << h << "[\n"; + //verbose_stream() << *this << "\n"; + + if (m_lo == m_hi) { + set_value(m_lo, l); + set_value(m_hi, h); + } + else { + auto old_lo = lo(); + auto old_hi = hi(); + if (old_lo < old_hi) { + if (old_lo < l && l < old_hi) + set_value(m_lo, l), + old_lo = l; + if (old_hi < h && h < old_hi) + set_value(m_hi, h); + } + else { + SASSERT(old_hi < old_lo); + if (old_lo < l || l < old_hi) + set_value(m_lo, l), + old_lo = l; + if (old_lo < h && h < old_hi) + set_value(m_hi, h); + else if (old_hi < old_lo && (h < old_hi || old_lo < h)) + set_value(m_hi, h); + } + } + + + + SASSERT(!has_overflow(m_lo)); + SASSERT(!has_overflow(m_hi)); + + tighten_range(); + SASSERT(well_formed()); + // verbose_stream() << *this << "\n"; + } + + // + // update bits based on ranges + // tighten lo/hi based on fixed bits. + // lo[bit_i] != fixedbit[bit_i] + // let bit_i be most significant bit position of disagreement. + // if fixedbit = 1, lo = 0, increment lo + // if fixedbit = 0, lo = 1, lo := fixed & bits + // (hi-1)[bit_i] != fixedbit[bit_i] + // if fixedbit = 0, hi-1 = 1, set hi-1 := 0, maximize below bit_i + // if fixedbit = 1, hi-1 = 0, hi := fixed & bits + // tighten fixed bits based on lo/hi + // lo + 1 = hi -> set bits = lo + // lo < hi, set most significant bits based on hi + // + void sls_valuation::tighten_range() { + + // verbose_stream() << "tighten " << *this << "\n"; + if (m_lo == m_hi) + return; + + if (!in_range(m_bits)) { + // verbose_stream() << "not in range\n"; + bool compatible = true; + for (unsigned i = 0; i < nw && compatible; ++i) + compatible = 0 == (fixed[i] & (m_bits[i] ^ m_lo[i])); + //verbose_stream() << (fixed[0] & (m_bits[0] ^ m_lo[0])) << "\n"; + //verbose_stream() << bw << " " << m_lo[0] << " " << m_bits[0] << "\n"; + if (compatible) { + //verbose_stream() << "compatible\n"; + set(m_bits, m_lo); + } + else { + bvect tmp(m_bits.nw); + tmp.set_bw(bw); + set(tmp, m_lo); + unsigned max_diff = bw; + for (unsigned i = 0; i < bw; ++i) { + if (fixed.get(i) && (m_bits.get(i) ^ m_lo.get(i))) + max_diff = i; + } + SASSERT(max_diff != bw); + + for (unsigned i = 0; i <= max_diff; ++i) + tmp.set(i, fixed.get(i) && m_bits.get(i)); + + bool found0 = false; + for (unsigned i = max_diff + 1; i < bw; ++i) { + if (found0 || m_lo.get(i) || fixed.get(i)) + tmp.set(i, m_lo.get(i) && fixed.get(i)); + else { + tmp.set(i, true); + found0 = true; + } + } + set(m_bits, tmp); + } + } + // update lo, hi to be feasible. + + for (unsigned i = bw; i-- > 0; ) { + if (!fixed.get(i)) + continue; + if (m_bits.get(i) == m_lo.get(i)) + continue; + if (m_bits.get(i)) { + m_lo.set(i, true); + for (unsigned j = i; j-- > 0; ) + m_lo.set(j, fixed.get(j) && m_bits.get(j)); + } + else { + for (unsigned j = bw; j-- > 0; ) + m_lo.set(j, fixed.get(j) && m_bits.get(j)); + } + break; + } + + SASSERT(well_formed()); + } + + void sls_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const { + digit_t c; + mpn_manager().sub(a.data(), nw, b.data(), nw, out.data(), &c); + clear_overflow_bits(out); + } + + bool sls_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const { + digit_t c; + mpn_manager().add(a.data(), nw, b.data(), nw, out.data(), nw + 1, &c); + bool ovfl = out[nw] != 0 || has_overflow(out); + clear_overflow_bits(out); + return ovfl; + } + + bool sls_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const { + mpn_manager().mul(a.data(), nw, b.data(), nw, out.data()); + bool ovfl = false; + if (check_overflow) { + ovfl = has_overflow(out); + for (unsigned i = nw; i < 2 * nw; ++i) + ovfl |= out[i] != 0; + } + clear_overflow_bits(out); + return ovfl; + } + + bool sls_valuation::is_power_of2(bvect const& src) const { + unsigned c = 0; + for (unsigned i = 0; i < nw; ++i) + c += get_num_1bits(src[i]); + return c == 1; + } + + +} diff --git a/src/ast/sls/sls_valuation.h b/src/ast/sls/sls_valuation.h new file mode 100644 index 00000000000..dcabf04c08a --- /dev/null +++ b/src/ast/sls/sls_valuation.h @@ -0,0 +1,313 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_valuation.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class bvect : public svector { + public: + unsigned bw = 0; + unsigned nw = 0; + unsigned mask = 0; + + bvect() {} + bvect(unsigned sz) : svector(sz, (unsigned)0) {} + void set_bw(unsigned bw); + + void copy_to(unsigned nw, bvect & dst) const { + SASSERT(nw <= this->size()); + for (unsigned i = 0; i < nw; ++i) + dst[i] = (*this)[i]; + } + + void set(unsigned bit_idx, bool val) { + auto _val = static_cast(0 - static_cast(val)); + get_bit_word(bit_idx) ^= (_val ^ get_bit_word(bit_idx)) & get_pos_mask(bit_idx); + } + + bool get(unsigned bit_idx) const { + return (get_bit_word(bit_idx) & get_pos_mask(bit_idx)) != 0; + } + + unsigned parity() const { + SASSERT(bw > 0); + for (unsigned i = 0; i < nw; ++i) + if ((*this)[i] != 0) + return (8 * sizeof(digit_t) * i) + trailing_zeros((*this)[i]); + return bw; + } + + rational get_value(unsigned nw) const; + + friend bool operator==(bvect const& a, bvect const& b); + friend bool operator<(bvect const& a, bvect const& b); + friend bool operator>(bvect const& a, bvect const& b); + friend bool operator<=(bvect const& a, bvect const& b); + friend bool operator>=(bvect const& a, bvect const& b); + friend std::ostream& operator<<(std::ostream& out, bvect const& v); + + private: + + static digit_t get_pos_mask(unsigned bit_idx) { + return (digit_t)1 << (digit_t)(bit_idx % (8 * sizeof(digit_t))); + } + + digit_t get_bit_word(unsigned bit_idx) const { + return (*this)[bit_idx / (8 * sizeof(digit_t))]; + } + + digit_t& get_bit_word(unsigned bit_idx) { + return (*this)[bit_idx / (8 * sizeof(digit_t))]; + } + }; + + bool operator==(bvect const& a, bvect const& b); + bool operator<(bvect const& a, bvect const& b); + bool operator<=(bvect const& a, bvect const& b); + bool operator>=(bvect const& a, bvect const& b); + bool operator>(bvect const& a, bvect const& b); + inline bool operator!=(bvect const& a, bvect const& b) { return !(a == b); } + std::ostream& operator<<(std::ostream& out, bvect const& v); + + class sls_valuation { + protected: + bvect m_bits; + bvect m_lo, m_hi; // range assignment to bit-vector, as wrap-around interval + unsigned m_signed_prefix = 0; + + unsigned mask; + bool round_up(bvect& dst) const; + bool round_down(bvect& dst) const; + + void repair_sign_bits(bvect& dst) const; + + + public: + unsigned bw; // bit-width + unsigned nw; // num words + bvect fixed; // bit assignment and don't care bit + bvect eval; // current evaluation + + sls_valuation(unsigned bw); + + void set_bw(unsigned bw); + void set_signed(unsigned prefix) { m_signed_prefix = prefix; } + + unsigned num_bytes() const { return (bw + 7) / 8; } + + digit_t bits(unsigned i) const { return m_bits[i]; } + bvect const& bits() const { return m_bits; } + bool commit_eval(); + + bool get_bit(unsigned i) const { return m_bits.get(i); } + bool try_set_bit(unsigned i, bool b) { + SASSERT(in_range(m_bits)); + if (fixed.get(i) && get_bit(i) != b) + return false; + eval.set(i, b); + if (in_range(m_bits)) + return true; + eval.set(i, !b); + return false; + } + + void set_value(bvect& bits, rational const& r); + + rational get_value() const { return m_bits.get_value(nw); } + rational get_eval() const { return eval.get_value(nw); } + rational lo() const { return m_lo.get_value(nw); } + rational hi() const { return m_hi.get_value(nw); } + + + void get(bvect& dst) const; + void add_range(rational lo, rational hi); + bool has_range() const { return m_lo != m_hi; } + void tighten_range(); + + void clear_overflow_bits(bvect& bits) const { + SASSERT(nw > 0); + bits[nw - 1] &= mask; + SASSERT(!has_overflow(bits)); + } + + bool in_range(bvect const& bits) const; + bool can_set(bvect const& bits) const; + + bool eq(sls_valuation const& other) const { return eq(other.m_bits); } + bool eq(bvect const& other) const { return other == m_bits; } + + bool is_zero() const { return is_zero(m_bits); } + bool is_zero(bvect const& a) const { + for (unsigned i = 0; i < nw - 1; ++i) + if (a[i] != 0) + return false; + return (a[nw - 1] & mask) == 0; + } + + bool is_ones() const { return is_ones(m_bits); } + + bool is_ones(bvect const& a) const { + SASSERT(!has_overflow(a)); + for (unsigned i = 0; i + 1 < nw; ++i) + if (0 != ~a[i]) + return false; + return 0 == (mask & ~a[nw - 1]); + } + + bool is_one() const { return is_one(m_bits); } + bool is_one(bvect const& a) const { + SASSERT(!has_overflow(a)); + for (unsigned i = 1; i < nw; ++i) + if (0 != a[i]) + return false; + return 1 == a[0]; + } + + bool sign() const { return m_bits.get(bw - 1); } + + bool has_overflow(bvect const& bits) const { return 0 != (bits[nw - 1] & ~mask); } + + unsigned parity(bvect const& bits) const { return bits.parity(); } + + void min_feasible(bvect& out) const; + void max_feasible(bvect& out) const; + + // most significant bit or bw if src = 0 + unsigned msb(bvect const& src) const; + + bool is_power_of2(bvect const& src) const; + + // retrieve largest number at or below (above) src which is feasible + // with respect to fixed, lo, hi. + bool get_at_most(bvect const& src, bvect& dst) const; + bool get_at_least(bvect const& src, bvect& dst) const; + + bool set_random_at_most(bvect const& src, bvect& tmp, random_gen& r); + bool set_random_at_least(bvect const& src, bvect& tmp, random_gen& r); + bool set_random_in_range(bvect const& lo, bvect const& hi, bvect& tmp, random_gen& r); + + bool set_repair(bool try_down, bvect& dst); + void set_random_above(bvect& dst, random_gen& r); + void set_random_below(bvect& dst, random_gen& r); + void round_down(bvect& dst, std::function const& is_feasible); + void round_up(bvect& dst, std::function const& is_feasible); + + + static digit_t random_bits(random_gen& r); + void get_variant(bvect& dst, random_gen& r) const; + + bool try_set(bvect const& src) { + if (!can_set(src)) + return false; + set(src); + return true; + } + + void set(bvect const& src) { + for (unsigned i = nw; i-- > 0; ) + eval[i] = src[i]; + clear_overflow_bits(eval); + } + + void set_zero(bvect& out) const { + for (unsigned i = 0; i < nw; ++i) + out[i] = 0; + } + + void set_one(bvect& out) const { + for (unsigned i = 1; i < nw; ++i) + out[i] = 0; + out[0] = 1; + } + + void set_zero() { + set_zero(eval); + } + + void sub1(bvect& out) const { + for (unsigned i = 0; i < bw; ++i) { + if (out.get(i)) { + out.set(i, false); + return; + } + else + out.set(i, true); + } + } + + void set_sub(bvect& out, bvect const& a, bvect const& b) const; + bool set_add(bvect& out, bvect const& a, bvect const& b) const; + bool set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow = true) const; + void shift_right(bvect& out, unsigned shift) const; + + void set_range(bvect& dst, unsigned lo, unsigned hi, bool b) { + for (unsigned i = lo; i < hi; ++i) + dst.set(i, b); + } + + bool try_set_range(bvect& dst, unsigned lo, unsigned hi, bool b) { + for (unsigned i = lo; i < hi; ++i) + if (fixed.get(i) && get_bit(i) != b) + return false; + for (unsigned i = lo; i < hi; ++i) + dst.set(i, b); + return true; + } + + void set(bvect& dst, unsigned v) const { + dst[0] = v; + for (unsigned i = 1; i < nw; ++i) + dst[i] = 0; + } + + void set(bvect& dst, bvect const& src) const { + for (unsigned i = 0; i < nw; ++i) + dst[i] = src[i]; + } + + unsigned to_nat(unsigned max_n); + + std::ostream& display(std::ostream& out) const { + out << m_bits; + out << " ev: " << eval; + if (!is_zero(fixed)) { + out << " fix:"; + out << fixed; + } + if (m_lo != m_hi) + out << " [" << m_lo << ", " << m_hi << "["; + return out; + } + + bool well_formed() const { + return !has_overflow(m_bits) && (!has_range() || in_range(m_bits)); + } + + }; + + inline std::ostream& operator<<(std::ostream& out, sls_valuation const& v) { return v.display(out); } + +} diff --git a/src/ast/special_relations_decl_plugin.cpp b/src/ast/special_relations_decl_plugin.cpp index 24a756bf79e..bbfe819d420 100644 --- a/src/ast/special_relations_decl_plugin.cpp +++ b/src/ast/special_relations_decl_plugin.cpp @@ -26,7 +26,8 @@ special_relations_decl_plugin::special_relations_decl_plugin(): m_po("partial-order"), m_plo("piecewise-linear-order"), m_to("tree-order"), - m_tc("transitive-closure") + m_tc("transitive-closure"), + m_ac("ac-op") {} func_decl * special_relations_decl_plugin::mk_func_decl( @@ -41,24 +42,53 @@ func_decl * special_relations_decl_plugin::mk_func_decl( m_manager->raise_exception("argument sort missmatch. The two arguments should have the same sort"); return nullptr; } + if (!range && k == OP_SPECIAL_RELATION_AC) + range = domain[0]; + if (!range) { range = m_manager->mk_bool_sort(); } - if (!m_manager->is_bool(range)) { - m_manager->raise_exception("range type is expected to be Boolean for special relations"); - } + auto check_bool_range = [&]() { + if (!m_manager->is_bool(range)) + m_manager->raise_exception("range type is expected to be Boolean for special relations"); + }; + + m_has_special_relation = true; func_decl_info info(m_family_id, k, num_parameters, parameters); symbol name; switch(k) { - case OP_SPECIAL_RELATION_PO: name = m_po; break; - case OP_SPECIAL_RELATION_LO: name = m_lo; break; - case OP_SPECIAL_RELATION_PLO: name = m_plo; break; - case OP_SPECIAL_RELATION_TO: name = m_to; break; + case OP_SPECIAL_RELATION_PO: check_bool_range(); name = m_po; break; + case OP_SPECIAL_RELATION_LO: check_bool_range(); name = m_lo; break; + case OP_SPECIAL_RELATION_PLO: check_bool_range(); name = m_plo; break; + case OP_SPECIAL_RELATION_TO: check_bool_range(); name = m_to; break; + case OP_SPECIAL_RELATION_AC: { + if (range != domain[0]) + m_manager->raise_exception("AC operation should have the same range as domain type"); + name = m_ac; + if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) + m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("ac function should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("ac function should have same domain"); + if (f->get_domain(0) != f->get_range()) + m_manager->raise_exception("ac function should have same domain and range"); + break; + } case OP_SPECIAL_RELATION_TC: + check_bool_range(); name = m_tc; if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast())) m_manager->raise_exception("parameter to transitive closure should be a function declaration"); + func_decl* f = to_func_decl(parameters[0].get_ast()); + if (f->get_arity() != 2) + m_manager->raise_exception("tc relation should be binary"); + if (f->get_domain(0) != f->get_domain(1)) + m_manager->raise_exception("tc relation should have same domain"); + if (!m_manager->is_bool(f->get_range())) + m_manager->raise_exception("tc relation should be Boolean"); break; } return m_manager->mk_func_decl(name, arity, domain, range, info); @@ -71,6 +101,7 @@ void special_relations_decl_plugin::get_op_names(svector & op_name op_names.push_back(builtin_name(m_plo.str(), OP_SPECIAL_RELATION_PLO)); op_names.push_back(builtin_name(m_to.str(), OP_SPECIAL_RELATION_TO)); op_names.push_back(builtin_name(m_tc.str(), OP_SPECIAL_RELATION_TC)); + op_names.push_back(builtin_name(m_ac.str(), OP_SPECIAL_RELATION_AC)); } } @@ -81,6 +112,7 @@ sr_property special_relations_util::get_property(func_decl* f) const { case OP_SPECIAL_RELATION_PLO: return sr_plo; case OP_SPECIAL_RELATION_TO: return sr_to; case OP_SPECIAL_RELATION_TC: return sr_tc; + case OP_SPECIAL_RELATION_AC: return sr_none; default: UNREACHABLE(); return sr_po; diff --git a/src/ast/special_relations_decl_plugin.h b/src/ast/special_relations_decl_plugin.h index c422cbcdc01..a65f98758b9 100644 --- a/src/ast/special_relations_decl_plugin.h +++ b/src/ast/special_relations_decl_plugin.h @@ -16,6 +16,8 @@ Module Name: Revision History: + 2023-11-27: Added ac-op for E-graph plugin + --*/ #pragma once @@ -28,6 +30,7 @@ enum special_relations_op_kind { OP_SPECIAL_RELATION_PLO, OP_SPECIAL_RELATION_TO, OP_SPECIAL_RELATION_TC, + OP_SPECIAL_RELATION_AC, LAST_SPECIAL_RELATIONS_OP }; @@ -37,6 +40,7 @@ class special_relations_decl_plugin : public decl_plugin { symbol m_plo; symbol m_to; symbol m_tc; + symbol m_ac; bool m_has_special_relation = false; public: special_relations_decl_plugin(); @@ -86,13 +90,16 @@ class special_relations_util { public: special_relations_util(ast_manager& m) : m(m), m_fid(null_family_id) { } + family_id get_family_id() const { return fid(); } + bool has_special_relation() const { return static_cast(m.get_plugin(m.mk_family_id("specrels")))->has_special_relation(); } bool is_special_relation(func_decl* f) const { return f->get_family_id() == fid(); } - bool is_special_relation(app* e) const { return is_special_relation(e->get_decl()); } + bool is_special_relation(expr* e) const { return is_app(e) && is_special_relation(to_app(e)->get_decl()); } sr_property get_property(func_decl* f) const; sr_property get_property(app* e) const { return get_property(e->get_decl()); } func_decl* get_relation(func_decl* f) const { SASSERT(is_special_relation(f)); return to_func_decl(f->get_parameter(0).get_ast()); } + func_decl* get_relation(expr* e) const { SASSERT(is_special_relation(e)); return to_func_decl(to_app(e)->get_parameter(0).get_ast()); } func_decl* mk_to_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_TO); } func_decl* mk_po_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_PO); } @@ -105,12 +112,14 @@ class special_relations_util { bool is_plo(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(expr const* e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_AC); } bool is_lo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_LO); } bool is_po(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PO); } bool is_plo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PLO); } bool is_to(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TO); } bool is_tc(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TC); } + bool is_ac(func_decl const* e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_AC); } app * mk_lo (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_LO, arg1, arg2); } app * mk_po (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_PO, arg1, arg2); } diff --git a/src/ast/well_sorted.cpp b/src/ast/well_sorted.cpp index fa8e2768bfa..84cdf28addb 100644 --- a/src/ast/well_sorted.cpp +++ b/src/ast/well_sorted.cpp @@ -70,8 +70,7 @@ struct well_sorted_proc { strm << "Expected sort: " << mk_pp(expected_sort, m_manager) << '\n'; strm << "Actual sort: " << mk_pp(actual_sort, m_manager) << '\n'; strm << "Function sort: " << mk_pp(decl, m_manager) << '.'; - auto str = strm.str(); - warning_msg("%s", str.c_str()); + warning_msg("%s", std::move(strm).str().c_str()); m_error = true; return; } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 8d1e375d101..019b7fd4f8f 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -42,6 +42,7 @@ Module Name: #include "ast/for_each_expr.h" #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/recfun_replace.h" +#include "ast/polymorphism_util.h" #include "model/model_evaluator.h" #include "model/model_smt2_pp.h" #include "model/model_v2_pp.h" @@ -223,12 +224,48 @@ bool func_decls::check_signature(ast_manager& m, func_decl* f, unsigned arity, s return true; } -func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * domain, sort * range) const { +bool func_decls::check_poly_signature(ast_manager& m, func_decl* f, unsigned arity, sort* const* domain, sort* range, func_decl*& g) { + polymorphism::substitution sub(m); + arith_util au(m); + sort_ref range_ref(range, m); + if (range != nullptr && !sub.match(f->get_range(), range)) + return false; + if (f->get_arity() != arity) + return false; + for (unsigned i = 0; i < arity; i++) + if (!sub.match(f->get_domain(i), domain[i])) + return false; + if (!range) + range_ref = sub(f->get_range()); + + recfun::util u(m); + auto& p = u.get_plugin(); + if (!u.has_def(f)) { + g = m.instantiate_polymorphic(f, arity, domain, range_ref); + return true; + } + // this is an instantiation of a recursive polymorphic function. + // create a self-contained polymorphic definition for the instantiation. + auto def = u.get_def(f); + auto promise_def = p.mk_def(f->get_name(), arity, domain, range_ref, false); + recfun_replace replace(m); + expr_ref tt = sub(def.get_rhs()); + p.set_definition(replace, promise_def, def.is_macro(), def.get_vars().size(), def.get_vars().data(), tt); + g = promise_def.get_def()->get_decl(); + insert(m, g); + return true; +} + + +func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * domain, sort * range) { bool coerced = false; + func_decl* g = nullptr; if (!more_than_one()) { func_decl* f = first(); if (check_signature(m, f, arity, domain, range, coerced)) - return f; + return f; + if (check_poly_signature(m, f, arity, domain, range, g)) + return g; return nullptr; } func_decl_set * fs = UNTAG(func_decl_set *, m_decls); @@ -241,10 +278,15 @@ func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * doma return f; } } - return best_f; + if (best_f != nullptr) + return best_f; + for (func_decl* f : *fs) + if (check_poly_signature(m, f, arity, domain, range, g)) + return g; + return nullptr; } -func_decl * func_decls::find(ast_manager & m, unsigned num_args, expr * const * args, sort * range) const { +func_decl * func_decls::find(ast_manager & m, unsigned num_args, expr * const * args, sort * range) { if (!more_than_one()) first(); ptr_buffer sorts; @@ -376,12 +418,13 @@ void cmd_context::erase_macro(symbol const& s) { decls.erase_last(m()); } -bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr*& t) const { +bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr_ref& t) { macro_decls decls; if (!m_macros.find(s, decls)) return false; for (macro_decl const& d : decls) { - if (d.m_domain.size() != n) continue; + if (d.m_domain.size() != n) + continue; bool eq = true; coerced_args.reset(); for (unsigned i = 0; eq && i < n; ++i) { @@ -406,6 +449,26 @@ bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, exp return true; } } + for (macro_decl const& d : decls) { + if (d.m_domain.size() != n) + continue; + polymorphism::substitution sub(m()); + bool eq = true; + for (unsigned i = 0; eq && i < n; ++i) { + if (!sub.match(d.m_domain[i], args[i]->get_sort())) + eq = false; + } + if (eq) { + t = d.m_body; + t = sub(t); + verbose_stream() << "macro " << t << "\n"; + ptr_buffer domain; + for (unsigned i = 0; i < n; ++i) + domain.push_back(args[i]->get_sort()); + insert_macro(s, n, domain.data(), t); + return true; + } + } return false; } @@ -508,8 +571,12 @@ class cmd_context::pp_env : public smt2_pp_environment { m_owner.m_func_decls.contains(s); } format_ns::format * pp_sort(sort * s) override { - return m_owner.pp(s); + auto * f = m_owner.try_pp(s); + if (f) + return f; + return smt2_pp_environment::pp_sort(s); } + format_ns::format * pp_fdecl(func_decl * f, unsigned & len) override { symbol s = f->get_name(); func_decls fs; @@ -935,18 +1002,16 @@ void cmd_context::insert(cmd * c) { void cmd_context::insert_user_tactic(symbol const & s, sexpr * d) { sm().inc_ref(d); sexpr * old_d; - if (m_user_tactic_decls.find(s, old_d)) { - sm().dec_ref(old_d); - } + if (m_user_tactic_decls.find(s, old_d)) + sm().dec_ref(old_d); m_user_tactic_decls.insert(s, d); } void cmd_context::insert(symbol const & s, object_ref * r) { r->inc_ref(*this); object_ref * old_r = nullptr; - if (m_object_refs.find(s, old_r)) { - old_r->dec_ref(*this); - } + if (m_object_refs.find(s, old_r)) + old_r->dec_ref(*this); m_object_refs.insert(s, r); } @@ -1050,16 +1115,17 @@ static builtin_decl const & peek_builtin_decl(builtin_decl const & first, family } func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, - unsigned arity, sort * const * domain, sort * range) const { + unsigned arity, sort * const * domain, sort * range) { if (domain && contains_macro(s, arity, domain)) throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s); func_decl * f = nullptr; - func_decls fs; - if (num_indices == 0 && m_func_decls.find(s, fs)) + if (num_indices == 0 && m_func_decls.contains(s)) { + auto& fs = m_func_decls.find(s); f = fs.find(m(), arity, domain, range); - if (f) + } + if (f) return f; builtin_decl d; if ((arity == 0 || domain) && m_builtin_decls.find(s, d)) { @@ -1085,11 +1151,12 @@ func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, throw cmd_exception("invalid function declaration reference, invalid builtin reference ", s); return f; } - if (num_indices > 0 && m_func_decls.find(s, fs)) + if (num_indices > 0 && m_func_decls.contains(s)) { + auto& fs = m_func_decls.find(s); f = fs.find(m(), arity, domain, range); - if (f) + } + if (f) return f; - throw cmd_exception("invalid function declaration reference, unknown indexed function ", s); } @@ -1121,7 +1188,7 @@ object_ref * cmd_context::find_object_ref(symbol const & s) const { #define CHECK_SORT(T) if (well_sorted_check_enabled()) m().check_sorts_core(T) -void cmd_context::mk_const(symbol const & s, expr_ref & result) const { +void cmd_context::mk_const(symbol const & s, expr_ref & result) { mk_app(s, 0, nullptr, 0, nullptr, nullptr, result); } @@ -1149,9 +1216,10 @@ bool cmd_context::try_mk_builtin_app(symbol const & s, unsigned num_args, expr * bool cmd_context::try_mk_declared_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - func_decls& fs, expr_ref & result) const { - if (!m_func_decls.find(s, fs)) + expr_ref & result) { + if (!m_func_decls.contains(s)) return false; + func_decls& fs = m_func_decls.find(s); if (num_args == 0 && !range) { if (fs.more_than_one()) @@ -1176,8 +1244,8 @@ bool cmd_context::try_mk_declared_app(symbol const & s, unsigned num_args, expr bool cmd_context::try_mk_macro_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & result) const { - expr* _t; + expr_ref & result) { + expr_ref _t(m()); expr_ref_vector coerced_args(m()); if (macros_find(s, num_args, args, coerced_args, _t)) { TRACE("macro_bug", tout << "well_sorted_check_enabled(): " << well_sorted_check_enabled() << "\n"; @@ -1252,19 +1320,21 @@ bool cmd_context::try_mk_pdecl_app(symbol const & s, unsigned num_args, expr * c void cmd_context::mk_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & result) const { + expr_ref & result) { - func_decls fs; + if (try_mk_macro_app(s, num_args, args, num_indices, indices, range, result)) return; - if (try_mk_declared_app(s, num_args, args, num_indices, indices, range, fs, result)) - return; + if (try_mk_declared_app(s, num_args, args, num_indices, indices, range, result)) + return; if (try_mk_builtin_app(s, num_args, args, num_indices, indices, range, result)) return; if (!range && try_mk_pdecl_app(s, num_args, args, num_indices, indices, result)) return; + func_decls fs; + m_func_decls.find(s, fs); std::ostringstream buffer; buffer << "unknown constant " << s; if (num_args > 0) { @@ -1390,14 +1460,13 @@ void cmd_context::reset_macros() { } void cmd_context::reset_cmds() { - for (auto& kv : m_cmds) { - kv.m_value->reset(*this); + for (auto& [k,v] : m_cmds) { + v->reset(*this); } } void cmd_context::finalize_cmds() { - for (auto& kv : m_cmds) { - cmd * c = kv.m_value; + for (auto& [k,c] : m_cmds) { c->finalize(*this); dealloc(c); } @@ -1429,6 +1498,7 @@ void cmd_context::reset(bool finalize) { m_builtin_decls.reset(); m_extra_builtin_decls.reset(); m_check_logic.reset(); + m_proof_cmds = nullptr; reset_object_refs(); reset_cmds(); reset_psort_decls(); @@ -1608,6 +1678,8 @@ void cmd_context::restore_assertions(unsigned old_sz) { SASSERT(m_assertions.empty()); return; } + if (m_assertions.empty()) + return; if (old_sz == m_assertions.size()) return; SASSERT(old_sz < m_assertions.size()); @@ -2226,6 +2298,8 @@ vector> cmd_context::tracked_assertions() { } void cmd_context::reset_tracked_assertions() { + for (expr* a : m_assertion_names) + m().dec_ref(a); m_assertion_names.reset(); for (expr* a : m_assertions) m().dec_ref(a); @@ -2261,8 +2335,12 @@ bool cmd_context::is_model_available(model_ref& md) const { } format_ns::format * cmd_context::pp(sort * s) const { + return get_pp_env().pp_sort(s); +} + +format_ns::format* cmd_context::try_pp(sort* s) const { TRACE("cmd_context", tout << "pp(sort * s), s: " << mk_pp(s, m()) << "\n";); - return pm().pp(s); + return pm().pp(get_pp_env(), s); } cmd_context::pp_env & cmd_context::get_pp_env() const { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index b034a9ffc07..c07d888c7e4 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -58,11 +58,12 @@ class func_decls { bool clash(func_decl * f) const; bool empty() const { return m_decls == nullptr; } func_decl * first() const; - func_decl * find(ast_manager & m, unsigned arity, sort * const * domain, sort * range) const; - func_decl * find(ast_manager & m, unsigned arity, expr * const * args, sort * range) const; + func_decl * find(ast_manager & m, unsigned arity, sort * const * domain, sort * range); + func_decl * find(ast_manager & m, unsigned arity, expr * const * args, sort * range); unsigned get_num_entries() const; func_decl * get_entry(unsigned inx); bool check_signature(ast_manager& m, func_decl* f, unsigned arityh, sort * const* domain, sort * range, bool& coerced) const; + bool check_poly_signature(ast_manager& m, func_decl* f, unsigned arity, sort* const* domain, sort* range, func_decl*& g); }; struct macro_decl { @@ -355,7 +356,7 @@ class cmd_context : public progress_callback, public tactic_manager, public ast_ bool contains_macro(symbol const& s, unsigned arity, sort *const* domain) const; void insert_macro(symbol const& s, unsigned arity, sort*const* domain, expr* t); void erase_macro(symbol const& s); - bool macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr*& t) const; + bool macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr_ref& t); recfun::decl::plugin& get_recfun_plugin(); @@ -449,22 +450,22 @@ class cmd_context : public progress_callback, public tactic_manager, public ast_ void insert_rec_fun(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* e); func_decl * find_func_decl(symbol const & s) const; func_decl * find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, - unsigned arity, sort * const * domain, sort * range) const; + unsigned arity, sort * const * domain, sort * range); recfun::promise_def decl_rec_fun(const symbol &name, unsigned int arity, sort *const *domain, sort *range); psort_decl * find_psort_decl(symbol const & s) const; cmd * find_cmd(symbol const & s) const; sexpr * find_user_tactic(symbol const & s) const; object_ref * find_object_ref(symbol const & s) const; - void mk_const(symbol const & s, expr_ref & result) const; + void mk_const(symbol const & s, expr_ref & result); void mk_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & r) const; + expr_ref & r); bool try_mk_macro_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & r) const; + expr_ref & r); bool try_mk_builtin_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, expr_ref & r) const; bool try_mk_declared_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - func_decls& fs, expr_ref & result) const; + expr_ref & result); bool try_mk_pdecl_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, expr_ref & r) const; void erase_cmd(symbol const & s); void erase_func_decl(symbol const & s); @@ -538,6 +539,7 @@ class cmd_context : public progress_callback, public tactic_manager, public ast_ } format_ns::format * pp(sort * s) const; + format_ns::format* try_pp(sort* s) const; void pp(sort * s, format_ns::format_ref & r) const override { r = pp(s); } void pp(func_decl * f, format_ns::format_ref & r) const override; void pp(expr * n, unsigned num_vars, char const * var_prefix, format_ns::format_ref & r, sbuffer & var_names) const override; diff --git a/src/cmd_context/extra_cmds/dbg_cmds.cpp b/src/cmd_context/extra_cmds/dbg_cmds.cpp index b18d435285d..c8f5deb7d40 100644 --- a/src/cmd_context/extra_cmds/dbg_cmds.cpp +++ b/src/cmd_context/extra_cmds/dbg_cmds.cpp @@ -16,6 +16,7 @@ Module Name: --*/ #include +#include "ast/ast.h" #include "cmd_context/cmd_context.h" #include "cmd_context/cmd_util.h" #include "ast/rewriter/rewriter.h" @@ -34,7 +35,9 @@ Module Name: #include "qe/qe_mbp.h" #include "qe/qe_mbi.h" #include "qe/mbp/mbp_term_graph.h" - +#include "qe/mbp/mbp_qel.h" +#include "qe/lite/qe_lite_tactic.h" +#include "qe/lite/qel.h" BINARY_SYM_CMD(get_quantifier_body_cmd, "dbg-get-qbody", @@ -369,7 +372,7 @@ class mbp_cmd : public cmd { } vars.push_back(to_app(v)); } - qe::mbproj mbp(m); + qe::mbproj mbp(m, gparams::get_module("smt")); expr_ref fml(m_fml, m); mbp.spacer(vars, *mdl.get(), fml); ctx.regular_stream() << fml << "\n"; @@ -572,8 +575,192 @@ class euf_project_cmd : public cmd { }; +class mbp_qel_cmd : public cmd { + unsigned m_arg_index; + ptr_vector m_lits; + ptr_vector m_vars; + + public: + mbp_qel_cmd() : cmd("mbp-qel"){}; + char const *get_usage() const override { return "(exprs) (vars)"; } + char const *get_descr(cmd_context &ctx) const override { + return "Model based projection using e-graphs"; + } + unsigned get_arity() const override { return 2; } + cmd_arg_kind next_arg_kind(cmd_context &ctx) const override { + return CPK_EXPR_LIST; + } + void set_next_arg(cmd_context &ctx, unsigned num, + expr *const *args) override { + if (m_arg_index == 0) { + m_lits.append(num, args); + m_arg_index = 1; + } + else { m_vars.append(num, args); } + } + void prepare(cmd_context &ctx) override { + m_arg_index = 0; + m_lits.reset(); + m_vars.reset(); + } + void execute(cmd_context &ctx) override { + ast_manager &m = ctx.m(); + app_ref_vector vars(m); + expr_ref fml(m); + expr_ref_vector lits(m); + for (expr *v : m_vars) vars.push_back(to_app(v)); + for (expr *e : m_lits) lits.push_back(e); + fml = mk_and(lits); + solver_factory &sf = ctx.get_solver_factory(); + params_ref pa; + solver_ref s = sf(m, pa, false, true, true, symbol::null); + s->assert_expr(fml); + lbool r = s->check_sat(); + if (r != l_true) return; + model_ref mdl; + s->get_model(mdl); + mbp::mbp_qel mbptg(m, pa); + mbptg(vars, fml, *mdl.get()); + + ctx.regular_stream() << "------------------------------ " << std::endl; + ctx.regular_stream() << "Orig tg: " << mk_and(lits) << std::endl; + ctx.regular_stream() << "To elim: "; + for (expr *v : m_vars) { + ctx.regular_stream() << to_app(v)->get_decl()->get_name() << " "; + } + ctx.regular_stream() << std::endl; + ctx.regular_stream() << "output " << fml << std::endl; + } +}; + +class qel_cmd : public cmd { + unsigned m_arg_index; + ptr_vector m_lits; + ptr_vector m_vars; + + public: + qel_cmd() : cmd("qel"){}; + char const *get_usage() const override { return "(lits) (vars)"; } + char const *get_descr(cmd_context &ctx) const override { + return "QE lite over e-graphs"; + } + unsigned get_arity() const override { return 2; } + cmd_arg_kind next_arg_kind(cmd_context &ctx) const override { + if (m_arg_index == 0) return CPK_EXPR_LIST; + return CPK_FUNC_DECL_LIST; + } + void set_next_arg(cmd_context &ctx, unsigned num, + expr *const *args) override { + m_lits.append(num, args); + m_arg_index = 1; + } + void set_next_arg(cmd_context &ctx, unsigned num, + func_decl *const *ts) override { + m_vars.append(num, ts); + } + void prepare(cmd_context &ctx) override { + m_arg_index = 0; + m_lits.reset(); + m_vars.reset(); + } + void execute(cmd_context &ctx) override { + ast_manager &m = ctx.m(); + func_decl_ref_vector vars(m); + app_ref_vector vars_apps(m); + expr_ref_vector lits(m); + + ctx.regular_stream() << "------------------------------ " << std::endl; + + for (func_decl *v : m_vars) { + vars.push_back(v); + vars_apps.push_back(m.mk_const(v)); + } + for (expr *e : m_lits) lits.push_back(e); + + expr_ref fml(m.mk_and(lits), m); + ctx.regular_stream() << "[tg] Before: " << fml << std::endl + << "[tg] Vars: "; + for (app *a : vars_apps) ctx.regular_stream() << app_ref(a, m) << " "; + + ctx.regular_stream() << std::endl; + + params_ref pa; + + // the following is the same code as in qe_mbp in spacer + qel qe(m, pa); + qe(vars_apps, fml); + ctx.regular_stream() << "[tg] After: " << fml << std::endl + << "[tg] Vars: "; + for (app *a : vars_apps) ctx.regular_stream() << app_ref(a, m) << " "; + + ctx.regular_stream() << std::endl; + } +}; + +class qe_lite_cmd : public cmd { + unsigned m_arg_index; + ptr_vector m_lits; + ptr_vector m_vars; + + public: + qe_lite_cmd() : cmd("qe-lite"){}; + char const *get_usage() const override { return "(lits) (vars)"; } + char const *get_descr(cmd_context &ctx) const override { + return "QE lite over e-graphs"; + } + unsigned get_arity() const override { return 2; } + cmd_arg_kind next_arg_kind(cmd_context &ctx) const override { + if (m_arg_index == 0) return CPK_EXPR_LIST; + return CPK_FUNC_DECL_LIST; + } + void set_next_arg(cmd_context &ctx, unsigned num, + expr *const *args) override { + m_lits.append(num, args); + m_arg_index = 1; + } + void set_next_arg(cmd_context &ctx, unsigned num, + func_decl *const *ts) override { + m_vars.append(num, ts); + } + void prepare(cmd_context &ctx) override { + m_arg_index = 0; + m_lits.reset(); + m_vars.reset(); + } + void execute(cmd_context &ctx) override { + ast_manager &m = ctx.m(); + func_decl_ref_vector vars(m); + app_ref_vector vars_apps(m); + expr_ref_vector lits(m); + + ctx.regular_stream() << "------------------------------ " << std::endl; + + for (func_decl *v : m_vars) { + vars.push_back(v); + vars_apps.push_back(m.mk_const(v)); + } + for (expr *e : m_lits) lits.push_back(e); + + expr_ref fml(m.mk_and(lits), m); + ctx.regular_stream() << "[der] Before: " << fml << std::endl + << "[der] Vars: "; + for (app *a : vars_apps) ctx.regular_stream() << app_ref(a, m) << " "; + + ctx.regular_stream() << std::endl; + + params_ref pa; + // the following is the same code as in qe_mbp in spacer + qe_lite qe(m, pa, false); + qe(vars_apps, fml); + ctx.regular_stream() << "[der] After: " << fml << std::endl + << "[der] Vars: "; + for (app *a : vars_apps) ctx.regular_stream() << app_ref(a, m) << " "; + + ctx.regular_stream() << std::endl; + } +}; -void install_dbg_cmds(cmd_context & ctx) { +void install_dbg_cmds(cmd_context &ctx) { ctx.insert(alloc(print_dimacs_cmd)); ctx.insert(alloc(get_quantifier_body_cmd)); ctx.insert(alloc(set_cmd)); @@ -598,7 +785,10 @@ void install_dbg_cmds(cmd_context & ctx) { ctx.insert(alloc(set_next_id)); ctx.insert(alloc(get_interpolant_cmd)); ctx.insert(alloc(mbp_cmd)); + ctx.insert(alloc(mbp_qel_cmd)); ctx.insert(alloc(mbi_cmd)); ctx.insert(alloc(euf_project_cmd)); ctx.insert(alloc(eufi_cmd)); + ctx.insert(alloc(qel_cmd)); + ctx.insert(alloc(qe_lite_cmd)); } diff --git a/src/cmd_context/extra_cmds/proof_cmds.cpp b/src/cmd_context/extra_cmds/proof_cmds.cpp index 9c9b9ed62ae..ea585bfae9b 100644 --- a/src/cmd_context/extra_cmds/proof_cmds.cpp +++ b/src/cmd_context/extra_cmds/proof_cmds.cpp @@ -43,6 +43,7 @@ Proof checker for clauses created during search. #include "util/small_object_allocator.h" #include "ast/ast_util.h" #include "ast/ast_ll_pp.h" +#include "ast/arith_decl_plugin.h" #include "smt/smt_solver.h" #include "sat/sat_solver.h" #include "sat/sat_drat.h" @@ -63,6 +64,7 @@ class proof_trim { vector m_clauses; bool_vector m_is_infer; symbol m_rup; + bool m_empty = false; void mk_clause(expr_ref_vector const& clause) { trim.init_clause(); @@ -121,25 +123,32 @@ class proof_trim { */ void infer(expr_ref_vector const& clause, app* hint) { - if (hint && !is_rup(hint) && m_checker.check(hint)) { + if (m_empty) + return; + + if (hint && !is_rup(hint)) { auto clause1 = m_checker.clause(hint); if (clause1.size() != clause.size()) { mk_clause(clause1); - trim.assume(m_clauses.size()); clause1.push_back(hint); + trim.assume(m_clauses.size()); m_clauses.push_back(clause1); m_is_infer.push_back(true); - mk_clause(clause); - trim.infer(m_clauses.size()); - m_clauses.push_back(clause); - m_clauses.back().push_back(hint); - m_is_infer.push_back(true); - if (clause.empty()) + + if (clause.empty()) { + mk_clause(clause); + trim.infer(m_clauses.size()); + m_clauses.push_back(clause); + m_clauses.back().push_back(hint); + m_is_infer.push_back(true); + m_empty = true; do_trim(std::cout); + } return; } } + mk_clause(clause); if (is_rup(hint)) trim.infer(m_clauses.size()); @@ -149,20 +158,32 @@ class proof_trim { if (hint) m_clauses.back().push_back(hint); m_is_infer.push_back(true); - if (clause.empty()) + if (clause.empty()) { + m_empty = true; do_trim(std::cout); + } } void updt_params(params_ref const& p) { trim.updt_params(p); } + expr_ref mk_dep(unsigned id, unsigned_vector const& deps) { + arith_util a(m); + expr_ref_vector args(m); + args.push_back(a.mk_int(id)); + for (auto d : deps) + args.push_back(a.mk_int(d)); + return expr_ref(m.mk_app(symbol("deps"), args.size(), args.data(), m.mk_proof_sort()), m); + } + void do_trim(std::ostream& out) { ast_pp_util pp(m); auto ids = trim.trim(); - for (unsigned id : ids) { - auto const& clause = m_clauses[id]; + for (auto const& [id, deps] : ids) { + auto& clause = m_clauses[id]; bool is_infer = m_is_infer[id]; + clause.push_back(mk_dep(id, deps)); for (expr* e : clause) pp.collect(e); @@ -218,8 +239,10 @@ class proof_saver { class proof_cmds_imp : public proof_cmds { cmd_context& ctx; ast_manager& m; + arith_util m_arith; expr_ref_vector m_lits; app_ref m_proof_hint; + unsigned_vector m_deps; bool m_check = true; bool m_save = false; bool m_trim = false; @@ -245,11 +268,24 @@ class proof_cmds_imp : public proof_cmds { m_del = m.mk_app(symbol("del"), 0, nullptr, m.mk_proof_sort()); return m_del; } + + bool is_dep(expr* e) { + return m.is_proof(e) && symbol("deps") == to_app(e)->get_name(); + } + + void get_deps(expr* e) { + rational n; + bool is_int = false; + for (expr* arg : *to_app(e)) + if (m_arith.is_numeral(arg, n, is_int) && n.is_unsigned()) + m_deps.push_back(n.get_unsigned()); + } public: proof_cmds_imp(cmd_context& ctx): ctx(ctx), - m(ctx.m()), + m(ctx.m()), + m_arith(m), m_lits(m), m_proof_hint(m), m_assumption(m), @@ -258,8 +294,14 @@ class proof_cmds_imp : public proof_cmds { } void add_literal(expr* e) override { - if (m.is_proof(e)) - m_proof_hint = to_app(e); + if (m.is_proof(e)) { + if (is_dep(e)) + get_deps(e); + else if (!m_proof_hint) + m_proof_hint = to_app(e); + } + else if (!m.is_bool(e)) + throw default_exception("literal should be either a Proof or Bool"); else m_lits.push_back(e); } @@ -272,9 +314,10 @@ class proof_cmds_imp : public proof_cmds { if (m_trim) trim().assume(m_lits); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, assumption(), m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, assumption(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); m_lits.reset(); m_proof_hint.reset(); + m_deps.reset(); } void end_infer() override { @@ -285,9 +328,10 @@ class proof_cmds_imp : public proof_cmds { if (m_trim) trim().infer(m_lits, m_proof_hint); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, m_proof_hint, m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, m_proof_hint, m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); m_lits.reset(); m_proof_hint.reset(); + m_deps.reset(); } void end_deleted() override { @@ -298,16 +342,17 @@ class proof_cmds_imp : public proof_cmds { if (m_trim) trim().del(m_lits); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, del(), m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, del(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); m_lits.reset(); m_proof_hint.reset(); + m_deps.reset(); } void updt_params(params_ref const& p) override { solver_params sp(p); - m_check = sp.proof_check(); m_save = sp.proof_save(); m_trim = sp.proof_trim(); + m_check = sp.proof_check() && !m_trim && !m_save && !m_on_clause_eh; if (m_trim) trim().updt_params(p); } @@ -315,6 +360,8 @@ class proof_cmds_imp : public proof_cmds { void register_on_clause(void* ctx, user_propagator::on_clause_eh_t& on_clause_eh) override { m_on_clause_ctx = ctx; m_on_clause_eh = on_clause_eh; + if (m_on_clause_eh) + m_check = false; } }; diff --git a/src/cmd_context/pdecl.cpp b/src/cmd_context/pdecl.cpp index b8dd01aea64..f343be94d12 100644 --- a/src/cmd_context/pdecl.cpp +++ b/src/cmd_context/pdecl.cpp @@ -348,6 +348,26 @@ std::ostream& psort_user_decl::display(std::ostream & out) const { return out << ")"; } +// ------------------- +// psort_type_var_decl + +psort_type_var_decl::psort_type_var_decl(unsigned id, pdecl_manager & m, symbol const & n): + psort_decl(id, 0, m, n) { + m_psort_kind = PSORT_TV; +} + +void psort_type_var_decl::finalize(pdecl_manager & m) { + psort_decl::finalize(m); +} + +sort * psort_type_var_decl::instantiate(pdecl_manager & m, unsigned n, sort * const * s) { + return m.m().mk_type_var(m_name); +} + +std::ostream& psort_type_var_decl::display(std::ostream & out) const { + return out << "(declare-type-var " << m_name << ")"; +} + // ------------------- // psort_dt_decl @@ -785,7 +805,7 @@ struct pdecl_manager::sort_info { virtual unsigned obj_size() const { return sizeof(sort_info); } virtual void finalize(pdecl_manager & m) { m.dec_ref(m_decl); } virtual void display(std::ostream & out, pdecl_manager const & m) const = 0; - virtual format * pp(pdecl_manager const & m) const = 0; + virtual format * pp(smt2_pp_environment& env, pdecl_manager const & m) const = 0; }; struct pdecl_manager::app_sort_info : public pdecl_manager::sort_info { @@ -817,15 +837,19 @@ struct pdecl_manager::app_sort_info : public pdecl_manager::sort_info { } } - format * pp(pdecl_manager const & m) const override { + format * pp(smt2_pp_environment& env, pdecl_manager const & m) const override { + symbol s = m_decl->get_name(); + std::string name = s.str(); + if (is_smt2_quoted_symbol(s)) + name = mk_smt2_quoted_symbol(s); if (m_args.empty()) { - return mk_string(m.m(), m_decl->get_name().str()); + return mk_string(m.m(), name); } else { ptr_buffer b; for (auto arg : m_args) - b.push_back(m.pp(arg)); - return mk_seq1(m.m(), b.begin(), b.end(), f2f(), m_decl->get_name().str()); + b.push_back(m.pp(env, arg)); + return mk_seq1(m.m(), b.begin(), b.end(), f2f(), name); } } }; @@ -853,13 +877,18 @@ struct pdecl_manager::indexed_sort_info : public pdecl_manager::sort_info { } } - format * pp(pdecl_manager const & m) const override { + format * pp(smt2_pp_environment& env, pdecl_manager const & m) const override { + symbol s = m_decl->get_name(); + std::string name = s.str(); + if (is_smt2_quoted_symbol(s)) + name = mk_smt2_quoted_symbol(s); + if (m_indices.empty()) { - return mk_string(m.m(), m_decl->get_name().str()); + return mk_string(m.m(), name); } else { ptr_buffer b; - b.push_back(mk_string(m.m(), m_decl->get_name().str())); + b.push_back(mk_string(m.m(), name)); for (auto idx : m_indices) b.push_back(mk_unsigned(m.m(), idx)); return mk_seq1(m.m(), b.begin(), b.end(), f2f(), "_"); @@ -969,6 +998,10 @@ psort_decl * pdecl_manager::mk_psort_dt_decl(unsigned num_params, symbol const & return new (a().allocate(sizeof(psort_dt_decl))) psort_dt_decl(m_id_gen.mk(), num_params, *this, n); } +psort_decl * pdecl_manager::mk_psort_type_var_decl(symbol const & n) { + return new (a().allocate(sizeof(psort_type_var_decl))) psort_type_var_decl(m_id_gen.mk(), *this, n); +} + psort_decl * pdecl_manager::mk_psort_builtin_decl(symbol const & n, family_id fid, decl_kind k) { return new (a().allocate(sizeof(psort_builtin_decl))) psort_builtin_decl(m_id_gen.mk(), *this, n, fid, k); @@ -1072,27 +1105,10 @@ void pdecl_manager::display(std::ostream & out, sort * s) const { out << s->get_name(); } -format * pdecl_manager::pp(sort * s) const { +format * pdecl_manager::pp(smt2_pp_environment& env, sort * s) const { sort_info * info = nullptr; - if (m_sort2info.find(s, info)) { - return info->pp(*this); - } - unsigned num_params = s->get_num_parameters(); - if (s->get_family_id() != null_family_id && num_params > 0) { - // Small hack to display FP and BitVec sorts that were not explicitly referenced by the user. - unsigned i = 0; - for (i = 0; i < num_params; i++) { - if (!s->get_parameter(i).is_int()) - break; - } - if (i == num_params) { - // all parameters are integer - ptr_buffer b; - b.push_back(mk_string(m(), s->get_name().str())); - for (unsigned i = 0; i < num_params; i++) - b.push_back(mk_unsigned(m(), s->get_parameter(i).get_int())); - return mk_seq1(m(), b.begin(), b.end(), f2f(), "_"); - } - } - return mk_string(m(), s->get_name().str()); + if (m_sort2info.find(s, info)) + return info->pp(env, *this); + else + return nullptr; } diff --git a/src/cmd_context/pdecl.h b/src/cmd_context/pdecl.h index a55f782f0c6..a3005f182e1 100644 --- a/src/cmd_context/pdecl.h +++ b/src/cmd_context/pdecl.h @@ -23,6 +23,7 @@ Revision History: #include "util/dictionary.h" #include "ast/format.h" #include "ast/datatype_decl_plugin.h" +#include "ast/ast_smt2_pp.h" class pdecl_manager; @@ -61,7 +62,7 @@ class psort_inst_cache; */ class psort : public pdecl { protected: - psort_inst_cache * m_inst_cache; + psort_inst_cache* m_inst_cache; friend class pdecl_manager; psort(unsigned id, unsigned num_params):pdecl(id, num_params), m_inst_cache(nullptr) {} bool is_psort() const override { return true; } @@ -85,7 +86,7 @@ typedef ptr_hashtable psort_table; #define PSORT_DECL_VAR_PARAMS UINT_MAX -typedef enum { PSORT_BASE = 0, PSORT_USER, PSORT_BUILTIN, PSORT_DT } psort_decl_kind; +typedef enum { PSORT_BASE = 0, PSORT_USER, PSORT_BUILTIN, PSORT_DT, PSORT_TV } psort_decl_kind; class psort_decl : public pdecl { protected: @@ -122,7 +123,19 @@ class psort_user_decl : public psort_decl { sort * instantiate(pdecl_manager & m, unsigned n, sort * const * s) override; std::ostream& display(std::ostream & out) const override; }; - + +class psort_type_var_decl : public psort_decl { +protected: + friend class pdecl_manager; + psort * m_def; + psort_type_var_decl(unsigned id, pdecl_manager & m, symbol const & n); + size_t obj_size() const override { return sizeof(psort_type_var_decl); } + void finalize(pdecl_manager & m) override; +public: + sort * instantiate(pdecl_manager & m, unsigned n, sort * const * s) override; + std::ostream& display(std::ostream & out) const override; +}; + class psort_builtin_decl : public psort_decl { protected: friend class pdecl_manager; @@ -303,6 +316,7 @@ class pdecl_manager { psort_decl * mk_psort_dt_decl(unsigned num_params, symbol const & n); psort_decl * mk_psort_user_decl(unsigned num_params, symbol const & n, psort * def); psort_decl * mk_psort_builtin_decl(symbol const & n, family_id fid, decl_kind k); + psort_decl * mk_psort_type_var_decl(symbol const& n); paccessor_decl * mk_paccessor_decl(unsigned num_params, symbol const & s, ptype const & p); pconstructor_decl * mk_pconstructor_decl(unsigned num_params, symbol const & s, symbol const & r, unsigned num, paccessor_decl * const * as); pdatatype_decl * mk_pdatatype_decl(unsigned num_params, symbol const & s, unsigned num, pconstructor_decl * const * cs); @@ -333,7 +347,7 @@ class pdecl_manager { void save_info(sort * s, psort_decl * d, unsigned num_args, sort * const * args); void save_info(sort * s, psort_decl * d, unsigned num_indices, unsigned const * indices); void display(std::ostream & out, sort * s) const; - format_ns::format * pp(sort * s) const; + format_ns::format * pp(smt2_pp_environment& env, sort * s) const; }; diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index 291af3b5cb7..3ad64acfdab 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -62,6 +62,10 @@ namespace dd { init_nodes(level2var); } + void pdd_manager::set_max_num_nodes(unsigned n) { + m_max_num_nodes = n + m_level2var.size(); + } + void pdd_manager::init_nodes(unsigned_vector const& l2v) { // add dummy nodes for operations, and 0, 1 pdds. for (unsigned i = 0; i < pdd_no_op; ++i) { @@ -109,11 +113,34 @@ namespace dd { pdd pdd_manager::add(rational const& r, pdd const& b) { pdd c(mk_val(r)); return pdd(apply(c.root, b.root, pdd_add_op), this); } pdd pdd_manager::zero() { return pdd(zero_pdd, this); } pdd pdd_manager::one() { return pdd(one_pdd, this); } - - pdd pdd_manager::mk_or(pdd const& p, pdd const& q) { return p + q - (p*q); } - pdd pdd_manager::mk_xor(pdd const& p, pdd const& q) { if (m_semantics == mod2_e) return p + q; return (p*q*2) - p - q; } - pdd pdd_manager::mk_xor(pdd const& p, unsigned x) { pdd q(mk_val(x)); if (m_semantics == mod2_e) return p + q; return (p*q*2) - p - q; } - pdd pdd_manager::mk_not(pdd const& p) { return 1 - p; } + + // NOTE: bit-wise AND cannot be expressed in mod2N_e semantics with the existing operations. + pdd pdd_manager::mk_and(pdd const& p, pdd const& q) { + VERIFY(m_semantics == mod2_e || m_semantics == zero_one_vars_e); + return p * q; + } + + pdd pdd_manager::mk_or(pdd const& p, pdd const& q) { + return p + q - mk_and(p, q); + } + + pdd pdd_manager::mk_xor(pdd const& p, pdd const& q) { + if (m_semantics == mod2_e) + return p + q; + return p + q - 2*mk_and(p, q); + } + + pdd pdd_manager::mk_xor(pdd const& p, unsigned x) { + pdd q(mk_val(x)); + return mk_xor(p, q); + } + + pdd pdd_manager::mk_not(pdd const& p) { + if (m_semantics == mod2N_e) + return -p - 1; + VERIFY(m_semantics == mod2_e || m_semantics == zero_one_vars_e); + return 1 - p; + } pdd pdd_manager::subst_val(pdd const& p, unsigned v, rational const& val) { pdd r = mk_var(v) + val; @@ -169,15 +196,8 @@ namespace dd { if (m_semantics != mod2N_e) return 0; - if (is_val(p)) { - rational v = val(p); - if (v.is_zero()) - return m_power_of_2 + 1; - unsigned r = 0; - while (v.is_even() && v > 0) - r++, v /= 2; - return r; - } + if (is_val(p)) + return val(p).parity(m_power_of_2); init_mark(); PDD q = p; m_todo.push_back(hi(q)); @@ -185,9 +205,9 @@ namespace dd { q = lo(q); m_todo.push_back(hi(q)); } - unsigned p2 = val(q).trailing_zeros(); + unsigned parity = val(q).parity(m_power_of_2); init_mark(); - while (p2 != 0 && !m_todo.empty()) { + while (parity != 0 && !m_todo.empty()) { PDD r = m_todo.back(); m_todo.pop_back(); if (is_marked(r)) @@ -199,11 +219,11 @@ namespace dd { } else if (val(r).is_zero()) continue; - else if (val(r).trailing_zeros() < p2) - p2 = val(r).trailing_zeros(); + else + parity = std::min(parity, val(r).trailing_zeros()); } m_todo.reset(); - return p2; + return parity; } pdd pdd_manager::subst_val(pdd const& p, pdd const& s) { @@ -242,7 +262,7 @@ namespace dd { } pdd_manager::PDD pdd_manager::apply(PDD arg1, PDD arg2, pdd_op op) { - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -251,8 +271,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + count++; } } SASSERT(well_formed()); @@ -503,7 +524,7 @@ namespace dd { if (m_semantics == mod2_e) { return a; } - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -512,8 +533,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + ++count; } } SASSERT(well_formed()); @@ -561,7 +583,7 @@ namespace dd { return true; } SASSERT(c.is_int()); - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -574,8 +596,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + ++count; } } } @@ -1134,6 +1157,7 @@ namespace dd { unsigned pdd_manager::max_pow2_divisor(PDD p) { init_mark(); unsigned min_j = UINT_MAX; + SASSERT(m_todo.empty()); m_todo.push_back(p); while (!m_todo.empty()) { PDD r = m_todo.back(); @@ -1352,9 +1376,8 @@ namespace dd { e->get_data().m_refcount = 0; } if (do_gc) { - if (m_nodes.size() > m_max_num_nodes) { - throw mem_out(); - } + if (m_nodes.size() > m_max_num_nodes) + throw mem_out(); alloc_free_nodes(m_nodes.size()/2); } SASSERT(e->get_data().m_lo == n.m_lo); @@ -1600,7 +1623,8 @@ namespace dd { for (unsigned i = m_nodes.size(); i-- > pdd_no_op; ) { if (!reachable[i]) { if (is_val(i)) { - if (m_freeze_value == val(i)) continue; + if (m_freeze_value == val(i)) + continue; m_free_values.push_back(m_mpq_table.find(val(i)).m_value_index); m_mpq_table.remove(val(i)); } @@ -1615,20 +1639,17 @@ namespace dd { ptr_vector to_delete, to_keep; for (auto* e : m_op_cache) { - if (e->m_result != null_pdd) { - to_delete.push_back(e); - } - else { - to_keep.push_back(e); - } + if (e->m_result != null_pdd) + to_delete.push_back(e); + else + to_keep.push_back(e); } m_op_cache.reset(); - for (op_entry* e : to_delete) { + for (op_entry* e : to_delete) m_alloc.deallocate(sizeof(*e), e); - } - for (op_entry* e : to_keep) { - m_op_cache.insert(e); - } + + for (op_entry* e : to_keep) + m_op_cache.insert(e); m_factor_cache.reset(); @@ -1784,27 +1805,44 @@ namespace dd { } pdd& pdd::operator=(pdd const& other) { + if (m != other.m) { + verbose_stream() << "pdd manager confusion: " << *this << " (mod 2^" << power_of_2() << ") := " << other << " (mod 2^" << other.power_of_2() << ")\n"; + UNREACHABLE(); + // TODO: in the end, this operator should probably be changed to also update the manager. But for now I want to detect such confusions. + reset(*other.m); + } + SASSERT_EQ(power_of_2(), other.power_of_2()); + VERIFY_EQ(power_of_2(), other.power_of_2()); + VERIFY_EQ(m, other.m); unsigned r1 = root; root = other.root; - m.inc_ref(root); - m.dec_ref(r1); + m->inc_ref(root); + m->dec_ref(r1); return *this; } pdd& pdd::operator=(unsigned k) { - m.dec_ref(root); - root = m.mk_val(k).root; - m.inc_ref(root); + m->dec_ref(root); + root = m->mk_val(k).root; + m->inc_ref(root); return *this; } pdd& pdd::operator=(rational const& k) { - m.dec_ref(root); - root = m.mk_val(k).root; - m.inc_ref(root); + m->dec_ref(root); + root = m->mk_val(k).root; + m->inc_ref(root); return *this; } + /* Reset pdd to 0. Allows re-assigning the pdd manager. */ + void pdd::reset(pdd_manager& new_m) { + m->dec_ref(root); + root = 0; + m = &new_m; + SASSERT(is_zero()); + } + rational const& pdd::leading_coefficient() const { pdd p = *this; while (!p.is_val()) @@ -1812,11 +1850,10 @@ namespace dd { return p.val(); } - rational const& pdd::offset() const { - pdd p = *this; - while (!p.is_val()) - p = p.lo(); - return p.val(); + rational const& pdd_manager::offset(PDD p) const { + while (!is_val(p)) + p = lo(p); + return val(p); } pdd pdd::shl(unsigned n) const { @@ -1830,7 +1867,7 @@ namespace dd { pdd pdd::subst_pdd(unsigned v, pdd const& r) const { if (is_val()) return *this; - if (m.m_var2level[var()] < m.m_var2level[v]) + if (m->m_var2level[var()] < m->m_var2level[v]) return *this; pdd l = lo().subst_pdd(v, r); pdd h = hi().subst_pdd(v, r); @@ -1839,7 +1876,7 @@ namespace dd { else if (l == lo() && h == hi()) return *this; else - return m.mk_var(var())*h + l; + return m->mk_var(var())*h + l; } std::pair pdd::var_factors() const { @@ -1870,7 +1907,7 @@ namespace dd { ++i; ++j; } - else if (m.m_var2level[lo_vars[i]] > m.m_var2level[hi_vars[j]]) + else if (m->m_var2level[lo_vars[i]] > m->m_var2level[hi_vars[j]]) hi_vars[jr++] = hi_vars[j++]; else lo_vars[ir++] = lo_vars[i++]; @@ -1881,7 +1918,7 @@ namespace dd { auto mul = [&](unsigned_vector const& vars, pdd p) { for (auto v : vars) - p *= m.mk_var(v); + p *= m->mk_var(v); return p; }; @@ -1907,7 +1944,7 @@ namespace dd { std::ostream& operator<<(std::ostream& out, pdd const& b) { return b.display(out); } void pdd_iterator::next() { - auto& m = m_pdd.m; + auto& m = m_pdd.manager(); while (!m_nodes.empty()) { auto& p = m_nodes.back(); if (p.first && !m.is_val(p.second)) { @@ -1934,13 +1971,16 @@ namespace dd { void pdd_iterator::first() { unsigned n = m_pdd.root; - auto& m = m_pdd.m; + auto& m = m_pdd.manager(); while (!m.is_val(n)) { m_nodes.push_back(std::make_pair(true, n)); m_mono.vars.push_back(m.var(n)); n = m.hi(n); } m_mono.coeff = m.val(n); + // if m_pdd is constant and non-zero, the iterator should return a single monomial + if (m_nodes.empty() && !m_mono.coeff.is_zero()) + m_nodes.push_back(std::make_pair(false, n)); } pdd_iterator pdd::begin() const { return pdd_iterator(*this, true); } @@ -1959,5 +1999,32 @@ namespace dd { return out; } + void pdd_linear_iterator::first() { + m_next = m_pdd.root; + next(); + } + + void pdd_linear_iterator::next() { + SASSERT(m_next != pdd_manager::null_pdd); + auto& m = m_pdd.manager(); + while (!m.is_val(m_next)) { + unsigned const var = m.var(m_next); + rational const val = m.offset(m.hi(m_next)); + m_next = m.lo(m_next); + if (!val.is_zero()) { + m_mono = {val, var}; + return; + } + } + m_next = pdd_manager::null_pdd; + } + + pdd_linear_iterator pdd::pdd_linear_monomials::begin() const { + return pdd_linear_iterator(m_pdd, true); + } + + pdd_linear_iterator pdd::pdd_linear_monomials::end() const { + return pdd_linear_iterator(m_pdd, false); + } -} +} // namespace dd diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index 6dee7977ff6..fa545c06ec0 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -45,6 +45,7 @@ namespace dd { class pdd; class pdd_manager; class pdd_iterator; + class pdd_linear_iterator; class pdd_manager { public: @@ -53,13 +54,14 @@ namespace dd { friend test; friend pdd; friend pdd_iterator; + friend pdd_linear_iterator; typedef unsigned PDD; typedef vector> monomials_t; - const PDD null_pdd = UINT_MAX; - const PDD zero_pdd = 0; - const PDD one_pdd = 1; + static constexpr PDD null_pdd = UINT_MAX; + static constexpr PDD zero_pdd = 0; + static constexpr PDD one_pdd = 1; enum pdd_op { pdd_add_op = 2, @@ -261,6 +263,7 @@ namespace dd { inline PDD lo(PDD p) const { return m_nodes[p].m_lo; } inline PDD hi(PDD p) const { return m_nodes[p].m_hi; } inline rational const& val(PDD p) const { SASSERT(is_val(p)); return m_values[lo(p)]; } + inline rational get_signed_val(PDD p) const { SASSERT(m_semantics == mod2_e || m_semantics == mod2N_e); rational const& a = val(p); return a.get_bit(power_of_2() - 1) ? a - two_to_N() : a; } inline void inc_ref(PDD p) { if (m_nodes[p].m_refcount != max_rc) m_nodes[p].m_refcount++; SASSERT(!m_free_nodes.contains(p)); } inline void dec_ref(PDD p) { if (m_nodes[p].m_refcount != max_rc) m_nodes[p].m_refcount--; SASSERT(!m_free_nodes.contains(p)); } inline PDD level2pdd(unsigned l) const { return m_var2pdd[m_level2var[l]]; } @@ -324,8 +327,9 @@ namespace dd { semantics get_semantics() const { return m_semantics; } void reset(unsigned_vector const& level2var); - void set_max_num_nodes(unsigned n) { m_max_num_nodes = n; } + void set_max_num_nodes(unsigned n); unsigned_vector const& get_level2var() const { return m_level2var; } + unsigned num_nodes() const { return m_nodes.size() - m_free_nodes.size(); } pdd mk_var(unsigned i); pdd mk_val(rational const& r); @@ -340,9 +344,10 @@ namespace dd { pdd mul(rational const& c, pdd const& b); pdd div(pdd const& a, rational const& c); bool try_div(pdd const& a, rational const& c, pdd& out_result); + pdd mk_and(pdd const& p, pdd const& q); pdd mk_or(pdd const& p, pdd const& q); pdd mk_xor(pdd const& p, pdd const& q); - pdd mk_xor(pdd const& p, unsigned q); + pdd mk_xor(pdd const& p, unsigned x); pdd mk_not(pdd const& p); pdd reduce(pdd const& a, pdd const& b); pdd subst_val0(pdd const& a, vector> const& s); @@ -367,6 +372,8 @@ namespace dd { bool is_univariate_in(PDD p, unsigned v); void get_univariate_coefficients(PDD p, vector& coeff); + rational const& offset(PDD p) const; + // create an spoly r if leading monomials of a and b overlap bool try_spoly(pdd const& a, pdd const& b, pdd& r); @@ -398,106 +405,120 @@ namespace dd { friend test; friend class pdd_manager; friend class pdd_iterator; + friend class pdd_linear_iterator; unsigned root; - pdd_manager& m; - pdd(unsigned root, pdd_manager& m): root(root), m(m) { m.inc_ref(root); } - pdd(unsigned root, pdd_manager* _m): root(root), m(*_m) { m.inc_ref(root); } + pdd_manager* m; + pdd(unsigned root, pdd_manager& pm): root(root), m(&pm) { m->inc_ref(root); } + pdd(unsigned root, pdd_manager* pm): root(root), m(pm) { m->inc_ref(root); } public: - pdd(pdd_manager& pm): root(0), m(pm) { SASSERT(is_zero()); } - pdd(pdd const& other): root(other.root), m(other.m) { m.inc_ref(root); } - pdd(pdd && other) noexcept : root(0), m(other.m) { std::swap(root, other.root); } + pdd(pdd_manager& m): pdd(0, m) { SASSERT(is_zero()); } + pdd(pdd const& other): pdd(other.root, other.m) { m->inc_ref(root); } + pdd(pdd && other) noexcept : pdd(0, other.m) { std::swap(root, other.root); } pdd& operator=(pdd const& other); pdd& operator=(unsigned k); pdd& operator=(rational const& k); - ~pdd() { m.dec_ref(root); } - pdd lo() const { return pdd(m.lo(root), m); } - pdd hi() const { return pdd(m.hi(root), m); } + // TODO: pdd& operator=(pdd&& other); (just swap like move constructor?) + ~pdd() { m->dec_ref(root); } + void reset(pdd_manager& new_m); + pdd lo() const { return pdd(m->lo(root), m); } + pdd hi() const { return pdd(m->hi(root), m); } unsigned index() const { return root; } - unsigned var() const { return m.var(root); } - rational const& val() const { SASSERT(is_val()); return m.val(root); } + unsigned var() const { return m->var(root); } + rational const& val() const { return m->val(root); } + rational get_signed_val() const { return m->get_signed_val(root); } rational const& leading_coefficient() const; - rational const& offset() const; - bool is_val() const { return m.is_val(root); } - bool is_one() const { return m.is_one(root); } - bool is_zero() const { return m.is_zero(root); } - bool is_linear() const { return m.is_linear(root); } - bool is_var() const { return m.is_var(root); } - bool is_max() const { return m.is_max(root); } + rational const& offset() const { return m->offset(root); } + bool is_val() const { return m->is_val(root); } + bool is_one() const { return m->is_one(root); } + bool is_zero() const { return m->is_zero(root); } + bool is_linear() const { return m->is_linear(root); } + bool is_var() const { return m->is_var(root); } + bool is_max() const { return m->is_max(root); } /** Polynomial is of the form a * x + b for some numerals a, b. */ bool is_unilinear() const { return !is_val() && lo().is_val() && hi().is_val(); } /** Polynomial is of the form a * x for some numeral a. */ bool is_unary() const { return !is_val() && lo().is_zero() && hi().is_val(); } bool is_offset() const { return !is_val() && lo().is_val() && hi().is_one(); } - bool is_binary() const { return m.is_binary(root); } - bool is_monomial() const { return m.is_monomial(root); } - bool is_univariate() const { return m.is_univariate(root); } - bool is_univariate_in(unsigned v) const { return m.is_univariate_in(root, v); } - void get_univariate_coefficients(vector& coeff) const { m.get_univariate_coefficients(root, coeff); } - vector get_univariate_coefficients() const { vector coeff; m.get_univariate_coefficients(root, coeff); return coeff; } - bool is_never_zero() const { return m.is_never_zero(root); } - unsigned min_parity() const { return m.min_parity(root); } - bool var_is_leaf(unsigned v) const { return m.var_is_leaf(root, v); } - - pdd operator-() const { return m.minus(*this); } - pdd operator+(pdd const& other) const { return m.add(*this, other); } - pdd operator-(pdd const& other) const { return m.sub(*this, other); } - pdd operator*(pdd const& other) const { return m.mul(*this, other); } - pdd operator&(pdd const& other) const { return m.mul(*this, other); } - pdd operator|(pdd const& other) const { return m.mk_or(*this, other); } - pdd operator^(pdd const& other) const { return m.mk_xor(*this, other); } - pdd operator^(unsigned other) const { return m.mk_xor(*this, other); } - - pdd operator*(rational const& other) const { return m.mul(other, *this); } - pdd operator+(rational const& other) const { return m.add(other, *this); } - pdd operator~() const { return m.mk_not(*this); } + bool is_binary() const { return m->is_binary(root); } + bool is_monomial() const { return m->is_monomial(root); } + bool is_univariate() const { return m->is_univariate(root); } + bool is_univariate_in(unsigned v) const { return m->is_univariate_in(root, v); } + void get_univariate_coefficients(vector& coeff) const { m->get_univariate_coefficients(root, coeff); } + vector get_univariate_coefficients() const { vector coeff; m->get_univariate_coefficients(root, coeff); return coeff; } + bool is_never_zero() const { return m->is_never_zero(root); } + unsigned min_parity() const { return m->min_parity(root); } + bool var_is_leaf(unsigned v) const { return m->var_is_leaf(root, v); } + + pdd operator-() const { return m->minus(*this); } + pdd operator+(pdd const& other) const { VERIFY_EQ(m, other.m); return m->add(*this, other); } + pdd operator-(pdd const& other) const { VERIFY_EQ(m, other.m); return m->sub(*this, other); } + pdd operator*(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mul(*this, other); } + pdd operator&(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_and(*this, other); } + pdd operator|(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_or(*this, other); } + pdd operator^(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_xor(*this, other); } + pdd operator^(unsigned other) const { return m->mk_xor(*this, m->mk_val(other)); } + + pdd operator*(rational const& other) const { return m->mul(other, *this); } + pdd operator+(rational const& other) const { return m->add(other, *this); } + pdd operator~() const { return m->mk_not(*this); } pdd shl(unsigned n) const; - pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); } - pdd div(rational const& other) const { return m.div(*this, other); } - bool try_div(rational const& other, pdd& out_result) const { return m.try_div(*this, other, out_result); } - pdd pow(unsigned j) const { return m.pow(*this, j); } - pdd reduce(pdd const& other) const { return m.reduce(*this, other); } - bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } - void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); } - bool factor(unsigned v, unsigned degree, pdd& lc) const { return m.factor(*this, v, degree, lc); } - bool resolve(unsigned v, pdd const& other, pdd& result) { return m.resolve(v, *this, other, result); } - pdd reduce(unsigned v, pdd const& other) const { return m.reduce(v, *this, other); } + pdd rev_sub(rational const& r) const { return m->sub(m->mk_val(r), *this); } + pdd div(rational const& other) const { return m->div(*this, other); } + bool try_div(rational const& other, pdd& out_result) const { VERIFY_EQ(m, out_result.m); return m->try_div(*this, other, out_result); } + pdd pow(unsigned j) const { return m->pow(*this, j); } + pdd reduce(pdd const& other) const { VERIFY_EQ(m, other.m); return m->reduce(*this, other); } + bool different_leading_term(pdd const& other) const { VERIFY_EQ(m, other.m); return m->different_leading_term(*this, other); } + void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { VERIFY_EQ(m, lc.m); VERIFY_EQ(m, rest.m); m->factor(*this, v, degree, lc, rest); } + bool factor(unsigned v, unsigned degree, pdd& lc) const { VERIFY_EQ(m, lc.m); return m->factor(*this, v, degree, lc); } + bool resolve(unsigned v, pdd const& other, pdd& result) { VERIFY_EQ(m, other.m); VERIFY_EQ(m, result.m); return m->resolve(v, *this, other, result); } + pdd reduce(unsigned v, pdd const& other) const { VERIFY_EQ(m, other.m); return m->reduce(v, *this, other); } /** * \brief factor out variables */ std::pair var_factors() const; - pdd subst_val0(vector> const& s) const { return m.subst_val0(*this, s); } - pdd subst_val(pdd const& s) const { return m.subst_val(*this, s); } - pdd subst_val(unsigned v, rational const& val) const { return m.subst_val(*this, v, val); } - pdd subst_add(unsigned var, rational const& val) const { return m.subst_add(*this, var, val); } - bool subst_get(unsigned var, rational& out_val) const { return m.subst_get(*this, var, out_val); } + pdd subst_val0(vector> const& s) const { return m->subst_val0(*this, s); } + pdd subst_val(pdd const& s) const { VERIFY_EQ(m, s.m); return m->subst_val(*this, s); } + pdd subst_val(unsigned v, rational const& val) const { return m->subst_val(*this, v, val); } + pdd subst_add(unsigned var, rational const& val) const { return m->subst_add(*this, var, val); } + bool subst_get(unsigned var, rational& out_val) const { return m->subst_get(*this, var, out_val); } /** * \brief substitute variable v by r. */ pdd subst_pdd(unsigned v, pdd const& r) const; - std::ostream& display(std::ostream& out) const { return m.display(out, *this); } - bool operator==(pdd const& other) const { return root == other.root; } - bool operator!=(pdd const& other) const { return root != other.root; } + std::ostream& display(std::ostream& out) const { return m->display(out, *this); } + bool operator==(pdd const& other) const { return root == other.root && m == other.m; } + bool operator!=(pdd const& other) const { return !operator==(other); } unsigned hash() const { return root; } - unsigned power_of_2() const { return m.power_of_2(); } + unsigned power_of_2() const { return m->power_of_2(); } - unsigned dag_size() const { return m.dag_size(*this); } - double tree_size() const { return m.tree_size(*this); } - unsigned degree() const { return m.degree(*this); } - unsigned degree(unsigned v) const { return m.degree(root, v); } - unsigned max_pow2_divisor() const { return m.max_pow2_divisor(root); } - unsigned_vector const& free_vars() const { return m.free_vars(*this); } + unsigned dag_size() const { return m->dag_size(*this); } + double tree_size() const { return m->tree_size(*this); } + unsigned degree() const { return m->degree(*this); } + unsigned degree(unsigned v) const { return m->degree(root, v); } + unsigned max_pow2_divisor() const { return m->max_pow2_divisor(root); } + unsigned_vector const& free_vars() const { return m->free_vars(*this); } - void swap(pdd& other) { std::swap(root, other.root); } + void swap(pdd& other) noexcept { VERIFY_EQ(m, other.m); std::swap(root, other.root); } pdd_iterator begin() const; pdd_iterator end() const; - pdd_manager& manager() const { return m; } + class pdd_linear_monomials { + friend class pdd; + pdd const& m_pdd; + pdd_linear_monomials(pdd const& p): m_pdd(p) {} + public: + pdd_linear_iterator begin() const; + pdd_linear_iterator end() const; + }; + pdd_linear_monomials linear_monomials() const { return pdd_linear_monomials(*this); } + + pdd_manager& manager() const { return *m; } }; inline pdd operator*(rational const& r, pdd const& b) { return b * r; } @@ -526,7 +547,7 @@ namespace dd { inline pdd& operator-=(pdd & p, rational const& q) { p = p - q; return p; } inline pdd& operator+=(pdd & p, rational const& q) { p = p + q; return p; } - inline void swap(pdd& p, pdd& q) { p.swap(q); } + inline void swap(pdd& p, pdd& q) noexcept { p.swap(q); } std::ostream& operator<<(std::ostream& out, pdd const& b); @@ -551,7 +572,27 @@ namespace dd { pdd_iterator& operator++() { next(); return *this; } pdd_iterator operator++(int) { auto tmp = *this; next(); return tmp; } bool operator==(pdd_iterator const& other) const { return m_nodes == other.m_nodes; } - bool operator!=(pdd_iterator const& other) const { return m_nodes != other.m_nodes; } + bool operator!=(pdd_iterator const& other) const { return !operator==(other); } + }; + + class pdd_linear_iterator { + friend class pdd::pdd_linear_monomials; + pdd m_pdd; + std::pair m_mono; + pdd_manager::PDD m_next = pdd_manager::null_pdd; + pdd_linear_iterator(pdd const& p, bool at_start): m_pdd(p) { if (at_start) first(); } + void first(); + void next(); + public: + using value_type = std::pair; // coefficient and variable + using reference = value_type const&; + using pointer = value_type const*; + reference operator*() const { return m_mono; } + pointer operator->() const { return &m_mono; } + pdd_linear_iterator& operator++() { next(); return *this; } + pdd_linear_iterator operator++(int) { auto tmp = *this; next(); return tmp; } + bool operator==(pdd_linear_iterator const& other) const { return m_next == other.m_next; } + bool operator!=(pdd_linear_iterator const& other) const { return m_next != other.m_next; } }; class val_pp { diff --git a/src/math/grobner/pdd_simplifier.cpp b/src/math/grobner/pdd_simplifier.cpp index 4a81406818d..6364d9ae62e 100644 --- a/src/math/grobner/pdd_simplifier.cpp +++ b/src/math/grobner/pdd_simplifier.cpp @@ -75,7 +75,7 @@ namespace dd { } } catch (pdd_manager::mem_out) { - IF_VERBOSE(2, verbose_stream() << "simplifier memout\n"); + IF_VERBOSE(1, verbose_stream() << "simplifier memout\n"); // done reduce DEBUG_CODE(s.invariant();); } @@ -89,12 +89,13 @@ namespace dd { bool simplifier::simplify_linear_step(bool binary) { TRACE("dd.solver", tout << "binary " << binary << "\n";); - IF_VERBOSE(2, verbose_stream() << "binary " << binary << "\n"); + IF_VERBOSE(3, verbose_stream() << "binary " << binary << "\n"); equation_vector linear; for (equation* e : s.m_to_simplify) { pdd p = e->poly(); if (binary) { - if (p.is_binary()) linear.push_back(e); + if (p.is_binary()) + linear.push_back(e); } else if (p.is_linear()) { linear.push_back(e); @@ -112,29 +113,33 @@ namespace dd { use_list_t use_list = get_use_list(); compare_top_var ctv; std::stable_sort(linear.begin(), linear.end(), ctv); - equation_vector trivial; + struct trivial { + solver& s; + equation_vector elems; + trivial(solver& s) : s(s) {} + ~trivial () { + for (auto* e : elems) + s.del_equation(e); + } + }; + trivial trivial(s); unsigned j = 0; bool has_conflict = false; for (equation* src : linear) { - if (has_conflict) { - break; - } - if (s.is_trivial(*src)) { - continue; - } + if (has_conflict) + break; + if (s.is_trivial(*src)) + continue; unsigned v = src->poly().var(); equation_vector const& uses = use_list[v]; TRACE("dd.solver", s.display(tout << "uses of: ", *src) << "\n"; - for (equation* e : uses) { - s.display(tout, *e) << "\n"; - }); + for (equation* e : uses) s.display(tout, *e) << "\n";); bool changed_leading_term; bool all_reduced = true; for (equation* dst : uses) { - if (src == dst || s.is_trivial(*dst)) { - continue; - } + if (src == dst || s.is_trivial(*dst)) + continue; pdd q = dst->poly(); if (!src->poly().is_binary() && !q.is_linear()) { all_reduced = false; @@ -142,9 +147,8 @@ namespace dd { } remove_from_use(dst, use_list, v); s.simplify_using(*dst, *src, changed_leading_term); - if (s.is_trivial(*dst)) { - trivial.push_back(dst); - } + if (s.is_trivial(*dst)) + trivial.elems.push_back(dst); else if (s.is_conflict(dst)) { s.pop_equation(dst); s.set_conflict(dst); @@ -158,9 +162,8 @@ namespace dd { // SASSERT(!dst->poly().free_vars().contains(v)); add_to_use(dst, use_list); } - if (all_reduced) { - linear[j++] = src; - } + if (all_reduced) + linear[j++] = src; } if (!has_conflict) { linear.shrink(j); @@ -169,9 +172,7 @@ namespace dd { s.push_equation(solver::solved, src); } } - for (equation* e : trivial) { - s.del_equation(e); - } + DEBUG_CODE(s.invariant();); return j > 0 || has_conflict; } @@ -184,11 +185,12 @@ namespace dd { */ bool simplifier::simplify_cc_step() { TRACE("dd.solver", tout << "cc\n";); - IF_VERBOSE(2, verbose_stream() << "cc\n"); + IF_VERBOSE(3, verbose_stream() << "cc\n"); u_map los; bool reduced = false; - unsigned j = 0; - for (equation* eq1 : s.m_to_simplify) { + solver::scoped_update sc(s.m_to_simplify); + for (; sc.i < sc.sz; ++sc.i) { + auto* eq1 = sc.get(); SASSERT(eq1->state() == solver::to_simplify); pdd p = eq1->poly(); equation* eq2 = los.insert_if_not_there(p.lo().index(), eq1); @@ -201,14 +203,11 @@ namespace dd { s.retire(eq1); continue; } - else if (s.check_conflict(*eq1)) { - continue; - } + else if (s.check_conflict(*eq1)) + continue; } - s.m_to_simplify[j] = eq1; - eq1->set_index(j++); + sc.nextj(); } - s.m_to_simplify.shrink(j); return reduced; } @@ -217,7 +216,7 @@ namespace dd { */ bool simplifier::simplify_leaf_step() { TRACE("dd.solver", tout << "leaf\n";); - IF_VERBOSE(2, verbose_stream() << "leaf\n"); + IF_VERBOSE(3, verbose_stream() << "leaf\n"); use_list_t use_list = get_use_list(); equation_vector leaves; for (unsigned i = 0; i < s.m_to_simplify.size(); ++i) { @@ -225,15 +224,12 @@ namespace dd { pdd p = e->poly(); if (p.is_val()) continue; - if (!p.hi().is_val()) { - continue; - } + if (!p.hi().is_val()) + continue; leaves.reset(); - for (equation* e2 : use_list[p.var()]) { - if (e != e2 && e2->poly().var_is_leaf(p.var())) { - leaves.push_back(e2); - } - } + for (equation* e2 : use_list[p.var()]) + if (e != e2 && e2->poly().var_is_leaf(p.var())) + leaves.push_back(e2); for (equation* e2 : leaves) { bool changed_leading_term; remove_from_use(e2, use_list); @@ -262,24 +258,21 @@ namespace dd { */ bool simplifier::simplify_elim_pure_step() { TRACE("dd.solver", tout << "pure\n";); - IF_VERBOSE(2, verbose_stream() << "pure\n"); - use_list_t use_list = get_use_list(); - unsigned j = 0; - for (equation* e : s.m_to_simplify) { + IF_VERBOSE(3, verbose_stream() << "pure\n"); + use_list_t use_list = get_use_list(); + solver::scoped_update sc(s.m_to_simplify); + bool has_solved = false; + for (; sc.i < sc.sz; ++sc.i) { + equation* e = sc.get(); pdd p = e->poly(); if (!p.is_val() && p.hi().is_val() && use_list[p.var()].size() == 1) { s.push_equation(solver::solved, e); + has_solved = true; } - else { - s.m_to_simplify[j] = e; - e->set_index(j++); - } + else + sc.nextj(); } - if (j != s.m_to_simplify.size()) { - s.m_to_simplify.shrink(j); - return true; - } - return false; + return has_solved; } /** @@ -288,63 +281,59 @@ namespace dd { */ bool simplifier::simplify_elim_dual_step() { use_list_t use_list = get_use_list(); - unsigned j = 0; bool reduced = false; - for (unsigned i = 0; i < s.m_to_simplify.size(); ++i) { - equation* e = s.m_to_simplify[i]; - pdd p = e->poly(); - // check that e is linear in top variable. - if (e->state() != solver::to_simplify) { - reduced = true; - } - else if (!s.done() && !s.is_trivial(*e) && p.hi().is_val() && use_list[p.var()].size() == 2) { - for (equation* e2 : use_list[p.var()]) { - if (e2 == e) continue; - bool changed_leading_term; - - remove_from_use(e2, use_list); - s.simplify_using(*e2, *e, changed_leading_term); - if (s.is_conflict(e2)) { - s.pop_equation(e2); - s.set_conflict(e2); - } - // when e2 is trivial, leading term is changed - SASSERT(!s.is_trivial(*e2) || changed_leading_term); - if (changed_leading_term) { - s.pop_equation(e2); - s.push_equation(solver::to_simplify, e2); + { + solver::scoped_update sc(s.m_to_simplify); + for (; sc.i < sc.sz; ++sc.i) { + equation* e = sc.get(); + pdd p = e->poly(); + // check that e is linear in top variable. + if (e->state() != solver::to_simplify) { + reduced = true; + } + else if (!s.done() && !s.is_trivial(*e) && p.hi().is_val() && use_list[p.var()].size() == 2) { + for (equation* e2 : use_list[p.var()]) { + if (e2 == e) + continue; + bool changed_leading_term; + + remove_from_use(e2, use_list); + s.simplify_using(*e2, *e, changed_leading_term); + if (s.is_conflict(e2)) { + s.pop_equation(e2); + s.set_conflict(e2); + } + // when e2 is trivial, leading term is changed + SASSERT(!s.is_trivial(*e2) || changed_leading_term); + if (changed_leading_term) { + s.pop_equation(e2); + s.push_equation(solver::to_simplify, e2); + } + add_to_use(e2, use_list); + break; } - add_to_use(e2, use_list); - break; + reduced = true; + s.push_equation(solver::solved, e); } - reduced = true; - s.push_equation(solver::solved, e); - } - else { - s.m_to_simplify[j] = e; - e->set_index(j++); + else + sc.nextj(); } } if (reduced) { // clean up elements in s.m_to_simplify // they may have moved. - s.m_to_simplify.shrink(j); - j = 0; - for (equation* e : s.m_to_simplify) { - if (s.is_trivial(*e)) { - s.retire(e); - } - else if (e->state() == solver::to_simplify) { - s.m_to_simplify[j] = e; - e->set_index(j++); - } + solver::scoped_update sc(s.m_to_simplify); + for (; sc.i < sc.sz; ++sc.i) { + equation* e = sc.get(); + if (s.is_trivial(*e)) + s.retire(e); + else if (e->state() == solver::to_simplify) + sc.nextj(); } - s.m_to_simplify.shrink(j); return true; } - else { - return false; - } + else + return false; } void simplifier::add_to_use(equation* e, use_list_t& use_list) { diff --git a/src/math/grobner/pdd_solver.cpp b/src/math/grobner/pdd_solver.cpp index 63c5ad83545..10f1eb88ffa 100644 --- a/src/math/grobner/pdd_solver.cpp +++ b/src/math/grobner/pdd_solver.cpp @@ -59,9 +59,10 @@ namespace dd { */ - solver::solver(reslimit& lim, pdd_manager& m) : + solver::solver(reslimit& lim, u_dependency_manager& dm, pdd_manager& m) : m(m), - m_limit(lim) + m_limit(lim), + m_dep_manager(dm) {} solver::~solver() { @@ -89,11 +90,9 @@ namespace dd { } void solver::saturate() { - simplify(); - if (done()) { - return; - } - init_saturate(); + if (done()) + return; + init_saturate(); TRACE("dd.solver", display(tout);); try { while (!done() && step()) { @@ -104,7 +103,7 @@ namespace dd { DEBUG_CODE(invariant();); } catch (pdd_manager::mem_out) { - IF_VERBOSE(2, verbose_stream() << "mem-out\n"); + IF_VERBOSE(1, verbose_stream() << "mem-out saturate\n"); // don't reduce further } } @@ -123,8 +122,7 @@ namespace dd { solver::scoped_process::~scoped_process() { if (e) { - pdd p = e->poly(); - SASSERT(!p.is_val()); + SASSERT(!e->poly().is_val()); g.push_equation(processed, e); } } @@ -136,9 +134,8 @@ namespace dd { void solver::superpose(equation const & eq) { - for (equation* target : m_processed) { - superpose(eq, *target); - } + for (equation* target : m_processed) + superpose(eq, *target); } /* @@ -165,32 +162,28 @@ namespace dd { TRACE("dd.solver", display(tout << "simplification result: ", eq);); } + void solver::well_formed() { + auto& set = m_to_simplify; + for (unsigned k = 0; k < set.size(); ++k) + for (unsigned l = k + 1; l < set.size(); ++l) { + if (!set[l] || !set[k] || set[k] != set[l]) + continue; + verbose_stream() << k << " " << l << " " << set[k] << "\n"; + for (auto* s : set) + verbose_stream() << s->idx() << "\n"; + VERIFY(set[k] != set[l]); + } + } /* Use the given equation to simplify equations in set */ - void solver::simplify_using(equation_vector& set, std::function& simplifier) { - struct scoped_update { - equation_vector& set; - unsigned i, j, sz; - scoped_update(equation_vector& set): set(set), i(0), j(0), sz(set.size()) {} - void nextj() { - set[j] = set[i]; - set[i]->set_index(j++); - } - ~scoped_update() { - for (; i < sz; ++i) - nextj(); - set.shrink(j); - } - }; - + void solver::simplify_using(equation_vector& set, std::function& simplifier) { scoped_update sr(set); for (; sr.i < sr.sz; ++sr.i) { equation& target = *set[sr.i]; bool changed_leading_term = false; bool simplified = true; simplified = !done() && simplifier(target, changed_leading_term); - if (simplified && is_trivial(target)) retire(&target); @@ -285,21 +278,32 @@ namespace dd { m_stats.m_compute_steps++; IF_VERBOSE(3, if (m_stats.m_compute_steps % 100 == 0) verbose_stream() << "compute steps = " << m_stats.m_compute_steps << "\n";); equation* e = pick_next(); - if (!e) return false; + if (!e) + return false; scoped_process sd(*this, e); equation& eq = *e; SASSERT(eq.state() == to_simplify); simplify_using(eq, m_processed); - if (is_trivial(eq)) { sd.e = nullptr; retire(e); return true; } - if (check_conflict(eq)) { sd.e = nullptr; return false; } + if (is_trivial(eq)) { + sd.e = nullptr; + retire(e); + return true; + } + if (check_conflict(eq)) { + sd.e = nullptr; + return false; + } m_too_complex = false; simplify_using(m_processed, eq); - if (done()) return false; + if (done()) + return false; TRACE("dd.solver", display(tout << "eq = ", eq);); superpose(eq); simplify_using(m_to_simplify, eq); - if (done()) return false; - if (!m_too_complex) sd.done(); + if (done()) + return false; + if (!m_too_complex) + sd.done(); return true; } @@ -344,9 +348,9 @@ namespace dd { } void solver::reset() { - for (equation* e : m_solved) dealloc(e); - for (equation* e : m_to_simplify) dealloc(e); - for (equation* e : m_processed) dealloc(e); + for (equation* e : m_solved) dealloc(e); + for (equation* e : m_to_simplify) dealloc(e); + for (equation* e : m_processed) dealloc(e); m_subst.reset(); m_solved.reset(); m_processed.reset(); @@ -444,7 +448,6 @@ namespace dd { #endif } - void solver::pop_equation(equation& eq) { equation_vector& v = get_queue(eq); unsigned idx = eq.idx(); diff --git a/src/math/grobner/pdd_solver.h b/src/math/grobner/pdd_solver.h index 40f8fdce25e..872fef5fd55 100644 --- a/src/math/grobner/pdd_solver.h +++ b/src/math/grobner/pdd_solver.h @@ -49,30 +49,17 @@ class solver { }; struct config { - unsigned m_eqs_threshold; - unsigned m_expr_size_limit; - unsigned m_expr_degree_limit; - unsigned m_max_steps; - unsigned m_max_simplified; - unsigned m_random_seed; - bool m_enable_exlin; - unsigned m_eqs_growth; - unsigned m_expr_size_growth; - unsigned m_expr_degree_growth; - unsigned m_number_of_conflicts_to_report; - config() : - m_eqs_threshold(UINT_MAX), - m_expr_size_limit(UINT_MAX), - m_expr_degree_limit(UINT_MAX), - m_max_steps(UINT_MAX), - m_max_simplified(UINT_MAX), - m_random_seed(0), - m_enable_exlin(false), - m_eqs_growth(10), - m_expr_size_growth(10), - m_expr_degree_growth(5), - m_number_of_conflicts_to_report(1) - {} + unsigned m_eqs_threshold = UINT_MAX; + unsigned m_expr_size_limit = UINT_MAX; + unsigned m_expr_degree_limit = UINT_MAX; + unsigned m_max_steps = UINT_MAX; + unsigned m_max_simplified = UINT_MAX; + unsigned m_random_seed = 0; + bool m_enable_exlin = false; + unsigned m_eqs_growth = 10; + unsigned m_expr_size_growth = 10; + unsigned m_expr_degree_growth = 5; + unsigned m_number_of_conflicts_to_report = 1; }; enum eq_state { @@ -82,18 +69,14 @@ class solver { }; class equation { - eq_state m_state; - unsigned m_idx; //!< unique index + eq_state m_state = to_simplify; + unsigned m_idx = 0; //!< unique index pdd m_poly; //!< polynomial in pdd form u_dependency * m_dep; //!< justification for the equality public: equation(pdd const& p, u_dependency* d): - m_state(to_simplify), - m_idx(0), m_poly(p), - m_dep(d) - { - + m_dep(d) { } const pdd& poly() const { return m_poly; } @@ -105,13 +88,38 @@ class solver { void set_state(eq_state st) { m_state = st; } void set_index(unsigned idx) { m_idx = idx; } }; -private: typedef ptr_vector equation_vector; + + struct scoped_update { + equation_vector& set; + unsigned i = 0; + unsigned j = 0; + unsigned sz; + scoped_update(equation_vector& set) : + set(set), sz(set.size()) { + } + ~scoped_update() { + for (; i < sz; ++i) + nextj(); + set.shrink(j); + } + equation* get() { return set[i]; } + + void nextj() { + set[j] = set[i]; + set[i]->set_index(j++); + } + }; + +private: + + typedef std::function print_dep_t; pdd_manager& m; reslimit& m_limit; + u_dependency_manager& m_dep_manager; stats m_stats; config m_config; print_dep_t m_print_dep; @@ -119,12 +127,11 @@ class solver { equation_vector m_processed; equation_vector m_to_simplify; vector> m_subst; - mutable u_dependency_manager m_dep_manager; equation_vector m_all_eqs; equation* m_conflict = nullptr; bool m_too_complex; public: - solver(reslimit& lim, pdd_manager& m); + solver(reslimit& lim, u_dependency_manager& dm, pdd_manager& m); ~solver(); pdd_manager& get_manager() { return m; } @@ -144,7 +151,6 @@ class solver { void saturate(); equation_vector const& equations(); - u_dependency_manager& dep() const { return m_dep_manager; } void collect_statistics(statistics & st) const; std::ostream& display(std::ostream& out, const equation& eq) const; @@ -192,6 +198,7 @@ class solver { void push_equation(eq_state st, equation& eq); void push_equation(eq_state st, equation* eq) { push_equation(st, *eq); } + void well_formed(); void invariant() const; struct scoped_process { solver& g; diff --git a/src/math/interval/dep_intervals.h b/src/math/interval/dep_intervals.h index d641a294d6f..f4774c1864c 100644 --- a/src/math/interval/dep_intervals.h +++ b/src/math/interval/dep_intervals.h @@ -27,6 +27,7 @@ #include "math/interval/interval.h" class dep_intervals { + public: enum with_deps_t { with_deps, without_deps }; @@ -142,8 +143,9 @@ class dep_intervals { public: typedef interval_manager::interval interval; + u_dependency_manager& m_dep_manager; mutable unsynch_mpq_manager m_num_manager; - mutable u_dependency_manager m_dep_manager; + im_config m_config; mutable interval_manager m_imanager; @@ -158,9 +160,10 @@ class dep_intervals { public: u_dependency_manager& dep_manager() { return m_dep_manager; } - dep_intervals(reslimit& lim) : - m_config(m_num_manager, m_dep_manager), - m_imanager(lim, im_config(m_num_manager, m_dep_manager)) + dep_intervals(u_dependency_manager& dm, reslimit& lim) : + m_dep_manager(dm), + m_config(m_num_manager, dm), + m_imanager(lim, im_config(m_num_manager, dm)) {} std::ostream& display(std::ostream& out, const interval& i) const; @@ -172,6 +175,8 @@ class dep_intervals { void set_upper_is_inf(interval& a, bool inf) const { m_config.set_upper_is_inf(a, inf); } void set_lower_dep(interval& a, u_dependency* d) const { m_config.set_lower_dep(a, d); } void set_upper_dep(interval& a, u_dependency* d) const { m_config.set_upper_dep(a, d); } + u_dependency* get_lower_dep(interval const& a) const { return a.m_lower_dep; } + u_dependency* get_upper_dep(interval const& a) const { return a.m_upper_dep; } void reset(interval& a) const { set_lower_is_inf(a, true); set_upper_is_inf(a, true); } void set_value(interval& a, rational const& n) const { set_lower(a, n); @@ -335,15 +340,17 @@ class dep_intervals { bool is_empty(interval const& a) const; void set_interval_for_scalar(interval&, const rational&); + template void linearize(u_dependency* dep, T& expl) const { vector v; m_dep_manager.linearize(dep, v); - for (unsigned ci: v) + for (auto ci: v) expl.push_back(ci); } - void reset() { m_dep_manager.reset(); } + + void reset() { } void del(interval& i) { m_imanager.del(i); } diff --git a/src/math/lp/.clang-format b/src/math/lp/.clang-format new file mode 100644 index 00000000000..d7f8d6171ad --- /dev/null +++ b/src/math/lp/.clang-format @@ -0,0 +1,5 @@ +BasedOnStyle: Google +IndentWidth: 4 +ColumnLimit: 0 +NamespaceIndentation: All +BreakBeforeBraces: Stroustrup \ No newline at end of file diff --git a/src/math/lp/bound_analyzer_on_row.h b/src/math/lp/bound_analyzer_on_row.h index 0008a0ee90b..20e45a2188a 100644 --- a/src/math/lp/bound_analyzer_on_row.h +++ b/src/math/lp/bound_analyzer_on_row.h @@ -26,6 +26,8 @@ Revision History: #include "math/lp/test_bound_analyzer.h" namespace lp { + + template // C plays a role of a container, B - lp_bound_propagator class bound_analyzer_on_row { const C& m_row; @@ -91,39 +93,17 @@ public : } bool bound_is_available(unsigned j, bool lower_bound) { - return (lower_bound && lower_bound_is_available(j)) || - (!lower_bound && upper_bound_is_available(j)); - } - - bool upper_bound_is_available(unsigned j) const { - switch (m_bp.get_column_type(j)) { - case column_type::fixed: - case column_type::boxed: - case column_type::upper_bound: - return true; - default: - return false; - } - } - - bool lower_bound_is_available(unsigned j) const { - switch (m_bp.get_column_type(j)) { - case column_type::fixed: - case column_type::boxed: - case column_type::lower_bound: - return true; - default: - return false; - } + return (lower_bound && m_bp.lower_bound_is_available(j)) || + (!lower_bound && m_bp.upper_bound_is_available(j)); } const impq & ub(unsigned j) const { - lp_assert(upper_bound_is_available(j)); + lp_assert(m_bp.upper_bound_is_available(j)); return m_bp.get_upper_bound(j); } const impq & lb(unsigned j) const { - lp_assert(lower_bound_is_available(j)); + lp_assert(m_bp.lower_bound_is_available(j)); return m_bp.get_lower_bound(j); } @@ -301,10 +281,32 @@ public : // */ // } - void limit_j(unsigned j, const mpq& u, bool coeff_before_j_is_pos, bool is_lower_bound, bool strict){ - m_bp.try_add_bound(u, j, is_lower_bound, coeff_before_j_is_pos, m_row_index, strict); + void limit_j(unsigned bound_j, const mpq& u, bool coeff_before_j_is_pos, bool is_lower_bound, bool strict) + { + unsigned row_index = this->m_row_index; + auto* lar = &m_bp.lp(); + auto explain = [bound_j, coeff_before_j_is_pos, is_lower_bound, strict, row_index, lar]() { + (void) strict; + TRACE("bound_analyzer", tout << "explain_bound_on_var_on_coeff, bound_j = " << bound_j << ", coeff_before_j_is_pos = " << coeff_before_j_is_pos << ", is_lower_bound = " << is_lower_bound << ", strict = " << strict << ", row_index = " << row_index << "\n";); + int bound_sign = (is_lower_bound ? 1 : -1); + int j_sign = (coeff_before_j_is_pos ? 1 : -1) * bound_sign; + + u_dependency* ret = nullptr; + for (auto const& r : lar->get_row(row_index)) { + unsigned j = r.var(); + if (j == bound_j) + continue; + mpq const& a = r.coeff(); + int a_sign = is_pos(a) ? 1 : -1; + int sign = j_sign * a_sign; + u_dependency* witness = sign > 0 ? lar->get_column_upper_bound_witness(j) : lar->get_column_lower_bound_witness(j); + ret = lar->join_deps(ret, witness); + } + return ret; + }; + m_bp.add_bound(u, bound_j, is_lower_bound, strict, explain); } - + void advance_u(unsigned j) { m_column_of_u = (m_column_of_u == -1) ? j : -2; } @@ -335,6 +337,9 @@ public : break; } } + }; + + } diff --git a/src/math/lp/column.h b/src/math/lp/column.h new file mode 100644 index 00000000000..1b2d1f2e5f8 --- /dev/null +++ b/src/math/lp/column.h @@ -0,0 +1,79 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Abstract: + + justifications for upper or lower bounds + +Author: + + Lev Nachmanson (levnach) + +--*/ + +#pragma once +#include "util/vector.h" +#include "util/dependency.h" +#include +#include +#include +#include "math/lp/column_info.h" +#include "math/lp/lp_types.h" + +namespace lp { + + +inline bool kind_is_strict(lconstraint_kind kind) { return kind == LT || kind == GT;} + +inline std::ostream& operator<<(std::ostream& out, lconstraint_kind k) { + switch (k) { + case LE: return out << "<="; + case LT: return out << "<"; + case GE: return out << ">="; + case GT: return out << ">"; + case EQ: return out << "="; + case NE: return out << "!="; + } + return out << "??"; +} + +inline bool compare(const std::pair & a, const std::pair & b) { + return a.second < b.second; +} +class lar_term; // forward definition +class column { + u_dependency* m_lower_bound_witness = nullptr; + u_dependency* m_upper_bound_witness = nullptr; + bool m_associated_with_row = false; + lar_term* m_term = nullptr; +public: + lar_term* term() const { return m_term; } + + u_dependency*& lower_bound_witness() { return m_lower_bound_witness; } + u_dependency* lower_bound_witness() const { return m_lower_bound_witness; } + u_dependency*& upper_bound_witness() { return m_upper_bound_witness; } + u_dependency* upper_bound_witness() const { return m_upper_bound_witness; } + + // equality is used by stackedvector operations. + // this appears to be a low level reason + + bool operator!=(const column & p) const { + return !(*this == p); + } + + bool operator==(const column & p) const { + return m_lower_bound_witness == p.m_lower_bound_witness + && m_upper_bound_witness == p.m_upper_bound_witness + && m_associated_with_row == p.m_associated_with_row; + } + column() = delete; + column(bool) = delete; + + + column(bool associated_with_row, lar_term* term) : + m_associated_with_row(associated_with_row), m_term(term) {} + + bool associated_with_row() const { return m_associated_with_row; } +}; + +} diff --git a/src/math/lp/core_solver_pretty_printer_def.h b/src/math/lp/core_solver_pretty_printer_def.h index 27aa6c75d72..cbe67ea36a1 100644 --- a/src/math/lp/core_solver_pretty_printer_def.h +++ b/src/math/lp/core_solver_pretty_printer_def.h @@ -279,10 +279,12 @@ template void core_solver_pretty_printer::print() print_row(i); } m_out << std::endl; - if (m_core_solver.inf_set().size()) { - m_out << "inf columns: "; - print_vector(m_core_solver.inf_set(), m_out); + if (!m_core_solver.inf_heap().empty()) { + m_out << "inf columns: size() = " << m_core_solver.inf_heap().size() << std::endl; + print_vector(m_core_solver.inf_heap(), m_out); m_out << std::endl; + } else { + m_out << "inf columns: none\n"; } } diff --git a/src/math/lp/dense_matrix.cpp b/src/math/lp/dense_matrix.cpp index 25fc65a5d41..cd8e019e619 100644 --- a/src/math/lp/dense_matrix.cpp +++ b/src/math/lp/dense_matrix.cpp @@ -17,6 +17,7 @@ Revision History: --*/ +// clang-format off #include "math/lp/lp_settings.h" #include "math/lp/dense_matrix_def.h" #ifdef Z3DEBUG diff --git a/src/math/lp/dense_matrix.h b/src/math/lp/dense_matrix.h index fcc85cdd17d..6b039a9203f 100644 --- a/src/math/lp/dense_matrix.h +++ b/src/math/lp/dense_matrix.h @@ -17,6 +17,7 @@ Revision History: --*/ +// clang-format off #pragma once #ifdef Z3DEBUG #include "util/vector.h" diff --git a/src/math/lp/emonics.cpp b/src/math/lp/emonics.cpp index bcdb81dd8f4..e6e52e57b26 100644 --- a/src/math/lp/emonics.cpp +++ b/src/math/lp/emonics.cpp @@ -517,11 +517,10 @@ bool emonics::invariant() const { TRACE("nla_solver_mons", display(tout);); // the variable index contains exactly the active monomials unsigned mons = 0; - for (lpvar v = 0; v < m_var2index.size(); v++) { - if (is_monic_var(v)) { + for (lpvar v = 0; v < m_var2index.size(); v++) + if (is_monic_var(v)) mons++; - } - } + if (m_monics.size() != mons) { TRACE("nla_solver_mons", tout << "missmatch of monic vars\n";); return false; @@ -537,7 +536,7 @@ bool emonics::invariant() const { do { auto const& m = m_monics[c->m_index]; bool found = false; - for (lp::var_index w : m.rvars()) { + for (lp::lpvar w : m.rvars()) { auto w1 = m_ve.find(w); found |= v1.var() == w1.var(); } @@ -596,4 +595,35 @@ bool emonics::invariant() const { return true; } + +void emonics::set_propagated(monic const& m) { + struct set_unpropagated : public trail { + emonics& em; + unsigned var; + public: + set_unpropagated(emonics& em, unsigned var): em(em), var(var) {} + void undo() override { + em[var].set_propagated(false); + } + }; + SASSERT(!m.is_propagated()); + (*this)[m.var()].set_propagated(true); + m_u_f_stack.push(set_unpropagated(*this, m.var())); +} + +void emonics::set_bound_propagated(monic const& m) { + struct set_bound_unpropagated : public trail { + emonics& em; + unsigned var; + public: + set_bound_unpropagated(emonics& em, unsigned var): em(em), var(var) {} + void undo() override { + em[var].set_bound_propagated(false); + } + }; + SASSERT(!m.is_bound_propagated()); + (*this)[m.var()].set_bound_propagated(true); + m_u_f_stack.push(set_bound_unpropagated(*this, m.var())); +} + } diff --git a/src/math/lp/emonics.h b/src/math/lp/emonics.h index e4f4f484861..55086515dc8 100644 --- a/src/math/lp/emonics.h +++ b/src/math/lp/emonics.h @@ -142,6 +142,9 @@ class emonics { void merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) {} void after_merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) {} + void set_propagated(monic const& m); + void set_bound_propagated(monic const& m); + // this method is required by union_find trail_stack & get_trail_stack() { return m_u_f_stack; } diff --git a/src/math/lp/explanation.h b/src/math/lp/explanation.h index d2e7edc337b..960c5fb4ab5 100644 --- a/src/math/lp/explanation.h +++ b/src/math/lp/explanation.h @@ -48,6 +48,15 @@ class explanation { SASSERT(m_vector.empty()); m_set.insert(j); } + + void remove(constraint_index j) { + m_set.remove(j); + unsigned i = 0; + for (auto& p : m_vector) + if (p.first != j) + m_vector[i++] = p; + m_vector.shrink(i); + } void add_expl(const explanation& e) { if (e.m_vector.empty()) { @@ -112,4 +121,18 @@ class explanation { } }; + + struct equality { + lp::lpvar i, j; + lp::explanation e; + equality(lp::lpvar i, lp::lpvar j, lp::explanation const& e):i(i),j(j),e(e) {} + }; + + struct fixed_equality { + lp::lpvar v; + rational k; + lp::explanation e; + fixed_equality(lp::lpvar v, rational const& k, lp::explanation const& e):v(v),k(k),e(e) {} + }; + } diff --git a/src/math/lp/factorization.cpp b/src/math/lp/factorization.cpp index 229fca61f83..e1dcff626dd 100644 --- a/src/math/lp/factorization.cpp +++ b/src/math/lp/factorization.cpp @@ -1,3 +1,12 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Author: +Lev Nachmanson (levnach) +Nikolaj Bjorner (nbjorner) + +--*/ + #include "util/vector.h" #include "math/lp/factorization.h" namespace nla { @@ -7,11 +16,10 @@ void const_iterator_mon::init_vars_by_the_mask(unsigned_vector & k_vars, unsigne SASSERT(m_mask.size() + 1 == m_ff->m_vars.size()); k_vars.push_back(m_ff->m_vars.back()); for (unsigned j = 0; j < m_mask.size(); j++) { - if (m_mask[j]) { - k_vars.push_back(m_ff->m_vars[j]); - } else { - j_vars.push_back(m_ff->m_vars[j]); - } + if (m_mask[j]) + k_vars.push_back(m_ff->m_vars[j]); + else + j_vars.push_back(m_ff->m_vars[j]); } } // todo : do we need the sign? @@ -29,9 +37,9 @@ bool const_iterator_mon::get_factors(factor& k, factor& j, rational& sign) const m_full_factorization_returned = true; return false; } - if (k_vars.size() == 1) { - k.set(k_vars[0], factor_type::VAR); - } else { + if (k_vars.size() == 1) + k.set(k_vars[0], factor_type::VAR); + else { unsigned i; if (!m_ff->find_canonical_monic_of_vars(k_vars, i)) { ++m_num_failures; @@ -41,9 +49,9 @@ bool const_iterator_mon::get_factors(factor& k, factor& j, rational& sign) const } m_num_failures = 0; - if (j_vars.size() == 1) { - j.set(j_vars[0], factor_type::VAR); - } else { + if (j_vars.size() == 1) + j.set(j_vars[0], factor_type::VAR); + else { unsigned i; if (!m_ff->find_canonical_monic_of_vars(j_vars, i)) { ++m_num_failures; diff --git a/src/math/lp/factorization.h b/src/math/lp/factorization.h index b233894ad9c..e1096a75f60 100644 --- a/src/math/lp/factorization.h +++ b/src/math/lp/factorization.h @@ -2,20 +2,10 @@ /*++ Copyright (c) 2017 Microsoft Corporation - Module Name: - - - - Abstract: - - - Author: Lev Nachmanson (levnach) Nikolaj Bjorner (nbjorner) - Revision History: - --*/ #pragma once @@ -30,13 +20,12 @@ struct factorization_factory; enum class factor_type { VAR, MON }; class factor { - lpvar m_var{ UINT_MAX }; - factor_type m_type{ factor_type::VAR }; - bool m_sign{ false }; + lpvar m_var = UINT_MAX; + factor_type m_type = factor_type::VAR; + bool m_sign = false; public: - factor(): factor(false) {} - factor(bool sign): m_sign(sign) {} - explicit factor(lpvar v, factor_type t) : m_var(v), m_type(t), m_sign(false) {} + factor() { } + explicit factor(lpvar v, factor_type t) : m_var(v), m_type(t) {} unsigned var() const { return m_var; } factor_type type() const { return m_type; } void set(lpvar v, factor_type t) { m_var = v; m_type = t; } diff --git a/src/math/lp/general_matrix.h b/src/math/lp/general_matrix.h index a4f6693a211..fb1030e6b10 100644 --- a/src/math/lp/general_matrix.h +++ b/src/math/lp/general_matrix.h @@ -109,7 +109,7 @@ class general_matrix { auto & row = m_data[adjust_row(i)]; lp_assert(row_is_initialized_correctly(row)); for (lp::lar_term::ival p : c) { - unsigned j = adjust_column(column_fix(p.column().index())); + unsigned j = adjust_column(column_fix(p.j())); row[j] = sign * p.coeff(); } } diff --git a/src/math/lp/gomory.cpp b/src/math/lp/gomory.cpp index 2ecbc49aca6..7b4347af513 100644 --- a/src/math/lp/gomory.cpp +++ b/src/math/lp/gomory.cpp @@ -22,25 +22,24 @@ #include "math/lp/lar_solver.h" #include "math/lp/lp_utils.h" -#define SMALL_CUTS 1 namespace lp { - -class create_cut { + +enum class row_polarity { UNDEF, MIN, MAX, MIXED}; +struct create_cut { lar_term & m_t; // the term to return in the cut mpq & m_k; // the right side of the cut explanation* m_ex; // the conflict explanation unsigned m_inf_col; // a basis column which has to be an integer but has a non integral value const row_strip& m_row; - const int_solver& lia; - mpq m_lcm_den; + int_solver& lia; mpq m_f; mpq m_one_minus_f; mpq m_fj; mpq m_one_minus_fj; -#if SMALL_CUTS mpq m_abs_max, m_big_number; -#endif - struct found_big {}; + row_polarity m_polarity; + bool m_found_big; + u_dependency* m_dep; const impq & get_value(unsigned j) const { return lia.get_value(j); } bool is_int(unsigned j) const { return lia.column_is_int(j) || (lia.is_fixed(j) && @@ -50,9 +49,13 @@ class create_cut { bool at_upper(unsigned j) const { return lia.at_upper(j); } const impq & lower_bound(unsigned j) const { return lia.lower_bound(j); } const impq & upper_bound(unsigned j) const { return lia.upper_bound(j); } - constraint_index column_lower_bound_constraint(unsigned j) const { return lia.column_lower_bound_constraint(j); } - constraint_index column_upper_bound_constraint(unsigned j) const { return lia.column_upper_bound_constraint(j); } + u_dependency* column_lower_bound_constraint(unsigned j) const { return lia.column_lower_bound_constraint(j); } + u_dependency* column_upper_bound_constraint(unsigned j) const { return lia.column_upper_bound_constraint(j); } bool column_is_fixed(unsigned j) const { return lia.lra.column_is_fixed(j); } + void push_explanation(u_dependency* d) { + for (auto ci : lia.lra.flatten(d)) + m_ex->push_back(ci); + } void int_case_in_gomory_cut(unsigned j) { lp_assert(is_int(j) && m_fj.is_pos()); @@ -66,8 +69,8 @@ class create_cut { // here we have the product of new_a*(xj - lb(j)), so new_a*lb(j) is added to m_k new_a = m_fj <= m_one_minus_f ? m_fj / m_one_minus_f : ((1 - m_fj) / m_f); lp_assert(new_a.is_pos()); - m_k.addmul(new_a, lower_bound(j).x); - m_ex->push_back(column_lower_bound_constraint(j)); + m_k.addmul(new_a, lower_bound(j).x); + push_explanation(column_lower_bound_constraint(j)); } else { lp_assert(at_upper(j)); @@ -75,52 +78,59 @@ class create_cut { new_a = - (m_fj <= m_f ? m_fj / m_f : ((1 - m_fj) / m_one_minus_f)); lp_assert(new_a.is_neg()); m_k.addmul(new_a, upper_bound(j).x); - m_ex->push_back(column_upper_bound_constraint(j)); - } + push_explanation(column_upper_bound_constraint(j)); + } m_t.add_monomial(new_a, j); - m_lcm_den = lcm(m_lcm_den, denominator(new_a)); - TRACE("gomory_cut_detail", tout << "new_a = " << new_a << ", k = " << m_k << ", lcm_den = " << m_lcm_den << "\n";); -#if SMALL_CUTS - // if (numerator(new_a).is_big()) throw found_big(); + TRACE("gomory_cut_detail", tout << "new_a = " << new_a << ", k = " << m_k << "\n";); if (numerator(new_a) > m_big_number) - throw found_big(); -#endif + m_found_big = true; + } + + void set_polarity(row_polarity p) { + if (m_polarity == row_polarity::MIXED) return; + if (m_polarity == row_polarity::UNDEF) m_polarity = p; + else if (m_polarity != p) m_polarity = row_polarity::MIXED; } void real_case_in_gomory_cut(const mpq & a, unsigned j) { TRACE("gomory_cut_detail_real", tout << "j = " << j << ", a = " << a << ", m_k = " << m_k << "\n";); mpq new_a; if (at_lower(j)) { - if (a.is_pos()) + if (a.is_pos()) { // the delta is a (x - f) is positive it has to grow and fight m_one_minus_f new_a = a / m_one_minus_f; - else + set_polarity(row_polarity::MIN); // reverse the polarity since a = -p.coeff() + } + else { // the delta is negative and it works again m_f new_a = - a / m_f; + set_polarity(row_polarity::MAX); + } m_k.addmul(new_a, lower_bound(j).x); // is it a faster operation than - // k += lower_bound(j).x * new_a; - m_ex->push_back(column_lower_bound_constraint(j)); + // k += lower_bound(j).x * new_a; + push_explanation(column_lower_bound_constraint(j)); } else { lp_assert(at_upper(j)); - if (a.is_pos()) + if (a.is_pos()) { // the delta is works again m_f - new_a = - a / m_f; - else + new_a = - a / m_f; + set_polarity(row_polarity::MAX); + } + else { // the delta is positive works again m_one_minus_f - new_a = a / m_one_minus_f; - m_k.addmul(new_a, upper_bound(j).x); // k += upper_bound(j).x * new_a; - m_ex->push_back(column_upper_bound_constraint(j)); + new_a = a / m_one_minus_f; + set_polarity(row_polarity::MIN); + } + m_k.addmul(new_a, upper_bound(j).x); // k += upper_bound(j).x * new_a; + push_explanation(column_upper_bound_constraint(j)); } m_t.add_monomial(new_a, j); TRACE("gomory_cut_detail_real", tout << "add " << new_a << "*v" << j << ", k: " << m_k << "\n"; tout << "m_t = "; lia.lra.print_term(m_t, tout) << "\nk: " << m_k << "\n";); -#if SMALL_CUTS - // if (numerator(new_a).is_big()) throw found_big(); if (numerator(new_a) > m_big_number) - throw found_big(); -#endif + m_found_big = true; } lia_move report_conflict_from_gomory_cut() { @@ -128,48 +138,7 @@ class create_cut { // conflict 0 >= k where k is positive return lia_move::conflict; } - - void adjust_term_and_k_for_some_ints_case_gomory() { - lp_assert(!m_t.is_empty()); - // k = 1 + sum of m_t at bounds - auto pol = m_t.coeffs_as_vector(); - m_t.clear(); - if (pol.size() == 1) { - TRACE("gomory_cut_detail", tout << "pol.size() is 1" << std::endl;); - unsigned v = pol[0].second; - lp_assert(is_int(v)); - const mpq& a = pol[0].first; - if (a.is_pos()) { // we have av >= k - m_k /= a; - if (!m_k.is_int()) - m_k = ceil(m_k); - m_t.add_monomial(mpq(1), v); - } - else { - m_k /= -a; - if (!m_k.is_int()) - m_k = ceil(m_k); - m_t.add_monomial(-mpq(1), v); - } - } - else { - m_lcm_den = lcm(m_lcm_den, denominator(m_k)); - lp_assert(m_lcm_den.is_pos()); - TRACE("gomory_cut_detail", tout << "pol.size() > 1 den: " << m_lcm_den << std::endl;); - if (!m_lcm_den.is_one()) { - // normalize coefficients of integer parameters to be integers. - for (auto & pi: pol) { - pi.first *= m_lcm_den; - SASSERT(!is_int(pi.second) || pi.first.is_int()); - } - m_k *= m_lcm_den; - } - for (const auto & pi: pol) - m_t.add_monomial(pi.first, pi.second); - } - TRACE("gomory_cut_detail", tout << "k = " << m_k << std::endl;); - lp_assert(m_k.is_int()); - } + std::string var_name(unsigned j) const { return std::string("x") + std::to_string(j); @@ -178,7 +147,7 @@ class create_cut { std::ostream& dump_coeff_val(std::ostream & out, const mpq & a) const { if (a.is_int()) out << a; - else if ( a >= zero_of_type()) + else if (a >= zero_of_type()) out << "(/ " << numerator(a) << " " << denominator(a) << ")"; else out << "(- (/ " << numerator(-a) << " " << denominator(-a) << "))"; @@ -187,10 +156,7 @@ class create_cut { template void dump_coeff(std::ostream & out, const T& c) const { - out << "( * "; - dump_coeff_val(out, c.coeff()); - auto t = lia.lra.column2tv(c.column()); - out << " " << var_name(t.id()) << ")"; + dump_coeff_val(out << "(* ", c.coeff()) << " " << var_name(c.j()) << ")"; } std::ostream& dump_row_coefficients(std::ostream & out) const { @@ -204,7 +170,7 @@ class create_cut { void dump_the_row(std::ostream& out) const { out << "; the row, excluding fixed vars\n"; - out << "(assert ( = ( +"; + out << "(assert (= (+"; dump_row_coefficients(out) << ") 0))\n"; } @@ -217,16 +183,15 @@ class create_cut { for (const auto & p : m_row) dump_declaration(out, p.var()); for (lar_term::ival p : m_t) { - auto t = lia.lra.column2tv(p.column()); - if (t.is_term()) { - dump_declaration(out, t.id()); - } + if (lia.lra.column_has_term(p.j())) + dump_declaration(out, p.j()); } } void dump_lower_bound_expl(std::ostream & out, unsigned j) const { out << "(assert (>= " << var_name(j) << " " << lower_bound(j).x << "))\n"; } + void dump_upper_bound_expl(std::ostream & out, unsigned j) const { out << "(assert (<= " << var_name(j) << " " << upper_bound(j).x << "))\n"; } @@ -255,12 +220,12 @@ class create_cut { return dump_term_coefficients(out << "(+ ") << ")"; } - std::ostream& dump_term_le_k(std::ostream & out) const { - return dump_term_sum(out << "(<= ") << " " << m_k << ")"; + std::ostream& dump_term_ge_k(std::ostream & out) const { + return dump_term_sum(out << "(>= ") << " " << m_k << ")"; } void dump_the_cut_assert(std::ostream & out) const { - dump_term_le_k(out << "(assert (not ") << "))\n"; + dump_term_ge_k(out << "(assert (not ") << "))\n"; } void dump_cut_and_constraints_as_smt_lemma(std::ostream& out) const { @@ -281,18 +246,40 @@ class create_cut { lia_move cut() { TRACE("gomory_cut", dump(tout);); - - // gomory will be t >= k and the current solution has a property t < k + // If m_polarity is MAX, then + // the row constraints the base variable to be at the maximum, + // MIN - at the minimum, + // MIXED : the row does not constraint the base variable to be at an extremum + // UNDEF is the initial state + m_polarity = row_polarity::UNDEF; + // gomory cut will be m_t >= m_k and the current solution has a property m_t < m_k m_k = 1; m_t.clear(); - mpq m_lcm_den(1); - bool some_int_columns = false; - mpq m_f = fractional_part(get_value(m_inf_col)); + m_ex->clear(); + m_found_big = false; TRACE("gomory_cut_detail", tout << "m_f: " << m_f << ", "; tout << "1 - m_f: " << 1 - m_f << ", get_value(m_inf_col).x - m_f = " << get_value(m_inf_col).x - m_f << "\n";); lp_assert(m_f.is_pos() && (get_value(m_inf_col).x - m_f).is_int()); - -#if SMALL_CUTS + auto set_polarity_for_int = [&](const mpq & a, lpvar j) { + if (a.is_pos()) { + if (at_lower(j)) + set_polarity(row_polarity::MAX); + else if (at_upper(j)) + set_polarity(row_polarity::MIN); + else + set_polarity(row_polarity::MIXED); + } + else { + if (at_lower(j)) + set_polarity(row_polarity::MIN); + else if (at_upper(j)) + set_polarity(row_polarity::MAX); + else + set_polarity(row_polarity::MIXED); + } + }; + + m_abs_max = 0; for (const auto & p : m_row) { mpq t = abs(ceil(p.coeff())); @@ -300,132 +287,274 @@ class create_cut { m_abs_max = t; } m_big_number = m_abs_max.expt(2); -#endif - mpq one_min_m_f = 1 - m_f; + for (const auto & p : m_row) { unsigned j = p.var(); - if (j == m_inf_col) { - lp_assert(p.coeff() == one_of_type()); - TRACE("gomory_cut_detail", tout << "seeing basic var\n";); + if (j == m_inf_col) continue; + // use -p.coeff() to make the format compatible with the format used in: Integrating Simplex with DPLL(T) + + if (lia.is_fixed(j)) { + push_explanation(column_lower_bound_constraint(j)); + push_explanation(column_upper_bound_constraint(j)); continue; } - - // use -p.coeff() to make the format compatible with the format used in: Integrating Simplex with DPLL(T) - try { - if (lia.is_fixed(j)) { - m_ex->push_back(column_lower_bound_constraint(j)); - m_ex->push_back(column_upper_bound_constraint(j)); - continue; - } - if (is_real(j)) - real_case_in_gomory_cut(- p.coeff(), j); - else if (!p.coeff().is_int()) { - some_int_columns = true; + if (is_real(j)) + real_case_in_gomory_cut(- p.coeff(), j); + else { + if (!p.coeff().is_int()) { m_fj = fractional_part(-p.coeff()); m_one_minus_fj = 1 - m_fj; int_case_in_gomory_cut(j); } + if (m_polarity != row_polarity::MIXED) + set_polarity_for_int(p.coeff(), j); + } - catch (found_big) { - m_ex->clear(); - m_t.clear(); - m_k = 1; + + if (m_found_big) { return lia_move::undef; } } - if (m_t.is_empty()) + + if (m_t.is_empty()) { return report_conflict_from_gomory_cut(); - if (some_int_columns) - adjust_term_and_k_for_some_ints_case_gomory(); - TRACE("gomory_cut_detail", dump_cut_and_constraints_as_smt_lemma(tout);); - lp_assert(lia.current_solution_is_inf_on_cut()); // checks that indices are columns - TRACE("gomory_cut", print_linear_combination_of_column_indices_only(m_t.coeffs_as_vector(), tout << "gomory cut:"); tout << " <= " << m_k << std::endl;); + } + TRACE("gomory_cut", print_linear_combination_of_column_indices_only(m_t.coeffs_as_vector(), tout << "gomory cut: "); tout << " >= " << m_k << std::endl;); + + m_dep = nullptr; + for (auto c : *m_ex) + m_dep = lia.lra.join_deps(lia.lra.dep_manager().mk_leaf(c.ci()), m_dep); + + TRACE("gomory_cut_detail", dump_cut_and_constraints_as_smt_lemma(tout); + lia.lra.display(tout)); + SASSERT(lia.current_solution_is_inf_on_cut()); + + lia.settings().stats().m_gomory_cuts++; return lia_move::cut; } - create_cut(lar_term & t, mpq & k, explanation* ex, unsigned basic_inf_int_j, const row_strip& row, const int_solver& lia) : + create_cut(lar_term & t, mpq & k, explanation* ex, unsigned basic_inf_int_j, const row_strip& row, int_solver& lia) : m_t(t), m_k(k), m_ex(ex), m_inf_col(basic_inf_int_j), m_row(row), lia(lia), - m_lcm_den(1), m_f(fractional_part(get_value(basic_inf_int_j).x)), m_one_minus_f(1 - m_f) {} -}; + }; -lia_move gomory::cut(lar_term & t, mpq & k, explanation* ex, unsigned basic_inf_int_j, const row_strip& row) { - create_cut cc(t, k, ex, basic_inf_int_j, row, lia); - return cc.cut(); -} + bool gomory::is_gomory_cut_target(lpvar k) { + SASSERT(lia.is_base(k)); + const row_strip& row = lra.get_row(lia.row_of_basic_column(k)); + // Consider monomial c*x from the row, where x is non-basic. + // Then, for each such monomial, one of following conditions + // has to hold for the row to be eligible for Gomory cut: + // 1) c is integral and x integral varible with an integral value + // 2) the value of x is at a bound and has no infinitesimals. -bool gomory::is_gomory_cut_target(const row_strip& row) { - // All non base variables must be at their bounds and assigned to rationals (that is, infinitesimals are not allowed). - unsigned j; - for (const auto & p : row) { - j = p.var(); - if (!lia.is_base(j) && (!lia.at_bound(j) || !is_zero(lia.get_value(j).y))) { - TRACE("gomory_cut", tout << "row is not gomory cut target:\n"; - lia.display_column(tout, j); - tout << "infinitesimal: " << !is_zero(lia.get_value(j).y) << "\n";); - return false; + + unsigned j; + for (const auto & p : row) { + j = p.var(); + if (k == j) continue; + + if (p.coeff().is_int() && lia.column_is_int(j) && lia.get_value(j).is_int()) continue; + + if ( !lia.at_bound(j) || lia.get_value(j).y != 0) { + TRACE("gomory_cut", tout << "row is not gomory cut target:\n"; + lia.display_column(tout, j); + tout << "infinitesimal: " << !(lia.get_value(j).y ==0) << "\n";); + return false; + } } + return true; + + // Condition 1) above can be relaxed even more, allowing any value for x, but it will change the calculation for m_f. } - return true; -} -int gomory::find_basic_var() { - unsigned n = 0; - int result = -1; - unsigned min_row_size = UINT_MAX; - -#if 0 - result = lia.select_int_infeasible_var(); - - if (result == -1) - return result; - - const row_strip& row = lra.get_row(lia.row_of_basic_column(result)); - if (is_gomory_cut_target(row)) - return result; - result = -1; -#endif - - for (unsigned j : lra.r_basis()) { - if (!lia.column_is_int_inf(j)) - continue; - const row_strip& row = lra.get_row(lia.row_of_basic_column(j)); - if (!is_gomory_cut_target(row)) - continue; - IF_VERBOSE(20, lia.display_row_info(verbose_stream(), lia.row_of_basic_column(j))); - // Prefer smaller row size - if (min_row_size == UINT_MAX || - 2*row.size() < min_row_size || - (4*row.size() < 5*min_row_size && lia.random() % (++n) == 0)) { - result = j; - n = 1; - min_row_size = std::min(min_row_size, row.size()); + // return the minimal distance from the variable value to an integer + mpq get_gomory_score(const int_solver& lia, lpvar j) { + const mpq& val = lia.get_value(j).x; + auto l = val - floor(val); + if (l <= mpq(1, 2)) + return l; + return mpq(1) - l; + } + + unsigned_vector gomory::gomory_select_int_infeasible_vars(unsigned num_cuts) { + std::list sorted_vars; + std::unordered_map score; + for (lpvar j : lra.r_basis()) { + if (!lia.column_is_int_inf(j) || !is_gomory_cut_target(j)) + continue; + SASSERT(!lia.is_fixed(j)); + sorted_vars.push_back(j); + score[j] = get_gomory_score(lia, j); + } + // prefer the variables with the values close to integers + sorted_vars.sort([&](lpvar j, lpvar k) { + auto diff = score[j] - score[k]; + if (diff.is_neg()) + return true; + if (diff.is_pos()) + return false; + return lra.usage_in_terms(j) > lra.usage_in_terms(k); + }); + unsigned_vector ret; + unsigned n = static_cast(sorted_vars.size()); + + while (num_cuts-- && n > 0) { + unsigned k = lia.random() % n; + + double k_ratio = k / (double) n; + k_ratio *= k_ratio*k_ratio; // square k_ratio to make it smaller + k = static_cast(std::floor(k_ratio * n)); + // these operations move k to the beginning of the indices range + SASSERT(0 <= k && k < n); + auto it = sorted_vars.begin(); + while(k--) it++; + + ret.push_back(*it); + sorted_vars.erase(it); + n--; } + return ret; } - return result; -} -lia_move gomory::operator()() { - lra.move_non_basic_columns_to_bounds(true); - int j = find_basic_var(); - if (j == -1) - return lia_move::undef; - unsigned r = lia.row_of_basic_column(j); - const row_strip& row = lra.get_row(r); - SASSERT(lra.row_is_correct(r)); - SASSERT(is_gomory_cut_target(row)); - lia.m_upper = false; - return cut(lia.m_t, lia.m_k, lia.m_ex, j, row); -} + row_polarity test_row_polarity(const int_solver& lia, const row_strip& row, lpvar basic_j) { + row_polarity ret = row_polarity::UNDEF; + for (const auto& p : row) { + lpvar j = p.var(); + if (j == basic_j) + continue; + if (lia.is_fixed(j)) + continue; + + row_polarity rp; + if (p.coeff().is_pos()) { + if (lia.at_lower(j)) + rp = row_polarity::MAX; + else if (lia.at_upper(j)) + rp = row_polarity::MIN; + else + rp = row_polarity::MIXED; + } + else { + if (lia.at_lower(j)) + rp = row_polarity::MIN; + else if (lia.at_upper(j)) + rp = row_polarity::MAX; + else + rp = row_polarity::MIXED; + + } + if (ret == row_polarity::UNDEF) + ret = rp; + if (ret != rp) + return row_polarity::MIXED; + } + return ret; + } + + u_dependency* gomory::add_deps(u_dependency* dep, const row_strip& row, lpvar basic_var) { + u_dependency* ret = dep; + for (const auto& p : row) { + lpvar j = p.var(); + if (j == basic_var) + continue; + if (lia.is_fixed(j)) + continue; + if (lia.is_real(j)) continue; + if (!p.coeff().is_int()) continue; + // the explanation for all above have been already added + if (lia.at_lower(j)) + ret = lia.lra.dep_manager().mk_join(lia.column_lower_bound_constraint(j), ret); + else { + SASSERT(lia.at_upper(j)); + ret = lia.lra.dep_manager().mk_join(lia.column_upper_bound_constraint(j), ret); + } + } + return ret; + } + lia_move gomory::get_gomory_cuts(unsigned num_cuts) { + struct cut_result {lar_term t; mpq k; u_dependency *dep;}; + vector big_cuts; + unsigned_vector columns_for_cuts = gomory_select_int_infeasible_vars(num_cuts); + bool has_small_cut = false; + + // define inline helper functions + auto is_small_cut = [&](lar_term const& t) { + return all_of(t, [&](auto ci) { return ci.coeff().is_small(); }); + }; + auto add_cut = [&](const lar_term& t, const mpq& k, u_dependency * dep) { + lp::lpvar j = lra.add_term(t.coeffs_as_vector(), UINT_MAX); + lra.update_column_type_and_bound(j, lp::lconstraint_kind::GE, k, dep); + }; + auto _check_feasible = [&](void) { + lra.find_feasible_solution(); + if (!lra.is_feasible() && !lia.settings().get_cancel_flag()) { + lra.get_infeasibility_explanation(*lia.m_ex); + return false; + } + return true; + }; + +// start creating cuts + for (unsigned j : columns_for_cuts) { + SASSERT(is_gomory_cut_target(j)); + unsigned row_index = lia.row_of_basic_column(j); + const row_strip& row = lra.get_row(row_index); + create_cut cc(lia.m_t, lia.m_k, lia.m_ex, j, row, lia); + auto r = cc.cut(); + if (r != lia_move::cut) { + if (r == lia_move::conflict) + return lia_move::conflict; + continue; + } + SASSERT(test_row_polarity(lia, row, j) == cc.m_polarity); + if (cc.m_polarity == row_polarity::MAX) + lra.update_column_type_and_bound(j, lp::lconstraint_kind::LE, floor(lra.get_column_value(j).x), add_deps(cc.m_dep, row, j)); + else if (cc.m_polarity == row_polarity::MIN) + lra.update_column_type_and_bound(j, lp::lconstraint_kind::GE, ceil(lra.get_column_value(j).x), add_deps(cc.m_dep, row, j)); + + if (!is_small_cut(lia.m_t)) { + big_cuts.push_back({cc.m_t, cc.m_k, cc.m_dep}); + continue; + } + has_small_cut = true; + add_cut(cc.m_t, cc.m_k, cc.m_dep); + if (lia.settings().get_cancel_flag()) + return lia_move::undef; + } + + if (big_cuts.size()) { + lra.push(); + for (auto const& cut : big_cuts) + add_cut(cut.t, cut.k, cut.dep); + bool feas = _check_feasible(); + lra.pop(1); -gomory::gomory(int_solver& lia): lia(lia), lra(lia.lra) { } + if (!feas) + for (auto const& cut : big_cuts) + add_cut(cut.t, cut.k, cut.dep); + } + + if (!_check_feasible()) + return lia_move::conflict; + + if (!lia.has_inf_int()) + return lia_move::sat; + if (has_small_cut || big_cuts.size()) + return lia_move::continue_with_check; + + lra.move_non_basic_columns_to_bounds(); + return lia_move::undef; + } + + + gomory::gomory(int_solver& lia): lia(lia), lra(lia.lra) { } } diff --git a/src/math/lp/gomory.h b/src/math/lp/gomory.h index 68e42feb947..0fdf4f8fe1c 100644 --- a/src/math/lp/gomory.h +++ b/src/math/lp/gomory.h @@ -27,11 +27,11 @@ namespace lp { class gomory { class int_solver& lia; class lar_solver& lra; - int find_basic_var(); - bool is_gomory_cut_target(const row_strip& row); - lia_move cut(lar_term & t, mpq & k, explanation* ex, unsigned basic_inf_int_j, const row_strip& row); + unsigned_vector gomory_select_int_infeasible_vars(unsigned num_cuts); + bool is_gomory_cut_target(lpvar j); + u_dependency* add_deps(u_dependency*, const row_strip&, lpvar); public: + lia_move get_gomory_cuts(unsigned num_cuts); gomory(int_solver& lia); - lia_move operator()(); }; } diff --git a/src/math/lp/hnf_cutter.cpp b/src/math/lp/hnf_cutter.cpp index 3c4ea10ab92..b120b7aac0e 100644 --- a/src/math/lp/hnf_cutter.cpp +++ b/src/math/lp/hnf_cutter.cpp @@ -20,7 +20,7 @@ namespace lp { lra(lia.lra), m_settings(lia.settings()), m_abs_max(zero_of_type()), - m_var_register(false) {} + m_var_register() {} bool hnf_cutter::is_full() const { return @@ -39,7 +39,7 @@ namespace lp { m_overflow = false; } - void hnf_cutter::add_term(const lar_term* t, const mpq &rs, constraint_index ci, bool upper_bound) { + void hnf_cutter::add_term(const lar_term* t, const mpq &rs, u_dependency* ci, bool upper_bound) { m_terms.push_back(t); m_terms_upper.push_back(upper_bound); if (upper_bound) @@ -50,7 +50,7 @@ namespace lp { m_constraints_for_explanation.push_back(ci); for (lar_term::ival p : *t) { - m_var_register.add_var(p.column().index(), true); // hnf only deals with integral variables for now + m_var_register.add_var(p.j(), true); // hnf only deals with integral variables for now mpq t = abs(ceil(p.coeff())); if (t > m_abs_max) m_abs_max = t; @@ -83,14 +83,13 @@ namespace lp { // consider return from here if b[i] is not an integer and return i } } - + vector hnf_cutter::create_b(const svector & basis_rows) { if (basis_rows.size() == m_right_sides.size()) return m_right_sides; vector b; - for (unsigned i : basis_rows) { - b.push_back(m_right_sides[i]); - } + for (unsigned i : basis_rows) + b.push_back(m_right_sides[i]); return b; } @@ -98,16 +97,15 @@ namespace lp { int ret = -1; int n = 0; for (int i = 0; i < static_cast(b.size()); i++) { - if (is_integer(b[i])) continue; - if (n == 0 ) { + if (is_integer(b[i])) + continue; + if (n == 0) { lp_assert(ret == -1); n = 1; ret = i; - } else { - if (m_settings.random_next() % (++n) == 0) { - ret = i; - } } + else if (m_settings.random_next() % (++n) == 0) + ret = i; } return ret; } @@ -146,7 +144,7 @@ namespace lp { } #endif void hnf_cutter::shrink_explanation(const svector& basis_rows) { - svector new_expl; + ptr_vector new_expl; for (unsigned i : basis_rows) { new_expl.push_back(m_constraints_for_explanation[i]); } @@ -229,39 +227,35 @@ branch y_i >= ceil(y0_i) is impossible. svector hnf_cutter::vars() const { return m_var_register.vars(); } - void hnf_cutter::try_add_term_to_A_for_hnf(tv const &i) { + void hnf_cutter::try_add_term_to_A_for_hnf(lpvar j) { mpq rs; - const lar_term& t = lra.get_term(i); - constraint_index ci; + const lar_term& t = lra.get_term(j); + u_dependency* dep; bool upper_bound; - if (!is_full() && lra.get_equality_and_right_side_for_term_on_current_x(i, rs, ci, upper_bound)) { - add_term(&t, rs, ci, upper_bound); + if (!is_full() && lra.get_equality_and_right_side_for_term_on_current_x(j, rs, dep, upper_bound)) { + add_term(&t, rs, dep, upper_bound); } } bool hnf_cutter::hnf_has_var_with_non_integral_value() const { - for (unsigned j : vars()) - if (!lia.get_value(j).is_int()) - return true; - return false; + return any_of(vars(), [&](unsigned j) { return !lia.get_value(j).is_int(); }); } bool hnf_cutter::init_terms_for_hnf_cut() { clear(); - for (unsigned i = 0; i < lra.terms().size() && !is_full(); i++) - try_add_term_to_A_for_hnf(tv::term(i)); + for (const lar_term* t: lra.terms()) + try_add_term_to_A_for_hnf(t->j()); return hnf_has_var_with_non_integral_value(); } - + lia_move hnf_cutter::make_hnf_cut() { - if (!init_terms_for_hnf_cut()) { + if (!init_terms_for_hnf_cut()) return lia_move::undef; - } lia.settings().stats().m_hnf_cutter_calls++; TRACE("hnf_cut", tout << "settings().stats().m_hnf_cutter_calls = " << lia.settings().stats().m_hnf_cutter_calls << "\n"; - for (unsigned i : constraints_for_explanation()) { - lra.constraints().display(tout, i); - } + for (u_dependency* d : constraints_for_explanation()) + for (auto ci : lra.flatten(d)) + lra.constraints().display(tout, ci); tout << lra.constraints(); ); #ifdef Z3DEBUG @@ -272,23 +266,22 @@ branch y_i >= ceil(y0_i) is impossible. , x0 #endif ); - + if (r == lia_move::cut) { TRACE("hnf_cut", lra.print_term(lia.m_t, tout << "cut:"); tout << " <= " << lia.m_k << std::endl; - for (unsigned i : constraints_for_explanation()) { - lra.constraints().display(tout, i); - } + for (auto* dep : constraints_for_explanation()) + for (auto ci : lra.flatten(dep)) + lra.constraints().display(tout, ci); ); lp_assert(lia.current_solution_is_inf_on_cut()); lia.settings().stats().m_hnf_cuts++; lia.m_ex->clear(); - for (unsigned i : constraints_for_explanation()) { - lia.m_ex->push_back(i); - } + for (u_dependency* dep : constraints_for_explanation()) + for (auto ci : lia.lra.flatten(dep)) + lia.m_ex->push_back(ci); } return r; } - } diff --git a/src/math/lp/hnf_cutter.h b/src/math/lp/hnf_cutter.h index b3530ea2901..07ddc7b12f0 100644 --- a/src/math/lp/hnf_cutter.h +++ b/src/math/lp/hnf_cutter.h @@ -34,7 +34,7 @@ class hnf_cutter { general_matrix m_A; vector m_terms; vector m_terms_upper; - svector m_constraints_for_explanation; + ptr_vector m_constraints_for_explanation; vector m_right_sides; mpq m_abs_max; bool m_overflow; @@ -50,18 +50,18 @@ class hnf_cutter { private: bool init_terms_for_hnf_cut(); bool hnf_has_var_with_non_integral_value() const; - void try_add_term_to_A_for_hnf(tv const& i); + void try_add_term_to_A_for_hnf(lpvar); unsigned terms_count() const { return m_terms.size(); } const mpq & abs_max() const { return m_abs_max; } const vector& terms() const { return m_terms; } - const svector& constraints_for_explanation() const { return m_constraints_for_explanation; } + const ptr_vector& constraints_for_explanation() const { return m_constraints_for_explanation; } const vector & right_sides() const { return m_right_sides; } bool is_full() const; void clear(); - void add_term(const lar_term* t, const mpq &rs, constraint_index ci, bool upper_bound); + void add_term(const lar_term* t, const mpq &rs, u_dependency* ci, bool upper_bound); void print(std::ostream & out); diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index 4d4ac497540..0cd62ecafef 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -40,7 +40,7 @@ bool horner::row_has_monomial_to_refine(const T& row) const { template bool horner::row_is_interesting(const T& row) const { TRACE("nla_solver_details", c().print_row(row, tout);); - if (row.size() > c().m_nla_settings.horner_row_length_limit) { + if (row.size() > c().params().arith_nl_horner_row_length_limit()) { TRACE("nla_solver_details", tout << "disregard\n";); return false; } @@ -76,7 +76,7 @@ bool horner::lemmas_on_expr(cross_nested& cn, nex_sum* e) { template bool horner::lemmas_on_row(const T& row) { SASSERT (row_is_interesting(row)); - c().clear_and_resize_active_var_set(); + c().clear_active_var_set(); u_dependency* dep = nullptr; create_sum_from_row(row, m_nex_creator, m_row_sum, dep); c().set_active_vars_weights(m_nex_creator); // without this call the comparisons will be incorrect @@ -98,19 +98,19 @@ bool horner::lemmas_on_row(const T& row) { } bool horner::horner_lemmas() { - if (!c().m_nla_settings.run_horner) { + if (!c().params().arith_nl_horner()) { TRACE("nla_solver", tout << "not generating horner lemmas\n";); return false; } c().lp_settings().stats().m_horner_calls++; - const auto& matrix = c().m_lar_solver.A_r(); + const auto& matrix = c().lra.A_r(); // choose only rows that depend on m_to_refine variables std::set rows_to_check; for (lpvar j : c().m_to_refine) { for (auto & s : matrix.m_columns[j]) rows_to_check.insert(s.var()); } - c().clear_and_resize_active_var_set(); + c().clear_active_var_set(); svector rows; for (unsigned i : rows_to_check) { if (row_is_interesting(matrix.m_rows[i])) diff --git a/src/math/lp/horner.h b/src/math/lp/horner.h index 2b6fc3cd816..9d530aceeb0 100644 --- a/src/math/lp/horner.h +++ b/src/math/lp/horner.h @@ -23,7 +23,7 @@ #include "math/lp/nla_intervals.h" #include "math/lp/nex.h" #include "math/lp/cross_nested.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" namespace nla { class core; diff --git a/src/math/lp/implied_bound.h b/src/math/lp/implied_bound.h index 9435edcdc76..195ec03599d 100644 --- a/src/math/lp/implied_bound.h +++ b/src/math/lp/implied_bound.h @@ -21,37 +21,40 @@ Revision History: #include "math/lp/lp_settings.h" #include "math/lp/lar_constraints.h" namespace lp { -struct implied_bound { +class implied_bound { + public: mpq m_bound; - unsigned m_j; // the column for which the bound has been found + // It is either the column for which the bound has been found, or, + // in the case the column was created as + // the slack variable to a term, it is the term index. + // It is the same index that was returned by lar_solver::add_var(), or + // by lar_solver::add_term() + unsigned m_j; bool m_is_lower_bound; - bool m_coeff_before_j_is_pos; - unsigned m_row_or_term_index; bool m_strict; + private: + std::function m_explain_bound = nullptr; + public: + // s is expected to be the pointer to lp_bound_propagator. + u_dependency* explain_implied() const { return m_explain_bound(); } + void set_explain(std::function f) { m_explain_bound = f; } lconstraint_kind kind() const { lconstraint_kind k = m_is_lower_bound? GE : LE; if (m_strict) k = static_cast(k / 2); return k; } - bool operator==(const implied_bound & o) const { - return m_j == o.m_j && m_is_lower_bound == o.m_is_lower_bound && m_bound == o.m_bound && - m_coeff_before_j_is_pos == o.m_coeff_before_j_is_pos && - m_row_or_term_index == o.m_row_or_term_index && m_strict == o.m_strict; - } implied_bound(){} implied_bound(const mpq & a, unsigned j, - bool lower_bound, - bool coeff_before_j_is_pos, - unsigned row_or_term_index, - bool strict): + bool is_lower_bound, + bool is_strict, + std::function get_dep): m_bound(a), m_j(j), - m_is_lower_bound(lower_bound), - m_coeff_before_j_is_pos(coeff_before_j_is_pos), - m_row_or_term_index(row_or_term_index), - m_strict(strict) { + m_is_lower_bound(is_lower_bound), + m_strict(is_strict), + m_explain_bound(get_dep) { } }; } diff --git a/src/math/lp/int_branch.cpp b/src/math/lp/int_branch.cpp index da34f77fd9e..10211c4faf3 100644 --- a/src/math/lp/int_branch.cpp +++ b/src/math/lp/int_branch.cpp @@ -15,7 +15,6 @@ Revision History: --*/ - #include "math/lp/int_solver.h" #include "math/lp/lar_solver.h" #include "math/lp/int_branch.h" @@ -25,16 +24,17 @@ namespace lp { int_branch::int_branch(int_solver& lia):lia(lia), lra(lia.lra) {} lia_move int_branch::operator()() { - lra.move_non_basic_columns_to_bounds(true); + lra.move_non_basic_columns_to_bounds(); int j = find_inf_int_base_column(); return j == -1? lia_move::sat : create_branch_on_column(j); } lia_move int_branch::create_branch_on_column(int j) { TRACE("check_main_int", tout << "branching" << std::endl;); - lp_assert(lia.m_t.is_empty()); + lia.m_t.clear(); + lp_assert(j != -1); - lia.m_t.add_monomial(mpq(1), lra.column_to_reported_index(j)); + lia.m_t.add_monomial(mpq(1), j); if (lia.is_free(j)) { lia.m_upper = lia.random() % 2; lia.m_k = mpq(0); @@ -53,7 +53,7 @@ lia_move int_branch::create_branch_on_column(int j) { int int_branch::find_inf_int_base_column() { -#if 0 +#if 1 return lia.select_int_infeasible_var(); #endif @@ -63,7 +63,7 @@ int int_branch::find_inf_int_base_column() { mpq small_value(1024); unsigned n = 0; lar_core_solver & lcs = lra.m_mpq_lar_core_solver; - unsigned prev_usage = 0; // to quiet down the compile + unsigned prev_usage = 0; // to quiet down the compiler unsigned k = 0; unsigned usage; unsigned j; diff --git a/src/math/lp/int_cube.cpp b/src/math/lp/int_cube.cpp index da724a543fd..c8488ca3737 100644 --- a/src/math/lp/int_cube.cpp +++ b/src/math/lp/int_cube.cpp @@ -43,7 +43,7 @@ namespace lp { if (st != lp_status::FEASIBLE && st != lp_status::OPTIMAL) { TRACE("cube", tout << "cannot find a feasible solution";); lra.pop(); - lra.move_non_basic_columns_to_bounds(false); + lra.move_non_basic_columns_to_bounds(); // it can happen that we found an integer solution here return !lra.r_basis_has_inf_int()? lia_move::sat: lia_move::undef; } @@ -55,21 +55,21 @@ namespace lp { lia.settings().stats().m_cube_success++; return lia_move::sat; } - +// i is the column index having the term bool int_cube::tighten_term_for_cube(unsigned i) { - if (!lra.term_is_used_as_row(i)) + if (!lra.column_associated_with_row(i)) return true; - const lar_term* t = lra.terms()[i]; - impq delta = get_cube_delta_for_term(*t); - TRACE("cube", lra.print_term_as_indices(*t, tout); tout << ", delta = " << delta << "\n";); + const lar_term& t = lra.get_term(i); + impq delta = get_cube_delta_for_term(t); + TRACE("cube", lra.print_term_as_indices(t, tout); tout << ", delta = " << delta << "\n";); if (is_zero(delta)) return true; - return lra.tighten_term_bounds_by_delta(tv::term(i), delta); + return lra.tighten_term_bounds_by_delta(i, delta); } bool int_cube::tighten_terms_for_cube() { - for (unsigned i = 0; i < lra.terms().size(); i++) - if (!tighten_term_for_cube(i)) { + for (const lar_term* t: lra.terms()) + if (!tighten_term_for_cube(t->j())) { TRACE("cube", tout << "cannot tighten";); return false; } @@ -86,7 +86,7 @@ namespace lp { bool seen_minus = false; bool seen_plus = false; for(lar_term::ival p : t) { - if (!lia.column_is_int(p.column())) + if (!lia.column_is_int(p.j())) goto usual_delta; const mpq & c = p.coeff(); if (c == one_of_type()) { @@ -104,7 +104,7 @@ namespace lp { usual_delta: mpq delta = zero_of_type(); for (lar_term::ival p : t) - if (lia.column_is_int(p.column())) + if (lia.column_is_int(p.j())) delta += abs(p.coeff()); delta *= mpq(1, 2); diff --git a/src/math/lp/int_gcd_test.cpp b/src/math/lp/int_gcd_test.cpp index 4801cc43641..f3b1b6389b8 100644 --- a/src/math/lp/int_gcd_test.cpp +++ b/src/math/lp/int_gcd_test.cpp @@ -250,10 +250,9 @@ namespace lp { } void int_gcd_test::add_to_explanation_from_fixed_or_boxed_column(unsigned j) { - constraint_index lc, uc; - lra.get_bound_constraint_witnesses_for_column(j, lc, uc); - lia.m_ex->push_back(lc); - lia.m_ex->push_back(uc); + auto* deps = lra.get_bound_constraint_witnesses_for_column(j); + for (auto d : lra.flatten(deps)) + lia.m_ex->push_back(d); } bool int_gcd_test::accumulate_parity(const row_strip & row, unsigned least_idx) { diff --git a/src/math/lp/int_solver.cpp b/src/math/lp/int_solver.cpp index 6c34ce16fec..f547ba274fb 100644 --- a/src/math/lp/int_solver.cpp +++ b/src/math/lp/int_solver.cpp @@ -2,8 +2,6 @@ Copyright (c) 2017 Microsoft Corporation Author: Lev Nachmanson */ - -#include #include "math/lp/int_solver.h" #include "math/lp/lar_solver.h" #include "math/lp/lp_utils.h" @@ -14,691 +12,823 @@ namespace lp { -int_solver::patcher::patcher(int_solver& lia): - lia(lia), - lra(lia.lra), - lrac(lia.lrac), - m_num_nbasic_patches(0), - m_patch_cost(0), - m_next_patch(0), - m_delay(0) -{} - -bool int_solver::patcher::should_apply() { -#if 1 - return true; -#else - if (m_delay == 0) { - return true; + int_solver::patcher::patcher(int_solver& lia): + lia(lia), + lra(lia.lra), + lrac(lia.lrac) + {} + + unsigned int_solver::patcher::count_non_int() { + unsigned non_int = 0; + for (auto j : lra.r_basis()) + if (lra.column_is_int(j) && !lra.column_value_is_int(j)) + ++non_int; + return non_int; + } + + lia_move int_solver::patcher::patch_basic_columns() { + lia.settings().stats().m_patches++; + lra.remove_fixed_vars_from_base(); + lp_assert(lia.is_feasible()); + for (unsigned j : lra.r_basis()) + if (!lra.get_value(j).is_int() && lra.column_is_int(j) && !lia.is_fixed(j)) + patch_basic_column(j); + if (!lia.has_inf_int()) { + lia.settings().stats().m_patches_success++; + return lia_move::sat; + } + return lia_move::undef; } - --m_delay; - return false; -#endif -} -lia_move int_solver::patcher::operator()() { - return patch_nbasic_columns(); -} + // clang-format on + /** + * \brief find integral and minimal, in the absolute values, deltas such that x - alpha*delta is integral too. + */ + bool get_patching_deltas(const rational& x, const rational& alpha, + rational& delta_plus, rational& delta_minus) { + auto a1 = numerator(alpha); + auto a2 = denominator(alpha); + auto x1 = numerator(x); + auto x2 = denominator(x); + if (!divides(x2, a2)) + return false; -lia_move int_solver::patcher::patch_nbasic_columns() { - lia.settings().stats().m_patches++; - lp_assert(lia.is_feasible()); - m_num_nbasic_patches = 0; - m_patch_cost = 0; - for (unsigned j : lia.lrac.m_r_nbasis) { - patch_nbasic_column(j); - } - lp_assert(lia.is_feasible()); - if (!lia.has_inf_int()) { - lia.settings().stats().m_patches_success++; - m_delay = 0; - m_next_patch = 0; - return lia_move::sat; - } - if (m_patch_cost > 0 && m_num_nbasic_patches * 10 < m_patch_cost) { - m_delay = std::min(20u, m_next_patch++); - } - else { - m_delay = 0; - m_next_patch = 0; - } - return lia_move::undef; -} + // delta has to be integral. + // We need to find delta such that x1/x2 + (a1/a2)*delta is integral (we are going to flip the delta sign later). + // Then a2*x1/x2 + a1*delta is integral, but x2 and x1 are coprime: + // that means that t = a2/x2 is + // integral. We established that a2 = x2*t Then x1 + a1*delta*(x2/a2) = x1 + // + a1*(delta/t) is integral. Taking into account that t and a1 are + // coprime we have delta = t*k, where k is an integer. + rational t = a2 / x2; + // Now we have x1/x2 + (a1/x2)*k is integral, or (x1 + a1*k)/x2 is integral. + // It is equivalent to x1 + a1*k = x2*m, where m is an integer + // We know that a2 and a1 are coprime, and x2 divides a2, so x2 and a1 are + // coprime. We can find u and v such that u*a1 + v*x2 = 1. + rational u, v; + gcd(a1, x2, u, v); + lp_assert(gcd(a1, x2, u, v).is_one()); + lp_assert((x + (a1 / a2) * (-u * t) * x1).is_int()); + // 1 = (u- l*x2 ) * a1 + (v + l*a1)*x2, for every integer l. + rational d = u * t * x1; + // We can prove that x+alpha*d is integral, + // and any other delta, satisfying x+alpha*delta, is equal to d modulo a2. + delta_plus = mod(d, a2); + lp_assert(delta_plus > 0); + delta_minus = delta_plus - a2; + lp_assert(delta_minus < 0); -void int_solver::patcher::patch_nbasic_column(unsigned j) { - impq & val = lrac.m_r_x[j]; - bool inf_l, inf_u; - impq l, u; - mpq m; - bool has_free = lia.get_freedom_interval_for_column(j, inf_l, l, inf_u, u, m); - m_patch_cost += lra.A_r().number_of_non_zeroes_in_column(j); - if (!has_free) { - return; + return true; } - bool m_is_one = m.is_one(); - bool val_is_int = lia.value_is_int(j); + /** + * \brief try to patch the basic column v + */ + bool int_solver::patcher::patch_basic_column_on_row_cell(unsigned v, row_cell const& c) { + if (v == c.var()) + return false; + if (!lra.column_is_int(c.var())) // could use real to patch integer + return false; + if (c.coeff().is_int()) + return false; + mpq a = fractional_part(c.coeff()); + mpq r = fractional_part(lra.get_value(v)); + lp_assert(0 < r && r < 1); + lp_assert(0 < a && a < 1); + mpq delta_plus, delta_minus; + if (!get_patching_deltas(r, a, delta_plus, delta_minus)) + return false; - // check whether value of j is already a multiple of m. - if (val_is_int && (m_is_one || (val.x / m).is_int())) { - return; - } - TRACE("patch_int", - tout << "TARGET j" << j << " -> ["; - if (inf_l) tout << "-oo"; else tout << l; - tout << ", "; - if (inf_u) tout << "oo"; else tout << u; - tout << "]"; - tout << ", m: " << m << ", val: " << val << ", is_int: " << lra.column_is_int(j) << "\n";); - if (m.is_big() || (!inf_l && l.x.is_big()) || (!inf_u && u.x.is_big())) { - return; - } - if (!inf_l) { - l = impq(m_is_one ? ceil(l) : m * ceil(l / m)); - if (inf_u || l <= u) { - TRACE("patch_int", tout << "patching with l: " << l << '\n';); - lra.set_value_for_nbasic_column(j, l); + if (lia.random() % 2) + return try_patch_column(v, c.var(), delta_plus) || + try_patch_column(v, c.var(), delta_minus); + else + return try_patch_column(v, c.var(), delta_minus) || + try_patch_column(v, c.var(), delta_plus); + } + + bool int_solver::patcher::try_patch_column(unsigned v, unsigned j, mpq const& delta) { + const auto & A = lra.A_r(); + if (delta < 0) { + if (lia.has_lower(j) && lia.get_value(j) + impq(delta) < lra.get_lower_bound(j)) + return false; } else { - --m_num_nbasic_patches; - TRACE("patch_int", tout << "not patching " << l << "\n";); + if (lia.has_upper(j) && lia.get_value(j) + impq(delta) > lra.get_upper_bound(j)) + return false; } + for (auto const& c : A.column(j)) { + unsigned row_index = c.var(); + unsigned bj = lrac.m_r_basis[row_index]; + auto old_val = lia.get_value(bj); + auto new_val = old_val - impq(c.coeff()*delta); + if (lia.has_lower(bj) && new_val < lra.get_lower_bound(bj)) + return false; + if (lia.has_upper(bj) && new_val > lra.get_upper_bound(bj)) + return false; + if (old_val.is_int() && !new_val.is_int()){ + return false; // do not waste resources on this case + } + // if bj == v, then, because we are patching the lra.get_value(v), + // we just need to assert that the lra.get_value(v) would be integral. + lp_assert(bj != v || lra.from_model_in_impq_to_mpq(new_val).is_int()); + } + + lra.set_value_for_nbasic_column(j, lia.get_value(j) + impq(delta)); + return true; } - else if (!inf_u) { - u = impq(m_is_one ? floor(u) : m * floor(u / m)); - lra.set_value_for_nbasic_column(j, u); - TRACE("patch_int", tout << "patching with u: " << u << '\n';); - } - else { - lra.set_value_for_nbasic_column(j, impq(0)); - TRACE("patch_int", tout << "patching with 0\n";); - } - ++m_num_nbasic_patches; -} + + void int_solver::patcher::patch_basic_column(unsigned v) { + SASSERT(!lia.is_fixed(v)); + for (auto const& c : lra.basic2row(v)) + if (patch_basic_column_on_row_cell(v, c)) + return; + } + + + int_solver::int_solver(lar_solver& lar_slv) : + lra(lar_slv), + lrac(lra.m_mpq_lar_core_solver), + m_gcd(*this), + m_patcher(*this), + m_number_of_calls(0), + m_hnf_cutter(*this), + m_hnf_cut_period(settings().hnf_cut_period()) { + lra.set_int_solver(this); + } + + // this will allow to enable and disable tracking of the pivot rows + struct check_return_helper { + lar_solver& lra; + bool m_track_touched_rows; + check_return_helper(lar_solver& ls) : + lra(ls), + m_track_touched_rows(lra.touched_rows_are_tracked()) { + lra.track_touched_rows(false); + } + ~check_return_helper() { + lra.track_touched_rows(m_track_touched_rows); + } + }; -int_solver::int_solver(lar_solver& lar_slv) : - lra(lar_slv), - lrac(lra.m_mpq_lar_core_solver), - m_gcd(*this), - m_patcher(*this), - m_number_of_calls(0), - m_hnf_cutter(*this), - m_hnf_cut_period(settings().hnf_cut_period()) { - lra.set_int_solver(this); -} + lia_move int_solver::check(lp::explanation * e) { + SASSERT(lra.ax_is_correct()); + if (!has_inf_int()) + return lia_move::sat; -// this will allow to enable and disable tracking of the pivot rows -struct check_return_helper { - lar_solver& lra; - bool m_track_pivoted_rows; - check_return_helper(lar_solver& ls) : - lra(ls), - m_track_pivoted_rows(lra.get_track_pivoted_rows()) { - lra.set_track_pivoted_rows(false); - } - ~check_return_helper() { - lra.set_track_pivoted_rows(m_track_pivoted_rows); - } -}; + m_t.clear(); + m_k.reset(); + m_ex = e; + m_ex->clear(); + m_upper = false; + m_cut_vars.reset(); + + lia_move r = lia_move::undef; -lia_move int_solver::check(lp::explanation * e) { - SASSERT(lra.ax_is_correct()); - if (!has_inf_int()) return lia_move::sat; + if (m_gcd.should_apply()) + r = m_gcd(); - m_t.clear(); - m_k.reset(); - m_ex = e; - m_ex->clear(); - m_upper = false; - lia_move r = lia_move::undef; + check_return_helper pc(lra); - if (m_gcd.should_apply()) r = m_gcd(); + if (settings().get_cancel_flag()) + return lia_move::undef; - check_return_helper pc(lra); + ++m_number_of_calls; + if (r == lia_move::undef && m_patcher.should_apply()) r = m_patcher(); + if (r == lia_move::undef && should_find_cube()) r = int_cube(*this)(); + if (r == lia_move::undef) lra.move_non_basic_columns_to_bounds(); + if (r == lia_move::undef && should_hnf_cut()) r = hnf_cut(); - if (settings().get_cancel_flag()) - return lia_move::undef; + if (r == lia_move::undef && should_gomory_cut()) r = gomory(*this).get_gomory_cuts(2); - ++m_number_of_calls; - if (r == lia_move::undef && m_patcher.should_apply()) r = m_patcher(); - if (r == lia_move::undef && should_find_cube()) r = int_cube(*this)(); - if (r == lia_move::undef && should_hnf_cut()) r = hnf_cut(); - if (r == lia_move::undef && should_gomory_cut()) r = gomory(*this)(); - if (r == lia_move::undef) r = int_branch(*this)(); - return r; -} + if (r == lia_move::undef) r = int_branch(*this)(); + if (settings().get_cancel_flag()) r = lia_move::undef; + return r; + } -std::ostream& int_solver::display_inf_rows(std::ostream& out) const { - unsigned num = lra.A_r().column_count(); - for (unsigned v = 0; v < num; v++) { - if (column_is_int(v) && !get_value(v).is_int()) { - display_column(out, v); + std::ostream& int_solver::display_inf_rows(std::ostream& out) const { + unsigned num = lra.A_r().column_count(); + for (unsigned v = 0; v < num; v++) { + if (column_is_int(v) && !get_value(v).is_int()) { + display_column(out, v); + } } - } - num = 0; - for (unsigned i = 0; i < lra.A_r().row_count(); i++) { - unsigned j = lrac.m_r_basis[i]; - if (column_is_int_inf(j)) { - num++; - lra.print_row(lra.A_r().m_rows[i], out); - out << "\n"; + num = 0; + for (unsigned i = 0; i < lra.A_r().row_count(); i++) { + unsigned j = lrac.m_r_basis[i]; + if (column_is_int_inf(j)) { + num++; + lra.print_row(lra.A_r().m_rows[i], out); + out << "\n"; + } } + out << "num of int infeasible: " << num << "\n"; + return out; } - out << "num of int infeasible: " << num << "\n"; - return out; -} -bool int_solver::cut_indices_are_columns() const { - for (lar_term::ival p : m_t) { - if (p.column().index() >= lra.A_r().column_count()) - return false; + bool int_solver::cut_indices_are_columns() const { + for (lar_term::ival p : m_t) { + if (p.j() >= lra.A_r().column_count()) + return false; + } + return true; } - return true; -} - -bool int_solver::current_solution_is_inf_on_cut() const { - SASSERT(cut_indices_are_columns()); - const auto & x = lrac.m_r_x; - impq v = m_t.apply(x); - mpq sign = m_upper ? one_of_type() : -one_of_type(); - CTRACE("current_solution_is_inf_on_cut", v * sign <= impq(m_k) * sign, - tout << "m_upper = " << m_upper << std::endl; - tout << "v = " << v << ", k = " << m_k << std::endl; - ); - return v * sign > impq(m_k) * sign; -} -bool int_solver::has_inf_int() const { - return lra.has_inf_int(); -} - -constraint_index int_solver::column_upper_bound_constraint(unsigned j) const { - return lra.get_column_upper_bound_witness(j); -} + bool int_solver::current_solution_is_inf_on_cut() const { + SASSERT(cut_indices_are_columns()); + const auto & x = lrac.m_r_x; + impq v = m_t.apply(x); + mpq sign = m_upper ? one_of_type() : -one_of_type(); + CTRACE("current_solution_is_inf_on_cut", v * sign <= impq(m_k) * sign, + tout << "m_upper = " << m_upper << std::endl; + tout << "v = " << v << ", k = " << m_k << std::endl; + tout << "term:";lra.print_term(m_t, tout) << "\n"; + ); + return v * sign > impq(m_k) * sign; + } -constraint_index int_solver::column_lower_bound_constraint(unsigned j) const { - return lra.get_column_lower_bound_witness(j); -} + bool int_solver::has_inf_int() const { + return lra.has_inf_int(); + } -unsigned int_solver::row_of_basic_column(unsigned j) const { - return lra.row_of_basic_column(j); -} + u_dependency* int_solver::column_upper_bound_constraint(unsigned j) const { + return lra.get_column_upper_bound_witness(j); + } -lp_settings& int_solver::settings() { - return lra.settings(); -} + u_dependency* int_solver::column_lower_bound_constraint(unsigned j) const { + return lra.get_column_lower_bound_witness(j); + } -const lp_settings& int_solver::settings() const { - return lra.settings(); -} + unsigned int_solver::row_of_basic_column(unsigned j) const { + return lra.row_of_basic_column(j); + } -bool int_solver::column_is_int(column_index const& j) const { - return lra.column_is_int(j); -} + lp_settings& int_solver::settings() { + return lra.settings(); + } -bool int_solver::is_real(unsigned j) const { - return !column_is_int(j); -} + const lp_settings& int_solver::settings() const { + return lra.settings(); + } -bool int_solver::value_is_int(unsigned j) const { - return lra.column_value_is_int(j); -} + bool int_solver::column_is_int(lpvar j) const { + return lra.column_is_int(j); + } -unsigned int_solver::random() { - return settings().random_next(); -} + bool int_solver::is_real(unsigned j) const { + return !column_is_int(j); + } -const impq& int_solver::upper_bound(unsigned j) const { - return lra.column_upper_bound(j); -} + bool int_solver::value_is_int(unsigned j) const { + return lra.column_value_is_int(j); + } -const impq& int_solver::lower_bound(unsigned j) const { - return lra.column_lower_bound(j); -} + unsigned int_solver::random() { + return settings().random_next(); + } -bool int_solver::is_term(unsigned j) const { - return lra.column_corresponds_to_term(j); -} + const impq& int_solver::upper_bound(unsigned j) const { + return lra.column_upper_bound(j); + } -unsigned int_solver::column_count() const { - return lra.column_count(); -} + const impq& int_solver::lower_bound(unsigned j) const { + return lra.column_lower_bound(j); + } -bool int_solver::should_find_cube() { - return m_number_of_calls % settings().m_int_find_cube_period == 0; -} + bool int_solver::is_term(unsigned j) const { + return lra.column_has_term(j); + } + unsigned int_solver::column_count() const { + return lra.column_count(); + } -bool int_solver::should_gomory_cut() { - return m_number_of_calls % settings().m_int_gomory_cut_period == 0; -} + bool int_solver::should_find_cube() { + return m_number_of_calls % settings().m_int_find_cube_period == 0; + } -bool int_solver::should_hnf_cut() { - return settings().enable_hnf() && m_number_of_calls % m_hnf_cut_period == 0; -} + bool int_solver::should_gomory_cut() { + return m_number_of_calls % settings().m_int_gomory_cut_period == 0; + } -lia_move int_solver::hnf_cut() { - lia_move r = m_hnf_cutter.make_hnf_cut(); - if (r == lia_move::undef) { - m_hnf_cut_period *= 2; + bool int_solver::should_hnf_cut() { + return settings().enable_hnf() && m_number_of_calls % m_hnf_cut_period == 0; } - else { - m_hnf_cut_period = settings().hnf_cut_period(); + + lia_move int_solver::hnf_cut() { + lia_move r = m_hnf_cutter.make_hnf_cut(); + if (r == lia_move::undef) + m_hnf_cut_period *= 2; + else + m_hnf_cut_period = settings().hnf_cut_period(); + return r; } - return r; -} -bool int_solver::has_lower(unsigned j) const { - switch (lrac.m_column_types()[j]) { - case column_type::fixed: - case column_type::boxed: - case column_type::lower_bound: - return true; - default: - return false; + bool int_solver::has_lower(unsigned j) const { + switch (lrac.m_column_types()[j]) { + case column_type::fixed: + case column_type::boxed: + case column_type::lower_bound: + return true; + default: + return false; + } } -} -bool int_solver::has_upper(unsigned j) const { - switch (lrac.m_column_types()[j]) { - case column_type::fixed: - case column_type::boxed: - case column_type::upper_bound: - return true; - default: - return false; + bool int_solver::has_upper(unsigned j) const { + switch (lrac.m_column_types()[j]) { + case column_type::fixed: + case column_type::boxed: + case column_type::upper_bound: + return true; + default: + return false; + } } -} -static void set_lower(impq & l, bool & inf_l, impq const & v ) { - if (inf_l || v > l) { - l = v; - inf_l = false; + static void set_lower(impq & l, bool & inf_l, impq const & v ) { + if (inf_l || v > l) { + l = v; + inf_l = false; + } } -} -static void set_upper(impq & u, bool & inf_u, impq const & v) { - if (inf_u || v < u) { - u = v; - inf_u = false; + static void set_upper(impq & u, bool & inf_u, impq const & v) { + if (inf_u || v < u) { + u = v; + inf_u = false; + } } -} -// this function assumes that all basic columns dependend on j are feasible -bool int_solver::get_freedom_interval_for_column(unsigned j, bool & inf_l, impq & l, bool & inf_u, impq & u, mpq & m) { - if (lrac.m_r_heading[j] >= 0) // the basic var - return false; + // this function assumes that all basic columns dependend on j are feasible + bool int_solver::get_freedom_interval_for_column(unsigned j, bool & inf_l, impq & l, bool & inf_u, impq & u, mpq & m) { + if (lrac.m_r_heading[j] >= 0 || is_fixed(j)) // basic or fixed var + return false; - TRACE("random_update", display_column(tout, j) << ", is_int = " << column_is_int(j) << "\n";); - impq const & xj = get_value(j); + TRACE("random_update", display_column(tout, j) << ", is_int = " << column_is_int(j) << "\n";); + impq const & xj = get_value(j); - inf_l = true; - inf_u = true; - l = u = zero_of_type(); - m = mpq(1); + inf_l = true; + inf_u = true; + l = u = zero_of_type(); + m = mpq(1); - if (has_lower(j)) - set_lower(l, inf_l, lower_bound(j) - xj); + if (has_lower(j)) + set_lower(l, inf_l, lower_bound(j) - xj); - if (has_upper(j)) - set_upper(u, inf_u, upper_bound(j) - xj); + if (has_upper(j)) + set_upper(u, inf_u, upper_bound(j) - xj); - const auto & A = lra.A_r(); - TRACE("random_update", tout << "m = " << m << "\n";); - - auto delta = [](mpq const& x, impq const& y, impq const& z) { - if (x.is_one()) - return y - z; - if (x.is_minus_one()) - return z - y; - return (y - z) / x; - }; - - for (auto c : A.column(j)) { - unsigned row_index = c.var(); - const mpq & a = c.coeff(); - unsigned i = lrac.m_r_basis[row_index]; - impq const & xi = get_value(i); - lp_assert(lrac.m_r_solver.column_is_feasible(i)); - if (column_is_int(i) && !a.is_int()) - m = lcm(m, denominator(a)); - - - if (!inf_l && !inf_u) { - if (l == u) - continue; + const auto & A = lra.A_r(); + TRACE("random_update", tout << "m = " << m << "\n";); + + auto delta = [](mpq const& x, impq const& y, impq const& z) { + if (x.is_one()) + return y - z; + if (x.is_minus_one()) + return z - y; + return (y - z) / x; + }; + + for (auto c : A.column(j)) { + unsigned row_index = c.var(); + const mpq & a = c.coeff(); + unsigned i = lrac.m_r_basis[row_index]; + impq const & xi = get_value(i); + lp_assert(lrac.m_r_solver.column_is_feasible(i)); + if (column_is_int(i) && !a.is_int() && xi.is_int()) + m = lcm(m, denominator(a)); + + if (!inf_l && !inf_u && l == u) + continue; + + if (a.is_neg()) { + if (has_lower(i)) + set_lower(l, inf_l, delta(a, xi, lra.get_lower_bound(i))); + if (has_upper(i)) + set_upper(u, inf_u, delta(a, xi, lra.get_upper_bound(i))); + } + else { + if (has_upper(i)) + set_lower(l, inf_l, delta(a, xi, lra.get_upper_bound(i))); + if (has_lower(i)) + set_upper(u, inf_u, delta(a, xi, lra.get_lower_bound(i))); + } } - if (a.is_neg()) { - if (has_lower(i)) - set_lower(l, inf_l, delta(a, xi, lrac.m_r_lower_bounds()[i])); - if (has_upper(i)) - set_upper(u, inf_u, delta(a, xi, lrac.m_r_upper_bounds()[i])); - } - else { - if (has_upper(i)) - set_lower(l, inf_l, delta(a, xi, lrac.m_r_upper_bounds()[i])); - if (has_lower(i)) - set_upper(u, inf_u, delta(a, xi, lrac.m_r_lower_bounds()[i])); - } - } + l += xj; + u += xj; - l += xj; - u += xj; - - TRACE("freedom_interval", - tout << "freedom variable for:\n"; - tout << lra.get_variable_name(j); - tout << "["; - if (inf_l) tout << "-oo"; else tout << l; - tout << "; "; - if (inf_u) tout << "oo"; else tout << u; - tout << "]\n"; - tout << "val = " << get_value(j) << "\n"; - tout << "return " << (inf_l || inf_u || l <= u); - ); - return (inf_l || inf_u || l <= u); -} + TRACE("freedom_interval", + tout << "freedom variable for:\n"; + tout << lra.get_variable_name(j); + tout << "["; + if (inf_l) tout << "-oo"; else tout << l; + tout << "; "; + if (inf_u) tout << "oo"; else tout << u; + tout << "]\n"; + tout << "val = " << get_value(j) << "\n"; + tout << "return " << (inf_l || inf_u || l <= u); + ); + return (inf_l || inf_u || l <= u); + } -bool int_solver::is_feasible() const { - lp_assert( - lrac.m_r_solver.calc_current_x_is_feasible_include_non_basis() == - lrac.m_r_solver.current_x_is_feasible()); - return lrac.m_r_solver.current_x_is_feasible(); -} + bool int_solver::is_feasible() const { + lp_assert( + lrac.m_r_solver.calc_current_x_is_feasible_include_non_basis() == + lrac.m_r_solver.current_x_is_feasible()); + return lrac.m_r_solver.current_x_is_feasible(); + } -const impq & int_solver::get_value(unsigned j) const { - return lrac.m_r_x[j]; -} + const impq & int_solver::get_value(unsigned j) const { + return lrac.m_r_x[j]; + } -std::ostream& int_solver::display_column(std::ostream & out, unsigned j) const { - return lrac.m_r_solver.print_column_info(j, out); -} + std::ostream& int_solver::display_column(std::ostream & out, unsigned j) const { + return lrac.m_r_solver.print_column_info(j, out); + } -bool int_solver::column_is_int_inf(unsigned j) const { - return column_is_int(j) && (!value_is_int(j)); -} + bool int_solver::column_is_int_inf(unsigned j) const { + return column_is_int(j) && (!value_is_int(j)); + } -bool int_solver::is_base(unsigned j) const { - return lrac.m_r_heading[j] >= 0; -} + bool int_solver::is_base(unsigned j) const { + return lrac.m_r_heading[j] >= 0; + } -bool int_solver::is_boxed(unsigned j) const { - return lrac.m_column_types[j] == column_type::boxed; -} + bool int_solver::is_boxed(unsigned j) const { + return lrac.m_column_types[j] == column_type::boxed; + } -bool int_solver::is_fixed(unsigned j) const { - return lrac.m_column_types[j] == column_type::fixed; -} + bool int_solver::is_fixed(unsigned j) const { + return lrac.m_column_types[j] == column_type::fixed; + } -bool int_solver::is_free(unsigned j) const { - return lrac.m_column_types[j] == column_type::free_column; -} + bool int_solver::is_free(unsigned j) const { + return lrac.m_column_types[j] == column_type::free_column; + } -bool int_solver::at_bound(unsigned j) const { - auto & mpq_solver = lrac.m_r_solver; - switch (mpq_solver.m_column_types[j] ) { - case column_type::fixed: - case column_type::boxed: - return - mpq_solver.m_lower_bounds[j] == get_value(j) || - mpq_solver.m_upper_bounds[j] == get_value(j); - case column_type::lower_bound: - return mpq_solver.m_lower_bounds[j] == get_value(j); - case column_type::upper_bound: - return mpq_solver.m_upper_bounds[j] == get_value(j); - default: - return false; + bool int_solver::at_bound(unsigned j) const { + auto & mpq_solver = lrac.m_r_solver; + switch (mpq_solver.m_column_types[j] ) { + case column_type::fixed: + case column_type::boxed: + return + mpq_solver.m_lower_bounds[j] == get_value(j) || + mpq_solver.m_upper_bounds[j] == get_value(j); + case column_type::lower_bound: + return mpq_solver.m_lower_bounds[j] == get_value(j); + case column_type::upper_bound: + return mpq_solver.m_upper_bounds[j] == get_value(j); + default: + return false; + } } -} -bool int_solver::at_lower(unsigned j) const { - auto & mpq_solver = lrac.m_r_solver; - switch (mpq_solver.m_column_types[j] ) { - case column_type::fixed: - case column_type::boxed: - case column_type::lower_bound: - return mpq_solver.m_lower_bounds[j] == get_value(j); - default: - return false; + bool int_solver::at_lower(unsigned j) const { + auto & mpq_solver = lrac.m_r_solver; + switch (mpq_solver.m_column_types[j] ) { + case column_type::fixed: + case column_type::boxed: + case column_type::lower_bound: + return mpq_solver.m_lower_bounds[j] == get_value(j); + default: + return false; + } } -} -bool int_solver::at_upper(unsigned j) const { - auto & mpq_solver = lrac.m_r_solver; - switch (mpq_solver.m_column_types[j] ) { - case column_type::fixed: - case column_type::boxed: - case column_type::upper_bound: - return mpq_solver.m_upper_bounds[j] == get_value(j); - default: - return false; + bool int_solver::at_upper(unsigned j) const { + auto & mpq_solver = lrac.m_r_solver; + switch (mpq_solver.m_column_types[j] ) { + case column_type::fixed: + case column_type::boxed: + case column_type::upper_bound: + return mpq_solver.m_upper_bounds[j] == get_value(j); + default: + return false; + } } -} -std::ostream & int_solver::display_row(std::ostream & out, lp::row_strip const & row) const { -bool first = true; - auto & rslv = lrac.m_r_solver; -for (const auto &c : row) - { - if (is_fixed(c.var())) - { - if (!get_value(c.var()).is_zero()) - { - impq val = get_value(c.var()) * c.coeff(); - if (!first && val.is_pos()) - out << "+"; - if (val.y.is_zero()) - out << val.x << " "; - else - out << val << " "; + std::ostream & int_solver::display_row(std::ostream & out, lp::row_strip const & row) const { + bool first = true; + auto & rslv = lrac.m_r_solver; + for (const auto &c : row) { + if (is_fixed(c.var())) { + if (!get_value(c.var()).is_zero()) { + impq val = get_value(c.var()) * c.coeff(); + if (!first && val.is_pos()) + out << "+"; + if (val.y.is_zero()) + out << val.x << " "; + else + out << val << " "; + } + first = false; + continue; } - first = false; - continue; - } - if (c.coeff().is_one()) - { - if (!first) - out << "+"; - } - else if (c.coeff().is_minus_one()) - out << "-"; - else - { - if (c.coeff().is_pos()) - { + if (c.coeff().is_one()) { if (!first) out << "+"; } - if (c.coeff().is_big()) - { - out << " b*"; + else if (c.coeff().is_minus_one()) + out << "-"; + else { + if (c.coeff().is_pos() && !first) + out << "+"; + if (c.coeff().is_big()) + out << " b*"; + else + out << c.coeff(); } - else - out << c.coeff(); + out << rslv.column_name(c.var()) << " "; + first = false; } - out << rslv.column_name(c.var()) << " "; - first = false; - } - out << "\n"; - for (const auto &c : row) - { - if (is_fixed(c.var())) - continue; - rslv.print_column_info(c.var(), out); - if (is_base(c.var())) - out << "j" << c.var() << " base\n"; - } - return out; -} -std::ostream& int_solver::display_row_info(std::ostream & out, unsigned row_index) const { - auto & rslv = lrac.m_r_solver; - auto const& row = rslv.m_A.m_rows[row_index]; - return display_row(out, row); -} + out << "\n"; + for (const auto &c : row) { + if (is_fixed(c.var())) + continue; + rslv.print_column_info(c.var(), out); + if (is_base(c.var())) + out << "j" << c.var() << " base\n"; + } + return out; + } + + std::ostream& int_solver::display_row_info(std::ostream & out, unsigned row_index) const { + auto & rslv = lrac.m_r_solver; + auto const& row = rslv.m_A.m_rows[row_index]; + return display_row(out, row); + } + + bool int_solver::shift_var(unsigned j, unsigned range) { + if (is_fixed(j) || is_base(j)) + return false; + if (settings().get_cancel_flag()) + return false; + bool inf_l = false, inf_u = false; + impq l, u; + mpq m; + if (!get_freedom_interval_for_column(j, inf_l, l, inf_u, u, m)) + return false; + if (settings().get_cancel_flag()) + return false; + const impq & x = get_value(j); + // x, the value of j column, might be shifted on a multiple of m + if (inf_l && inf_u) { + impq new_val = m * impq(random() % (range + 1)) + x; + lra.set_value_for_nbasic_column(j, new_val); + return true; + } + if (column_is_int(j)) { + if (!inf_l) + l = impq(ceil(l)); + if (!inf_u) + u = impq(floor(u)); + } + if (!inf_l && !inf_u && l >= u) + return false; -bool int_solver::shift_var(unsigned j, unsigned range) { - if (is_fixed(j) || is_base(j)) - return false; - if (settings().get_cancel_flag()) - return false; - bool inf_l = false, inf_u = false; - impq l, u; - mpq m; - VERIFY(get_freedom_interval_for_column(j, inf_l, l, inf_u, u, m) || settings().get_cancel_flag()); - if (settings().get_cancel_flag()) - return false; - const impq & x = get_value(j); - // x, the value of j column, might be shifted on a multiple of m - if (inf_l && inf_u) { - impq new_val = m * impq(random() % (range + 1)) + x; - lra.set_value_for_nbasic_column(j, new_val); - return true; - } - if (column_is_int(j)) { - if (!inf_l) { - l = impq(ceil(l)); + if (inf_u) { + SASSERT(!inf_l); + impq new_val = x + m * impq(random() % (range + 1)); + lra.set_value_for_nbasic_column(j, new_val); + return true; } - if (!inf_u) { - u = impq(floor(u)); + + if (inf_l) { + SASSERT(!inf_u); + impq new_val = x - m * impq(random() % (range + 1)); + lra.set_value_for_nbasic_column(j, new_val); + return true; } - } - if (!inf_l && !inf_u && l >= u) - return false; - if (inf_u) { - SASSERT(!inf_l); - impq new_val = x + m * impq(random() % (range + 1)); + SASSERT(!inf_l && !inf_u); + // The shift has to be a multiple of m: let us look for s, such that the shift is m*s. + // We have new_val = x+m*s <= u, so m*s <= u-x and, finally, s <= floor((u- x)/m) = a + // The symmetric reasoning gives us s >= ceil((l-x)/m) = b + // We randomly pick s in the segment [b, a] + mpq a = floor((u - x) / m); + mpq b = ceil((l - x) / m); + mpq r = a - b; + if (!r.is_pos()) + return false; + TRACE("int_solver", tout << "a = " << a << ", b = " << b << ", r = " << r<< ", m = " << m << "\n";); + if (r < mpq(range)) + range = static_cast(r.get_uint64()); + + mpq s = b + mpq(random() % (range + 1)); + impq new_val = x + m * impq(s); + TRACE("int_solver", tout << "new_val = " << new_val << "\n";); + SASSERT(l <= new_val && new_val <= u); lra.set_value_for_nbasic_column(j, new_val); return true; } - if (inf_l) { - SASSERT(!inf_u); - impq new_val = x - m * impq(random() % (range + 1)); - lra.set_value_for_nbasic_column(j, new_val); - return true; - } - SASSERT(!inf_l && !inf_u); - // The shift has to be a multiple of m: let us look for s, such that the shift is m*s. - // We have new_val = x+m*s <= u, so m*s <= u-x and, finally, s <= floor((u- x)/m) = a - // The symmetric reasoning gives us s >= ceil((l-x)/m) = b - // We randomly pick s in the segment [b, a] - mpq a = floor((u - x) / m); - mpq b = ceil((l - x) / m); - mpq r = a - b; - if (!r.is_pos()) - return false; - TRACE("int_solver", tout << "a = " << a << ", b = " << b << ", r = " << r<< ", m = " << m << "\n";); - if (r < mpq(range)) - range = static_cast(r.get_uint64()); - - mpq s = b + mpq(random() % (range + 1)); - impq new_val = x + m * impq(s); - TRACE("int_solver", tout << "new_val = " << new_val << "\n";); - SASSERT(l <= new_val && new_val <= u); - lra.set_value_for_nbasic_column(j, new_val); - return true; -} + int int_solver::select_int_infeasible_var() { + int r_small_box = -1; + int r_small_value = -1; + int r_any_value = -1; + unsigned n_small_box = 1; + unsigned n_small_value = 1; + unsigned n_any_value = 1; + mpq range; + mpq new_range; + mpq small_value(1024); + lar_core_solver & lcs = lra.m_mpq_lar_core_solver; + unsigned prev_usage = 0; -// not used: -bool int_solver::non_basic_columns_are_at_bounds() const { - for (unsigned j : lrac.m_r_nbasis) { - auto & val = lrac.m_r_x[j]; - switch (lrac.m_column_types()[j]) { - case column_type::boxed: - if (val != lrac.m_r_lower_bounds()[j] && val != lrac.m_r_upper_bounds()[j]) - return false; - break; - case column_type::lower_bound: - if (val != lrac.m_r_lower_bounds()[j]) - return false; - break; - case column_type::upper_bound: - if (val != lrac.m_r_upper_bounds()[j]) - return false; - break; - default: - if (column_is_int(j) && !val.is_int()) { - return false; - } - } - } - return true; -} + auto add_column = [&](bool improved, int& result, unsigned& n, unsigned j) { + if (result == -1) + result = j; + else if (improved && ((random() % (++n)) == 0)) + result = j; + }; + + for (unsigned j : lra.r_basis()) { + if (!column_is_int_inf(j)) + continue; + if (m_cut_vars.contains(j)) + continue; -int int_solver::select_int_infeasible_var() { - int result = -1; - mpq range; - mpq new_range; - mpq small_value(1024); - unsigned n = 0; - lar_core_solver & lcs = lra.m_mpq_lar_core_solver; - unsigned prev_usage = 0; // to quiet down the compile - unsigned k = 0; - unsigned usage; - unsigned j; - - enum state { small_box, is_small_value, any_value, not_found }; - state st = not_found; - - // 1. small box - // 2. small value - // 3. any value - for (; k < lra.r_basis().size(); k++) { - j = lra.r_basis()[k]; - if (!column_is_int_inf(j)) - continue; - usage = lra.usage_in_terms(j); - if (is_boxed(j) && (new_range = lcs.m_r_upper_bounds()[j].x - lcs.m_r_lower_bounds()[j].x - rational(2*usage)) <= small_value) { SASSERT(!is_fixed(j)); - if (st != small_box) { - n = 0; - st = small_box; - } - if (n == 0 || new_range < range) { - result = j; - range = new_range; - n = 1; + + unsigned usage = lra.usage_in_terms(j); + if (is_boxed(j) && (new_range = lcs.m_r_upper_bounds()[j].x - lcs.m_r_lower_bounds()[j].x - rational(2*usage)) <= small_value) { + + bool improved = new_range <= range || r_small_box == -1; + if (improved) + range = new_range; + add_column(improved, r_small_box, n_small_box, j); + continue; } - else if (new_range == range && (random() % (++n) == 0)) { - result = j; + impq const& value = get_value(j); + if (abs(value.x) < small_value || + (has_upper(j) && small_value > upper_bound(j).x - value.x) || + (has_lower(j) && small_value > value.x - lower_bound(j).x)) { + TRACE("int_solver", tout << "small j" << j << "\n"); + add_column(true, r_small_value, n_small_value, j); + continue; } - continue; + TRACE("int_solver", tout << "any j" << j << "\n"); + add_column(usage >= prev_usage, r_any_value, n_any_value, j); + if (usage > prev_usage) + prev_usage = usage; + } + + if (r_small_box != -1 && (random() % 3 != 0)) + return r_small_box; + if (r_small_value != -1 && (random() % 3) != 0) + return r_small_value; + if (r_any_value != -1) + return r_any_value; + if (r_small_box != -1) + return r_small_box; + return r_small_value; + } + + void int_solver::simplify(std::function& is_root) { + return; + +#if 0 + + // in-processing simplification can go here, such as bounds improvements. + + if (!lra.is_feasible()) { + lra.find_feasible_solution(); + if (!lra.is_feasible()) + return; } - if (st == small_box) - continue; - impq const& value = get_value(j); - if (abs(value.x) < small_value || - (has_upper(j) && small_value > upper_bound(j).x - value.x) || - (has_lower(j) && small_value > value.x - lower_bound(j).x)) { - if (st != is_small_value) { - n = 0; - st = is_small_value; + + + lp::explanation exp; + m_ex = &exp; + m_t.clear(); + m_k.reset(); + + if (has_inf_int()) + local_gomory(5); + + stopwatch sw; + explanation exp1, exp2; + + // + // identify equalities + // + + m_equalities.reset(); + map value2roots; + + vector> coeffs; + coeffs.push_back({-rational::one(), 0}); + coeffs.push_back({rational::one(), 0}); + + num_checks = 0; + + // make sure values are sampled with respect to the same state of the Simplex. + vector values; + for (lpvar j = 0; j < lra.column_count(); ++j) + values.push_back(get_value(j).x); + + sw.reset(); + sw.start(); + start = random(); + for (lpvar j0 = 0; j0 < lra.column_count(); ++j0) { + lpvar j = (j0 + start) % lra.column_count(); + if (is_fixed(j)) + continue; + if (!lra.column_is_int(j)) + continue; + if (!is_root(j)) + continue; + rational value = values[j]; + if (!value2roots.contains(value)) { + unsigned_vector vec; + vec.push_back(j); + value2roots.insert(value, vec); + continue; } - if (random() % (++n) == 0) - result = j; + auto& roots = value2roots.find(value); + bool has_eq = false; + // + // Super inefficient check. There are better ways. + // 1. call into equality finder: + // the cheap equality finder can also be used. + // 2. value sweeping: + // update partitions of values based on feasible tableaus + // instead of having just the values vector use the values + // collected when the find_feasible_solution succeeds with + // a new assignment. + // 3. a more expensive equality finder: + // use the tableau to extract equalities from tight rows. + // If x = y is implied, there is a set of rows that link x and y + // and such that the variables are at their bounds. + // 4. retain information between calls: + // If simplification is invoked at the same backtracking level (or above) + // form the previous call and it is established that x <= y (but not x == y), then no need to + // recheck the inequality x <= y. + for (auto k : roots) { + bool le = false, ge = false; + u_dependency* dep = nullptr; + lra.push(); + coeffs[0].second = j; + coeffs[1].second = k; + lp::lpvar term_index = lra.add_term(coeffs, UINT_MAX); + term_index = lra.map_term_index_to_column_index(term_index); + lra.push(); + lra.update_column_type_and_bound(term_index, lp::lconstraint_kind::GE, mpq(1), nullptr); + lra.find_feasible_solution(); + if (!lra.is_feasible()) { + lra.get_infeasibility_explanation(exp1); + le = true; + } + lra.pop(1); + ++num_checks; + if (le) { + lra.push(); + lra.update_column_type_and_bound(term_index, lp::lconstraint_kind::LE, mpq(-1), nullptr); + lra.find_feasible_solution(); + if (!lra.is_feasible()) { + lra.get_infeasibility_explanation(exp2); + exp1.add_expl(exp2); + ge = true; + } + lra.pop(1); + ++num_checks; + } + lra.pop(1); + if (le && ge) { + has_eq = true; + m_equalities.push_back({j, k, exp1}); + break; + } + // artificial throttle. + if (num_checks > 10000) + break; + } + if (!has_eq) + roots.push_back(j); + + // artificial throttle. + if (num_checks > 10000) + break; } - if (st == is_small_value) - continue; - SASSERT(st == not_found || st == any_value); - st = any_value; - if (n == 0 /*|| usage > prev_usage*/) { - result = j; - prev_usage = usage; - n = 1; - } - else if (usage > 0 && /*usage == prev_usage && */ (random() % (++n) == 0)) - result = j; - } - - return result; -} + sw.stop(); + std::cout << "equalities " << m_equalities.size() << " num checks " << num_checks << " time: " << sw.get_seconds() << "\n"; + std::cout.flush(); + + // + // Cuts? Eg. for 0-1 variables or bounded integers? + // +#endif + } + + + } diff --git a/src/math/lp/int_solver.h b/src/math/lp/int_solver.h index 822e1cf1e5f..524f8fb2832 100644 --- a/src/math/lp/int_solver.h +++ b/src/math/lp/int_solver.h @@ -20,7 +20,7 @@ Revision History: #pragma once #include "math/lp/lp_settings.h" #include "math/lp/static_matrix.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" #include "math/lp/lar_term.h" #include "math/lp/lar_constraints.h" #include "math/lp/hnf_cutter.h" @@ -33,7 +33,7 @@ class lar_solver; class lar_core_solver; class int_solver { - friend class create_cut; + friend struct create_cut; friend class gomory; friend class int_cube; friend class int_branch; @@ -44,17 +44,16 @@ class int_solver { int_solver& lia; lar_solver& lra; lar_core_solver& lrac; - unsigned m_num_nbasic_patches; - unsigned m_patch_cost; - unsigned m_next_patch; - unsigned m_delay; public: patcher(int_solver& lia); - bool should_apply(); - lia_move operator()(); - void patch_nbasic_column(unsigned j); + bool should_apply() const { return true; } + lia_move operator()() { return patch_basic_columns(); } + void patch_basic_column(unsigned j); + bool try_patch_column(unsigned v, unsigned j, mpq const& delta); + unsigned count_non_int(); private: - lia_move patch_nbasic_columns(); + bool patch_basic_column_on_row_cell(unsigned v, row_cell const& c); + lia_move patch_basic_columns(); }; lar_solver& lra; @@ -64,10 +63,13 @@ class int_solver { unsigned m_number_of_calls; lar_term m_t; // the term to return in the cut mpq m_k; // the right side of the cut + bool m_upper; // cut is an upper bound explanation *m_ex; // the conflict explanation - bool m_upper; // we have a cut m_t*x <= k if m_upper is true nad m_t*x >= k otherwise hnf_cutter m_hnf_cutter; unsigned m_hnf_cut_period; + unsigned_vector m_cut_vars; // variables that should not be selected for cuts + + vector m_equalities; public: int_solver(lar_solver& lp); @@ -81,16 +83,17 @@ class int_solver { bool is_real(unsigned j) const; const impq & lower_bound(unsigned j) const; const impq & upper_bound(unsigned j) const; - bool column_is_int(column_index const& j) const; + bool column_is_int(lpvar j) const; const impq & get_value(unsigned j) const; bool at_lower(unsigned j) const; bool at_upper(unsigned j) const; - + void simplify(std::function& is_root); + vector const& equalities() const { return m_equalities; } + private: // lia_move patch_nbasic_columns(); bool get_freedom_interval_for_column(unsigned j, bool & inf_l, impq & l, bool & inf_u, impq & u, mpq & m); bool is_boxed(unsigned j) const; - bool is_fixed(unsigned j) const; bool is_free(unsigned j) const; bool value_is_int(unsigned j) const; bool is_feasible() const; @@ -106,13 +109,13 @@ class int_solver { bool has_lower(unsigned j) const; bool has_upper(unsigned j) const; unsigned row_of_basic_column(unsigned j) const; - bool non_basic_columns_are_at_bounds() const; bool cut_indices_are_columns() const; - + public: + bool is_fixed(unsigned j) const; std::ostream& display_column(std::ostream & out, unsigned j) const; - constraint_index column_upper_bound_constraint(unsigned j) const; - constraint_index column_lower_bound_constraint(unsigned j) const; + u_dependency* column_upper_bound_constraint(unsigned j) const; + u_dependency* column_lower_bound_constraint(unsigned j) const; bool current_solution_is_inf_on_cut() const; bool shift_var(unsigned j, unsigned range); @@ -125,12 +128,9 @@ class int_solver { public: bool is_term(unsigned j) const; unsigned column_count() const; - bool all_columns_are_bounded() const; - void find_feasible_solution(); lia_move hnf_cut(); - void patch_nbasic_column(unsigned j) { m_patcher.patch_nbasic_column(j); } int select_int_infeasible_var(); - - }; + +}; } diff --git a/src/math/lp/lar_constraints.h b/src/math/lp/lar_constraints.h index f8cffbe5793..b9069625f5c 100644 --- a/src/math/lp/lar_constraints.h +++ b/src/math/lp/lar_constraints.h @@ -1,21 +1,10 @@ /*++ Copyright (c) 2017 Microsoft Corporation -Module Name: - - - -Abstract: - - - Author: Lev Nachmanson (levnach) -Revision History: - - --*/ #pragma once @@ -27,7 +16,7 @@ Revision History: #include "util/region.h" #include "util/stacked_value.h" #include "math/lp/lp_utils.h" -#include "math/lp/ul_pair.h" +#include "math/lp/column.h" #include "math/lp/lar_term.h" #include "math/lp/column_namer.h" namespace lp { @@ -53,15 +42,19 @@ class lar_base_constraint { mpq m_right_side; bool m_active; unsigned m_j; -public: + u_dependency* m_dep; + + public: - virtual vector> coeffs() const = 0; - lar_base_constraint(unsigned j, lconstraint_kind kind, const mpq& right_side) :m_kind(kind), m_right_side(right_side), m_active(false), m_j(j) {} + virtual vector> coeffs() const = 0; + lar_base_constraint(unsigned j, lconstraint_kind kind, u_dependency* dep, const mpq& right_side) : + m_kind(kind), m_right_side(right_side), m_active(false), m_j(j), m_dep(dep) {} virtual ~lar_base_constraint() = default; lconstraint_kind kind() const { return m_kind; } mpq const& rhs() const { return m_right_side; } unsigned column() const { return m_j; } + u_dependency* dep() const { return m_dep; } void activate() { m_active = true; } void deactivate() { m_active = false; } @@ -73,11 +66,11 @@ class lar_base_constraint { class lar_var_constraint: public lar_base_constraint { public: - lar_var_constraint(unsigned j, lconstraint_kind kind, const mpq& right_side) : - lar_base_constraint(j, kind, right_side) {} + lar_var_constraint(unsigned j, lconstraint_kind kind, u_dependency* dep, const mpq& right_side) : + lar_base_constraint(j, kind, dep, right_side) {} - vector> coeffs() const override { - vector> ret; + vector> coeffs() const override { + vector> ret; ret.push_back(std::make_pair(one_of_type(), column())); return ret; } @@ -88,20 +81,21 @@ class lar_var_constraint: public lar_base_constraint { class lar_term_constraint: public lar_base_constraint { const lar_term * m_term; public: - lar_term_constraint(unsigned j, const lar_term *t, lconstraint_kind kind, const mpq& right_side) : - lar_base_constraint(j, kind, right_side), m_term(t) {} + lar_term_constraint(unsigned j, const lar_term* t, lconstraint_kind kind, u_dependency* dep, const mpq& right_side) : + lar_base_constraint(j, kind, dep, right_side), m_term(t) {} - vector> coeffs() const override { return m_term->coeffs_as_vector(); } + vector> coeffs() const override { return m_term->coeffs_as_vector(); } unsigned size() const override { return m_term->size();} }; class constraint_set { region m_region; column_namer& m_namer; + u_dependency_manager& m_dep_manager; vector m_constraints; stacked_value m_constraint_count; unsigned_vector m_active; - stacked_value m_active_lim; + stacked_value m_active_lim; constraint_index add(lar_base_constraint* c) { constraint_index ci = m_constraints.size(); @@ -137,9 +131,15 @@ class constraint_set { return out << "constraint " << T_to_string(ci) << " is not found" << std::endl; } + u_dependency* mk_dep() { + return m_dep_manager.mk_leaf(m_constraints.size()); + } + public: - constraint_set(column_namer& cn): - m_namer(cn) {} + constraint_set(u_dependency_manager& d, column_namer& cn): + m_namer(cn), + m_dep_manager(d) + {} ~constraint_set() { for (auto* c : m_constraints) @@ -168,12 +168,13 @@ class constraint_set { m_region.pop_scope(k); } - constraint_index add_var_constraint(var_index j, lconstraint_kind k, mpq const& rhs) { - return add(new (m_region) lar_var_constraint(j, k, rhs)); + constraint_index add_var_constraint(lpvar j, lconstraint_kind k, mpq const& rhs) { + return add(new (m_region) lar_var_constraint(j, k, mk_dep(), rhs)); } constraint_index add_term_constraint(unsigned j, const lar_term* t, lconstraint_kind k, mpq const& rhs) { - return add(new (m_region) lar_term_constraint(j, t, k, rhs)); + auto* dep = mk_dep(); + return add(new (m_region) lar_term_constraint(j, t, k, dep, rhs)); } // future behavior uses activation bit. diff --git a/src/math/lp/lar_core_solver.h b/src/math/lp/lar_core_solver.h index 06ef4d50ba3..05f33b94b70 100644 --- a/src/math/lp/lar_core_solver.h +++ b/src/math/lp/lar_core_solver.h @@ -16,15 +16,15 @@ Copyright (c) 2017 Microsoft Corporation #include "math/lp/stacked_vector.h" #include "util/stacked_value.h" namespace lp { - class lar_core_solver { vector> m_infeasible_linear_combination; int m_infeasible_sum_sign; // todo: get rid of this field vector> m_right_sides_dummy; vector m_costs_dummy; - -public: stacked_value m_stacked_simplex_strategy; + +public: + stacked_vector m_column_types; // r - solver fields, for rational numbers vector> m_r_x; // the solution @@ -44,8 +44,6 @@ class lar_core_solver { const column_namer & column_names ); - lp_settings & settings() { return m_r_solver.m_settings;} - const lp_settings & settings() const { return m_r_solver.m_settings;} int get_infeasible_sum_sign() const { return m_infeasible_sum_sign; } @@ -58,8 +56,7 @@ class lar_core_solver { void fill_not_improvable_zero_sum_from_inf_row(); column_type get_column_type(unsigned j) { return m_column_types[j];} - - + void print_pivot_row(std::ostream & out, unsigned row_index) const { for (unsigned j : m_r_solver.m_pivot_row.m_index) { if (numeric_traits::is_pos(m_r_solver.m_pivot_row.m_data[j])) @@ -69,9 +66,9 @@ class lar_core_solver { out << " +" << m_r_solver.column_name(m_r_solver.m_basis[row_index]) << std::endl; - for (unsigned j : m_r_solver.m_pivot_row.m_index) { + for (unsigned j : m_r_solver.m_pivot_row.m_index) m_r_solver.print_column_bound_info(j, out); - } + m_r_solver.print_column_bound_info(m_r_solver.m_basis[row_index], out); } @@ -93,6 +90,8 @@ class lar_core_solver { void solve(); + void pivot(int entering, int leaving) { m_r_solver.pivot(entering, leaving); } + bool lower_bounds_are_set() const { return true; } const indexed_vector & get_pivot_row() const { @@ -109,10 +108,7 @@ class lar_core_solver { m_column_types.push(); // rational m_r_lower_bounds.push(); - m_r_upper_bounds.push(); - - - + m_r_upper_bounds.push(); } void pop(unsigned k) { @@ -126,11 +122,10 @@ class lar_core_solver { m_r_solver.m_d.resize(m_r_A.column_count()); m_stacked_simplex_strategy.pop(k); - settings().set_simplex_strategy(m_stacked_simplex_strategy); + m_r_solver.m_settings.simplex_strategy() = m_stacked_simplex_strategy; + m_infeasible_linear_combination.reset(); lp_assert(m_r_solver.basis_heading_is_correct()); } - - bool r_basis_is_OK() const { #ifdef Z3DEBUG @@ -214,9 +209,7 @@ class lar_core_solver { } bool column_is_fixed(unsigned j) const { - return m_column_types()[j] == column_type::fixed || - ( m_column_types()[j] == column_type::boxed && - m_r_solver.m_lower_bounds[j] == m_r_solver.m_upper_bounds[j]); + return m_column_types()[j] == column_type::fixed; } bool column_is_free(unsigned j) const { diff --git a/src/math/lp/lar_core_solver_def.h b/src/math/lp/lar_core_solver_def.h index 550b6fe3674..e0ffed16e6a 100644 --- a/src/math/lp/lar_core_solver_def.h +++ b/src/math/lp/lar_core_solver_def.h @@ -49,7 +49,7 @@ void lar_core_solver::fill_not_improvable_zero_sum_from_inf_row() { m_infeasible_sum_sign = m_r_solver.inf_sign_of_column(bj); m_infeasible_linear_combination.clear(); for (auto & rc : m_r_solver.m_A.m_rows[m_r_solver.m_inf_row_index_for_tableau]) - m_infeasible_linear_combination.push_back(std::make_pair(rc.coeff(), rc.var())); + m_infeasible_linear_combination.push_back({rc.coeff(), rc.var()}); } void lar_core_solver::fill_not_improvable_zero_sum() { @@ -85,22 +85,21 @@ unsigned lar_core_solver::get_number_of_non_ints() const { void lar_core_solver::solve() { TRACE("lar_solver", tout << m_r_solver.get_status() << "\n";); lp_assert(m_r_solver.non_basic_columns_are_set_correctly()); - lp_assert(m_r_solver.inf_set_is_correct()); - TRACE("find_feas_stats", tout << "infeasibles = " << m_r_solver.inf_set_size() << ", int_infs = " << get_number_of_non_ints() << std::endl;); + lp_assert(m_r_solver.inf_heap_is_correct()); + TRACE("find_feas_stats", tout << "infeasibles = " << m_r_solver.inf_heap_size() << ", int_infs = " << get_number_of_non_ints() << std::endl;); if (m_r_solver.current_x_is_feasible() && m_r_solver.m_look_for_feasible_solution_only) { - m_r_solver.set_status(lp_status::OPTIMAL); - TRACE("lar_solver", tout << m_r_solver.get_status() << "\n";); - return; + m_r_solver.set_status(lp_status::OPTIMAL); + TRACE("lar_solver", tout << m_r_solver.get_status() << "\n";); + return; } - ++settings().stats().m_need_to_solve_inf; + ++m_r_solver.m_settings.stats().m_need_to_solve_inf; lp_assert( r_basis_is_OK()); - - + if (m_r_solver.m_look_for_feasible_solution_only) //todo : should it be set? m_r_solver.find_feasible_solution(); - else { + else m_r_solver.solve(); - } + lp_assert(r_basis_is_OK()); switch (m_r_solver.get_status()) @@ -117,11 +116,9 @@ void lar_core_solver::solve() { } lp_assert(r_basis_is_OK()); lp_assert(m_r_solver.non_basic_columns_are_set_correctly()); - lp_assert(m_r_solver.inf_set_is_correct()); - - TRACE("lar_solver", tout << m_r_solver.get_status() << "\n";); -} - + lp_assert(m_r_solver.inf_heap_is_correct()); + TRACE("lar_solver", tout << m_r_solver.get_status() << "\n";); } +} // namespace lp diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 0eb65e1973b..ef61c220953 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -2,7 +2,6 @@ Copyright (c) 2017 Microsoft Corporation Author: Nikolaj Bjorner, Lev Nachmanson */ - #include "math/lp/lar_solver.h" #include "smt/params/smt_params_helper.hpp" @@ -17,31 +16,30 @@ namespace lp { void lar_solver::updt_params(params_ref const& _p) { smt_params_helper p(_p); - set_track_pivoted_rows(p.arith_bprop_on_pivoted_rows()); + track_touched_rows(p.arith_bprop_on_pivoted_rows()); set_cut_strategy(p.arith_branch_cut_ratio()); m_settings.updt_params(_p); } lar_solver::lar_solver() : - m_crossed_bounds_column(-1), m_mpq_lar_core_solver(m_settings, *this), - m_var_register(false), - m_term_register(true), - m_constraints(*this) {} + m_var_register(), + m_constraints(m_dependencies, *this) {} - void lar_solver::set_track_pivoted_rows(bool v) { - m_mpq_lar_core_solver.m_r_solver.m_pivoted_rows = v ? (&m_rows_with_changed_bounds) : nullptr; + // start or ends tracking the rows that were changed by solve() + void lar_solver::track_touched_rows(bool v) { + m_mpq_lar_core_solver.m_r_solver.m_touched_rows = v ? (&m_touched_rows) : nullptr; } - - bool lar_solver::get_track_pivoted_rows() const { - return m_mpq_lar_core_solver.m_r_solver.m_pivoted_rows != nullptr; + + // returns true iff the solver tracks the rows that were changed by solve() + bool lar_solver::touched_rows_are_tracked() const { + return m_mpq_lar_core_solver.m_r_solver.m_touched_rows != nullptr; } lar_solver::~lar_solver() { for (auto t : m_terms) delete t; } - bool lar_solver::sizes_are_correct() const { lp_assert(A_r().column_count() == m_mpq_lar_core_solver.m_r_solver.m_column_types.size()); @@ -50,13 +48,12 @@ namespace lp { return true; } - std::ostream& lar_solver::print_implied_bound(const implied_bound& be, std::ostream& out) const { out << "implied bound\n"; unsigned v = be.m_j; - if (tv::is_term(v)) { - out << "it is a term number " << tv::unmask_term(be.m_j) << std::endl; - print_term(*m_terms[tv::unmask_term(v)], out); + if (column_has_term(v)) { + out << "term for column " << v << std::endl; + print_term(*m_columns[v].term(), out); } else { out << get_variable_name(v); @@ -97,7 +94,7 @@ namespace lp { if (strict) kind = static_cast((static_cast(kind) / 2)); - if (!tv::is_term(be.m_j)) { + if (!column_has_term(be.m_j)) { if (coeff_map.size() != 1) return false; auto it = coeff_map.find(be.m_j); @@ -111,13 +108,13 @@ namespace lp { else { lar_term const& t = get_term(be.m_j); auto first_coeff = t.begin(); - unsigned j = (*first_coeff).column(); + unsigned j = (*first_coeff).j(); auto it = coeff_map.find(j); if (it == coeff_map.end()) return false; mpq ratio = it->second / (*first_coeff).coeff(); for (auto p : t) { - it = coeff_map.find(p.column()); + it = coeff_map.find(p.j()); if (it == coeff_map.end()) return false; if (p.coeff() * ratio != it->second) @@ -141,47 +138,12 @@ namespace lp { return false; } - void lar_solver::substitute_basis_var_in_terms_for_row(unsigned i) { - // todo : create a map from term basic vars to the rows where they are used - unsigned basis_j = m_mpq_lar_core_solver.m_r_solver.m_basis[i]; - for (unsigned k = 0; k < m_terms.size(); k++) { - if (term_is_used_as_row(k)) - continue; - if (!m_terms[k]->contains(basis_j)) - continue; - m_terms[k]->subst_in_row(basis_j, m_mpq_lar_core_solver.m_r_solver.m_pivot_row); - } - } - - // Returns the column index without changes, - // but in the case the column was created as - // the slack variable to a term return the term index. - // It is the same index that was returned by add_var(), or - // by add_term() - unsigned lar_solver::column_to_reported_index(unsigned j) const { - if (tv::is_term(j)) - return j; - unsigned ext_var_or_term = m_var_register.local_to_external(j); - if (tv::is_term(ext_var_or_term)) { - j = ext_var_or_term; - } - return j; - } - - unsigned lar_solver::map_term_index_to_column_index(unsigned j) const { - SASSERT(tv::is_term(j)); - return m_var_register.external_to_local(j); - } - - // here i is just the term index - bool lar_solver::term_is_used_as_row(unsigned i) const { - SASSERT(i < m_terms.size()); - return m_var_register.external_is_used(tv::mask_term(i)); - } - lp_status lar_solver::get_status() const { return m_status; } - void lar_solver::set_status(lp_status s) { m_status = s; } + void lar_solver::set_status(lp_status s) { + TRACE("lar_solver", tout << "setting status to " << s << "\n";); + m_status = s; + } lp_status lar_solver::find_feasible_solution() { stats().m_make_feasible++; @@ -189,122 +151,115 @@ namespace lp { stats().m_max_cols = A_r().column_count(); if (A_r().row_count() > stats().m_max_rows) stats().m_max_rows = A_r().row_count(); - if (strategy_is_undecided()) - decide_on_strategy_and_adjust_initial_state(); - auto strategy_was = settings().simplex_strategy(); - settings().set_simplex_strategy(simplex_strategy_enum::tableau_rows); + flet f(settings().simplex_strategy(), simplex_strategy_enum::tableau_rows); m_mpq_lar_core_solver.m_r_solver.m_look_for_feasible_solution_only = true; auto ret = solve(); - settings().set_simplex_strategy(strategy_was); return ret; } lp_status lar_solver::solve() { - if (m_status == lp_status::INFEASIBLE) { + if (m_status == lp_status::INFEASIBLE || m_status == lp_status::CANCELLED) return m_status; - } + solve_with_core_solver(); - if (m_status != lp_status::INFEASIBLE) { - if (m_settings.bound_propagation()) - detect_rows_with_changed_bounds(); - } + if (m_status == lp_status::INFEASIBLE || m_status == lp_status::CANCELLED) + return m_status; + + if (m_settings.bound_propagation()) + detect_rows_with_changed_bounds(); clear_columns_with_changed_bounds(); return m_status; } void lar_solver::fill_explanation_from_crossed_bounds_column(explanation& evidence) const { - lp_assert(static_cast(get_column_type(m_crossed_bounds_column)) >= static_cast(column_type::boxed)); - lp_assert(!m_mpq_lar_core_solver.m_r_solver.column_is_feasible(m_crossed_bounds_column)); - // this is the case when the lower bound is in conflict with the upper one - const ul_pair& ul = m_columns_to_ul_pairs[m_crossed_bounds_column]; - evidence.add_pair(ul.upper_bound_witness(), numeric_traits::one()); - evidence.add_pair(ul.lower_bound_witness(), -numeric_traits::one()); + svector deps; + SASSERT(m_crossed_bounds_deps != nullptr); + m_dependencies.linearize(m_crossed_bounds_deps, deps); + for (auto d : deps) + evidence.add_pair(d, -numeric_traits::one()); } void lar_solver::push() { + m_trail.push_scope(); m_simplex_strategy = m_settings.simplex_strategy(); m_simplex_strategy.push(); - m_columns_to_ul_pairs.push(); - m_crossed_bounds_column.push(); + m_crossed_bounds_column = null_lpvar; + m_crossed_bounds_deps = nullptr; m_mpq_lar_core_solver.push(); - m_term_count = m_terms.size(); - m_term_count.push(); m_constraints.push(); m_usage_in_terms.push(); + m_dependencies.push_scope(); } - void lar_solver::clean_popped_elements(unsigned n, u_set& set) { + void lar_solver::clean_popped_elements(unsigned n, indexed_uint_set& set) { vector to_remove; for (unsigned j : set) if (j >= n) to_remove.push_back(j); for (unsigned j : to_remove) - set.erase(j); + set.remove(j); } + void lar_solver::clean_popped_elements_for_heap(unsigned n, lpvar_heap& heap) { + vector to_remove; + for (unsigned j : heap) + if (j >= n) + to_remove.push_back(j); + for (unsigned j : to_remove) + heap.erase(j); + } void lar_solver::pop(unsigned k) { TRACE("lar_solver", tout << "k = " << k << std::endl;); - m_crossed_bounds_column.pop(k); - unsigned n = m_columns_to_ul_pairs.peek_size(k); + m_crossed_bounds_column = null_lpvar; + m_crossed_bounds_deps = nullptr; + m_trail.pop_scope(k); + unsigned n = m_columns.size(); m_var_register.shrink(n); - pop_tableau(); + + lp_assert(m_mpq_lar_core_solver.m_r_solver.m_costs.size() == A_r().column_count()); + lp_assert(m_mpq_lar_core_solver.m_r_solver.m_basis.size() == A_r().row_count()); + lp_assert(m_mpq_lar_core_solver.m_r_solver.basis_heading_is_correct()); lp_assert(A_r().column_count() == n); - TRACE("lar_solver_details", - for (unsigned j = 0; j < n; j++) { - print_column_info(j, tout) << "\n"; - } - ); - m_columns_to_ul_pairs.pop(k); + TRACE("lar_solver_details", for (unsigned j = 0; j < n; j++) print_column_info(j, tout) << "\n";); m_mpq_lar_core_solver.pop(k); remove_non_fixed_from_fixed_var_table(); - clean_popped_elements(n, m_columns_with_changed_bounds); - clean_popped_elements(n, m_incorrect_columns); for (auto rid : m_row_bounds_to_replay) - insert_row_with_changed_bounds(rid); + add_touched_row(rid); m_row_bounds_to_replay.reset(); unsigned m = A_r().row_count(); - clean_popped_elements(m, m_rows_with_changed_bounds); - clean_inf_set_of_r_solver_after_pop(); - lp_assert( - m_settings.simplex_strategy() == simplex_strategy_enum::undecided || - m_mpq_lar_core_solver.m_r_solver.reduced_costs_are_correct_tableau()); + clean_popped_elements(m, m_touched_rows); + clean_inf_heap_of_r_solver_after_pop(); + SASSERT(m_mpq_lar_core_solver.m_r_solver.reduced_costs_are_correct_tableau()); m_constraints.pop(k); - m_term_count.pop(k); - for (unsigned i = m_term_count; i < m_terms.size(); i++) { - if (m_need_register_terms) - deregister_normalized_term(*m_terms[i]); - delete m_terms[i]; - } - m_term_register.shrink(m_term_count); - m_terms.resize(m_term_count); m_simplex_strategy.pop(k); - m_settings.set_simplex_strategy(m_simplex_strategy); + m_settings.simplex_strategy() = m_simplex_strategy; lp_assert(sizes_are_correct()); lp_assert(m_mpq_lar_core_solver.m_r_solver.reduced_costs_are_correct_tableau()); m_usage_in_terms.pop(k); + m_dependencies.pop_scope(k); + // init the nbasis sorting + require_nbasis_sort(); set_status(lp_status::UNKNOWN); } bool lar_solver::maximize_term_on_tableau(const lar_term& term, impq& term_max) { - if (settings().simplex_strategy() == simplex_strategy_enum::undecided) - decide_on_strategy_and_adjust_initial_state(); - + flet f(m_mpq_lar_core_solver.m_r_solver.m_look_for_feasible_solution_only, false); m_mpq_lar_core_solver.m_r_solver.set_status(lp_status::FEASIBLE); m_mpq_lar_core_solver.solve(); lp_status st = m_mpq_lar_core_solver.m_r_solver.get_status(); TRACE("lar_solver", tout << st << "\n";); SASSERT(m_mpq_lar_core_solver.m_r_solver.calc_current_x_is_feasible_include_non_basis()); - if (st == lp_status::UNBOUNDED) { + if (st == lp_status::UNBOUNDED || st == lp_status::CANCELLED) { return false; } else { @@ -313,6 +268,86 @@ namespace lp { } } + // get dependencies of the corresponding bounds from max_coeffs + u_dependency* lar_solver::get_dependencies_of_maximum(const vector>& max_coeffs) { + // The linear combinations of d_j*x[j] = the term that got maximized, where (d_j, j) is in max_coeffs + // Every j with positive coeff is at its upper bound, + // and every j with negative coeff is at its lower bound: so the sum cannot be increased. + // All variables j in the sum are non-basic. + u_dependency* dep = nullptr; + for (const auto & [d_j, j]: max_coeffs) { + SASSERT (!d_j.is_zero()); + + TRACE("lar_solver_improve_bounds", tout << "d[" << j << "] = " << d_j << "\n"; + this->m_mpq_lar_core_solver.m_r_solver.print_column_info(j, tout);); + const column& ul = m_columns[j]; + u_dependency * bound_dep; + if (d_j.is_pos()) + bound_dep = ul.upper_bound_witness(); + else + bound_dep = ul.lower_bound_witness(); + TRACE("lar_solver_improve_bounds", { + svector cs; + m_dependencies.linearize(bound_dep, cs); + for (auto c : cs) + m_constraints.display(tout, c) << "\n"; + }); + SASSERT(bound_dep != nullptr); + dep = m_dependencies.mk_join(dep, bound_dep); + } + return dep; + } + // returns nullptr if the bound is not improved, otherwise returns the witness of the bound + u_dependency* lar_solver::find_improved_bound(lpvar j, bool lower_bound, mpq& bound) { + + SASSERT(is_feasible()); + if (lower_bound && column_has_lower_bound(j) && get_column_value(j) == column_lower_bound(j)) + return nullptr; // cannot do better + if (!lower_bound && column_has_upper_bound(j) && get_column_value(j) == column_upper_bound(j)) + return nullptr; // cannot do better + + + lar_term term = get_term_to_maximize(j); + if (lower_bound) + term.negate(); + vector> max_coeffs; + TRACE("lar_solver_improve_bounds", tout << "j = " << j << ", "; print_term(term, tout << "term to maximize\n");); + impq term_max; + if (!maximize_term_on_feasible_r_solver(term, term_max, &max_coeffs)) + return nullptr; + // term_max is equal to the sum of m_d[j]*x[j] over all non basic j. + // For the sum to be at the maximum all non basic variables should be at their bounds: if (m_d[j] > 0) x[j] = u[j], otherwise x[j] = l[j]. At upper bounds we have u[j].y <= 0, and at lower bounds we have l[j].y >= 0, therefore for the sum term_max.y <= 0. + SASSERT(!term_max.y.is_pos()); + + // To keep it simpler we ignore possible improvements from non-strict to strict bounds. + bound = term_max.x; + if (lower_bound) { + bound.neg(); + if (column_is_int(j)) + bound = ceil(bound); + + if (column_has_lower_bound(j) && column_is_int(j) && bound <= column_lower_bound(j).x) + return nullptr; + + TRACE("lar_solver_improve_bounds", + tout << "setting lower bound for " << j << " to " << bound << "\n"; + if (column_has_lower_bound(j)) tout << "bound was = " << column_lower_bound(j) << "\n";); + } + else { + if (column_is_int(j)) + bound = floor(bound); + + if (column_has_upper_bound(j)) { + if (bound >= column_upper_bound(j).x) + return nullptr; + } + TRACE("lar_solver_improve_bounds", + tout << "setting upper bound for " << j << " to " << bound << "\n"; + if (column_has_upper_bound(j)) tout << "bound was = " << column_upper_bound(j) << "\n";;); + } + return get_dependencies_of_maximum(max_coeffs); + } + bool lar_solver::costs_are_zeros_for_r_solver() const { for (unsigned j = 0; j < m_mpq_lar_core_solver.m_r_solver.m_costs.size(); j++) { lp_assert(is_zero(m_mpq_lar_core_solver.m_r_solver.m_costs[j])); @@ -328,55 +363,48 @@ namespace lp { void lar_solver::set_costs_to_zero(const lar_term& term) { auto& rslv = m_mpq_lar_core_solver.m_r_solver; - auto& jset = m_mpq_lar_core_solver.m_r_solver.inf_set(); // hijack this set that should be empty right now - lp_assert(jset.empty()); - + auto& d = rslv.m_d; + auto& costs = rslv.m_costs; for (lar_term::ival p : term) { - unsigned j = p.column(); - rslv.m_costs[j] = zero_of_type(); + unsigned j = p.j(); + costs[j] = zero_of_type(); int i = rslv.m_basis_heading[j]; - if (i < 0) - jset.insert(j); - else { + if (i < 0) + d[j] = zero_of_type(); + else for (const auto& rc : A_r().m_rows[i]) - jset.insert(rc.var()); - } + d[rc.var()] = zero_of_type(); } - - for (unsigned j : jset) - rslv.m_d[j] = zero_of_type(); - - jset.clear(); - + lp_assert(reduced_costs_are_zeroes_for_r_solver()); lp_assert(costs_are_zeros_for_r_solver()); } void lar_solver::prepare_costs_for_r_solver(const lar_term& term) { TRACE("lar_solver", print_term(term, tout << "prepare: ") << "\n";); - m_basic_columns_with_changed_cost.resize(m_mpq_lar_core_solver.m_r_x.size()); - move_non_basic_columns_to_bounds(false); auto& rslv = m_mpq_lar_core_solver.m_r_solver; lp_assert(costs_are_zeros_for_r_solver()); lp_assert(reduced_costs_are_zeroes_for_r_solver()); + move_non_basic_columns_to_bounds(); rslv.m_costs.resize(A_r().column_count(), zero_of_type()); for (lar_term::ival p : term) { - unsigned j = p.column(); + unsigned j = p.j(); rslv.m_costs[j] = p.coeff(); if (rslv.m_basis_heading[j] < 0) rslv.m_d[j] += p.coeff(); else rslv.update_reduced_cost_for_basic_column_cost_change(-p.coeff(), j); } - rslv.m_costs_backup = rslv.m_costs; + if (settings().backup_costs) + rslv.m_costs_backup = rslv.m_costs; lp_assert(rslv.reduced_costs_are_correct_tableau()); } - void lar_solver::move_non_basic_columns_to_bounds(bool shift_randomly) { + void lar_solver::move_non_basic_columns_to_bounds() { auto& lcs = m_mpq_lar_core_solver; bool change = false; for (unsigned j : lcs.m_r_nbasis) { - if (move_non_basic_column_to_bounds(j, shift_randomly)) + if (move_non_basic_column_to_bounds(j)) change = true; } if (!change) @@ -387,46 +415,40 @@ namespace lp { find_feasible_solution(); } - bool lar_solver::move_non_basic_column_to_bounds(unsigned j, bool force_change) { + bool lar_solver::move_non_basic_column_to_bounds(unsigned j) { auto& lcs = m_mpq_lar_core_solver; auto& val = lcs.m_r_x[j]; switch (lcs.m_column_types()[j]) { case column_type::boxed: { - bool at_l = val == lcs.m_r_lower_bounds()[j]; - bool at_u = !at_l && (val == lcs.m_r_upper_bounds()[j]); - if (!at_l && !at_u) { - if (m_settings.random_next() % 2) - set_value_for_nbasic_column(j, lcs.m_r_lower_bounds()[j]); - else - set_value_for_nbasic_column(j, lcs.m_r_upper_bounds()[j]); - return true; - } - else if (force_change && m_settings.random_next() % 3 == 0) { - set_value_for_nbasic_column(j, - at_l ? lcs.m_r_upper_bounds()[j] : lcs.m_r_lower_bounds()[j]); - return true; - } - break; - } - case column_type::lower_bound: - if (val != lcs.m_r_lower_bounds()[j]) { - set_value_for_nbasic_column(j, lcs.m_r_lower_bounds()[j]); + const auto& l = lcs.m_r_lower_bounds()[j]; + if (val == l || val == lcs.m_r_upper_bounds()[j]) return false; + set_value_for_nbasic_column(j, l); + return true; + } + + case column_type::lower_bound: { + const auto& l = lcs.m_r_lower_bounds()[j]; + if (val != l) { + set_value_for_nbasic_column(j, l); return true; } - break; + return false; + } case column_type::fixed: - case column_type::upper_bound: - if (val != lcs.m_r_upper_bounds()[j]) { - set_value_for_nbasic_column(j, lcs.m_r_upper_bounds()[j]); + case column_type::upper_bound: { + const auto & u = lcs.m_r_upper_bounds()[j]; + if (val != u) { + set_value_for_nbasic_column(j, u); return true; } - break; + return false; + } case column_type::free_column: if (column_is_int(j) && !val.is_int()) { set_value_for_nbasic_column(j, impq(floor(val))); return true; } - break; + return false; default: SASSERT(false); } @@ -438,77 +460,67 @@ namespace lp { auto& x = m_mpq_lar_core_solver.m_r_x[j]; auto delta = new_val - x; x = new_val; + TRACE("lar_solver_feas", tout << "setting " << j << " to " + << new_val << (column_is_feasible(j)?"feas":"non-feas") << "\n";); change_basic_columns_dependend_on_a_given_nb_column(j, delta); } - - bool lar_solver::maximize_term_on_corrected_r_solver(lar_term& term, - impq& term_max) { + bool lar_solver::maximize_term_on_feasible_r_solver(lar_term& term, + impq& term_max, vector>* max_coeffs = nullptr) { settings().backup_costs = false; bool ret = false; - TRACE("lar_solver", print_term(term, tout << "maximize: ") << "\n" << constraints() << ", strategy = " << (int)settings().simplex_strategy() << "\n";); - switch (settings().simplex_strategy()) { - - case simplex_strategy_enum::tableau_rows: - settings().set_simplex_strategy(simplex_strategy_enum::tableau_costs); - prepare_costs_for_r_solver(term); - ret = maximize_term_on_tableau(term, term_max); - settings().set_simplex_strategy(simplex_strategy_enum::tableau_rows); - set_costs_to_zero(term); - m_mpq_lar_core_solver.m_r_solver.set_status(lp_status::OPTIMAL); - return ret; - - case simplex_strategy_enum::tableau_costs: - prepare_costs_for_r_solver(term); - ret = maximize_term_on_tableau(term, term_max); - set_costs_to_zero(term); - m_mpq_lar_core_solver.m_r_solver.set_status(lp_status::OPTIMAL); - return ret; - - - default: - UNREACHABLE(); // wrong mode + TRACE("lar_solver", print_term(term, tout << "maximize: ") << "\n" + << constraints() << ", strategy = " << (int)settings().simplex_strategy() << "\n";); + if (settings().simplex_strategy() != simplex_strategy_enum::tableau_costs) + require_nbasis_sort(); + flet f(settings().simplex_strategy(), simplex_strategy_enum::tableau_costs); + prepare_costs_for_r_solver(term); + ret = maximize_term_on_tableau(term, term_max); + if (ret && max_coeffs != nullptr) { + for (unsigned j = 0; j < column_count(); j++) { + const mpq& d_j = m_mpq_lar_core_solver.m_r_solver.m_d[j]; + if (d_j.is_zero()) + continue; + max_coeffs->push_back(std::make_pair(d_j, j)); + TRACE("lar_solver", tout<<"m_d["<(), j_or_term); + r.add_monomial(one_of_type(), j); return r; } return lar_term(); // return an empty term } - lp_status lar_solver::maximize_term(unsigned j_or_term, + lp_status lar_solver::maximize_term(unsigned j, impq& term_max) { TRACE("lar_solver", print_values(tout);); - - lar_term term = get_term_to_maximize(j_or_term); - if (term.is_empty()) { - return lp_status::UNBOUNDED; - } - impq prev_value; + SASSERT(m_mpq_lar_core_solver.m_r_solver.calc_current_x_is_feasible_include_non_basis()); + lar_term term = get_term_to_maximize(j); + if (term.is_empty()) return lp_status::UNBOUNDED; + impq prev_value = term.apply(m_mpq_lar_core_solver.m_r_x); auto backup = m_mpq_lar_core_solver.m_r_x; - if (m_mpq_lar_core_solver.m_r_solver.calc_current_x_is_feasible_include_non_basis()) { - prev_value = term.apply(m_mpq_lar_core_solver.m_r_x); - } - else { - m_mpq_lar_core_solver.m_r_solver.m_look_for_feasible_solution_only = false; - if (solve() != lp_status::OPTIMAL) - return lp_status::UNBOUNDED; - } - - m_mpq_lar_core_solver.m_r_solver.m_look_for_feasible_solution_only = false; - if (!maximize_term_on_corrected_r_solver(term, term_max)) { + if (!maximize_term_on_feasible_r_solver(term, term_max, nullptr)) { m_mpq_lar_core_solver.m_r_x = backup; return lp_status::UNBOUNDED; } @@ -528,7 +540,6 @@ namespace lp { return lp_status::FEASIBLE; // it should not happen } } - m_int_solver->patch_nbasic_column(j); if (!column_value_is_integer(j)) { term_max = prev_value; m_mpq_lar_core_solver.m_r_x = backup; @@ -551,13 +562,6 @@ namespace lp { return lp_status::FEASIBLE; } - - - const lar_term& lar_solver::get_term(unsigned j) const { - lp_assert(tv::is_term(j)); - return *m_terms[tv::unmask_term(j)]; - } - void lar_solver::pop_core_solver_params() { pop_core_solver_params(1); } @@ -567,42 +571,39 @@ namespace lp { } - void lar_solver::set_upper_bound_witness(var_index j, constraint_index ci) { - ul_pair ul = m_columns_to_ul_pairs[j]; - ul.upper_bound_witness() = ci; - m_columns_to_ul_pairs[j] = ul; + + void lar_solver::set_upper_bound_witness(lpvar j, u_dependency* dep) { + m_trail.push(vector_value_trail(m_columns, j)); + m_columns[j].upper_bound_witness() = dep; } - void lar_solver::set_lower_bound_witness(var_index j, constraint_index ci) { - ul_pair ul = m_columns_to_ul_pairs[j]; - ul.lower_bound_witness() = ci; - m_columns_to_ul_pairs[j] = ul; + void lar_solver::set_lower_bound_witness(lpvar j, u_dependency* dep) { + m_trail.push(vector_value_trail(m_columns, j)); + m_columns[j].lower_bound_witness() = dep; } - void lar_solver::register_monoid_in_map(std::unordered_map& coeffs, const mpq& a, unsigned j) { + void lar_solver::register_monoid_in_map(std::unordered_map& coeffs, const mpq& a, unsigned j) { auto it = coeffs.find(j); - if (it == coeffs.end()) { + if (it == coeffs.end()) coeffs[j] = a; - } - else { + else it->second += a; - } } - void lar_solver::substitute_terms_in_linear_expression(const vector>& left_side_with_terms, - vector>& left_side) const { - std::unordered_map coeffs; + void lar_solver::substitute_terms_in_linear_expression(const vector>& left_side_with_terms, + vector>& left_side) const { + std::unordered_map coeffs; for (auto& t : left_side_with_terms) { unsigned j = t.second; - if (!tv::is_term(j)) { + if (!column_has_term(j)) { register_monoid_in_map(coeffs, t.first, j); } else { - const lar_term& term = *m_terms[tv::unmask_term(t.second)]; + const lar_term& term = *m_columns[t.second].term(); for (auto p : term) - register_monoid_in_map(coeffs, t.first * p.coeff(), p.column()); + register_monoid_in_map(coeffs, t.first * p.coeff(), p.j()); } } @@ -611,18 +612,59 @@ namespace lp { left_side.push_back(std::make_pair(c, v)); } - void lar_solver::insert_row_with_changed_bounds(unsigned rid) { - m_rows_with_changed_bounds.insert(rid); + void lar_solver::add_touched_row(unsigned rid) { + if (m_settings.bound_propagation()) + m_touched_rows.insert(rid); } + void lar_solver::remove_fixed_vars_from_base() { + // this will allow to disable and restore the tracking of the touched rows + flet f(m_mpq_lar_core_solver.m_r_solver.m_touched_rows, nullptr); + unsigned num = A_r().column_count(); + unsigned_vector to_remove; + for (unsigned j : m_fixed_base_var_set) { + if (j >= num || !is_base(j) || !column_is_fixed(j)) { + to_remove.push_back(j); + continue; + } - - void lar_solver::detect_rows_of_bound_change_column_for_nbasic_column_tableau(unsigned j) { - for (auto& rc : m_mpq_lar_core_solver.m_r_A.m_columns[j]) - insert_row_with_changed_bounds(rc.var()); + lp_assert(is_base(j) && column_is_fixed(j)); + auto const& r = basic2row(j); + for (auto const& c : r) { + unsigned j_entering = c.var(); + if (!column_is_fixed(j_entering)) { + pivot(j_entering, j); + to_remove.push_back(j); + lp_assert(is_base(j_entering)); + break; + } + } + } + for (unsigned j : to_remove) { + m_fixed_base_var_set.remove(j); + } + lp_assert(fixed_base_removed_correctly()); } +#ifdef Z3DEBUG + bool lar_solver::fixed_base_removed_correctly() const { + for (unsigned i = 0; i < A_r().row_count(); i++) { + unsigned j = get_base_column_in_row(i); + if (column_is_fixed(j)) { + for (const auto & c : A_r().m_rows[i] ) { + if (!column_is_fixed(c.var())) { + TRACE("lar_solver", print_row(A_r().m_rows[i], tout) << "\n"; + for(const auto & c : A_r().m_rows[i]) { + print_column_info(c.var(), tout) << "\n"; + }); + return false; + } + } + } + } + return true; + } +#endif - bool lar_solver::use_tableau_costs() const { return m_settings.simplex_strategy() == simplex_strategy_enum::tableau_costs; } @@ -667,16 +709,16 @@ namespace lp { m_mpq_lar_core_solver.m_r_solver.add_delta_to_x_and_track_feasibility(bj, -A_r().get_val(c) * delta); TRACE("change_x_del", tout << "changed basis column " << bj << ", it is " << - (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(bj) ? "feas" : "inf") << std::endl;); + (column_is_feasible(bj) ? "feas" : "inf") << std::endl;); } } void lar_solver::update_x_and_inf_costs_for_column_with_changed_bounds(unsigned j) { if (m_mpq_lar_core_solver.m_r_heading[j] >= 0) { if (costs_are_used()) { - bool was_infeas = m_mpq_lar_core_solver.m_r_solver.inf_set_contains(j); + bool was_infeas = m_mpq_lar_core_solver.m_r_solver.inf_heap_contains(j); m_mpq_lar_core_solver.m_r_solver.track_column_feasibility(j); - if (was_infeas != m_mpq_lar_core_solver.m_r_solver.inf_set_contains(j)) + if (was_infeas != m_mpq_lar_core_solver.m_r_solver.inf_heap_contains(j)) m_basic_columns_with_changed_cost.insert(j); } else { @@ -692,14 +734,16 @@ namespace lp { void lar_solver::detect_rows_with_changed_bounds_for_column(unsigned j) { if (m_mpq_lar_core_solver.m_r_heading[j] >= 0) - insert_row_with_changed_bounds(m_mpq_lar_core_solver.m_r_heading[j]); + add_touched_row(m_mpq_lar_core_solver.m_r_heading[j]); else - detect_rows_of_bound_change_column_for_nbasic_column_tableau(j); + add_column_rows_to_touched_rows(j); } void lar_solver::detect_rows_with_changed_bounds() { for (auto j : m_columns_with_changed_bounds) detect_rows_with_changed_bounds_for_column(j); + if (m_find_monics_with_changed_bounds_func) + m_find_monics_with_changed_bounds_func(m_columns_with_changed_bounds); } void lar_solver::update_x_and_inf_costs_for_columns_with_changed_bounds_tableau() { @@ -710,9 +754,6 @@ namespace lp { void lar_solver::solve_with_core_solver() { m_mpq_lar_core_solver.prefix_r(); - if (costs_are_used()) { - m_basic_columns_with_changed_cost.resize(m_mpq_lar_core_solver.m_r_x.size()); - } update_x_and_inf_costs_for_columns_with_changed_bounds_tableau(); m_mpq_lar_core_solver.solve(); set_status(m_mpq_lar_core_solver.m_r_solver.get_status()); @@ -725,7 +766,8 @@ namespace lp { switch (this->get_status()) { case lp_status::OPTIMAL: case lp_status::FEASIBLE: - case lp_status::UNBOUNDED: + case lp_status::UNBOUNDED: + SASSERT(m_mpq_lar_core_solver.m_r_solver.inf_heap().size() == 0); return true; default: return false; @@ -745,15 +787,12 @@ namespace lp { return r; } - bool lar_solver::var_is_registered(var_index vj) const { - if (tv::is_term(vj)) { - return tv::unmask_term(vj) < m_terms.size(); - } - return vj < A_r().column_count(); + bool lar_solver::var_is_registered(lpvar vj) const { + return vj < A_r().column_count(); } - bool lar_solver::all_constrained_variables_are_registered(const vector>& left_side) { + bool lar_solver::all_constrained_variables_are_registered(const vector>& left_side) { for (auto it : left_side) { if (!var_is_registered(it.second)) return false; @@ -764,7 +803,7 @@ namespace lp { bool lar_solver::all_constraints_hold() const { if (m_settings.get_cancel_flag()) return true; - std::unordered_map var_map; + std::unordered_map var_map; get_model_do_not_care_about_diff_vars(var_map); for (auto const& c : m_constraints.active()) { @@ -780,7 +819,7 @@ namespace lp { return true; } - bool lar_solver::constraint_holds(const lar_base_constraint& constr, std::unordered_map& var_map) const { + bool lar_solver::constraint_holds(const lar_base_constraint& constr, std::unordered_map& var_map) const { mpq left_side_val = get_left_side_val(constr, var_map); switch (constr.kind()) { case LE: return left_side_val <= constr.rhs(); @@ -795,7 +834,7 @@ namespace lp { } - void lar_solver::register_in_map(std::unordered_map& coeffs, const lar_base_constraint& cn, const mpq& a) { + void lar_solver::register_in_map(std::unordered_map& coeffs, const lar_base_constraint& cn, const mpq& a) { for (auto& it : cn.coeffs()) { unsigned j = it.second; auto p = coeffs.find(j); @@ -810,7 +849,7 @@ namespace lp { } bool lar_solver::the_left_sides_sum_to_zero(const vector>& evidence) const { - std::unordered_map coeff_map; + std::unordered_map coeff_map; for (auto const & [coeff, con_ind] : evidence) { lp_assert(m_constraints.valid_index(con_ind)); register_in_map(coeff_map, m_constraints[con_ind], coeff); @@ -872,15 +911,15 @@ namespace lp { return ret; } - bool lar_solver::has_lower_bound(var_index var, constraint_index& ci, mpq& value, bool& is_strict) const { + bool lar_solver::has_lower_bound(lpvar var, u_dependency*& ci, mpq& value, bool& is_strict) const { - if (var >= m_columns_to_ul_pairs.size()) { + if (var >= m_columns.size()) { // TBD: bounds on terms could also be used, caller may have to track these. return false; } - const ul_pair& ul = m_columns_to_ul_pairs[var]; + const column& ul = m_columns[var]; ci = ul.lower_bound_witness(); - if (ci != null_ci) { + if (ci != nullptr) { auto& p = m_mpq_lar_core_solver.m_r_lower_bounds()[var]; value = p.x; is_strict = p.y.is_pos(); @@ -891,15 +930,15 @@ namespace lp { } } - bool lar_solver::has_upper_bound(var_index var, constraint_index& ci, mpq& value, bool& is_strict) const { + bool lar_solver::has_upper_bound(lpvar var, u_dependency*& ci, mpq& value, bool& is_strict) const { - if (var >= m_columns_to_ul_pairs.size()) { + if (var >= m_columns.size()) { // TBD: bounds on terms could also be used, caller may have to track these. return false; } - const ul_pair& ul = m_columns_to_ul_pairs[var]; + const column& ul = m_columns[var]; ci = ul.upper_bound_witness(); - if (ci != null_ci) { + if (ci != nullptr) { auto& p = m_mpq_lar_core_solver.m_r_upper_bounds()[var]; value = p.x; is_strict = p.y.is_neg(); @@ -910,12 +949,12 @@ namespace lp { } } - bool lar_solver::has_value(var_index var, mpq& value) const { - if (tv::is_term(var)) { + bool lar_solver::has_value(lpvar var, mpq& value) const { + if (column_has_term(var)) { lar_term const& t = get_term(var); value = 0; for (lar_term::ival cv : t) { - impq const& r = get_column_value(cv.column()); + impq const& r = get_column_value(cv.j()); if (!numeric_traits::is_zero(r.y)) return false; value += r.x * cv.coeff(); } @@ -931,7 +970,7 @@ namespace lp { void lar_solver::get_infeasibility_explanation(explanation& exp) const { exp.clear(); - if (m_crossed_bounds_column != -1) { + if (m_crossed_bounds_column != null_lpvar) { fill_explanation_from_crossed_bounds_column(exp); return; } @@ -955,16 +994,20 @@ namespace lp { unsigned j = it.second; int adj_sign = coeff.is_pos() ? inf_sign : -inf_sign; - const ul_pair& ul = m_columns_to_ul_pairs[j]; - - constraint_index bound_constr_i = adj_sign < 0 ? ul.upper_bound_witness() : ul.lower_bound_witness(); - lp_assert(m_constraints.valid_index(bound_constr_i)); - exp.add_pair(bound_constr_i, coeff); + const column& ul = m_columns[j]; + + u_dependency* bound_constr_i = adj_sign < 0 ? ul.upper_bound_witness() : ul.lower_bound_witness(); + svector deps; + m_dependencies.linearize(bound_constr_i, deps); + for (auto d : deps) { + lp_assert(m_constraints.valid_index(d)); + exp.add_pair(d, coeff); + } } } // (x, y) != (x', y') => (x + delta*y) != (x' + delta*y') - void lar_solver::get_model(std::unordered_map& variable_values) const { + void lar_solver::get_model(std::unordered_map& variable_values) const { variable_values.clear(); if (!init_model()) return; @@ -972,19 +1015,23 @@ namespace lp { unsigned n = m_mpq_lar_core_solver.m_r_x.size(); for (unsigned j = 0; j < n; j++) - variable_values[j] = get_value(column_index(j)); + variable_values[j] = get_value(j); TRACE("lar_solver_model", tout << "delta = " << m_delta << "\nmodel:\n"; for (auto p : variable_values) tout << this->get_variable_name(p.first) << " = " << p.second << "\n";); } bool lar_solver::init_model() const { - if (get_status() != lp_status::OPTIMAL && get_status() != lp_status::FEASIBLE) + CTRACE("lar_solver_model",!m_columns_with_changed_bounds.empty(), tout << "non-empty changed bounds\n"); + TRACE("lar_solver_model", tout << get_status() << "\n"); + auto status = get_status(); + SASSERT((status != lp_status::OPTIMAL && status != lp_status::FEASIBLE) + || m_mpq_lar_core_solver.m_r_solver.calc_current_x_is_feasible_include_non_basis()); + if (status != lp_status::OPTIMAL && status != lp_status::FEASIBLE) return false; if (!m_columns_with_changed_bounds.empty()) return false; - lp_assert(m_mpq_lar_core_solver.m_r_solver.calc_current_x_is_feasible_include_non_basis()); m_delta = m_mpq_lar_core_solver.find_delta_for_strict_bounds(mpq(1)); unsigned j; unsigned n = m_mpq_lar_core_solver.m_r_x.size(); @@ -1006,7 +1053,7 @@ namespace lp { return true; } - void lar_solver::get_model_do_not_care_about_diff_vars(std::unordered_map& variable_values) const { + void lar_solver::get_model_do_not_care_about_diff_vars(std::unordered_map& variable_values) const { mpq delta = m_mpq_lar_core_solver.find_delta_for_strict_bounds(mpq(1)); for (unsigned i = 0; i < m_mpq_lar_core_solver.m_r_x.size(); i++) { const impq& rp = m_mpq_lar_core_solver.m_r_x[i]; @@ -1014,34 +1061,11 @@ namespace lp { } } - mpq lar_solver::get_value(column_index const& j) const { + mpq lar_solver::get_value(lpvar j) const { SASSERT(get_status() == lp_status::OPTIMAL || get_status() == lp_status::FEASIBLE); - SASSERT(m_columns_with_changed_bounds.empty()); + VERIFY(m_columns_with_changed_bounds.empty()); numeric_pair const& rp = get_column_value(j); - return rp.x + m_delta * rp.y; - } - - mpq lar_solver::get_tv_value(tv const& t) const { - if (t.is_var()) - return get_value(t.column()); -#if 0 - unsigned term_j = 0; - if (m_var_register.term_is_used(t.id(), term_j)) - return get_value(column_index(term_j)); -#endif - mpq r(0); - for (lar_term::ival p : get_term(t)) - r += p.coeff() * get_value(p.column()); - return r; - } - - impq lar_solver::get_tv_ivalue(tv const& t) const { - if (t.is_var()) - return get_column_value(t.column()); - impq r; - for (lar_term::ival p : get_term(t)) - r += p.coeff() * get_column_value(p.column()); - return r; + return from_model_in_impq_to_mpq(rp); } void lar_solver::get_rid_of_inf_eps() { @@ -1056,19 +1080,21 @@ namespace lp { return; mpq delta = m_mpq_lar_core_solver.find_delta_for_strict_bounds(mpq(1)); for (unsigned j = 0; j < number_of_vars(); j++) { - auto& r = m_mpq_lar_core_solver.m_r_x[j]; - if (!r.y.is_zero()) - r = impq(r.x + delta * r.y); + auto& v = m_mpq_lar_core_solver.m_r_x[j]; + if (!v.y.is_zero()) { + v = impq(v.x + delta * v.y); + TRACE("lar_solver_feas", tout << "x[" << j << "] = " << v << "\n";); + } } } - void lar_solver::set_variable_name(var_index vi, std::string name) { + void lar_solver::set_variable_name(lpvar vi, std::string name) { m_var_register.set_name(vi, name); } - std::string lar_solver::get_variable_name(var_index j) const { - if (tv::is_term(j)) - return std::string("_t") + T_to_string(tv::unmask_term(j)); + std::string lar_solver::get_variable_name(lpvar j) const { + if (column_has_term(j)) + return std::string("_t") + T_to_string(j); if (j >= m_var_register.size()) return std::string("_s") + T_to_string(j); @@ -1080,7 +1106,7 @@ namespace lp { return std::string("j") + T_to_string(m_var_register.local_to_external(j)); } else { - std::string s = column_corresponds_to_term(j) ? "t" : "j"; + std::string s = column_has_term(j) ? "t" : "j"; return s + T_to_string(j); } } @@ -1125,7 +1151,7 @@ namespace lp { out << " - "; else if (val != numeric_traits::one()) out << T_to_string(val); - out << this->get_variable_name(p.column()); + out << this->get_variable_name(p.j()); } return out; } @@ -1135,10 +1161,10 @@ namespace lp { return out; } - mpq lar_solver::get_left_side_val(const lar_base_constraint& cns, const std::unordered_map& var_map) const { + mpq lar_solver::get_left_side_val(const lar_base_constraint& cns, const std::unordered_map& var_map) const { mpq ret = cns.get_free_coeff_of_left_side(); for (auto& it : cns.coeffs()) { - var_index j = it.second; + lpvar j = it.second; auto vi = var_map.find(j); lp_assert(vi != var_map.end()); ret += it.first * vi->second; @@ -1147,13 +1173,13 @@ namespace lp { } - void lar_solver::fill_var_set_for_random_update(unsigned sz, var_index const* vars, vector& column_list) { + void lar_solver::fill_var_set_for_random_update(unsigned sz, lpvar const* vars, vector& column_list) { TRACE("lar_solver_rand", tout << "sz = " << sz << "\n";); for (unsigned i = 0; i < sz; i++) { - var_index var = vars[i]; - if (tv::is_term(var)) { - if (term_is_used_as_row(tv::unmask_term(var))) { - column_list.push_back(map_term_index_to_column_index(var)); + lpvar var = vars[i]; + if (column_has_term(var)) { + if (m_columns[var].associated_with_row()) { + column_list.push_back(var); } } else { @@ -1162,17 +1188,17 @@ namespace lp { } } - void lar_solver::random_update(unsigned sz, var_index const* vars) { + void lar_solver::random_update(unsigned sz, lpvar const* vars) { vector column_list; fill_var_set_for_random_update(sz, vars, column_list); random_updater ru(*this, column_list); ru.update(); } - void lar_solver::mark_rows_for_bound_prop(lpvar j) { - auto& column = A_r().m_columns[j]; + void lar_solver::add_column_rows_to_touched_rows(lpvar j) { + const auto& column = A_r().m_columns[j]; for (auto const& r : column) - insert_row_with_changed_bounds(r.var()); + add_touched_row(r.var()); } void lar_solver::pop() { @@ -1180,7 +1206,7 @@ namespace lp { } bool lar_solver::column_represents_row_in_tableau(unsigned j) { - return m_columns_to_ul_pairs()[j].associated_with_row(); + return m_columns[j].associated_with_row(); } void lar_solver::make_sure_that_the_bottom_right_elem_not_zero_in_tableau(unsigned i, unsigned j) { @@ -1282,27 +1308,13 @@ namespace lp { lp_assert(A_r().column_count() == m_mpq_lar_core_solver.m_r_solver.m_costs.size()); } - void lar_solver::pop_tableau() { - lp_assert(m_mpq_lar_core_solver.m_r_solver.m_costs.size() == A_r().column_count()); - lp_assert(m_mpq_lar_core_solver.m_r_solver.m_basis.size() == A_r().row_count()); - lp_assert(m_mpq_lar_core_solver.m_r_solver.basis_heading_is_correct()); - // We remove last variables starting from m_column_names.size() to m_vec_of_canonic_left_sides.size(). - // At this moment m_column_names is already popped - unsigned size = m_var_register.size(); - while (A_r().column_count() > size) - remove_last_column_from_tableau(); - lp_assert(m_mpq_lar_core_solver.m_r_solver.m_costs.size() == A_r().column_count()); - lp_assert(m_mpq_lar_core_solver.m_r_solver.m_basis.size() == A_r().row_count()); - lp_assert(m_mpq_lar_core_solver.m_r_solver.basis_heading_is_correct()); - } - - void lar_solver::clean_inf_set_of_r_solver_after_pop() { + void lar_solver::clean_inf_heap_of_r_solver_after_pop() { vector became_feas; - clean_popped_elements(A_r().column_count(), m_mpq_lar_core_solver.m_r_solver.inf_set()); + clean_popped_elements_for_heap(A_r().column_count(), m_mpq_lar_core_solver.m_r_solver.inf_heap()); std::unordered_set basic_columns_with_changed_cost; m_inf_index_copy.reset(); - for (auto j : m_mpq_lar_core_solver.m_r_solver.inf_set()) + for (auto j : m_mpq_lar_core_solver.m_r_solver.inf_heap()) m_inf_index_copy.push_back(j); for (auto j : m_inf_index_copy) { if (m_mpq_lar_core_solver.m_r_heading[j] >= 0) { @@ -1320,16 +1332,16 @@ namespace lp { lp_assert(m_mpq_lar_core_solver.m_r_solver.m_basis_heading[j] < 0); m_mpq_lar_core_solver.m_r_solver.m_d[j] -= m_mpq_lar_core_solver.m_r_solver.m_costs[j]; m_mpq_lar_core_solver.m_r_solver.m_costs[j] = zero_of_type(); - m_mpq_lar_core_solver.m_r_solver.remove_column_from_inf_set(j); + m_mpq_lar_core_solver.m_r_solver.remove_column_from_inf_heap(j); } became_feas.clear(); - for (unsigned j : m_mpq_lar_core_solver.m_r_solver.inf_set()) { + for (unsigned j : m_mpq_lar_core_solver.m_r_solver.inf_heap()) { lp_assert(m_mpq_lar_core_solver.m_r_heading[j] >= 0); - if (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(j)) + if (column_is_feasible(j)) became_feas.push_back(j); } for (unsigned j : became_feas) - m_mpq_lar_core_solver.m_r_solver.remove_column_from_inf_set(j); + m_mpq_lar_core_solver.m_r_solver.remove_column_from_inf_heap(j); } @@ -1344,26 +1356,21 @@ namespace lp { bool lar_solver::term_is_int(const lar_term* t) const { for (auto const p : *t) - if (!(column_is_int(p.column()) && p.coeff().is_int())) + if (!(column_is_int(p.j()) && p.coeff().is_int())) return false; return true; } bool lar_solver::term_is_int(const vector>& coeffs) const { - for (auto const& p : coeffs) - if (!(column_is_int(p.second) && p.first.is_int())) + for (auto const& [coeff, v] : coeffs) + if (!(column_is_int(v) && coeff.is_int())) return false; return true; } - bool lar_solver::var_is_int(var_index v) const { - if (tv::is_term(v)) { - lar_term const& t = get_term(v); - return term_is_int(&t); - } - else { - return column_is_int(v); - } + bool lar_solver::var_is_int(lpvar v) const { + SASSERT(!column_has_term(v) || term_is_int(&get_term(v)) == column_is_int(v)); + return column_is_int(v); } bool lar_solver::column_is_int(unsigned j) const { @@ -1378,102 +1385,46 @@ namespace lp { return m_mpq_lar_core_solver.column_is_free(j); } - // column is at lower or upper bound, lower and upper bound are different. - // the lower/upper bound is not strict. - // the LP obtained by making the bound strict is infeasible - // -> the column has to be fixed - bool lar_solver::is_fixed_at_bound(column_index const& j) { - if (column_is_fixed(j)) - return false; - mpq val; - if (!has_value(j, val)) - return false; - lp::lconstraint_kind k; - if (column_has_upper_bound(j) && - get_upper_bound(j).x == val) { - verbose_stream() << "check upper " << j << "\n"; - push(); - if (column_is_int(j)) - k = LE, val -= 1; - else - k = LT; - auto ci = mk_var_bound(j, k, val); - update_column_type_and_bound(j, k, val, ci); - auto st = find_feasible_solution(); - pop(1); - return st == lp_status::INFEASIBLE; - } - if (column_has_lower_bound(j) && - get_lower_bound(j).x == val) { - verbose_stream() << "check lower " << j << "\n"; - push(); - if (column_is_int(j)) - k = GE, val += 1; - else - k = GT; - auto ci = mk_var_bound(j, k, val); - update_column_type_and_bound(j, k, val, ci); - auto st = find_feasible_solution(); - pop(1); - return st == lp_status::INFEASIBLE; - } - - return false; - } - - bool lar_solver::has_fixed_at_bound() { - verbose_stream() << "has-fixed-at-bound\n"; - unsigned num_fixed = 0; - for (unsigned j = 0; j < A_r().m_columns.size(); ++j) { - auto ci = column_index(j); - if (is_fixed_at_bound(ci)) { - ++num_fixed; - verbose_stream() << "fixed " << j << "\n"; - } - } - verbose_stream() << "num fixed " << num_fixed << "\n"; - if (num_fixed > 0) - find_feasible_solution(); - return num_fixed > 0; - } - - // below is the initialization functionality of lar_solver - bool lar_solver::strategy_is_undecided() const { - return m_settings.simplex_strategy() == simplex_strategy_enum::undecided; - } - - var_index lar_solver::add_named_var(unsigned ext_j, bool is_int, const std::string& name) { - var_index j = add_var(ext_j, is_int); + lpvar lar_solver::add_named_var(unsigned ext_j, bool is_int, const std::string& name) { + lpvar j = add_var(ext_j, is_int); m_var_register.set_name(j, name); return j; } - unsigned lar_solver::external_to_column_index(unsigned ext_j) const { - unsigned j = external_to_local(ext_j); - if (j == null_lpvar) - return j; - - if (tv::is_term(j)) - return map_term_index_to_column_index(j); - - return j; - } + struct lar_solver::undo_add_column : public trail { + lar_solver& s; + undo_add_column(lar_solver& s) : s(s) {} + void undo() override { + auto& col = s.m_columns.back(); + if (col.term() != nullptr) { + if (s.m_need_register_terms) + s.deregister_normalized_term(*col.term()); + delete col.term(); + s.m_terms.pop_back(); + } + s.remove_last_column_from_tableau(); + s.m_columns.pop_back(); + unsigned j = s.m_columns.size(); + if (s.m_columns_with_changed_bounds.contains(j)) + s.m_columns_with_changed_bounds.remove(j); + if (s.m_incorrect_columns.contains(j)) + s.m_incorrect_columns.remove(j); + } + }; - var_index lar_solver::add_var(unsigned ext_j, bool is_int) { + lpvar lar_solver::add_var(unsigned ext_j, bool is_int) { TRACE("add_var", tout << "adding var " << ext_j << (is_int ? " int" : " nonint") << std::endl;); - var_index local_j; - SASSERT(!m_term_register.external_is_used(ext_j)); - lp_assert(!tv::is_term(ext_j)); + lpvar local_j; if (m_var_register.external_is_used(ext_j, local_j)) return local_j; - lp_assert(m_columns_to_ul_pairs.size() == A_r().column_count()); + lp_assert(m_columns.size() == A_r().column_count()); local_j = A_r().column_count(); - m_columns_to_ul_pairs.push_back(ul_pair(false)); // not associated with a row - while (m_usage_in_terms.size() <= ext_j) { + m_columns.push_back(column(false, nullptr)); // false - not associated with a row, nullptr for term + m_trail.push(undo_add_column(*this)); + while (m_usage_in_terms.size() <= ext_j) m_usage_in_terms.push_back(0); - } add_non_basic_var_to_core_fields(ext_j, is_int); lp_assert(sizes_are_correct()); return local_j; @@ -1483,19 +1434,18 @@ namespace lp { return m_var_register.has_int_var(); } - void lar_solver::register_new_ext_var_index(unsigned ext_v, bool is_int) { + void lar_solver::register_new_external_var(unsigned ext_v, bool is_int) { lp_assert(!m_var_register.external_is_used(ext_v)); m_var_register.add_var(ext_v, is_int); } bool lar_solver::external_is_used(unsigned v) const { - return m_var_register.external_is_used(v) || m_term_register.external_is_used(v); + return m_var_register.external_is_used(v); } void lar_solver::add_non_basic_var_to_core_fields(unsigned ext_j, bool is_int) { - register_new_ext_var_index(ext_j, is_int); + register_new_external_var(ext_j, is_int); m_mpq_lar_core_solver.m_column_types.push_back(column_type::free_column); - increase_by_one_columns_with_changed_bounds(); add_new_var_to_core_fields_for_mpq(false); // false for not adding a row } @@ -1508,7 +1458,7 @@ namespace lp { m_mpq_lar_core_solver.m_r_x.resize(j + 1); m_mpq_lar_core_solver.m_r_lower_bounds.increase_size_by_one(); m_mpq_lar_core_solver.m_r_upper_bounds.increase_size_by_one(); - m_mpq_lar_core_solver.m_r_solver.inf_set_increase_size_by_one(); + m_mpq_lar_core_solver.m_r_solver.inf_heap_increase_size_by_one(); m_mpq_lar_core_solver.m_r_solver.m_costs.resize(j + 1); m_mpq_lar_core_solver.m_r_solver.m_d.resize(j + 1); lp_assert(m_mpq_lar_core_solver.m_r_heading.size() == j); // as A().column_count() on the entry to the method @@ -1516,23 +1466,17 @@ namespace lp { A_r().add_row(); m_mpq_lar_core_solver.m_r_heading.push_back(m_mpq_lar_core_solver.m_r_basis.size()); m_mpq_lar_core_solver.m_r_basis.push_back(j); - if (m_settings.bound_propagation()) - insert_row_with_changed_bounds(A_r().row_count() - 1); + add_touched_row(A_r().row_count() - 1); } else { m_mpq_lar_core_solver.m_r_heading.push_back(-static_cast(m_mpq_lar_core_solver.m_r_nbasis.size()) - 1); m_mpq_lar_core_solver.m_r_nbasis.push_back(j); + require_nbasis_sort(); } } - - var_index lar_solver::add_term_undecided(const vector>& coeffs) { - push_term(new lar_term(coeffs)); - return tv::mask_term(m_terms.size() - 1); - } - #if Z3DEBUG_CHECK_UNIQUE_TERMS - bool lar_solver::term_coeffs_are_ok(const vector>& coeffs) { + bool lar_solver::term_coeffs_are_ok(const vector>& coeffs) { for (const auto& p : coeffs) if (column_is_real(p.second)) @@ -1556,71 +1500,58 @@ namespace lp { return false; } #endif - void lar_solver::push_term(lar_term* t) { - m_terms.push_back(t); - } - + + // terms - bool lar_solver::all_vars_are_registered(const vector>& coeffs) { - for (const auto& p : coeffs) { - if (p.second >= m_var_register.size()) { - return false; - } - } - return true; + bool lar_solver::all_vars_are_registered(const vector>& coeffs) { + return all_of(coeffs, [&](const auto& p) { return p.second < m_var_register.size(); }); } void lar_solver::subst_known_terms(lar_term* t) { std::set seen_terms; for (auto p : *t) { - auto j = p.column(); - if (this->column_corresponds_to_term(j)) + auto j = p.j(); + if (this->column_has_term(j)) seen_terms.insert(j); } while (!seen_terms.empty()) { unsigned j = *seen_terms.begin(); seen_terms.erase(j); - auto tj = this->m_var_register.local_to_external(j); - auto& ot = this->get_term(tj); + const lar_term& ot = this->get_term(j); for (auto p : ot) - if (this->column_corresponds_to_term(p.column())) - seen_terms.insert(p.column()); + if (this->column_has_term(p.j())) + seen_terms.insert(p.j()); t->subst_by_term(ot, j); } } - // do not register in m_var_register this term if ext_i == UINT_MAX - var_index lar_solver::add_term(const vector>& coeffs, unsigned ext_i) { - TRACE("lar_solver_terms", print_linear_combination_of_column_indices_only(coeffs, tout) << ", ext_i =" << ext_i << "\n";); + // if UINT_MAX == null_lpvar then the term does not correspond and external variable + lpvar lar_solver::add_term(const vector>& coeffs, unsigned ext_i) { + TRACE("lar_solver_terms", print_linear_combination_of_column_indices_only(coeffs, tout) << ", ext_i =" << ext_i << "\n";); SASSERT(!m_var_register.external_is_used(ext_i)); - m_term_register.add_var(ext_i, term_is_int(coeffs)); - lp_assert(all_vars_are_registered(coeffs)); - if (strategy_is_undecided()) - return add_term_undecided(coeffs); + SASSERT(all_vars_are_registered(coeffs)); lar_term* t = new lar_term(coeffs); subst_known_terms(t); - push_term(t); - SASSERT(m_terms.size() == m_term_register.size()); - unsigned adjusted_term_index = m_terms.size() - 1; - var_index ret = tv::mask_term(adjusted_term_index); - if (!coeffs.empty()) { - add_row_from_term_no_constraint(m_terms.back(), ret); - if (m_settings.bound_propagation()) - insert_row_with_changed_bounds(A_r().row_count() - 1); - } + SASSERT(t->is_empty() == false); + m_terms.push_back(t); + lpvar ret = A_r().column_count(); + add_row_from_term_no_constraint(t, ext_i); + lp_assert(m_var_register.size() == A_r().column_count()); if (m_need_register_terms) register_normalized_term(*t, A_r().column_count() - 1); return ret; } - - void lar_solver::add_row_from_term_no_constraint(const lar_term* term, unsigned term_ext_index) { + void lar_solver::add_row_from_term_no_constraint(lar_term* term, unsigned ext_index) { TRACE("dump_terms", print_term(*term, tout) << std::endl;); - register_new_ext_var_index(term_ext_index, term_is_int(term)); + register_new_external_var(ext_index, term_is_int(term)); // j will be a new variable unsigned j = A_r().column_count(); - ul_pair ul(true); // to mark this column as associated_with_row - m_columns_to_ul_pairs.push_back(ul); + SASSERT(ext_index == null_lpvar || external_to_local(ext_index) == j); + column ul(true, term); // true - to mark this column as associated_with_row + term->j() = j; // point from the term to the column + m_columns.push_back(ul); + m_trail.push(undo_add_column(*this)); add_basic_var_to_core_fields(); A_r().fill_last_row_with_pivoting(*term, @@ -1630,7 +1561,7 @@ namespace lp { m_mpq_lar_core_solver.m_r_solver.update_x(j, get_basic_var_value_from_row(A_r().row_count() - 1)); for (lar_term::ival c : *term) { - unsigned j = c.column(); + unsigned j = c.j(); while (m_usage_in_terms.size() <= j) m_usage_in_terms.push_back(0); m_usage_in_terms[j] = m_usage_in_terms[j] + 1; @@ -1639,24 +1570,20 @@ namespace lp { void lar_solver::add_basic_var_to_core_fields() { m_mpq_lar_core_solver.m_column_types.push_back(column_type::free_column); - increase_by_one_columns_with_changed_bounds(); - m_incorrect_columns.increase_size_by_one(); - m_rows_with_changed_bounds.increase_size_by_one(); add_new_var_to_core_fields_for_mpq(true); - } bool lar_solver::bound_is_integer_for_integer_column(unsigned j, const mpq& right_side) const { return !column_is_int(j) || right_side.is_int(); } - constraint_index lar_solver::add_var_bound_check_on_equal(var_index j, lconstraint_kind kind, const mpq& right_side, var_index& equal_var) { + constraint_index lar_solver::add_var_bound_check_on_equal(lpvar j, lconstraint_kind kind, const mpq& right_side, lpvar& equal_var) { constraint_index ci = mk_var_bound(j, kind, right_side); activate_check_on_equal(ci, equal_var); return ci; } - constraint_index lar_solver::add_var_bound(var_index j, lconstraint_kind kind, const mpq& right_side) { + constraint_index lar_solver::add_var_bound(lpvar j, lconstraint_kind kind, const mpq& right_side) { constraint_index ci = mk_var_bound(j, kind, right_side); activate(ci); return ci; @@ -1706,7 +1633,7 @@ namespace lp { // SASSERT(column_is_fixed(k)); if (j != k && column_is_fixed(k)) { SASSERT(column_is_int(j) == column_is_int(k)); - equal_to_j = column_to_reported_index(k); + equal_to_j = k; TRACE("lar_solver", tout << "found equal column k = " << k << ", external = " << equal_to_j << "\n";); } @@ -1714,12 +1641,12 @@ namespace lp { void lar_solver::activate_check_on_equal(constraint_index ci, unsigned& equal_column) { auto const& c = m_constraints[ci]; - update_column_type_and_bound_check_on_equal(c.column(), c.kind(), c.rhs(), ci, equal_column); + update_column_type_and_bound_check_on_equal(c.column(), c.rhs(), ci, equal_column); } void lar_solver::activate(constraint_index ci) { auto const& c = m_constraints[ci]; - update_column_type_and_bound(c.column(), c.kind(), c.rhs(), ci); + update_column_type_and_bound(c.column(), c.rhs(), ci); } mpq lar_solver::adjust_bound_for_int(lpvar j, lconstraint_kind& k, const mpq& bound) { @@ -1746,10 +1673,10 @@ namespace lp { } - constraint_index lar_solver::mk_var_bound(var_index j, lconstraint_kind kind, const mpq& right_side) { + constraint_index lar_solver::mk_var_bound(lpvar j, lconstraint_kind kind, const mpq& right_side) { TRACE("lar_solver", tout << "j = " << get_variable_name(j) << " " << lconstraint_kind_string(kind) << " " << right_side << std::endl;); constraint_index ci; - if (!tv::is_term(j)) { // j is a var + if (!column_has_term(j)) { mpq rs = adjust_bound_for_int(j, kind, right_side); lp_assert(bound_is_integer_for_integer_column(j, rs)); ci = m_constraints.add_var_constraint(j, kind, rs); @@ -1761,9 +1688,7 @@ namespace lp { return ci; } - bool lar_solver::compare_values(var_index j, lconstraint_kind k, const mpq& rhs) { - if (tv::is_term(j)) - j = to_column(j); + bool lar_solver::compare_values(lpvar j, lconstraint_kind k, const mpq& rhs) { return compare_values(get_column_value(j), k, rhs); } @@ -1781,217 +1706,289 @@ namespace lp { } void lar_solver::update_column_type_and_bound(unsigned j, - lconstraint_kind kind, const mpq& right_side, constraint_index constr_index) { + TRACE("lar_solver_feas", tout << "j = " << j << " was " << (this->column_is_feasible(j)?"feas":"non-feas") << std::endl;); m_constraints.activate(constr_index); - if (column_has_upper_bound(j)) - update_column_type_and_bound_with_ub(j, kind, right_side, constr_index); - else - update_column_type_and_bound_with_no_ub(j, kind, right_side, constr_index); + lconstraint_kind kind = m_constraints[constr_index].kind(); + u_dependency* dep = m_constraints[constr_index].dep(); + update_column_type_and_bound(j, kind, right_side, dep); } - void lar_solver::update_column_type_and_bound_check_on_equal(unsigned j, - lconstraint_kind kind, - const mpq& right_side, - constraint_index constr_index, - unsigned& equal_to_j) { - update_column_type_and_bound(j, kind, right_side, constr_index); - equal_to_j = null_lpvar; - if (column_is_fixed(j)) { - register_in_fixed_var_table(j, equal_to_j); - } - } - - constraint_index lar_solver::add_var_bound_on_constraint_for_term(var_index j, lconstraint_kind kind, const mpq& right_side) { - lp_assert(tv::is_term(j)); - unsigned adjusted_term_index = tv::unmask_term(j); - // lp_assert(!term_is_int(m_terms[adjusted_term_index]) || right_side.is_int()); - unsigned term_j; - lar_term const* term = m_terms[adjusted_term_index]; - if (m_var_register.external_is_used(j, term_j)) { - mpq rs = adjust_bound_for_int(term_j, kind, right_side); - lp_assert(bound_is_integer_for_integer_column(term_j, rs)); - return m_constraints.add_term_constraint(term_j, term, kind, rs); + bool lar_solver::validate_bound(lpvar j, lconstraint_kind kind, const mpq& rs, u_dependency* dep) { + if (m_validate_blocker) return true; + + lar_solver solver; + solver.m_validate_blocker = true; + TRACE("lar_solver_validate", tout << "j = " << j << " " << lconstraint_kind_string(kind) << " " << rs << std::endl;); + add_dep_constraints_to_solver(solver, dep); + if (solver.external_to_local(j) == null_lpvar) { + return false; // we have to mention j in the dep + } + if (kind != EQ) { + add_bound_negation_to_solver(solver, j, kind, rs); + solver.find_feasible_solution(); + return solver.get_status() == lp_status::INFEASIBLE; } else { - return add_constraint_from_term_and_create_new_column_row(j, term, kind, right_side); + solver.push(); + add_bound_negation_to_solver(solver, j, LE, rs); + solver.find_feasible_solution(); + if (solver.get_status() != lp_status::INFEASIBLE) + return false; + solver.pop(); + add_bound_negation_to_solver(solver, j, GE, rs); + solver.find_feasible_solution(); + return solver.get_status() == lp_status::INFEASIBLE; } } - constraint_index lar_solver::add_constraint_from_term_and_create_new_column_row( - unsigned term_j, const lar_term* term, lconstraint_kind kind, const mpq& right_side) { - add_row_from_term_no_constraint(term, term_j); - unsigned j = A_r().column_count() - 1; + void lar_solver::add_dep_constraints_to_solver(lar_solver& ls, u_dependency* dep) { + auto constraints = flatten(dep); + for (auto c : constraints) + add_constraint_to_validate(ls, c); + } + void lar_solver::add_bound_negation_to_solver(lar_solver& ls, lpvar j, lconstraint_kind kind, const mpq& right_side) { + j = ls.external_to_local(j); + switch (kind) { + case LE: + ls.add_var_bound(j, GT, right_side); + break; + case LT: + ls.add_var_bound(j, GE, right_side); + break; + case GE: + ls.add_var_bound(j, LT, right_side); + break; + case GT: + ls.add_var_bound(j, LE, right_side); + break; + default: + UNREACHABLE(); + break; + } + } + void lar_solver::add_constraint_to_validate(lar_solver& ls, constraint_index ci) { + auto const& c = m_constraints[ci]; + TRACE("lar_solver_validate", tout << "adding constr with column = "<< c.column() << "\n"; m_constraints.display(tout, c); tout << std::endl;); + vector> coeffs; + for (auto p : c.coeffs()) { + lpvar jext = p.second; + lpvar j = ls.external_to_local(jext); + if (j == null_lpvar) { + ls.add_var(jext, column_is_int(jext)); + j = ls.external_to_local(jext); + } + coeffs.push_back(std::make_pair(p.first, j)); + } + + lpvar column_ext = c.column(); + unsigned j = ls.external_to_local(column_ext); + lpvar tv; + if (j == UINT_MAX) { + tv = ls.add_term(coeffs, column_ext); + } + else { + tv = ls.add_term(coeffs, null_lpvar); + } + ls.add_var_bound(tv, c.kind(), c.rhs()); + } + void lar_solver::update_column_type_and_bound(unsigned j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { + // SASSERT(validate_bound(j, kind, right_side, dep)); + TRACE( + "lar_solver_feas", + tout << "j" << j << " " << lconstraint_kind_string(kind) << " " << right_side << std::endl; + if (dep) { + tout << "dep:\n"; + auto cs = flatten(dep); + for (auto c : cs) { + constraints().display(tout, c); + tout << std::endl; + } + }); mpq rs = adjust_bound_for_int(j, kind, right_side); - lp_assert(bound_is_integer_for_integer_column(j, rs)); - return m_constraints.add_term_constraint(j, term, kind, rs); + if (column_has_upper_bound(j)) + update_column_type_and_bound_with_ub(j, kind, rs, dep); + else + update_column_type_and_bound_with_no_ub(j, kind, rs, dep); + + if (is_base(j) && column_is_fixed(j)) + m_fixed_base_var_set.insert(j); + TRACE("lar_solver_feas", tout << "j = " << j << " became " << (this->column_is_feasible(j) ? "feas" : "non-feas") << ", and " << (this->column_is_bounded(j) ? "bounded" : "non-bounded") << std::endl;); } - void lar_solver::decide_on_strategy_and_adjust_initial_state() { - lp_assert(strategy_is_undecided()); - - m_settings.set_simplex_strategy(simplex_strategy_enum::tableau_rows); // todo: when to switch to tableau_costs? - - adjust_initial_state(); + void lar_solver::insert_to_columns_with_changed_bounds(unsigned j) { + m_columns_with_changed_bounds.insert(j); + TRACE("lar_solver", tout << "column " << j << (column_is_feasible(j) ? " feas" : " non-feas") << "\n";); } - void lar_solver::adjust_initial_state() { - switch (m_settings.simplex_strategy()) { - case simplex_strategy_enum::tableau_rows: - adjust_initial_state_for_tableau_rows(); - break; - case simplex_strategy_enum::tableau_costs: - UNREACHABLE(); // not implemented - case simplex_strategy_enum::undecided: - adjust_initial_state_for_tableau_rows(); - break; + void lar_solver::update_column_type_and_bound_check_on_equal(unsigned j, + const mpq& right_side, + constraint_index constr_index, + unsigned& equal_to_j) { + update_column_type_and_bound(j, right_side, constr_index); + equal_to_j = null_lpvar; + if (column_is_fixed(j)) { + register_in_fixed_var_table(j, equal_to_j); } } - - void lar_solver::adjust_initial_state_for_tableau_rows() { - for (unsigned i = 0; i < m_terms.size(); i++) { - if (m_var_register.external_is_used(tv::mask_term(i))) - continue; - add_row_from_term_no_constraint(m_terms[i], tv::mask_term(i)); - } + constraint_index lar_solver::add_var_bound_on_constraint_for_term(lpvar j, lconstraint_kind kind, const mpq& right_side) { + mpq rs = adjust_bound_for_int(j, kind, right_side); + SASSERT(bound_is_integer_for_integer_column(j, rs)); + return m_constraints.add_term_constraint(j, m_columns[j].term(), kind, rs); } - - void lar_solver::update_column_type_and_bound_with_ub(unsigned j, lp::lconstraint_kind kind, const mpq& right_side, unsigned constraint_index) { + struct scoped_backup { + lar_solver& m_s; + scoped_backup(lar_solver& s) : m_s(s) { + m_s.backup_x(); + } + ~scoped_backup() { + m_s.restore_x(); + } + }; + + void lar_solver::update_column_type_and_bound_with_ub(unsigned j, lp::lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { SASSERT(column_has_upper_bound(j)); if (column_has_lower_bound(j)) { - update_bound_with_ub_lb(j, kind, right_side, constraint_index); + update_bound_with_ub_lb(j, kind, right_side, dep); } else { - update_bound_with_ub_no_lb(j, kind, right_side, constraint_index); + update_bound_with_ub_no_lb(j, kind, right_side, dep); } } - void lar_solver::update_column_type_and_bound_with_no_ub(unsigned j, lp::lconstraint_kind kind, const mpq& right_side, unsigned constraint_index) { + void lar_solver::update_column_type_and_bound_with_no_ub(unsigned j, lp::lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { SASSERT(!column_has_upper_bound(j)); if (column_has_lower_bound(j)) { - update_bound_with_no_ub_lb(j, kind, right_side, constraint_index); + update_bound_with_no_ub_lb(j, kind, right_side, dep); } else { - update_bound_with_no_ub_no_lb(j, kind, right_side, constraint_index); + update_bound_with_no_ub_no_lb(j, kind, right_side, dep); } } - - void lar_solver::update_bound_with_ub_lb(var_index j, lconstraint_kind kind, const mpq& right_side, constraint_index ci) { + void lar_solver::update_bound_with_ub_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { lp_assert(column_has_lower_bound(j) && column_has_upper_bound(j)); lp_assert(m_mpq_lar_core_solver.m_column_types[j] == column_type::boxed || - m_mpq_lar_core_solver.m_column_types[j] == column_type::fixed); + m_mpq_lar_core_solver.m_column_types[j] == column_type::fixed); mpq y_of_bound(0); switch (kind) { - case LT: - y_of_bound = -1; - case LE: - { - auto up = numeric_pair(right_side, y_of_bound); - if (up < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - set_infeasible_column(j); - } - if (up >= m_mpq_lar_core_solver.m_r_upper_bounds[j]) return; - m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; - set_upper_bound_witness(j, ci); - insert_to_columns_with_changed_bounds(j); - } - break; - case GT: - y_of_bound = 1; - case GE: - { - auto low = numeric_pair(right_side, y_of_bound); - if (low > m_mpq_lar_core_solver.m_r_upper_bounds[j]) { - set_infeasible_column(j); - } - if (low < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - return; + case LT: + y_of_bound = -1; + case LE: { + auto up = numeric_pair(right_side, y_of_bound); + if (up < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + set_crossed_bounds_column_and_deps(j, true, dep); + } + else { + if (up >= m_mpq_lar_core_solver.m_r_upper_bounds[j]) return; + m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; + set_upper_bound_witness(j, dep); + insert_to_columns_with_changed_bounds(j); + } + break; } - m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; - insert_to_columns_with_changed_bounds(j); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_column_types[j] = (low == m_mpq_lar_core_solver.m_r_upper_bounds[j] ? column_type::fixed : column_type::boxed); - } - break; - case EQ: - { - auto v = numeric_pair(right_side, zero_of_type()); - if (v > m_mpq_lar_core_solver.m_r_upper_bounds[j] || v < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - set_infeasible_column(j); + case GT: + y_of_bound = 1; + case GE: { + auto low = numeric_pair(right_side, y_of_bound); + if (low > m_mpq_lar_core_solver.m_r_upper_bounds[j]) { + set_crossed_bounds_column_and_deps(j, false, dep); + } else { + if (low < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + return; + } + m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_column_types[j] = (low == m_mpq_lar_core_solver.m_r_upper_bounds[j] ? column_type::fixed : column_type::boxed); + insert_to_columns_with_changed_bounds(j); + } + break; + + } + case EQ: { + auto v = numeric_pair(right_side, zero_of_type()); + if (v > m_mpq_lar_core_solver.m_r_upper_bounds[j]){ + set_crossed_bounds_column_and_deps(j, false, dep); + } + else if (v < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + set_crossed_bounds_column_and_deps(j, true, dep); + } + else { + set_upper_bound_witness(j, dep); + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; + insert_to_columns_with_changed_bounds(j); + } + break; } - set_upper_bound_witness(j, ci); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; - break; - } - default: - UNREACHABLE(); + default: + UNREACHABLE(); } if (m_mpq_lar_core_solver.m_r_upper_bounds[j] == m_mpq_lar_core_solver.m_r_lower_bounds[j]) { m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; } } - void lar_solver::update_bound_with_no_ub_lb(var_index j, lconstraint_kind kind, const mpq& right_side, constraint_index ci) { + + void lar_solver::update_bound_with_no_ub_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { lp_assert(column_has_lower_bound(j) && !column_has_upper_bound(j)); lp_assert(m_mpq_lar_core_solver.m_column_types[j] == column_type::lower_bound); mpq y_of_bound(0); switch (kind) { - case LT: - y_of_bound = -1; - case LE: - { - auto up = numeric_pair(right_side, y_of_bound); - if (up < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - set_infeasible_column(j); - } - m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; - set_upper_bound_witness(j, ci); - insert_to_columns_with_changed_bounds(j); - m_mpq_lar_core_solver.m_column_types[j] = (up == m_mpq_lar_core_solver.m_r_lower_bounds[j] ? column_type::fixed : column_type::boxed); - } - break; - case GT: - y_of_bound = 1; - case GE: - { - auto low = numeric_pair(right_side, y_of_bound); - if (low < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - return; - } - m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; - insert_to_columns_with_changed_bounds(j); - set_lower_bound_witness(j, ci); - } - break; - case EQ: - { - auto v = numeric_pair(right_side, zero_of_type()); - if (v < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { - set_infeasible_column(j); + case LT: + y_of_bound = -1; + case LE: { + auto up = numeric_pair(right_side, y_of_bound); + if (up < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + set_crossed_bounds_column_and_deps(j, true, dep); + } + else { + m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; + set_upper_bound_witness(j, dep); + m_mpq_lar_core_solver.m_column_types[j] = (up == m_mpq_lar_core_solver.m_r_lower_bounds[j] ? column_type::fixed : column_type::boxed); + insert_to_columns_with_changed_bounds(j); + } + break; + } + case GT: + y_of_bound = 1; + case GE: { + auto low = numeric_pair(right_side, y_of_bound); + if (low < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + return; + } + m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; + set_lower_bound_witness(j, dep); + insert_to_columns_with_changed_bounds(j); + break; + } + case EQ: { + auto v = numeric_pair(right_side, zero_of_type()); + if (v < m_mpq_lar_core_solver.m_r_lower_bounds[j]) { + set_crossed_bounds_column_and_deps(j, true, dep); + } + else { + set_upper_bound_witness(j, dep); + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; + m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; + insert_to_columns_with_changed_bounds(j); + } + break; } - set_upper_bound_witness(j, ci); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; - m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; - break; - } - - default: - UNREACHABLE(); + default: + UNREACHABLE(); } - } - - void lar_solver::update_bound_with_ub_no_lb(var_index j, lconstraint_kind kind, const mpq& right_side, constraint_index ci) { + + void lar_solver::update_bound_with_ub_no_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { lp_assert(!column_has_lower_bound(j) && column_has_upper_bound(j)); lp_assert(m_mpq_lar_core_solver.m_column_types[j] == column_type::upper_bound); mpq y_of_bound(0); @@ -2003,7 +2000,7 @@ namespace lp { auto up = numeric_pair(right_side, y_of_bound); if (up >= m_mpq_lar_core_solver.m_r_upper_bounds[j]) return; m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; - set_upper_bound_witness(j, ci); + set_upper_bound_witness(j, dep); insert_to_columns_with_changed_bounds(j); } break; @@ -2013,25 +2010,29 @@ namespace lp { { auto low = numeric_pair(right_side, y_of_bound); if (low > m_mpq_lar_core_solver.m_r_upper_bounds[j]) { - set_infeasible_column(j); + set_crossed_bounds_column_and_deps(j, false, dep); + } + else { + m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_column_types[j] = (low == m_mpq_lar_core_solver.m_r_upper_bounds[j] ? column_type::fixed : column_type::boxed); + insert_to_columns_with_changed_bounds(j); } - m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; - insert_to_columns_with_changed_bounds(j); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_column_types[j] = (low == m_mpq_lar_core_solver.m_r_upper_bounds[j] ? column_type::fixed : column_type::boxed); } break; case EQ: { auto v = numeric_pair(right_side, zero_of_type()); if (v > m_mpq_lar_core_solver.m_r_upper_bounds[j]) { - set_infeasible_column(j); + set_crossed_bounds_column_and_deps(j, false, dep); + } + else { + set_upper_bound_witness(j, dep); + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; + m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; + insert_to_columns_with_changed_bounds(j); } - - set_upper_bound_witness(j, ci); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; - m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; break; } @@ -2039,62 +2040,50 @@ namespace lp { UNREACHABLE(); } } - void lar_solver::update_bound_with_no_ub_no_lb(var_index j, lconstraint_kind kind, const mpq& right_side, constraint_index ci) { + + void lar_solver::update_bound_with_no_ub_no_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep) { lp_assert(!column_has_lower_bound(j) && !column_has_upper_bound(j)); - insert_to_columns_with_changed_bounds(j); mpq y_of_bound(0); switch (kind) { - case LT: - y_of_bound = -1; - case LE: - { - auto up = numeric_pair(right_side, y_of_bound); - m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; - set_upper_bound_witness(j, ci); - m_mpq_lar_core_solver.m_column_types[j] = column_type::upper_bound; - } - break; - case GT: - y_of_bound = 1; - case GE: - { - auto low = numeric_pair(right_side, y_of_bound); - m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; - insert_to_columns_with_changed_bounds(j); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_column_types[j] = column_type::lower_bound; - } - break; - case EQ: - { - auto v = numeric_pair(right_side, zero_of_type()); - set_upper_bound_witness(j, ci); - set_lower_bound_witness(j, ci); - m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; - m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; - break; - } + case LT: + y_of_bound = -1; + case LE: { + auto up = numeric_pair(right_side, y_of_bound); + m_mpq_lar_core_solver.m_r_upper_bounds[j] = up; + set_upper_bound_witness(j, dep); + m_mpq_lar_core_solver.m_column_types[j] = column_type::upper_bound; + } break; + case GT: + y_of_bound = 1; + case GE: { + auto low = numeric_pair(right_side, y_of_bound); + m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_column_types[j] = column_type::lower_bound; + + } break; + case EQ: { + auto v = numeric_pair(right_side, zero_of_type()); + set_upper_bound_witness(j, dep); + set_lower_bound_witness(j, dep); + m_mpq_lar_core_solver.m_r_upper_bounds[j] = m_mpq_lar_core_solver.m_r_lower_bounds[j] = v; + m_mpq_lar_core_solver.m_column_types[j] = column_type::fixed; + break; + } - default: - UNREACHABLE(); + default: + UNREACHABLE(); } - } - - bool lar_solver::column_corresponds_to_term(unsigned j) const { - return tv::is_term(m_var_register.local_to_external(j)); - } + insert_to_columns_with_changed_bounds(j); + } - var_index lar_solver::to_column(unsigned ext_j) const { + lpvar lar_solver::to_column(unsigned ext_j) const { return m_var_register.external_to_local(ext_j); } - bool lar_solver::tighten_term_bounds_by_delta(tv const& t, const impq& delta) { - SASSERT(t.is_term()); - unsigned tj = t.index(); - unsigned j; - if (!m_var_register.external_is_used(tj, j)) - return true; // the term is not a column so it has no bounds + bool lar_solver::tighten_term_bounds_by_delta(lpvar j, const impq& delta) { + SASSERT(column_has_term(j)); auto& slv = m_mpq_lar_core_solver.m_r_solver; TRACE("cube", tout << "delta = " << delta << std::endl; m_int_solver->display_column(tout, j); ); @@ -2107,25 +2096,24 @@ namespace lp { TRACE("cube", tout << "can tighten";); if (slv.column_has_upper_bound(j)) { if (!is_zero(delta.y) || !is_zero(slv.m_upper_bounds[j].y)) - add_var_bound(tj, lconstraint_kind::LT, slv.m_upper_bounds[j].x - delta.x); + add_var_bound(j, lconstraint_kind::LT, slv.m_upper_bounds[j].x - delta.x); else - add_var_bound(tj, lconstraint_kind::LE, slv.m_upper_bounds[j].x - delta.x); + add_var_bound(j, lconstraint_kind::LE, slv.m_upper_bounds[j].x - delta.x); } if (slv.column_has_lower_bound(j)) { if (!is_zero(delta.y) || !is_zero(slv.m_lower_bounds[j].y)) - add_var_bound(tj, lconstraint_kind::GT, slv.m_lower_bounds[j].x + delta.x); + add_var_bound(j, lconstraint_kind::GT, slv.m_lower_bounds[j].x + delta.x); else - add_var_bound(tj, lconstraint_kind::GE, slv.m_lower_bounds[j].x + delta.x); + add_var_bound(j, lconstraint_kind::GE, slv.m_lower_bounds[j].x + delta.x); } return true; } void lar_solver::round_to_integer_solution() { - m_incorrect_columns.resize(column_count()); for (unsigned j = 0; j < column_count(); j++) { if (!column_is_int(j)) continue; - if (column_corresponds_to_term(j)) continue; + if (column_has_term(j)) continue; impq& v = m_mpq_lar_core_solver.m_r_x[j]; if (v.is_int()) continue; @@ -2144,26 +2132,26 @@ namespace lp { } if (!m_incorrect_columns.empty()) { fix_terms_with_rounded_columns(); - m_incorrect_columns.clear(); + m_incorrect_columns.reset(); } } void lar_solver::fix_terms_with_rounded_columns() { - for (unsigned i = 0; i < m_terms.size(); i++) { - if (!term_is_used_as_row(i)) + for (const lar_term* t : m_terms) { + lpvar j = t->j(); + if (!m_columns[j].associated_with_row()) continue; bool need_to_fix = false; - const lar_term& t = *m_terms[i]; - for (lar_term::ival p : t) { - if (m_incorrect_columns.contains(p.column())) { + + for (lar_term::ival p : * t) { + if (m_incorrect_columns.contains(p.j())) { need_to_fix = true; break; } } if (need_to_fix) { - lpvar j = m_var_register.external_to_local(tv::mask_term(i)); - impq v = t.apply(m_mpq_lar_core_solver.m_r_x); + impq v = t->apply(m_mpq_lar_core_solver.m_r_x); m_mpq_lar_core_solver.m_r_solver.update_x(j, v); } } @@ -2175,7 +2163,7 @@ namespace lp { bool lar_solver::sum_first_coords(const lar_term& t, mpq& val) const { val = zero_of_type(); for (lar_term::ival c : t) { - const auto& x = m_mpq_lar_core_solver.m_r_x[c.column()]; + const auto& x = m_mpq_lar_core_solver.m_r_x[c.j()]; if (!is_zero(x.y)) return false; val += x.x * c.coeff(); @@ -2183,18 +2171,14 @@ namespace lp { return true; } - bool lar_solver::get_equality_and_right_side_for_term_on_current_x(tv const& t, mpq& rs, constraint_index& ci, bool& upper_bound) const { - lp_assert(t.is_term()) - unsigned j; - bool is_int; - if (!m_var_register.external_is_used(t.index(), j, is_int)) - return false; // the term does not have a bound because it does not correspond to a column - if (!is_int) // todo - allow for the next version of hnf + bool lar_solver::get_equality_and_right_side_for_term_on_current_x(lpvar j, mpq& rs, u_dependency*& ci, bool& upper_bound) const { + lp_assert(column_has_term(j)); + if (!column_is_int(j)) // todo - allow for the next version of hnf return false; bool rs_is_calculated = false; mpq b; bool is_strict; - const lar_term& term = get_term(t); + const lar_term& term = get_term(j); if (has_upper_bound(j, ci, b, is_strict) && !is_strict) { lp_assert(b.is_int()); if (!sum_first_coords(term, rs)) @@ -2261,9 +2245,8 @@ namespace lp { void lar_solver::register_existing_terms() { if (!m_need_register_terms) { TRACE("nla_solver", tout << "registering " << m_terms.size() << " terms\n";); - for (unsigned k = 0; k < m_terms.size(); k++) { - lpvar j = m_var_register.external_to_local(tv::mask_term(k)); - register_normalized_term(*m_terms[k], j); + for (const lar_term* t : m_terms) { + register_normalized_term(*t, t->j()); } } m_need_register_terms = true; @@ -2271,8 +2254,7 @@ namespace lp { // a_j.first gives the normalised coefficient, // a_j.second givis the column bool lar_solver::fetch_normalized_term_column(const lar_term& c, std::pair& a_j) const { - TRACE("lar_solver_terms", tout << "looking for term "; - print_term_as_indices(c, tout) << "\n";); + TRACE("lar_solver_terms", print_term_as_indices(c, tout << "looking for term ") << "\n";); lp_assert(c.is_normalized()); auto it = m_normalized_terms_to_columns.find(c); if (it != m_normalized_terms_to_columns.end()) { @@ -2285,23 +2267,18 @@ namespace lp { } std::pair lar_solver::add_equality(lpvar j, lpvar k) { - vector> coeffs; - if (tv::is_term(j)) - j = map_term_index_to_column_index(j); - - if (tv::is_term(k)) - k = map_term_index_to_column_index(k); - + vector> coeffs; + coeffs.push_back(std::make_pair(mpq(1), j)); coeffs.push_back(std::make_pair(mpq(-1), k)); - unsigned term_index = add_term(coeffs, UINT_MAX); // UINT_MAX is the external null var + unsigned ej = add_term(coeffs, UINT_MAX); // UINT_MAX is the external null var if (get_column_value(j) != get_column_value(k)) set_status(lp_status::UNKNOWN); return std::pair( - add_var_bound(term_index, lconstraint_kind::LE, mpq(0)), - add_var_bound(term_index, lconstraint_kind::GE, mpq(0))); + add_var_bound(ej, lconstraint_kind::LE, mpq(0)), + add_var_bound(ej, lconstraint_kind::GE, mpq(0))); } bool lar_solver::inside_bounds(lpvar j, const impq& val) const { @@ -2311,6 +2288,39 @@ namespace lp { return false; return true; } + // If lower_bound is true than the new asserted upper bound is less than the existing lower bound. + // Otherwise the new asserted lower bound is is greater than the existing upper bound. + // dep is the reason for the new bound + + void lar_solver::set_crossed_bounds_column_and_deps(unsigned j, bool lower_bound, u_dependency* dep) { + if (m_crossed_bounds_column != null_lpvar) return; // already set + SASSERT(m_crossed_bounds_deps == nullptr); + set_status(lp_status::INFEASIBLE); + m_crossed_bounds_column = j; + const auto& ul = this->m_columns[j]; + u_dependency* bdep = lower_bound? ul.lower_bound_witness() : ul.upper_bound_witness(); + SASSERT(bdep != nullptr); + m_crossed_bounds_deps = m_dependencies.mk_join(bdep, dep); + insert_to_columns_with_changed_bounds(j); + } + + void lar_solver::collect_more_rows_for_lp_propagation(){ + for (auto j : m_columns_with_changed_bounds) + detect_rows_with_changed_bounds_for_column(j); + } + std::ostream& lar_solver::print_explanation( + std::ostream& out, const explanation& exp, + std::function var_str) const { + out << "expl: "; + unsigned i = 0; + for (auto p : exp) { + out << "(" << p.ci() << ")"; + constraints().display(out, var_str, p.ci()); + if (++i < exp.size()) + out << " "; + } + return out; + } } // namespace lp diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 182ef0be3bf..4d71d0181da 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -18,45 +18,48 @@ --*/ #pragma once -#include "util/vector.h" -#include -#include "util/debug.h" -#include "util/buffer.h" -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include + +#include "math/lp/bound_analyzer_on_row.h" +#include "math/lp/implied_bound.h" +#include "math/lp/int_solver.h" #include "math/lp/lar_constraints.h" #include "math/lp/lar_core_solver.h" -#include "math/lp/numeric_pair.h" +#include "math/lp/lp_bound_propagator.h" #include "math/lp/lp_primal_core_solver.h" +#include "math/lp/lp_types.h" +#include "math/lp/nra_solver.h" +#include "math/lp/numeric_pair.h" #include "math/lp/random_updater.h" -#include "util/stacked_value.h" #include "math/lp/stacked_vector.h" -#include "math/lp/implied_bound.h" -#include "math/lp/bound_analyzer_on_row.h" -#include "math/lp/int_solver.h" -#include "math/lp/nra_solver.h" -#include "math/lp/lp_types.h" -#include "math/lp/lp_bound_propagator.h" +#include "util/buffer.h" +#include "util/debug.h" +#include "util/stacked_value.h" +#include "util/vector.h" +#include "util/trail.h" namespace lp { class int_branch; class int_solver; + + class lar_solver : public column_namer { struct term_hasher { - std::size_t operator()(const lar_term &t) const - { - using std::size_t; + std::size_t operator()(const lar_term& t) const { using std::hash; + using std::size_t; using std::string; size_t seed = 0; int i = 0; for (const auto p : t) { - hash_combine(seed, (unsigned)p.column()); + hash_combine(seed, (unsigned)p.j()); hash_combine(seed, p.coeff()); if (i++ > 10) break; @@ -66,110 +69,119 @@ class lar_solver : public column_namer { }; struct term_comparer { - bool operator()(const lar_term &a, const lar_term& b) const - { - return a == b; + bool operator()(const lar_term& a, const lar_term& b) const { + return a == b; } }; - + //////////////////// fields ////////////////////////// - lp_settings m_settings; - lp_status m_status = lp_status::UNKNOWN; - stacked_value m_simplex_strategy; + trail_stack m_trail; + lp_settings m_settings; + lp_status m_status = lp_status::UNKNOWN; + stacked_value m_simplex_strategy; // such can be found at the initialization step: u < l - stacked_value m_crossed_bounds_column; - lar_core_solver m_mpq_lar_core_solver; - int_solver * m_int_solver = nullptr; - bool m_need_register_terms = false; - var_register m_var_register; - var_register m_term_register; - stacked_vector m_columns_to_ul_pairs; - constraint_set m_constraints; + lpvar m_crossed_bounds_column = null_lpvar; + u_dependency* m_crossed_bounds_deps = nullptr; + lar_core_solver m_mpq_lar_core_solver; + int_solver* m_int_solver = nullptr; + bool m_need_register_terms = false; + var_register m_var_register; + svector m_columns; + constraint_set m_constraints; // the set of column indices j such that bounds have changed for j - u_set m_columns_with_changed_bounds; - u_set m_rows_with_changed_bounds; - unsigned_vector m_row_bounds_to_replay; - - u_set m_basic_columns_with_changed_cost; + indexed_uint_set m_columns_with_changed_bounds; + indexed_uint_set m_touched_rows; + unsigned_vector m_row_bounds_to_replay; + u_dependency_manager m_dependencies; + svector m_tmp_dependencies; + + indexed_uint_set m_basic_columns_with_changed_cost; // these are basic columns with the value changed, so the corresponding row in the tableau // does not sum to zero anymore - u_set m_incorrect_columns; - // copy of m_r_solver.inf_set() - unsigned_vector m_inf_index_copy; - stacked_value m_term_count; - vector m_terms; - indexed_vector m_column_buffer; + indexed_uint_set m_incorrect_columns; + // copy of m_r_solver.inf_heap() + unsigned_vector m_inf_index_copy; + vector m_terms; + indexed_vector m_column_buffer; std::unordered_map, term_hasher, term_comparer> - m_normalized_terms_to_columns; - vector m_backup_x; - stacked_vector m_usage_in_terms; + m_normalized_terms_to_columns; + vector m_backup_x; + stacked_vector m_usage_in_terms; // ((x[j], is_int(j))->j) for fixed j, used in equalities propagation // maps values to integral fixed vars - map, default_eq> m_fixed_var_table_int; + map, default_eq> m_fixed_var_table_int; // maps values to non-integral fixed vars - map, default_eq> m_fixed_var_table_real; + map, default_eq> m_fixed_var_table_real; + // the set of fixed variables which are also base variables + indexed_uint_set m_fixed_base_var_set; // end of fields + ////////////////// nested structs ///////////////////////// + struct undo_add_column; + ////////////////// methods //////////////////////////////// - - static bool valid_index(unsigned j) { return static_cast(j) >= 0;} - const lar_term & get_term(unsigned j) const; + + static bool valid_index(unsigned j) { return static_cast(j) >= 0; } bool row_has_a_big_num(unsigned i) const; // init region - bool strategy_is_undecided() const; - void register_new_ext_var_index(unsigned ext_v, bool is_int); - bool term_is_int(const lar_term * t) const; - bool term_is_int(const vector> & coeffs) const; + void register_new_external_var(unsigned ext_v, bool is_int); + bool term_is_int(const lar_term* t) const; + bool term_is_int(const vector>& coeffs) const; void add_non_basic_var_to_core_fields(unsigned ext_j, bool is_int); void add_new_var_to_core_fields_for_mpq(bool register_in_basis); mpq adjust_bound_for_int(lpvar j, lconstraint_kind&, const mpq&); // terms - bool all_vars_are_registered(const vector> & coeffs); - var_index add_term_undecided(const vector> & coeffs); - bool term_coeffs_are_ok(const vector> & coeffs); - void push_term(lar_term* t); - void add_row_from_term_no_constraint(const lar_term * term, unsigned term_ext_index); + bool all_vars_are_registered(const vector>& coeffs); + bool term_coeffs_are_ok(const vector>& coeffs); + void add_row_from_term_no_constraint(lar_term* term, unsigned term_ext_index); void add_basic_var_to_core_fields(); - bool compare_values(impq const& lhs, lconstraint_kind k, const mpq & rhs); + bool compare_values(impq const& lhs, lconstraint_kind k, const mpq& rhs); - inline void clear_columns_with_changed_bounds() { m_columns_with_changed_bounds.clear(); } - inline void increase_by_one_columns_with_changed_bounds() { m_columns_with_changed_bounds.increase_size_by_one(); } - inline void insert_to_columns_with_changed_bounds(unsigned j) { m_columns_with_changed_bounds.insert(j); } - - void update_column_type_and_bound_check_on_equal(unsigned j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index, unsigned&); - void update_column_type_and_bound(unsigned j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_column_type_and_bound_with_ub(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_column_type_and_bound_with_no_ub(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_bound_with_ub_lb(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_bound_with_no_ub_lb(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_bound_with_ub_no_lb(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); - void update_bound_with_no_ub_no_lb(var_index j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index); + inline void clear_columns_with_changed_bounds() { m_columns_with_changed_bounds.reset(); } + public: + const auto& columns_with_changed_bounds() const { return m_columns_with_changed_bounds; } + void insert_to_columns_with_changed_bounds(unsigned j); + const u_dependency* crossed_bounds_deps() const { return m_crossed_bounds_deps;} + u_dependency*& crossed_bounds_deps() { return m_crossed_bounds_deps;} + + lpvar crossed_bounds_column() const { return m_crossed_bounds_column; } + lpvar& crossed_bounds_column() { return m_crossed_bounds_column; } + + + private: + bool validate_bound(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void add_dep_constraints_to_solver(lar_solver& ls, u_dependency* dep); + void add_bound_negation_to_solver(lar_solver& ls, lpvar j, lconstraint_kind kind, const mpq& right_side); + void add_constraint_to_validate(lar_solver& ls, constraint_index ci); + bool m_validate_blocker = false; + void update_column_type_and_bound_check_on_equal(unsigned j, const mpq& right_side, constraint_index ci, unsigned&); + void update_column_type_and_bound(unsigned j, const mpq& right_side, constraint_index ci); + public: + bool validate_blocker() const { return m_validate_blocker; } + bool & validate_blocker() { return m_validate_blocker; } + void update_column_type_and_bound(unsigned j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + private: + void require_nbasis_sort() { m_mpq_lar_core_solver.m_r_solver.m_nbasis_sort_counter = 0; } + void update_column_type_and_bound_with_ub(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void update_column_type_and_bound_with_no_ub(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void update_bound_with_ub_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void update_bound_with_no_ub_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void update_bound_with_ub_no_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); + void update_bound_with_no_ub_no_lb(lpvar j, lconstraint_kind kind, const mpq& right_side, u_dependency* dep); void register_in_fixed_var_table(unsigned, unsigned&); void remove_non_fixed_from_fixed_var_table(); - constraint_index add_var_bound_on_constraint_for_term(var_index j, lconstraint_kind kind, const mpq & right_side); - inline void set_infeasible_column(unsigned j) { - set_status(lp_status::INFEASIBLE); - m_crossed_bounds_column = j; - } - constraint_index add_constraint_from_term_and_create_new_column_row(unsigned term_j, const lar_term* term, - lconstraint_kind kind, const mpq & right_side); + constraint_index add_var_bound_on_constraint_for_term(lpvar j, lconstraint_kind kind, const mpq& right_side); + void set_crossed_bounds_column_and_deps(unsigned j, bool lower_bound, u_dependency* dep); unsigned row_of_basic_column(unsigned) const; - void decide_on_strategy_and_adjust_initial_state(); - void adjust_initial_state(); - void adjust_initial_state_for_tableau_rows(); bool sizes_are_correct() const; - bool implied_bound_is_correctly_explained(implied_bound const & be, const vector> & explanation) const; - + bool implied_bound_is_correctly_explained(implied_bound const& be, const vector>& explanation) const; - void substitute_basis_var_in_terms_for_row(unsigned i); - template - unsigned calculate_implied_bounds_for_row(unsigned row_index, lp_bound_propagator & bp) { - - if (A_r().m_rows[row_index].size() > settings().max_row_length_for_bound_propagation || row_has_a_big_num(row_index)) + unsigned calculate_implied_bounds_for_row(unsigned row_index, lp_bound_propagator& bp) { + if (A_r().m_rows[row_index].size() > settings().max_row_length_for_bound_propagation || row_has_a_big_num(row_index)) return 0; - + return bound_analyzer_on_row, lp_bound_propagator>::analyze_row( A_r().m_rows[row_index], null_ci, @@ -177,51 +189,51 @@ class lar_solver : public column_namer { row_index, bp); } - - static void clean_popped_elements(unsigned n, u_set& set); - bool maximize_term_on_tableau(const lar_term & term, - impq &term_max); + + static void clean_popped_elements_for_heap(unsigned n, lpvar_heap& set); + static void clean_popped_elements(unsigned n, indexed_uint_set& set); + bool maximize_term_on_tableau(const lar_term& term, impq& term_max); bool costs_are_zeros_for_r_solver() const; bool reduced_costs_are_zeroes_for_r_solver() const; - void set_costs_to_zero(const lar_term & term); - void prepare_costs_for_r_solver(const lar_term & term); - bool maximize_term_on_corrected_r_solver(lar_term & term, impq &term_max); + void set_costs_to_zero(const lar_term& term); + void prepare_costs_for_r_solver(const lar_term& term); + bool maximize_term_on_feasible_r_solver(lar_term& term, impq& term_max, vector>* max_coeffs); + u_dependency* get_dependencies_of_maximum(const vector>& max_coeffs); + void pop_core_solver_params(); void pop_core_solver_params(unsigned k); - void set_upper_bound_witness(var_index j, constraint_index ci); - void set_lower_bound_witness(var_index j, constraint_index ci); - void substitute_terms_in_linear_expression( const vector>& left_side_with_terms, - vector> &left_side) const; - - void detect_rows_of_bound_change_column_for_nbasic_column_tableau(unsigned j); + void set_upper_bound_witness(lpvar j, u_dependency* ci); + void set_lower_bound_witness(lpvar j, u_dependency* ci); + void substitute_terms_in_linear_expression(const vector>& left_side_with_terms, + vector>& left_side) const; + bool use_tableau_costs() const; bool tableau_with_costs() const; bool costs_are_used() const; - void change_basic_columns_dependend_on_a_given_nb_column(unsigned j, const numeric_pair & delta); + void change_basic_columns_dependend_on_a_given_nb_column(unsigned j, const numeric_pair& delta); void update_x_and_inf_costs_for_column_with_changed_bounds(unsigned j); - unsigned num_changed_bounds() const { return m_rows_with_changed_bounds.size(); } - void insert_row_with_changed_bounds(unsigned rid); + void add_touched_row(unsigned rid); void detect_rows_with_changed_bounds_for_column(unsigned j); void detect_rows_with_changed_bounds(); void update_x_and_inf_costs_for_columns_with_changed_bounds_tableau(); void solve_with_core_solver(); numeric_pair get_basic_var_value_from_row(unsigned i); - bool all_constrained_variables_are_registered(const vector>& left_side); + bool all_constrained_variables_are_registered(const vector>& left_side); bool all_constraints_hold() const; - bool constraint_holds(const lar_base_constraint & constr, std::unordered_map & var_map) const; - static void register_in_map(std::unordered_map & coeffs, const lar_base_constraint & cn, const mpq & a); - static void register_monoid_in_map(std::unordered_map & coeffs, const mpq & a, unsigned j); - bool the_left_sides_sum_to_zero(const vector> & evidence) const; + bool constraint_holds(const lar_base_constraint& constr, std::unordered_map& var_map) const; + static void register_in_map(std::unordered_map& coeffs, const lar_base_constraint& cn, const mpq& a); + static void register_monoid_in_map(std::unordered_map& coeffs, const mpq& a, unsigned j); + bool the_left_sides_sum_to_zero(const vector>& evidence) const; bool explanation_is_correct(explanation&) const; bool inf_explanation_is_correct() const; - mpq sum_of_right_sides_of_explanation(explanation &) const; + mpq sum_of_right_sides_of_explanation(explanation&) const; void get_infeasibility_explanation_for_inf_sign( - explanation & exp, - const vector> & inf_row, + explanation& exp, + const vector>& inf_row, int inf_sign) const; - mpq get_left_side_val(const lar_base_constraint & cns, const std::unordered_map & var_map) const; - void fill_var_set_for_random_update(unsigned sz, var_index const * vars, vector& column_list); + mpq get_left_side_val(const lar_base_constraint& cns, const std::unordered_map& var_map) const; + void fill_var_set_for_random_update(unsigned sz, lpvar const* vars, vector& column_list); bool column_represents_row_in_tableau(unsigned j); void make_sure_that_the_bottom_right_elem_not_zero_in_tableau(unsigned i, unsigned j); void remove_last_row_and_column_from_tableau(unsigned j); @@ -229,26 +241,30 @@ class lar_solver : public column_namer { void remove_last_column_from_basis_tableau(unsigned j); void remove_last_column_from_tableau(); - void pop_tableau(); - void clean_inf_set_of_r_solver_after_pop(); + void clean_inf_heap_of_r_solver_after_pop(); inline bool column_value_is_integer(unsigned j) const { return get_column_value(j).is_int(); } bool model_is_int_feasible() const; - - bool bound_is_integer_for_integer_column(unsigned j, const mpq & right_side) const; - inline lar_core_solver & get_core_solver() { return m_mpq_lar_core_solver; } - var_index to_column(unsigned ext_j) const; + + bool bound_is_integer_for_integer_column(unsigned j, const mpq& right_side) const; + inline lar_core_solver& get_core_solver() { return m_mpq_lar_core_solver; } + lpvar to_column(unsigned ext_j) const; void fix_terms_with_rounded_columns(); bool remove_from_basis(unsigned); lar_term get_term_to_maximize(unsigned ext_j) const; - bool sum_first_coords(const lar_term& t, mpq & val) const; + bool sum_first_coords(const lar_term& t, mpq& val) const; void register_normalized_term(const lar_term&, lpvar); void deregister_normalized_term(const lar_term&); mutable std::unordered_set m_set_of_different_pairs; - mutable std::unordered_set m_set_of_different_singles; + mutable std::unordered_set m_set_of_different_singles; mutable mpq m_delta; public: + u_dependency* find_improved_bound(lpvar j, bool is_lower, mpq& bound); + + std::ostream& print_explanation( + std::ostream& out, const explanation& exp, + std::function var_str = [](lpvar j) { return std::string("j") + T_to_string(j); }) const; // this function just looks at the status bool is_feasible() const; @@ -256,7 +272,6 @@ class lar_solver : public column_namer { return m_fixed_var_table_int; } - const map, default_eq>& fixed_var_table_real() const { return m_fixed_var_table_real; } @@ -265,13 +280,12 @@ class lar_solver : public column_namer { return m_fixed_var_table_real; } - bool find_in_fixed_tables(const rational& mpq, bool is_int, unsigned & j) const { - return is_int? fixed_var_table_int().find(mpq, j) : fixed_var_table_real().find(mpq, j); + bool find_in_fixed_tables(const rational& mpq, bool is_int, unsigned& j) const { + return is_int ? fixed_var_table_int().find(mpq, j) : fixed_var_table_real().find(mpq, j); } - - template void remove_non_fixed_from_table(T&); - unsigned external_to_column_index(unsigned) const; + template + void remove_non_fixed_from_table(T&); bool inside_bounds(lpvar, const impq&) const; @@ -282,157 +296,161 @@ class lar_solver : public column_namer { inline void set_column_value_test(unsigned j, const impq& v) { set_column_value(j, v); } - - var_index add_named_var(unsigned ext_j, bool is_integer, const std::string&); - lp_status maximize_term(unsigned j_or_term, impq &term_max); + lpvar add_named_var(unsigned ext_j, bool is_integer, const std::string&); + + lp_status maximize_term(unsigned j_or_term, impq& term_max); - inline core_solver_pretty_printer pp(std::ostream& out) const { - return core_solver_pretty_printer(m_mpq_lar_core_solver.m_r_solver, out); + inline core_solver_pretty_printer pp(std::ostream& out) const { + return core_solver_pretty_printer(m_mpq_lar_core_solver.m_r_solver, out); } - void get_infeasibility_explanation(explanation &) const; + void get_infeasibility_explanation(explanation&) const; inline void backup_x() { m_backup_x = m_mpq_lar_core_solver.m_r_x; } inline void restore_x() { m_mpq_lar_core_solver.m_r_x = m_backup_x; } template - void explain_implied_bound(const implied_bound & ib, lp_bound_propagator & bp) { - unsigned i = ib.m_row_or_term_index; - int bound_sign = (ib.m_is_lower_bound ? 1 : -1); - int j_sign = (ib.m_coeff_before_j_is_pos ? 1 : -1) * bound_sign; - unsigned bound_j = ib.m_j; - if (tv::is_term(bound_j)) - bound_j = m_var_register.external_to_local(bound_j); - - for (auto const& r : get_row(i)) { - unsigned j = r.var(); - if (j == bound_j) - continue; - mpq const& a = r.coeff(); - int a_sign = is_pos(a) ? 1 : -1; - int sign = j_sign * a_sign; - const ul_pair & ul = m_columns_to_ul_pairs[j]; - auto witness = sign > 0 ? ul.upper_bound_witness() : ul.lower_bound_witness(); - lp_assert(is_valid(witness)); - bp.consume(a, witness); - } + void explain_implied_bound(const implied_bound& ib, lp_bound_propagator& bp) { + u_dependency* dep = ib.explain_implied(); + for (auto ci : flatten(dep)) + bp.consume(mpq(1), ci); // TODO: flatten should provide the coefficients + /* + if (ib.m_is_monic) { + NOT_IMPLEMENTED_YET(); + } else { + unsigned i = ib.m_row_or_term_index; + int bound_sign = (ib.m_is_lower_bound ? 1 : -1); + int j_sign = (ib.m_coeff_before_j_is_pos ? 1 : -1) * bound_sign; + unsigned bound_j = ib.m_j; + if (tv::is_term(bound_j)) + bound_j = m_var_register.external_to_local(bound_j); + + for (auto const& r : get_row(i)) { + unsigned j = r.var(); + if (j == bound_j) + continue; + mpq const& a = r.coeff(); + int a_sign = is_pos(a) ? 1 : -1; + int sign = j_sign * a_sign; + const column& ul = m_columns[j]; + auto* witness = sign > 0 ? ul.upper_bound_witness() : ul.lower_bound_witness(); + lp_assert(witness); + for (auto ci : flatten(witness)) + bp.consume(a, ci); + } + }*/ } void set_value_for_nbasic_column(unsigned j, const impq& new_val); + void remove_fixed_vars_from_base(); + inline unsigned get_base_column_in_row(unsigned row_index) const { return m_mpq_lar_core_solver.m_r_solver.get_base_column_in_row(row_index); } - - // lp_assert(implied_bound_is_correctly_explained(ib, explanation)); } - constraint_index mk_var_bound(var_index j, lconstraint_kind kind, const mpq & right_side); - void activate_check_on_equal(constraint_index, var_index&); +#ifdef Z3DEBUG + bool fixed_base_removed_correctly() const; +#endif + constraint_index mk_var_bound(lpvar j, lconstraint_kind kind, const mpq& right_side); + void activate_check_on_equal(constraint_index, lpvar&); void activate(constraint_index); - void random_update(unsigned sz, var_index const * vars); - void mark_rows_for_bound_prop(lpvar j); + void random_update(unsigned sz, lpvar const* vars); + void add_column_rows_to_touched_rows(lpvar j); template - void propagate_bounds_for_touched_rows(lp_bound_propagator & bp) { - unsigned num_prop = 0; - for (unsigned i : m_rows_with_changed_bounds) { - num_prop += calculate_implied_bounds_for_row(i, bp); - if (settings().get_cancel_flag()) - return; - } - // these two loops should be run sequentially - // since the first loop might change column bounds - // and add fixed columns this way + void propagate_bounds_for_touched_rows(lp_bound_propagator& bp) { if (settings().propagate_eqs()) { + if (settings().random_next() % 10 == 0) + remove_fixed_vars_from_base(); bp.clear_for_eq(); - for (unsigned i : m_rows_with_changed_bounds) { + for (unsigned i : m_touched_rows) { unsigned offset_eqs = stats().m_offset_eqs; - bp.cheap_eq_tree(i); + bp.cheap_eq_on_nbase(i); if (settings().get_cancel_flag()) return; if (stats().m_offset_eqs > offset_eqs) m_row_bounds_to_replay.push_back(i); } } - m_rows_with_changed_bounds.clear(); + for (unsigned i : m_touched_rows) { + calculate_implied_bounds_for_row(i, bp); + if (settings().get_cancel_flag()) + return; + } + m_touched_rows.reset(); } - + void collect_more_rows_for_lp_propagation(); template - void check_missed_propagations(lp_bound_propagator & bp) { - for (unsigned i = 0; i < A_r().row_count(); i++) - if (!m_rows_with_changed_bounds.contains(i)) + void check_missed_propagations(lp_bound_propagator& bp) { + for (unsigned i = 0; i < A_r().row_count(); i++) + if (!m_touched_rows.contains(i)) if (0 < calculate_implied_bounds_for_row(i, bp)) { verbose_stream() << i << ": " << get_row(i) << "\n"; } } - bool is_fixed_at_bound(column_index const& j); - bool has_fixed_at_bound(); - - bool is_fixed(column_index const& j) const { return column_is_fixed(j); } - inline column_index to_column_index(unsigned v) const { return column_index(external_to_column_index(v)); } bool external_is_used(unsigned) const; void pop(unsigned k); - bool compare_values(var_index j, lconstraint_kind kind, const mpq & right_side); - var_index add_term(const vector> & coeffs, unsigned ext_i); + unsigned num_scopes() const { return m_trail.get_num_scopes(); } + bool compare_values(lpvar j, lconstraint_kind kind, const mpq& right_side); + lpvar add_term(const vector>& coeffs, unsigned ext_i); void register_existing_terms(); - constraint_index add_var_bound(var_index, lconstraint_kind, const mpq &); - constraint_index add_var_bound_check_on_equal(var_index, lconstraint_kind, const mpq &, var_index&); - - var_index add_var(unsigned ext_j, bool is_integer); + constraint_index add_var_bound(lpvar, lconstraint_kind, const mpq&); + constraint_index add_var_bound_check_on_equal(lpvar, lconstraint_kind, const mpq&, lpvar&); + + lpvar add_var(unsigned ext_j, bool is_integer); void set_cut_strategy(unsigned cut_frequency); inline unsigned column_count() const { return A_r().column_count(); } - inline var_index local_to_external(var_index idx) const { - return tv::is_term(idx)? - m_term_register.local_to_external(idx) : m_var_register.local_to_external(idx); + inline lpvar local_to_external(lpvar idx) const { + return m_var_register.local_to_external(idx); } - bool column_corresponds_to_term(unsigned) const; + inline bool column_associated_with_row(lpvar j) const { return m_columns[j].associated_with_row(); } inline unsigned row_count() const { return A_r().row_count(); } - bool var_is_registered(var_index vj) const; - void clear_inf_set() { - m_mpq_lar_core_solver.m_r_solver.inf_set().clear(); + bool var_is_registered(lpvar vj) const; + void clear_inf_heap() { + m_mpq_lar_core_solver.m_r_solver.inf_heap().clear(); } - inline void remove_column_from_inf_set(unsigned j) { - m_mpq_lar_core_solver.m_r_solver.remove_column_from_inf_set(j); + + void pivot(int entering, int leaving) { + m_mpq_lar_core_solver.pivot(entering, leaving); } + template void change_basic_columns_dependend_on_a_given_nb_column_report(unsigned j, - const numeric_pair & delta, + const numeric_pair& delta, const ChangeReport& after) { - - for (const auto & c : A_r().m_columns[j]) { - unsigned bj = m_mpq_lar_core_solver.m_r_basis[c.var()]; - if (tableau_with_costs()) { - m_basic_columns_with_changed_cost.insert(bj); - } - m_mpq_lar_core_solver.m_r_solver.add_delta_to_x_and_track_feasibility(bj, - A_r().get_val(c) * delta); - after(bj); - TRACE("change_x_del", - tout << "changed basis column " << bj << ", it is " << - ( m_mpq_lar_core_solver.m_r_solver.column_is_feasible(bj)? "feas":"inf") << std::endl;); - } - } + for (const auto& c : A_r().m_columns[j]) { + unsigned bj = m_mpq_lar_core_solver.m_r_basis[c.var()]; + if (tableau_with_costs()) + m_basic_columns_with_changed_cost.insert(bj); + m_mpq_lar_core_solver.m_r_solver.add_delta_to_x_and_track_feasibility(bj, -A_r().get_val(c) * delta); + after(bj); + TRACE("change_x_del", + tout << "changed basis column " << bj << ", it is " << (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(bj) ? "feas" : "inf") << std::endl;); + } + } template void set_value_for_nbasic_column_report(unsigned j, - const impq & new_val, + const impq& new_val, const ChangeReport& after) { - lp_assert(!is_base(j)); - auto & x = m_mpq_lar_core_solver.m_r_x[j]; + auto& x = m_mpq_lar_core_solver.m_r_x[j]; auto delta = new_val - x; x = new_val; after(j); change_basic_columns_dependend_on_a_given_nb_column_report(j, delta, after); } - + template bool try_to_patch(lpvar j, const mpq& val, const Blocker& is_blocked, const ChangeReport& change_report) { - if (is_base(j)) { + if (is_base(j)) { TRACE("nla_solver", get_int_solver()->display_row_info(tout, row_of_basic_column(j)) << "\n";); - remove_from_basis(j); + if (!remove_from_basis(j)) + return false; } impq ival(val); @@ -442,8 +460,8 @@ class lar_solver : public column_namer { impq delta = get_column_value(j) - ival; for (auto c : A_r().column(j)) { unsigned row_index = c.var(); - const mpq & a = c.coeff(); - unsigned rj = m_mpq_lar_core_solver.m_r_basis[row_index]; + const mpq& a = c.coeff(); + unsigned rj = m_mpq_lar_core_solver.m_r_basis[row_index]; impq rj_new_val = a * delta + get_column_value(rj); // if (column_is_int(rj) && !rj_new_val.is_int()) // return false; @@ -460,65 +478,76 @@ class lar_solver : public column_namer { inline bool column_has_lower_bound(unsigned j) const { return m_mpq_lar_core_solver.m_r_solver.column_has_lower_bound(j); - } + } - inline - constraint_index get_column_upper_bound_witness(unsigned j) const { - if (tv::is_term(j)) { - j = m_var_register.external_to_local(j); - } - return m_columns_to_ul_pairs()[j].upper_bound_witness(); + svector const& flatten(u_dependency* d) { + m_tmp_dependencies.reset(); + m_dependencies.linearize(d, m_tmp_dependencies); + return m_tmp_dependencies; } - inline - const impq& get_upper_bound(column_index j) const { + void push_explanation(u_dependency* d, explanation& ex) { + for (auto ci : flatten(d)) + ex.push_back(ci); + } + + u_dependency_manager& dep_manager() { return m_dependencies; } + + inline u_dependency* get_column_upper_bound_witness(unsigned j) const { + return m_columns[j].upper_bound_witness(); + } + + inline const impq& get_upper_bound(lpvar j) const { return m_mpq_lar_core_solver.m_r_solver.m_upper_bounds[j]; } - inline - const impq& get_lower_bound(column_index j) const { + inline const impq& get_lower_bound(lpvar j) const { return m_mpq_lar_core_solver.m_r_solver.m_lower_bounds[j]; } - bool has_lower_bound(var_index var, constraint_index& ci, mpq& value, bool& is_strict) const; - bool has_upper_bound(var_index var, constraint_index& ci, mpq& value, bool& is_strict) const; - bool has_value(var_index var, mpq& value) const; - bool fetch_normalized_term_column(const lar_term& t, std::pair& ) const; - unsigned map_term_index_to_column_index(unsigned j) const; + + inline mpq bound_span_x(lpvar j) const { + return m_mpq_lar_core_solver.m_r_solver.m_upper_bounds[j].x - m_mpq_lar_core_solver.m_r_solver.m_lower_bounds[j].x; + } + + bool has_lower_bound(lpvar var, u_dependency*& ci, mpq& value, bool& is_strict) const; + bool has_upper_bound(lpvar var, u_dependency*& ci, mpq& value, bool& is_strict) const; + bool has_value(lpvar var, mpq& value) const; + bool fetch_normalized_term_column(const lar_term& t, std::pair&) const; bool column_is_fixed(unsigned j) const; bool column_is_free(unsigned j) const; - unsigned column_to_reported_index(unsigned j) const; - lp_settings & settings(); - lp_settings const & settings() const; + bool column_is_feasible(unsigned j) const { return m_mpq_lar_core_solver.m_r_solver.column_is_feasible(j);} + lp_settings& settings(); + lp_settings const& settings() const; statistics& stats(); - + void updt_params(params_ref const& p); column_type get_column_type(unsigned j) const { return m_mpq_lar_core_solver.m_column_types()[j]; } - const impq & get_lower_bound(unsigned j) const { return m_mpq_lar_core_solver.m_r_lower_bounds()[j]; } - const impq & get_upper_bound(unsigned j) const { return m_mpq_lar_core_solver.m_r_upper_bounds()[j]; } + const vector& get_column_types() const { return m_mpq_lar_core_solver.m_column_types(); } std::ostream& print_terms(std::ostream& out) const; - std::ostream& print_term(lar_term const& term, std::ostream & out) const; - static std::ostream& print_term_as_indices(lar_term const& term, std::ostream & out); - std::ostream& print_constraint_indices_only(const lar_base_constraint * c, std::ostream & out) const; - std::ostream& print_implied_bound(const implied_bound& be, std::ostream & out) const; + std::ostream& print_term(lar_term const& term, std::ostream& out) const; + static std::ostream& print_term_as_indices(lar_term const& term, std::ostream& out); + std::ostream& print_constraint_indices_only(const lar_base_constraint* c, std::ostream& out) const; + std::ostream& print_implied_bound(const implied_bound& be, std::ostream& out) const; std::ostream& print_values(std::ostream& out) const; std::ostream& display(std::ostream& out) const; - + std::ostream& display_constraint(std::ostream& out, constraint_index ci) const { + return m_constraints.display(out, ci); + } bool init_model() const; - mpq get_value(column_index const& j) const; - mpq get_tv_value(tv const& t) const; - impq get_tv_ivalue(tv const& t) const; - void get_model(std::unordered_map & variable_values) const; + mpq from_model_in_impq_to_mpq(const impq& v) const { return v.x + m_delta * v.y; } + mpq get_value(lpvar j) const; + void get_model(std::unordered_map& variable_values) const; void get_rid_of_inf_eps(); - void get_model_do_not_care_about_diff_vars(std::unordered_map & variable_values) const; - std::string get_variable_name(var_index vi) const override; - void set_variable_name(var_index vi, std::string); + void get_model_do_not_care_about_diff_vars(std::unordered_map& variable_values) const; + std::string get_variable_name(lpvar vi) const override; + void set_variable_name(lpvar vi, std::string); inline unsigned number_of_vars() const { return m_var_register.size(); } inline bool is_base(unsigned j) const { return m_mpq_lar_core_solver.m_r_heading[j] >= 0; } - inline const impq & column_lower_bound(unsigned j) const { + inline const impq& column_lower_bound(unsigned j) const { return m_mpq_lar_core_solver.lower_bound(j); } - inline const impq & column_upper_bound(unsigned j) const { + inline const impq& column_upper_bound(unsigned j) const { return m_mpq_lar_core_solver.upper_bound(j); } @@ -531,38 +560,41 @@ class lar_solver : public column_namer { } std::pair add_equality(lpvar j, lpvar k); - - inline void get_bound_constraint_witnesses_for_column(unsigned j, constraint_index & lc, constraint_index & uc) const { - const ul_pair & ul = m_columns_to_ul_pairs[j]; - lc = ul.lower_bound_witness(); - uc = ul.upper_bound_witness(); + + u_dependency* get_bound_constraint_witnesses_for_column(unsigned j) { + const column& ul = m_columns[j]; + return m_dependencies.mk_join(ul.lower_bound_witness(), ul.upper_bound_witness()); } + template + u_dependency* get_bound_constraint_witnesses_for_columns(const T& collection) { + u_dependency* dep = nullptr; + for (auto j : collection) { + u_dependency* d = get_bound_constraint_witnesses_for_column(j); + dep = m_dependencies.mk_join(dep, d); + } + return dep; + } + u_dependency* join_deps(u_dependency* a, u_dependency *b) { return m_dependencies.mk_join(a, b); } inline constraint_set const& constraints() const { return m_constraints; } void push(); void pop(); - inline constraint_index get_column_lower_bound_witness(unsigned j) const { - if (tv::is_term(j)) { - j = m_var_register.external_to_local(j); - } - return m_columns_to_ul_pairs()[j].lower_bound_witness(); - } - inline tv column2tv(column_index const& c) const { - return tv::raw(column_to_reported_index(c)); + inline u_dependency* get_column_lower_bound_witness(unsigned j) const { + return m_columns[j].lower_bound_witness(); } - + inline bool column_has_term(lpvar j) const { return m_columns[j].term() != nullptr; } inline std::ostream& print_column_info(unsigned j, std::ostream& out) const { m_mpq_lar_core_solver.m_r_solver.print_column_info(j, out); - if (tv::is_term(j)) { + if (column_has_term(j)) { print_term_as_indices(get_term(j), out) << "\n"; - - } else if (column_corresponds_to_term(j)) { + + } else if (column_has_term(j)) { const lar_term& t = get_term(m_var_register.local_to_external(j)); print_term_as_indices(t, out) << "\n"; } return out; } - + void subst_known_terms(lar_term*); inline std::ostream& print_column_bound_info(unsigned j, std::ostream& out) const { @@ -573,77 +605,74 @@ class lar_solver : public column_namer { inline bool has_inf_int() const { for (unsigned j = 0; j < column_count(); j++) { - if (column_is_int(j) && ! column_value_is_int(j)) + if (column_is_int(j) && !column_value_is_int(j)) return true; } return false; } - inline const vector & terms() const { return m_terms; } - inline lar_term const& term(unsigned i) const { return *m_terms[i]; } - inline void set_int_solver(int_solver * int_slv) { m_int_solver = int_slv; } - inline int_solver * get_int_solver() { return m_int_solver; } - inline const int_solver * get_int_solver() const { return m_int_solver; } - inline const lar_term & get_term(tv const& t) const { lp_assert(t.is_term()); return *m_terms[t.id()]; } - lp_status find_feasible_solution(); - void move_non_basic_columns_to_bounds(bool); - bool move_non_basic_column_to_bounds(unsigned j, bool); + inline const vector& terms() const { return m_terms; } + + inline void set_int_solver(int_solver* int_slv) { m_int_solver = int_slv; } + inline int_solver* get_int_solver() { return m_int_solver; } + inline const int_solver* get_int_solver() const { return m_int_solver; } + inline const lar_term& get_term(lpvar j) const { + return *m_columns[j].term(); + } + lp_status find_feasible_solution(); + void move_non_basic_columns_to_bounds(); + bool move_non_basic_column_to_bounds(unsigned j); inline bool r_basis_has_inf_int() const { for (unsigned j : r_basis()) { - if (column_is_int(j) && ! column_value_is_int(j)) + if (column_is_int(j) && !column_value_is_int(j)) return true; } return false; } void round_to_integer_solution(); - inline const row_strip & get_row(unsigned i) const { return A_r().m_rows[i]; } - inline const row_strip & basic2row(unsigned i) const { return A_r().m_rows[row_of_basic_column(i)]; } - inline const column_strip & get_column(unsigned i) const { return A_r().m_columns[i]; } + inline const row_strip& get_row(unsigned i) const { return A_r().m_rows[i]; } + inline const row_strip& basic2row(unsigned i) const { return A_r().m_rows[row_of_basic_column(i)]; } + inline const column_strip& get_column(unsigned i) const { return A_r().m_columns[i]; } bool row_is_correct(unsigned i) const; bool ax_is_correct() const; - bool get_equality_and_right_side_for_term_on_current_x(tv const& t, mpq &rs, constraint_index& ci, bool &upper_bound) const; - bool var_is_int(var_index v) const; - inline const vector & r_heading() const { return m_mpq_lar_core_solver.m_r_heading; } - inline const vector & r_basis() const { return m_mpq_lar_core_solver.r_basis(); } - inline const vector & r_nbasis() const { return m_mpq_lar_core_solver.r_nbasis(); } - inline bool column_is_real(unsigned j) const { return !column_is_int(j); } + bool get_equality_and_right_side_for_term_on_current_x(lpvar j, mpq& rs, u_dependency*& ci, bool& upper_bound) const; + bool var_is_int(lpvar v) const; + inline const vector& r_heading() const { return m_mpq_lar_core_solver.m_r_heading; } + inline const vector& r_basis() const { return m_mpq_lar_core_solver.r_basis(); } + inline const vector& r_nbasis() const { return m_mpq_lar_core_solver.r_nbasis(); } + inline bool column_is_real(unsigned j) const { return !column_is_int(j); } lp_status get_status() const; - bool has_changed_columns() const { return !m_columns_with_changed_bounds.empty(); } + bool has_changed_columns() const { return !m_columns_with_changed_bounds.empty(); } void set_status(lp_status s); lp_status solve(); - void fill_explanation_from_crossed_bounds_column(explanation & evidence) const; + void fill_explanation_from_crossed_bounds_column(explanation& evidence) const; bool term_is_used_as_row(unsigned term) const; - bool tighten_term_bounds_by_delta(tv const& t, const impq&); + bool tighten_term_bounds_by_delta(lpvar j, const impq&); lar_solver(); - void set_track_pivoted_rows(bool v); - bool get_track_pivoted_rows() const; + void track_touched_rows(bool v); + bool touched_rows_are_tracked() const; ~lar_solver() override; const vector& r_x() const { return m_mpq_lar_core_solver.m_r_x; } bool column_is_int(unsigned j) const; inline bool column_value_is_int(unsigned j) const { return m_mpq_lar_core_solver.m_r_x[j].is_int(); } - inline static_matrix & A_r() { return m_mpq_lar_core_solver.m_r_A; } - inline const static_matrix & A_r() const { return m_mpq_lar_core_solver.m_r_A; } + inline static_matrix& A_r() { return m_mpq_lar_core_solver.m_r_A; } + inline const static_matrix& A_r() const { return m_mpq_lar_core_solver.m_r_A; } // columns - bool column_is_int(column_index const& j) const { return column_is_int((unsigned)j); } -// const impq& get_ivalue(column_index const& j) const { return get_column_value(j); } - const impq& get_column_value(column_index const& j) const { return m_mpq_lar_core_solver.m_r_x[j]; } - inline - var_index external_to_local(unsigned j) const { - var_index local_j; - if (m_var_register.external_is_used(j, local_j) || - m_term_register.external_is_used(j, local_j)) { + const impq& get_column_value(lpvar j) const { return m_mpq_lar_core_solver.m_r_x[j]; } + inline lpvar external_to_local(unsigned j) const { + lpvar local_j; + if (m_var_register.external_is_used(j, local_j)) { return local_j; - } - else { + } else { return -1; } } - unsigned usage_in_terms(column_index j) const { + unsigned usage_in_terms(lpvar j) const { if (j >= m_usage_in_terms.size()) return 0; return m_usage_in_terms[j]; } + std::function m_find_monics_with_changed_bounds_func = nullptr; friend int_solver; friend int_branch; - }; -} +} // namespace lp diff --git a/src/math/lp/lar_term.h b/src/math/lp/lar_term.h index fc73f949f1f..6547377d365 100644 --- a/src/math/lp/lar_term.h +++ b/src/math/lp/lar_term.h @@ -26,8 +26,12 @@ namespace lp { class lar_term { typedef unsigned lpvar; u_map m_coeffs; + // the column index related to the term + lpvar m_j = -1; public: - lar_term() {} + // the column index related to the term + lpvar j() const { return m_j; } + lpvar& j() { return m_j; } void add_monomial(const mpq& c, unsigned j) { if (c.is_zero()) return; @@ -52,8 +56,7 @@ class lar_term { unsigned size() const { return static_cast(m_coeffs.size()); } - template - const T & coeffs() const { + u_map const & coeffs() const { return m_coeffs; } @@ -63,10 +66,11 @@ class lar_term { mpq a = it->get_data().m_value; this->m_coeffs.erase(term_column); for (auto p : t) { - this->add_monomial(a * p.coeff(), p.column()); + this->add_monomial(a * p.coeff(), p.j()); } } - + // constructors + lar_term() {} lar_term(const vector>& coeffs) { for (auto const& p : coeffs) { add_monomial(p.first, p.second); @@ -97,9 +101,8 @@ class lar_term { vector> coeffs_as_vector() const { vector> ret; - for (const auto & p : m_coeffs) { - ret.push_back(std::make_pair(p.m_value, p.m_key)); - } + for (const auto & [c, v] : m_coeffs) + ret.push_back({v, c}); return ret; } @@ -140,17 +143,23 @@ class lar_term { } return ret; } + + lar_term& operator*=(mpq const& k) { + for (auto & t : m_coeffs) + t.m_value *= k; + return *this; + } void clear() { m_coeffs.reset(); } class ival { - unsigned m_var; + lpvar m_var; const mpq & m_coeff; public: - ival(unsigned var, const mpq & val) : m_var(var), m_coeff(val) { } - column_index column() const { return column_index(m_var); } + ival(lpvar var, const mpq & val) : m_var(var), m_coeff(val) { } + lpvar j() const { return m_var; } const mpq & coeff() const { return m_coeff; } }; @@ -169,13 +178,13 @@ class lar_term { lpvar min_var = -1; mpq c; for (ival p : *this) { - if (p.column() < min_var) { - min_var = p.column(); + if (p.j() < min_var) { + min_var = p.j(); } } lar_term r; for (ival p : *this) { - if (p.column() == min_var) { + if (p.j() == min_var) { return p.coeff().is_one(); } } diff --git a/src/math/lp/lia_move.h b/src/math/lp/lia_move.h index ca61d7b7aba..12e3d8e35ea 100644 --- a/src/math/lp/lia_move.h +++ b/src/math/lp/lia_move.h @@ -19,34 +19,38 @@ Revision History: --*/ #pragma once namespace lp { -enum class lia_move { - sat, - branch, - cut, - conflict, - continue_with_check, - undef, - unsat -}; -inline std::string lia_move_to_string(lia_move m) { - switch (m) { - case lia_move::sat: - return "sat"; - case lia_move::branch: - return "branch"; - case lia_move::cut: - return "cut"; - case lia_move::conflict: - return "conflict"; - case lia_move::continue_with_check: - return "continue_with_check"; - case lia_move::undef: - return "undef"; - case lia_move::unsat: - return "unsat"; - default: - UNREACHABLE(); + enum class lia_move { + sat, + branch, + cut, + conflict, + continue_with_check, + undef, + unsat }; - return "strange"; -} + inline std::string lia_move_to_string(lia_move m) { + switch (m) { + case lia_move::sat: + return "sat"; + case lia_move::branch: + return "branch"; + case lia_move::cut: + return "cut"; + case lia_move::conflict: + return "conflict"; + case lia_move::continue_with_check: + return "continue_with_check"; + case lia_move::undef: + return "undef"; + case lia_move::unsat: + return "unsat"; + default: + UNREACHABLE(); + }; + return "strange"; + } + + inline std::ostream& operator<<(std::ostream& out, lia_move const& m) { + return out << lia_move_to_string(m); + } } diff --git a/src/math/lp/lp_api.h b/src/math/lp/lp_api.h index 2a4e5058d4f..325bc980bc8 100644 --- a/src/math/lp/lp_api.h +++ b/src/math/lp/lp_api.h @@ -31,7 +31,7 @@ namespace lp_api { class bound { Literal m_bv; theory_var m_var; - lp::lpvar m_vi; + lp::lpvar m_column_index; bool m_is_int; rational m_value; bound_kind m_bound_kind; @@ -41,7 +41,7 @@ namespace lp_api { bound(Literal bv, theory_var v, lp::lpvar vi, bool is_int, rational const& val, bound_kind k, lp::constraint_index ct, lp::constraint_index cf) : m_bv(bv), m_var(v), - m_vi(vi), + m_column_index(vi), m_is_int(is_int), m_value(val), m_bound_kind(k) { @@ -53,7 +53,7 @@ namespace lp_api { theory_var get_var() const { return m_var; } - lp::tv tv() const { return lp::tv::raw(m_vi); } + lp::lpvar column_index() const { return m_column_index; } Literal get_lit() const { return m_bv; } @@ -108,6 +108,7 @@ namespace lp_api { unsigned m_gomory_cuts; unsigned m_assume_eqs; unsigned m_branch; + unsigned m_bv_axioms; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); @@ -128,6 +129,7 @@ namespace lp_api { st.update("arith-gomory-cuts", m_gomory_cuts); st.update("arith-assume-eqs", m_assume_eqs); st.update("arith-branch", m_branch); + st.update("arith-bv-axioms", m_bv_axioms); } }; diff --git a/src/math/lp/lp_bound_propagator.h b/src/math/lp/lp_bound_propagator.h index dba93398ea9..daeac3b18ed 100644 --- a/src/math/lp/lp_bound_propagator.h +++ b/src/math/lp/lp_bound_propagator.h @@ -5,123 +5,34 @@ Lev Nachmanson (levnach) */ #pragma once -#include "math/lp/lp_settings.h" #include +#include "math/lp/lp_settings.h" +#include "util/uint_set.h" +#include "math/lp/implied_bound.h" +#include "util/vector.h" namespace lp { + template class lp_bound_propagator { - class edge; // forward definition - // vertex represents a column - // The set of vertices is organised in a tree. - // The edges of the tree are rows, - // Vertices with m_neg set to false grow with the same rate as the root. - // Vertices with m_neq set to true diminish with the same rate as the roow grows. - // When two vertices with the same m_neg have the same value of columns - // then we have an equality betweet the columns. - class vertex { - unsigned m_column; - vector m_edges; - edge m_edge_from_parent; - unsigned m_level; // the distance in hops to the root; - // it is handy to find the common ancestor - public: - vertex() {} - vertex(unsigned column) : - m_column(column), - m_level(0) - {} - unsigned column() const { return m_column; } - const vertex* parent() const { return m_edge_from_parent.source(); } - vertex* parent() { return m_edge_from_parent.source(); } - unsigned level() const { return m_level; } - void set_edge_from_parent(edge &e) { m_edge_from_parent = e; } - const edge& edge_from_parent() const { return m_edge_from_parent; } - - void add_child(int row, vertex* child) { - SASSERT(*this != *child); - SASSERT(child->parent() == nullptr); - edge e = edge(this, child, row); - m_edges.push_back(e); - child->set_edge_from_parent(e); - child->m_level = m_level + 1; - } - const vector & edges() const { return m_edges; } - bool operator==(const vertex& o) const { - return m_column == o.m_column; - } - bool operator!=(const vertex& o) const { - return m_column != o.m_column; - } - }; - - class edge { - vertex* m_source; - vertex* m_target; - int m_row; - public: - edge(vertex* source, vertex* target, int row) : m_source(source), m_target(target), m_row(row) {} - edge() : m_source(nullptr), m_target(nullptr), m_row(-1) {} - const vertex* source() const { return m_source; } - vertex* source() { return m_source; } - const vertex* target() const { return m_target; } - vertex* target() { return m_target; } - int row() const { return m_row; } - edge reverse() const { return edge(m_target, m_source, m_row); } - }; - - static int other(int x, int y, int z) { SASSERT(x == z || y == z); return x == z ? y : x; } - std::ostream& print_vert(std::ostream & out, const vertex* v) const { - out << "(c = " << v->column() << ", parent = {"; - if (v->parent()) - out << "(" << v->parent()->column() << ")"; - else - out << "null"; - out << "} , lvl = " << v->level(); - if (m_pol.contains(v->column())) - out << (pol(v) == -1? " -":" +"); - else - out << " not in m_pol"; - out << ')'; - return out; - } - - hashtable m_visited_rows; - hashtable m_visited_columns; - u_map m_vertices; - vertex* m_root = nullptr; - // At some point we can find a row with a single vertex non fixed vertex - // then we can fix the whole tree, - // by adjusting the vertices offsets, so they become absolute. - // If the tree is fixed then in addition to checking with the m_vals_to_verts - // we are going to check with the m_fixed_var_tables. - const vertex* m_fixed_vertex = nullptr; - explanation m_fixed_vertex_explanation; - // a pair (o, j) belongs to m_vals_to_verts iff x[j] = x[m_root->column()] + o - map, default_eq> m_vals_to_verts; - // a pair (o, j) belongs to m_vals_to_verts_neg iff -x[j] = x[m_root->column()] + o - map, default_eq> m_vals_to_verts_neg; - // x[m_root->column()] - m_pol[j].pol()*x[j] == const; - // to bind polarity and the vertex in the table - u_map m_pol; - // if m_pos.contains(j) then x[j] = x[m_root->column()] + o - uint_set m_pos; - + uint_set m_visited_rows; // these maps map a column index to the corresponding index in ibounds - std::unordered_map m_improved_lower_bounds; - std::unordered_map m_improved_upper_bounds; - - T& m_imp; - vector m_ibounds; + u_map m_improved_lower_bounds; + u_map m_improved_upper_bounds; + T& m_imp; + std_vector& m_ibounds; + map, default_eq> m_val2fixed_row; + // works for rows of the form x + y + sum of fixed = 0 + map, default_eq> m_row2index_pos; + // works for rows of the form x - y + sum of fixed = 0 + map, default_eq> m_row2index_neg; - map, default_eq> m_val2fixed_row; - - bool is_fixed_row(unsigned r, unsigned & x) { + const vector* m_column_types; + // returns true iff there is only one non-fixed column in the row + bool only_one_nfixed(unsigned r, unsigned& x) { x = UINT_MAX; - const auto & row = lp().get_row(r); - for (unsigned k = 0; k < row.size(); k++) { - const auto& c = row[k]; + for (const auto& c: lp().get_row(r)) { if (column_is_fixed(c.var())) continue; if (x != UINT_MAX) @@ -130,256 +41,149 @@ class lp_bound_propagator { } return x != UINT_MAX; } - +public: + const lar_solver& lp() const { return m_imp.lp(); } + lar_solver& lp() { return m_imp.lp(); } + bool upper_bound_is_available(unsigned j) const { + switch (get_column_type(j)) { + case column_type::fixed: + case column_type::boxed: + case column_type::upper_bound: + return true; + default: + return false; + } + } + bool lower_bound_is_available(unsigned j) const { + switch (get_column_type(j)) { + case column_type::fixed: + case column_type::boxed: + case column_type::lower_bound: + return true; + default: + return false; + } + } +private: void try_add_equation_with_internal_fixed_tables(unsigned r1) { - SASSERT(m_fixed_vertex); unsigned v1, v2; - if (!is_fixed_row(r1, v1)) + if (!only_one_nfixed(r1, v1)) return; unsigned r2 = UINT_MAX; if (!m_val2fixed_row.find(val(v1), r2) || r2 >= lp().row_count()) { m_val2fixed_row.insert(val(v1), r1); return; } - if (!is_fixed_row(r2, v2) || val(v1) != val(v2) || is_int(v1) != is_int(v2)) { + if (!only_one_nfixed(r2, v2) || val(v1) != val(v2) || is_int(v1) != is_int(v2)) { m_val2fixed_row.insert(val(v1), r1); return; } if (v1 == v2) return; - +#if Z3DEBUG + lp_assert(val(v1) == val(v2)); + unsigned debv1, debv2; + lp_assert(only_one_nfixed(r1, debv1) && only_one_nfixed(r2, debv2)); + lp_assert(debv1 == v1 && debv2 == v2); + lp_assert(ival(v1).y == ival(v2).y); +#endif explanation ex; explain_fixed_in_row(r1, ex); explain_fixed_in_row(r2, ex); TRACE("eq", print_row(tout, r1); print_row(tout, r2); tout << v1 << " == " << v2 << " = " << val(v1) << "\n"); add_eq_on_columns(ex, v1, v2, true); } - - void try_add_equation_with_lp_fixed_tables(unsigned row_index, const vertex *v) { - SASSERT(m_fixed_vertex); - unsigned v_j = v->column(); - unsigned j = null_lpvar; - if (!lp().find_in_fixed_tables(val(v_j), is_int(v_j), j)) { - try_add_equation_with_internal_fixed_tables(row_index); - return; - } - - TRACE("cheap_eq", - tout << "v_j = "; lp().print_column_info(v_j, tout) << std::endl; - tout << "v = "; print_vert(tout, v) << std::endl; - tout << "found j " << j << std::endl; lp().print_column_info(j, tout)<< std::endl; - tout << "found j = " << j << std::endl;); - vector path = connect_in_tree(v, m_fixed_vertex); - explanation ex = get_explanation_from_path(path); - ex.add_expl(m_fixed_vertex_explanation); - explain_fixed_column(j, ex); - add_eq_on_columns(ex, j, v_j, true); - } - void try_add_equation_with_val_table(const vertex *v) { - SASSERT(m_fixed_vertex); - unsigned v_j = v->column(); - const vertex *u = nullptr; - if (!m_vals_to_verts.find(val(v_j), u)) { - m_vals_to_verts.insert(val(v_j), v); - return; - } - unsigned j = u->column(); - if (j == v_j || is_int(j) != is_int(v_j)) - return; - - TRACE("cheap_eq", tout << "found j=" << j << " for v="; - print_vert(tout, v) << "\n in m_vals_to_verts\n";); - vector path = connect_in_tree(u, v); - explanation ex = get_explanation_from_path(path); - ex.add_expl(m_fixed_vertex_explanation); - add_eq_on_columns(ex, j, v_j, true); - } - - static bool not_set(unsigned j) { return j == UINT_MAX; } + static bool is_not_set(unsigned j) { return j == UINT_MAX; } static bool is_set(unsigned j) { return j != UINT_MAX; } - - void create_root(unsigned row_index) { - SASSERT(!m_root && !m_fixed_vertex); - unsigned x, y; - int polarity; - TRACE("cheap_eq_det", print_row(tout, row_index);); - if (!is_tree_offset_row(row_index, x, y, polarity)) { - TRACE("cheap_eq_det", tout << "not an offset row\n";); - return; - } - TRACE("cheap_eq", print_row(tout, row_index);); - m_root = alloc_v(x); - set_polarity(m_root, 1); // keep m_root in the positive table - if (not_set(y)) { - set_fixed_vertex(m_root); - explain_fixed_in_row(row_index, m_fixed_vertex_explanation); - } - else { - vertex *v = add_child_with_check(row_index, y, m_root, polarity); - if (v) - explore_under(v); - } - explore_under(m_root); - } - - void explore_under(vertex * v) { - check_for_eq_and_add_to_val_tables(v); - go_over_vertex_column(v); - } - - // In case of only one non fixed column, and the function returns true, - // this column would be represened by x. - bool is_tree_offset_row(unsigned row_index, unsigned & x, unsigned & y, int & polarity) const { - x = y = UINT_MAX; - const row_cell* x_cell = nullptr; - const row_cell* y_cell = nullptr; - const auto & row = lp().get_row(row_index); - for (unsigned k = 0; k < row.size(); k++) { - const auto& c = row[k]; - if (column_is_fixed(c.var())) - continue; - if (not_set(x)) { - if (c.coeff().is_one() || c.coeff().is_minus_one()) { - x = c.var(); - x_cell = & c; - } - else - return false; - } - else if (not_set(y)) { - if (c.coeff().is_one() || c.coeff().is_minus_one()) { - y = c.var(); - y_cell = & c; - } - else - return false; - } - else - return false; - } - if (is_set(x)) { - if (is_set(y)) - polarity = x_cell->coeff().is_pos() == y_cell->coeff().is_pos()? -1 : 1; - else - polarity = 1; - return true; - } - return false; - } - - void go_over_vertex_column(vertex * v) { - lpvar j = v->column(); - if (!check_insert(m_visited_columns, j)) - return; - - for (const auto & c : lp().get_column(j)) { - unsigned row_index = c.var(); - if (!check_insert(m_visited_rows, row_index)) - continue; - vertex* u = get_child_from_row(row_index, v); - if (u) - explore_under(u); - } - } void reset_cheap_eq_eh() { - if (!m_root) - return; - delete_tree(m_root); - m_root = nullptr; - set_fixed_vertex(nullptr); - m_fixed_vertex_explanation.clear(); - m_vals_to_verts.reset(); - m_vals_to_verts_neg.reset(); - m_pol.reset(); - m_vertices.reset(); + m_row2index_pos.reset(); + m_row2index_neg.reset(); } - + struct reset_cheap_eq { lp_bound_propagator& p; - reset_cheap_eq(lp_bound_propagator& p):p(p) {} + reset_cheap_eq(lp_bound_propagator& p) : p(p) {} ~reset_cheap_eq() { p.reset_cheap_eq_eh(); } }; - public: - - lp_bound_propagator(T& imp): - m_imp(imp) {} + lp_bound_propagator(T& imp, std_vector & ibounds) : m_imp(imp), m_ibounds(ibounds) {} + + const std_vector& ibounds() const { return m_ibounds; } - const vector& ibounds() const { return m_ibounds; } - void init() { - m_improved_upper_bounds.clear(); - m_improved_lower_bounds.clear(); - m_ibounds.reset(); + m_improved_upper_bounds.reset(); + m_improved_lower_bounds.reset(); + m_ibounds.clear(); + m_column_types = &lp().get_column_types(); } - - const lar_solver& lp() const { return m_imp.lp(); } - lar_solver& lp() { return m_imp.lp(); } - + column_type get_column_type(unsigned j) const { - return m_imp.lp().get_column_type(j); + return (*m_column_types)[j]; } - - const impq & get_lower_bound(unsigned j) const { - return m_imp.lp().get_lower_bound(j); + + const impq& get_lower_bound(unsigned j) const { + return lp().get_lower_bound(j); } - const mpq & get_lower_bound_rational(unsigned j) const { - return m_imp.lp().get_lower_bound(j).x; + const mpq& get_lower_bound_rational(unsigned j) const { + return lp().get_lower_bound(j).x; } - - const impq & get_upper_bound(unsigned j) const { - return m_imp.lp().get_upper_bound(j); + + const impq& get_upper_bound(unsigned j) const { + return lp().get_upper_bound(j); } - const mpq & get_upper_bound_rational(unsigned j) const { - return m_imp.lp().get_upper_bound(j).x; + const mpq& get_upper_bound_rational(unsigned j) const { + return lp().get_upper_bound(j).x; } // require also the zero infinitesemal part bool column_is_fixed(lpvar j) const { - return lp().column_is_fixed(j) && get_lower_bound(j).y.is_zero(); + return (*m_column_types)[j] == column_type::fixed && get_lower_bound(j).y.is_zero(); } - - void try_add_bound(mpq const& v, unsigned j, bool is_low, bool coeff_before_j_is_pos, unsigned row_or_term_index, bool strict) { - j = m_imp.lp().column_to_reported_index(j); - lconstraint_kind kind = is_low? GE : LE; + + void add_bound(mpq const& v, unsigned j, bool is_low, bool strict, std::function explain_bound) { + lconstraint_kind kind = is_low ? GE : LE; if (strict) kind = static_cast(kind / 2); - + if (!m_imp.bound_is_interesting(j, kind, v)) return; - unsigned k; // index to ibounds if (is_low) { - if (try_get_value(m_improved_lower_bounds, j, k)) { - auto & found_bound = m_ibounds[k]; + unsigned k; + if (m_improved_lower_bounds.find(j, k)) { + auto& found_bound = m_ibounds[k]; if (v > found_bound.m_bound || (v == found_bound.m_bound && !found_bound.m_strict && strict)) { - found_bound = implied_bound(v, j, is_low, coeff_before_j_is_pos, row_or_term_index, strict); - TRACE("try_add_bound", m_imp.lp().print_implied_bound(found_bound, tout);); + + found_bound.m_bound = v; + found_bound.m_strict = strict; + found_bound.set_explain(explain_bound); + TRACE("add_bound", lp().print_implied_bound(found_bound, tout);); } + } else { + m_improved_lower_bounds.insert(j, static_cast(m_ibounds.size())); + m_ibounds.push_back(implied_bound(v, j, is_low, strict, explain_bound)); + TRACE("add_bound", lp().print_implied_bound(m_ibounds.back(), tout);); } - else { - m_improved_lower_bounds[j] = m_ibounds.size(); - m_ibounds.push_back(implied_bound(v, j, is_low, coeff_before_j_is_pos, row_or_term_index, strict)); - TRACE("try_add_bound", m_imp.lp().print_implied_bound(m_ibounds.back(), tout);); - } - } - else { // the upper bound case - if (try_get_value(m_improved_upper_bounds, j, k)) { - auto & found_bound = m_ibounds[k]; + } else { // the upper bound case + unsigned k; + if (m_improved_upper_bounds.find(j, k)) { + auto& found_bound = m_ibounds[k]; if (v < found_bound.m_bound || (v == found_bound.m_bound && !found_bound.m_strict && strict)) { - found_bound = implied_bound(v, j, is_low, coeff_before_j_is_pos, row_or_term_index, strict); - TRACE("try_add_bound", m_imp.lp().print_implied_bound(found_bound, tout);); + + found_bound.m_bound = v; + found_bound.m_strict = strict; + found_bound.set_explain(explain_bound); + TRACE("add_bound", lp().print_implied_bound(found_bound, tout);); } - } - else { - m_improved_upper_bounds[j] = m_ibounds.size(); - m_ibounds.push_back(implied_bound(v, j, is_low, coeff_before_j_is_pos, row_or_term_index, strict)); - TRACE("try_add_bound", m_imp.lp().print_implied_bound(m_ibounds.back(), tout);); + } else { + m_improved_upper_bounds.insert(j, static_cast(m_ibounds.size())); + m_ibounds.push_back(implied_bound(v, j, is_low, strict, explain_bound)); + TRACE("add_bound", lp().print_implied_bound(m_ibounds.back(), tout);); } } } @@ -388,195 +192,43 @@ class lp_bound_propagator { m_imp.consume(a, ci); } - const mpq& val(unsigned j) const { - return lp().get_column_value(j).x; + const mpq& val(unsigned j) const { + return lp().get_column_value(j).x; // figure out why it is safe to return .x } - const mpq& val(const vertex* v) const { - return val(v->column()); + const impq& ival(unsigned j) const { + return lp().get_column_value(j); // figure out why it is safe to return .x } - - bool tree_contains_r(vertex* root, vertex *v) const { - if (*root == *v) - return true; - for (auto e : root->edges()) - if (tree_contains_r(e.target(), v)) - return true; - return false; - } - - // pol for polarity - int pol(const vertex* v) const { return pol(v->column()); } - int pol(unsigned j) const { return m_pol[j]; } - void set_polarity(const vertex* v, int p) { - SASSERT(p == 1 || p == -1); - unsigned j = v->column(); - SASSERT(!m_pol.contains(j)); - m_pol.insert(j, p); - } - - void check_and_set_polarity(vertex* v, int polarity, unsigned row_index, vertex*v_parent) { - int prev_pol; - if (!m_pol.find(v->column(), prev_pol)) { - set_polarity(v, polarity); - return; - } - if (prev_pol == polarity) - return; - // we have a path L between v and parent with p(L) = -1, that means we can - // create an equality of the form x + x = a, where x = v->column() = u->column() - vector path = connect_in_tree(v, v_parent); - m_fixed_vertex_explanation = get_explanation_from_path(path); - explain_fixed_in_row(row_index, m_fixed_vertex_explanation); - set_fixed_vertex(v); - TRACE("cheap_eq", - tout << "polarity switch: " << polarity << "\nv = "; print_vert(tout , v) << "\nu = "; tout << "fixed vertex explanation\n"; - for (auto p : m_fixed_vertex_explanation) - lp().constraints().display(tout, [this](lpvar j) { return lp().get_variable_name(j);}, p.ci());); - - } - - bool tree_contains(vertex *v) const { - if (!m_root) - return false; - return tree_contains_r(m_root, v); - } - - vertex * alloc_v(unsigned column) { - vertex * v = alloc(vertex, column); - m_vertices.insert(column, v); - SASSERT(!tree_contains(v)); - return v; - } - unsigned column(unsigned row, unsigned index) { return lp().get_row(row)[index].var(); } - bool fixed_phase() const { return m_fixed_vertex; } - - - // Returns the vertex to start exploration from, or nullptr. - // It is assumed that parent->column() is present in the row - vertex* get_child_from_row(unsigned row_index, vertex* parent) { - TRACE("cheap_eq_det", print_row(tout, row_index);); - unsigned x, y; int row_polarity; - if (!is_tree_offset_row(row_index, x, y, row_polarity)) { - TRACE("cheap_eq_det", tout << "not an offset row\n"; ); - return nullptr; - } - if (not_set(y)) { // there is only one fixed variable in the row - if (!fixed_phase()) { - set_fixed_vertex(parent); - explain_fixed_in_row(row_index, m_fixed_vertex_explanation); - } - return nullptr; - } - - SASSERT(is_set(x) && is_set(y)); - unsigned col = other(x, y, parent->column()); - return add_child_with_check(row_index, col, parent, row_polarity); - } - - vertex * add_child_with_check(unsigned row_index, unsigned col, vertex* parent, int row_polarity) { - vertex* vy; - if (m_vertices.find(col, vy)) { - SASSERT(vy != nullptr); - if (!fixed_phase()) { - check_and_set_polarity(vy, pol(parent) * row_polarity, row_index, parent); - } - return nullptr; // it is not a new vertex - } - vy = alloc_v(col); - parent->add_child(row_index, vy); - if (!fixed_phase()) - check_and_set_polarity(vy, row_polarity * pol(parent), row_index, parent); - return vy; - } - bool is_equal(lpvar j, lpvar k) const { + bool is_equal(lpvar j, lpvar k) const { return m_imp.is_equal(col_to_imp(j), col_to_imp(k)); } - void check_for_eq_and_add_to_val_table(vertex* v, map, default_eq>& table) { - TRACE("cheap_eq", tout << "v = "; print_vert(tout, v) << "\n";); - const vertex *k; // the other vertex - if (table.find(val(v), k)) { - TRACE("cheap_eq", tout << "found k " ; print_vert(tout, k) << "\n";); - if (k->column() != v->column() && - is_int(k->column()) == is_int(v->column()) && - !is_equal(k->column(), v->column())) { - report_eq(k, v); - } - else { - TRACE("cheap_eq", tout << "no report\n";); - } - } - else { - TRACE("cheap_eq", tout << "registered: " << val(v) << " -> { "; print_vert(tout, v) << "} \n";); - table.insert(val(v), v); - } - } - - void check_for_eq_and_add_to_val_tables(vertex* v) { - TRACE("cheap_eq_det", print_vert(tout, v) << "\n";); - if (!fixed_phase()) { - if (pol(v->column()) == -1) - check_for_eq_and_add_to_val_table(v, m_vals_to_verts_neg); - else - check_for_eq_and_add_to_val_table(v, m_vals_to_verts); - } - } - void clear_for_eq() { m_visited_rows.reset(); - m_visited_columns.reset(); - m_root = nullptr; } - std::ostream& print_edge(const edge& e, std::ostream& out) const { - out << e.source()->column() << "->" << e.target()->column() << "\n"; - return print_row(out, e.row()); - } - - std::ostream& print_path(const vector& path, std::ostream& out) const { - out << "path = \n"; - for (const edge& k : path) - print_edge(k, out) << "\n"; + std::ostream& print_expl(std::ostream& out, const explanation& exp) const { + for (auto p : exp) + lp().constraints().display( + out, [this](lpvar j) { return lp().get_variable_name(j); }, p.ci()); return out; } - - // we have v_i and v_j, indices of vertices at the same offsets - void report_eq(const vertex* v_i, const vertex* v_j) { - SASSERT(v_i != v_j); - SASSERT(lp().get_column_value(v_i->column()) == lp().get_column_value(v_j->column())); - TRACE("cheap_eq", tout << v_i->column() << " = " << v_j->column() << "\nu = "; - print_vert(tout, v_i) << "\nv = "; print_vert(tout, v_j) <<"\n"); - - vector path = connect_in_tree(v_i, v_j); - lp::explanation exp = get_explanation_from_path(path); - add_eq_on_columns(exp, v_i->column(), v_j->column(), false); - - } - - std::ostream& print_expl(std::ostream & out, const explanation& exp) const { - for (auto p : exp) - lp().constraints().display(out, [this](lpvar j) { return lp().get_variable_name(j);}, p.ci()); - return out; - } - - bool add_eq_on_columns(const explanation& exp, lpvar j, lpvar k, bool is_fixed) { - SASSERT(j != k); - unsigned je = lp().column_to_reported_index(j); - unsigned ke = lp().column_to_reported_index(k); - TRACE("cheap_eq", - tout << "reporting eq " << j << ", " << k << "\n"; + + bool add_eq_on_columns(const explanation& exp, lpvar je, lpvar ke, bool is_fixed) { + lp_assert(je != ke && is_int(je) == is_int(ke)); + lp_assert(ival(je) == ival(ke)); + + TRACE("eq", tout << "reported idx " << je << ", " << ke << "\n"; print_expl(tout, exp); - tout << "theory_vars v" << lp().local_to_external(je) << " == v" << lp().local_to_external(ke) << "\n"; - ); - + tout << "theory_vars v" << lp().local_to_external(je) << " == v" << lp().local_to_external(ke) << "\n";); + bool added = m_imp.add_eq(je, ke, exp, is_fixed); if (added) { if (is_fixed) @@ -589,173 +241,211 @@ class lp_bound_propagator { // column to theory_var unsigned col_to_imp(unsigned j) const { - return lp().local_to_external(lp().column_to_reported_index(j)); + return lp().local_to_external(j); } // theory_var to column unsigned imp_to_col(unsigned j) const { - return lp().external_to_column_index(j); + return lp().external_to_local(j); } bool is_int(lpvar j) const { return lp().column_is_int(j); } - - explanation get_explanation_from_path(vector& path) const { - explanation ex; - for (edge &e : path) - explain_fixed_in_row(e.row(), ex); - return ex; - } - void explain_fixed_in_row(unsigned row, explanation& ex) const { - TRACE("cheap_eq", tout << lp().get_row(row) << std::endl); - for (const auto & c : lp().get_row(row)) - if (lp().is_fixed(c.var())) + void explain_fixed_in_row(unsigned row, explanation& ex) { + TRACE("eq", tout << lp().get_row(row) << std::endl); + for (const auto& c : lp().get_row(row)) + if (lp().column_is_fixed(c.var())) explain_fixed_column(c.var(), ex); } - void explain_fixed_column(unsigned j, explanation & ex) const { - SASSERT(column_is_fixed(j)); - constraint_index lc, uc; - lp().get_bound_constraint_witnesses_for_column(j, lc, uc); - ex.push_back(lc); - ex.push_back(uc); - } - - vector connect_in_tree(const vertex* u, const vertex* v) const { - vector path; - TRACE("cheap_eq_details", tout << "u = " ; print_vert(tout, u); tout << "\nv = ";print_vert(tout, v) << "\n";); - vector v_branch; - // equalize the levels - while (u->level() > v->level()) { - path.push_back(u->edge_from_parent().reverse()); - u = u->parent(); - } - - while (u->level() < v->level()) { - v_branch.push_back(v->edge_from_parent()); - v = v->parent(); - } - SASSERT(u->level() == v->level()); - TRACE("cheap_eq_details", tout << "u = " ; print_vert(tout, u); tout << "\nv = "; print_vert(tout, v) << "\n";); - while (u != v) { - path.push_back(u->edge_from_parent().reverse()); - v_branch.push_back(v->edge_from_parent()); - u = u->parent(); - v = v->parent(); - } - for (unsigned i = v_branch.size(); i--; ) { - path.push_back(v_branch[i]); + unsigned explain_fixed_in_row_and_get_base(unsigned row, explanation& ex) { + unsigned base = UINT_MAX; + TRACE("eq", tout << lp().get_row(row) << std::endl); + for (const auto& c : lp().get_row(row)) { + if (lp().column_is_fixed(c.var())) { + explain_fixed_column(c.var(), ex); + } + else if (lp().is_base(c.var())) { + base = c.var(); + } } - TRACE("cheap_eq", print_path(path, tout);); - return path; + return base; } - - bool tree_is_correct() const { - std::unordered_set vs; - return tree_is_correct(m_root, vs); - } - - bool tree_is_correct(vertex* v, std::unordered_set& visited_verts) const { - if (fixed_phase()) - return true; - if (visited_verts.find(v->column()) != visited_verts.end()) - return false; - visited_verts.insert(v->column()); - for (auto e : v->edges()) - if (!tree_is_correct(e.target(), visited_verts)) + + void explain_fixed_column(unsigned j, explanation& ex) { + SASSERT(column_is_fixed(j)); + auto* deps = lp().get_bound_constraint_witnesses_for_column(j); + for (auto ci : lp().flatten(deps)) + ex.push_back(ci); + } +#ifdef Z3DEBUG + bool all_fixed_in_row(unsigned row) const { + for (const auto& c : lp().get_row(row)) + if (!lp().column_is_fixed(c.var())) return false; return true; } - std::ostream& print_tree(std::ostream & out, vertex* v) const { - print_vert(out, v); - out << "\nchildren :\n"; - for (auto c : v->edges()) { - out << "row = "; - print_row(out, c.row()); - print_tree(out, c.target()); + + // bounded by 2 + unsigned num_of_non_fixed_in_row(unsigned row_index) const { + unsigned n_of_nfixed = 0; + for (const auto& c : lp().get_row(row_index)) { + if (column_is_fixed(c.var())) + continue; + n_of_nfixed++; + if (n_of_nfixed > 1) + return n_of_nfixed; + } + + return n_of_nfixed; + } +#endif + // Let nf is the number of non-fixed columns in the row. + // Then the function returns min(nf, 3). + // if nf == 0, the row is of the form sum of fixed = 0 + // if nf == 1, the row is of the form x + sum of fixed = 0, where x is not fixed base + // if nf == 2, the row is of the form x + ay + sum of fixed = 0, x is a non fixed base and y is not fixed + // y_sign is set to a, if abs(a)= 1, and 0 otherwise + + unsigned extract_non_fixed(unsigned row_index, unsigned& x, unsigned& y, int& y_sign) const { + unsigned nf = 0; // number of non-fixed columns + y = UINT_MAX; + const auto& row = lp().get_row(row_index); + x = lp().get_base_column_in_row(row_index); + if (!column_is_fixed(x)) { + nf++; + } else { + // we have a fixed base column, exiting + return 0; + } + + for (const auto& c : row) { + unsigned j = c.var(); + if (j == x) continue; + if (column_is_fixed(j)) + continue; + if (++nf > 2) + return nf; + lp_assert(is_not_set(y)); + y = j; + if (c.coeff().is_one()) { + y_sign = 1; + } else if (c.coeff().is_minus_one()) { + y_sign = -1; + } else { + // y has a coefficient other than 1 or -1 + y_sign = 0; + return nf; // maybe be too low but we don't care + } } - return out; - } - void try_add_equation_with_fixed_tables(unsigned row_index, const vertex* v) { - try_add_equation_with_lp_fixed_tables(row_index, v); - try_add_equation_with_val_table(v); + return nf; } - - void handle_fixed_phase(unsigned row_index) { - if (!fixed_phase()) + + void try_add_equation_with_lp_fixed_tables(unsigned row_index, unsigned v_j) { + lp_assert(lp().get_base_column_in_row(row_index) == v_j); + lp_assert(num_of_non_fixed_in_row(row_index) == 1 || column_is_fixed(v_j)); + if (column_is_fixed(v_j)) { + return; + } + unsigned j = null_lpvar; + if (!lp().find_in_fixed_tables(val(v_j), is_int(v_j), j)) { + try_add_equation_with_internal_fixed_tables(row_index); return; - const vertex* v = m_root; - try_add_equation_with_fixed_tables(row_index, v); - for (auto e: v->edges()) - try_add_equation_with_fixed_tables(row_index, e.target()); + } + TRACE("eq", + tout << "v_j = "; + lp().print_column_info(v_j, tout) << std::endl; + tout << "found j " << j << std::endl; lp().print_column_info(j, tout) << std::endl; + print_row(tout, row_index) << std::endl; + ); + explanation ex; + explain_fixed_in_row(row_index, ex); + explain_fixed_column(j, ex); + add_eq_on_columns(ex, j, v_j, true); } - - void cheap_eq_tree(unsigned row_index) { + void cheap_eq_on_nbase(unsigned row_index) { reset_cheap_eq _reset(*this); - TRACE("cheap_eq_det", tout << "row_index = " << row_index << "\n";); - if (!check_insert(m_visited_rows, row_index)) + TRACE("eq", tout << "row_index = " << row_index << "\n"; + print_row(tout, row_index) << "\n";); + if (!check_insert(m_visited_rows, row_index)) return; - create_root(row_index); - if (!m_root) + unsigned x, y; + int y_sign; + unsigned nf = extract_non_fixed(row_index, x, y, y_sign); + if (nf == 0 || nf > 2) + return; + if (nf == 1) { + lp_assert(is_not_set(y)); + try_add_equation_with_lp_fixed_tables(row_index, x); + return; + } + if (y_sign == 0) { + // the coefficient before y is not 1 or -1 return; - - TRACE("cheap_eq", tout << "tree = "; print_tree(tout, m_root) << "\n";); - SASSERT(tree_is_correct()); - handle_fixed_phase(row_index); - - TRACE("cheap_eq", - tout << "done for row_index " << row_index << "\n"; - tout << "tree size = " << verts_size();); - } - - std::ostream& print_row(std::ostream & out, unsigned row_index) const { - unsigned x, y; int polarity; - if (true || !is_tree_offset_row(row_index, x, y, polarity)) - return lp().get_int_solver()->display_row_info(out, row_index); - + } + lp_assert(y_sign == -1 || y_sign == 1); + lp_assert(lp().is_base(y) == false); + auto& table = y_sign == 1 ? m_row2index_pos : m_row2index_neg; + table.insert(val(x), row_index); + TRACE("eq", tout << "y = " << y << "\n";); + + for (const column_cell& c : lp().get_column(y)) { + unsigned i = c.var(); // the running index of the row + if (i == row_index) + continue; + if (!check_insert(m_visited_rows, i)) + continue; + unsigned y_nb; + nf = extract_non_fixed(i, x, y_nb, y_sign); + if (nf != 2 || y_sign == 0) + continue; + + lp_assert(y_nb == y); + lp_assert(y_sign == 1 || y_sign == -1); + auto& table = y_sign == 1 ? m_row2index_pos : m_row2index_neg; + const auto& v = val(x); + unsigned found_i;; + + if (!table.find(v, found_i)) { + table.insert(v, i); + } else { + explanation ex; + unsigned base_of_found = lp().get_base_column_in_row(found_i); + if (is_int(x) != is_int(base_of_found) || ival(x).y != ival(base_of_found).y) + continue; + explain_fixed_in_row(found_i, ex); + explain_fixed_in_row(i, ex); + TRACE("eq", { + print_row(tout, i); + print_row(tout, found_i) << "\n"; + lp().print_column_info(base_of_found, tout); + lp().print_column_info(x, tout) << "\n"; + }); + add_eq_on_columns(ex, x, base_of_found, false); + } + } + } + + std::ostream& print_row(std::ostream& out, unsigned row_index) const { bool first = true; - for (const auto &c: lp().A_r().m_rows[row_index]) { + for (const auto& c : lp().A_r().m_rows[row_index]) { if (lp().column_is_fixed(c.var())) continue; if (c.coeff().is_one()) { - if (!first) - out << "+"; - } - else if (c.coeff().is_minus_one()) - out << "-"; + if (!first) + out << "+"; + } else if (c.coeff().is_minus_one()) + out << "-"; out << lp().get_variable_name(c.var()) << " "; - first = false; + first = false; } out << "\n"; return out; } - - void set_fixed_vertex(vertex *v) { - TRACE("cheap_eq", if (v) print_vert(tout, v); else tout << "set m_fixed_vertex to nullptr"; tout << "\n";); - SASSERT(!m_fixed_vertex || v == nullptr); - m_fixed_vertex = v; - } - - unsigned verts_size() const { - return subtree_size(m_root); - } - - unsigned subtree_size(vertex* v) const { - unsigned r = 1; // 1 for v - for (auto e : v->edges()) - r += subtree_size(e.target()); - return r; - } - - void delete_tree(vertex * v) { - for (auto p : v->edges()) - delete_tree(p.target()); - dealloc(v); - } template bool check_insert(C& table, unsigned j) { @@ -763,7 +453,6 @@ class lp_bound_propagator { return false; table.insert(j); return true; - } - + } }; -} +} // namespace lp diff --git a/src/math/lp/lp_core_solver_base.cpp b/src/math/lp/lp_core_solver_base.cpp index f1ae95ea0b9..e91fc15fca6 100644 --- a/src/math/lp/lp_core_solver_base.cpp +++ b/src/math/lp/lp_core_solver_base.cpp @@ -23,10 +23,8 @@ Revision History: #include "util/vector.h" #include #include "math/lp/lp_core_solver_base_def.h" -template bool lp::lp_core_solver_base >::print_statistics_with_iterations_and_nonzeroes_and_cost_and_check_that_the_time_is_over(char const*, std::ostream &); template bool lp::lp_core_solver_base::basis_heading_is_correct() const ; template bool lp::lp_core_solver_base::column_is_dual_feasible(unsigned int) const; -template bool lp::lp_core_solver_base::print_statistics_with_iterations_and_nonzeroes_and_cost_and_check_that_the_time_is_over(char const*, std::ostream &); template void lp::lp_core_solver_base::add_delta_to_entering(unsigned int, const lp::mpq&); template void lp::lp_core_solver_base >::init(); template void lp::lp_core_solver_base >::init_basis_heading_and_non_basic_columns_vector(); @@ -35,7 +33,6 @@ template lp::lp_core_solver_base >::lp_core_s vector&, vector &, vector &, vector >&, vector&, lp::lp_settings&, const column_namer&, const vector&, const vector >&, const vector >&); -template bool lp::lp_core_solver_base >::print_statistics_with_cost_and_check_that_the_time_is_over(lp::numeric_pair, std::ostream&); template void lp::lp_core_solver_base >::add_delta_to_entering(unsigned int, const lp::numeric_pair&); template lp::lp_core_solver_base::lp_core_solver_base( @@ -50,7 +47,6 @@ template lp::lp_core_solver_base::lp_core_solver_base( const vector&, const vector&, const vector&); -template bool lp::lp_core_solver_base >::print_statistics_with_iterations_and_check_that_the_time_is_over(std::ostream &); template std::string lp::lp_core_solver_base::column_name(unsigned int) const; template void lp::lp_core_solver_base::pretty_print(std::ostream & out); template std::string lp::lp_core_solver_base >::column_name(unsigned int) const; @@ -63,8 +59,8 @@ template bool lp::lp_core_solver_base >::colu template bool lp::lp_core_solver_base>::pivot_column_tableau(unsigned int, unsigned int); template bool lp::lp_core_solver_base::pivot_column_tableau(unsigned int, unsigned int); template void lp::lp_core_solver_base >::transpose_rows_tableau(unsigned int, unsigned int); -template bool lp::lp_core_solver_base >::inf_set_is_correct() const; -template bool lp::lp_core_solver_base::inf_set_is_correct() const; -template bool lp::lp_core_solver_base >::remove_from_basis(unsigned int); +template bool lp::lp_core_solver_base >::inf_heap_is_correct() const; +template bool lp::lp_core_solver_base::inf_heap_is_correct() const; +template bool lp::lp_core_solver_base >::remove_from_basis_core(unsigned int, unsigned int); diff --git a/src/math/lp/lp_core_solver_base.h b/src/math/lp/lp_core_solver_base.h index fb0c2850795..a5b79e9c8f5 100644 --- a/src/math/lp/lp_core_solver_base.h +++ b/src/math/lp/lp_core_solver_base.h @@ -19,6 +19,7 @@ Revision History: --*/ #pragma once #include +#include #include "util/vector.h" #include #include "math/lp/lp_utils.h" @@ -27,10 +28,14 @@ Revision History: #include "math/lp/static_matrix.h" #include "math/lp/permutation_matrix.h" #include "math/lp/column_namer.h" -#include "math/lp/u_set.h" - +#include "util/uint_set.h" +#include "util/heap.h" namespace lp { +struct lpvar_lt { + bool operator()(lpvar v1, lpvar v2) const { return v1 < v2; } +}; +typedef heap lpvar_heap; template X dot_product(const vector & a, const vector & b) { lp_assert(a.size() == b.size()); @@ -50,25 +55,25 @@ class lp_core_solver_base { lp_status m_status; public: bool current_x_is_feasible() const { - TRACE("feas", - if (m_inf_set.size()) { - tout << "column " << *m_inf_set.begin() << " is infeasible" << std::endl; - print_column_info(*m_inf_set.begin(), tout); + TRACE("feas_bug", + if (!m_inf_heap.empty()) { + tout << "column " << *m_inf_heap.begin() << " is infeasible" << std::endl; + print_column_info(*m_inf_heap.begin(), tout); } else { tout << "x is feasible\n"; } ); - return m_inf_set.size() == 0; + return m_inf_heap.empty(); } - bool current_x_is_infeasible() const { return m_inf_set.size() != 0; } + bool current_x_is_infeasible() const { return m_inf_heap.size() != 0; } private: - u_set m_inf_set; + lpvar_heap m_inf_heap; public: - const u_set& inf_set() const { return m_inf_set; } - u_set& inf_set() { return m_inf_set; } - void inf_set_increase_size_by_one() { m_inf_set.increase_size_by_one(); } - bool inf_set_contains(unsigned j) const { return m_inf_set.contains(j); } - unsigned inf_set_size() const { return m_inf_set.size(); } + const lpvar_heap& inf_heap() const { return m_inf_heap; } + lpvar_heap& inf_heap() { return m_inf_heap; } + void inf_heap_increase_size_by_one() { m_inf_heap.reserve(m_inf_heap.size() + 1); } + bool inf_heap_contains(unsigned j) const { return m_inf_heap.contains(j); } + unsigned inf_heap_size() const { return m_inf_heap.size(); } indexed_vector m_pivot_row; // this is the real pivot row of the simplex tableu static_matrix & m_A; // the matrix A // vector const & m_b; // the right side @@ -84,10 +89,11 @@ class lp_core_solver_base { const vector & m_column_types; const vector & m_lower_bounds; const vector & m_upper_bounds; - unsigned m_basis_sort_counter; + unsigned m_nbasis_sort_counter; vector m_trace_of_basis_change_vector; // the even positions are entering, the odd positions are leaving bool m_tracing_basis_changes; - u_set* m_pivoted_rows; + // these rows are changed by adding to them a multiple of the pivot row + indexed_uint_set* m_touched_rows = nullptr; bool m_look_for_feasible_solution_only; void start_tracing_basis_changes() { @@ -160,10 +166,6 @@ class lp_core_solver_base { void print_statistics(char const* str, X cost, std::ostream & message_stream); - bool print_statistics_with_iterations_and_check_that_the_time_is_over(std::ostream & message_stream); - - bool print_statistics_with_iterations_and_nonzeroes_and_cost_and_check_that_the_time_is_over(char const* str, std::ostream & message_stream); - bool print_statistics_with_cost_and_check_that_the_time_is_over(X cost, std::ostream & message_stream); unsigned total_iterations() const { return m_total_iterations; } @@ -255,7 +257,7 @@ class lp_core_solver_base { bool calc_current_x_is_feasible_include_non_basis() const; - bool inf_set_is_correct() const; + bool inf_heap_is_correct() const; bool column_is_dual_feasible(unsigned j) const; @@ -272,7 +274,7 @@ class lp_core_solver_base { bool non_basis_has_no_doubles() const; bool basis_is_correctly_represented_in_heading() const ; - bool non_basis_is_correctly_represented_in_heading() const ; + bool non_basis_is_correctly_represented_in_heading(std::list*) const ; bool basis_heading_is_correct() const; @@ -286,8 +288,8 @@ class lp_core_solver_base { X bound_span(unsigned j) const { return m_upper_bounds[j] - m_lower_bounds[j]; } - - std::string column_name(unsigned column) const; + // clang-format on + std::string column_name(unsigned column) const; bool make_column_feasible(unsigned j, numeric_pair & delta) { bool ret = false; @@ -298,15 +300,15 @@ class lp_core_solver_base { lp_assert(m_lower_bounds[j] == m_upper_bounds[j]); if (x != m_lower_bounds[j]) { delta = m_lower_bounds[j] - x; - ret = true;; + ret = true; } break; case column_type::boxed: if (x < m_lower_bounds[j]) { delta = m_lower_bounds[j] - x; - ret = true;; + ret = true; } - if (x > m_upper_bounds[j]) { + else if (x > m_upper_bounds[j]) { delta = m_upper_bounds[j] - x; ret = true; } @@ -333,7 +335,7 @@ class lp_core_solver_base { } - bool remove_from_basis(unsigned j); + bool remove_from_basis_core(unsigned entering, unsigned leaving); bool pivot_column_general(unsigned j, unsigned j_basic, indexed_vector & w); void init_basic_part_of_basis_heading() { unsigned m = m_basis.size(); @@ -411,6 +413,7 @@ class lp_core_solver_base { TRACE("lp_core", tout << "inf col "; print_column_info(j, tout) << "\n";); return false; } + return true; } @@ -514,7 +517,7 @@ class lp_core_solver_base { template - static void swap(vector &v, unsigned i, unsigned j) { + static void swap(vector &v, unsigned i, unsigned j) noexcept { auto t = v[i]; v[i] = v[j]; v[j] = t; @@ -526,63 +529,56 @@ class lp_core_solver_base { swap(m_basis_heading, m_basis[i], m_basis[ii]); } - bool column_is_in_inf_set(unsigned j) const { - return m_inf_set.contains(j); + bool column_is_in_inf_heap(unsigned j) const { + return m_inf_heap.contains(j); } bool column_is_base(unsigned j) const { return m_basis_heading[j] >= 0; } - - void update_x_with_feasibility_tracking(unsigned j, const X & v) { - TRACE("lar_solver", tout << "j = " << j << ", v = " << v << "\n";); - m_x[j] = v; - track_column_feasibility(j); - } - void add_delta_to_x_and_track_feasibility(unsigned j, const X & del) { - TRACE("lar_solver", tout << "del = " << del << ", was x[" << j << "] = " << m_x[j] << "\n";); + TRACE("lar_solver_feas", tout << "del = " << del << ", was x[" << j << "] = " << m_x[j] << "\n";); m_x[j] += del; - TRACE("lar_solver", tout << "became x[" << j << "] = " << m_x[j] << "\n";); + TRACE("lar_solver_feas", tout << "became x[" << j << "] = " << m_x[j] << "\n";); track_column_feasibility(j); } void update_x(unsigned j, const X & v) { - TRACE("lar_solver", tout << "j = " << j << ", v = " << v << "\n";); m_x[j] = v; + TRACE("lar_solver_feas", tout << "not tracking feas j = " << j << ", v = " << v << (column_is_feasible(j)? " feas":" non-feas") << "\n";); } - void add_delta_to_x(unsigned j, const X & delta) { - TRACE("lar_solver", tout << "j = " << j << ", delta = " << delta << "\n";); + void add_delta_to_x(unsigned j, const X& delta) { m_x[j] += delta; + TRACE("lar_solver_feas", tout << "not tracking feas j = " << j << " v = " << m_x[j] << " delta = " << delta << (column_is_feasible(j) ? " feas" : " non-feas") << "\n";); } - + void track_column_feasibility(unsigned j) { if (column_is_feasible(j)) - remove_column_from_inf_set(j); + remove_column_from_inf_heap(j); else - insert_column_into_inf_set(j); + insert_column_into_inf_heap(j); } - void insert_column_into_inf_set(unsigned j) { - TRACE("lar_solver", tout << "j = " << j << "\n";); - m_inf_set.insert(j); + void insert_column_into_inf_heap(unsigned j) { + if (!m_inf_heap.contains(j)) { + m_inf_heap.reserve(j+1); + m_inf_heap.insert(j); + TRACE("lar_solver_inf_heap", tout << "insert into inf_heap j = " << j << "\n";); + } lp_assert(!column_is_feasible(j)); } - void remove_column_from_inf_set(unsigned j) { - TRACE("lar_solver", tout << "j = " << j << "\n";); - m_inf_set.erase(j); + void remove_column_from_inf_heap(unsigned j) { + if (m_inf_heap.contains(j)) { + TRACE("lar_solver_inf_heap", tout << "erase from heap j = " << j << "\n";); + m_inf_heap.erase(j); + } lp_assert(column_is_feasible(j)); } - void resize_inf_set(unsigned size) { - TRACE("lar_solver",); - m_inf_set.resize(size); - } - - void clear_inf_set() { - TRACE("lar_solver",); - m_inf_set.clear(); + void clear_inf_heap() { + TRACE("lar_solver_feas",); + m_inf_heap.clear(); } bool costs_on_nbasis_are_zeros() const { diff --git a/src/math/lp/lp_core_solver_base_def.h b/src/math/lp/lp_core_solver_base_def.h index 8619c926e8d..da48eb5602c 100644 --- a/src/math/lp/lp_core_solver_base_def.h +++ b/src/math/lp/lp_core_solver_base_def.h @@ -42,7 +42,7 @@ lp_core_solver_base(static_matrix & A, m_total_iterations(0), m_iters_with_no_cost_growing(0), m_status(lp_status::FEASIBLE), - m_inf_set(A.column_count()), + m_inf_heap(std::max(static_cast(1024), A.column_count())), m_pivot_row(A.column_count()), m_A(A), m_basis(basis), @@ -56,9 +56,9 @@ lp_core_solver_base(static_matrix & A, m_column_types(column_types), m_lower_bounds(lower_bound_values), m_upper_bounds(upper_bound_values), - m_basis_sort_counter(0), + m_nbasis_sort_counter(0), m_tracing_basis_changes(false), - m_pivoted_rows(nullptr), + m_touched_rows(nullptr), m_look_for_feasible_solution_only(false) { lp_assert(bounds_for_boxed_are_set_correctly()); init(); @@ -115,12 +115,13 @@ pretty_print(std::ostream & out) { template void lp_core_solver_base:: add_delta_to_entering(unsigned entering, const X& delta) { - m_x[entering] += delta; - - for (const auto & c : m_A.m_columns[entering]) { - unsigned i = c.var(); - m_x[m_basis[i]] -= delta * m_A.get_val(c); - } + m_x[entering] += delta; + TRACE("lar_solver_feas", tout << "not tracking feas entering = " << entering << " = " << m_x[entering] << (column_is_feasible(entering) ? " feas" : " non-feas") << "\n";); + for (const auto & c : m_A.m_columns[entering]) { + unsigned i = c.var(); + m_x[m_basis[i]] -= delta * m_A.get_val(c); + TRACE("lar_solver_feas", tout << "not tracking feas m_basis[i] = " << m_basis[i] << " = " << m_x[m_basis[i]] << (column_is_feasible(m_basis[i]) ? " feas" : " non-feas") << "\n";); + } } @@ -132,37 +133,6 @@ print_statistics(char const* str, X cost, std::ostream & out) { << ", nonzeros = " << m_A.number_of_non_zeroes() << std::endl; } -template bool lp_core_solver_base:: -print_statistics_with_iterations_and_check_that_the_time_is_over(std::ostream & str) { - unsigned total_iterations = inc_total_iterations(); - if (m_settings.report_frequency != 0) { - if (m_settings.print_statistics && (total_iterations % m_settings.report_frequency == 0)) { - print_statistics("", X(), str); - } - } - return time_is_over(); -} - -template bool lp_core_solver_base:: -print_statistics_with_iterations_and_nonzeroes_and_cost_and_check_that_the_time_is_over(char const* str, std::ostream & out) { - unsigned total_iterations = inc_total_iterations(); - if (m_settings.report_frequency != 0) - if (m_settings.print_statistics && (total_iterations % m_settings.report_frequency == 0)) { - print_statistics(str, get_cost(), out); - } - return time_is_over(); -} - -template bool lp_core_solver_base:: -print_statistics_with_cost_and_check_that_the_time_is_over(X cost, std::ostream & out) { - unsigned total_iterations = inc_total_iterations(); - if (m_settings.report_frequency != 0) - if (m_settings.print_statistics && (total_iterations % m_settings.report_frequency == 0)) { - print_statistics("", cost, out); - } - return time_is_over(); -} - template bool lp_core_solver_base:: column_is_dual_feasible(unsigned j) const { switch (m_column_types[j]) { @@ -193,18 +163,6 @@ d_is_not_positive(unsigned j) const { return m_d[j] <= numeric_traits::zero(); } - -template bool lp_core_solver_base:: -time_is_over() { - if (m_settings.get_cancel_flag()) { - m_status = lp_status::TIME_EXHAUSTED; - return true; - } - else { - return false; - } -} - template void lp_core_solver_base:: rs_minus_Anx(vector & rs) { unsigned row = m_m(); @@ -250,9 +208,9 @@ template bool lp_core_solver_base::calc_current_x return true; } -template bool lp_core_solver_base::inf_set_is_correct() const { +template bool lp_core_solver_base::inf_heap_is_correct() const { for (unsigned j = 0; j < this->m_n(); j++) { - bool belongs_to_set = m_inf_set.contains(j); + bool belongs_to_set = m_inf_heap.contains(j); bool is_feas = column_is_feasible(j); if (is_feas == belongs_to_set) { TRACE("lp_core", tout << "incorrectly set column in inf set "; print_column_info(j, tout) << "\n";); @@ -324,8 +282,8 @@ pivot_column_tableau(unsigned j, unsigned piv_row_index) { if(! m_A.pivot_row_to_row_given_cell(piv_row_index, c, j)) { return false; } - if (m_pivoted_rows!= nullptr) - m_pivoted_rows->insert(c.var()); + if (m_touched_rows!= nullptr) + m_touched_rows->insert(c.var()); } if (m_settings.simplex_strategy() == simplex_strategy_enum::tableau_costs) @@ -359,7 +317,7 @@ basis_is_correctly_represented_in_heading() const { return true; } template bool lp_core_solver_base:: -non_basis_is_correctly_represented_in_heading() const { +non_basis_is_correctly_represented_in_heading(std::list* non_basis_list) const { for (unsigned i = 0; i < m_nbasis.size(); i++) if (m_basis_heading[m_nbasis[i]] != - static_cast(i) - 1) return false; @@ -367,7 +325,34 @@ non_basis_is_correctly_represented_in_heading() const { for (unsigned j = 0; j < m_A.column_count(); j++) if (m_basis_heading[j] >= 0) lp_assert(static_cast(m_basis_heading[j]) < m_A.row_count() && m_basis[m_basis_heading[j]] == j); - + + if (non_basis_list == nullptr) return true; + + std::unordered_set nbasis_set(this->m_nbasis.size()); + for (unsigned j : this->m_nbasis) + nbasis_set.insert(j); + + if (non_basis_list->size() != nbasis_set.size()) { + TRACE("lp_core", tout << "non_basis_list.size() = " << non_basis_list->size() << ", nbasis_set.size() = " << nbasis_set.size() << "\n";); + return false; + } + for (auto it = non_basis_list->begin(); it != non_basis_list->end(); it++) { + if (nbasis_set.find(*it) == nbasis_set.end()) { + TRACE("lp_core", tout << "column " << *it << " is in m_non_basis_list but not in m_nbasis\n";); + return false; + } + } + + // check for duplicates in m_non_basis_list + nbasis_set.clear(); + for (auto it = non_basis_list->begin(); it != non_basis_list->end(); it++) { + if (nbasis_set.find(*it) != nbasis_set.end()) { + TRACE("lp_core", tout << "column " << *it << " is in m_non_basis_list twice\n";); + return false; + } + nbasis_set.insert(*it); + } + return true; } @@ -389,7 +374,7 @@ template bool lp_core_solver_base:: if (!basis_is_correctly_represented_in_heading()) return false; - if (!non_basis_is_correctly_represented_in_heading()) + if (!non_basis_is_correctly_represented_in_heading(nullptr)) return false; return true; @@ -404,30 +389,22 @@ template void lp_core_solver_base::transpose_row transpose_basis(i, j); m_A.transpose_rows(i, j); } -// j is the new basic column, j_basic - the leaving column -template bool lp_core_solver_base::pivot_column_general(unsigned j, unsigned j_basic, indexed_vector & w) { - lp_assert(m_basis_heading[j] < 0); - lp_assert(m_basis_heading[j_basic] >= 0); - unsigned row_index = m_basis_heading[j_basic]; - // the tableau case - if (pivot_column_tableau(j, row_index)) - change_basis(j, j_basic); - else return false; - - return true; +// entering is the new base column, leaving - the column leaving the basis +template bool lp_core_solver_base::pivot_column_general(unsigned entering, unsigned leaving, indexed_vector & w) { + lp_assert(m_basis_heading[entering] < 0); + lp_assert(m_basis_heading[leaving] >= 0); + unsigned row_index = m_basis_heading[leaving]; + // the tableau case + if (!pivot_column_tableau(entering, row_index)) + return false; + change_basis(entering, leaving); + return true; } -template bool lp_core_solver_base::remove_from_basis(unsigned basic_j) { +template bool lp_core_solver_base::remove_from_basis_core(unsigned entering, unsigned leaving) { indexed_vector w(m_basis.size()); // the buffer - unsigned i = m_basis_heading[basic_j]; - for (auto &c : m_A.m_rows[i]) { - if (c.var() == basic_j) - continue; - if (pivot_column_general(c.var(), basic_j, w)) - return true; - } - return false; + return pivot_column_general(entering, leaving, w); } diff --git a/src/math/lp/lp_primal_core_solver.h b/src/math/lp/lp_primal_core_solver.h index 207428985bf..d34569032a6 100644 --- a/src/math/lp/lp_primal_core_solver.h +++ b/src/math/lp/lp_primal_core_solver.h @@ -17,12 +17,11 @@ Revision History: --*/ - #pragma once #include "math/lp/core_solver_pretty_printer.h" #include "math/lp/lp_core_solver_base.h" #include "math/lp/static_matrix.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" #include "util/vector.h" #include #include @@ -33,642 +32,626 @@ Revision History: #include #include #include +#include "util/heap.h" + namespace lp { -// This core solver solves (Ax=b, lower_bound_values \leq x \leq -// upper_bound_values, maximize costs*x ) The right side b is given implicitly -// by x and the basis -template -class lp_primal_core_solver : public lp_core_solver_base { -public: - int m_sign_of_entering_delta; - vector m_costs_backup; - unsigned m_inf_row_index_for_tableau; - bool m_bland_mode_tableau; - u_set m_left_basis_tableau; - unsigned m_bland_mode_threshold; - unsigned m_left_basis_repeated; - vector m_leaving_candidates; - - std::list m_non_basis_list; - void sort_non_basis(); - int choose_entering_column_tableau(); - - bool needs_to_grow(unsigned bj) const { - lp_assert(!this->column_is_feasible(bj)); - switch (this->m_column_types[bj]) { - case column_type::free_column: - return false; - case column_type::fixed: - case column_type::lower_bound: - case column_type::boxed: - return this->x_below_low_bound(bj); - default: - return false; - } - UNREACHABLE(); // unreachable - return false; - } - - int inf_sign_of_column(unsigned bj) const { - lp_assert(!this->column_is_feasible(bj)); - switch (this->m_column_types[bj]) { - case column_type::free_column: - return 0; - case column_type::lower_bound: - return 1; - case column_type::fixed: - case column_type::boxed: - return this->x_above_upper_bound(bj) ? -1 : 1; - default: - return -1; - } - UNREACHABLE(); // unreachable - return 0; - } - - bool monoid_can_decrease(const row_cell &rc) const { - unsigned j = rc.var(); - lp_assert(this->column_is_feasible(j)); - switch (this->m_column_types[j]) { - case column_type::free_column: - return true; - case column_type::fixed: - return false; - case column_type::lower_bound: - if (is_pos(rc.coeff())) { - return this->x_above_lower_bound(j); - } - - return true; - case column_type::upper_bound: - if (is_pos(rc.coeff())) { - return true; - } + // This core solver solves (Ax=b, lower_bound_values \leq x \leq + // upper_bound_values, maximize costs*x ) The right side b is given implicitly + // by x and the basis + template + class lp_primal_core_solver : public lp_core_solver_base { + public: + int m_sign_of_entering_delta; + vector m_costs_backup; + unsigned m_inf_row_index_for_tableau; + bool m_bland_mode_tableau; + indexed_uint_set m_left_basis_tableau; + unsigned m_bland_mode_threshold; + unsigned m_left_basis_repeated; + vector m_leaving_candidates; + + std::list m_non_basis_list; + void sort_non_basis(); + int choose_entering_column_tableau(); + + bool needs_to_grow(unsigned bj) const { + lp_assert(!this->column_is_feasible(bj)); + switch (this->m_column_types[bj]) { + case column_type::free_column: + return false; + case column_type::fixed: + case column_type::lower_bound: + case column_type::boxed: + return this->x_below_low_bound(bj); + default: + return false; + } + UNREACHABLE(); // unreachable + return false; + } - return this->x_below_upper_bound(j); - case column_type::boxed: - if (is_pos(rc.coeff())) { - return this->x_above_lower_bound(j); - } + int inf_sign_of_column(unsigned bj) const { + lp_assert(!this->column_is_feasible(bj)); + switch (this->m_column_types[bj]) { + case column_type::free_column: + return 0; + case column_type::lower_bound: + return 1; + case column_type::fixed: + case column_type::boxed: + return this->x_above_upper_bound(bj) ? -1 : 1; + default: + return -1; + } + UNREACHABLE(); // unreachable + return 0; + } - return this->x_below_upper_bound(j); - default: - return false; - } - UNREACHABLE(); // unreachable - return false; - } - - bool monoid_can_increase(const row_cell &rc) const { - unsigned j = rc.var(); - lp_assert(this->column_is_feasible(j)); - switch (this->m_column_types[j]) { - case column_type::free_column: - return true; - case column_type::fixed: - return false; - case column_type::lower_bound: - if (is_neg(rc.coeff())) { - return this->x_above_lower_bound(j); - } - - return true; - case column_type::upper_bound: - if (is_neg(rc.coeff())) { - return true; - } + bool monoid_can_decrease(const row_cell &rc) const { + unsigned j = rc.var(); + lp_assert(this->column_is_feasible(j)); + switch (this->m_column_types[j]) { + case column_type::free_column: + return true; + case column_type::fixed: + return false; + case column_type::lower_bound: + return !is_pos(rc.coeff()) || this->x_above_lower_bound(j); + case column_type::upper_bound: + return is_pos(rc.coeff()) || this->x_below_upper_bound(j); + case column_type::boxed: + if (is_pos(rc.coeff())) + return this->x_above_lower_bound(j); + return this->x_below_upper_bound(j); + default: + return false; + } + UNREACHABLE(); // unreachable + return false; + } - return this->x_below_upper_bound(j); - case column_type::boxed: - if (is_neg(rc.coeff())) { - return this->x_above_lower_bound(j); - } + bool monoid_can_increase(const row_cell &rc) const { + unsigned j = rc.var(); + lp_assert(this->column_is_feasible(j)); + switch (this->m_column_types[j]) { + case column_type::free_column: + return true; + case column_type::fixed: + return false; + case column_type::lower_bound: + if (is_neg(rc.coeff())) + return this->x_above_lower_bound(j); + return true; + case column_type::upper_bound: + if (is_neg(rc.coeff())) + return true; + return this->x_below_upper_bound(j); + case column_type::boxed: + if (is_neg(rc.coeff())) + return this->x_above_lower_bound(j); + return this->x_below_upper_bound(j); + default: + return false; + } + UNREACHABLE(); // unreachable + return false; + } - return this->x_below_upper_bound(j); - default: - return false; - } - UNREACHABLE(); // unreachable - return false; - } - - unsigned get_number_of_basic_vars_that_might_become_inf( - unsigned j) const { // consider looking at the signs here: todo - unsigned r = 0; - for (const auto &cc : this->m_A.m_columns[j]) { - unsigned k = this->m_basis[cc.var()]; - if (this->m_column_types[k] != column_type::free_column) - r++; - } - return r; - } - - int find_beneficial_column_in_row_tableau_rows_bland_mode(int i, T &a_ent) { - int j = -1; - unsigned bj = this->m_basis[i]; - bool bj_needs_to_grow = needs_to_grow(bj); - for (const row_cell &rc : this->m_A.m_rows[i]) { - if (rc.var() == bj) - continue; - if (bj_needs_to_grow) { - if (!monoid_can_decrease(rc)) - continue; - } else { - if (!monoid_can_increase(rc)) - continue; - } - if (rc.var() < static_cast(j)) { - j = rc.var(); - a_ent = rc.coeff(); - } - } - if (j == -1) { - m_inf_row_index_for_tableau = i; + /** + * Return the number of base non-free variables depending on the column j, + * different from bj, + * but take the min with the (bound+1). + * This function is used to select the pivot variable. + */ + unsigned get_num_of_not_free_basic_dependent_vars(unsigned j, unsigned bound, unsigned bj) const { + // consider looking at the signs here: todo + unsigned r = 0; + for (const auto &cc : this->m_A.m_columns[j]) { + unsigned basic_for_row = this->m_basis[cc.var()]; + if (basic_for_row == bj) + continue; + + // std::cout << this->m_A.m_rows[cc.var()] << std::endl; + if (this->m_column_types[basic_for_row] != column_type::free_column) + if (r++ > bound) return r; + } + return r; + } + + int find_beneficial_entering_in_row_tableau_rows_bland_mode(int i, T &a_ent) { + int j = -1; + unsigned bj = this->m_basis[i]; + bool bj_needs_to_grow = needs_to_grow(bj); + for (const row_cell &rc : this->m_A.m_rows[i]) { + if (rc.var() == bj) + continue; + if (bj_needs_to_grow) { + if (!monoid_can_decrease(rc)) + continue; + } + else { + if (!monoid_can_increase(rc)) + continue; + } + if (rc.var() < static_cast(j)) { + j = rc.var(); + a_ent = rc.coeff(); + } + } + if (j == -1) + m_inf_row_index_for_tableau = i; + return j; + } + + int find_beneficial_entering_tableau_rows(int i, T &a_ent) { + if (m_bland_mode_tableau) + return find_beneficial_entering_in_row_tableau_rows_bland_mode(i, a_ent); + // a short row produces short infeasibility explanation and benefits at + // least one pivot operation + int choice = -1; + int nchoices = 0; + unsigned min_non_free_so_far = -1; + unsigned best_col_sz = -1; + unsigned bj = this->m_basis[i]; + bool bj_needs_to_grow = needs_to_grow(bj); + for (unsigned k = 0; k < this->m_A.m_rows[i].size(); k++) { + const row_cell &rc = this->m_A.m_rows[i][k]; + unsigned j = rc.var(); + if (j == bj) + continue; + if (bj_needs_to_grow) { + if (!monoid_can_decrease(rc)) + continue; + } else { + if (!monoid_can_increase(rc)) + continue; + } + unsigned not_free = get_num_of_not_free_basic_dependent_vars(j, min_non_free_so_far, bj); + unsigned col_sz = this->m_A.m_columns[j].size(); + if (not_free < min_non_free_so_far || (not_free == min_non_free_so_far && col_sz < best_col_sz)) { + min_non_free_so_far = not_free; + best_col_sz = this->m_A.m_columns[j].size(); + choice = k; + nchoices = 1; + } + else if (not_free == min_non_free_so_far && + col_sz == best_col_sz) { + if (this->m_settings.random_next(++nchoices) == 0) + choice = k; + } + } + + if (choice == -1) { + m_inf_row_index_for_tableau = i; + return -1; + } + const row_cell &rc = this->m_A.m_rows[i][choice]; + a_ent = rc.coeff(); + return rc.var(); + } + + bool try_jump_to_another_bound_on_entering(unsigned entering, X &t); + + bool try_jump_to_another_bound_on_entering_unlimited(unsigned entering, X &t); + + int find_leaving_and_t_tableau(unsigned entering, X &t); + + void limit_theta(const X &lim, X &theta, bool &unlimited) { + if (unlimited) { + theta = lim; + unlimited = false; + } else + theta = std::min(lim, theta); } - return j; - } - - int find_beneficial_column_in_row_tableau_rows(int i, T &a_ent) { - if (m_bland_mode_tableau) - return find_beneficial_column_in_row_tableau_rows_bland_mode(i, a_ent); - // a short row produces short infeasibility explanation and benefits at - // least one pivot operation - int choice = -1; - int nchoices = 0; - unsigned num_of_non_free_basics = 1000000; - unsigned len = 100000000; - unsigned bj = this->m_basis[i]; - bool bj_needs_to_grow = needs_to_grow(bj); - for (unsigned k = 0; k < this->m_A.m_rows[i].size(); k++) { - const row_cell &rc = this->m_A.m_rows[i][k]; - unsigned j = rc.var(); - if (j == bj) - continue; - if (bj_needs_to_grow) { - if (!monoid_can_decrease(rc)) - continue; - } else { - if (!monoid_can_increase(rc)) - continue; - } - unsigned damage = get_number_of_basic_vars_that_might_become_inf(j); - if (damage < num_of_non_free_basics) { - num_of_non_free_basics = damage; - len = this->m_A.m_columns[j].size(); - choice = k; - nchoices = 1; - } else if (damage == num_of_non_free_basics && - this->m_A.m_columns[j].size() <= len && - (this->m_settings.random_next() % (++nchoices))) { - choice = k; - len = this->m_A.m_columns[j].size(); - } + void limit_theta_on_basis_column_for_inf_case_m_neg_upper_bound( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m < 0 && this->m_column_types[j] == column_type::upper_bound); + limit_inf_on_upper_bound_m_neg(m, this->m_x[j], this->m_upper_bounds[j], theta, unlimited); } - if (choice == -1) { - m_inf_row_index_for_tableau = i; - return -1; + void limit_theta_on_basis_column_for_inf_case_m_neg_lower_bound( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m < 0 && this->m_column_types[j] == column_type::lower_bound); + limit_inf_on_bound_m_neg(m, this->m_x[j], this->m_lower_bounds[j], theta, unlimited); } - const row_cell &rc = this->m_A.m_rows[i][choice]; - a_ent = rc.coeff(); - return rc.var(); - } - - bool try_jump_to_another_bound_on_entering(unsigned entering, const X &theta, - X &t, bool &unlimited); - bool try_jump_to_another_bound_on_entering_unlimited(unsigned entering, X &t); - int find_leaving_and_t_tableau(unsigned entering, X &t); - - void limit_theta(const X &lim, X &theta, bool &unlimited) { - if (unlimited) { - theta = lim; - unlimited = false; - } else { - theta = std::min(lim, theta); + + void limit_theta_on_basis_column_for_inf_case_m_pos_lower_bound( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m > 0 && this->m_column_types[j] == column_type::lower_bound); + limit_inf_on_lower_bound_m_pos(m, this->m_x[j], this->m_lower_bounds[j], theta, unlimited); } - } - - void limit_theta_on_basis_column_for_inf_case_m_neg_upper_bound( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m < 0 && this->m_column_types[j] == column_type::upper_bound); - limit_inf_on_upper_bound_m_neg(m, this->m_x[j], this->m_upper_bounds[j], - theta, unlimited); - } - - void limit_theta_on_basis_column_for_inf_case_m_neg_lower_bound( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m < 0 && this->m_column_types[j] == column_type::lower_bound); - limit_inf_on_bound_m_neg(m, this->m_x[j], this->m_lower_bounds[j], theta, - unlimited); - } - - void limit_theta_on_basis_column_for_inf_case_m_pos_lower_bound( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m > 0 && this->m_column_types[j] == column_type::lower_bound); - limit_inf_on_lower_bound_m_pos(m, this->m_x[j], this->m_lower_bounds[j], - theta, unlimited); - } - - void limit_theta_on_basis_column_for_inf_case_m_pos_upper_bound( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m > 0 && this->m_column_types[j] == column_type::upper_bound); - limit_inf_on_bound_m_pos(m, this->m_x[j], this->m_upper_bounds[j], theta, - unlimited); - }; - - void get_bound_on_variable_and_update_leaving_precisely( - unsigned j, vector &leavings, T m, X &t, - T &abs_of_d_of_leaving); - - X get_max_bound(vector &b); + + void limit_theta_on_basis_column_for_inf_case_m_pos_upper_bound( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m > 0 && this->m_column_types[j] == column_type::upper_bound); + limit_inf_on_bound_m_pos(m, this->m_x[j], this->m_upper_bounds[j], theta, unlimited); + }; + + void get_bound_on_variable_and_update_leaving_precisely( + unsigned j, vector &leavings, T m, X &t, + T &abs_of_d_of_leaving); #ifdef Z3DEBUG - void check_Ax_equal_b(); - void check_the_bounds(); - void check_bound(unsigned i); - void check_correctness(); + void check_Ax_equal_b(); + void check_the_bounds(); + void check_bound(unsigned i); + void check_correctness(); #endif - // from page 183 of Istvan Maros's book - // the basis structures have not changed yet - void update_reduced_costs_from_pivot_row(unsigned entering, unsigned leaving); - - // return 0 if the reduced cost at entering is close enough to the refreshed - // 1 if it is way off, and 2 if it is unprofitable - int refresh_reduced_cost_at_entering_and_check_that_it_is_off( - unsigned entering); - - void backup_and_normalize_costs(); - - void advance_on_entering_and_leaving_tableau(int entering, int leaving, X &t); - void advance_on_entering_equal_leaving_tableau(int entering, X &t); - - bool need_to_switch_costs() const { - if (this->m_settings.simplex_strategy() == - simplex_strategy_enum::tableau_rows) - return false; - // lp_assert(calc_current_x_is_feasible() == - // current_x_is_feasible()); - return this->current_x_is_feasible() == this->using_infeas_costs(); - } - - void advance_on_entering_tableau(int entering); - - void push_forward_offset_in_non_basis(unsigned &offset_in_nb); - - unsigned get_number_of_non_basic_column_to_try_for_enter(); - - // returns the number of iterations - unsigned solve(); - - void find_feasible_solution(); - - // bool is_tiny() const {return this->m_m < 10 && this->m_n < 20;} - - void one_iteration_tableau(); - - // this version assumes that the leaving already has the right value, and does - // not update it - void update_x_tableau_rows(unsigned entering, unsigned leaving, - const X &delta) { - this->add_delta_to_x(entering, delta); - for (const auto &c : this->m_A.m_columns[entering]) { - if (leaving != this->m_basis[c.var()]) { - this->add_delta_to_x_and_track_feasibility( - this->m_basis[c.var()], -delta * this->m_A.get_val(c)); - } - } - } - - void update_basis_and_x_tableau_rows(int entering, int leaving, X const &tt) { - lp_assert(entering != leaving); - update_x_tableau_rows(entering, leaving, tt); - this->pivot_column_tableau(entering, this->m_basis_heading[leaving]); - this->change_basis(entering, leaving); - } - - void advance_on_entering_and_leaving_tableau_rows(int entering, int leaving, - const X &theta) { - update_basis_and_x_tableau_rows(entering, leaving, theta); - this->track_column_feasibility(entering); - } - - int find_smallest_inf_column() { - int j = -1; - for (unsigned k : this->inf_set()) { - if (k < static_cast(j)) { - j = k; - } + void backup_and_normalize_costs(); + + void advance_on_entering_and_leaving_tableau(int entering, int leaving, X &t); + void advance_on_entering_equal_leaving_tableau(int entering, X &t); + + void pivot(int entering, int leaving) { + this->pivot_column_tableau(entering, this->m_basis_heading[leaving]); + this->change_basis(entering, leaving); } - return j; - } - - const X &get_val_for_leaving(unsigned j) const { - lp_assert(!this->column_is_feasible(j)); - switch (this->m_column_types[j]) { - case column_type::fixed: - case column_type::upper_bound: - return this->m_upper_bounds[j]; - case column_type::lower_bound: - return this->m_lower_bounds[j]; - break; - case column_type::boxed: - if (this->x_above_upper_bound(j)) - return this->m_upper_bounds[j]; - else - return this->m_lower_bounds[j]; - break; - default: - UNREACHABLE(); - return this->m_lower_bounds[j]; + + bool need_to_switch_costs() const { + if (this->m_settings.simplex_strategy() == + simplex_strategy_enum::tableau_rows) + return false; + // lp_assert(calc_current_x_is_feasible() == + // current_x_is_feasible()); + return this->current_x_is_feasible() == this->using_infeas_costs(); } - } - void one_iteration_tableau_rows() { - int leaving = find_smallest_inf_column(); - if (leaving == -1) { - this->set_status(lp_status::OPTIMAL); - return; + void advance_on_entering_tableau(int entering); + + void push_forward_offset_in_non_basis(unsigned &offset_in_nb); + + unsigned get_number_of_non_basic_column_to_try_for_enter(); + + // returns the number of iterations + unsigned solve(); + + void find_feasible_solution(); + + // bool is_tiny() const {return this->m_m < 10 && this->m_n < 20;} + + void one_iteration_tableau(); + + // this version assumes that the leaving already has the right value, and does + // not update it + void update_x_tableau_rows(unsigned entering, unsigned leaving, + const X &delta) { + this->add_delta_to_x(entering, delta); + for (const auto &c : this->m_A.m_columns[entering]) + if (leaving != this->m_basis[c.var()]) + this->add_delta_to_x_and_track_feasibility( + this->m_basis[c.var()], -delta * this->m_A.get_val(c)); } - SASSERT(this->column_is_base(leaving)); + void update_basis_and_x_tableau_rows(int entering, int leaving, X const &tt) { + lp_assert(entering != leaving); + update_x_tableau_rows(entering, leaving, tt); + this->pivot_column_tableau(entering, this->m_basis_heading[leaving]); + this->change_basis(entering, leaving); + } - if (!m_bland_mode_tableau) { - if (m_left_basis_tableau.contains(leaving)) { - if (++m_left_basis_repeated > m_bland_mode_threshold) { - m_bland_mode_tableau = true; - } - } else { - m_left_basis_tableau.insert(leaving); - } + void advance_on_entering_and_leaving_tableau_rows(int entering, int leaving, + const X &theta) { + update_basis_and_x_tableau_rows(entering, leaving, theta); + this->track_column_feasibility(entering); } - T a_ent; - int entering = find_beneficial_column_in_row_tableau_rows( - this->m_basis_heading[leaving], a_ent); - if (entering == -1) { - this->set_status(lp_status::INFEASIBLE); - return; + + int find_smallest_inf_column() { + if (this->inf_heap().empty()) + return -1; + + return this->inf_heap().min_value(); } - const X &new_val_for_leaving = get_val_for_leaving(leaving); - X theta = (this->m_x[leaving] - new_val_for_leaving) / a_ent; - this->m_x[leaving] = new_val_for_leaving; - this->remove_column_from_inf_set(leaving); - advance_on_entering_and_leaving_tableau_rows(entering, leaving, theta); - if (this->current_x_is_feasible()) - this->set_status(lp_status::OPTIMAL); - } - - void decide_on_status_when_cannot_find_entering() { - this->set_status(this->current_x_is_feasible() ? lp_status::OPTIMAL - : lp_status::INFEASIBLE); - } - - void limit_theta_on_basis_column_for_feas_case_m_neg_no_check( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m < 0); - limit_theta((this->m_lower_bounds[j] - this->m_x[j]) / m, theta, unlimited); - if (theta < zero_of_type()) - theta = zero_of_type(); - } - - bool limit_inf_on_bound_m_neg(const T &m, const X &x, const X &bound, - X &theta, bool &unlimited) { - // x gets smaller - lp_assert(m < 0); - if (this->below_bound(x, bound)) - return false; - if (this->above_bound(x, bound)) { - limit_theta((bound - x) / m, theta, unlimited); - } else { - theta = zero_of_type(); - unlimited = false; + + const X &get_val_for_leaving(unsigned j) const { + lp_assert(!this->column_is_feasible(j)); + switch (this->m_column_types[j]) { + case column_type::fixed: + case column_type::upper_bound: + return this->m_upper_bounds[j]; + case column_type::lower_bound: + return this->m_lower_bounds[j]; + break; + case column_type::boxed: + if (this->x_above_upper_bound(j)) + return this->m_upper_bounds[j]; + else + return this->m_lower_bounds[j]; + break; + default: + UNREACHABLE(); + return this->m_lower_bounds[j]; + } } - return true; - } - - bool limit_inf_on_bound_m_pos(const T &m, const X &x, const X &bound, - X &theta, bool &unlimited) { - // x gets larger - lp_assert(m > 0); - if (this->above_bound(x, bound)) - return false; - if (this->below_bound(x, bound)) { - limit_theta((bound - x) / m, theta, unlimited); - } else { - theta = zero_of_type(); - unlimited = false; + + void one_iteration_tableau_rows() { + int leaving = find_smallest_inf_column(); + if (leaving == -1) { + this->set_status(lp_status::OPTIMAL); + return; + } + + SASSERT(this->column_is_base(leaving)); + + if (!m_bland_mode_tableau) { + if (m_left_basis_tableau.contains(leaving)) { + if (++m_left_basis_repeated > m_bland_mode_threshold) { + m_bland_mode_tableau = true; + } + } else { + m_left_basis_tableau.insert(leaving); + } + } + T a_ent; + int entering = find_beneficial_entering_tableau_rows( + this->m_basis_heading[leaving], a_ent); + if (entering == -1) { + this->set_status(lp_status::INFEASIBLE); + return; + } + const X &new_val_for_leaving = get_val_for_leaving(leaving); + X theta = (this->m_x[leaving] - new_val_for_leaving) / a_ent; + this->m_x[leaving] = new_val_for_leaving; + TRACE("lar_solver_feas", tout << "entering = " << entering << ", leaving = " << leaving << ", new_val_for_leaving = " << new_val_for_leaving << ", theta = " << theta << "\n";); + TRACE("lar_solver_feas", tout << "leaving = " << leaving + << " removed from inf_heap()\n";); + // this will remove the leaving from the heap + this->inf_heap().erase_min(); + advance_on_entering_and_leaving_tableau_rows(entering, leaving, theta); + if (this->current_x_is_feasible()) + this->set_status(lp_status::OPTIMAL); } - return true; - } + void decide_on_status_when_cannot_find_entering() { + this->set_status(this->current_x_is_feasible() ? lp_status::OPTIMAL + : lp_status::INFEASIBLE); + } - void limit_inf_on_lower_bound_m_pos(const T &m, const X &x, const X &bound, - X &theta, bool &unlimited) { - // x gets larger - lp_assert(m > 0); - if (this->below_bound(x, bound)) { - limit_theta((bound - x) / m, theta, unlimited); + void limit_theta_on_basis_column_for_feas_case_m_neg_no_check( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m < 0); + limit_theta((this->m_lower_bounds[j] - this->m_x[j]) / m, theta, unlimited); + if (theta < zero_of_type()) + theta = zero_of_type(); } - } - - void limit_inf_on_upper_bound_m_neg(const T &m, const X &x, const X &bound, - X &theta, bool &unlimited) { - // x gets smaller - lp_assert(m < 0); - if (this->above_bound(x, bound)) { - limit_theta((bound - x) / m, theta, unlimited); + + bool limit_inf_on_bound_m_neg(const T &m, const X &x, const X &bound, + X &theta, bool &unlimited) { + // x gets smaller + lp_assert(m < 0); + if (this->below_bound(x, bound)) + return false; + if (this->above_bound(x, bound)) { + limit_theta((bound - x) / m, theta, unlimited); + } else { + theta = zero_of_type(); + unlimited = false; + } + return true; } - } - - void limit_theta_on_basis_column_for_inf_case_m_pos_boxed(unsigned j, - const T &m, - X &theta, - bool &unlimited) { - const X &x = this->m_x[j]; - const X &lbound = this->m_lower_bounds[j]; - - if (this->below_bound(x, lbound)) { - limit_theta((lbound - x) / m, theta, unlimited); - } else { - const X &ubound = this->m_upper_bounds[j]; - if (this->below_bound(x, ubound)) { - limit_theta((ubound - x) / m, theta, unlimited); - } else if (!this->above_bound(x, ubound)) { - theta = zero_of_type(); - unlimited = false; - } + + bool limit_inf_on_bound_m_pos(const T &m, const X &x, const X &bound, + X &theta, bool &unlimited) { + // x gets larger + lp_assert(m > 0); + if (this->above_bound(x, bound)) + return false; + if (this->below_bound(x, bound)) { + limit_theta((bound - x) / m, theta, unlimited); + } else { + theta = zero_of_type(); + unlimited = false; + } + + return true; } - } - - void limit_theta_on_basis_column_for_inf_case_m_neg_boxed(unsigned j, - const T &m, - X &theta, - bool &unlimited) { - // lp_assert(m < 0 && this->m_column_type[j] == column_type::boxed); - const X &x = this->m_x[j]; - const X &ubound = this->m_upper_bounds[j]; - if (this->above_bound(x, ubound)) { - limit_theta((ubound - x) / m, theta, unlimited); - } else { - const X &lbound = this->m_lower_bounds[j]; - if (this->above_bound(x, lbound)) { - limit_theta((lbound - x) / m, theta, unlimited); - } else if (!this->below_bound(x, lbound)) { - theta = zero_of_type(); - unlimited = false; - } + + void limit_inf_on_lower_bound_m_pos(const T &m, const X &x, const X &bound, + X &theta, bool &unlimited) { + // x gets larger + lp_assert(m > 0); + if (this->below_bound(x, bound)) { + limit_theta((bound - x) / m, theta, unlimited); + } } - } - - void limit_theta_on_basis_column_for_feas_case_m_pos_no_check( - unsigned j, const T &m, X &theta, bool &unlimited) { - lp_assert(m > 0); - limit_theta((this->m_upper_bounds[j] - this->m_x[j]) / m, theta, unlimited); - if (theta < zero_of_type()) { - theta = zero_of_type(); + + void limit_inf_on_upper_bound_m_neg(const T &m, const X &x, const X &bound, + X &theta, bool &unlimited) { + // x gets smaller + lp_assert(m < 0); + if (this->above_bound(x, bound)) { + limit_theta((bound - x) / m, theta, unlimited); + } } - } - - // j is a basic column or the entering, in any case x[j] has to stay feasible. - // m is the multiplier. updating t in a way that holds the following - // x[j] + t * m >= this->m_lower_bounds[j]( if m < 0 ) - // or - // x[j] + t * m <= this->m_upper_bounds[j] ( if m > 0) - void limit_theta_on_basis_column(unsigned j, T m, X &theta, bool &unlimited) { - switch (this->m_column_types[j]) { - case column_type::free_column: - break; - case column_type::upper_bound: - if (this->current_x_is_feasible()) { - if (m > 0) - limit_theta_on_basis_column_for_feas_case_m_pos_no_check(j, m, theta, - unlimited); - } else { // inside of feasibility_loop - if (m > 0) - limit_theta_on_basis_column_for_inf_case_m_pos_upper_bound( - j, m, theta, unlimited); - else - limit_theta_on_basis_column_for_inf_case_m_neg_upper_bound( - j, m, theta, unlimited); - } - break; - case column_type::lower_bound: - if (this->current_x_is_feasible()) { - if (m < 0) - limit_theta_on_basis_column_for_feas_case_m_neg_no_check(j, m, theta, - unlimited); - } else { - if (m < 0) - limit_theta_on_basis_column_for_inf_case_m_neg_lower_bound( - j, m, theta, unlimited); - else - limit_theta_on_basis_column_for_inf_case_m_pos_lower_bound( - j, m, theta, unlimited); - } - break; - // case fixed: - // if (get_this->current_x_is_feasible()) { - // theta = zero_of_type(); - // break; - // } - // if (m < 0) - // limit_theta_on_basis_column_for_inf_case_m_neg_fixed(j, m, - // theta); - // else - // limit_theta_on_basis_column_for_inf_case_m_pos_fixed(j, m, - // theta); - // break; - case column_type::fixed: - case column_type::boxed: - if (this->current_x_is_feasible()) { - if (m > 0) { - limit_theta_on_basis_column_for_feas_case_m_pos_no_check(j, m, theta, - unlimited); + + void limit_theta_on_basis_column_for_inf_case_m_pos_boxed(unsigned j, + const T &m, + X &theta, + bool &unlimited) { + const X &x = this->m_x[j]; + const X &lbound = this->m_lower_bounds[j]; + + if (this->below_bound(x, lbound)) { + limit_theta((lbound - x) / m, theta, unlimited); } else { - limit_theta_on_basis_column_for_feas_case_m_neg_no_check(j, m, theta, - unlimited); + const X &ubound = this->m_upper_bounds[j]; + if (this->below_bound(x, ubound)) { + limit_theta((ubound - x) / m, theta, unlimited); + } else if (!this->above_bound(x, ubound)) { + theta = zero_of_type(); + unlimited = false; + } } - } else { - if (m > 0) { - limit_theta_on_basis_column_for_inf_case_m_pos_boxed(j, m, theta, - unlimited); + } + + void limit_theta_on_basis_column_for_inf_case_m_neg_boxed(unsigned j, + const T &m, + X &theta, + bool &unlimited) { + // lp_assert(m < 0 && this->m_column_type[j] == column_type::boxed); + const X &x = this->m_x[j]; + const X &ubound = this->m_upper_bounds[j]; + if (this->above_bound(x, ubound)) { + limit_theta((ubound - x) / m, theta, unlimited); } else { - limit_theta_on_basis_column_for_inf_case_m_neg_boxed(j, m, theta, - unlimited); + const X &lbound = this->m_lower_bounds[j]; + if (this->above_bound(x, lbound)) { + limit_theta((lbound - x) / m, theta, unlimited); + } else if (!this->below_bound(x, lbound)) { + theta = zero_of_type(); + unlimited = false; + } + } + } + + void limit_theta_on_basis_column_for_feas_case_m_pos_no_check( + unsigned j, const T &m, X &theta, bool &unlimited) { + lp_assert(m > 0); + limit_theta((this->m_upper_bounds[j] - this->m_x[j]) / m, theta, unlimited); + if (theta < zero_of_type()) { + theta = zero_of_type(); + } + } + + // j is a basic column or the entering, in any case x[j] has to stay feasible. + // m is the multiplier. updating t in a way that holds the following + // x[j] + t * m >= this->m_lower_bounds[j]( if m < 0 ) + // or + // x[j] + t * m <= this->m_upper_bounds[j] ( if m > 0) + void limit_theta_on_basis_column(unsigned j, T m, X &theta, bool &unlimited) { + switch (this->m_column_types[j]) { + case column_type::free_column: + break; + case column_type::upper_bound: + if (this->current_x_is_feasible()) { + if (m > 0) + limit_theta_on_basis_column_for_feas_case_m_pos_no_check(j, m, theta, + unlimited); + } else { // inside of feasibility_loop + if (m > 0) + limit_theta_on_basis_column_for_inf_case_m_pos_upper_bound( + j, m, theta, unlimited); + else + limit_theta_on_basis_column_for_inf_case_m_neg_upper_bound( + j, m, theta, unlimited); + } + break; + case column_type::lower_bound: + if (this->current_x_is_feasible()) { + if (m < 0) + limit_theta_on_basis_column_for_feas_case_m_neg_no_check(j, m, theta, + unlimited); + } else { + if (m < 0) + limit_theta_on_basis_column_for_inf_case_m_neg_lower_bound( + j, m, theta, unlimited); + else + limit_theta_on_basis_column_for_inf_case_m_pos_lower_bound( + j, m, theta, unlimited); + } + break; + // case fixed: + // if (get_this->current_x_is_feasible()) { + // theta = zero_of_type(); + // break; + // } + // if (m < 0) + // limit_theta_on_basis_column_for_inf_case_m_neg_fixed(j, m, + // theta); + // else + // limit_theta_on_basis_column_for_inf_case_m_pos_fixed(j, m, + // theta); + // break; + case column_type::fixed: + case column_type::boxed: + if (this->current_x_is_feasible()) { + if (m > 0) { + limit_theta_on_basis_column_for_feas_case_m_pos_no_check(j, m, theta, + unlimited); + } else { + limit_theta_on_basis_column_for_feas_case_m_neg_no_check(j, m, theta, + unlimited); + } + } else { + if (m > 0) { + limit_theta_on_basis_column_for_inf_case_m_pos_boxed(j, m, theta, + unlimited); + } else { + limit_theta_on_basis_column_for_inf_case_m_neg_boxed(j, m, theta, + unlimited); + } + } + + break; + default: + UNREACHABLE(); + } + if (!unlimited && theta < zero_of_type()) { + theta = zero_of_type(); } - } + } + bool correctly_moved_to_bounds(lpvar) const; + bool column_is_benefitial_for_entering_basis(unsigned j) const; + void init_infeasibility_costs(); + void print_column(unsigned j, std::ostream &out); + + void print_bound_info_and_x(unsigned j, std::ostream &out); + + bool basis_column_is_set_correctly(unsigned j) const { + return this->m_A.m_columns[j].size() == 1; + } + + bool basis_columns_are_set_correctly() const { + for (unsigned j : this->m_basis) + if (!basis_column_is_set_correctly(j)) + return false; - break; - default: - UNREACHABLE(); + return this->m_basis_heading.size() == this->m_A.column_count() && + this->m_basis.size() == this->m_A.row_count(); } - if (!unlimited && theta < zero_of_type()) { - theta = zero_of_type(); + + void init_run_tableau(); + void update_x_tableau(unsigned entering, const X &delta); + // the delta is between the old and the new cost (old - new) + void update_reduced_cost_for_basic_column_cost_change(const T &delta, + unsigned j) { + lp_assert(this->m_basis_heading[j] >= 0); + unsigned i = static_cast(this->m_basis_heading[j]); + for (const row_cell &rc : this->m_A.m_rows[i]) { + unsigned k = rc.var(); + if (k == j) + continue; + this->m_d[k] += delta * rc.coeff(); + } } - } - - bool column_is_benefitial_for_entering_basis(unsigned j) const; - void init_infeasibility_costs(); - void print_column(unsigned j, std::ostream &out); - - void print_bound_info_and_x(unsigned j, std::ostream &out); - - bool basis_column_is_set_correctly(unsigned j) const { - return this->m_A.m_columns[j].size() == 1; - } - - bool basis_columns_are_set_correctly() const { - for (unsigned j : this->m_basis) - if (!basis_column_is_set_correctly(j)) - return false; - - return this->m_basis_heading.size() == this->m_A.column_count() && - this->m_basis.size() == this->m_A.row_count(); - } - - void init_run_tableau(); - void update_x_tableau(unsigned entering, const X &delta); - // the delta is between the old and the new cost (old - new) - void update_reduced_cost_for_basic_column_cost_change(const T &delta, - unsigned j) { - lp_assert(this->m_basis_heading[j] >= 0); - unsigned i = static_cast(this->m_basis_heading[j]); - for (const row_cell &rc : this->m_A.m_rows[i]) { - unsigned k = rc.var(); - if (k == j) - continue; - this->m_d[k] += delta * rc.coeff(); + + bool update_basis_and_x_tableau(int entering, int leaving, X const &tt); + void init_reduced_costs_tableau(); + void init_tableau_rows() { + m_bland_mode_tableau = false; + m_left_basis_tableau.reset(); + m_left_basis_repeated = 0; } - } - - bool update_basis_and_x_tableau(int entering, int leaving, X const &tt); - void init_reduced_costs_tableau(); - void init_tableau_rows() { - m_bland_mode_tableau = false; - m_left_basis_tableau.clear(); - m_left_basis_tableau.resize(this->m_A.column_count()); - m_left_basis_repeated = 0; - } - // stage1 constructor - lp_primal_core_solver( - static_matrix &A, - vector &b, // the right side vector - vector &x, // the number of elements in x needs to be at least as large - // as the number of columns in A - vector &basis, vector &nbasis, vector &heading, - vector &costs, const vector &column_type_array, - const vector &lower_bound_values, const vector &upper_bound_values, - lp_settings &settings, const column_namer &column_names) - : lp_core_solver_base(A, // b, - basis, nbasis, heading, x, costs, settings, - column_names, column_type_array, - lower_bound_values, upper_bound_values), - m_bland_mode_threshold(1000) { - this->set_status(lp_status::UNKNOWN); - } - - friend core_solver_pretty_printer; + // stage1 constructor + lp_primal_core_solver( + static_matrix &A, + vector &b, // the right side vector + vector &x, // the number of elements in x needs to be at least as large + // as the number of columns in A + vector &basis, vector &nbasis, vector &heading, + vector &costs, const vector &column_type_array, + const vector &lower_bound_values, const vector &upper_bound_values, + lp_settings &settings, const column_namer &column_names) + : lp_core_solver_base(A, // b, + basis, nbasis, heading, x, costs, settings, + column_names, column_type_array, + lower_bound_values, upper_bound_values), + m_bland_mode_threshold(1000) { + this->set_status(lp_status::UNKNOWN); + } + + friend core_solver_pretty_printer; }; -} // namespace lp +} // namespace lp diff --git a/src/math/lp/lp_primal_core_solver_def.h b/src/math/lp/lp_primal_core_solver_def.h index c3c545fdd8a..f14d268a35c 100644 --- a/src/math/lp/lp_primal_core_solver_def.h +++ b/src/math/lp/lp_primal_core_solver_def.h @@ -37,51 +37,60 @@ void lp_primal_core_solver::sort_non_basis() { unsigned ca = this->m_A.number_of_non_zeroes_in_column(a); unsigned cb = this->m_A.number_of_non_zeroes_in_column(b); if (ca == 0 && cb != 0) return false; + if (ca != 0 && cb == 0) return true; return ca < cb; }); - - m_non_basis_list.clear(); - // reinit m_basis_heading - for (unsigned j = 0; j < this->m_nbasis.size(); j++) { - unsigned col = this->m_nbasis[j]; - this->m_basis_heading[col] = - static_cast(j) - 1; - m_non_basis_list.push_back(col); + m_non_basis_list.resize(this->m_nbasis.size()); + // initialize m_non_basis_list from m_nbasis by using an iterator on m_non_basis_list + auto it = m_non_basis_list.begin(); + unsigned j = 0; + for (; j < this->m_nbasis.size(); j++, ++it) { + unsigned col = *it = this->m_nbasis[j]; + this->m_basis_heading[col] = -static_cast(j) - 1; } } +template +bool lp_primal_core_solver::correctly_moved_to_bounds(unsigned j) const { + switch (this->m_column_types[j]) { + case column_type::fixed: + return this->m_x[j] == this->m_lower_bounds[j]; + case column_type::boxed: + return this->m_x[j] == this->m_lower_bounds[j] || this->m_x[j] == this->m_upper_bounds[j]; + case column_type::lower_bound: + return this->m_x[j] == this->m_lower_bounds[j]; + case column_type::upper_bound: + return this->m_x[j] == this->m_upper_bounds[j]; + case column_type::free_column: + return true; + default: + UNREACHABLE(); + return false; + } +} template bool lp_primal_core_solver::column_is_benefitial_for_entering_basis(unsigned j) const { const T& dj = this->m_d[j]; - TRACE("lar_solver", tout << "dj=" << dj << "\n";); + if (dj.is_zero()) return false; + TRACE("lar_solver", tout << "d[" << j <<"] = " << dj << "\n";); + SASSERT(correctly_moved_to_bounds(j)); switch (this->m_column_types[j]) { case column_type::fixed: break; case column_type::free_column: - if (!is_zero(dj)) - return true; - break; + return true; case column_type::lower_bound: if (dj > zero_of_type()) return true; - if (dj < 0 && this->m_x[j] > this->m_lower_bounds[j]){ - return true; - } break; case column_type::upper_bound: if (dj < zero_of_type()) return true; - if (dj > 0 && this->m_x[j] < this->m_upper_bounds[j]) { - return true; - } break; case column_type::boxed: - if (dj > zero_of_type()) { - if (this->m_x[j] < this->m_upper_bounds[j]) - return true; - break; - } else if (dj < zero_of_type()) { - if (this->m_x[j] > this->m_lower_bounds[j]) - return true; - } + if (dj > zero_of_type() && this->m_x[j] == this->m_lower_bounds[j]) + return true; + if (dj < zero_of_type() && this->m_x[j] == this->m_upper_bounds[j]) + return true; break; default: UNREACHABLE(); @@ -89,48 +98,17 @@ bool lp_primal_core_solver::column_is_benefitial_for_entering_basis(unsign } return false; } - -template bool lp_primal_core_solver::try_jump_to_another_bound_on_entering(unsigned entering, - const X & theta, - X & t, - bool & unlimited) { - switch(this->m_column_types[entering]){ - case column_type::boxed: - if (m_sign_of_entering_delta > 0) { - t = this->m_upper_bounds[entering] - this->m_x[entering]; - if (unlimited || t <= theta){ - lp_assert(t >= zero_of_type()); - return true; - } - } else { // m_sign_of_entering_delta == -1 - t = this->m_x[entering] - this->m_lower_bounds[entering]; - if (unlimited || t <= theta) { - lp_assert(t >= zero_of_type()); - return true; - } - } - return false; - case column_type::upper_bound: - if (m_sign_of_entering_delta > 0) { - t = this->m_upper_bounds[entering] - this->m_x[entering]; - if (unlimited || t <= theta){ - lp_assert(t >= zero_of_type()); - return true; - } - } - return false; - case column_type::lower_bound: - if (m_sign_of_entering_delta < 0) { - t = this->m_x[entering] - this->m_lower_bounds[entering]; - if (unlimited || t <= theta) { - lp_assert(t >= zero_of_type()); - return true; - } - } +// we assume that the columns are at their bounds +template bool lp_primal_core_solver::try_jump_to_another_bound_on_entering(unsigned entering, X & theta) { + if (this->m_column_types[entering] != column_type::boxed) return false; - default:return false; + X t = this->m_upper_bounds[entering] - this->m_lower_bounds[entering]; + if (t <= theta) { + theta = t; + return true; } return false; + } template bool lp_primal_core_solver:: diff --git a/src/math/lp/lp_primal_core_solver_tableau_def.h b/src/math/lp/lp_primal_core_solver_tableau_def.h index 898abb1525b..cec0e316b4c 100644 --- a/src/math/lp/lp_primal_core_solver_tableau_def.h +++ b/src/math/lp/lp_primal_core_solver_tableau_def.h @@ -30,7 +30,7 @@ template void lp_primal_core_solver::one_iteratio else { advance_on_entering_tableau(entering); } - lp_assert(this->inf_set_is_correct()); + lp_assert(this->inf_heap_is_correct()); } template void lp_primal_core_solver::advance_on_entering_tableau(int entering) { @@ -44,21 +44,24 @@ template void lp_primal_core_solver::advance_on_e advance_on_entering_and_leaving_tableau(entering, leaving, t); } + template int lp_primal_core_solver::choose_entering_column_tableau() { //this moment m_y = cB * B(-1) - unsigned number_of_benefitial_columns_to_go_over = get_number_of_non_basic_column_to_try_for_enter(); - - if (number_of_benefitial_columns_to_go_over == 0) - return -1; - if (this->m_basis_sort_counter == 0) { + if (this->m_nbasis_sort_counter == 0) { sort_non_basis(); - this->m_basis_sort_counter = 20; + this->m_nbasis_sort_counter = 20; } else { - this->m_basis_sort_counter--; + this->m_nbasis_sort_counter--; } + unsigned number_of_benefitial_columns_to_go_over = get_number_of_non_basic_column_to_try_for_enter(); + + if (number_of_benefitial_columns_to_go_over == 0) + return -1; + unsigned j_nz = this->m_m() + 1; // this number is greater than the max column size std::list::iterator entering_iter = m_non_basis_list.end(); + unsigned n = 0; for (auto non_basis_iter = m_non_basis_list.begin(); number_of_benefitial_columns_to_go_over && non_basis_iter != m_non_basis_list.end(); ++non_basis_iter) { unsigned j = *non_basis_iter; if (!column_is_benefitial_for_entering_basis(j)) @@ -69,13 +72,13 @@ template void lp_primal_core_solver::advance_on_e if (t < j_nz) { j_nz = t; entering_iter = non_basis_iter; - if (number_of_benefitial_columns_to_go_over) - number_of_benefitial_columns_to_go_over--; + number_of_benefitial_columns_to_go_over--; + n = 1; } - else if (t == j_nz && this->m_settings.random_next() % 2 == 0) { + else if (t == j_nz && this->m_settings.random_next(++n) == 0) { entering_iter = non_basis_iter; } - }// while (number_of_benefitial_columns_to_go_over && initial_offset_in_non_basis != offset_in_nb); + } if (entering_iter == m_non_basis_list.end()) return -1; unsigned entering = *entering_iter; @@ -96,7 +99,8 @@ unsigned lp_primal_core_solver::solve() { } do { - if (this->print_statistics_with_iterations_and_nonzeroes_and_cost_and_check_that_the_time_is_over( "feas t", * this->m_settings.get_message_ostream())) { + if (this->m_settings.get_cancel_flag()) { + this->set_status(lp_status::CANCELLED); return this->total_iterations(); } if (this->m_settings.use_tableau_rows()) { @@ -166,7 +170,8 @@ template void lp_primal_core_solver::advance_on_en } this->update_basis_and_x_tableau(entering, leaving, t); this->iters_with_no_cost_growing() = 0; - } else { + } + else { this->pivot_column_tableau(entering, this->m_basis_heading[leaving]); this->change_basis(entering, leaving); } @@ -240,10 +245,7 @@ template int lp_primal_core_solver::find_leaving_ } } - ratio = t; - unlimited = false; - if (try_jump_to_another_bound_on_entering(entering, t, ratio, unlimited)) { - t = ratio; + if (try_jump_to_another_bound_on_entering(entering, t)) { return entering; } if (m_leaving_candidates.size() == 1) @@ -253,10 +255,8 @@ template int lp_primal_core_solver::find_leaving_ } template void lp_primal_core_solver::init_run_tableau() { lp_assert(basis_columns_are_set_correctly()); - this->m_basis_sort_counter = 0; // to initiate the sort of the basis - // this->set_total_iterations(0); this->iters_with_no_cost_growing() = 0; - lp_assert(this->inf_set_is_correct()); + lp_assert(this->inf_heap_is_correct()); if (this->current_x_is_feasible() && this->m_look_for_feasible_solution_only) return; if (this->m_settings.backup_costs) diff --git a/src/math/lp/lp_settings.h b/src/math/lp/lp_settings.h index c213333e0e3..08fd365055a 100644 --- a/src/math/lp/lp_settings.h +++ b/src/math/lp/lp_settings.h @@ -53,9 +53,8 @@ inline std::ostream& operator<<(std::ostream& out, column_type const& t) { } enum class simplex_strategy_enum { - undecided = 3, - tableau_rows = 0, - tableau_costs = 1 + tableau_rows, + tableau_costs }; std::string column_type_to_string(column_type t); @@ -115,6 +114,13 @@ struct statistics { unsigned m_hnf_cutter_calls; unsigned m_hnf_cuts; unsigned m_nla_calls; + unsigned m_gomory_cuts; + unsigned m_nla_add_bounds; + unsigned m_nla_propagate_bounds; + unsigned m_nla_propagate_eq; + unsigned m_nla_lemmas; + unsigned m_nra_calls; + unsigned m_nla_bounds_improvements; unsigned m_horner_calls; unsigned m_horner_conflicts; unsigned m_cross_nested_forms; @@ -137,6 +143,7 @@ struct statistics { st.update("arith-patches-success", m_patches_success); st.update("arith-hnf-calls", m_hnf_cutter_calls); st.update("arith-hnf-cuts", m_hnf_cuts); + st.update("arith-gomory-cuts", m_gomory_cuts); st.update("arith-horner-calls", m_horner_calls); st.update("arith-horner-conflicts", m_horner_conflicts); st.update("arith-horner-cross-nested-forms", m_cross_nested_forms); @@ -144,6 +151,12 @@ struct statistics { st.update("arith-grobner-conflicts", m_grobner_conflicts); st.update("arith-offset-eqs", m_offset_eqs); st.update("arith-fixed-eqs", m_fixed_eqs); + st.update("arith-nla-add-bounds", m_nla_add_bounds); + st.update("arith-nla-propagate-bounds", m_nla_propagate_bounds); + st.update("arith-nla-propagate-eq", m_nla_propagate_eq); + st.update("arith-nla-lemmas", m_nla_lemmas); + st.update("arith-nra-calls", m_nra_calls); + st.update("arith-bounds-improvements", m_nla_bounds_improvements); } }; @@ -218,6 +231,8 @@ struct lp_settings { unsigned hnf_cut_period() const { return m_hnf_cut_period; } void set_hnf_cut_period(unsigned period) { m_hnf_cut_period = period; } unsigned random_next() { return m_rand(); } + unsigned random_next(unsigned u ) { return m_rand(u); } + void set_random_seed(unsigned s) { m_rand.set_seed(s); } bool bound_progation() const { @@ -248,7 +263,7 @@ struct lp_settings { // the method of lar solver to use simplex_strategy_enum simplex_strategy() const { return m_simplex_strategy; } - void set_simplex_strategy(simplex_strategy_enum s) { m_simplex_strategy = s; } + simplex_strategy_enum & simplex_strategy() { return m_simplex_strategy; } bool use_tableau_rows() const { return m_simplex_strategy == simplex_strategy_enum::tableau_rows; } #ifdef Z3DEBUG diff --git a/src/math/lp/lp_types.h b/src/math/lp/lp_types.h index 3f9c107c55e..5883495fad3 100644 --- a/src/math/lp/lp_types.h +++ b/src/math/lp/lp_types.h @@ -20,7 +20,9 @@ Revision History: #pragma once #include - +#include +#include "util/debug.h" +#include "util/dependency.h" namespace nla { class core; @@ -28,70 +30,12 @@ namespace nla { namespace lp { -typedef unsigned var_index; typedef unsigned constraint_index; typedef unsigned row_index; enum lconstraint_kind { LE = -2, LT = -1 , GE = 2, GT = 1, EQ = 0, NE = 3 }; typedef unsigned lpvar; const lpvar null_lpvar = UINT_MAX; const constraint_index null_ci = UINT_MAX; - -class column_index { - unsigned m_index; - friend class lar_solver; - friend class lar_term; - friend nla::core; - - operator unsigned() const { return m_index; } - -public: - column_index(unsigned j): m_index(j) {} - unsigned index() const { return m_index; } - bool is_null() const { return m_index == null_lpvar; } -}; - - -// index that comes from term or variable. -class tv { - unsigned m_index; - static const unsigned EF = UINT_MAX >> 1; - tv(unsigned i): m_index(i) {} -public: - static tv term(unsigned i) { SASSERT(0 == (i & left_most_bit)); return tv(mask_term(i)); } - static tv var(unsigned i) { SASSERT(0 == (i & left_most_bit)); return tv(i); } - static tv raw(unsigned i) { return tv(i); } - - // retrieve the identifier associated with tv - unsigned id() const { return unmask_term(m_index); } - column_index column() const { SASSERT(is_var()); return column_index(id()); } - - // retrieve the raw index. - unsigned index() const { return m_index; } - - bool is_term() const { return 0 != (m_index & left_most_bit); } - bool is_var() const { return 0 == (m_index & left_most_bit); } - - // utilities useful where tv isn't already encapsulating id's. - static inline unsigned unmask_term(unsigned j) { return j & EF; } - static inline bool is_term(unsigned j) { return j & left_most_bit; } - static inline unsigned mask_term(unsigned j) { return j | left_most_bit; } - - // used by var_register. could we encapsulate even this? - static const unsigned left_most_bit = ~EF; - - std::string to_string() const { - std::ostringstream strm; - strm << (is_term() ? "t" : "j") << id(); - return strm.str(); - } - - bool is_null() const { return m_index == UINT_MAX; } - -}; - } -inline std::ostream& operator<<(std::ostream& out, lp::tv const& t) { - return out << (t.is_term() ? "t":"j") << t.id() << "\n"; -} diff --git a/src/math/lp/lp_utils.h b/src/math/lp/lp_utils.h index 3c1383cb39e..d6943bf50a4 100644 --- a/src/math/lp/lp_utils.h +++ b/src/math/lp/lp_utils.h @@ -116,12 +116,9 @@ template std::ostream& print_linear_combination_of_column_indices_only(const vector> & coeffs, std::ostream & out) { return print_linear_combination_customized( coeffs, - [](unsigned j) {std::stringstream ss; - if (tv::is_term(j)) { - ss << "t" << tv::unmask_term(j); - } else { - ss << "j" << j; - } + [](unsigned j) { + std::stringstream ss; + ss << "j" << j; return ss.str();}, out); } diff --git a/src/math/lp/monic.h b/src/math/lp/monic.h index 884adaaf8df..b5113416629 100644 --- a/src/math/lp/monic.h +++ b/src/math/lp/monic.h @@ -21,15 +21,15 @@ namespace nla { class mon_eq { // fields - lp::var_index m_v; - svector m_vs; + lp::lpvar m_v; + svector m_vs; public: // constructors - mon_eq(lp::var_index v, unsigned sz, lp::var_index const* vs): + mon_eq(lp::lpvar v, unsigned sz, lp::lpvar const* vs): m_v(v), m_vs(sz, vs) { std::sort(m_vs.begin(), m_vs.end()); } - mon_eq(lp::var_index v, const svector &vs): + mon_eq(lp::lpvar v, const svector &vs): m_v(v), m_vs(vs) { std::sort(m_vs.begin(), m_vs.end()); } @@ -37,7 +37,7 @@ class mon_eq { unsigned var() const { return m_v; } unsigned size() const { return m_vs.size(); } - const svector& vars() const { return m_vs; } + const svector& vars() const { return m_vs; } bool empty() const { return m_vs.empty(); } bool is_sorted() const { for (unsigned i = 0; i + 1 < size(); i++) @@ -49,7 +49,7 @@ class mon_eq { return std::binary_search(m_vs.begin(), m_vs.end(), j); } protected: - svector& vars1() { return m_vs; } + svector& vars1() { return m_vs; } }; // support the congruence @@ -58,6 +58,8 @@ class monic: public mon_eq { svector m_rvars; bool m_rsign; mutable unsigned m_visited; + bool m_propagated = false; + bool m_bound_propagated = false; public: // constructors monic(lpvar v, unsigned sz, lpvar const* vs, unsigned idx): @@ -74,6 +76,10 @@ class monic: public mon_eq { void reset_rfields() { m_rsign = false; m_rvars.reset(); SASSERT(m_rvars.size() == 0); } void push_rvar(signed_var sv) { m_rsign ^= sv.sign(); m_rvars.push_back(sv.var()); } void sort_rvars() { std::sort(m_rvars.begin(), m_rvars.end()); } + void set_propagated(bool p) { m_propagated = p; } + bool is_propagated() const { return m_propagated; } + void set_bound_propagated(bool p) { m_bound_propagated = p; } + bool is_bound_propagated() const { return m_bound_propagated; } svector::const_iterator begin() const { return vars().begin(); } svector::const_iterator end() const { return vars().end(); } diff --git a/src/math/lp/monomial_bounds.cpp b/src/math/lp/monomial_bounds.cpp index 1ed0956dcc8..f2a1b0287d0 100644 --- a/src/math/lp/monomial_bounds.cpp +++ b/src/math/lp/monomial_bounds.cpp @@ -10,6 +10,7 @@ #include "math/lp/monomial_bounds.h" #include "math/lp/nla_core.h" #include "math/lp/nla_intervals.h" +#include "math/lp/numeric_pair.h" namespace nla { @@ -17,10 +18,11 @@ namespace nla { common(c), dep(c->m_intervals.get_dep_intervals()) {} - void monomial_bounds::operator()() { + void monomial_bounds::propagate() { for (lpvar v : c().m_to_refine) { - monic const& m = c().emons()[v]; - propagate(m); + propagate(c().emon(v)); + if (add_lemma()) + break; } } @@ -28,7 +30,6 @@ namespace nla { return rational(q).bitsize() > 256; } - /** * Accumulate product of variables in monomial starting at position 'start' */ @@ -51,38 +52,81 @@ namespace nla { * a bounds axiom. */ bool monomial_bounds::propagate_value(dep_interval& range, lpvar v) { - auto val = c().val(v); - if (dep.is_below(range, val)) { + + bool propagated = false; + if (should_propagate_upper(range, v, 1)) { + auto const& upper = dep.upper(range); + auto cmp = dep.upper_is_open(range) ? llc::LT : llc::LE; + ++c().lra.settings().stats().m_nla_propagate_bounds; lp::explanation ex; dep.get_upper_dep(range, ex); - auto const& upper = dep.upper(range); if (is_too_big(upper)) return false; - auto cmp = dep.upper_is_open(range) ? llc::LT : llc::LE; new_lemma lemma(c(), "propagate value - upper bound of range is below value"); lemma &= ex; lemma |= ineq(v, cmp, upper); - TRACE("nla_solver", dep.display(tout << val << " > ", range) << "\n" << lemma << "\n";); - return true; + TRACE("nla_solver", dep.display(tout << c().val(v) << " > ", range) << "\n" << lemma << "\n";); + propagated = true; } - else if (dep.is_above(range, val)) { + if (should_propagate_lower(range, v, 1)) { + auto const& lower = dep.lower(range); + auto cmp = dep.lower_is_open(range) ? llc::GT : llc::GE; + ++c().lra.settings().stats().m_nla_propagate_bounds; lp::explanation ex; dep.get_lower_dep(range, ex); - auto const& lower = dep.lower(range); if (is_too_big(lower)) return false; - auto cmp = dep.lower_is_open(range) ? llc::GT : llc::GE; new_lemma lemma(c(), "propagate value - lower bound of range is above value"); lemma &= ex; lemma |= ineq(v, cmp, lower); - TRACE("nla_solver", dep.display(tout << val << " < ", range) << "\n" << lemma << "\n";); - return true; + TRACE("nla_solver", dep.display(tout << c().val(v) << " < ", range) << "\n" << lemma << "\n";); + propagated = true; } - else { + return propagated; + } + + bool monomial_bounds::should_propagate_lower(dep_interval const& range, lpvar v, unsigned p) { + if (dep.lower_is_inf(range)) + return false; + auto bound = c().val(v); + auto const& lower = dep.lower(range); + if (p > 1) + bound = power(bound, p); + return bound < lower; + } + + bool monomial_bounds::should_propagate_upper(dep_interval const& range, lpvar v, unsigned p) { + if (dep.upper_is_inf(range)) return false; + auto bound = c().val(v); + auto const& upper = dep.upper(range); + if (p > 1) + bound = power(bound, p); + return bound > upper; + } + + /** + * Ensure that bounds are integral when the variable is integer. + */ + void monomial_bounds::propagate_bound(lpvar v, lp::lconstraint_kind cmp, rational const& q, u_dependency* d) { + SASSERT(cmp != llc::EQ && cmp != llc::NE); + if (!c().var_is_int(v)) + c().lra.update_column_type_and_bound(v, cmp, q, d); + else if (q.is_int()) { + if (cmp == llc::GT) + c().lra.update_column_type_and_bound(v, llc::GE, q + 1, d); + else if(cmp == llc::LT) + c().lra.update_column_type_and_bound(v, llc::LE, q - 1, d); + else + c().lra.update_column_type_and_bound(v, cmp, q, d); } + else if (cmp == llc::GE || cmp == llc::GT) + c().lra.update_column_type_and_bound(v, llc::GE, ceil(q), d); + else + c().lra.update_column_type_and_bound(v, llc::LE, floor(q), d); } + /** * val(v)^p should be in range. * if val(v)^p > upper(range) add @@ -97,28 +141,33 @@ namespace nla { SASSERT(p > 0); if (p == 1) return propagate_value(range, v); - auto val_v = c().val(v); - auto val = power(val_v, p); rational r; - if (dep.is_below(range, val)) { + if (should_propagate_upper(range, v, p)) { // v.upper^p > range.upper lp::explanation ex; dep.get_upper_dep(range, ex); + // p even, range.upper < 0, v^p >= 0 -> infeasible if (p % 2 == 0 && rational(dep.upper(range)).is_neg()) { + ++c().lra.settings().stats().m_nla_propagate_bounds; new_lemma lemma(c(), "range requires a non-negative upper bound"); lemma &= ex; return true; } - else if (rational(dep.upper(range)).root(p, r)) { + + if (rational(dep.upper(range)).root(p, r)) { // v = -2, [-4,-3]^3 < v^3 -> add bound v <= -3 - // v = -2, [-1,+1]^2 < v^2 -> add bound v >= -1 - if ((p % 2 == 1) || val_v.is_pos()) { + // v = -2, [-1,+1]^2 < v^2 -> add bound v >= -1 + + if ((p % 2 == 1) || c().val(v).is_pos()) { + ++c().lra.settings().stats().m_nla_propagate_bounds; auto le = dep.upper_is_open(range) ? llc::LT : llc::LE; new_lemma lemma(c(), "propagate value - root case - upper bound of range is below value"); lemma &= ex; lemma |= ineq(v, le, r); return true; } - if (p % 2 == 0 && val_v.is_neg()) { + + if (p % 2 == 0 && c().val(v).is_neg()) { + ++c().lra.settings().stats().m_nla_propagate_bounds; SASSERT(!r.is_neg()); auto ge = dep.upper_is_open(range) ? llc::GT : llc::GE; new_lemma lemma(c(), "propagate value - root case - upper bound of range is below negative value"); @@ -127,44 +176,52 @@ namespace nla { return true; } } - // TBD: add bounds as long as difference to val is above some epsilon. } - else if (dep.is_above(range, val)) { + + if (should_propagate_lower(range, v, p)) { // v.lower^p < range.lower + // + // range.lower < 0 -> v.lower >= root(p, range.lower) + // range.lower >= 0, p odd -> v.lower >= root(p, range.lower) + // range.lower >= 0, p even, v.lower >= 0 -> v.lower >= root(p, range.lower) + // default: + // v.lower >= root(p, range.lower) || (p even & v.upper <= -root(p, range.lower)) + // + // pre-condition: p even -> range.lower >= 0 + // if (rational(dep.lower(range)).root(p, r)) { - lp::explanation ex; - dep.get_lower_dep(range, ex); + ++c().lra.settings().stats().m_nla_propagate_bounds; auto ge = dep.lower_is_open(range) ? llc::GT : llc::GE; auto le = dep.lower_is_open(range) ? llc::LT : llc::LE; + lp::explanation ex; + dep.get_lower_dep(range, ex); new_lemma lemma(c(), "propagate value - root case - lower bound of range is above value"); lemma &= ex; - lemma |= ineq(v, ge, r); - if (p % 2 == 0) { - lemma |= ineq(v, le, -r); - } - return true; + lemma |= ineq(v, ge, r); + if (p % 2 == 0) + lemma |= ineq(v, le, -r); + return true; } - // TBD: add bounds as long as difference to val is above some epsilon. } return false; } void monomial_bounds::var2interval(lpvar v, scoped_dep_interval& i) { - lp::constraint_index ci; + u_dependency* d = nullptr; rational bound; bool is_strict; - if (c().has_lower_bound(v, ci, bound, is_strict)) { + if (c().has_lower_bound(v, d, bound, is_strict)) { dep.set_lower_is_open(i, is_strict); dep.set_lower(i, bound); - dep.set_lower_dep(i, dep.mk_leaf(ci)); + dep.set_lower_dep(i, d); dep.set_lower_is_inf(i, false); } else { dep.set_lower_is_inf(i, true); } - if (c().has_upper_bound(v, ci, bound, is_strict)) { + if (c().has_upper_bound(v, d, bound, is_strict)) { dep.set_upper_is_open(i, is_strict); dep.set_upper(i, bound); - dep.set_upper_dep(i, dep.mk_leaf(ci)); + dep.set_upper_dep(i, d); dep.set_upper_is_inf(i, false); } else { @@ -257,5 +314,153 @@ namespace nla { } } + void monomial_bounds::unit_propagate() { + for (lpvar v : c().m_monics_with_changed_bounds) { + if (!c().is_monic_var(v)) + continue; + monic& m = c().emon(v); + unit_propagate(m); + if (add_lemma()) + break; + if (c().m_conflicts > 0) + break; + } + } + + bool monomial_bounds::add_lemma() { + if (c().lra.get_status() != lp::lp_status::INFEASIBLE) + return false; + lp::explanation exp; + c().lra.get_infeasibility_explanation(exp); + new_lemma lemma(c(), "propagate fixed - infeasible lra"); + lemma &= exp; + return true; + } + + void monomial_bounds::unit_propagate(monic & m) { + if (m.is_propagated()) + return; + lpvar w, fixed_to_zero; + + if (!is_linear(m, w, fixed_to_zero)) + return; + + c().emons().set_propagated(m); + + if (fixed_to_zero != null_lpvar) { + propagate_fixed_to_zero(m, fixed_to_zero); + } + else { + rational k = fixed_var_product(m, w); + if (w == null_lpvar) + propagate_fixed(m, k); + else + propagate_nonfixed(m, k, w); + } + ++c().lra.settings().stats().m_nla_propagate_eq; + } + + lp::explanation monomial_bounds::get_explanation(u_dependency* dep) { + lp::explanation exp; + svector cs; + c().lra.dep_manager().linearize(dep, cs); + for (auto d : cs) + exp.add_pair(d, mpq(1)); + return exp; + } + + void monomial_bounds::propagate_fixed_to_zero(monic const& m, lpvar fixed_to_zero) { + auto* dep = c().lra.get_bound_constraint_witnesses_for_column(fixed_to_zero); + TRACE("nla_solver", tout << "propagate fixed " << m << " = 0, fixed_to_zero = " << fixed_to_zero << "\n";); + c().lra.update_column_type_and_bound(m.var(), lp::lconstraint_kind::EQ, rational(0), dep); + + // propagate fixed equality + auto exp = get_explanation(dep); + c().add_fixed_equality(m.var(), rational(0), exp); + } + + void monomial_bounds::propagate_fixed(monic const& m, rational const& k) { + auto* dep = explain_fixed(m, k); + TRACE("nla_solver", tout << "propagate fixed " << m << " = " << k << "\n";); + c().lra.update_column_type_and_bound(m.var(), lp::lconstraint_kind::EQ, k, dep); + + // propagate fixed equality + auto exp = get_explanation(dep); + c().add_fixed_equality(m.var(), k, exp); + } + + void monomial_bounds::propagate_nonfixed(monic const& m, rational const& k, lpvar w) { + vector> coeffs; + coeffs.push_back({-k, w}); + coeffs.push_back({rational::one(), m.var()}); + lp::lpvar j = c().lra.add_term(coeffs, UINT_MAX); + auto* dep = explain_fixed(m, k); + TRACE("nla_solver", tout << "propagate nonfixed " << m << " = " << k << " " << w << "\n";); + c().lra.update_column_type_and_bound(j, lp::lconstraint_kind::EQ, mpq(0), dep); + + if (k == 1) { + lp::explanation exp = get_explanation(dep); + c().add_equality(m.var(), w, exp); + } + } + + u_dependency* monomial_bounds::explain_fixed(monic const& m, rational const& k) { + u_dependency* dep = nullptr; + auto update_dep = [&](unsigned j) { + dep = c().lra.dep_manager().mk_join(dep, c().lra.get_column_lower_bound_witness(j)); + dep = c().lra.dep_manager().mk_join(dep, c().lra.get_column_upper_bound_witness(j)); + return dep; + }; + + if (k == 0) { + for (auto j : m.vars()) + if (c().var_is_fixed_to_zero(j)) + return update_dep(j); + } + else { + for (auto j : m.vars()) + if (c().var_is_fixed(j)) + update_dep(j); + } + return dep; + } + + + bool monomial_bounds::is_linear(monic const& m, lpvar& w, lpvar & fixed_to_zero) { + w = fixed_to_zero = null_lpvar; + for (lpvar v : m) { + if (!c().var_is_fixed(v)) { + if (w != null_lpvar) + return false; + w = v; + } + else if (c().get_lower_bound(v).is_zero()) { + fixed_to_zero = v; + return true; + } + } + return true; + } + + + rational monomial_bounds::fixed_var_product(monic const& m, lpvar w) { + rational r(1); + for (lpvar v : m) { + // we have to use the column bounds here, because the column value may be outside the bounds + if (v != w ){ + SASSERT(c().var_is_fixed(v)); + r *= c().lra.get_lower_bound(v).x; + } + } + return r; + } + + lpvar monomial_bounds::non_fixed_var(monic const& m) { + for (lpvar v : m) + if (!c().var_is_fixed(v)) + return v; + return null_lpvar; + } + } diff --git a/src/math/lp/monomial_bounds.h b/src/math/lp/monomial_bounds.h index 236f29bc068..19043e072d8 100644 --- a/src/math/lp/monomial_bounds.h +++ b/src/math/lp/monomial_bounds.h @@ -10,12 +10,16 @@ #include "math/lp/nla_common.h" #include "math/lp/nla_intervals.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" namespace nla { class core; class monomial_bounds : common { dep_intervals& dep; + + bool should_propagate_lower(dep_interval const& range, lpvar v, unsigned p); + bool should_propagate_upper(dep_interval const& range, lpvar v, unsigned p); + void propagate_bound(lpvar v, lp::lconstraint_kind cmp, rational const& q, u_dependency* d); void var2interval(lpvar v, scoped_dep_interval& i); bool is_too_big(mpq const& q) const; bool propagate_down(monic const& m, lpvar u); @@ -23,12 +27,25 @@ namespace nla { bool propagate_value(dep_interval& range, lpvar v, unsigned power); void compute_product(unsigned start, monic const& m, scoped_dep_interval& i); bool propagate(monic const& m); + void propagate_fixed_to_zero(monic const& m, lpvar fixed_to_zero); + void propagate_fixed(monic const& m, rational const& k); + void propagate_nonfixed(monic const& m, rational const& k, lpvar w); + u_dependency* explain_fixed(monic const& m, rational const& k); + lp::explanation get_explanation(u_dependency* dep); bool propagate_down(monic const& m, dep_interval& mi, lpvar v, unsigned power, dep_interval& product); void analyze_monomial(monic const& m, unsigned& num_free, lpvar& free_v, unsigned& power) const; bool is_free(lpvar v) const; bool is_zero(lpvar v) const; + bool add_lemma(); + + // monomial propagation + void unit_propagate(monic & m); + bool is_linear(monic const& m, lpvar& w, lpvar & fixed_to_zero); + rational fixed_var_product(monic const& m, lpvar w); + lpvar non_fixed_var(monic const& m); public: monomial_bounds(core* core); - void operator()(); + void propagate(); + void unit_propagate(); }; } diff --git a/src/math/lp/nla_basics_lemmas.cpp b/src/math/lp/nla_basics_lemmas.cpp index 7124fd4099c..705247aa145 100644 --- a/src/math/lp/nla_basics_lemmas.cpp +++ b/src/math/lp/nla_basics_lemmas.cpp @@ -104,7 +104,7 @@ bool basics::basic_sign_lemma_model_based() { return true; } } - return c().m_lemma_vec->size() > 0; + return c().m_lemmas.size() > 0; } @@ -331,7 +331,7 @@ bool basics::basic_lemma_for_mon_neutral_derived(const monic& rm, const factoriz for (auto fc : f) { lpvar j = var(fc); all_int &= c().var_is_int(j); - if (j == null_lpvar && abs(val(j)) == abs_mv && + if (u == null_lpvar && abs(val(j)) == abs_mv && c().vars_are_equiv(j, mon_var) && (mon_var_is_sep_from_zero || c().var_is_separated_from_zero(j))) u = j; diff --git a/src/math/lp/nla_common.cpp b/src/math/lp/nla_common.cpp index 45898c6139e..eceacbcb418 100644 --- a/src/math/lp/nla_common.cpp +++ b/src/math/lp/nla_common.cpp @@ -60,22 +60,20 @@ unsigned common::random() { } void common::add_deps_of_fixed(lpvar j, u_dependency*& dep) { - unsigned lc, uc; - auto& dep_manager = c().m_intervals.get_dep_intervals().dep_manager(); - c().m_lar_solver.get_bound_constraint_witnesses_for_column(j, lc, uc); - dep = dep_manager.mk_join(dep, dep_manager.mk_leaf(lc)); - dep = dep_manager.mk_join(dep, dep_manager.mk_leaf(uc)); + auto& dm = c().lra.dep_manager(); + auto* deps = c().lra.get_bound_constraint_witnesses_for_column(j); + dep = dm.mk_join(dep, deps); } // creates a nex expression for the coeff and var, nex * common::nexvar(const rational & coeff, lpvar j, nex_creator& cn, u_dependency*& dep) { SASSERT(!coeff.is_zero()); - if (c().m_nla_settings.horner_subs_fixed == 1 && c().var_is_fixed(j)) { + if (c().params().arith_nl_horner_subs_fixed() == 1 && c().var_is_fixed(j)) { add_deps_of_fixed(j, dep); - return cn.mk_scalar(coeff * c().m_lar_solver.column_lower_bound(j).x); + return cn.mk_scalar(coeff * c().lra.column_lower_bound(j).x); } - if (c().m_nla_settings.horner_subs_fixed == 2 && c().var_is_fixed_to_zero(j)) { + if (c().params().arith_nl_horner_subs_fixed() == 2 && c().var_is_fixed_to_zero(j)) { add_deps_of_fixed(j, dep); return cn.mk_scalar(rational(0)); } @@ -89,10 +87,10 @@ nex * common::nexvar(const rational & coeff, lpvar j, nex_creator& cn, u_depende mf *= coeff; u_dependency * initial_dep = dep; for (lpvar k : m.vars()) { - if (c().m_nla_settings.horner_subs_fixed == 1 && c().var_is_fixed(k)) { + if (c().params().arith_nl_horner_subs_fixed() == 1 && c().var_is_fixed(k)) { add_deps_of_fixed(k, dep); - mf *= c().m_lar_solver.column_lower_bound(k).x; - } else if (c().m_nla_settings.horner_subs_fixed == 2 && + mf *= c().lra.column_lower_bound(k).x; + } else if (c().params().arith_nl_horner_subs_fixed() == 2 && c().var_is_fixed_to_zero(k)) { dep = initial_dep; add_deps_of_fixed(k, dep); diff --git a/src/math/lp/nla_common.h b/src/math/lp/nla_common.h index 1302c390944..731b4766e9c 100644 --- a/src/math/lp/nla_common.h +++ b/src/math/lp/nla_common.h @@ -80,21 +80,6 @@ struct common { bool check_monic(const monic&) const; unsigned random(); void add_deps_of_fixed(lpvar j, u_dependency*& dep); - class ci_value_manager { - public: - void inc_ref(lp::constraint_index const & v) { - } - - void dec_ref(lp::constraint_index const & v) { - } - }; - - struct u_dependency_config { - typedef ci_value_manager value_manager; - typedef region allocator; - static const bool ref_count = false; - typedef lp::constraint_index value; - }; nex* nexvar(const rational& coeff, lpvar j, nex_creator&, u_dependency*&); template diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 4d1cc6edb65..675bf502231 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -17,14 +17,15 @@ Module Name: #include "math/grobner/pdd_solver.h" #include "math/dd/pdd_interval.h" #include "math/dd/pdd_eval.h" -namespace nla { +using namespace nla; typedef lp::lar_term term; -core::core(lp::lar_solver& s, reslimit & lim) : +core::core(lp::lar_solver& s, params_ref const& p, reslimit & lim) : m_evars(), - m_lar_solver(s), + lra(s), m_reslim(lim), + m_params(p), m_tangents(this), m_basics(this), m_order(this), @@ -39,7 +40,15 @@ core::core(lp::lar_solver& s, reslimit & lim) : m_use_nra_model(false), m_nra(s, m_nra_lim, *this) { - m_nlsat_delay = lp_settings().nlsat_delay(); + m_nlsat_delay_bound = lp_settings().nlsat_delay(); + lra.m_find_monics_with_changed_bounds_func = [&](const indexed_uint_set& columns_with_changed_bounds) { + for (lpvar j : columns_with_changed_bounds) { + if (is_monic_var(j)) + m_monics_with_changed_bounds.insert(j); + for (const auto & m: m_emons.get_use_list(j)) + m_monics_with_changed_bounds.insert(m.var()); + } + }; } bool core::compare_holds(const rational& ls, llc cmp, const rational& rs) const { @@ -59,21 +68,10 @@ bool core::compare_holds(const rational& ls, llc cmp, const rational& rs) const rational core::value(const lp::lar_term& r) const { rational ret(0); for (lp::lar_term::ival t : r) - ret += t.coeff() * val(t.column()); + ret += t.coeff() * val(t.j()); return ret; } -lp::lar_term core::subs_terms_to_columns(const lp::lar_term& t) const { - lp::lar_term r; - for (lp::lar_term::ival p : t) { - lpvar j = p.column(); - if (lp::tv::is_term(j)) - j = m_lar_solver.map_term_index_to_column_index(j); - r.add_monomial(p.coeff(), j); - } - return r; -} - bool core::ineq_holds(const ineq& n) const { return compare_holds(value(n.term()), n.cmp(), n.rs()); } @@ -130,12 +128,10 @@ bool core::canonize_sign(const factorization& f) const { void core::add_monic(lpvar v, unsigned sz, lpvar const* vs) { m_add_buffer.resize(sz); for (unsigned i = 0; i < sz; i++) { - lpvar j = vs[i]; - if (lp::tv::is_term(j)) - j = m_lar_solver.map_term_index_to_column_index(j); - m_add_buffer[i] = j; + m_add_buffer[i] = vs[i]; } m_emons.add(v, m_add_buffer); + m_monics_with_changed_bounds.insert(v); } void core::push() { @@ -153,7 +149,7 @@ void core::pop(unsigned n) { rational core::product_value(const monic& m) const { rational r(1); for (auto j : m.vars()) { - r *= m_lar_solver.get_column_value(j).x; + r *= lra.get_column_value(j).x; } return r; } @@ -166,8 +162,10 @@ bool core::check_monic(const monic& m) const { if (!is_relevant(m.var())) return true; #endif - SASSERT((!m_lar_solver.column_is_int(m.var())) || m_lar_solver.get_column_value(m.var()).is_int()); - bool ret = product_value(m) == m_lar_solver.get_column_value(m.var()).x; + if (lra.column_is_int(m.var()) && !lra.get_column_value(m.var()).is_int()) + return true; + + bool ret = product_value(m) == lra.get_column_value(m.var()).x; CTRACE("nla_solver_check_monic", !ret, print_monic(m, tout) << '\n';); return ret; } @@ -179,7 +177,7 @@ std::ostream& core::print_product(const T & m, std::ostream& out) const { for (lpvar v : m) { if (!first) out << "*"; else first = false; if (lp_settings().print_external_var_name()) - out << "(" << m_lar_solver.get_variable_name(v) << "=" << val(v) << ")"; + out << "(" << lra.get_variable_name(v) << "=" << val(v) << ")"; else out << "(j" << v << " = " << val(v) << ")"; @@ -225,7 +223,7 @@ std::ostream & core::print_factor_with_vars(const factor& f, std::ostream& out) std::ostream& core::print_monic(const monic& m, std::ostream& out) const { if (lp_settings().print_external_var_name()) - out << "([" << m.var() << "] = " << m_lar_solver.get_variable_name(m.var()) << " = " << val(m.var()) << " = "; + out << "([" << m.var() << "] = " << lra.get_variable_name(m.var()) << " = " << val(m.var()) << " = "; else out << "(j" << m.var() << " = " << val(m.var()) << " = "; print_product(m.vars(), out) << ")\n"; @@ -268,7 +266,7 @@ std::ostream& core::print_explanation(const lp::explanation& exp, std::ostream& unsigned i = 0; for (auto p : exp) { out << "(" << p.ci() << ")"; - m_lar_solver.constraints().display(out, [this](lpvar j) { return var_str(j);}, p.ci()); + lra.constraints().display(out, [this](lpvar j) { return var_str(j);}, p.ci()); if (++i < exp.size()) out << " "; } @@ -313,43 +311,41 @@ bool core::explain_lower_bound(const lp::lar_term& t, const rational& rs, lp::ex bool core::explain_coeff_lower_bound(const lp::lar_term::ival& p, rational& bound, lp::explanation& e) const { const rational& a = p.coeff(); SASSERT(!a.is_zero()); - unsigned c; // the index for the lower or the upper bound if (a.is_pos()) { - unsigned c = m_lar_solver.get_column_lower_bound_witness(p.column()); - if (c + 1 == 0) + auto* dep = lra.get_column_lower_bound_witness(p.j()); + if (!dep) return false; - bound = a * m_lar_solver.get_lower_bound(p.column()).x; - e.push_back(c); + bound = a * lra.get_lower_bound(p.j()).x; + lra.push_explanation(dep, e); return true; } // a.is_neg() - c = m_lar_solver.get_column_upper_bound_witness(p.column()); - if (c + 1 == 0) + auto* dep = lra.get_column_upper_bound_witness(p.j()); + if (!dep) return false; - bound = a * m_lar_solver.get_upper_bound(p.column()).x; - e.push_back(c); + bound = a * lra.get_upper_bound(p.j()).x; + lra.push_explanation(dep, e); return true; } bool core::explain_coeff_upper_bound(const lp::lar_term::ival& p, rational& bound, lp::explanation& e) const { const rational& a = p.coeff(); - lpvar j = p.column(); + lpvar j = p.j(); SASSERT(!a.is_zero()); - unsigned c; // the index for the lower or the upper bound if (a.is_neg()) { - unsigned c = m_lar_solver.get_column_lower_bound_witness(j); - if (c + 1 == 0) + auto *dep = lra.get_column_lower_bound_witness(j); + if (!dep) return false; - bound = a * m_lar_solver.get_lower_bound(j).x; - e.push_back(c); + bound = a * lra.get_lower_bound(j).x; + lra.push_explanation(dep, e); return true; } // a.is_pos() - c = m_lar_solver.get_column_upper_bound_witness(j); - if (c + 1 == 0) + auto* dep = lra.get_column_upper_bound_witness(j); + if (!dep) return false; - bound = a * m_lar_solver.get_upper_bound(j).x; - e.push_back(c); + bound = a * lra.get_upper_bound(j).x; + lra.push_explanation(dep, e); return true; } @@ -410,12 +406,12 @@ bool core::explain_by_equiv(const lp::lar_term& t, lp::explanation& e) const { return false; m_evars.explain(signed_var(i, false), signed_var(j, sign), e); - TRACE("nla_solver", tout << "explained :"; m_lar_solver.print_term_as_indices(t, tout);); + TRACE("nla_solver", tout << "explained :"; lra.print_term_as_indices(t, tout);); return true; } void core::mk_ineq_no_expl_check(new_lemma& lemma, lp::lar_term& t, llc cmp, const rational& rs) { - TRACE("nla_solver_details", m_lar_solver.print_term_as_indices(t, tout << "t = ");); + TRACE("nla_solver_details", lra.print_term_as_indices(t, tout << "t = ");); lemma |= ineq(cmp, t, rs); CTRACE("nla_solver", ineq_holds(ineq(cmp, t, rs)), print_ineq(ineq(cmp, t, rs), tout) << "\n";); SASSERT(!ineq_holds(ineq(cmp, t, rs))); @@ -482,18 +478,18 @@ int core::vars_sign(const svector& v) { } bool core::has_upper_bound(lpvar j) const { - return m_lar_solver.column_has_upper_bound(j); + return lra.column_has_upper_bound(j); } bool core::has_lower_bound(lpvar j) const { - return m_lar_solver.column_has_lower_bound(j); + return lra.column_has_lower_bound(j); } const rational& core::get_upper_bound(unsigned j) const { - return m_lar_solver.get_upper_bound(j).x; + return lra.get_upper_bound(j).x; } const rational& core::get_lower_bound(unsigned j) const { - return m_lar_solver.get_lower_bound(j).x; + return lra.get_lower_bound(j).x; } bool core::zero_is_an_inner_point_of_bounds(lpvar j) const { @@ -537,43 +533,40 @@ bool core::sign_contradiction(const monic& m) const { bool core::var_is_fixed_to_zero(lpvar j) const { return - m_lar_solver.column_is_fixed(j) && - m_lar_solver.get_lower_bound(j) == lp::zero_of_type(); + lra.column_is_fixed(j) && + lra.get_lower_bound(j) == lp::zero_of_type(); } bool core::var_is_fixed_to_val(lpvar j, const rational& v) const { return - m_lar_solver.column_is_fixed(j) && - m_lar_solver.get_lower_bound(j) == lp::impq(v); + lra.column_is_fixed(j) && + lra.get_lower_bound(j) == lp::impq(v); } bool core::var_is_fixed(lpvar j) const { - return m_lar_solver.column_is_fixed(j); + return lra.column_is_fixed(j); } bool core::var_is_free(lpvar j) const { - return m_lar_solver.column_is_free(j); + return lra.column_is_free(j); } - std::ostream & core::print_ineq(const ineq & in, std::ostream & out) const { - m_lar_solver.print_term_as_indices(in.term(), out); - out << " " << lconstraint_kind_string(in.cmp()) << " " << in.rs(); - return out; + lra.print_term_as_indices(in.term(), out); + return out << " " << lconstraint_kind_string(in.cmp()) << " " << in.rs(); } std::ostream & core::print_var(lpvar j, std::ostream & out) const { - if (is_monic_var(j)) { + if (is_monic_var(j)) print_monic(m_emons[j], out); - } - m_lar_solver.print_column_info(j, out); + lra.print_column_info(j, out); signed_var jr = m_evars.find(j); out << "root="; if (jr.sign()) { out << "-"; } - out << m_lar_solver.get_variable_name(jr.var()) << "\n"; + out << lra.get_variable_name(jr.var()) << "\n"; return out; } @@ -595,7 +588,7 @@ std::ostream & core::print_ineqs(const lemma& l, std::ostream & out) const { print_ineq(in, out); if (i + 1 < l.ineqs().size()) out << " or "; for (lp::lar_term::ival p: in.term()) - vars.insert(p.column()); + vars.insert(p.j()); } out << std::endl; for (lpvar j : vars) { @@ -640,11 +633,11 @@ void core::trace_print_monic_and_factorization(const monic& rm, const factorizat bool core::var_has_positive_lower_bound(lpvar j) const { - return m_lar_solver.column_has_lower_bound(j) && m_lar_solver.get_lower_bound(j) > lp::zero_of_type(); + return lra.column_has_lower_bound(j) && lra.get_lower_bound(j) > lp::zero_of_type(); } bool core::var_has_negative_upper_bound(lpvar j) const { - return m_lar_solver.column_has_upper_bound(j) && m_lar_solver.get_upper_bound(j) < lp::zero_of_type(); + return lra.column_has_upper_bound(j) && lra.get_upper_bound(j) < lp::zero_of_type(); } bool core::var_is_separated_from_zero(lpvar j) const { @@ -681,10 +674,10 @@ template bool core::mon_has_zero(const unsigned_vector& product lp::lp_settings& core::lp_settings() { - return m_lar_solver.settings(); + return lra.settings(); } const lp::lp_settings& core::lp_settings() const { - return m_lar_solver.settings(); + return lra.settings(); } unsigned core::random() { return lp_settings().random_next(); } @@ -692,15 +685,15 @@ unsigned core::random() { return lp_settings().random_next(); } // we look for octagon constraints here, with a left part +-x +- y void core::collect_equivs() { - const lp::lar_solver& s = m_lar_solver; + const lp::lar_solver& s = lra; - for (unsigned i = 0; i < s.terms().size(); i++) { - if (!s.term_is_used_as_row(i)) + for (const auto * t : s.terms()) { + if (!s.column_associated_with_row(t->j())) continue; - lpvar j = s.external_to_local(lp::tv::mask_term(i)); + lpvar j = t->j(); if (var_is_fixed_to_zero(j)) { - TRACE("nla_solver_mons", s.print_term_as_indices(*s.terms()[i], tout << "term = ") << "\n";); - add_equivalence_maybe(s.terms()[i], s.get_column_upper_bound_witness(j), s.get_column_lower_bound_witness(j)); + TRACE("nla_solver_mons", s.print_term_as_indices(*t, tout << "term = ") << "\n";); + add_equivalence_maybe(t, s.get_column_upper_bound_witness(j), s.get_column_lower_bound_witness(j)); } } m_emons.ensure_canonized(); @@ -725,16 +718,16 @@ bool core::is_octagon_term(const lp::lar_term& t, bool & sign, lpvar& i, lpvar & return false; } if (i == null_lpvar) - i = p.column(); + i = p.j(); else - j = p.column(); + j = p.j(); } SASSERT(j != null_lpvar); sign = (seen_minus && seen_plus)? false : true; return true; } -void core::add_equivalence_maybe(const lp::lar_term *t, lpci c0, lpci c1) { +void core::add_equivalence_maybe(const lp::lar_term* t, u_dependency* c0, u_dependency* c1) { bool sign; lpvar i, j; if (!is_octagon_term(*t, sign, i, j)) @@ -807,7 +800,12 @@ void core::print_stats(std::ostream& out) { void core::clear() { - m_lemma_vec->clear(); + m_lemmas.clear(); + m_literals.clear(); + m_fixed_equalities.clear(); + m_equalities.clear(); + m_conflicts = 0; + m_check_feasible = false; } void core::init_search() { @@ -826,14 +824,14 @@ void core::insert_to_refine(lpvar j) { void core::erase_from_to_refine(lpvar j) { TRACE("lar_solver", tout << "j=" << j << '\n';); - m_to_refine.erase(j); + if (m_to_refine.contains(j)) + m_to_refine.remove(j); } void core::init_to_refine() { TRACE("nla_solver_details", tout << "emons:" << pp_emons(*this, m_emons);); - m_to_refine.clear(); - m_to_refine.resize(m_lar_solver.number_of_vars()); + m_to_refine.reset(); unsigned r = random(), sz = m_emons.number_of_monics(); for (unsigned k = 0; k < sz; k++) { auto const & m = *(m_emons.begin() + (k + r)% sz); @@ -859,11 +857,11 @@ std::unordered_set core::collect_vars(const lemma& l) const { for (const auto& i : l.ineqs()) { for (lp::lar_term::ival p : i.term()) { - insert_j(p.column()); + insert_j(p.j()); } } for (auto p : l.expl()) { - const auto& c = m_lar_solver.constraints()[p.ci()]; + const auto& c = lra.constraints()[p.ci()]; for (const auto& r : c.coeffs()) { insert_j(r.second); } @@ -1043,7 +1041,7 @@ rational core::val(const factorization& f) const { } new_lemma::new_lemma(core& c, char const* name):name(name), c(c) { - c.m_lemma_vec->push_back(lemma()); + c.m_lemmas.push_back(lemma()); } new_lemma& new_lemma::operator|=(ineq const& ineq) { @@ -1061,11 +1059,14 @@ new_lemma::~new_lemma() { (void)i; (void)name; // code for checking lemma can be added here + if (current().is_conflict()) { + c.m_conflicts++; + } TRACE("nla_solver", tout << name << " " << (++i) << "\n" << *this; ); } lemma& new_lemma::current() const { - return c.m_lemma_vec->back(); + return c.m_lemmas.back(); } new_lemma& new_lemma::operator&=(lp::explanation const& e) { @@ -1122,8 +1123,8 @@ new_lemma& new_lemma::explain_equiv(lpvar a, lpvar b) { new_lemma& new_lemma::explain_var_separated_from_zero(lpvar j) { SASSERT(c.var_is_separated_from_zero(j)); - if (c.m_lar_solver.column_has_upper_bound(j) && - (c.m_lar_solver.get_upper_bound(j)< lp::zero_of_type())) + if (c.lra.column_has_upper_bound(j) && + (c.lra.get_upper_bound(j)< lp::zero_of_type())) explain_existing_upper_bound(j); else explain_existing_lower_bound(j); @@ -1133,7 +1134,7 @@ new_lemma& new_lemma::explain_var_separated_from_zero(lpvar j) { new_lemma& new_lemma::explain_existing_lower_bound(lpvar j) { SASSERT(c.has_lower_bound(j)); lp::explanation ex; - ex.push_back(c.m_lar_solver.get_column_lower_bound_witness(j)); + c.lra.push_explanation(c.lra.get_column_lower_bound_witness(j), ex); *this &= ex; TRACE("nla_solver", tout << j << ": " << *this << "\n";); return *this; @@ -1142,7 +1143,7 @@ new_lemma& new_lemma::explain_existing_lower_bound(lpvar j) { new_lemma& new_lemma::explain_existing_upper_bound(lpvar j) { SASSERT(c.has_upper_bound(j)); lp::explanation ex; - ex.push_back(c.m_lar_solver.get_column_upper_bound_witness(j)); + c.lra.push_explanation(c.lra.get_column_upper_bound_witness(j), ex); *this &= ex; return *this; } @@ -1152,7 +1153,7 @@ std::ostream& new_lemma::display(std::ostream & out) const { for (auto p : lemma.expl()) { out << "(" << p.ci() << ") "; - c.m_lar_solver.constraints().display(out, [this](lpvar j) { return c.var_str(j);}, p.ci()); + c.lra.constraints().display(out, [this](lpvar j) { return c.var_str(j);}, p.ci()); } out << " ==> "; if (lemma.ineqs().empty()) { @@ -1178,15 +1179,11 @@ void core::negate_relation(new_lemma& lemma, unsigned j, const rational& a) { } bool core::conflict_found() const { - for (const auto & l : * m_lemma_vec) { - if (l.is_conflict()) - return true; - } - return false; + return any_of(m_lemmas, [&](const auto& l) { return l.is_conflict(); }); } bool core::done() const { - return m_lemma_vec->size() >= 10 || + return m_lemmas.size() >= 10 || conflict_found() || lp_settings().get_cancel_flag(); } @@ -1237,7 +1234,7 @@ bool core::var_breaks_correct_monic_as_factor(lpvar j, const monic& m) const { bool core::var_breaks_correct_monic(lpvar j) const { if (is_monic_var(j) && !m_to_refine.contains(j)) { - TRACE("nla_solver", tout << "j = " << j << ", m = "; print_monic(emons()[j], tout) << "\n";); + TRACE("nla_solver", tout << "j = " << j << ", m = "; print_monic(emon(j), tout) << "\n";); return true; // changing the value of a correct monic } @@ -1259,7 +1256,7 @@ void core::update_to_refine_of_var(lpvar j) { insert_to_refine(var(m)); } if (is_monic_var(j)) { - const monic& m = emons()[j]; + const monic& m = emon(j); if (var_val(m) == mul_val(m)) erase_from_to_refine(j); else @@ -1300,7 +1297,7 @@ bool core::has_real(const monic& m) const { bool core::is_patch_blocked(lpvar u, const lp::impq& ival) const { TRACE("nla_solver", tout << "u = " << u << '\n';); if (m_cautious_patching && - (!m_lar_solver.inside_bounds(u, ival) || (var_is_int(u) && ival.is_int() == false))) { + (!lra.inside_bounds(u, ival) || (var_is_int(u) && ival.is_int() == false))) { TRACE("nla_solver", tout << "u = " << u << " blocked, for feas or integr\n";); return true; // block } @@ -1331,7 +1328,7 @@ bool core::is_patch_blocked(lpvar u, const lp::impq& ival) const { bool core::try_to_patch(const rational& v) { auto is_blocked = [this](lpvar u, const lp::impq& iv) { return is_patch_blocked(u, iv); }; auto change_report = [this](lpvar u) { update_to_refine_of_var(u); }; - return m_lar_solver.try_to_patch(m_patched_var, v, is_blocked, change_report); + return lra.try_to_patch(m_patched_var, v, is_blocked, change_report); } bool in_power(const svector& vs, unsigned l) { @@ -1340,12 +1337,12 @@ bool in_power(const svector& vs, unsigned l) { } bool core::to_refine_is_correct() const { - for (unsigned j = 0; j < m_lar_solver.number_of_vars(); j++) { + for (unsigned j = 0; j < lra.number_of_vars(); j++) { if (!is_monic_var(j)) continue; - bool valid = check_monic(emons()[j]); + bool valid = check_monic(emon(j)); if (valid == m_to_refine.contains(j)) { TRACE("nla_solver", tout << "inconstency in m_to_refine : "; - print_monic(emons()[j], tout) << "\n"; + print_monic(emon(j), tout) << "\n"; if (valid) tout << "should NOT be in to_refine\n"; else tout << "should be in to_refine\n";); return false; @@ -1355,7 +1352,7 @@ bool core::to_refine_is_correct() const { } void core::patch_monomial(lpvar j) { - m_patched_monic =& (emons()[j]); + m_patched_monic =& (emon(j)); m_patched_var = j; TRACE("nla_solver", tout << "m = "; print_monic(*m_patched_monic, tout) << "\n";); rational v = mul_val(*m_patched_monic); @@ -1405,62 +1402,24 @@ void core::patch_monomial(lpvar j) { } void core::patch_monomials_on_to_refine() { - auto to_refine = m_to_refine.index(); // the rest of the function might change m_to_refine, so have to copy + unsigned_vector to_refine; + for (unsigned j : m_to_refine) + to_refine.push_back(j); + unsigned sz = to_refine.size(); unsigned start = random(); - for (unsigned i = 0; i < sz; i++) { + for (unsigned i = 0; i < sz && !m_to_refine.empty(); i++) patch_monomial(to_refine[(start + i) % sz]); - if (m_to_refine.size() == 0) - break; - } + TRACE("nla_solver", tout << "sz = " << sz << ", m_to_refine = " << m_to_refine.size() << - (sz > m_to_refine.size()? " less" : "same" ) << "\n";); + (sz > m_to_refine.size()? " less" : " same" ) << "\n";); } void core::patch_monomials() { m_cautious_patching = true; patch_monomials_on_to_refine(); - if (m_to_refine.size() == 0 || !m_nla_settings.expensive_patching) { - return; - } - NOT_IMPLEMENTED_YET(); - m_cautious_patching = false; - patch_monomials_on_to_refine(); - m_lar_solver.push(); - save_tableau(); - constrain_nl_in_tableau(); - if (solve_tableau() && integrality_holds()) { - m_lar_solver.pop(1); - } else { - m_lar_solver.pop(); - restore_tableau(); - m_lar_solver.clear_inf_set(); - } - SASSERT(m_lar_solver.ax_is_correct()); -} - -void core::constrain_nl_in_tableau() { - NOT_IMPLEMENTED_YET(); -} - -bool core::solve_tableau() { - NOT_IMPLEMENTED_YET(); - return false; -} - -void core::restore_tableau() { - NOT_IMPLEMENTED_YET(); -} - -void core::save_tableau() { - NOT_IMPLEMENTED_YET(); -} - -bool core::integrality_holds() { - NOT_IMPLEMENTED_YET(); - return false; } /** @@ -1471,7 +1430,7 @@ void core::check_weighted(unsigned sz, std::pair 0 && !done() && m_lemma_vec->empty()) { + while (bound > 0 && !done() && m_lemmas.empty()) { unsigned n = random() % bound; for (unsigned i = 0; i < sz; ++i) { if (seen.contains(i)) @@ -1487,68 +1446,93 @@ void core::check_weighted(unsigned sz, std::pair& l_vec) { - m_lemma_vec = &l_vec; - return m_powers.check(r, x, y, l_vec); +lbool core::check_power(lpvar r, lpvar x, lpvar y) { + clear(); + return m_powers.check(r, x, y, m_lemmas); } -void core::check_bounded_divisions(vector& l_vec) { - m_lemma_vec = &l_vec; +void core::check_bounded_divisions() { + clear(); m_divisions.check_bounded_divisions(); } +// looking for a free variable inside of a monic to split +void core::add_bounds() { + unsigned r = random(), sz = m_to_refine.size(); + for (unsigned k = 0; k < sz; k++) { + lpvar i = m_to_refine[(k + r) % sz]; + auto const& m = m_emons[i]; + for (lpvar j : m.vars()) { + if (!var_is_free(j)) + continue; + if (m.is_bound_propagated()) + continue; + m_emons.set_bound_propagated(m); + // split the free variable (j <= 0, or j > 0), and return + m_literals.push_back(ineq(j, lp::lconstraint_kind::EQ, rational::zero())); + TRACE("nla_solver", print_ineq(m_literals.back(), tout) << "\n"); + ++lp_settings().stats().m_nla_add_bounds; + return; + } + } +} -lbool core::check(vector& l_vec) { +lbool core::check() { lp_settings().stats().m_nla_calls++; TRACE("nla_solver", tout << "calls = " << lp_settings().stats().m_nla_calls << "\n";); - m_lar_solver.get_rid_of_inf_eps(); - m_lemma_vec = &l_vec; - if (!(m_lar_solver.get_status() == lp::lp_status::OPTIMAL || - m_lar_solver.get_status() == lp::lp_status::FEASIBLE)) { - TRACE("nla_solver", tout << "unknown because of the m_lar_solver.m_status = " << m_lar_solver.get_status() << "\n";); + lra.get_rid_of_inf_eps(); + if (!(lra.get_status() == lp::lp_status::OPTIMAL || + lra.get_status() == lp::lp_status::FEASIBLE)) { + TRACE("nla_solver", tout << "unknown because of the lra.m_status = " << lra.get_status() << "\n";); return l_undef; } init_to_refine(); patch_monomials(); set_use_nra_model(false); - if (m_to_refine.empty()) { return l_true; } + if (m_to_refine.empty()) + return l_true; init_search(); lbool ret = l_undef; bool run_grobner = need_run_grobner(); bool run_horner = need_run_horner(); - bool run_bounded_nlsat = should_run_bounded_nlsat(); + bool run_bounds = params().arith_nl_branching(); - if (l_vec.empty() && !done()) - m_monomial_bounds(); + auto no_effect = [&]() { return ret == l_undef && !done() && m_lemmas.empty() && m_literals.empty() && !m_check_feasible; }; - if (l_vec.empty() && !done() && run_horner) - m_horner.horner_lemmas(); - - if (l_vec.empty() && !done() && run_grobner) - m_grobner(); + if (no_effect()) + m_monomial_bounds.propagate(); + + { + std::function check1 = [&]() { if (no_effect() && run_horner) m_horner.horner_lemmas(); }; + std::function check2 = [&]() { if (no_effect() && run_grobner) m_grobner(); }; + std::function check3 = [&]() { if (no_effect() && run_bounds) add_bounds(); }; + + std::pair> checks[] = + { {1, check1}, + {1, check2}, + {1, check3} }; + check_weighted(3, checks); - if (l_vec.empty() && !done()) + if (lp_settings().get_cancel_flag()) + return l_undef; + if (!m_lemmas.empty() || !m_literals.empty() || m_check_feasible) + return l_false; + } + + if (no_effect() && should_run_bounded_nlsat()) + ret = bounded_nlsat(); + + if (no_effect()) m_basics.basic_lemma(true); - if (l_vec.empty() && !done()) + if (no_effect()) m_basics.basic_lemma(false); - if (l_vec.empty() && !done()) - m_divisions.check(); - -#if 0 - if (l_vec.empty() && !done() && !run_horner) - m_horner.horner_lemmas(); - - if (l_vec.empty() && !done() && !run_grobner) - m_grobner(); -#endif - - if (!conflict_found() && !done() && run_bounded_nlsat) - ret = bounded_nlsat(); + if (no_effect()) + m_divisions.check(); - if (l_vec.empty() && !done() && ret == l_undef) { + if (no_effect()) { std::function check1 = [&]() { m_order.order_lemma(); }; std::function check2 = [&]() { m_monotone.monotonicity_lemma(); }; std::function check3 = [&]() { m_tangents.tangent_lemma(); }; @@ -1560,35 +1544,32 @@ lbool core::check(vector& l_vec) { check_weighted(3, checks); unsigned num_calls = lp_settings().stats().m_nla_calls; - if (!conflict_found() && m_nla_settings.run_nra && num_calls % 50 == 0 && num_calls > 500) + if (!conflict_found() && params().arith_nl_nra() && num_calls % 50 == 0 && num_calls > 500) ret = bounded_nlsat(); } - if (l_vec.empty() && !done() && m_nla_settings.run_nra && ret == l_undef) { + if (no_effect() && params().arith_nl_nra()) { ret = m_nra.check(); - m_stats.m_nra_calls++; + lp_settings().stats().m_nra_calls++; } - if (ret == l_undef && !l_vec.empty() && m_reslim.inc()) + if (ret == l_undef && !no_effect() && m_reslim.inc()) ret = l_false; - m_stats.m_nla_lemmas += l_vec.size(); - for (const auto& l : l_vec) - m_stats.m_nla_explanations += static_cast(l.expl().size()); - + lp_settings().stats().m_nla_lemmas += m_lemmas.size(); - TRACE("nla_solver", tout << "ret = " << ret << ", lemmas count = " << l_vec.size() << "\n";); - IF_VERBOSE(2, if(ret == l_undef) {verbose_stream() << "Monomials\n"; print_monics(verbose_stream());}); + TRACE("nla_solver", tout << "ret = " << ret << ", lemmas count = " << m_lemmas.size() << "\n";); + IF_VERBOSE(5, if(ret == l_undef) {verbose_stream() << "Monomials\n"; print_monics(verbose_stream());}); CTRACE("nla_solver", ret == l_undef, tout << "Monomials\n"; print_monics(tout);); return ret; } bool core::should_run_bounded_nlsat() { - if (!m_nla_settings.run_nra) + if (!params().arith_nl_nra()) return false; - if (m_nlsat_delay > m_nlsat_fails) - ++m_nlsat_fails; - return m_nlsat_delay <= m_nlsat_fails; + if (m_nlsat_delay > 0) + --m_nlsat_delay; + return m_nlsat_delay < 2; } lbool core::bounded_nlsat() { @@ -1604,21 +1585,21 @@ lbool core::bounded_nlsat() { } p.set_uint("max_conflicts", UINT_MAX); m_nra.updt_params(p); - m_stats.m_nra_calls++; + lp_settings().stats().m_nra_calls++; if (ret == l_undef) - ++m_nlsat_delay; - else { - m_nlsat_fails = 0; - m_nlsat_delay /= 2; - } - if (ret == l_true) { - m_lemma_vec->reset(); - } + ++m_nlsat_delay_bound; + else if (m_nlsat_delay_bound > 0) + m_nlsat_delay_bound /= 2; + + m_nlsat_delay = m_nlsat_delay_bound; + + if (ret == l_true) + clear(); return ret; } bool core::no_lemmas_hold() const { - for (auto & l : * m_lemma_vec) { + for (auto & l : m_lemmas) { if (lemma_holds(l)) { TRACE("nla_solver", print_lemma(l, tout);); return false; @@ -1627,23 +1608,16 @@ bool core::no_lemmas_hold() const { return true; } -lbool core::test_check(vector& l) { - m_lar_solver.set_status(lp::lp_status::OPTIMAL); - return check(l); + +lbool core::test_check() { + lra.set_status(lp::lp_status::OPTIMAL); + return check(); } std::ostream& core::print_terms(std::ostream& out) const { - for (unsigned i = 0; i< m_lar_solver.terms().size(); i++) { - unsigned ext = lp::tv::mask_term(i); - if (!m_lar_solver.var_is_registered(ext)) { - out << "term is not registered\n"; - continue; - } - - const lp::lar_term & t = *m_lar_solver.terms()[i]; - out << "term:"; print_term(t, out) << std::endl; - lpvar j = m_lar_solver.external_to_local(ext); - print_var(j, out); + for (const auto * t: lra.terms()) { + out << "term:"; print_term(*t, out) << std::endl; + print_var(t->j(), out); } return out; } @@ -1670,19 +1644,19 @@ std::ostream& core::print_term( const lp::lar_term& t, std::ostream& out) const std::unordered_set core::get_vars_of_expr_with_opening_terms(const nex *e ) { auto ret = get_vars_of_expr(e); - auto & ls = m_lar_solver; + auto & ls = lra; svector added; for (auto j : ret) { added.push_back(j); } for (unsigned i = 0; i < added.size(); ++i) { lpvar j = added[i]; - if (ls.column_corresponds_to_term(j)) { - const auto& t = m_lar_solver.get_term(lp::tv::raw(ls.local_to_external(j))); + if (ls.column_has_term(j)) { + const auto& t = lra.get_term(j); for (auto p : t) { - if (ret.find(p.column()) == ret.end()) { - added.push_back(p.column()); - ret.insert(p.column()); + if (ret.find(p.j()) == ret.end()) { + added.push_back(p.j()); + ret.insert(p.j()); } } } @@ -1698,7 +1672,7 @@ bool core::is_nl_var(lpvar j) const { unsigned core::get_var_weight(lpvar j) const { unsigned k; - switch (m_lar_solver.get_column_type(j)) { + switch (lra.get_column_type(j)) { case lp::column_type::fixed: k = 0; @@ -1727,18 +1701,16 @@ unsigned core::get_var_weight(lpvar j) const { void core::set_active_vars_weights(nex_creator& nc) { - nc.set_number_of_vars(m_lar_solver.column_count()); + nc.set_number_of_vars(lra.column_count()); for (lpvar j : active_var_set()) nc.set_var_weight(j, get_var_weight(j)); } bool core::influences_nl_var(lpvar j) const { - if (lp::tv::is_term(j)) - j = lp::tv::unmask_term(j); if (is_nl_var(j)) return true; - for (const auto & c : m_lar_solver.A_r().m_columns[j]) { - lpvar basic_in_row = m_lar_solver.r_basis()[c.var()]; + for (const auto & c : lra.A_r().m_columns[j]) { + lpvar basic_in_row = lra.r_basis()[c.var()]; if (is_nl_var(basic_in_row)) return true; } @@ -1752,12 +1724,20 @@ void core::set_use_nra_model(bool m) { } } -void core::collect_statistics(::statistics & st) { - st.update("arith-nla-explanations", m_stats.m_nla_explanations); - st.update("arith-nla-lemmas", m_stats.m_nla_lemmas); - st.update("arith-nra-calls", m_stats.m_nra_calls); +void core::propagate() { +#if Z3DEBUG + flet f(lra.validate_blocker(), true); +#endif + clear(); + m_monomial_bounds.unit_propagate(); + m_monics_with_changed_bounds.reset(); } +void core::simplify() { + // in-processing simplifiation can go here, such as bounds improvements. + +} + + -} // end of nla diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index 938bcbe8367..eb1f8b4369f 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -22,12 +22,12 @@ #include "math/lp/nla_powers.h" #include "math/lp/nla_divisions.h" #include "math/lp/emonics.h" -#include "math/lp/nla_settings.h" #include "math/lp/nex.h" #include "math/lp/horner.h" #include "math/lp/monomial_bounds.h" #include "math/lp/nla_intervals.h" #include "nlsat/nlsat_solver.h" +#include "smt/params/smt_params_helper.hpp" namespace nra { class solver; @@ -44,7 +44,6 @@ bool try_insert(const A& elem, B& collection) { return true; } - class core { friend struct common; friend class new_lemma; @@ -54,37 +53,31 @@ class core { friend struct tangents; friend class monotone; friend class powers; - friend struct nla_settings; friend class intervals; friend class horner; friend class solver; friend class monomial_bounds; friend class nra::solver; + friend class divisions; - struct stats { - unsigned m_nla_explanations; - unsigned m_nla_lemmas; - unsigned m_nra_calls; - stats() { reset(); } - void reset() { - memset(this, 0, sizeof(*this)); - } - }; - - stats m_stats; - unsigned m_nlsat_delay = 50; - unsigned m_nlsat_fails = 0; + unsigned m_nlsat_delay = 0; + unsigned m_nlsat_delay_bound = 0; bool should_run_bounded_nlsat(); lbool bounded_nlsat(); var_eqs m_evars; - lp::lar_solver& m_lar_solver; + lp::lar_solver& lra; reslimit& m_reslim; + smt_params_helper m_params; std::function m_relevant; - vector * m_lemma_vec; - lp::u_set m_to_refine; + vector m_lemmas; + vector m_literals; + vector m_equalities; + vector m_fixed_equalities; + indexed_uint_set m_to_refine; + indexed_uint_set m_monics_with_changed_bounds; tangents m_tangents; basics m_basics; order m_order; @@ -93,13 +86,13 @@ class core { divisions m_divisions; intervals m_intervals; monomial_bounds m_monomial_bounds; - nla_settings m_nla_settings; - + unsigned m_conflicts; + bool m_check_feasible = false; horner m_horner; grobner m_grobner; emonics m_emons; svector m_add_buffer; - mutable lp::u_set m_active_var_set; + mutable indexed_uint_set m_active_var_set; reslimit m_nra_lim; @@ -110,31 +103,31 @@ class core { monic const* m_patched_monic = nullptr; void check_weighted(unsigned sz, std::pair>* checks); + void add_bounds(); public: // constructor - core(lp::lar_solver& s, reslimit&); - + core(lp::lar_solver& s, params_ref const& p, reslimit&); + const auto& monics_with_changed_bounds() const { return m_monics_with_changed_bounds; } void insert_to_refine(lpvar j); void erase_from_to_refine(lpvar j); - const lp::u_set& active_var_set () const { return m_active_var_set;} + const indexed_uint_set& active_var_set () const { return m_active_var_set;} bool active_var_set_contains(unsigned j) const { return m_active_var_set.contains(j); } - void insert_to_active_var_set(unsigned j) const { m_active_var_set.insert(j); } + void insert_to_active_var_set(unsigned j) const { + m_active_var_set.insert(j); + } - void clear_active_var_set() const { m_active_var_set.clear(); } + void clear_active_var_set() const { m_active_var_set.reset(); } - void clear_and_resize_active_var_set() const { - m_active_var_set.clear(); - m_active_var_set.resize(m_lar_solver.number_of_vars()); - } - unsigned get_var_weight(lpvar) const; reslimit& reslim() { return m_reslim; } emonics& emons() { return m_emons; } const emonics& emons() const { return m_emons; } + monic& emon(unsigned i) { return m_emons[i]; } + monic const& emon(unsigned i) const { return m_emons[i]; } bool has_relevant_monomial() const; @@ -142,17 +135,16 @@ class core { rational value(const lp::lar_term& r) const; - lp::lar_term subs_terms_to_columns(const lp::lar_term& t) const; bool ineq_holds(const ineq& n) const; bool lemma_holds(const lemma& l) const; bool is_monic_var(lpvar j) const { return m_emons.is_monic_var(j); } - const rational& val(lpvar j) const { return m_lar_solver.get_column_value(j).x; } + const rational& val(lpvar j) const { return lra.get_column_value(j).x; } - const rational& var_val(const monic& m) const { return m_lar_solver.get_column_value(m.var()).x; } + const rational& var_val(const monic& m) const { return lra.get_column_value(m.var()).x; } rational mul_val(const monic& m) const { rational r(1); - for (lpvar v : m.vars()) r *= m_lar_solver.get_column_value(v).x; + for (lpvar v : m.vars()) r *= lra.get_column_value(v).x; return r; } @@ -168,13 +160,15 @@ class core { lpvar var(const factor& f) const { return f.var(); } + smt_params_helper const & params() const { return m_params; } + // returns true if the combination of the Horner's schema and Grobner Basis should be called bool need_run_horner() const { - return m_nla_settings.run_horner && lp_settings().stats().m_nla_calls % m_nla_settings.horner_frequency == 0; + return params().arith_nl_horner() && lp_settings().stats().m_nla_calls % params().arith_nl_horner_frequency() == 0; } - bool need_run_grobner() const { - return m_nla_settings.run_grobner && lp_settings().stats().m_nla_calls % m_nla_settings.grobner_frequency == 0; + bool need_run_grobner() const { + return params().arith_nl_grobner(); } void set_active_vars_weights(nex_creator&); @@ -235,6 +229,7 @@ class core { std::ostream & print_factor(const factor& f, std::ostream& out) const; std::ostream & print_factor_with_vars(const factor& f, std::ostream& out) const; + std::ostream & print_factor_with_vars(lpvar j, std::ostream& out) const { return print_var(j, out); } std::ostream& print_monic(const monic& m, std::ostream& out) const; std::ostream& print_bfc(const factorization& m, std::ostream& out) const; std::ostream& print_monic_with_vars(unsigned i, std::ostream& out) const; @@ -288,16 +283,16 @@ class core { } const rational& get_upper_bound(unsigned j) const; const rational& get_lower_bound(unsigned j) const; - bool has_lower_bound(lp::var_index var, lp::constraint_index& ci, lp::mpq& value, bool& is_strict) const { - return m_lar_solver.has_lower_bound(var, ci, value, is_strict); + bool has_lower_bound(lp::lpvar var, u_dependency*& ci, lp::mpq& value, bool& is_strict) const { + return lra.has_lower_bound(var, ci, value, is_strict); } - bool has_upper_bound(lp::var_index var, lp::constraint_index& ci, lp::mpq& value, bool& is_strict) const { - return m_lar_solver.has_upper_bound(var, ci, value, is_strict); + bool has_upper_bound(lp::lpvar var, u_dependency*& ci, lp::mpq& value, bool& is_strict) const { + return lra.has_upper_bound(var, ci, value, is_strict); } bool zero_is_an_inner_point_of_bounds(lpvar j) const; - bool var_is_int(lpvar j) const { return m_lar_solver.column_is_int(j); } + bool var_is_int(lpvar j) const { return lra.column_is_int(j); } int rat_sign(const monic& m) const; inline int rat_sign(lpvar j) const { return nla::rat_sign(val(j)); } @@ -339,7 +334,7 @@ class core { bool is_octagon_term(const lp::lar_term& t, bool & sign, lpvar& i, lpvar &j) const; - void add_equivalence_maybe(const lp::lar_term *t, lpci c0, lpci c1); + void add_equivalence_maybe(const lp::lar_term* t, u_dependency* c0, u_dependency* c1); void init_vars_equivalence(); @@ -381,13 +376,17 @@ class core { bool conflict_found() const; - lbool check(vector& l_vec); - lbool check_power(lpvar r, lpvar x, lpvar y, vector& l_vec); - void check_bounded_divisions(vector&); + lbool check(); + lbool check_power(lpvar r, lpvar x, lpvar y); + void check_bounded_divisions(); bool no_lemmas_hold() const; + + void propagate(); + + void simplify(); - lbool test_check(vector& l); + lbool test_check(); lpvar map_to_root(lpvar) const; std::ostream& print_terms(std::ostream&) const; std::ostream& print_term(const lp::lar_term&, std::ostream&) const; @@ -420,14 +419,15 @@ class core { bool has_real(const monic& m) const; void set_use_nra_model(bool m); bool use_nra_model() const { return m_use_nra_model; } - void collect_statistics(::statistics&); -private: - void restore_patched_values(); - void constrain_nl_in_tableau(); - bool solve_tableau(); - void restore_tableau(); - void save_tableau(); - bool integrality_holds(); + vector const& lemmas() const { return m_lemmas; } + vector const& literals() const { return m_literals; } + vector const& equalities() const { return m_equalities; } + vector const& fixed_equalities() const { return m_fixed_equalities; } + bool should_check_feasible() const { return m_check_feasible; } + + void add_fixed_equality(lp::lpvar v, rational const& k, lp::explanation const& e) { m_fixed_equalities.push_back({v, k, e}); } + void add_equality(lp::lpvar i, lp::lpvar j, lp::explanation const& e) { m_equalities.push_back({i, j, e}); } + }; // end of core struct pp_mon { diff --git a/src/math/lp/nla_defs.h b/src/math/lp/nla_defs.h index df9158b42e6..b576036143d 100644 --- a/src/math/lp/nla_defs.h +++ b/src/math/lp/nla_defs.h @@ -14,11 +14,10 @@ #include "math/lp/explanation.h" namespace nla { -typedef lp::constraint_index lpci; -typedef lp::lconstraint_kind llc; -typedef lp::constraint_index lpci; -typedef lp::explanation expl_set; -typedef lp::var_index lpvar; + typedef lp::constraint_index lpci; + typedef lp::lconstraint_kind llc; + typedef lp::explanation expl_set; + typedef unsigned lpvar; struct from_index_dummy{}; class signed_var { @@ -55,12 +54,12 @@ inline std::ostream& operator<<(std::ostream& out, signed_var const& sv) { retur * where m_vs = [v1, v2, .., vn] */ class monic_coeff { - svector m_vs; + svector m_vs; rational m_coeff; public: - monic_coeff(const svector& vs, rational const& coeff): m_vs(vs), m_coeff(coeff) {} + monic_coeff(const svector& vs, rational const& coeff): m_vs(vs), m_coeff(coeff) {} rational const& coeff() const { return m_coeff; } - const svector & vars() const { return m_vs; } + const svector & vars() const { return m_vs; } }; template bool has_zero(const T& product) { for (const rational & t : product) { diff --git a/src/math/lp/nla_divisions.cpp b/src/math/lp/nla_divisions.cpp index cbb30d9d9f1..6c3bb178c66 100644 --- a/src/math/lp/nla_divisions.cpp +++ b/src/math/lp/nla_divisions.cpp @@ -19,19 +19,17 @@ Module Name: namespace nla { void divisions::add_idivision(lpvar q, lpvar x, lpvar y) { + const auto& lra = m_core.lra; if (x == null_lpvar || y == null_lpvar || q == null_lpvar) return; - if (lp::tv::is_term(x) || lp::tv::is_term(y) || lp::tv::is_term(q)) - return; m_idivisions.push_back({q, x, y}); m_core.trail().push(push_back_vector(m_idivisions)); } void divisions::add_rdivision(lpvar q, lpvar x, lpvar y) { + auto& lra = m_core.lra; if (x == null_lpvar || y == null_lpvar || q == null_lpvar) return; - if (lp::tv::is_term(x) || lp::tv::is_term(y) || lp::tv::is_term(q)) - return; m_rdivisions.push_back({ q, x, y }); m_core.trail().push(push_back_vector(m_rdivisions)); } @@ -39,7 +37,7 @@ namespace nla { void divisions::add_bounded_division(lpvar q, lpvar x, lpvar y) { if (x == null_lpvar || y == null_lpvar || q == null_lpvar) return; - if (lp::tv::is_term(x) || lp::tv::is_term(y) || lp::tv::is_term(q)) + if (m_core.lra.column_has_term(x) || m_core.lra.column_has_term(y) || m_core.lra.column_has_term(q)) return; m_bounded_divisions.push_back({ q, x, y }); m_core.trail().push(push_back_vector(m_bounded_divisions)); @@ -52,7 +50,7 @@ namespace nla { // y2 <= y1 < 0 & x1 <= x2 <= 0 => x1/y1 >= x2/y2 void divisions::check() { - core& c = m_core; + core& c = m_core; if (c.use_nra_model()) return; @@ -132,7 +130,7 @@ namespace nla { auto x2val = c.val(x2); auto y2val = c.val(y2); auto q2val = c.val(q2); - if (monotonicity(x, xval, y, yval, r, rval, x2, x2val, y2, y2val, q2, q2val)) + if (monotonicity(x, xval, y, yval, r, rval, x2, x2val, y2, y2val, q2, q2val)) return; } } diff --git a/src/math/lp/nla_grobner.cpp b/src/math/lp/nla_grobner.cpp index 974c48d1400..4beb7eaff7e 100644 --- a/src/math/lp/nla_grobner.cpp +++ b/src/math/lp/nla_grobner.cpp @@ -13,7 +13,6 @@ Module Name: #include "util/uint_set.h" #include "math/lp/nla_core.h" #include "math/lp/factorization_factory_imp.h" -#include "math/lp/nex.h" #include "math/grobner/pdd_solver.h" #include "math/dd/pdd_interval.h" #include "math/dd/pdd_eval.h" @@ -22,10 +21,10 @@ namespace nla { grobner::grobner(core* c): common(c), - m_pdd_manager(m_core.m_lar_solver.number_of_vars()), - m_solver(m_core.m_reslim, m_pdd_manager), - m_lar_solver(m_core.m_lar_solver) - + m_pdd_manager(m_core.lra.number_of_vars()), + m_solver(m_core.m_reslim, m_core.lra.dep_manager(), m_pdd_manager), + lra(m_core.lra), + m_quota(m_core.params().arith_nl_gr_q()) {} lp::lp_settings& grobner::lp_settings() { @@ -33,73 +32,87 @@ namespace nla { } void grobner::operator()() { - unsigned& quota = c().m_nla_settings.grobner_quota; - if (quota == 1) + + if (lra.column_count() > 5000) return; + if (m_quota == 0) + m_quota = c().params().arith_nl_gr_q(); + + if (m_quota == 1) { + m_delay_base++; + m_delay = m_delay_base; + m_quota = c().params().arith_nl_gr_q(); + } + + if (m_delay > 0) { + --m_delay; + return; + } + lp_settings().stats().m_grobner_calls++; find_nl_cluster(); - configure(); - m_solver.saturate(); - - if (is_conflicting()) + if (!configure()) return; + m_solver.saturate(); + if (m_delay_base > 0) + --m_delay_base; + try { - if (propagate_bounds()) + + if (is_conflicting()) return; if (propagate_eqs()) return; - + if (propagate_factorization()) return; + + if (propagate_linear_equations()) + return; + } catch (...) { } - if (quota > 1) - quota--; + // DEBUG_CODE(for (auto e : m_solver.equations()) check_missing_propagation(*e);); - IF_VERBOSE(2, verbose_stream() << "grobner miss, quota " << quota << "\n"); - IF_VERBOSE(4, diagnose_pdd_miss(verbose_stream())); + // for (auto e : m_solver.equations()) check_missing_propagation(*e); + + ++m_delay_base; + if (m_quota > 0) + --m_quota; -#if 0 - // diagnostics: did we miss something - vector eqs; - for (auto eq : m_solver.equations()) - eqs.push_back(eq->poly()); - c().m_nra.check(eqs); -#endif + IF_VERBOSE(3, verbose_stream() << "grobner miss, quota " << m_quota << "\n"); + IF_VERBOSE(4, diagnose_pdd_miss(verbose_stream())); } - bool grobner::is_conflicting() { - unsigned conflicts = 0; - for (auto eq : m_solver.equations()) - if (is_conflicting(*eq) && ++conflicts >= m_solver.number_of_conflicts_to_report()) - break; - - if (conflicts > 0) - lp_settings().stats().m_grobner_conflicts++; - - TRACE("grobner", m_solver.display(tout)); - IF_VERBOSE(2, if (conflicts > 0) verbose_stream() << "grobner conflict\n"); - - return conflicts > 0; + dd::solver::equation_vector const& grobner::core_equations(bool all_eqs) { + flet _add_all(m_add_all_eqs, all_eqs); + find_nl_cluster(); + if (!configure()) + throw dd::pdd_manager::mem_out(); + return m_solver.equations(); } - bool grobner::propagate_bounds() { - unsigned changed = 0; - for (auto eq : m_solver.equations()) - if (propagate_bounds(*eq) && ++changed >= m_solver.number_of_conflicts_to_report()) + bool grobner::is_conflicting() { + for (auto eq : m_solver.equations()) { + if (is_conflicting(*eq)) { + lp_settings().stats().m_grobner_conflicts++; + TRACE("grobner", m_solver.display(tout)); + IF_VERBOSE(3, verbose_stream() << "grobner conflict\n"); return true; - return changed > 0; + } + } + return false; } bool grobner::propagate_eqs() { unsigned changed = 0; - for (auto eq : m_solver.equations()) + for (auto eq : m_solver.equations()) if (propagate_fixed(*eq) && ++changed >= m_solver.number_of_conflicts_to_report()) return true; return changed > 0; @@ -107,7 +120,7 @@ namespace nla { bool grobner::propagate_factorization() { unsigned changed = 0; - for (auto eq : m_solver.equations()) + for (auto eq : m_solver.equations()) if (propagate_factorization(*eq) && ++changed >= m_solver.number_of_conflicts_to_report()) return true; return changed > 0; @@ -121,11 +134,11 @@ namespace nla { typedef lp::lar_term term; bool grobner::propagate_fixed(const dd::solver::equation& eq) { dd::pdd const& p = eq.poly(); - //IF_VERBOSE(0, verbose_stream() << p << "\n"); if (p.is_unary()) { unsigned v = p.var(); if (c().var_is_fixed(v)) return false; + ineq new_eq(v, llc::EQ, rational::zero()); if (c().ineq_holds(new_eq)) return false; @@ -169,14 +182,22 @@ namespace nla { // IF_VERBOSE(0, verbose_stream() << "factored " << q << " : " << vars << "\n"); term t; + rational lc(1); + auto ql = q; + while (!ql.is_val()) { + lc = lcm(lc, denominator(ql.hi().val())); + ql = ql.lo(); + } + lc = lcm(denominator(ql.val()), lc); + while (!q.is_val()) { - t.add_monomial(q.hi().val(), q.var()); + t.add_monomial(lc*q.hi().val(), q.var()); q = q.lo(); } vector ineqs; for (auto v : vars) ineqs.push_back(ineq(v, llc::EQ, rational::zero())); - ineqs.push_back(ineq(t, llc::EQ, -q.val())); + ineqs.push_back(ineq(t, llc::EQ, -lc*q.val())); for (auto const& i : ineqs) if (c().ineq_holds(i)) return false; @@ -189,41 +210,45 @@ namespace nla { return true; } - - void grobner::add_dependencies(new_lemma& lemma, const dd::solver::equation& eq) { - lp::explanation ex; + void grobner::explain(dd::solver::equation const& eq, lp::explanation& exp) { u_dependency_manager dm; vector lv; dm.linearize(eq.dep(), lv); for (unsigned ci : lv) - ex.push_back(ci); - lemma &= ex; + exp.push_back(ci); } - void grobner::configure() { + + void grobner::add_dependencies(new_lemma& lemma, const dd::solver::equation& eq) { + lp::explanation exp; + explain(eq, exp); + lemma &= exp; + } + + bool grobner::configure() { m_solver.reset(); try { set_level2var(); TRACE("grobner", tout << "base vars: "; for (lpvar j : c().active_var_set()) - if (m_lar_solver.is_base(j)) + if (lra.is_base(j)) tout << "j" << j << " "; tout << "\n"); for (lpvar j : c().active_var_set()) { - if (m_lar_solver.is_base(j)) - add_row(m_lar_solver.basic2row(j)); + if (lra.is_base(j)) + add_row(lra.basic2row(j)); if (c().is_monic_var(j) && c().var_is_fixed(j)) add_fixed_monic(j); } } - catch (...) { + catch (dd::pdd_manager::mem_out) { IF_VERBOSE(2, verbose_stream() << "pdd throw\n"); - return; + return false; } TRACE("grobner", m_solver.display(tout)); - + #if 0 IF_VERBOSE(2, m_pdd_grobner.display(verbose_stream())); dd::pdd_eval eval(m_pdd_manager); @@ -239,14 +264,16 @@ namespace nla { struct dd::solver::config cfg; cfg.m_max_steps = m_solver.equations().size(); - cfg.m_max_simplified = c().m_nla_settings.grobner_max_simplified; - cfg.m_eqs_growth = c().m_nla_settings.grobner_eqs_growth; - cfg.m_expr_size_growth = c().m_nla_settings.grobner_expr_size_growth; - cfg.m_expr_degree_growth = c().m_nla_settings.grobner_expr_degree_growth; - cfg.m_number_of_conflicts_to_report = c().m_nla_settings.grobner_number_of_conflicts_to_report; + cfg.m_max_simplified = c().params().arith_nl_grobner_max_simplified(); + cfg.m_eqs_growth = c().params().arith_nl_grobner_eqs_growth(); + cfg.m_expr_size_growth = c().params().arith_nl_grobner_expr_size_growth(); + cfg.m_expr_degree_growth = c().params().arith_nl_grobner_expr_degree_growth(); + cfg.m_number_of_conflicts_to_report = c().params().arith_nl_grobner_cnfl_to_report(); m_solver.set(cfg); m_solver.adjust_cfg(); m_pdd_manager.set_max_num_nodes(10000); // or something proportional to the number of initial nodes. + + return true; } std::ostream& grobner::diagnose_pdd_miss(std::ostream& out) { @@ -263,30 +290,41 @@ namespace nla { } } - for (unsigned j = 0; j < m_lar_solver.number_of_vars(); ++j) { - if (m_lar_solver.column_has_lower_bound(j) || m_lar_solver.column_has_upper_bound(j)) { + for (unsigned j = 0; j < lra.number_of_vars(); ++j) { + if (lra.column_has_lower_bound(j) || lra.column_has_upper_bound(j)) { out << j << ": ["; - if (m_lar_solver.column_has_lower_bound(j)) out << m_lar_solver.get_lower_bound(j); + if (lra.column_has_lower_bound(j)) out << lra.get_lower_bound(j); out << ".."; - if (m_lar_solver.column_has_upper_bound(j)) out << m_lar_solver.get_upper_bound(j); + if (lra.column_has_upper_bound(j)) out << lra.get_upper_bound(j); out << "]\n"; } } return out; } + bool grobner::equation_is_true(dd::solver::equation const& eq) { + if (any_of(eq.poly().free_vars(), [&](unsigned j) { return lra.column_is_free(j); })) + return true; + dd::pdd_eval eval; + eval.var2val() = [&](unsigned j){ return val(j); }; + return eval(eq.poly()) == 0; + } + + bool grobner::is_conflicting(const dd::solver::equation& e) { + if (equation_is_true(e)) + return false; auto& di = c().m_intervals.get_dep_intervals(); - dd::pdd_interval eval(di); - eval.var2interval() = [this](lpvar j, bool deps, scoped_dep_interval& a) { + dd::pdd_interval evali(di); + evali.var2interval() = [this](lpvar j, bool deps, scoped_dep_interval& a) { if (deps) c().m_intervals.set_var_interval(j, a); else c().m_intervals.set_var_interval(j, a); }; scoped_dep_interval i(di), i_wd(di); - eval.get_interval(e.poly(), i); + evali.get_interval(e.poly(), i); if (!di.separated_from_zero(i)) { TRACE("grobner", m_solver.display(tout << "not separated from 0 ", e) << "\n"; - eval.get_interval_distributed(e.poly(), i); + evali.get_interval_distributed(e.poly(), i); tout << "separated from 0: " << di.separated_from_zero(i) << "\n"; for (auto j : e.poly().free_vars()) { scoped_dep_interval a(di); @@ -294,10 +332,17 @@ namespace nla { c().m_intervals.display(tout << "j" << j << " ", a); tout << " "; } tout << "\n"); - + + + if (add_horner_conflict(e)) + return true; +#if 0 + if (add_nla_conflict(e)) + return true; +#endif return false; } - eval.get_interval(e.poly(), i_wd); + evali.get_interval(e.poly(), i_wd); std::function f = [this](const lp::explanation& e) { new_lemma lemma(m_core, "pdd"); lemma &= e; @@ -307,25 +352,67 @@ namespace nla { return true; } else { +#if 0 + if (add_nla_conflict(e)) + return true; +#endif TRACE("grobner", m_solver.display(tout << "no conflict ", e) << "\n"); return false; } } - bool grobner::propagate_bounds(const dd::solver::equation& e) { - return false; - // TODO - auto& di = c().m_intervals.get_dep_intervals(); - dd::pdd_interval eval(di); - eval.var2interval() = [this](lpvar j, bool deps, scoped_dep_interval& a) { - if (deps) c().m_intervals.set_var_interval(j, a); - else c().m_intervals.set_var_interval(j, a); - }; - scoped_dep_interval i(di), i_wd(di); - eval.get_interval(e.poly(), i); - return false; + bool grobner::propagate_linear_equations() { + unsigned changed = 0; + m_mon2var.clear(); + for (auto const& m : c().emons()) + m_mon2var[m.vars()] = m.var(); + + for (auto eq : m_solver.equations()) + if (propagate_linear_equations(*eq)) + ++changed; + return changed > 0; + } + + bool grobner::propagate_linear_equations(dd::solver::equation const& e) { + if (equation_is_true(e)) + return false; + rational value(0); + for (auto const& [coeff, vars] : e.poly()) { + if (vars.empty()) + value += coeff; + else if (vars.size() == 1) + value += coeff*val(vars[0]); + else if (m_mon2var.find(vars) == m_mon2var.end()) + return false; + else + value += coeff*val(m_mon2var.find(vars)->second); + } + if (value == 0) + return false; + + rational lc(1); + for (auto const& [coeff, vars] : e.poly()) + lc = lcm(denominator(coeff), lc); + + vector> coeffs; + rational offset(0); + + for (auto const& [coeff, vars] : e.poly()) { + if (vars.size() == 0) + offset -= lc*coeff; + else if (vars.size() == 1) + coeffs.push_back({lc*coeff, vars[0]}); + else + coeffs.push_back({lc*coeff, m_mon2var.find(vars)->second}); + } + + lp::lpvar j = c().lra.add_term(coeffs, UINT_MAX); + c().lra.update_column_type_and_bound(j, lp::lconstraint_kind::EQ, offset, e.dep()); + c().m_check_feasible = true; + return true; } + void grobner::add_var_and_its_factors_to_q_and_collect_new_rows(lpvar j, svector & q) { if (c().active_var_set_contains(j)) return; @@ -339,18 +426,23 @@ namespace nla { if (c().var_is_fixed(j)) return; - const auto& matrix = m_lar_solver.A_r(); + const auto& matrix = lra.A_r(); for (auto & s : matrix.m_columns[j]) { unsigned row = s.var(); if (m_rows.contains(row)) continue; m_rows.insert(row); - unsigned k = m_lar_solver.get_base_column_in_row(row); - if (m_lar_solver.column_is_free(k) && k != j) + unsigned k = lra.get_base_column_in_row(row); + // grobner bassis does not know about integer constraints + if (lra.column_is_free(k) && !m_add_all_eqs && k != j) + continue; + // a free column over the reals can be assigned + if (lra.column_is_free(k) && k != j && !lra.var_is_int(k)) continue; - CTRACE("grobner", matrix.m_rows[row].size() > c().m_nla_settings.grobner_row_length_limit, - tout << "ignore the row " << row << " with the size " << matrix.m_rows[row].size() << "\n";); - if (matrix.m_rows[row].size() > c().m_nla_settings.grobner_row_length_limit) + CTRACE("grobner", matrix.m_rows[row].size() > c().params().arith_nl_grobner_row_length_limit(), + tout << "ignore the row " << row << " with the size " << matrix.m_rows[row].size() << "\n";); + // limits overhead of grobner equations, unless this is for extracting a complete COI of the non-satisfied subset. + if (!m_add_all_eqs && matrix.m_rows[row].size() > c().params().arith_nl_horner_row_length_limit()) continue; for (auto& rc : matrix.m_rows[row]) add_var_and_its_factors_to_q_and_collect_new_rows(rc.var(), q); @@ -358,11 +450,10 @@ namespace nla { } const rational& grobner::val_of_fixed_var_with_deps(lpvar j, u_dependency*& dep) { - unsigned lc, uc; - m_lar_solver.get_bound_constraint_witnesses_for_column(j, lc, uc); - dep = c().m_intervals.mk_join(dep, c().m_intervals.mk_leaf(lc)); - dep = c().m_intervals.mk_join(dep, c().m_intervals.mk_leaf(uc)); - return m_lar_solver.column_lower_bound(j).x; + auto* d = lra.get_bound_constraint_witnesses_for_column(j); + if (d) + dep = c().m_intervals.mk_join(dep, d); + return lra.column_lower_bound(j).x; } dd::pdd grobner::pdd_expr(const rational& coeff, lpvar j, u_dependency*& dep) { @@ -373,12 +464,12 @@ namespace nla { while (!vars.empty()) { j = vars.back(); vars.pop_back(); - if (c().m_nla_settings.grobner_subs_fixed > 0 && c().var_is_fixed_to_zero(j)) { + if (c().params().arith_nl_grobner_subs_fixed() > 0 && c().var_is_fixed_to_zero(j)) { r = m_pdd_manager.mk_val(val_of_fixed_var_with_deps(j, zero_dep)); dep = zero_dep; return r; } - if (c().m_nla_settings.grobner_subs_fixed == 1 && c().var_is_fixed(j)) + if (c().params().arith_nl_grobner_subs_fixed() == 1 && c().var_is_fixed(j)) r *= val_of_fixed_var_with_deps(j, dep); else if (!c().is_monic_var(j)) r *= m_pdd_manager.mk_var(j); @@ -411,7 +502,7 @@ namespace nla { SASSERT(r.hi().is_val()); v = r.var(); rational val = r.hi().val(); - switch (m_lar_solver.get_column_type(v)) { + switch (lra.get_column_type(v)) { case lp::column_type::lower_bound: if (val > 0) num_lo++, lo = v, lc = val; else num_hi++, hi = v, hc = val; break; @@ -477,7 +568,6 @@ namespace nla { add_eq(sum, dep); } - void grobner::find_nl_cluster() { prepare_rows_and_active_vars(); svector q; @@ -498,14 +588,13 @@ namespace nla { } void grobner::prepare_rows_and_active_vars() { - m_rows.clear(); - m_rows.resize(m_lar_solver.row_count()); - c().clear_and_resize_active_var_set(); + m_rows.reset(); + c().clear_active_var_set(); } void grobner::display_matrix_of_m_rows(std::ostream & out) const { - const auto& matrix = m_lar_solver.A_r(); + const auto& matrix = lra.A_r(); out << m_rows.size() << " rows" << "\n"; out << "the matrix\n"; for (const auto & r : matrix.m_rows) @@ -513,7 +602,7 @@ namespace nla { } void grobner::set_level2var() { - unsigned n = m_lar_solver.column_count(); + unsigned n = lra.column_count(); unsigned_vector sorted_vars(n), weighted_vars(n); for (unsigned j = 0; j < n; j++) { sorted_vars[j] = j; @@ -547,4 +636,76 @@ namespace nla { tout << "\n"); } + bool grobner::is_nla_conflict(const dd::solver::equation& eq) { + vector eqs; + eqs.push_back(eq.poly()); + return l_false == c().m_nra.check(eqs); + } + + bool grobner::add_horner_conflict(const dd::solver::equation& eq) { + nex_creator& nc = m_nex_creator; + nc.pop(0); + nex_creator::sum_factory sum(nc); + u_map var2nex; + for (auto v : eq.poly().free_vars()) + var2nex.insert(v, nc.mk_var(v)); + unsigned mx = 0; + for (auto v : eq.poly().free_vars()) + mx = std::max(v, mx); + nc.set_number_of_vars(mx + 1); + for (auto const& [coeff, vars] : eq.poly()) { + switch (vars.size()) { + case 0: + sum += nc.mk_scalar(coeff); + break; + case 1: + sum += nc.mk_mul(coeff, var2nex[vars[0]]); + break; + default: + nc.m_mk_mul.reset(); + nc.m_mk_mul *= coeff; + for (auto v : vars) + nc.m_mk_mul *= var2nex[v]; + sum += nc.m_mk_mul.mk(); + break; + } + } + nex* e = nc.simplify(sum.mk()); + if (e->get_degree() < 2 || !e->is_sum()) + return false; + + auto dep = eq.dep(); + cross_nested cn( + [this, dep](const nex* n) { return c().m_intervals.check_nex(n, dep); }, + [this](unsigned j) { return c().var_is_fixed(j); }, + [this]() { return c().random(); }, nc); + cn.run(to_sum(e)); + bool ret = cn.done(); + return ret; + } + + bool grobner::add_nla_conflict(const dd::solver::equation& eq) { + if (is_nla_conflict(eq)) { + new_lemma lemma(m_core,"nla-conflict"); + lp::explanation exp; + explain(eq, exp); + lemma &= exp; + return true; + } + return false; + } + + + void grobner::check_missing_propagation(const dd::solver::equation& e) { + bool is_confl = is_nla_conflict(e); + CTRACE("grobner", is_confl, m_solver.display(tout << "missed conflict ", e);); + if (is_confl) { + IF_VERBOSE(2, verbose_stream() << "missed conflict\n"); + return; + } + //lbool r = c().m_nra.check_tight(e.poly()); + //CTRACE("grobner", r == l_false, m_solver.display(tout << "tight equality ", e);); + } + + } diff --git a/src/math/lp/nla_grobner.h b/src/math/lp/nla_grobner.h index 902ad3a468e..be5f06136e0 100644 --- a/src/math/lp/nla_grobner.h +++ b/src/math/lp/nla_grobner.h @@ -12,7 +12,7 @@ #include "math/lp/nla_intervals.h" #include "math/lp/nex.h" #include "math/lp/cross_nested.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" #include "math/grobner/pdd_solver.h" namespace nla { @@ -21,28 +21,41 @@ namespace nla { class grobner : common { dd::pdd_manager m_pdd_manager; dd::solver m_solver; - lp::lar_solver& m_lar_solver; - lp::u_set m_rows; + lp::lar_solver& lra; + indexed_uint_set m_rows; + unsigned m_quota = 0; + unsigned m_delay_base = 0; + unsigned m_delay = 0; + bool m_add_all_eqs = false; + std::unordered_map m_mon2var; lp::lp_settings& lp_settings(); // solving bool is_conflicting(); - bool is_conflicting(const dd::solver::equation& eq); - - bool propagate_bounds(); - bool propagate_bounds(const dd::solver::equation& eq); + bool is_conflicting(dd::solver::equation const& eq); bool propagate_eqs(); - bool propagate_fixed(const dd::solver::equation& eq); + bool propagate_fixed(dd::solver::equation const& eq); bool propagate_factorization(); - bool propagate_factorization(const dd::solver::equation& eq); - - void add_dependencies(new_lemma& lemma, const dd::solver::equation& eq); + bool propagate_factorization(dd::solver::equation const& eq); + + bool propagate_linear_equations(); + bool propagate_linear_equations(dd::solver::equation const& eq); + + void add_dependencies(new_lemma& lemma, dd::solver::equation const& eq); + void explain(dd::solver::equation const& eq, lp::explanation& exp); + + bool add_horner_conflict(dd::solver::equation const& eq); + bool is_nla_conflict(dd::solver::equation const& eq); + bool add_nla_conflict(dd::solver::equation const& eq); + void check_missing_propagation(dd::solver::equation const& eq); + + bool equation_is_true(dd::solver::equation const& eq); // setup - void configure(); + bool configure(); void set_level2var(); void find_nl_cluster(); void prepare_rows_and_active_vars(); @@ -60,5 +73,6 @@ namespace nla { public: grobner(core *core); void operator()(); + dd::solver::equation_vector const& core_equations(bool all_eqs); }; } diff --git a/src/math/lp/nla_intervals.cpp b/src/math/lp/nla_intervals.cpp index 4ffbcb7e3fc..791251feedb 100644 --- a/src/math/lp/nla_intervals.cpp +++ b/src/math/lp/nla_intervals.cpp @@ -4,6 +4,11 @@ #include "util/mpq.h" namespace nla { + +intervals::intervals(core* c, reslimit& lim): + m_dep_intervals(c->lra.dep_manager(), lim), + m_core(c) {} + typedef enum dep_intervals::with_deps_t e_with_deps; const nex* intervals::get_inf_interval_child(const nex_sum& e) const { @@ -173,7 +178,7 @@ lp::lar_term intervals::expression_to_normalized_term(const nex_sum* e, rational // where m_terms[k] corresponds to the returned lpvar lpvar intervals::find_term_column(const lp::lar_term & norm_t, rational& a) const { std::pair a_j; - if (m_core->m_lar_solver.fetch_normalized_term_column(norm_t, a_j)) { + if (m_core->lra.fetch_normalized_term_column(norm_t, a_j)) { a /= a_j.first; return a_j.second; } @@ -206,19 +211,10 @@ void intervals::set_zero_interval_deps_for_mult(interval& a) { a.m_upper_dep = a.m_lower_dep; } -u_dependency *intervals::mk_dep(lp::constraint_index ci) { - return m_dep_intervals.mk_leaf(ci); -} - -u_dependency *intervals::mk_dep(const lp::explanation& expl) { +u_dependency* intervals::mk_dep(const lp::explanation& expl) { u_dependency * r = nullptr; - for (auto p : expl) { - if (r == nullptr) { - r = m_dep_intervals.mk_leaf(p.ci()); - } else { - r = m_dep_intervals.mk_join(r, m_dep_intervals.mk_leaf(p.ci())); - } - } + for (auto p : expl) + r = m_dep_intervals.mk_join(r, m_dep_intervals.mk_leaf(p.ci())); return r; } @@ -249,25 +245,25 @@ std::ostream& intervals::display(std::ostream& out, const interval& i) const { template void intervals::set_var_interval(lpvar v, interval& b) { TRACE("nla_intervals_details", m_core->print_var(v, tout) << "\n";); - lp::constraint_index ci; + u_dependency* dep = nullptr; rational val; bool is_strict; - if (ls().has_lower_bound(v, ci, val, is_strict)) { + if (ls().has_lower_bound(v, dep, val, is_strict)) { m_dep_intervals.set_lower(b, val); m_dep_intervals.set_lower_is_open(b, is_strict); m_dep_intervals.set_lower_is_inf(b, false); - if (wd == e_with_deps::with_deps) b.m_lower_dep = mk_dep(ci); + if (wd == e_with_deps::with_deps) b.m_lower_dep = dep; } else { m_dep_intervals.set_lower_is_open(b, true); m_dep_intervals.set_lower_is_inf(b, true); if (wd == e_with_deps::with_deps) b.m_lower_dep = nullptr; } - if (ls().has_upper_bound(v, ci, val, is_strict)) { + if (ls().has_upper_bound(v, dep, val, is_strict)) { m_dep_intervals.set_upper(b, val); m_dep_intervals.set_upper_is_open(b, is_strict); m_dep_intervals.set_upper_is_inf(b, false); - if (wd == e_with_deps::with_deps) b.m_upper_dep = mk_dep(ci); + if (wd == e_with_deps::with_deps) b.m_upper_dep = dep; } else { m_dep_intervals.set_upper_is_open(b, true); @@ -303,7 +299,7 @@ bool intervals::interval_from_term(const nex& e, scoped_dep_interval& i) { m_dep_intervals.set(i, bi); TRACE("nla_intervals", - m_core->m_lar_solver.print_column_info(j, tout) << "\n"; + m_core->lra.print_column_info(j, tout) << "\n"; tout << "a=" << a << ", b=" << b << "\n"; tout << e << ", interval = "; display(tout, i);); return true; @@ -476,9 +472,9 @@ bool intervals::interval_of_expr(const nex* e, unsigned p, scoped_dep_interval& } -lp::lar_solver& intervals::ls() { return m_core->m_lar_solver; } +lp::lar_solver& intervals::ls() { return m_core->lra; } -const lp::lar_solver& intervals::ls() const { return m_core->m_lar_solver; } +const lp::lar_solver& intervals::ls() const { return m_core->lra; } } // end of nla namespace diff --git a/src/math/lp/nla_intervals.h b/src/math/lp/nla_intervals.h index 0545e993369..514049aca36 100644 --- a/src/math/lp/nla_intervals.h +++ b/src/math/lp/nla_intervals.h @@ -25,16 +25,13 @@ class intervals { public: typedef dep_intervals::interval interval; private: - u_dependency* mk_dep(lp::constraint_index ci); u_dependency* mk_dep(lp::explanation const&); lp::lar_solver& ls(); const lp::lar_solver& ls() const; public: - intervals(core* c, reslimit& lim) : - m_dep_intervals(lim), - m_core(c) - {} + intervals(core* c, reslimit& lim); + dep_intervals& get_dep_intervals() { return m_dep_intervals; } u_dependency* mk_join(u_dependency* a, u_dependency* b) { return m_dep_intervals.mk_join(a, b); } u_dependency* mk_leaf(lp::constraint_index ci) { return m_dep_intervals.mk_leaf(ci); } diff --git a/src/math/lp/nla_monotone_lemmas.h b/src/math/lp/nla_monotone_lemmas.h index d13f588e889..fb9c469a84a 100644 --- a/src/math/lp/nla_monotone_lemmas.h +++ b/src/math/lp/nla_monotone_lemmas.h @@ -7,16 +7,14 @@ --*/ #pragma once namespace nla { -class core; -class monotone : common { -public: - monotone(core *core); - void monotonicity_lemma(); -private: - void monotonicity_lemma(monic const& m); - void monotonicity_lemma_gt(const monic& m); - void monotonicity_lemma_lt(const monic& m); - std::vector get_sorted_key(const monic& rm) const; - vector> get_sorted_key_with_rvars(const monic& a) const; -}; + class core; + class monotone : common { + public: + monotone(core *core); + void monotonicity_lemma(); + private: + void monotonicity_lemma(monic const& m); + void monotonicity_lemma_gt(const monic& m); + void monotonicity_lemma_lt(const monic& m); + }; } diff --git a/src/math/lp/nla_order_lemmas.cpp b/src/math/lp/nla_order_lemmas.cpp index 94ddc4d9b04..902e208fc78 100644 --- a/src/math/lp/nla_order_lemmas.cpp +++ b/src/math/lp/nla_order_lemmas.cpp @@ -19,7 +19,7 @@ typedef lp::lar_term term; // a > b && c > 0 => ac > bc void order::order_lemma() { TRACE("nla_solver", ); - if (!c().m_nla_settings.run_order) { + if (!c().params().arith_nl_order()) { TRACE("nla_solver", tout << "not generating order lemmas\n";); return; } @@ -116,7 +116,7 @@ void order::order_lemma_on_factor_binomial_rm(const monic& ac, bool k, const mon tout << "bd=" << pp_mon_with_vars(_(), bd) << "\n"; ); factor d(_().m_evars.find(ac.vars()[k]).var(), factor_type::VAR); - factor b(false); + factor b; if (c().divide(bd, d, b)) { order_lemma_on_binomial_ac_bd(ac, k, bd, b, d.var()); } @@ -192,7 +192,7 @@ bool order::order_lemma_on_ac_and_bc(const monic& rm_ac, tout << "rm_bd = " << pp_mon_with_vars(_(), rm_bd) << "\n"; tout << "ac_f[k] = "; c().print_factor_with_vars(ac_f[k], tout);); - factor b(false); + factor b; return c().divide(rm_bd, ac_f[k], b) && order_lemma_on_ac_and_bc_and_factors(rm_ac, ac_f[!k], ac_f[k], rm_bd, b); diff --git a/src/math/lp/nla_powers.cpp b/src/math/lp/nla_powers.cpp index f389aad93c7..4e521bb0130 100644 --- a/src/math/lp/nla_powers.cpp +++ b/src/math/lp/nla_powers.cpp @@ -80,12 +80,12 @@ namespace nla { lbool powers::check(lpvar r, lpvar x, lpvar y, vector& lemmas) { TRACE("nla", tout << r << " == " << x << "^" << y << "\n"); + core& c = m_core; if (x == null_lpvar || y == null_lpvar || r == null_lpvar) return l_undef; - if (lp::tv::is_term(x) || lp::tv::is_term(y) || lp::tv::is_term(r)) + if (c.lra.column_has_term(x) || c.lra.column_has_term(y) || c.lra.column_has_term(r)) return l_undef; - core& c = m_core; if (c.use_nra_model()) return l_undef; diff --git a/src/math/lp/nla_settings.h b/src/math/lp/nla_settings.h deleted file mode 100644 index ec11ea5b234..00000000000 --- a/src/math/lp/nla_settings.h +++ /dev/null @@ -1,46 +0,0 @@ -/*++ -Copyright (c) 2017 Microsoft Corporation - -Author: - - Lev Nachmanson (levnach) - ---*/ - -#pragma once -namespace nla { - struct nla_settings { - bool run_order = true; - bool run_tangents = true; - - // horner fields - bool run_horner = true; - unsigned horner_frequency = 4; - unsigned horner_row_length_limit = 10; - unsigned horner_subs_fixed = 2; - - - // grobner fields - bool run_grobner = true; - unsigned grobner_row_length_limit = 50; - unsigned grobner_subs_fixed = 1; - unsigned grobner_eqs_growth = 10; - unsigned grobner_tree_size_growth = 2; - unsigned grobner_expr_size_growth = 2; - unsigned grobner_expr_degree_growth = 2; - unsigned grobner_max_simplified = 10000; - unsigned grobner_number_of_conflicts_to_report = 1; - unsigned grobner_quota = 0; - unsigned grobner_frequency = 4; - - - // nra fields - bool run_nra = false; - - // expensive patching - bool expensive_patching = false; - - nla_settings() {} - - }; -} diff --git a/src/math/lp/nla_solver.cpp b/src/math/lp/nla_solver.cpp index bd0f1953ce0..f55eec6a6f7 100644 --- a/src/math/lp/nla_solver.cpp +++ b/src/math/lp/nla_solver.cpp @@ -16,9 +16,6 @@ #include "math/polynomial/algebraic_numbers.h" namespace nla { - - nla_settings& solver::settings() { return m_core->m_nla_settings; } - void solver::add_monic(lpvar v, unsigned sz, lpvar const* vs) { m_core->add_monic(v, sz, vs); } @@ -45,8 +42,12 @@ namespace nla { bool solver::need_check() { return m_core->has_relevant_monomial(); } - lbool solver::check(vector& l) { - return m_core->check(l); + lbool solver::check() { + return m_core->check(); + } + + void solver::propagate() { + m_core->propagate(); } void solver::push(){ @@ -57,8 +58,8 @@ namespace nla { m_core->pop(n); } - solver::solver(lp::lar_solver& s, reslimit& limit): - m_core(alloc(core, s, limit)) { + solver::solver(lp::lar_solver& s, params_ref const& p, reslimit& limit): + m_core(alloc(core, s, p, limit)) { } bool solver::influences_nl_var(lpvar j) const { @@ -82,22 +83,45 @@ namespace nla { nlsat::anum_manager& solver::am() { return m_core->m_nra.am(); } - nlsat::anum const& solver::am_value(lp::var_index v) const { + nlsat::anum const& solver::am_value(lp::lpvar v) const { SASSERT(use_nra_model()); return m_core->m_nra.value(v); } - void solver::collect_statistics(::statistics & st) { - m_core->collect_statistics(st); + scoped_anum& solver::tmp1() { + SASSERT(use_nra_model()); + return m_core->m_nra.tmp1(); } + scoped_anum& solver::tmp2() { + SASSERT(use_nra_model()); + return m_core->m_nra.tmp2(); + } + + // ensure r = x^y, add abstraction/refinement lemmas - lbool solver::check_power(lpvar r, lpvar x, lpvar y, vector& lemmas) { - return m_core->check_power(r, x, y, lemmas); + lbool solver::check_power(lpvar r, lpvar x, lpvar y) { + return m_core->check_power(r, x, y); + } + + void solver::check_bounded_divisions() { + m_core->check_bounded_divisions(); + } + + vector const& solver::lemmas() const { + return m_core->lemmas(); + } + + vector const& solver::literals() const { + return m_core->literals(); + } + + vector const& solver::equalities() const { + return m_core->equalities(); } - void solver::check_bounded_divisions(vector& lemmas) { - m_core->check_bounded_divisions(lemmas); + vector const& solver::fixed_equalities() const { + return m_core->fixed_equalities(); } } diff --git a/src/math/lp/nla_solver.h b/src/math/lp/nla_solver.h index d04ff8e516c..53e62b4f074 100644 --- a/src/math/lp/nla_solver.h +++ b/src/math/lp/nla_solver.h @@ -13,7 +13,6 @@ Copyright (c) 2017 Microsoft Corporation #include "util/params.h" #include "math/lp/lar_solver.h" #include "math/lp/monic.h" -#include "math/lp/nla_settings.h" #include "math/lp/nla_core.h" namespace nra { class solver; @@ -25,28 +24,35 @@ namespace nla { core* m_core; public: - solver(lp::lar_solver& s, reslimit& limit); + solver(lp::lar_solver& s, params_ref const& p, reslimit& limit); ~solver(); - + const auto& monics_with_changed_bounds() const { return m_core->monics_with_changed_bounds(); } void add_monic(lpvar v, unsigned sz, lpvar const* vs); void add_idivision(lpvar q, lpvar x, lpvar y); void add_rdivision(lpvar q, lpvar x, lpvar y); void add_bounded_division(lpvar q, lpvar x, lpvar y); - void check_bounded_divisions(vector&); + void check_bounded_divisions(); void set_relevant(std::function& is_relevant); - nla_settings& settings(); void push(); void pop(unsigned scopes); bool need_check(); - lbool check(vector&); - lbool check_power(lpvar r, lpvar x, lpvar y, vector&); + lbool check(); + void propagate(); + void simplify() { m_core->simplify(); } + lbool check_power(lpvar r, lpvar x, lpvar y); bool is_monic_var(lpvar) const; bool influences_nl_var(lpvar) const; std::ostream& display(std::ostream& out) const; bool use_nra_model() const; core& get_core(); nlsat::anum_manager& am(); - nlsat::anum const& am_value(lp::var_index v) const; - void collect_statistics(::statistics & st); + nlsat::anum const& am_value(lp::lpvar v) const; + scoped_anum& tmp1(); + scoped_anum& tmp2(); + vector const& lemmas() const; + vector const& literals() const; + vector const& fixed_equalities() const; + vector const& equalities() const; + bool should_check_feasible() const { return m_core->should_check_feasible(); } }; } diff --git a/src/math/lp/nla_tangent_lemmas.cpp b/src/math/lp/nla_tangent_lemmas.cpp index 299d8031fdf..56ebbfacc6c 100644 --- a/src/math/lp/nla_tangent_lemmas.cpp +++ b/src/math/lp/nla_tangent_lemmas.cpp @@ -75,8 +75,8 @@ class tangent_imp { c().negate_relation(lemma, m_jy, m_y.rat_sign()*pl.y); #if Z3DEBUG SASSERT(c().val(m_x) == m_xy.x && c().val(m_y) == m_xy.y); - int mult_sign = nla::rat_sign(pl.x - m_xy.x)*nla::rat_sign(pl.y - m_xy.y); - SASSERT((mult_sign == 1) == m_below); + // int mult_sign = nla::rat_sign(pl.x - m_xy.x)*nla::rat_sign(pl.y - m_xy.y); + SASSERT((nla::rat_sign(pl.x - m_xy.x)*nla::rat_sign(pl.y - m_xy.y) == 1) == m_below); // If "mult_sign is 1" then (a - x)(b-y) > 0 and ab - bx - ay + xy > 0 // or -ab + bx + ay < xy or -ay - bx + xy > -ab // val(j) stands for xy. So, finally we have -ay - bx + j > - ab @@ -186,7 +186,7 @@ tangents::tangents(core * c) : common(c) {} void tangents::tangent_lemma() { factorization bf(nullptr); const monic* m = nullptr; - if (c().m_nla_settings.run_tangents && c().find_bfc_to_refine(m, bf)) { + if (c().params().arith_nl_tangents() && c().find_bfc_to_refine(m, bf)) { lpvar j = m->var(); tangent_imp tangent(point(val(bf[0]), val(bf[1])), c().val(j), *m, bf, *this); tangent(); diff --git a/src/math/lp/nla_types.h b/src/math/lp/nla_types.h index 8169266ccf6..f0f79665223 100644 --- a/src/math/lp/nla_types.h +++ b/src/math/lp/nla_types.h @@ -22,8 +22,10 @@ namespace nla { typedef lp::lconstraint_kind llc; typedef lp::constraint_index lpci; typedef lp::explanation expl_set; - typedef lp::var_index lpvar; + typedef lp::lpvar lpvar; const lpvar null_lpvar = UINT_MAX; + + inline int rat_sign(const rational& r) { return r.is_pos()? 1 : ( r.is_neg()? -1 : 0); } inline rational rrat_sign(const rational& r) { return rational(rat_sign(r)); } diff --git a/src/math/lp/nra_solver.cpp b/src/math/lp/nra_solver.cpp index 1f4e0b76abb..9c9db4e4114 100644 --- a/src/math/lp/nra_solver.cpp +++ b/src/math/lp/nra_solver.cpp @@ -9,7 +9,7 @@ #include "math/polynomial/polynomial.h" #include "math/polynomial/algebraic_numbers.h" #include "util/map.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" #include "math/lp/nla_core.h" @@ -18,18 +18,20 @@ namespace nra { typedef nla::mon_eq mon_eq; typedef nla::variable_map_type variable_map_type; + struct solver::imp { - lp::lar_solver& s; + lp::lar_solver& lra; reslimit& m_limit; params_ref m_params; u_map m_lp2nl; // map from lar_solver variables to nlsat::solver variables - lp::u_set m_term_set; + indexed_uint_set m_term_set; scoped_ptr m_nlsat; - scoped_ptr m_zero; - mutable variable_map_type m_variable_values; // current model - nla::core& m_nla_core; + scoped_ptr m_values; // values provided by LRA solver + scoped_ptr m_tmp1, m_tmp2; + nla::core& m_nla_core; + imp(lp::lar_solver& s, reslimit& lim, params_ref const& p, nla::core& nla_core): - s(s), + lra(s), m_limit(lim), m_params(p), m_nla_core(nla_core) {} @@ -38,6 +40,91 @@ struct solver::imp { return m_nla_core.m_to_refine.size() != 0; } + indexed_uint_set m_mon_set, m_constraint_set; + + struct occurs { + unsigned_vector constraints; + unsigned_vector monics; + unsigned_vector terms; + }; + + void init_cone_of_influence() { + indexed_uint_set visited; + unsigned_vector todo; + vector var2occurs; + m_term_set.reset(); + m_mon_set.reset(); + m_constraint_set.reset(); + + for (auto ci : lra.constraints().indices()) { + auto const& c = lra.constraints()[ci]; + for (auto const& [coeff, v] : c.coeffs()) { + var2occurs.reserve(v + 1); + var2occurs[v].constraints.push_back(ci); + } + } + + for (auto const& m : m_nla_core.emons()) { + for (auto v : m.vars()) { + var2occurs.reserve(v + 1); + var2occurs[v].monics.push_back(m.var()); + } + } + + for (const auto *t : lra.terms() ) { + for (auto const iv : *t) { + auto v = iv.j(); + var2occurs.reserve(v + 1); + var2occurs[v].terms.push_back(t->j()); + } + } + + for (auto const& m : m_nla_core.m_to_refine) + todo.push_back(m); + + for (unsigned i = 0; i < todo.size(); ++i) { + auto v = todo[i]; + if (visited.contains(v)) + continue; + visited.insert(v); + var2occurs.reserve(v + 1); + for (auto ci : var2occurs[v].constraints) { + m_constraint_set.insert(ci); + auto const& c = lra.constraints()[ci]; + for (auto const& [coeff, w] : c.coeffs()) + todo.push_back(w); + } + for (auto w : var2occurs[v].monics) + todo.push_back(w); + + for (auto ti : var2occurs[v].terms) { + for (auto iv : lra.get_term(ti)) + todo.push_back(iv.j()); + todo.push_back(ti); + } + + if (lra.column_has_term(v)) { + m_term_set.insert(v); + for (auto kv : lra.get_term(v)) + todo.push_back(kv.j()); + } + + if (m_nla_core.is_monic_var(v)) { + m_mon_set.insert(v); + for (auto w : m_nla_core.emons()[v]) + todo.push_back(w); + } + } + } + + void reset() { + m_values = nullptr; + m_tmp1 = nullptr; m_tmp2 = nullptr; + m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); + m_values = alloc(scoped_anum_vector, am()); + m_term_set.reset(); + m_lp2nl.reset(); + } /** \brief one-shot nlsat check. @@ -50,30 +137,29 @@ struct solver::imp { TBD: use partial model from lra_solver to prime the state of nlsat_solver. TBD: explore more incremental ways of applying nlsat (using assumptions) */ - lbool check() { + lbool check() { SASSERT(need_check()); - m_zero = nullptr; - m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); - m_zero = alloc(scoped_anum, am()); - m_term_set.clear(); - m_lp2nl.reset(); + reset(); vector core; + init_cone_of_influence(); // add linear inequalities from lra_solver - for (lp::constraint_index ci : s.constraints().indices()) { + for (auto ci : m_constraint_set) add_constraint(ci); - } - + // add polynomial definitions. - for (auto const& m : m_nla_core.emons()) - add_monic_eq(m); - for (unsigned i : m_term_set) + for (auto const& m : m_mon_set) + add_monic_eq(m_nla_core.emons()[m]); + + // add term definitions. + for (unsigned i : m_term_set) add_term(i); - // TBD: add variable bounds? + + TRACE("nra", m_nlsat->display(tout)); lbool r = l_undef; try { - r = m_nlsat->check(); + r = m_nlsat->check(); } catch (z3_exception&) { if (m_limit.is_canceled()) { @@ -83,15 +169,29 @@ struct solver::imp { throw; } } - TRACE("nra", + TRACE("nra", m_nlsat->display(tout << r << "\n"); - display(tout); - for (auto kv : m_lp2nl) - tout << "j" << kv.m_key << " := x" << kv.m_value << "\n"; - ); + display(tout); + for (auto [j, x] : m_lp2nl) tout << "j" << j << " := x" << x << "\n";); switch (r) { - case l_true: + case l_true: m_nla_core.set_use_nra_model(true); + lra.init_model(); + for (lp::constraint_index ci : lra.constraints().indices()) + if (!check_constraint(ci)) { + IF_VERBOSE(0, verbose_stream() << "constraint " << ci << " violated\n"; + lra.constraints().display(verbose_stream())); + UNREACHABLE(); + return l_undef; + } + for (auto const& m : m_nla_core.emons()) { + if (!check_monic(m)) { + IF_VERBOSE(0, verbose_stream() << "monic " << m << " violated\n"; + lra.constraints().display(verbose_stream())); + UNREACHABLE(); + return l_undef; + } + } break; case l_false: { lp::explanation ex; @@ -99,7 +199,7 @@ struct solver::imp { for (auto c : core) { unsigned idx = static_cast(static_cast(c) - this); ex.push_back(idx); - TRACE("arith", tout << "ex: " << idx << "\n";); + TRACE("nra", lra.display_constraint(tout << "ex: " << idx << ": ", idx) << "\n";); } nla::new_lemma lemma(m_nla_core, __FUNCTION__); lemma &= ex; @@ -108,16 +208,35 @@ struct solver::imp { } case l_undef: break; - } + } return r; - } + } + + void add_monic_eq_bound(mon_eq const& m) { + if (!lra.column_has_lower_bound(m.var()) && + !lra.column_has_upper_bound(m.var())) + return; + polynomial::manager& pm = m_nlsat->pm(); + svector vars; + for (auto v : m.vars()) + vars.push_back(lp2nl(v)); + auto v = m.var(); + polynomial::monomial_ref m1(pm.mk_monomial(vars.size(), vars.data()), pm); + polynomial::monomial * mls[1] = { m1 }; + polynomial::scoped_numeral_vector coeffs(pm.m()); + coeffs.push_back(mpz(1)); + polynomial::polynomial_ref p(pm.mk_polynomial(1, coeffs.data(), mls), pm); + if (lra.column_has_lower_bound(v)) + add_lb_p(lra.get_lower_bound(v), p, lra.get_column_lower_bound_witness(v)); + if (lra.column_has_upper_bound(v)) + add_ub_p(lra.get_upper_bound(v), p, lra.get_column_upper_bound_witness(v)); + } void add_monic_eq(mon_eq const& m) { polynomial::manager& pm = m_nlsat->pm(); svector vars; - for (auto v : m.vars()) { + for (auto v : m.vars()) vars.push_back(lp2nl(v)); - } polynomial::monomial_ref m1(pm.mk_monomial(vars.size(), vars.data()), pm); polynomial::monomial_ref m2(pm.mk_monomial(lp2nl(m.var()), 1), pm); polynomial::monomial * mls[2] = { m1, m2 }; @@ -132,7 +251,7 @@ struct solver::imp { } void add_constraint(unsigned idx) { - auto& c = s.constraints()[idx]; + auto& c = lra.constraints()[idx]; auto& pm = m_nlsat->pm(); auto k = c.kind(); auto rhs = c.rhs(); @@ -140,9 +259,9 @@ struct solver::imp { auto sz = lhs.size(); svector vars; rational den = denominator(rhs); - for (auto kv : lhs) { - vars.push_back(lp2nl(kv.second)); - den = lcm(den, denominator(kv.first)); + for (auto [coeff, v] : lhs) { + vars.push_back(lp2nl(v)); + den = lcm(den, denominator(coeff)); } vector coeffs; for (auto kv : lhs) { @@ -176,26 +295,122 @@ struct solver::imp { m_nlsat->mk_clause(1, &lit, a); } + bool check_monic(mon_eq const& m) { + scoped_anum val1(am()), val2(am()); + am().set(val1, value(m.var())); + am().set(val2, rational::one().to_mpq()); + for (auto v : m.vars()) + am().mul(val2, value(v), val2); + return am().eq(val1, val2); + } + + bool check_constraint(unsigned idx) { + auto& c = lra.constraints()[idx]; + auto k = c.kind(); + auto offset = -c.rhs(); + auto lhs = c.coeffs(); + + scoped_anum val(am()), mon(am()); + am().set(val, offset.to_mpq()); + for (auto [coeff, v] : lhs) { + am().set(mon, coeff.to_mpq()); + am().mul(mon, value(v), mon); + am().add(val, mon, val); + } + am().set(mon, rational::zero().to_mpq()); + switch (k) { + case lp::lconstraint_kind::LE: + return am().le(val, mon); + case lp::lconstraint_kind::GE: + return am().ge(val, mon); + case lp::lconstraint_kind::LT: + return am().lt(val, mon); + case lp::lconstraint_kind::GT: + return am().gt(val, mon); + case lp::lconstraint_kind::EQ: + return am().eq(val, mon); + default: + UNREACHABLE(); + } + return false; + } + + lbool check(dd::solver::equation_vector const& eqs) { + reset(); + for (auto const& eq : eqs) + add_eq(*eq); + for (auto const& m : m_nla_core.emons()) + if (any_of(m.vars(), [&](lp::lpvar v) { return m_lp2nl.contains(v); })) + add_monic_eq_bound(m); + for (unsigned i : m_term_set) + add_term(i); + for (auto const& [v, w] : m_lp2nl) { + if (lra.column_has_lower_bound(v)) + add_lb(lra.get_lower_bound(v), w, lra.get_column_lower_bound_witness(v)); + if (lra.column_has_upper_bound(v)) + add_ub(lra.get_upper_bound(v), w, lra.get_column_upper_bound_witness(v)); + } + + lbool r = l_undef; + try { + r = m_nlsat->check(); + } + catch (z3_exception&) { + if (m_limit.is_canceled()) { + r = l_undef; + } + else { + throw; + } + } + + switch (r) { + case l_true: + m_nla_core.set_use_nra_model(true); + lra.init_model(); + for (lp::constraint_index ci : lra.constraints().indices()) + if (!check_constraint(ci)) + return l_undef; + for (auto const& m : m_nla_core.emons()) + if (!check_monic(m)) + return l_undef; + break; + case l_false: { + lp::explanation ex; + vector core; + m_nlsat->get_core(core); + u_dependency_manager dm; + vector lv; + for (auto c : core) + dm.linearize(static_cast(c), lv); + for (auto ci : lv) + ex.push_back(ci); + nla::new_lemma lemma(m_nla_core, __FUNCTION__); + lemma &= ex; + break; + } + case l_undef: + break; + } + return r; + } lbool check(vector const& eqs) { - m_zero = nullptr; - m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); - m_zero = alloc(scoped_anum, am()); - m_lp2nl.reset(); - m_term_set.clear(); + reset(); for (auto const& eq : eqs) add_eq(eq); + for (auto const& m : m_nla_core.emons()) + add_monic_eq(m); for (auto const& [v, w] : m_lp2nl) { - auto& ls = m_nla_core.m_lar_solver; - if (ls.column_has_lower_bound(v)) - add_lb(ls.get_lower_bound(v), w); - if (ls.column_has_upper_bound(v)) - add_ub(ls.get_upper_bound(v), w); + if (lra.column_has_lower_bound(v)) + add_lb(lra.get_lower_bound(v), w); + if (lra.column_has_upper_bound(v)) + add_ub(lra.get_upper_bound(v), w); } - + lbool r = l_undef; try { - r = m_nlsat->check(); + r = m_nlsat->check(); } catch (z3_exception&) { if (m_limit.is_canceled()) { @@ -206,85 +421,109 @@ struct solver::imp { } } + if (r == l_true) + return r; + IF_VERBOSE(0, verbose_stream() << "check-nra " << r << "\n"; m_nlsat->display(verbose_stream()); for (auto const& [v, w] : m_lp2nl) { - auto& ls = m_nla_core.m_lar_solver; - if (ls.column_has_lower_bound(v)) - verbose_stream() << w << " >= " << ls.get_lower_bound(v) << "\n"; - if (ls.column_has_upper_bound(v)) - verbose_stream() << w << " <= " << ls.get_upper_bound(v) << "\n"; - }); - - + if (lra.column_has_lower_bound(v)) + verbose_stream() << "x" << w << " >= " << lra.get_lower_bound(v) << "\n"; + if (lra.column_has_upper_bound(v)) + verbose_stream() << "x" << w << " <= " << lra.get_upper_bound(v) << "\n"; + }); + return r; } - - void add_eq(dd::pdd const& eq) { + + void add_eq(dd::solver::equation const& eq) { + add_eq(eq.poly(), eq.dep()); + } + + void add_eq(dd::pdd const& eq, nlsat::assumption a = nullptr) { dd::pdd normeq = eq; rational lc(1); - for (auto const& [c, m] : eq) + for (auto const& [c, m] : eq) lc = lcm(denominator(c), lc); if (lc != 1) normeq *= lc; polynomial::manager& pm = m_nlsat->pm(); polynomial::polynomial_ref p(pdd2polynomial(normeq), pm); - bool is_even[1] = { false }; - polynomial::polynomial* ps[1] = { p }; - nlsat::literal lit = m_nlsat->mk_ineq_literal(nlsat::atom::kind::EQ, 1, ps, is_even); - m_nlsat->mk_clause(1, &lit, nullptr); + bool is_even[1] = {false}; + polynomial::polynomial* ps[1] = {p}; + nlsat::literal lit = m_nlsat->mk_ineq_literal(nlsat::atom::kind::EQ, 1, ps, is_even); + m_nlsat->mk_clause(1, &lit, a); } - void add_lb(lp::impq const& b, unsigned w) { - add_bound(b.x, w, b.y <= 0, b.y > 0 ? nlsat::atom::kind::GT : nlsat::atom::kind::LT); + void add_lb(lp::impq const& b, unsigned w, nlsat::assumption a = nullptr) { + polynomial::manager& pm = m_nlsat->pm(); + polynomial::polynomial_ref p(pm.mk_polynomial(w), pm); + add_lb_p(b, p, a); } - void add_ub(lp::impq const& b, unsigned w) { - add_bound(b.x, w, b.y >= 0, b.y < 0 ? nlsat::atom::kind::LT : nlsat::atom::kind::GT); + + void add_ub(lp::impq const& b, unsigned w, nlsat::assumption a = nullptr) { + polynomial::manager& pm = m_nlsat->pm(); + polynomial::polynomial_ref p(pm.mk_polynomial(w), pm); + add_ub_p(b, p, a); + } + + void add_lb_p(lp::impq const& b, polynomial::polynomial* p, nlsat::assumption a = nullptr) { + add_bound_p(b.x, p, b.y <= 0, b.y > 0 ? nlsat::atom::kind::GT : nlsat::atom::kind::LT, a); + } + + void add_ub_p(lp::impq const& b, polynomial::polynomial* p, nlsat::assumption a = nullptr) { + add_bound_p(b.x, p, b.y >= 0, b.y < 0 ? nlsat::atom::kind::LT : nlsat::atom::kind::GT, a); } // w - bound < 0 // w - bound > 0 - void add_bound(lp::mpq const& bound, unsigned w, bool neg, nlsat::atom::kind k) { + + void add_bound_p(lp::mpq const& bound, polynomial::polynomial* p1, bool neg, nlsat::atom::kind k, nlsat::assumption a = nullptr) { polynomial::manager& pm = m_nlsat->pm(); - polynomial::polynomial_ref p1(pm.mk_polynomial(w), pm); polynomial::polynomial_ref p2(pm.mk_const(bound), pm); polynomial::polynomial_ref p(pm.sub(p1, p2), pm); - polynomial::polynomial* ps[1] = { p }; - bool is_even[1] = { false }; + polynomial::polynomial* ps[1] = {p}; + bool is_even[1] = {false}; nlsat::literal lit = m_nlsat->mk_ineq_literal(k, 1, ps, is_even); if (neg) lit.neg(); - m_nlsat->mk_clause(1, &lit, nullptr); + m_nlsat->mk_clause(1, &lit, a); } - + + void add_bound(lp::mpq const& bound, unsigned w, bool neg, nlsat::atom::kind k, nlsat::assumption a = nullptr) { + polynomial::manager& pm = m_nlsat->pm(); + polynomial::polynomial_ref p(pm.mk_polynomial(w), pm); + add_bound_p(bound, p, neg, k, a); + } + polynomial::polynomial* pdd2polynomial(dd::pdd const& p) { polynomial::manager& pm = m_nlsat->pm(); - if (p.is_val()) + if (p.is_val()) return pm.mk_const(p.val()); polynomial::polynomial_ref lo(pdd2polynomial(p.lo()), pm); polynomial::polynomial_ref hi(pdd2polynomial(p.hi()), pm); unsigned w, v = p.var(); if (!m_lp2nl.find(v, w)) { - w = m_nlsat->mk_var(false); + w = m_nlsat->mk_var(is_int(v)); m_lp2nl.insert(v, w); } polynomial::polynomial_ref vp(pm.mk_polynomial(w, 1), pm); - return pm.add(lo, pm.mul(vp, hi)); + polynomial::polynomial_ref mp(pm.mul(vp, hi), pm); + return pm.add(lo, mp); } - - bool is_int(lp::var_index v) { - return s.var_is_int(v); + + + + bool is_int(lp::lpvar v) { + return lra.var_is_int(v); } - - polynomial::var lp2nl(lp::var_index v) { + polynomial::var lp2nl(lp::lpvar v) { polynomial::var r; if (!m_lp2nl.find(v, r)) { r = m_nlsat->mk_var(is_int(v)); m_lp2nl.insert(v, r); - if (!m_term_set.contains(v) && s.column_corresponds_to_term(v)) { - if (v >= m_term_set.data_size()) - m_term_set.resize(v + 1); + if (!m_term_set.contains(v) && lra.column_has_term(v)) { m_term_set.insert(v); } } @@ -292,18 +531,17 @@ struct solver::imp { } // void add_term(unsigned term_column) { - lp::tv ti = lp::tv::raw(s.column_to_reported_index(term_column)); - const lp::lar_term& t = s.get_term(ti); - // code that creates a polynomial equality between the linear coefficients and + const lp::lar_term& t = lra.get_term(term_column); + // code that creates a polynomial equality between the linear coefficients and // variable representing the term. svector vars; rational den(1); for (lp::lar_term::ival kv : t) { - vars.push_back(lp2nl(kv.column().index())); + vars.push_back(lp2nl(kv.j())); den = lcm(den, denominator(kv.coeff())); } vars.push_back(lp2nl(term_column)); - + vector coeffs; for (auto kv : t) { coeffs.push_back(den * kv.coeff()); @@ -311,24 +549,43 @@ struct solver::imp { coeffs.push_back(-den); polynomial::manager& pm = m_nlsat->pm(); polynomial::polynomial_ref p(pm.mk_linear(coeffs.size(), coeffs.data(), vars.data(), rational(0)), pm); - polynomial::polynomial* ps[1] = { p }; - bool is_even[1] = { false }; - nlsat::literal lit = m_nlsat->mk_ineq_literal(nlsat::atom::kind::EQ, 1, ps, is_even); + polynomial::polynomial* ps[1] = {p}; + bool is_even[1] = {false}; + nlsat::literal lit = m_nlsat->mk_ineq_literal(nlsat::atom::kind::EQ, 1, ps, is_even); m_nlsat->mk_clause(1, &lit, nullptr); } - nlsat::anum const& value(lp::var_index v) const { + nlsat::anum const& value(lp::lpvar v) { polynomial::var pv; if (m_lp2nl.find(v, pv)) return m_nlsat->value(pv); - else - return *m_zero; + else { + for (unsigned w = m_values->size(); w <= v; ++w) { + scoped_anum a(am()); + am().set(a, m_nla_core.val(w).to_mpq()); + m_values->push_back(a); + } + return (*m_values)[v]; + } } nlsat::anum_manager& am() { return m_nlsat->am(); } + scoped_anum& tmp1() { + if (!m_tmp1) + m_tmp1 = alloc(scoped_anum, am()); + return *m_tmp1; + } + + scoped_anum& tmp2() { + if (!m_tmp2) + m_tmp2 = alloc(scoped_anum, am()); + return *m_tmp2; + } + + void updt_params(params_ref& p) { m_params.append(p); } @@ -363,6 +620,10 @@ lbool solver::check(vector const& eqs) { return m_imp->check(eqs); } +lbool solver::check(dd::solver::equation_vector const& eqs) { + return m_imp->check(eqs); +} + bool solver::need_check() { return m_imp->need_check(); } @@ -371,7 +632,7 @@ std::ostream& solver::display(std::ostream& out) const { return m_imp->display(out); } -nlsat::anum const& solver::value(lp::var_index v) const { +nlsat::anum const& solver::value(lp::lpvar v) { return m_imp->value(v); } @@ -379,6 +640,11 @@ nlsat::anum_manager& solver::am() { return m_imp->am(); } +scoped_anum& solver::tmp1() { return m_imp->tmp1(); } + +scoped_anum& solver::tmp2() { return m_imp->tmp2(); } + + void solver::updt_params(params_ref& p) { m_imp->updt_params(p); } diff --git a/src/math/lp/nra_solver.h b/src/math/lp/nra_solver.h index b8863e44b9b..90f022ba6dc 100644 --- a/src/math/lp/nra_solver.h +++ b/src/math/lp/nra_solver.h @@ -6,9 +6,11 @@ #pragma once #include "util/vector.h" #include "math/lp/lp_settings.h" +#include "math/polynomial/algebraic_numbers.h" #include "util/rlimit.h" #include "util/params.h" #include "nlsat/nlsat_solver.h" +#include "math/grobner/pdd_solver.h" #include "math/dd/dd_pdd.h" namespace lp { @@ -18,8 +20,6 @@ namespace lp { namespace nra { - - class solver { struct imp; imp* m_imp; @@ -38,10 +38,15 @@ namespace nra { lbool check(); /** - \breif Check feasibility of equalities modulo bounds constraints on their variables. + \brief Check feasibility of equalities modulo bounds constraints on their variables. */ lbool check(vector const& eqs); + /** + \brief Check feasibility with respect to a set of reduced constraints. + */ + lbool check(dd::solver::equation_vector const& eqs); + /* \brief determine whether nra check is needed. */ @@ -50,10 +55,14 @@ namespace nra { /* \brief Access model. */ - nlsat::anum const& value(lp::var_index v) const; + nlsat::anum const& value(lp::lpvar v); nlsat::anum_manager& am(); + scoped_anum& tmp1(); + + scoped_anum& tmp2(); + void updt_params(params_ref& p); /* diff --git a/src/math/lp/numeric_pair.h b/src/math/lp/numeric_pair.h index 25127400627..93780f7be39 100644 --- a/src/math/lp/numeric_pair.h +++ b/src/math/lp/numeric_pair.h @@ -24,7 +24,6 @@ #include #ifdef lp_for_z3 #include "util/rational.h" -#include "util/sstream.h" #include "util/z3_exception.h" #else // include "util/numerics/mpq.h" diff --git a/src/math/lp/random_updater.h b/src/math/lp/random_updater.h index d5cd4928caa..f0f76775079 100644 --- a/src/math/lp/random_updater.h +++ b/src/math/lp/random_updater.h @@ -25,14 +25,14 @@ Revision History: #include #include #include "math/lp/lp_settings.h" -#include "math/lp/u_set.h" +#include "util/uint_set.h" // see http://research.microsoft.com/projects/z3/smt07.pdf // The class searches for a feasible solution with as many different values of variables as it can find namespace lp { template struct numeric_pair; // forward definition class lar_solver; // forward definition class random_updater { - u_set m_var_set; + indexed_uint_set m_var_set; lar_solver & m_lar_solver; unsigned m_range; bool shift_var(unsigned j); diff --git a/src/math/lp/random_updater_def.h b/src/math/lp/random_updater_def.h index 7d167a4a03d..861068106c2 100644 --- a/src/math/lp/random_updater_def.h +++ b/src/math/lp/random_updater_def.h @@ -32,7 +32,6 @@ random_updater::random_updater( const vector & column_indices) : m_lar_solver(lar_solver), m_range(100000) { - m_var_set.resize(m_lar_solver.number_of_vars()); for (unsigned j : column_indices) m_var_set.insert(j); TRACE("lar_solver_rand", tout << "size = " << m_var_set.size() << "\n";); @@ -45,7 +44,9 @@ bool random_updater::shift_var(unsigned j) { if (ret) { const auto & A = m_lar_solver.A_r(); for (const auto& c : A.m_columns[j]) { - m_var_set.erase(m_lar_solver.r_basis()[c.var()]); + unsigned k = m_lar_solver.r_basis()[c.var()]; + if (m_var_set.contains(k)) + m_var_set.remove(k); } } return ret; @@ -54,7 +55,11 @@ bool random_updater::shift_var(unsigned j) { void random_updater::update() { // VERIFY(m_lar_solver.check_feasible()); - auto columns = m_var_set.index(); // m_var_set is going to change during the loop + unsigned_vector columns; + // m_var_set is going to change during the loop, make a copy + for (unsigned j : m_var_set) { + columns.push_back(j); + } for (auto j : columns) { if (!m_var_set.contains(j)) { TRACE("lar_solver_rand", tout << "skipped " << j << "\n";); diff --git a/src/math/lp/static_matrix.cpp b/src/math/lp/static_matrix.cpp index efb6e07cf75..a46b5abc003 100644 --- a/src/math/lp/static_matrix.cpp +++ b/src/math/lp/static_matrix.cpp @@ -29,7 +29,6 @@ namespace lp { template std::set> lp::static_matrix::get_domain(); template std::set> lp::static_matrix >::get_domain(); template void static_matrix::add_column_to_vector(mpq const&, unsigned int, mpq*) const; -template void static_matrix::add_columns_at_the_end(unsigned int); template bool static_matrix::is_correct() const; template mpq static_matrix::get_balance() const; diff --git a/src/math/lp/static_matrix.h b/src/math/lp/static_matrix.h index f79ff36ac38..9d6bb859964 100644 --- a/src/math/lp/static_matrix.h +++ b/src/math/lp/static_matrix.h @@ -126,7 +126,6 @@ class static_matrix unsigned lowest_row_in_column(unsigned col); - void add_columns_at_the_end(unsigned delta); void add_new_element(unsigned i, unsigned j, const T & v); void add_row() {m_rows.push_back(row_strip());} @@ -347,14 +346,14 @@ class static_matrix // we use the form -it + 1 = 0 m_work_vector.set_value(one_of_type(), bj); for (auto p : row) { - m_work_vector.set_value(-p.coeff(), p.column().index()); + m_work_vector.set_value(-p.coeff(), p.j()); // but take care of the basis 1 later } // now iterate with pivoting fill_last_row_with_pivoting_loop_block(bj, basis_heading); for (auto p : row) { - fill_last_row_with_pivoting_loop_block(p.column().index(), basis_heading); + fill_last_row_with_pivoting_loop_block(p.j(), basis_heading); } unsigned last_row = row_count() - 1; diff --git a/src/math/lp/static_matrix_def.h b/src/math/lp/static_matrix_def.h index 76c1dec546c..0370ee8990a 100644 --- a/src/math/lp/static_matrix_def.h +++ b/src/math/lp/static_matrix_def.h @@ -124,11 +124,6 @@ template unsigned static_matrix::lowest_row_in_co return ret; } -template void static_matrix::add_columns_at_the_end(unsigned delta) { - for (unsigned i = 0; i < delta; i++) - add_column(); -} - template void static_matrix::forget_last_columns(unsigned how_many_to_forget) { lp_assert(m_columns.size() >= how_many_to_forget); unsigned j = column_count() - 1; diff --git a/src/math/lp/u_set.h b/src/math/lp/u_set.h deleted file mode 100644 index ce59dccb724..00000000000 --- a/src/math/lp/u_set.h +++ /dev/null @@ -1,114 +0,0 @@ -/*++ -Copyright (c) 2017 Microsoft Corporation - -Module Name: - - - -Abstract: - - - -Author: - - Lev Nachmanson (levnach) - -Revision History: - -TBD use indexed_uint_set from src/util/uint_set.h, - ---*/ -#pragma once -#include "util/vector.h" -#include -namespace lp { -// serves at a set of non-negative integers smaller than the set size -class u_set { - svector m_data; - unsigned_vector m_index; - -public: - u_set(unsigned size): m_data(size, -1) {} - u_set() {} - u_set(u_set const& other): - m_data(other.m_data), - m_index(other.m_index) {} - - bool contains(unsigned j) const { - if (j >= m_data.size()) - return false; - return m_data[j] >= 0; - } - void insert(unsigned j) { - lp_assert(j < m_data.size()); - if (contains(j)) return; - m_data[j] = m_index.size(); - m_index.push_back(j); - } - void erase(unsigned j) { - if (!contains(j)) return; - unsigned pos_j = m_data[j]; - unsigned last_pos = m_index.size() - 1; - int last_j = m_index[last_pos]; - if (last_pos != pos_j) { - // move last to j spot - m_data[last_j] = pos_j; - m_index[pos_j] = last_j; - } - m_index.pop_back(); - m_data[j] = -1; - } - - int operator[](unsigned j) const { return m_index[j]; } - - void resize(unsigned size) { - if (size < data_size()) { - bool copy = false; - unsigned i = 0; - for (unsigned j : m_index) { - if (j < size) { - if (copy) { - m_data[j] = i; - m_index[i] = j; - } - i++; - } else { - copy = true; - } - } - m_index.shrink(i); - } - m_data.resize(size, -1); - } - - void increase_size_by_one() { - resize(m_data.size() + 1); - } - - unsigned data_size() const { return m_data.size(); } - unsigned size() const { return m_index.size();} - bool empty() const { return size() == 0; } - void clear() { - for (unsigned j : m_index) - m_data[j] = -1; - m_index.resize(0); - } - - std::ostream& display(std::ostream& out) const { - for (unsigned j : m_index) { - out << j << " "; - } - out << std::endl; - return out; - } - const unsigned * begin() const { return m_index.begin(); } - const unsigned * end() const { return m_index.end(); } - const unsigned_vector& index() { return m_index; } -}; - - -} - -inline std::ostream& operator<<(std::ostream& out, lp::u_set const& s) { - return s.display(out); -} diff --git a/src/math/lp/ul_pair.h b/src/math/lp/ul_pair.h deleted file mode 100644 index abfb4483be7..00000000000 --- a/src/math/lp/ul_pair.h +++ /dev/null @@ -1,82 +0,0 @@ -/*++ -Copyright (c) 2017 Microsoft Corporation - -Module Name: - - - -Abstract: - - - -Author: - - Lev Nachmanson (levnach) - -Revision History: - - ---*/ - -#pragma once -#include "util/vector.h" -#include -#include -#include -#include "math/lp/column_info.h" -#include "math/lp/lp_types.h" - -namespace lp { - - -inline bool kind_is_strict(lconstraint_kind kind) { return kind == LT || kind == GT;} - -inline std::ostream& operator<<(std::ostream& out, lconstraint_kind k) { - switch (k) { - case LE: return out << "<="; - case LT: return out << "<"; - case GE: return out << ">="; - case GT: return out << ">"; - case EQ: return out << "="; - case NE: return out << "!="; - } - return out << "??"; -} - -inline bool compare(const std::pair & a, const std::pair & b) { - return a.second < b.second; -} - -class ul_pair { - constraint_index m_lower_bound_witness; - constraint_index m_upper_bound_witness; - bool m_associated_with_row; -public: - constraint_index& lower_bound_witness() {return m_lower_bound_witness;} - constraint_index lower_bound_witness() const {return m_lower_bound_witness;} - constraint_index& upper_bound_witness() { return m_upper_bound_witness;} - constraint_index upper_bound_witness() const {return m_upper_bound_witness;} - bool operator!=(const ul_pair & p) const { - return !(*this == p); - } - - bool operator==(const ul_pair & p) const { - return m_lower_bound_witness == p.m_lower_bound_witness - && m_upper_bound_witness == p.m_upper_bound_witness && - m_associated_with_row == p.m_associated_with_row; - } - // empty constructor - ul_pair() : - m_lower_bound_witness(UINT_MAX), - m_upper_bound_witness(UINT_MAX), - m_associated_with_row(false) {} - - ul_pair(bool associated_with_row) : - m_lower_bound_witness(UINT_MAX), - m_upper_bound_witness(UINT_MAX), - m_associated_with_row(associated_with_row) {} - - bool associated_with_row() const { return m_associated_with_row; } -}; - -} diff --git a/src/math/lp/var_eqs.h b/src/math/lp/var_eqs.h index 998779dc61a..a639a9a5ee2 100644 --- a/src/math/lp/var_eqs.h +++ b/src/math/lp/var_eqs.h @@ -28,22 +28,24 @@ namespace nla { class eq_justification { - lpci m_cs[4]; + u_dependency* m_cs[4]; public: - eq_justification(std::initializer_list cs) { + eq_justification(std::initializer_list cs) { int i = 0; - for (lpci c: cs) { + for (auto c: cs) { m_cs[i++] = c; } for (; i < 4; i++) { - m_cs[i] = -1; + m_cs[i] = nullptr; } } - void explain(lp::explanation& e) const { - for (lpci c : m_cs) - if (c + 1 != 0) // c != -1 - e.push_back(c); + u_dependency* const* begin() const { return m_cs; } + u_dependency* const* end() const { + unsigned i = 0; + for (; i < 4 && m_cs[i]; ++i) + ; + return m_cs + i; } }; @@ -202,7 +204,7 @@ class var_eqs { } for (eq_justification const& j : m_justtrail) { - j.explain(e); + explain_eq(j, e); } m_stats.m_num_explains += m_justtrail.size(); m_stats.m_num_explain_calls++; @@ -216,6 +218,17 @@ class var_eqs { // IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n"); } + void explain_eq(eq_justification const& eq, lp::explanation& e) const { + u_dependency_manager dm; + unsigned_vector deps; + for (auto* dep : eq) { + deps.reset(); + dm.linearize(dep, deps); + for (auto ci : deps) + e.push_back(ci); + } + } + void explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const { SASSERT(find(v1) == find(v2)); if (v1 == v2) { @@ -249,7 +262,7 @@ class var_eqs { } while (head != 0) { - m_justtrail[head].explain(e); + explain_eq(m_justtrail[head], e); head = m_todo[head].m_index; ++m_stats.m_num_explains; } diff --git a/src/math/lp/var_register.h b/src/math/lp/var_register.h index 49767274de0..bd7e6efe830 100644 --- a/src/math/lp/var_register.h +++ b/src/math/lp/var_register.h @@ -41,10 +41,7 @@ class ext_var_info { class var_register { vector m_local_to_external; std::unordered_map m_external_to_local; - unsigned m_locals_mask; - unsigned m_locals_mask_inverted; public: - var_register(bool mask_locals): m_locals_mask(mask_locals? tv::left_most_bit: 0), m_locals_mask_inverted(~m_locals_mask) {} void set_name(unsigned j, std::string name) { m_local_to_external[j].set_name(name); @@ -63,7 +60,7 @@ class var_register { } m_local_to_external.push_back(ext_var_info(user_var, is_int)); - unsigned local = ( size() - 1 ) | m_locals_mask; + unsigned local = size() - 1; if (user_var != UINT_MAX) m_external_to_local[user_var] = local; @@ -72,15 +69,14 @@ class var_register { svector vars() const { svector ret; - for (const auto& p : m_local_to_external) { + for (const auto& p : m_local_to_external) ret.push_back(p.external_j()); - } return ret; } // returns UINT_MAX if unsigned local_to_external(unsigned local_var) const { - unsigned k = local_var & m_locals_mask_inverted; + unsigned k = local_var; if (k >= m_local_to_external.size()) return UINT_MAX; return m_local_to_external[k].external_j(); @@ -120,27 +116,22 @@ class var_register { local_j = UINT_MAX; return false; } - local_j = it->second & m_locals_mask_inverted; + local_j = it->second; is_int = m_local_to_external[local_j].is_integer(); return true; } bool has_int_var() const { - for (const auto & vi : m_local_to_external) { - if (vi.is_integer()) - return true; - } - return false; + return any_of(m_local_to_external, [&](const auto& vi) { return vi.is_integer(); }); } bool local_is_int(unsigned j) const { - return m_local_to_external[j & m_locals_mask_inverted].is_integer(); + return m_local_to_external[j].is_integer(); } void shrink(unsigned shrunk_size) { - for (unsigned j = size(); j-- > shrunk_size;) { + for (unsigned j = size(); j-- > shrunk_size;) m_external_to_local.erase(m_local_to_external[j].external_j()); - } m_local_to_external.resize(shrunk_size); } diff --git a/src/math/polynomial/algebraic_numbers.cpp b/src/math/polynomial/algebraic_numbers.cpp index 1a4769ac334..06b4465d77e 100644 --- a/src/math/polynomial/algebraic_numbers.cpp +++ b/src/math/polynomial/algebraic_numbers.cpp @@ -180,7 +180,7 @@ namespace algebraic_numbers { return m_upmanager; } - void del(basic_cell * c) { + void del_basic(basic_cell * c) { qm().del(c->m_value); m_allocator.deallocate(sizeof(basic_cell), c); } @@ -204,13 +204,13 @@ namespace algebraic_numbers { } void del(numeral & a) { - if (a.m_cell == nullptr) + if (a.is_null()) return; if (a.is_basic()) - del(a.to_basic()); + del_basic(a.to_basic()); else del(a.to_algebraic()); - a.m_cell = nullptr; + a.clear(); } void reset(numeral & a) { @@ -218,7 +218,7 @@ namespace algebraic_numbers { } bool is_zero(numeral const & a) { - return a.m_cell == nullptr; + return a.is_null(); } bool is_pos(numeral const & a) { @@ -358,8 +358,8 @@ namespace algebraic_numbers { return a.to_algebraic()->m_p_sz - 1; } - void swap(numeral & a, numeral & b) { - std::swap(a.m_cell, b.m_cell); + void swap(numeral & a, numeral & b) noexcept { + a.swap(b); } basic_cell * mk_basic_cell(mpq & n) { @@ -432,13 +432,13 @@ namespace algebraic_numbers { } if (a.is_basic()) { if (is_zero(a)) - a.m_cell = mk_basic_cell(n); + a = mk_basic_cell(n); else qm().set(a.to_basic()->m_value, n); } else { del(a); - a.m_cell = mk_basic_cell(n); + a = mk_basic_cell(n); } } @@ -492,7 +492,7 @@ namespace algebraic_numbers { else { if (a.is_basic()) { del(a); - a.m_cell = TAG(void*, mk_algebraic_cell(sz, p, lower, upper, minimal), ROOT); + a = mk_algebraic_cell(sz, p, lower, upper, minimal); } else { SASSERT(sz > 2); @@ -526,7 +526,7 @@ namespace algebraic_numbers { del(a); void * mem = m_allocator.allocate(sizeof(algebraic_cell)); algebraic_cell * c = new (mem) algebraic_cell(); - a.m_cell = TAG(void *, c, ROOT); + a = c; copy(c, b.to_algebraic()); SASSERT(acell_inv(*c)); } @@ -795,8 +795,8 @@ namespace algebraic_numbers { // root was found scoped_mpq r(qm()); to_mpq(qm(), lower(c), r); - del(c); - a.m_cell = mk_basic_cell(r); + del(a); + a = mk_basic_cell(r); return false; } } @@ -816,8 +816,8 @@ namespace algebraic_numbers { // actual root was found scoped_mpq r(qm()); to_mpq(qm(), lower(c), r); - del(c); - a.m_cell = mk_basic_cell(r); + del(a); + a = mk_basic_cell(r); return false; } SASSERT(acell_inv(*c)); @@ -2935,7 +2935,7 @@ namespace algebraic_numbers { return m_imp->to_rational(const_cast(a), r); } - void manager::swap(numeral & a, numeral & b) { + void manager::swap(numeral & a, numeral & b) noexcept { return m_imp->swap(a, b); } diff --git a/src/math/polynomial/algebraic_numbers.h b/src/math/polynomial/algebraic_numbers.h index 00af9a7a8a6..e2e95367c8a 100644 --- a/src/math/polynomial/algebraic_numbers.h +++ b/src/math/polynomial/algebraic_numbers.h @@ -129,7 +129,7 @@ namespace algebraic_numbers { void set(numeral & a, mpq const & n); void set(numeral & a, numeral const & n); - void swap(numeral & a, numeral & b); + void swap(numeral & a, numeral & b) noexcept; /** \brief Store in b an integer value smaller than 'a'. @@ -360,19 +360,25 @@ namespace algebraic_numbers { struct basic_cell; struct algebraic_cell; - enum anum_kind { BASIC = 0, ROOT }; + + + class anum { + enum anum_kind { BASIC = 0, ROOT }; + void* m_cell; + public: + anum() :m_cell(nullptr) {} + anum(basic_cell* cell) :m_cell(TAG(void*, cell, BASIC)) { } + anum(algebraic_cell * cell):m_cell(TAG(void*, cell, ROOT)) { } - class anum { - friend struct manager::imp; - friend class manager; - void * m_cell; - anum(basic_cell * cell):m_cell(TAG(void*, cell, BASIC)) {} - anum(algebraic_cell * cell):m_cell(TAG(void*, cell, ROOT)) {} bool is_basic() const { return GET_TAG(m_cell) == BASIC; } basic_cell * to_basic() const { SASSERT(is_basic()); return UNTAG(basic_cell*, m_cell); } algebraic_cell * to_algebraic() const { SASSERT(!is_basic()); return UNTAG(algebraic_cell*, m_cell); } - public: - anum():m_cell(nullptr) {} + + bool is_null() const { return m_cell == nullptr; } + void clear() { m_cell = nullptr; } + void swap(anum & other) { std::swap(m_cell, other.m_cell); } + anum& operator=(basic_cell* cell) { SASSERT(is_null()); m_cell = TAG(void*, cell, BASIC); return *this; } + anum& operator=(algebraic_cell* cell) { SASSERT(is_null()); m_cell = TAG(void*, cell, ROOT); return *this; } }; }; @@ -428,6 +434,7 @@ AN_MK_BINARY(operator/, div) #undef AN_MK_BINARY #undef AN_MK_BINARY_CORE + inline scoped_anum root(scoped_anum const & a, unsigned k) { scoped_anum r(a.m()); a.m().root(a, k, r); diff --git a/src/math/polynomial/algebraic_params.pyg b/src/math/polynomial/algebraic_params.pyg index 77354849450..0e53a40b836 100644 --- a/src/math/polynomial/algebraic_params.pyg +++ b/src/math/polynomial/algebraic_params.pyg @@ -5,6 +5,6 @@ def_module_params('algebraic', ('min_mag', UINT, 16, 'Z3 represents algebraic numbers using a (square-free) polynomial p and an isolating interval (which contains one and only one root of p). This interval may be refined during the computations. This parameter specifies whether to cache the value of a refined interval or not. It says the minimal size of an interval for caching purposes is 1/2^16'), ('factor', BOOL, True, 'use polynomial factorization to simplify polynomials representing algebraic numbers'), ('factor_max_prime', UINT, 31, 'parameter for the polynomial factorization procedure in the algebraic number module. Z3 polynomial factorization is composed of three steps: factorization in GF(p), lifting and search. This parameter limits the maximum prime number p to be used in the first step'), - ('factor_num_primes', UINT, 1, 'parameter for the polynomial factorization procedure in the algebraic number module. Z3 polynomial factorization is composed of three steps: factorization in GF(p), lifting and search. The search space may be reduced by factoring the polynomial in different GF(p)\'s. This parameter specify the maximum number of finite factorizations to be considered, before lifiting and searching'), + ('factor_num_primes', UINT, 1, 'parameter for the polynomial factorization procedure in the algebraic number module. Z3 polynomial factorization is composed of three steps: factorization in GF(p), lifting and search. The search space may be reduced by factoring the polynomial in different GF(p)\'s. This parameter specify the maximum number of finite factorizations to be considered, before lifting and searching'), ('factor_search_size', UINT, 5000, 'parameter for the polynomial factorization procedure in the algebraic number module. Z3 polynomial factorization is composed of three steps: factorization in GF(p), lifting and search. This parameter can be used to limit the search space'))) diff --git a/src/math/polynomial/polynomial.cpp b/src/math/polynomial/polynomial.cpp index b85ac1cf592..da6bc7b3925 100644 --- a/src/math/polynomial/polynomial.cpp +++ b/src/math/polynomial/polynomial.cpp @@ -447,7 +447,7 @@ namespace polynomial { } }; - inline void swap(monomial * & m1, monomial * & m2) { std::swap(m1, m2); } + inline void swap(monomial * & m1, monomial * & m2) noexcept { std::swap(m1, m2); } typedef chashtable monomial_table; @@ -1285,10 +1285,7 @@ namespace polynomial { } }); monomial_table new_table; - monomial_table::iterator it = m_monomials.begin(); - monomial_table::iterator end = m_monomials.end(); - for (; it != end; ++it) { - monomial * m = *it; + for (monomial * m : m_monomials) { m->rename(sz, xs); SASSERT(!new_table.contains(m)); new_table.insert(m); @@ -1586,9 +1583,7 @@ namespace polynomial { m_i->display_smt2(out, proc); } else { - out << "(* "; m_i->display_smt2(out, proc); - out << ")"; } } else { @@ -3629,6 +3624,7 @@ namespace polynomial { unsigned counter = 0; while (true) { + (void)counter; SASSERT(degree(pp_u, x) >= degree(pp_v, x)); unsigned delta = degree(pp_u, x) - degree(pp_v, x); TRACE("polynomial_gcd_detail", @@ -4174,6 +4170,7 @@ namespace polynomial { unsigned counter = 0; for (;; counter++) { + (void) counter; while (true) { peek_fresh(interpolator.inputs(), p, val); // the selected value must satisfy lc_g(val) != 0 @@ -5614,90 +5611,6 @@ namespace polynomial { } } - void psc_chain1(polynomial const * p, polynomial const * q, var x, polynomial_ref_vector & S) { - subresultant_chain(p, q, x, S); - unsigned sz = S.size(); - TRACE("psc", tout << "subresultant_chain\n"; - for (unsigned i = 0; i < sz; i++) { tout << "i: " << i << " "; S.get(i)->display(tout, m_manager); tout << "\n"; }); - for (unsigned i = 0; i < sz - 1; i++) { - S.set(i, coeff(S.get(i), x, i)); - } - S.set(sz-1, mk_one()); - } - - // Store in S a list of the non-zero principal subresultant coefficients of A and B - // If i < j then psc_{i}(A,B) precedes psc_{j}(A,B) in S. - // The leading coefficients of A and B are not included in S. - void psc_chain2(polynomial const * A, polynomial const * B, var x, polynomial_ref_vector & S) { - polynomial_ref G1(pm()); - polynomial_ref G2(pm()); - polynomial_ref G3(pm()); - polynomial_ref Gh3(pm()); - polynomial_ref g1(pm()), h0(pm()), hs0(pm()), h1(pm()), hs1(pm()); - unsigned n1 = degree(A, x); - unsigned n2 = degree(B, x); - if (n1 > n2) { - G1 = const_cast(A); - G2 = const_cast(B); - } - else { - G1 = const_cast(B); - G2 = const_cast(A); - std::swap(n1, n2); - } - unsigned d0 = 0; - unsigned d1 = n1 - n2; - unsigned i = 1; - unsigned n3; - S.reset(); - while (true) { - // Compute Gh_{i+2} - if (!is_zero(G2)) { - exact_pseudo_remainder(G1, G2, x, Gh3); - n3 = degree(Gh3, x); - if (!is_zero(Gh3) && d1%2 == 0) - Gh3 = neg(Gh3); - } - - // Compute hi - if (i > 1) { - g1 = lc(G1, x); - pw(g1, d0, h1); - if (i > 2) { - pw(h0, d0 - 1, hs0); - h1 = exact_div(h1, hs0); - S.push_back(h1); - if (is_zero(G2)) { - std::reverse(S.data(), S.data() + S.size()); - return; - } - } - } - - // Compute G_{i+2} - if (i == 1 || is_zero(Gh3)) { - G3 = Gh3; - } - else { - pw(h1, d1, hs1); - hs1 = mul(g1, hs1); - G3 = exact_div(Gh3, hs1); - hs1 = nullptr; - } - - // prepare for next iteration - n1 = n2; - n2 = n3; - d0 = d1; - d1 = n1 - n2; - G1 = G2; - G2 = G3; - if (i > 1) - h0 = h1; - i = i + 1; - } - } - // Optimized calculation of S_e using "Dichotomous Lazard" void Se_Lazard(unsigned d, polynomial const * lc_S_d, polynomial const * S_d_1, var x, polynomial_ref & S_e) { unsigned n = d - degree(S_d_1, x) - 1; @@ -5858,90 +5771,7 @@ namespace polynomial { std::reverse(S.data(), S.data() + S.size()); } - void psc_chain_classic_core(polynomial const * P, polynomial const * Q, var x, polynomial_ref_vector & S) { - TRACE("psc_chain_classic", tout << "P: "; P->display(tout, m_manager); tout << "\nQ: "; Q->display(tout, m_manager); tout << "\n";); - unsigned degP = degree(P, x); - unsigned degQ = degree(Q, x); - SASSERT(degP >= degQ); - polynomial_ref A(pm()), B(pm()), C(pm()), minus_Q(pm()), lc_Q(pm()), lc_B(pm()), lc_A(pm()); - polynomial_ref tmp1(pm()), tmp2(pm()), s_delta(pm()), minus_B(pm()), ps(pm()); - - lc_Q = lc(Q, x); - polynomial_ref s(pm()); - // s <- lc(Q)^(deg(P)-deg(Q)) - pw(lc_Q, degP - degQ, s); - minus_Q = neg(Q); - // A <- Q - A = const_cast(Q); - // B <- prem(P, -Q) - exact_pseudo_remainder(P, minus_Q, x, B); - while (true) { - unsigned d = degree(A, x); - unsigned e = degree(B, x); - if (is_zero(B)) - return; - TRACE("psc_chain_classic", tout << "A: " << A << "\nB: " << B << "\ns: " << s << "\nd: " << d << ", e: " << e << "\n";); - // B is S_{d-1} - ps = coeff(B, x, d-1); - if (!is_zero(ps)) - S.push_back(ps); - unsigned delta = d - e; - if (delta > 1) { - // C <- S_e - // Standard S_e calculation - // C <- (lc(B)^(delta-1) B) / s^(delta-1) - lc_B = lc(B, x); - pw(lc_B, delta-1, lc_B); - lc_B = mul(lc_B, B); - pw(s, delta - 1, s_delta); // s_delta <- s^(delta-1) - C = exact_div(lc_B, s_delta); - - // s_delta <- s^delta - s_delta = mul(s_delta, s); - // C is S_e - ps = coeff(C, x, e); - if (!is_zero(ps)) - S.push_back(ps); - - } - else { - SASSERT(delta == 0 || delta == 1); - C = B; - // s_delta <- s^delta - pw(s, delta, s_delta); - } - if (e == 0) - return; - // B <- prem(A, -B)/(s^delta * lc(A) - lc_A = lc(A, x); - minus_B = neg(B); - exact_pseudo_remainder(A, minus_B, x, tmp1); - tmp2 = mul(lc_A, s_delta); - B = exact_div(tmp1, tmp2); - // A <- C - A = C; - // s <- lc(A) - s = lc(A, x); - } - } - - void psc_chain_classic(polynomial const * P, polynomial const * Q, var x, polynomial_ref_vector & S) { - SASSERT(degree(P, x) > 0); - SASSERT(degree(Q, x) > 0); - S.reset(); - if (degree(P, x) >= degree(Q, x)) - psc_chain_classic_core(P, Q, x, S); - else - psc_chain_classic_core(Q, P, x, S); - if (S.empty()) - S.push_back(mk_zero()); - std::reverse(S.data(), S.data() + S.size()); - } - void psc_chain(polynomial const * A, polynomial const * B, var x, polynomial_ref_vector & S) { - // psc_chain1(A, B, x, S); - //psc_chain2(A, B, x, S); - //psc_chain_classic(A, B, x, S); psc_chain_optimized(A, B, x, S); } diff --git a/src/math/polynomial/rpolynomial.cpp b/src/math/polynomial/rpolynomial.cpp index 0afd22a7566..dd1b82994da 100644 --- a/src/math/polynomial/rpolynomial.cpp +++ b/src/math/polynomial/rpolynomial.cpp @@ -658,9 +658,7 @@ namespace rpolynomial { void display(std::ostream & out, polynomial const * p, display_var_proc const & proc, bool use_star) { var x = p->max_var(); bool first = true; - unsigned i = p->size(); - while (i > 0) { - --i; + for (unsigned i = p->size(); i-- > 0; ) { poly_or_num * pn = p->arg(i); if (pn == nullptr) continue; @@ -697,23 +695,19 @@ namespace rpolynomial { display(out, to_poly(pn), proc, use_star); } else { - bool add_paren = false; - if (i > 0) - add_paren = !is_monomial(to_poly(pn)); + bool add_paren = !is_monomial(to_poly(pn)); if (add_paren) out << "("; display(out, to_poly(pn), proc, use_star); if (add_paren) out << ")"; - if (i > 0) { - if (use_star) - out << "*"; - else - out << " "; - proc(out, x); - if (i > 1) - out << "^" << i; - } + if (use_star) + out << "*"; + else + out << " "; + proc(out, x); + if (i > 1) + out << "^" << i; } } } diff --git a/src/math/polynomial/upolynomial.cpp b/src/math/polynomial/upolynomial.cpp index ac0aec8ef5f..a354ed2951b 100644 --- a/src/math/polynomial/upolynomial.cpp +++ b/src/math/polynomial/upolynomial.cpp @@ -126,7 +126,7 @@ namespace upolynomial { m_factors[i].swap(p); } - void core_manager::factors::swap(factors & other) { + void core_manager::factors::swap(factors & other) noexcept { m_factors.swap(other.m_factors); m_degrees.swap(other.m_degrees); nm().swap(m_constant, other.m_constant); diff --git a/src/math/polynomial/upolynomial.h b/src/math/polynomial/upolynomial.h index faa0e5e0983..2afdbb7b3d1 100644 --- a/src/math/polynomial/upolynomial.h +++ b/src/math/polynomial/upolynomial.h @@ -89,7 +89,7 @@ namespace upolynomial { void push_back_swap(numeral_vector & p, unsigned degree); void swap_factor(unsigned i, numeral_vector & p); - void swap(factors & other); + void swap(factors & other) noexcept; void multiply(numeral_vector & out) const; void display(std::ostream & out) const; diff --git a/src/math/polynomial/upolynomial_factorization_int.h b/src/math/polynomial/upolynomial_factorization_int.h index e66fd2f1ba1..a65e5ee62f7 100644 --- a/src/math/polynomial/upolynomial_factorization_int.h +++ b/src/math/polynomial/upolynomial_factorization_int.h @@ -85,7 +85,7 @@ namespace upolynomial { unsigned max_degree() const { return m_set.size() - 1; } - void swap(factorization_degree_set & other) { + void swap(factorization_degree_set & other) noexcept { m_set.swap(other.m_set); } diff --git a/src/math/realclosure/mpz_matrix.h b/src/math/realclosure/mpz_matrix.h index 878fb72f912..91fe226816b 100644 --- a/src/math/realclosure/mpz_matrix.h +++ b/src/math/realclosure/mpz_matrix.h @@ -50,7 +50,7 @@ class mpz_matrix { SASSERT(j < n); return a_ij[i*n + j]; } mpz & operator()(unsigned i, unsigned j) { SASSERT(i < m); SASSERT(j < n); return a_ij[i*n + j]; } - void swap(mpz_matrix & B) { std::swap(m, B.m); std::swap(n, B.n); std::swap(a_ij, B.a_ij); } + void swap(mpz_matrix & B) noexcept { std::swap(m, B.m); std::swap(n, B.n); std::swap(a_ij, B.a_ij); } mpz * row(unsigned i) const { SASSERT(i < m); return a_ij + i*n; } }; @@ -136,7 +136,7 @@ class scoped_mpz_matrix { mpz_matrix const & get() const { return A; } mpz_matrix & get() { return A; } - void swap(mpz_matrix & B) { A.swap(B); } + void swap(mpz_matrix & B) noexcept { A.swap(B); } void set(unsigned i, unsigned j, mpz const & v) { nm().set(A(i, j), v); } void set(unsigned i, unsigned j, int v) { nm().set(A(i, j), v); } diff --git a/src/math/realclosure/realclosure.cpp b/src/math/realclosure/realclosure.cpp index 68ed35b5d33..ecea560a515 100644 --- a/src/math/realclosure/realclosure.cpp +++ b/src/math/realclosure/realclosure.cpp @@ -135,7 +135,7 @@ namespace realclosure { typedef interval_manager mpbqi_manager; typedef mpbqi_manager::interval mpbqi; - void swap(mpbqi & a, mpbqi & b) { + void swap(mpbqi & a, mpbqi & b) noexcept { swap(a.m_lower, b.m_lower); swap(a.m_upper, b.m_upper); std::swap(a.m_lower_inf, b.m_lower_inf); @@ -2498,6 +2498,35 @@ namespace realclosure { } } + /** + \brief Return true if a is an algebraic number. + */ + bool is_algebraic(numeral const & a) { + return is_rational_function(a) && to_rational_function(a)->ext()->is_algebraic(); + } + + /** + \brief Return true if a represents an infinitesimal. + */ + bool is_infinitesimal(numeral const & a) { + return is_rational_function(a) && to_rational_function(a)->ext()->is_infinitesimal(); + } + + /** + \brief Return true if a is a transcendental. + */ + bool is_transcendental(numeral const & a) { + return is_rational_function(a) && to_rational_function(a)->ext()->is_transcendental(); + } + + /** + \brief Return true if a is a rational. + */ + bool is_rational(numeral const & a) { + return a.m_value->is_rational(); + } + + /** \brief Return true if a depends on infinitesimal extensions. */ @@ -2505,7 +2534,7 @@ namespace realclosure { return depends_on_infinitesimals(a.m_value); } - static void swap(mpbqi & a, mpbqi & b) { + static void swap(mpbqi & a, mpbqi & b) noexcept { realclosure::swap(a, b); } @@ -3330,6 +3359,151 @@ namespace realclosure { set(q, _q); } + unsigned extension_index(numeral const & a) { + if (!is_rational_function(a)) + return -1; + return to_rational_function(a)->ext()->idx(); + } + + symbol transcendental_name(numeral const & a) { + if (!is_transcendental(a)) + return symbol(); + return to_transcendental(to_rational_function(a)->ext())->m_name; + } + + symbol infinitesimal_name(numeral const & a) { + if (!is_infinitesimal(a)) + return symbol(); + return to_infinitesimal(to_rational_function(a)->ext())->m_name; + } + + unsigned num_coefficients(numeral const & a) { + if (!is_algebraic(a)) + return 0; + return to_algebraic(to_rational_function(a)->ext())->p().size(); + } + + numeral get_coefficient(numeral const & a, unsigned i) + { + if (!is_algebraic(a)) + return numeral(); + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + if (i >= ext->p().size()) + return numeral(); + value_ref v(*this); + v = ext->p()[i]; + numeral r; + set(r, v); + return r; + } + + unsigned num_sign_conditions(numeral const & a) { + unsigned r = 0; + if (is_algebraic(a)) { + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + const sign_det * sdt = ext->sdt(); + if (sdt) { + sign_condition * sc = sdt->sc(ext->sc_idx()); + while (sc) { + r++; + sc = sc->prev(); + } + } + } + return r; + } + + int get_sign_condition_sign(numeral const & a, unsigned i) + { + if (!is_algebraic(a)) + return 0; + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + const sign_det * sdt = ext->sdt(); + if (!sdt) + return 0; + else { + sign_condition * sc = sdt->sc(ext->sc_idx()); + while (i) { + if (sc) sc = sc->prev(); + i--; + } + return sc ? sc->sign() : 0; + } + } + + bool get_interval(numeral const & a, int & lower_is_inf, int & lower_is_open, numeral & lower, int & upper_is_inf, int & upper_is_open, numeral & upper) + { + if (!is_algebraic(a)) + return false; + lower = numeral(); + upper = numeral(); + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + mpbqi &ivl = ext->iso_interval(); + lower_is_inf = ivl.lower_is_inf(); + lower_is_open = ivl.lower_is_open(); + if (!m_bqm.is_zero(ivl.lower())) + set(lower, mk_rational(ivl.lower())); + upper_is_inf = ivl.upper_is_inf(); + upper_is_open = ivl.upper_is_open(); + if (!m_bqm.is_zero(ivl.upper())) + set(upper, mk_rational(ivl.upper())); + return true; + } + + unsigned get_sign_condition_size(numeral const &a, unsigned i) { + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + const sign_det * sdt = ext->sdt(); + if (!sdt) + return 0; + sign_condition * sc = sdt->sc(ext->sc_idx()); + while (i) { + if (sc) sc = sc->prev(); + i--; + } + return ext->sdt()->qs()[sc->qidx()].size(); + } + + int num_sign_condition_coefficients(numeral const &a, unsigned i) + { + if (!is_algebraic(a)) + return 0; + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + const sign_det * sdt = ext->sdt(); + if (!sdt) + return 0; + sign_condition * sc = sdt->sc(ext->sc_idx()); + while (i) { + if (sc) sc = sc->prev(); + i--; + } + const polynomial & q = ext->sdt()->qs()[sc->qidx()]; + return q.size(); + } + + numeral get_sign_condition_coefficient(numeral const &a, unsigned i, unsigned j) + { + if (!is_algebraic(a)) + return numeral(); + algebraic * ext = to_algebraic(to_rational_function(a)->ext()); + const sign_det * sdt = ext->sdt(); + if (!sdt) + return numeral(); + sign_condition * sc = sdt->sc(ext->sc_idx()); + while (i) { + if (sc) sc = sc->prev(); + i--; + } + const polynomial & q = ext->sdt()->qs()[sc->qidx()]; + if (j >= q.size()) + return numeral(); + value_ref v(*this); + v = q[j]; + numeral r; + set(r, v); + return r; + } + + // --------------------------------- // // GCD of integer coefficients @@ -6103,6 +6277,22 @@ namespace realclosure { return m_imp->is_int(a); } + bool manager::is_rational(numeral const & a) { + return m_imp->is_rational(a); + } + + bool manager::is_algebraic(numeral const & a) { + return m_imp->is_algebraic(a); + } + + bool manager::is_infinitesimal(numeral const & a) { + return m_imp->is_infinitesimal(a); + } + + bool manager::is_transcendental(numeral const & a) { + return m_imp->is_transcendental(a); + } + bool manager::depends_on_infinitesimals(numeral const & a) { return m_imp->depends_on_infinitesimals(a); } @@ -6123,7 +6313,7 @@ namespace realclosure { m_imp->set(a, n); } - void manager::swap(numeral & a, numeral & b) { + void manager::swap(numeral & a, numeral & b) noexcept { std::swap(a.m_value, b.m_value); } @@ -6251,6 +6441,56 @@ namespace realclosure { save_interval_ctx ctx(this); m_imp->clean_denominators(a, p, q); } + + unsigned manager::extension_index(numeral const & a) + { + return m_imp->extension_index(a); + } + + symbol manager::transcendental_name(numeral const &a) + { + return m_imp->transcendental_name(a); + } + + symbol manager::infinitesimal_name(numeral const &a) + { + return m_imp->infinitesimal_name(a); + } + + unsigned manager::num_coefficients(numeral const &a) + { + return m_imp->num_coefficients(a); + } + + manager::numeral manager::get_coefficient(numeral const &a, unsigned i) + { + return m_imp->get_coefficient(a, i); + } + + unsigned manager::num_sign_conditions(numeral const &a) + { + return m_imp->num_sign_conditions(a); + } + + int manager::get_sign_condition_sign(numeral const &a, unsigned i) + { + return m_imp->get_sign_condition_sign(a, i); + } + + bool manager::get_interval(numeral const & a, int & lower_is_inf, int & lower_is_open, numeral & lower, int & upper_is_inf, int & upper_is_open, numeral & upper) + { + return m_imp->get_interval(a, lower_is_inf, lower_is_open, lower, upper_is_inf, upper_is_open, upper); + } + + unsigned manager::num_sign_condition_coefficients(numeral const &a, unsigned i) + { + return m_imp->num_sign_condition_coefficients(a, i); + } + + manager::numeral manager::get_sign_condition_coefficient(numeral const &a, unsigned i, unsigned j) + { + return m_imp->get_sign_condition_coefficient(a, i, j); + } }; void pp(realclosure::manager::imp * imp, realclosure::polynomial const & p, realclosure::extension * ext) { diff --git a/src/math/realclosure/realclosure.h b/src/math/realclosure/realclosure.h index 788db4bbff6..12247627bb5 100644 --- a/src/math/realclosure/realclosure.h +++ b/src/math/realclosure/realclosure.h @@ -70,14 +70,14 @@ namespace realclosure { */ void mk_infinitesimal(char const * name, char const * pp_name, numeral & r); void mk_infinitesimal(numeral & r); - + /** - \brief Add a new transcendental real value to the field. + \brief Add a new transcendental real value to the field. The functor \c mk_interval is used to compute approximations of the transcendental value. This procedure should be used with care, if the value is not really transcendental with respect to the current field, computations with the new numeral may not terminate. Example: we extended the field with Pi. Pi is transcendental with respect to a field that contains only algebraic real numbers. - So, this step is fine. Let us call the resultant field F. + So, this step is fine. Let us call the resultant field F. Then, we extend the field F with 1 - Pi. 1 - Pi is transcendental with respect to algebraic real numbers, but it is NOT transcendental with respect to F, since F contains Pi. */ @@ -109,12 +109,12 @@ namespace realclosure { \brief Return the sign of a. */ int sign(numeral const & a); - + /** \brief Return true if a is zero. */ bool is_zero(numeral const & a); - + /** \brief Return true if a is positive. */ @@ -129,13 +129,33 @@ namespace realclosure { \brief Return true if a is an integer. */ bool is_int(numeral const & a); - + + /** + \brief Return true if a is a rational. + */ + bool is_rational(numeral const & a); + + /** + \brief Return true if a is an algebraic number. + */ + bool is_algebraic(numeral const & a); + + /** + \brief Return true if a represents an infinitesimal. + */ + bool is_infinitesimal(numeral const & a); + + /** + \brief Return true if a is a transcendental. + */ + bool is_transcendental(numeral const & a); + /** \brief Return true if the representation of \c a depends on infinitesimal extensions. */ bool depends_on_infinitesimals(numeral const & a); - + /** \brief a <- n */ @@ -144,18 +164,18 @@ namespace realclosure { void set(numeral & a, mpq const & n); void set(numeral & a, numeral const & n); - void swap(numeral & a, numeral & b); + void swap(numeral & a, numeral & b) noexcept; /** \brief Return a^{1/k} - + Throws an exception if (a is negative and k is even) or (k is zero). - */ + */ void root(numeral const & a, unsigned k, numeral & b); - + /** \brief Return a^k - + Throws an exception if 0^0. */ void power(numeral const & a, unsigned k, numeral & b); @@ -180,7 +200,7 @@ namespace realclosure { \brief a <- -a */ void neg(numeral & a); - + /** \brief b <- -a */ @@ -190,7 +210,7 @@ namespace realclosure { \brief a <- 1/a if a != 0 */ void inv(numeral & a); - + /** \brief b <- 1/a if a != 0 */ @@ -207,7 +227,7 @@ namespace realclosure { Return 1 if a > b */ int compare(numeral const & a, numeral const & b); - + /** \brief a == b */ @@ -249,7 +269,7 @@ namespace realclosure { bool ge(numeral const & a, numeral const & b) { return !lt(a, b); } bool ge(numeral const & a, mpq const & b) { return !lt(a, b); } bool ge(numeral const & a, mpz const & b) { return !lt(a, b); } - + void display(std::ostream & out, numeral const & a, bool compact=false, bool pp=false) const; /** @@ -259,10 +279,30 @@ namespace realclosure { */ void display_decimal(std::ostream & out, numeral const & a, unsigned precision = 10) const; - + void display_interval(std::ostream & out, numeral const & a) const; - + void clean_denominators(numeral const & a, numeral & p, numeral & q); + + unsigned extension_index(numeral const & a); + + symbol transcendental_name(numeral const &a); + + symbol infinitesimal_name(numeral const &a); + + unsigned num_coefficients(numeral const &a); + + numeral get_coefficient(numeral const &a, unsigned i); + + unsigned num_sign_conditions(numeral const &a); + + int get_sign_condition_sign(numeral const &a, unsigned i); + + bool get_interval(numeral const & a, int & lower_is_inf, int & lower_is_open, numeral & lower, int & upper_is_inf, int & upper_is_open, numeral & upper); + + unsigned num_sign_condition_coefficients(numeral const &a, unsigned i); + + numeral get_sign_condition_coefficient(numeral const &a, unsigned i, unsigned j); }; struct value; diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp index 3c38cfb0e7c..e66cab31084 100644 --- a/src/math/simplex/model_based_opt.cpp +++ b/src/math/simplex/model_based_opt.cpp @@ -1,1748 +1,1748 @@ -/*++ -Copyright (c) 2016 Microsoft Corporation - -Module Name: - - model_based_opt.cpp - -Abstract: - - Model-based optimization and projection for linear real, integer arithmetic. - -Author: - - Nikolaj Bjorner (nbjorner) 2016-27-4 - -Revision History: - - ---*/ - -#include "math/simplex/model_based_opt.h" -#include "util/uint_set.h" -#include "util/z3_exception.h" - -std::ostream& operator<<(std::ostream& out, opt::ineq_type ie) { - switch (ie) { - case opt::t_eq: return out << " = "; - case opt::t_lt: return out << " < "; - case opt::t_le: return out << " <= "; - case opt::t_divides: return out << " divides "; - case opt::t_mod: return out << " mod "; - case opt::t_div: return out << " div "; - } - return out; -} - - -namespace opt { - - /** - * Convert a row ax + coeffs + coeff = value into a definition for x - * x = (value - coeffs - coeff)/a - * as backdrop we have existing assignments to x and other variables that - * satisfy the equality with value, and such that value satisfies - * the row constraint ( = , <= , < , mod) - */ - model_based_opt::def::def(row const& r, unsigned x) { - for (var const & v : r.m_vars) { - if (v.m_id != x) { - m_vars.push_back(v); - } - else { - m_div = -v.m_coeff; - } - } - m_coeff = r.m_coeff; - switch (r.m_type) { - case opt::t_lt: - m_coeff += m_div; - break; - case opt::t_le: - // for: ax >= t, then x := (t + a - 1) div a - if (m_div.is_pos()) { - m_coeff += m_div; - m_coeff -= rational::one(); - } - break; - default: - break; - } - normalize(); - SASSERT(m_div.is_pos()); - } - - model_based_opt::def model_based_opt::def::operator+(def const& other) const { - def result; - vector const& vs1 = m_vars; - vector const& vs2 = other.m_vars; - vector & vs = result.m_vars; - rational c1(1), c2(1); - if (m_div != other.m_div) { - c1 = other.m_div; - c2 = m_div; - } - unsigned i = 0, j = 0; - while (i < vs1.size() || j < vs2.size()) { - unsigned v1 = UINT_MAX, v2 = UINT_MAX; - if (i < vs1.size()) v1 = vs1[i].m_id; - if (j < vs2.size()) v2 = vs2[j].m_id; - if (v1 == v2) { - vs.push_back(vs1[i]); - vs.back().m_coeff *= c1; - vs.back().m_coeff += c2 * vs2[j].m_coeff; - ++i; ++j; - if (vs.back().m_coeff.is_zero()) { - vs.pop_back(); - } - } - else if (v1 < v2) { - vs.push_back(vs1[i]); - vs.back().m_coeff *= c1; - ++i; - } - else { - vs.push_back(vs2[j]); - vs.back().m_coeff *= c2; - ++j; - } - } - result.m_div = c1*m_div; - result.m_coeff = (m_coeff*c1) + (other.m_coeff*c2); - result.normalize(); - return result; - } - - /** - a1*x1 + a2*x2 + a3*x3 + coeff1 / c1 - x2 |-> b1*x1 + b4*x4 + ceoff2 / c2 - ------------------------------------------------------------------------ - (a1*x1 + a2*((b1*x1 + b4*x4 + coeff2) / c2) + a3*x3 + coeff1) / c1 - ------------------------------------------------------------------------ - (c2*a1*x1 + a2*b1*x1 + a2*b4*x4 + c2*a3*x3 + c2*coeff1 + coeff2) / c1*c2 - */ - void model_based_opt::def::substitute(unsigned v, def const& other) { - vector const& vs1 = m_vars; - rational coeff(0); - for (auto const& [id, c] : vs1) { - if (id == v) { - coeff = c; - break; - } - } - if (coeff == 0) - return; - - rational c1 = m_div; - rational c2 = other.m_div; - - vector const& vs2 = other.m_vars; - vector vs; - unsigned i = 0, j = 0; - while (i < vs1.size() || j < vs2.size()) { - unsigned v1 = UINT_MAX, v2 = UINT_MAX; - if (i < vs1.size()) v1 = vs1[i].m_id; - if (j < vs2.size()) v2 = vs2[j].m_id; - if (v1 == v) - ++i; - else if (v1 == v2) { - vs.push_back(vs1[i]); - vs.back().m_coeff *= c2; - vs.back().m_coeff += coeff * vs2[j].m_coeff; - ++i; ++j; - if (vs.back().m_coeff.is_zero()) - vs.pop_back(); - } - else if (v1 < v2) { - vs.push_back(vs1[i]); - vs.back().m_coeff *= c2; - ++i; - } - else { - vs.push_back(vs2[j]); - vs.back().m_coeff *= coeff; - ++j; - } - } - m_div *= other.m_div; - m_coeff *= c2; - m_coeff += coeff*other.m_coeff; - m_vars.reset(); - m_vars.append(vs); - normalize(); - } - - model_based_opt::def model_based_opt::def::operator/(rational const& r) const { - def result(*this); - result.m_div *= r; - result.normalize(); - return result; - } - - model_based_opt::def model_based_opt::def::operator*(rational const& n) const { - def result(*this); - for (var& v : result.m_vars) { - v.m_coeff *= n; - } - result.m_coeff *= n; - result.normalize(); - return result; - } - - model_based_opt::def model_based_opt::def::operator+(rational const& n) const { - def result(*this); - result.m_coeff += n * result.m_div; - result.normalize(); - return result; - } - - void model_based_opt::def::normalize() { - if (!m_div.is_int()) { - rational den = denominator(m_div); - SASSERT(den > 1); - for (var& v : m_vars) - v.m_coeff *= den; - m_coeff *= den; - m_div *= den; - - } - if (m_div.is_neg()) { - for (var& v : m_vars) - v.m_coeff.neg(); - m_coeff.neg(); - m_div.neg(); - } - if (m_div.is_one()) - return; - rational g(m_div); - if (!m_coeff.is_int()) - return; - g = gcd(g, m_coeff); - for (var const& v : m_vars) { - if (!v.m_coeff.is_int()) - return; - g = gcd(g, abs(v.m_coeff)); - if (g.is_one()) - break; - } - if (!g.is_one()) { - for (var& v : m_vars) - v.m_coeff /= g; - m_coeff /= g; - m_div /= g; - } - } - - model_based_opt::model_based_opt() { - m_rows.push_back(row()); - } - - bool model_based_opt::invariant() { - for (unsigned i = 0; i < m_rows.size(); ++i) { - if (!invariant(i, m_rows[i])) { - return false; - } - } - return true; - } - -#define PASSERT(_e_) { CTRACE("qe", !(_e_), display(tout, r); display(tout);); SASSERT(_e_); } - - bool model_based_opt::invariant(unsigned index, row const& r) { - vector const& vars = r.m_vars; - for (unsigned i = 0; i < vars.size(); ++i) { - // variables in each row are sorted and have non-zero coefficients - PASSERT(i + 1 == vars.size() || vars[i].m_id < vars[i+1].m_id); - PASSERT(!vars[i].m_coeff.is_zero()); - PASSERT(index == 0 || m_var2row_ids[vars[i].m_id].contains(index)); - } - - PASSERT(r.m_value == eval(r)); - PASSERT(r.m_type != t_eq || r.m_value.is_zero()); - // values satisfy constraints - PASSERT(index == 0 || r.m_type != t_lt || r.m_value.is_neg()); - PASSERT(index == 0 || r.m_type != t_le || !r.m_value.is_pos()); - PASSERT(index == 0 || r.m_type != t_divides || (mod(r.m_value, r.m_mod).is_zero())); - PASSERT(index == 0 || r.m_type != t_mod || r.m_id < m_var2value.size()); - PASSERT(index == 0 || r.m_type != t_div || r.m_id < m_var2value.size()); - return true; - } - - // a1*x + obj - // a2*x + t2 <= 0 - // a3*x + t3 <= 0 - // a4*x + t4 <= 0 - // a1 > 0, a2 > 0, a3 > 0, a4 < 0 - // x <= -t2/a2 - // x <= -t2/a3 - // determine lub among these. - // then resolve lub with others - // e.g., -t2/a2 <= -t3/a3, then - // replace inequality a3*x + t3 <= 0 by -t2/a2 + t3/a3 <= 0 - // mark a4 as invalid. - // - - // a1 < 0, a2 < 0, a3 < 0, a4 > 0 - // x >= t2/a2 - // x >= t3/a3 - // determine glb among these - // the resolve glb with others. - // e.g. t2/a2 >= t3/a3 - // then replace a3*x + t3 by t3/a3 - t2/a2 <= 0 - // - inf_eps model_based_opt::maximize() { - SASSERT(invariant()); - unsigned_vector bound_trail, bound_vars; - TRACE("opt", display(tout << "tableau\n");); - while (!objective().m_vars.empty()) { - var v = objective().m_vars.back(); - unsigned x = v.m_id; - rational const& coeff = v.m_coeff; - unsigned bound_row_index; - rational bound_coeff; - if (find_bound(x, bound_row_index, bound_coeff, coeff.is_pos())) { - SASSERT(!bound_coeff.is_zero()); - TRACE("opt", display(tout << "update: " << v << " ", objective()); - for (unsigned above : m_above) { - display(tout << "resolve: ", m_rows[above]); - }); - for (unsigned above : m_above) { - resolve(bound_row_index, bound_coeff, above, x); - } - for (unsigned below : m_below) { - resolve(bound_row_index, bound_coeff, below, x); - } - // coeff*x + objective <= ub - // a2*x + t2 <= 0 - // => coeff*x <= -t2*coeff/a2 - // objective + t2*coeff/a2 <= ub - - mul_add(false, m_objective_id, - coeff/bound_coeff, bound_row_index); - retire_row(bound_row_index); - bound_trail.push_back(bound_row_index); - bound_vars.push_back(x); - } - else { - TRACE("opt", display(tout << "unbound: " << v << " ", objective());); - update_values(bound_vars, bound_trail); - return inf_eps::infinity(); - } - } - - // - // update the evaluation of variables to satisfy the bound. - // - - update_values(bound_vars, bound_trail); - - rational value = objective().m_value; - if (objective().m_type == t_lt) { - return inf_eps(inf_rational(value, rational(-1))); - } - else { - return inf_eps(inf_rational(value)); - } - } - - - void model_based_opt::update_value(unsigned x, rational const& val) { - rational old_val = m_var2value[x]; - m_var2value[x] = val; - SASSERT(val.is_int() || !is_int(x)); - unsigned_vector const& row_ids = m_var2row_ids[x]; - for (unsigned row_id : row_ids) { - rational coeff = get_coefficient(row_id, x); - if (coeff.is_zero()) { - continue; - } - row & r = m_rows[row_id]; - rational delta = coeff * (val - old_val); - r.m_value += delta; - SASSERT(invariant(row_id, r)); - } - } - - - void model_based_opt::update_values(unsigned_vector const& bound_vars, unsigned_vector const& bound_trail) { - for (unsigned i = bound_trail.size(); i-- > 0; ) { - unsigned x = bound_vars[i]; - row& r = m_rows[bound_trail[i]]; - rational val = r.m_coeff; - rational old_x_val = m_var2value[x]; - rational new_x_val; - rational x_coeff, eps(0); - vector const& vars = r.m_vars; - for (var const& v : vars) { - if (x == v.m_id) { - x_coeff = v.m_coeff; - } - else { - val += m_var2value[v.m_id]*v.m_coeff; - } - } - SASSERT(!x_coeff.is_zero()); - new_x_val = -val/x_coeff; - - if (r.m_type == t_lt) { - eps = abs(old_x_val - new_x_val)/rational(2); - eps = std::min(rational::one(), eps); - SASSERT(!eps.is_zero()); - - // - // ax + t < 0 - // <=> x < -t/a - // <=> x := -t/a - epsilon - // - if (x_coeff.is_pos()) { - new_x_val -= eps; - } - // - // -ax + t < 0 - // <=> -ax < -t - // <=> -x < -t/a - // <=> x > t/a - // <=> x := t/a + epsilon - // - else { - new_x_val += eps; - } - } - TRACE("opt", display(tout << "v" << x - << " coeff_x: " << x_coeff - << " old_x_val: " << old_x_val - << " new_x_val: " << new_x_val - << " eps: " << eps << " ", r); ); - m_var2value[x] = new_x_val; - - r.m_value = eval(r); - SASSERT(invariant(bound_trail[i], r)); - } - - // update and check bounds for all other affected rows. - for (unsigned i = bound_trail.size(); i-- > 0; ) { - unsigned x = bound_vars[i]; - unsigned_vector const& row_ids = m_var2row_ids[x]; - for (unsigned row_id : row_ids) { - row & r = m_rows[row_id]; - r.m_value = eval(r); - SASSERT(invariant(row_id, r)); - } - } - SASSERT(invariant()); - } - - bool model_based_opt::find_bound(unsigned x, unsigned& bound_row_index, rational& bound_coeff, bool is_pos) { - bound_row_index = UINT_MAX; - rational lub_val; - rational const& x_val = m_var2value[x]; - unsigned_vector const& row_ids = m_var2row_ids[x]; - uint_set visited; - m_above.reset(); - m_below.reset(); - for (unsigned row_id : row_ids) { - SASSERT(row_id != m_objective_id); - if (visited.contains(row_id)) - continue; - visited.insert(row_id); - row& r = m_rows[row_id]; - if (!r.m_alive) - continue; - rational a = get_coefficient(row_id, x); - if (a.is_zero()) { - // skip - } - else if (a.is_pos() == is_pos || r.m_type == t_eq) { - rational value = x_val - (r.m_value/a); - if (bound_row_index == UINT_MAX) { - lub_val = value; - bound_row_index = row_id; - bound_coeff = a; - } - else if ((value == lub_val && r.m_type == opt::t_lt) || - (is_pos && value < lub_val) || - - (!is_pos && value > lub_val)) { - m_above.push_back(bound_row_index); - lub_val = value; - bound_row_index = row_id; - bound_coeff = a; - } - else - m_above.push_back(row_id); - } - else - m_below.push_back(row_id); - } - return bound_row_index != UINT_MAX; - } - - void model_based_opt::retire_row(unsigned row_id) { - SASSERT(!m_retired_rows.contains(row_id)); - m_rows[row_id].m_alive = false; - m_retired_rows.push_back(row_id); - } - - rational model_based_opt::eval(unsigned x) const { - return m_var2value[x]; - } - - rational model_based_opt::eval(def const& d) const { - vector const& vars = d.m_vars; - rational val = d.m_coeff; - for (var const& v : vars) { - val += v.m_coeff * eval(v.m_id); - } - val /= d.m_div; - return val; - } - - rational model_based_opt::eval(row const& r) const { - vector const& vars = r.m_vars; - rational val = r.m_coeff; - for (var const& v : vars) { - val += v.m_coeff * eval(v.m_id); - } - return val; - } - - rational model_based_opt::eval(vector const& coeffs) const { - rational val(0); - for (var const& v : coeffs) - val += v.m_coeff * eval(v.m_id); - return val; - } - - rational model_based_opt::get_coefficient(unsigned row_id, unsigned var_id) const { - return m_rows[row_id].get_coefficient(var_id); - } - - rational model_based_opt::row::get_coefficient(unsigned var_id) const { - if (m_vars.empty()) - return rational::zero(); - unsigned lo = 0, hi = m_vars.size(); - while (lo < hi) { - unsigned mid = lo + (hi - lo)/2; - SASSERT(mid < hi); - unsigned id = m_vars[mid].m_id; - if (id == var_id) { - lo = mid; - break; - } - if (id < var_id) - lo = mid + 1; - else - hi = mid; - } - if (lo == m_vars.size()) - return rational::zero(); - unsigned id = m_vars[lo].m_id; - if (id == var_id) - return m_vars[lo].m_coeff; - else - return rational::zero(); - } - - model_based_opt::row& model_based_opt::row::normalize() { -#if 0 - if (m_type == t_divides || m_type == t_mod || m_type == t_div) - return *this; - rational D(denominator(abs(m_coeff))); - if (D == 0) - D = 1; - for (auto const& [id, coeff] : m_vars) - if (coeff != 0) - D = lcm(D, denominator(abs(coeff))); - if (D == 1) - return *this; - SASSERT(D > 0); - for (auto & [id, coeff] : m_vars) - coeff *= D; - m_coeff *= D; -#endif - return *this; - } - - // - // Let - // row1: t1 + a1*x <= 0 - // row2: t2 + a2*x <= 0 - // - // assume a1, a2 have the same signs: - // (t2 + a2*x) <= (t1 + a1*x)*a2/a1 - // <=> t2*a1/a2 - t1 <= 0 - // <=> t2 - t1*a2/a1 <= 0 - // - // assume a1 > 0, -a2 < 0: - // t1 + a1*x <= 0, t2 - a2*x <= 0 - // t2/a2 <= -t1/a1 - // t2 + t1*a2/a1 <= 0 - // assume -a1 < 0, a2 > 0: - // t1 - a1*x <= 0, t2 + a2*x <= 0 - // t1/a1 <= -t2/a2 - // t2 + t1*a2/a1 <= 0 - // - // the resolvent is the same in all cases (simpler proof should exist) - // - // assume a1 < 0, -a1 = a2: - // t1 <= a2*div(t2, a2) - // - - void model_based_opt::resolve(unsigned row_src, rational const& a1, unsigned row_dst, unsigned x) { - - SASSERT(a1 == get_coefficient(row_src, x)); - SASSERT(!a1.is_zero()); - SASSERT(row_src != row_dst); - - if (m_rows[row_dst].m_alive) { - rational a2 = get_coefficient(row_dst, x); - if (is_int(x)) { - TRACE("opt", - tout << x << ": " << a1 << " " << a2 << ": "; - display(tout, m_rows[row_dst]); - display(tout, m_rows[row_src]);); - if (a1.is_pos() != a2.is_pos() || m_rows[row_src].m_type == opt::t_eq) { - mul_add(x, a1, row_src, a2, row_dst); - } - else { - mul(row_dst, abs(a1)); - mul_add(false, row_dst, -abs(a2), row_src); - } - TRACE("opt", display(tout << "result ", m_rows[row_dst]);); - normalize(row_dst); - } - else { - mul_add(row_dst != m_objective_id && a1.is_pos() == a2.is_pos(), row_dst, -a2/a1, row_src); - } - } - } - - /** - * a1 > 0 - * a1*x + r1 = value - * a2*x + r2 <= 0 - * ------------------ - * a1*r2 - a2*r1 <= value - */ - void model_based_opt::solve(unsigned row_src, rational const& a1, unsigned row_dst, unsigned x) { - SASSERT(a1 == get_coefficient(row_src, x)); - SASSERT(a1.is_pos()); - SASSERT(row_src != row_dst); - if (!m_rows[row_dst].m_alive) return; - rational a2 = get_coefficient(row_dst, x); - mul(row_dst, a1); - mul_add(false, row_dst, -a2, row_src); - normalize(row_dst); - SASSERT(get_coefficient(row_dst, x).is_zero()); - } - - // resolution for integer rows. - void model_based_opt::mul_add( - unsigned x, rational src_c, unsigned row_src, rational dst_c, unsigned row_dst) { - row& dst = m_rows[row_dst]; - row const& src = m_rows[row_src]; - SASSERT(is_int(x)); - SASSERT(t_le == dst.m_type && t_le == src.m_type); - SASSERT(src_c.is_int()); - SASSERT(dst_c.is_int()); - SASSERT(m_var2value[x].is_int()); - - rational abs_src_c = abs(src_c); - rational abs_dst_c = abs(dst_c); - rational x_val = m_var2value[x]; - rational slack = (abs_src_c - rational::one()) * (abs_dst_c - rational::one()); - rational dst_val = dst.m_value - x_val*dst_c; - rational src_val = src.m_value - x_val*src_c; - rational distance = abs_src_c * dst_val + abs_dst_c * src_val + slack; - bool use_case1 = distance.is_nonpos() || abs_src_c.is_one() || abs_dst_c.is_one(); - bool use_case2 = false && abs_src_c == abs_dst_c && src_c.is_pos() != dst_c.is_pos() && !abs_src_c.is_one() && t_le == dst.m_type && t_le == src.m_type; - bool use_case3 = false && src_c.is_pos() != dst_c.is_pos() && t_le == dst.m_type && t_le == src.m_type; - - - if (use_case1) { - TRACE("opt", tout << "slack: " << slack << " " << src_c << " " << dst_val << " " << dst_c << " " << src_val << "\n";); - // dst <- abs_src_c*dst + abs_dst_c*src + slack - mul(row_dst, abs_src_c); - add(row_dst, slack); - mul_add(false, row_dst, abs_dst_c, row_src); - return; - } - - if (use_case2 || use_case3) { - // case2: - // x*src_c + s <= 0 - // -x*src_c + t <= 0 - // - // -src_c*div(-s, src_c) + t <= 0 - // - // Example: - // t <= 100*x <= s - // Then t <= 100*div(s, 100) - // - // case3: - // x*src_c + s <= 0 - // -x*dst_c + t <= 0 - // t <= x*dst_c, x*src_c <= -s -> - // t <= dst_c*div(-s, src_c) -> - // -dst_c*div(-s,src_c) + t <= 0 - // - - bool swapped = false; - if (src_c < 0) { - std::swap(row_src, row_dst); - std::swap(src_c, dst_c); - std::swap(abs_src_c, abs_dst_c); - swapped = true; - } - vector src_coeffs, dst_coeffs; - rational src_coeff = m_rows[row_src].m_coeff; - rational dst_coeff = m_rows[row_dst].m_coeff; - for (auto const& v : m_rows[row_src].m_vars) - if (v.m_id != x) - src_coeffs.push_back(var(v.m_id, -v.m_coeff)); - for (auto const& v : m_rows[row_dst].m_vars) - if (v.m_id != x) - dst_coeffs.push_back(v); - unsigned v = UINT_MAX; - if (src_coeffs.empty()) - dst_coeff -= abs_dst_c*div(-src_coeff, abs_src_c); - else - v = add_div(src_coeffs, -src_coeff, abs_src_c); - if (v != UINT_MAX) dst_coeffs.push_back(var(v, -abs_dst_c)); - if (swapped) - std::swap(row_src, row_dst); - retire_row(row_dst); - add_constraint(dst_coeffs, dst_coeff, t_le); - return; - } - - // - // create finite disjunction for |b|. - // exists x, z in [0 .. |b|-2] . b*x + s + z = 0 && ax + t <= 0 && bx + s <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && ax + t <= 0 && bx + s <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && bx + s <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && -z - s + s <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && -z <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 - // <=> - // exists x, z in [0 .. |b|-2] . b*x = -z - s && a*n_sign(b)(s + z) + |b|t <= 0 - // <=> - // exists z in [0 .. |b|-2] . |b| | (z + s) && a*n_sign(b)(s + z) + |b|t <= 0 - // - - TRACE("qe", tout << "finite disjunction " << distance << " " << src_c << " " << dst_c << "\n";); - vector coeffs; - if (abs_dst_c <= abs_src_c) { - rational z = mod(dst_val, abs_dst_c); - if (!z.is_zero()) z = abs_dst_c - z; - mk_coeffs_without(coeffs, dst.m_vars, x); - add_divides(coeffs, dst.m_coeff + z, abs_dst_c); - add(row_dst, z); - mul(row_dst, src_c * n_sign(dst_c)); - mul_add(false, row_dst, abs_dst_c, row_src); - } - else { - // z := b - (s + bx) mod b - // := b - s mod b - // b | s + z <=> b | s + b - s mod b <=> b | s - s mod b - rational z = mod(src_val, abs_src_c); - if (!z.is_zero()) z = abs_src_c - z; - mk_coeffs_without(coeffs, src.m_vars, x); - add_divides(coeffs, src.m_coeff + z, abs_src_c); - mul(row_dst, abs_src_c); - add(row_dst, z * dst_c * n_sign(src_c)); - mul_add(false, row_dst, dst_c * n_sign(src_c), row_src); - } - } - - void model_based_opt::mk_coeffs_without(vector& dst, vector const& src, unsigned x) { - for (var const & v : src) { - if (v.m_id != x) dst.push_back(v); - } - } - - rational model_based_opt::n_sign(rational const& b) const { - return rational(b.is_pos()?-1:1); - } - - void model_based_opt::mul(unsigned dst, rational const& c) { - if (c.is_one()) - return; - row& r = m_rows[dst]; - for (auto & v : r.m_vars) - v.m_coeff *= c; - r.m_mod *= c; - r.m_coeff *= c; - if (r.m_type != t_div && r.m_type != t_mod) - r.m_value *= c; - } - - void model_based_opt::add(unsigned dst, rational const& c) { - row& r = m_rows[dst]; - r.m_coeff += c; - r.m_value += c; - } - - void model_based_opt::sub(unsigned dst, rational const& c) { - row& r = m_rows[dst]; - r.m_coeff -= c; - r.m_value -= c; - } - - void model_based_opt::normalize(unsigned row_id) { - row& r = m_rows[row_id]; - if (!r.m_alive) - return; - if (r.m_vars.empty()) { - retire_row(row_id); - return; - } - if (r.m_type == t_divides) - return; - if (r.m_type == t_mod) - return; - if (r.m_type == t_div) - return; - rational g(abs(r.m_vars[0].m_coeff)); - bool all_int = g.is_int(); - for (unsigned i = 1; all_int && !g.is_one() && i < r.m_vars.size(); ++i) { - rational const& coeff = r.m_vars[i].m_coeff; - if (coeff.is_int()) { - g = gcd(g, abs(coeff)); - } - else { - all_int = false; - } - } - if (all_int && !r.m_coeff.is_zero()) { - if (r.m_coeff.is_int()) { - g = gcd(g, abs(r.m_coeff)); - } - else { - all_int = false; - } - } - if (all_int && !g.is_one()) { - SASSERT(!g.is_zero()); - mul(row_id, rational::one()/g); - } - } - - // - // set row1 <- row1 + c*row2 - // - void model_based_opt::mul_add(bool same_sign, unsigned row_id1, rational const& c, unsigned row_id2) { - if (c.is_zero()) - return; - - - m_new_vars.reset(); - row& r1 = m_rows[row_id1]; - row const& r2 = m_rows[row_id2]; - unsigned i = 0, j = 0; - while (i < r1.m_vars.size() || j < r2.m_vars.size()) { - if (j == r2.m_vars.size()) { - m_new_vars.append(r1.m_vars.size() - i, r1.m_vars.data() + i); - break; - } - if (i == r1.m_vars.size()) { - for (; j < r2.m_vars.size(); ++j) { - m_new_vars.push_back(r2.m_vars[j]); - m_new_vars.back().m_coeff *= c; - if (row_id1 != m_objective_id) - m_var2row_ids[r2.m_vars[j].m_id].push_back(row_id1); - } - break; - } - - unsigned v1 = r1.m_vars[i].m_id; - unsigned v2 = r2.m_vars[j].m_id; - if (v1 == v2) { - m_new_vars.push_back(r1.m_vars[i]); - m_new_vars.back().m_coeff += c*r2.m_vars[j].m_coeff; - ++i; - ++j; - if (m_new_vars.back().m_coeff.is_zero()) - m_new_vars.pop_back(); - } - else if (v1 < v2) { - m_new_vars.push_back(r1.m_vars[i]); - ++i; - } - else { - m_new_vars.push_back(r2.m_vars[j]); - m_new_vars.back().m_coeff *= c; - if (row_id1 != m_objective_id) - m_var2row_ids[r2.m_vars[j].m_id].push_back(row_id1); - ++j; - } - } - r1.m_coeff += c*r2.m_coeff; - r1.m_vars.swap(m_new_vars); - r1.m_value += c*r2.m_value; - - if (!same_sign && r2.m_type == t_lt) - r1.m_type = t_lt; - else if (same_sign && r1.m_type == t_lt && r2.m_type == t_lt) - r1.m_type = t_le; - SASSERT(invariant(row_id1, r1)); - } - - void model_based_opt::display(std::ostream& out) const { - for (auto const& r : m_rows) - display(out, r); - for (unsigned i = 0; i < m_var2row_ids.size(); ++i) { - unsigned_vector const& rows = m_var2row_ids[i]; - out << i << ": "; - for (auto const& r : rows) - out << r << " "; - out << "\n"; - } - } - - void model_based_opt::display(std::ostream& out, vector const& vars, rational const& coeff) { - unsigned i = 0; - for (var const& v : vars) { - if (i > 0 && v.m_coeff.is_pos()) - out << "+ "; - ++i; - if (v.m_coeff.is_one()) - out << "v" << v.m_id << " "; - else - out << v.m_coeff << "*v" << v.m_id << " "; - } - if (coeff.is_pos()) - out << " + " << coeff << " "; - else if (coeff.is_neg()) - out << coeff << " "; - } - - std::ostream& model_based_opt::display(std::ostream& out, row const& r) { - out << (r.m_alive?"a":"d") << " "; - display(out, r.m_vars, r.m_coeff); - switch (r.m_type) { - case opt::t_divides: - out << r.m_type << " " << r.m_mod << " = 0; value: " << r.m_value << "\n"; - break; - case opt::t_mod: - out << r.m_type << " " << r.m_mod << " = v" << r.m_id << " ; mod: " << mod(r.m_value, r.m_mod) << "\n"; - break; - case opt::t_div: - out << r.m_type << " " << r.m_mod << " = v" << r.m_id << " ; div: " << div(r.m_value, r.m_mod) << "\n"; - break; - default: - out << r.m_type << " 0; value: " << r.m_value << "\n"; - break; - } - return out; - } - - std::ostream& model_based_opt::display(std::ostream& out, def const& r) { - display(out, r.m_vars, r.m_coeff); - if (!r.m_div.is_one()) { - out << " / " << r.m_div; - } - return out; - } - - unsigned model_based_opt::add_var(rational const& value, bool is_int) { - unsigned v = m_var2value.size(); - m_var2value.push_back(value); - m_var2is_int.push_back(is_int); - SASSERT(value.is_int() || !is_int); - m_var2row_ids.push_back(unsigned_vector()); - return v; - } - - rational model_based_opt::get_value(unsigned var) { - return m_var2value[var]; - } - - void model_based_opt::set_row(unsigned row_id, vector const& coeffs, rational const& c, rational const& m, ineq_type rel) { - row& r = m_rows[row_id]; - rational val(c); - SASSERT(r.m_vars.empty()); - r.m_vars.append(coeffs.size(), coeffs.data()); - bool is_int_row = !coeffs.empty(); - std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); - for (auto const& c : coeffs) { - val += m_var2value[c.m_id] * c.m_coeff; - SASSERT(!is_int(c.m_id) || c.m_coeff.is_int()); - is_int_row &= is_int(c.m_id); - } - r.m_alive = true; - r.m_coeff = c; - r.m_value = val; - r.m_type = rel; - r.m_mod = m; - if (is_int_row && rel == t_lt) { - r.m_type = t_le; - r.m_coeff += rational::one(); - r.m_value += rational::one(); - } - } - - unsigned model_based_opt::new_row() { - unsigned row_id = 0; - if (m_retired_rows.empty()) { - row_id = m_rows.size(); - m_rows.push_back(row()); - } - else { - row_id = m_retired_rows.back(); - m_retired_rows.pop_back(); - SASSERT(!m_rows[row_id].m_alive); - m_rows[row_id].reset(); - m_rows[row_id].m_alive = true; - } - return row_id; - } - - unsigned model_based_opt::copy_row(unsigned src, unsigned excl) { - unsigned dst = new_row(); - row const& r = m_rows[src]; - set_row(dst, r.m_vars, r.m_coeff, r.m_mod, r.m_type); - for (auto const& v : r.m_vars) { - if (v.m_id != excl) - m_var2row_ids[v.m_id].push_back(dst); - } - SASSERT(invariant(dst, m_rows[dst])); - return dst; - } - - // -x + lo <= 0 - void model_based_opt::add_lower_bound(unsigned x, rational const& lo) { - vector coeffs; - coeffs.push_back(var(x, rational::minus_one())); - add_constraint(coeffs, lo, t_le); - } - - // x - hi <= 0 - void model_based_opt::add_upper_bound(unsigned x, rational const& hi) { - vector coeffs; - coeffs.push_back(var(x, rational::one())); - add_constraint(coeffs, -hi, t_le); - } - - void model_based_opt::add_constraint(vector const& coeffs, rational const& c, ineq_type rel) { - add_constraint(coeffs, c, rational::zero(), rel, 0); - } - - void model_based_opt::add_divides(vector const& coeffs, rational const& c, rational const& m) { - rational g(c); - for (auto const& [v, coeff] : coeffs) - g = gcd(coeff, g); - if ((g/m).is_int()) - return; - add_constraint(coeffs, c, m, t_divides, 0); - } - - unsigned model_based_opt::add_mod(vector const& coeffs, rational const& c, rational const& m) { - rational value = c; - for (auto const& var : coeffs) - value += var.m_coeff * m_var2value[var.m_id]; - unsigned v = add_var(mod(value, m), true); - add_constraint(coeffs, c, m, t_mod, v); - return v; - } - - unsigned model_based_opt::add_div(vector const& coeffs, rational const& c, rational const& m) { - rational value = c; - for (auto const& var : coeffs) - value += var.m_coeff * m_var2value[var.m_id]; - unsigned v = add_var(div(value, m), true); - add_constraint(coeffs, c, m, t_div, v); - return v; - } - - unsigned model_based_opt::add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type rel, unsigned id) { - auto const& r = m_rows.back(); - if (r.m_vars == coeffs && r.m_coeff == c && r.m_mod == m && r.m_type == rel && r.m_id == id && r.m_alive) - return m_rows.size() - 1; - unsigned row_id = new_row(); - set_row(row_id, coeffs, c, m, rel); - m_rows[row_id].m_id = id; - for (var const& coeff : coeffs) - m_var2row_ids[coeff.m_id].push_back(row_id); - SASSERT(invariant(row_id, m_rows[row_id])); - normalize(row_id); - return row_id; - } - - void model_based_opt::set_objective(vector const& coeffs, rational const& c) { - set_row(m_objective_id, coeffs, c, rational::zero(), t_le); - } - - void model_based_opt::get_live_rows(vector& rows) { - for (row & r : m_rows) - if (r.m_alive) - rows.push_back(r.normalize()); - } - - // - // pick glb and lub representative. - // The representative is picked such that it - // represents the fewest inequalities. - // The constraints that enforce a glb or lub are not forced. - // The constraints that separate the glb from ub or the lub from lb - // are not forced. - // In other words, suppose there are - // . N inequalities of the form t <= x - // . M inequalities of the form s >= x - // . t0 is glb among N under valuation. - // . s0 is lub among M under valuation. - // If N < M - // create the inequalities: - // t <= t0 for each t other than t0 (N-1 inequalities). - // t0 <= s for each s (M inequalities). - // If N >= M the construction is symmetric. - // - model_based_opt::def model_based_opt::project(unsigned x, bool compute_def) { - unsigned_vector& lub_rows = m_lub; - unsigned_vector& glb_rows = m_glb; - unsigned_vector& divide_rows = m_divides; - unsigned_vector& mod_rows = m_mod; - unsigned_vector& div_rows = m_div; - unsigned lub_index = UINT_MAX, glb_index = UINT_MAX; - bool lub_strict = false, glb_strict = false; - rational lub_val, glb_val; - rational const& x_val = m_var2value[x]; - unsigned_vector const& row_ids = m_var2row_ids[x]; - uint_set visited; - lub_rows.reset(); - glb_rows.reset(); - divide_rows.reset(); - mod_rows.reset(); - div_rows.reset(); - bool lub_is_unit = true, glb_is_unit = true; - unsigned eq_row = UINT_MAX; - // select the lub and glb. - for (unsigned row_id : row_ids) { - if (visited.contains(row_id)) - continue; - visited.insert(row_id); - row& r = m_rows[row_id]; - if (!r.m_alive) - continue; - rational a = get_coefficient(row_id, x); - if (a.is_zero()) - continue; - if (r.m_type == t_eq) - eq_row = row_id; - else if (r.m_type == t_mod) - mod_rows.push_back(row_id); - else if (r.m_type == t_div) - div_rows.push_back(row_id); - else if (r.m_type == t_divides) - divide_rows.push_back(row_id); - else if (a.is_pos()) { - rational lub_value = x_val - (r.m_value/a); - if (lub_rows.empty() || - lub_value < lub_val || - (lub_value == lub_val && r.m_type == t_lt && !lub_strict)) { - lub_val = lub_value; - lub_index = row_id; - lub_strict = r.m_type == t_lt; - } - lub_rows.push_back(row_id); - lub_is_unit &= a.is_one(); - } - else { - SASSERT(a.is_neg()); - rational glb_value = x_val - (r.m_value/a); - if (glb_rows.empty() || - glb_value > glb_val || - (glb_value == glb_val && r.m_type == t_lt && !glb_strict)) { - glb_val = glb_value; - glb_index = row_id; - glb_strict = r.m_type == t_lt; - } - glb_rows.push_back(row_id); - glb_is_unit &= a.is_minus_one(); - } - } - - if (!divide_rows.empty()) - return solve_divides(x, divide_rows, compute_def); - - if (!div_rows.empty() || !mod_rows.empty()) - return solve_mod_div(x, mod_rows, div_rows, compute_def); - - if (eq_row != UINT_MAX) - return solve_for(eq_row, x, compute_def); - - def result; - unsigned lub_size = lub_rows.size(); - unsigned glb_size = glb_rows.size(); - unsigned row_index = (lub_size <= glb_size) ? lub_index : glb_index; - - // There are only upper or only lower bounds. - if (row_index == UINT_MAX) { - if (compute_def) { - if (lub_index != UINT_MAX) - result = solve_for(lub_index, x, true); - else if (glb_index != UINT_MAX) - result = solve_for(glb_index, x, true); - else - result = def() + m_var2value[x]; - SASSERT(eval(result) == eval(x)); - } - else { - for (unsigned row_id : lub_rows) retire_row(row_id); - for (unsigned row_id : glb_rows) retire_row(row_id); - } - return result; - } - - SASSERT(lub_index != UINT_MAX); - SASSERT(glb_index != UINT_MAX); - if (compute_def) { - if (lub_size <= glb_size) - result = def(m_rows[lub_index], x); - else - result = def(m_rows[glb_index], x); - } - - // The number of matching lower and upper bounds is small. - if ((lub_size <= 2 || glb_size <= 2) && - (lub_size <= 3 && glb_size <= 3) && - (!is_int(x) || lub_is_unit || glb_is_unit)) { - for (unsigned i = 0; i < lub_size; ++i) { - unsigned row_id1 = lub_rows[i]; - bool last = i + 1 == lub_size; - rational coeff = get_coefficient(row_id1, x); - for (unsigned row_id2 : glb_rows) { - if (last) { - resolve(row_id1, coeff, row_id2, x); - } - else { - unsigned row_id3 = copy_row(row_id2); - resolve(row_id1, coeff, row_id3, x); - } - } - } - for (unsigned row_id : lub_rows) - retire_row(row_id); - - return result; - } - - // General case. - rational coeff = get_coefficient(row_index, x); - - for (unsigned row_id : lub_rows) - if (row_id != row_index) - resolve(row_index, coeff, row_id, x); - - for (unsigned row_id : glb_rows) - if (row_id != row_index) - resolve(row_index, coeff, row_id, x); - retire_row(row_index); - return result; - } - - - // - // Given v = a*x + b mod K - // - // - remove v = a*x + b mod K - // - // case a = 1: - // - add w = b mod K - // - x |-> K*y + z, 0 <= z < K - // - if z.value + w.value < K: - // add z + w - v = 0 - // - if z.value + w.value >= K: - // add z + w - v - K = 0 - // - // case a != 1, gcd(a, K) = 1 - // - x |-> x*y + a^-1*z, 0 <= z < K - // - add w = b mod K - // if z.value + w.value < K - // add z + w - v = 0 - // if z.value + w.value >= K - // add z + w - v - K = 0 - // - // case a != 1, gcd(a,K) = g != 1 - // - x |-> x*y + a^-1*z, 0 <= z < K - // a*x + b mod K = v is now - // g*z + b mod K = v - // - add w = b mod K - // - 0 <= g*z.value + w.value < K*(g+1) - // - add g*z + w - v - k*K = 0 for suitable k from 0 .. g based on model - // - // - // - // Given v = a*x + b div K - // Replace x |-> K*y + z - // - w = b div K - // - v = ((a*K*y + a*z) + b) div K - // = a*y + (a*z + b) div K - // = a*y + b div K + (b mod K + a*z) div K - // = a*y + b div K + k - // where k := (b.value mod K + a*z.value) div K - // k is between 0 and a - // - // - k*K <= b mod K + a*z < (k+1)*K - // - // A better version using a^-1 - // - v = (a*K*y + a^-1*a*z + b) div K - // = a*y + ((K*A + g)*z + b) div K where we write a*a^-1 = K*A + g - // = a*y + A + (g*z + b) div K - // - k*K <= b Kod m + gz < (k+1)*K - // where k is between 0 and g - // when gcd(a, K) = 1, then there are only two cases. - // - model_based_opt::def model_based_opt::solve_mod_div(unsigned x, unsigned_vector const& _mod_rows, unsigned_vector const& _div_rows, bool compute_def) { - def result; - unsigned_vector div_rows(_div_rows), mod_rows(_mod_rows); - SASSERT(!div_rows.empty() || !mod_rows.empty()); - TRACE("opt", display(tout << "solve_div v" << x << "\n")); - - rational K(1); - for (unsigned ri : div_rows) - K = lcm(K, m_rows[ri].m_mod); - for (unsigned ri : mod_rows) - K = lcm(K, m_rows[ri].m_mod); - - rational x_value = m_var2value[x]; - rational z_value = mod(x_value, K); - rational y_value = div(x_value, K); - SASSERT(x_value == K * y_value + z_value); - SASSERT(0 <= z_value && z_value < K); - // add new variables - unsigned z = add_var(z_value, true); - unsigned y = add_var(y_value, true); - - uint_set visited; - unsigned j = 0; - for (unsigned ri : div_rows) { - if (visited.contains(ri)) - continue; - row& r = m_rows[ri]; - mul(ri, K / r.m_mod); - r.m_alive = false; - visited.insert(ri); - div_rows[j++] = ri; - } - div_rows.shrink(j); - - j = 0; - for (unsigned ri : mod_rows) { - if (visited.contains(ri)) - continue; - m_rows[ri].m_alive = false; - visited.insert(ri); - mod_rows[j++] = ri; - } - mod_rows.shrink(j); - - - // replace x by K*y + z in other rows. - for (unsigned ri : m_var2row_ids[x]) { - if (visited.contains(ri)) - continue; - replace_var(ri, x, K, y, rational::one(), z); - visited.insert(ri); - normalize(ri); - } - - // add bounds for z - add_lower_bound(z, rational::zero()); - add_upper_bound(z, K - 1); - - - // solve for x_value = K*y_value + z_value, 0 <= z_value < K. - - unsigned_vector vs; - - for (unsigned ri : div_rows) { - - rational a = get_coefficient(ri, x); - replace_var(ri, x, rational::zero()); - - // add w = b div m - vector coeffs = m_rows[ri].m_vars; - rational coeff = m_rows[ri].m_coeff; - unsigned w = UINT_MAX; - rational offset(0); - if (K == 1) - offset = coeff; - else if (coeffs.empty()) - offset = div(coeff, K); - else - w = add_div(coeffs, coeff, K); - - // - // w = b div K - // v = a*y + w + k - // k = (a*z_value + (b_value mod K)) div K - // k*K <= a*z + b mod K < (k+1)*K - // - /** - * It is based on the following claim (tested for select values of a, K) - * (define-const K Int 13) - * (declare-const b Int) - * (define-const a Int -11) - * (declare-const y Int) - * (declare-const z Int) - * (define-const w Int (div b K)) - * (define-const k1 Int (+ (* a z) (mod b K))) - * (define-const k Int (div k1 K)) - * (define-const x Int (+ (* K y) z)) - * (define-const u Int (+ (* a x) b)) - * (define-const v Int (+ (* a y) w k)) - * (assert (<= 0 z)) - * (assert (< z K)) - * (assert (<= (* K k) k1)) - * (assert (< k1 (* K (+ k 1)))) - * (assert (not (= (div u K) v))) - * (check-sat) - */ - unsigned v = m_rows[ri].m_id; - rational b_value = eval(coeffs) + coeff; - rational k = div(a * z_value + mod(b_value, K), K); - vector div_coeffs; - div_coeffs.push_back(var(v, rational::minus_one())); - div_coeffs.push_back(var(y, a)); - if (w != UINT_MAX) - div_coeffs.push_back(var(w, rational::one())); - else if (K == 1) - div_coeffs.append(coeffs); - add_constraint(div_coeffs, k + offset, t_eq); - - unsigned u = UINT_MAX; - offset = 0; - if (K == 1) - offset = 0; - else if (coeffs.empty()) - offset = mod(coeff, K); - else - u = add_mod(coeffs, coeff, K); - - - // add a*z + (b mod K) < (k + 1)*K - vector bound_coeffs; - bound_coeffs.push_back(var(z, a)); - if (u != UINT_MAX) - bound_coeffs.push_back(var(u, rational::one())); - add_constraint(bound_coeffs, 1 - K * (k + 1) + offset, t_le); - - // add k*K <= az + (b mod K) - for (auto& c : bound_coeffs) - c.m_coeff.neg(); - add_constraint(bound_coeffs, k * K - offset, t_le); - // allow to recycle row. - retire_row(ri); - vs.push_back(v); - } - - for (unsigned ri : mod_rows) { - rational a = get_coefficient(ri, x); - replace_var(ri, x, rational::zero()); - rational rMod = m_rows[ri].m_mod; - - // add w = b mod rMod - vector coeffs = m_rows[ri].m_vars; - rational coeff = m_rows[ri].m_coeff; - unsigned v = m_rows[ri].m_id; - rational v_value = m_var2value[v]; - - unsigned w = UINT_MAX; - rational offset(0); - if (coeffs.empty() || rMod == 1) - offset = mod(coeff, rMod); - else - w = add_mod(coeffs, coeff, rMod); - - - rational w_value = w == UINT_MAX ? offset : m_var2value[w]; - -#if 0 - // V := (a * z_value + w_value) div rMod - // V*rMod <= a*z + w < (V+1)*rMod - // v = a*z + w - V*rMod - SASSERT(a > 0); - SASSERT(z_value >= 0); - SASSERT(w_value >= 0); - SASSERT(a * z_value + w_value >= 0); - rational V = div(a * z_value + w_value, rMod); - vector mod_coeffs; - SASSERT(V >= 0); - SASSERT(a * z_value + w_value >= V*rMod); - SASSERT((V+1)*rMod > a*z_value + w_value); - // -a*z - w + V*rMod <= 0 - mod_coeffs.push_back(var(z, -a)); - if (w != UINT_MAX) mod_coeffs.push_back(var(w, -rational::one())); - add_constraint(mod_coeffs, V*rMod - offset, t_le); - mod_coeffs.reset(); - // a*z + w - (V+1)*rMod + 1 <= 0 - mod_coeffs.push_back(var(z, a)); - if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); - add_constraint(mod_coeffs, -(V+1)*rMod + offset + 1, t_le); - mod_coeffs.reset(); - // -v + a*z + w - V*rMod = 0 - mod_coeffs.push_back(var(v, rational::minus_one())); - mod_coeffs.push_back(var(z, a)); - if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); - add_constraint(mod_coeffs, offset - V*rMod, t_eq); - -#else - // add v = a*z + w - V, for V = v_value - a * z_value - w_value - // claim: (= (mod x rMod) (- x (* rMod (div x rMod)))))) is a theorem for every x, rMod != 0 - rational V = v_value - a * z_value - w_value; - vector mod_coeffs; - mod_coeffs.push_back(var(v, rational::minus_one())); - mod_coeffs.push_back(var(z, a)); - if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); - add_constraint(mod_coeffs, V + offset, t_eq); - add_lower_bound(v, rational::zero()); - add_upper_bound(v, rMod - 1); -#endif - - retire_row(ri); - vs.push_back(v); - } - - - for (unsigned v : vs) { - def v_def = project(v, compute_def); - if (compute_def) - eliminate(v, v_def); - } - - // project internal variables. - def z_def = project(z, compute_def); - def y_def = project(y, compute_def); // may depend on z - - if (compute_def) { - z_def.substitute(y, y_def); - eliminate(y, y_def); - eliminate(z, z_def); - - result = (y_def * K) + z_def; - m_var2value[x] = eval(result); - TRACE("opt", tout << y << " := " << y_def << "\n"; - tout << z << " := " << z_def << "\n"; - tout << x << " := " << result << "\n"); - } - TRACE("opt", display(tout << "solve_div done v" << x << "\n")); - return result; - } - - // - // compute D and u. - // - // D = lcm(d1, d2) - // u = eval(x) mod D - // - // d1 | (a1x + t1) & d2 | (a2x + t2) - // = - // d1 | (a1(D*x' + u) + t1) & d2 | (a2(D*x' + u) + t2) - // = - // d1 | (a1*u + t1) & d2 | (a2*u + t2) - // - // x := D*x' + u - // - - model_based_opt::def model_based_opt::solve_divides(unsigned x, unsigned_vector const& divide_rows, bool compute_def) { - SASSERT(!divide_rows.empty()); - rational D(1); - for (unsigned idx : divide_rows) { - D = lcm(D, m_rows[idx].m_mod); - } - if (D.is_zero()) { - throw default_exception("modulo 0 is not defined"); - } - if (D.is_neg()) D = abs(D); - TRACE("opt1", display(tout << "lcm: " << D << " x: v" << x << " tableau\n");); - rational val_x = m_var2value[x]; - rational u = mod(val_x, D); - SASSERT(u.is_nonneg() && u < D); - for (unsigned idx : divide_rows) { - replace_var(idx, x, u); - SASSERT(invariant(idx, m_rows[idx])); - normalize(idx); - } - TRACE("opt1", display(tout << "tableau after replace x under mod\n");); - // - // update inequalities such that u is added to t and - // D is multiplied to coefficient of x. - // the interpretation of the new version of x is (x-u)/D - // - // a*x + t <= 0 - // a*(D*x' + u) + t <= 0 - // a*D*x' + a*u + t <= 0 - // - rational new_val = (val_x - u) / D; - SASSERT(new_val.is_int()); - unsigned y = add_var(new_val, true); - unsigned_vector const& row_ids = m_var2row_ids[x]; - uint_set visited; - for (unsigned row_id : row_ids) { - if (visited.contains(row_id)) - continue; - // x |-> D*y + u - replace_var(row_id, x, D, y, u); - visited.insert(row_id); - normalize(row_id); - } - TRACE("opt1", display(tout << "tableau after replace x by y := v" << y << "\n");); - def result = project(y, compute_def); - if (compute_def) { - result = (result * D) + u; - m_var2value[x] = eval(result); - } - TRACE("opt1", display(tout << "tableau after project y" << y << "\n");); - - return result; - } - - // update row with: x |-> C - void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& C) { - row& r = m_rows[row_id]; - SASSERT(!get_coefficient(row_id, x).is_zero()); - unsigned sz = r.m_vars.size(); - unsigned i = 0, j = 0; - rational coeff(0); - for (; i < sz; ++i) { - if (r.m_vars[i].m_id == x) { - coeff = r.m_vars[i].m_coeff; - } - else { - if (i != j) { - r.m_vars[j] = r.m_vars[i]; - } - ++j; - } - } - if (j != sz) { - r.m_vars.shrink(j); - } - r.m_coeff += coeff*C; - r.m_value += coeff*(C - m_var2value[x]); - } - - // update row with: x |-> A*y + B - void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& A, unsigned y, rational const& B) { - row& r = m_rows[row_id]; - rational coeff = get_coefficient(row_id, x); - if (coeff.is_zero()) return; - if (!r.m_alive) return; - replace_var(row_id, x, B); - r.m_vars.push_back(var(y, coeff*A)); - r.m_value += coeff*A*m_var2value[y]; - if (!r.m_vars.empty() && r.m_vars.back().m_id > y) - std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); - m_var2row_ids[y].push_back(row_id); - SASSERT(invariant(row_id, r)); - } - - // update row with: x |-> A*y + B*z - void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& A, unsigned y, rational const& B, unsigned z) { - row& r = m_rows[row_id]; - rational coeff = get_coefficient(row_id, x); - if (coeff.is_zero() || !r.m_alive) - return; - replace_var(row_id, x, rational::zero()); - if (A != 0) r.m_vars.push_back(var(y, coeff*A)); - if (B != 0) r.m_vars.push_back(var(z, coeff*B)); - r.m_value += coeff*A*m_var2value[y]; - r.m_value += coeff*B*m_var2value[z]; - std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); - if (A != 0) m_var2row_ids[y].push_back(row_id); - if (B != 0) m_var2row_ids[z].push_back(row_id); - SASSERT(invariant(row_id, r)); - } - - // 3x + t = 0 & 7 | (c*x + s) & ax <= u - // 3 | -t & 21 | (-ct + 3s) & a-t <= 3u - - model_based_opt::def model_based_opt::solve_for(unsigned row_id1, unsigned x, bool compute_def) { - TRACE("opt", tout << "v" << x << " := " << eval(x) << "\n" << m_rows[row_id1] << "\n"; - display(tout)); - rational a = get_coefficient(row_id1, x), b; - row& r1 = m_rows[row_id1]; - ineq_type ty = r1.m_type; - SASSERT(!a.is_zero()); - SASSERT(r1.m_alive); - if (a.is_neg()) { - a.neg(); - r1.neg(); - } - SASSERT(a.is_pos()); - if (ty == t_lt) { - SASSERT(compute_def); - r1.m_coeff -= r1.m_value; - r1.m_type = t_le; - r1.m_value = 0; - } - - if (m_var2is_int[x] && !a.is_one()) { - r1.m_coeff -= r1.m_value; - r1.m_value = 0; - vector coeffs; - mk_coeffs_without(coeffs, r1.m_vars, x); - rational c = mod(-eval(coeffs), a); - add_divides(coeffs, c, a); - } - unsigned_vector const& row_ids = m_var2row_ids[x]; - uint_set visited; - visited.insert(row_id1); - for (unsigned row_id2 : row_ids) { - if (visited.contains(row_id2)) - continue; - visited.insert(row_id2); - row& r = m_rows[row_id2]; - if (!r.m_alive) - continue; - b = get_coefficient(row_id2, x); - if (b.is_zero()) - continue; - row& dst = m_rows[row_id2]; - switch (dst.m_type) { - case t_eq: - case t_lt: - case t_le: - solve(row_id1, a, row_id2, x); - break; - case t_divides: - case t_mod: - case t_div: - // mod reduction already done. - UNREACHABLE(); - break; - } - } - def result; - if (compute_def) { - result = def(m_rows[row_id1], x); - m_var2value[x] = eval(result); - TRACE("opt1", tout << "updated eval " << x << " := " << eval(x) << "\n";); - } - retire_row(row_id1); - TRACE("opt", display(tout << "solved v" << x << "\n")); - return result; - } - - void model_based_opt::eliminate(unsigned v, def const& new_def) { - for (auto & d : m_result) - d.substitute(v, new_def); - } - - vector model_based_opt::project(unsigned num_vars, unsigned const* vars, bool compute_def) { - m_result.reset(); - for (unsigned i = 0; i < num_vars; ++i) { - m_result.push_back(project(vars[i], compute_def)); - eliminate(vars[i], m_result.back()); - TRACE("opt", display(tout << "After projecting: v" << vars[i] << "\n");); - } - return m_result; - } - -} - +/*++ +Copyright (c) 2016 Microsoft Corporation + +Module Name: + + model_based_opt.cpp + +Abstract: + + Model-based optimization and projection for linear real, integer arithmetic. + +Author: + + Nikolaj Bjorner (nbjorner) 2016-27-4 + +Revision History: + + +--*/ + +#include "math/simplex/model_based_opt.h" +#include "util/uint_set.h" +#include "util/z3_exception.h" + +std::ostream& operator<<(std::ostream& out, opt::ineq_type ie) { + switch (ie) { + case opt::t_eq: return out << " = "; + case opt::t_lt: return out << " < "; + case opt::t_le: return out << " <= "; + case opt::t_divides: return out << " divides "; + case opt::t_mod: return out << " mod "; + case opt::t_div: return out << " div "; + } + return out; +} + + +namespace opt { + + /** + * Convert a row ax + coeffs + coeff = value into a definition for x + * x = (value - coeffs - coeff)/a + * as backdrop we have existing assignments to x and other variables that + * satisfy the equality with value, and such that value satisfies + * the row constraint ( = , <= , < , mod) + */ + model_based_opt::def::def(row const& r, unsigned x) { + for (var const & v : r.m_vars) { + if (v.m_id != x) { + m_vars.push_back(v); + } + else { + m_div = -v.m_coeff; + } + } + m_coeff = r.m_coeff; + switch (r.m_type) { + case opt::t_lt: + m_coeff += m_div; + break; + case opt::t_le: + // for: ax >= t, then x := (t + a - 1) div a + if (m_div.is_pos()) { + m_coeff += m_div; + m_coeff -= rational::one(); + } + break; + default: + break; + } + normalize(); + SASSERT(m_div.is_pos()); + } + + model_based_opt::def model_based_opt::def::operator+(def const& other) const { + def result; + vector const& vs1 = m_vars; + vector const& vs2 = other.m_vars; + vector & vs = result.m_vars; + rational c1(1), c2(1); + if (m_div != other.m_div) { + c1 = other.m_div; + c2 = m_div; + } + unsigned i = 0, j = 0; + while (i < vs1.size() || j < vs2.size()) { + unsigned v1 = UINT_MAX, v2 = UINT_MAX; + if (i < vs1.size()) v1 = vs1[i].m_id; + if (j < vs2.size()) v2 = vs2[j].m_id; + if (v1 == v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c1; + vs.back().m_coeff += c2 * vs2[j].m_coeff; + ++i; ++j; + if (vs.back().m_coeff.is_zero()) { + vs.pop_back(); + } + } + else if (v1 < v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c1; + ++i; + } + else { + vs.push_back(vs2[j]); + vs.back().m_coeff *= c2; + ++j; + } + } + result.m_div = c1*m_div; + result.m_coeff = (m_coeff*c1) + (other.m_coeff*c2); + result.normalize(); + return result; + } + + /** + a1*x1 + a2*x2 + a3*x3 + coeff1 / c1 + x2 |-> b1*x1 + b4*x4 + ceoff2 / c2 + ------------------------------------------------------------------------ + (a1*x1 + a2*((b1*x1 + b4*x4 + coeff2) / c2) + a3*x3 + coeff1) / c1 + ------------------------------------------------------------------------ + (c2*a1*x1 + a2*b1*x1 + a2*b4*x4 + c2*a3*x3 + c2*coeff1 + coeff2) / c1*c2 + */ + void model_based_opt::def::substitute(unsigned v, def const& other) { + vector const& vs1 = m_vars; + rational coeff(0); + for (auto const& [id, c] : vs1) { + if (id == v) { + coeff = c; + break; + } + } + if (coeff == 0) + return; + + rational c1 = m_div; + rational c2 = other.m_div; + + vector const& vs2 = other.m_vars; + vector vs; + unsigned i = 0, j = 0; + while (i < vs1.size() || j < vs2.size()) { + unsigned v1 = UINT_MAX, v2 = UINT_MAX; + if (i < vs1.size()) v1 = vs1[i].m_id; + if (j < vs2.size()) v2 = vs2[j].m_id; + if (v1 == v) + ++i; + else if (v1 == v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c2; + vs.back().m_coeff += coeff * vs2[j].m_coeff; + ++i; ++j; + if (vs.back().m_coeff.is_zero()) + vs.pop_back(); + } + else if (v1 < v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c2; + ++i; + } + else { + vs.push_back(vs2[j]); + vs.back().m_coeff *= coeff; + ++j; + } + } + m_div *= other.m_div; + m_coeff *= c2; + m_coeff += coeff*other.m_coeff; + m_vars.reset(); + m_vars.append(vs); + normalize(); + } + + model_based_opt::def model_based_opt::def::operator/(rational const& r) const { + def result(*this); + result.m_div *= r; + result.normalize(); + return result; + } + + model_based_opt::def model_based_opt::def::operator*(rational const& n) const { + def result(*this); + for (var& v : result.m_vars) { + v.m_coeff *= n; + } + result.m_coeff *= n; + result.normalize(); + return result; + } + + model_based_opt::def model_based_opt::def::operator+(rational const& n) const { + def result(*this); + result.m_coeff += n * result.m_div; + result.normalize(); + return result; + } + + void model_based_opt::def::normalize() { + if (!m_div.is_int()) { + rational den = denominator(m_div); + SASSERT(den > 1); + for (var& v : m_vars) + v.m_coeff *= den; + m_coeff *= den; + m_div *= den; + + } + if (m_div.is_neg()) { + for (var& v : m_vars) + v.m_coeff.neg(); + m_coeff.neg(); + m_div.neg(); + } + if (m_div.is_one()) + return; + rational g(m_div); + if (!m_coeff.is_int()) + return; + g = gcd(g, m_coeff); + for (var const& v : m_vars) { + if (!v.m_coeff.is_int()) + return; + g = gcd(g, abs(v.m_coeff)); + if (g.is_one()) + break; + } + if (!g.is_one()) { + for (var& v : m_vars) + v.m_coeff /= g; + m_coeff /= g; + m_div /= g; + } + } + + model_based_opt::model_based_opt() { + m_rows.push_back(row()); + } + + bool model_based_opt::invariant() { + for (unsigned i = 0; i < m_rows.size(); ++i) { + if (!invariant(i, m_rows[i])) { + return false; + } + } + return true; + } + +#define PASSERT(_e_) { CTRACE("qe", !(_e_), display(tout, r); display(tout);); SASSERT(_e_); } + + bool model_based_opt::invariant(unsigned index, row const& r) { + vector const& vars = r.m_vars; + for (unsigned i = 0; i < vars.size(); ++i) { + // variables in each row are sorted and have non-zero coefficients + PASSERT(i + 1 == vars.size() || vars[i].m_id < vars[i+1].m_id); + PASSERT(!vars[i].m_coeff.is_zero()); + PASSERT(index == 0 || m_var2row_ids[vars[i].m_id].contains(index)); + } + + PASSERT(r.m_value == eval(r)); + PASSERT(r.m_type != t_eq || r.m_value.is_zero()); + // values satisfy constraints + PASSERT(index == 0 || r.m_type != t_lt || r.m_value.is_neg()); + PASSERT(index == 0 || r.m_type != t_le || !r.m_value.is_pos()); + PASSERT(index == 0 || r.m_type != t_divides || (mod(r.m_value, r.m_mod).is_zero())); + PASSERT(index == 0 || r.m_type != t_mod || r.m_id < m_var2value.size()); + PASSERT(index == 0 || r.m_type != t_div || r.m_id < m_var2value.size()); + return true; + } + + // a1*x + obj + // a2*x + t2 <= 0 + // a3*x + t3 <= 0 + // a4*x + t4 <= 0 + // a1 > 0, a2 > 0, a3 > 0, a4 < 0 + // x <= -t2/a2 + // x <= -t2/a3 + // determine lub among these. + // then resolve lub with others + // e.g., -t2/a2 <= -t3/a3, then + // replace inequality a3*x + t3 <= 0 by -t2/a2 + t3/a3 <= 0 + // mark a4 as invalid. + // + + // a1 < 0, a2 < 0, a3 < 0, a4 > 0 + // x >= t2/a2 + // x >= t3/a3 + // determine glb among these + // the resolve glb with others. + // e.g. t2/a2 >= t3/a3 + // then replace a3*x + t3 by t3/a3 - t2/a2 <= 0 + // + inf_eps model_based_opt::maximize() { + SASSERT(invariant()); + unsigned_vector bound_trail, bound_vars; + TRACE("opt", display(tout << "tableau\n");); + while (!objective().m_vars.empty()) { + var v = objective().m_vars.back(); + unsigned x = v.m_id; + rational const& coeff = v.m_coeff; + unsigned bound_row_index; + rational bound_coeff; + if (find_bound(x, bound_row_index, bound_coeff, coeff.is_pos())) { + SASSERT(!bound_coeff.is_zero()); + TRACE("opt", display(tout << "update: " << v << " ", objective()); + for (unsigned above : m_above) { + display(tout << "resolve: ", m_rows[above]); + }); + for (unsigned above : m_above) { + resolve(bound_row_index, bound_coeff, above, x); + } + for (unsigned below : m_below) { + resolve(bound_row_index, bound_coeff, below, x); + } + // coeff*x + objective <= ub + // a2*x + t2 <= 0 + // => coeff*x <= -t2*coeff/a2 + // objective + t2*coeff/a2 <= ub + + mul_add(false, m_objective_id, - coeff/bound_coeff, bound_row_index); + retire_row(bound_row_index); + bound_trail.push_back(bound_row_index); + bound_vars.push_back(x); + } + else { + TRACE("opt", display(tout << "unbound: " << v << " ", objective());); + update_values(bound_vars, bound_trail); + return inf_eps::infinity(); + } + } + + // + // update the evaluation of variables to satisfy the bound. + // + + update_values(bound_vars, bound_trail); + + rational value = objective().m_value; + if (objective().m_type == t_lt) { + return inf_eps(inf_rational(value, rational(-1))); + } + else { + return inf_eps(inf_rational(value)); + } + } + + + void model_based_opt::update_value(unsigned x, rational const& val) { + rational old_val = m_var2value[x]; + m_var2value[x] = val; + SASSERT(val.is_int() || !is_int(x)); + unsigned_vector const& row_ids = m_var2row_ids[x]; + for (unsigned row_id : row_ids) { + rational coeff = get_coefficient(row_id, x); + if (coeff.is_zero()) { + continue; + } + row & r = m_rows[row_id]; + rational delta = coeff * (val - old_val); + r.m_value += delta; + SASSERT(invariant(row_id, r)); + } + } + + + void model_based_opt::update_values(unsigned_vector const& bound_vars, unsigned_vector const& bound_trail) { + for (unsigned i = bound_trail.size(); i-- > 0; ) { + unsigned x = bound_vars[i]; + row& r = m_rows[bound_trail[i]]; + rational val = r.m_coeff; + rational old_x_val = m_var2value[x]; + rational new_x_val; + rational x_coeff, eps(0); + vector const& vars = r.m_vars; + for (var const& v : vars) { + if (x == v.m_id) { + x_coeff = v.m_coeff; + } + else { + val += m_var2value[v.m_id]*v.m_coeff; + } + } + SASSERT(!x_coeff.is_zero()); + new_x_val = -val/x_coeff; + + if (r.m_type == t_lt) { + eps = abs(old_x_val - new_x_val)/rational(2); + eps = std::min(rational::one(), eps); + SASSERT(!eps.is_zero()); + + // + // ax + t < 0 + // <=> x < -t/a + // <=> x := -t/a - epsilon + // + if (x_coeff.is_pos()) { + new_x_val -= eps; + } + // + // -ax + t < 0 + // <=> -ax < -t + // <=> -x < -t/a + // <=> x > t/a + // <=> x := t/a + epsilon + // + else { + new_x_val += eps; + } + } + TRACE("opt", display(tout << "v" << x + << " coeff_x: " << x_coeff + << " old_x_val: " << old_x_val + << " new_x_val: " << new_x_val + << " eps: " << eps << " ", r); ); + m_var2value[x] = new_x_val; + + r.m_value = eval(r); + SASSERT(invariant(bound_trail[i], r)); + } + + // update and check bounds for all other affected rows. + for (unsigned i = bound_trail.size(); i-- > 0; ) { + unsigned x = bound_vars[i]; + unsigned_vector const& row_ids = m_var2row_ids[x]; + for (unsigned row_id : row_ids) { + row & r = m_rows[row_id]; + r.m_value = eval(r); + SASSERT(invariant(row_id, r)); + } + } + SASSERT(invariant()); + } + + bool model_based_opt::find_bound(unsigned x, unsigned& bound_row_index, rational& bound_coeff, bool is_pos) { + bound_row_index = UINT_MAX; + rational lub_val; + rational const& x_val = m_var2value[x]; + unsigned_vector const& row_ids = m_var2row_ids[x]; + uint_set visited; + m_above.reset(); + m_below.reset(); + for (unsigned row_id : row_ids) { + SASSERT(row_id != m_objective_id); + if (visited.contains(row_id)) + continue; + visited.insert(row_id); + row& r = m_rows[row_id]; + if (!r.m_alive) + continue; + rational a = get_coefficient(row_id, x); + if (a.is_zero()) { + // skip + } + else if (a.is_pos() == is_pos || r.m_type == t_eq) { + rational value = x_val - (r.m_value/a); + if (bound_row_index == UINT_MAX) { + lub_val = value; + bound_row_index = row_id; + bound_coeff = a; + } + else if ((value == lub_val && r.m_type == opt::t_lt) || + (is_pos && value < lub_val) || + + (!is_pos && value > lub_val)) { + m_above.push_back(bound_row_index); + lub_val = value; + bound_row_index = row_id; + bound_coeff = a; + } + else + m_above.push_back(row_id); + } + else + m_below.push_back(row_id); + } + return bound_row_index != UINT_MAX; + } + + void model_based_opt::retire_row(unsigned row_id) { + SASSERT(!m_retired_rows.contains(row_id)); + m_rows[row_id].m_alive = false; + m_retired_rows.push_back(row_id); + } + + rational model_based_opt::eval(unsigned x) const { + return m_var2value[x]; + } + + rational model_based_opt::eval(def const& d) const { + vector const& vars = d.m_vars; + rational val = d.m_coeff; + for (var const& v : vars) { + val += v.m_coeff * eval(v.m_id); + } + val /= d.m_div; + return val; + } + + rational model_based_opt::eval(row const& r) const { + vector const& vars = r.m_vars; + rational val = r.m_coeff; + for (var const& v : vars) { + val += v.m_coeff * eval(v.m_id); + } + return val; + } + + rational model_based_opt::eval(vector const& coeffs) const { + rational val(0); + for (var const& v : coeffs) + val += v.m_coeff * eval(v.m_id); + return val; + } + + rational model_based_opt::get_coefficient(unsigned row_id, unsigned var_id) const { + return m_rows[row_id].get_coefficient(var_id); + } + + rational model_based_opt::row::get_coefficient(unsigned var_id) const { + if (m_vars.empty()) + return rational::zero(); + unsigned lo = 0, hi = m_vars.size(); + while (lo < hi) { + unsigned mid = lo + (hi - lo)/2; + SASSERT(mid < hi); + unsigned id = m_vars[mid].m_id; + if (id == var_id) { + lo = mid; + break; + } + if (id < var_id) + lo = mid + 1; + else + hi = mid; + } + if (lo == m_vars.size()) + return rational::zero(); + unsigned id = m_vars[lo].m_id; + if (id == var_id) + return m_vars[lo].m_coeff; + else + return rational::zero(); + } + + model_based_opt::row& model_based_opt::row::normalize() { +#if 0 + if (m_type == t_divides || m_type == t_mod || m_type == t_div) + return *this; + rational D(denominator(abs(m_coeff))); + if (D == 0) + D = 1; + for (auto const& [id, coeff] : m_vars) + if (coeff != 0) + D = lcm(D, denominator(abs(coeff))); + if (D == 1) + return *this; + SASSERT(D > 0); + for (auto & [id, coeff] : m_vars) + coeff *= D; + m_coeff *= D; +#endif + return *this; + } + + // + // Let + // row1: t1 + a1*x <= 0 + // row2: t2 + a2*x <= 0 + // + // assume a1, a2 have the same signs: + // (t2 + a2*x) <= (t1 + a1*x)*a2/a1 + // <=> t2*a1/a2 - t1 <= 0 + // <=> t2 - t1*a2/a1 <= 0 + // + // assume a1 > 0, -a2 < 0: + // t1 + a1*x <= 0, t2 - a2*x <= 0 + // t2/a2 <= -t1/a1 + // t2 + t1*a2/a1 <= 0 + // assume -a1 < 0, a2 > 0: + // t1 - a1*x <= 0, t2 + a2*x <= 0 + // t1/a1 <= -t2/a2 + // t2 + t1*a2/a1 <= 0 + // + // the resolvent is the same in all cases (simpler proof should exist) + // + // assume a1 < 0, -a1 = a2: + // t1 <= a2*div(t2, a2) + // + + void model_based_opt::resolve(unsigned row_src, rational const& a1, unsigned row_dst, unsigned x) { + + SASSERT(a1 == get_coefficient(row_src, x)); + SASSERT(!a1.is_zero()); + SASSERT(row_src != row_dst); + + if (m_rows[row_dst].m_alive) { + rational a2 = get_coefficient(row_dst, x); + if (is_int(x)) { + TRACE("opt", + tout << x << ": " << a1 << " " << a2 << ": "; + display(tout, m_rows[row_dst]); + display(tout, m_rows[row_src]);); + if (a1.is_pos() != a2.is_pos() || m_rows[row_src].m_type == opt::t_eq) { + mul_add(x, a1, row_src, a2, row_dst); + } + else { + mul(row_dst, abs(a1)); + mul_add(false, row_dst, -abs(a2), row_src); + } + TRACE("opt", display(tout << "result ", m_rows[row_dst]);); + normalize(row_dst); + } + else { + mul_add(row_dst != m_objective_id && a1.is_pos() == a2.is_pos(), row_dst, -a2/a1, row_src); + } + } + } + + /** + * a1 > 0 + * a1*x + r1 = value + * a2*x + r2 <= 0 + * ------------------ + * a1*r2 - a2*r1 <= value + */ + void model_based_opt::solve(unsigned row_src, rational const& a1, unsigned row_dst, unsigned x) { + SASSERT(a1 == get_coefficient(row_src, x)); + SASSERT(a1.is_pos()); + SASSERT(row_src != row_dst); + if (!m_rows[row_dst].m_alive) return; + rational a2 = get_coefficient(row_dst, x); + mul(row_dst, a1); + mul_add(false, row_dst, -a2, row_src); + normalize(row_dst); + SASSERT(get_coefficient(row_dst, x).is_zero()); + } + + // resolution for integer rows. + void model_based_opt::mul_add( + unsigned x, rational src_c, unsigned row_src, rational dst_c, unsigned row_dst) { + row& dst = m_rows[row_dst]; + row const& src = m_rows[row_src]; + SASSERT(is_int(x)); + SASSERT(t_le == dst.m_type && t_le == src.m_type); + SASSERT(src_c.is_int()); + SASSERT(dst_c.is_int()); + SASSERT(m_var2value[x].is_int()); + + rational abs_src_c = abs(src_c); + rational abs_dst_c = abs(dst_c); + rational x_val = m_var2value[x]; + rational slack = (abs_src_c - rational::one()) * (abs_dst_c - rational::one()); + rational dst_val = dst.m_value - x_val*dst_c; + rational src_val = src.m_value - x_val*src_c; + rational distance = abs_src_c * dst_val + abs_dst_c * src_val + slack; + bool use_case1 = distance.is_nonpos() || abs_src_c.is_one() || abs_dst_c.is_one(); + bool use_case2 = false && abs_src_c == abs_dst_c && src_c.is_pos() != dst_c.is_pos() && !abs_src_c.is_one() && t_le == dst.m_type && t_le == src.m_type; + bool use_case3 = false && src_c.is_pos() != dst_c.is_pos() && t_le == dst.m_type && t_le == src.m_type; + + + if (use_case1) { + TRACE("opt", tout << "slack: " << slack << " " << src_c << " " << dst_val << " " << dst_c << " " << src_val << "\n";); + // dst <- abs_src_c*dst + abs_dst_c*src + slack + mul(row_dst, abs_src_c); + add(row_dst, slack); + mul_add(false, row_dst, abs_dst_c, row_src); + return; + } + + if (use_case2 || use_case3) { + // case2: + // x*src_c + s <= 0 + // -x*src_c + t <= 0 + // + // -src_c*div(-s, src_c) + t <= 0 + // + // Example: + // t <= 100*x <= s + // Then t <= 100*div(s, 100) + // + // case3: + // x*src_c + s <= 0 + // -x*dst_c + t <= 0 + // t <= x*dst_c, x*src_c <= -s -> + // t <= dst_c*div(-s, src_c) -> + // -dst_c*div(-s,src_c) + t <= 0 + // + + bool swapped = false; + if (src_c < 0) { + std::swap(row_src, row_dst); + std::swap(src_c, dst_c); + std::swap(abs_src_c, abs_dst_c); + swapped = true; + } + vector src_coeffs, dst_coeffs; + rational src_coeff = m_rows[row_src].m_coeff; + rational dst_coeff = m_rows[row_dst].m_coeff; + for (auto const& v : m_rows[row_src].m_vars) + if (v.m_id != x) + src_coeffs.push_back(var(v.m_id, -v.m_coeff)); + for (auto const& v : m_rows[row_dst].m_vars) + if (v.m_id != x) + dst_coeffs.push_back(v); + unsigned v = UINT_MAX; + if (src_coeffs.empty()) + dst_coeff -= abs_dst_c*div(-src_coeff, abs_src_c); + else + v = add_div(src_coeffs, -src_coeff, abs_src_c); + if (v != UINT_MAX) dst_coeffs.push_back(var(v, -abs_dst_c)); + if (swapped) + std::swap(row_src, row_dst); + retire_row(row_dst); + add_constraint(dst_coeffs, dst_coeff, t_le); + return; + } + + // + // create finite disjunction for |b|. + // exists x, z in [0 .. |b|-2] . b*x + s + z = 0 && ax + t <= 0 && bx + s <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && ax + t <= 0 && bx + s <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && bx + s <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && -z - s + s <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 && -z <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && a|b|x + |b|t <= 0 + // <=> + // exists x, z in [0 .. |b|-2] . b*x = -z - s && a*n_sign(b)(s + z) + |b|t <= 0 + // <=> + // exists z in [0 .. |b|-2] . |b| | (z + s) && a*n_sign(b)(s + z) + |b|t <= 0 + // + + TRACE("qe", tout << "finite disjunction " << distance << " " << src_c << " " << dst_c << "\n";); + vector coeffs; + if (abs_dst_c <= abs_src_c) { + rational z = mod(dst_val, abs_dst_c); + if (!z.is_zero()) z = abs_dst_c - z; + mk_coeffs_without(coeffs, dst.m_vars, x); + add_divides(coeffs, dst.m_coeff + z, abs_dst_c); + add(row_dst, z); + mul(row_dst, src_c * n_sign(dst_c)); + mul_add(false, row_dst, abs_dst_c, row_src); + } + else { + // z := b - (s + bx) mod b + // := b - s mod b + // b | s + z <=> b | s + b - s mod b <=> b | s - s mod b + rational z = mod(src_val, abs_src_c); + if (!z.is_zero()) z = abs_src_c - z; + mk_coeffs_without(coeffs, src.m_vars, x); + add_divides(coeffs, src.m_coeff + z, abs_src_c); + mul(row_dst, abs_src_c); + add(row_dst, z * dst_c * n_sign(src_c)); + mul_add(false, row_dst, dst_c * n_sign(src_c), row_src); + } + } + + void model_based_opt::mk_coeffs_without(vector& dst, vector const& src, unsigned x) { + for (var const & v : src) { + if (v.m_id != x) dst.push_back(v); + } + } + + rational model_based_opt::n_sign(rational const& b) const { + return rational(b.is_pos()?-1:1); + } + + void model_based_opt::mul(unsigned dst, rational const& c) { + if (c.is_one()) + return; + row& r = m_rows[dst]; + for (auto & v : r.m_vars) + v.m_coeff *= c; + r.m_mod *= c; + r.m_coeff *= c; + if (r.m_type != t_div && r.m_type != t_mod) + r.m_value *= c; + } + + void model_based_opt::add(unsigned dst, rational const& c) { + row& r = m_rows[dst]; + r.m_coeff += c; + r.m_value += c; + } + + void model_based_opt::sub(unsigned dst, rational const& c) { + row& r = m_rows[dst]; + r.m_coeff -= c; + r.m_value -= c; + } + + void model_based_opt::normalize(unsigned row_id) { + row& r = m_rows[row_id]; + if (!r.m_alive) + return; + if (r.m_vars.empty()) { + retire_row(row_id); + return; + } + if (r.m_type == t_divides) + return; + if (r.m_type == t_mod) + return; + if (r.m_type == t_div) + return; + rational g(abs(r.m_vars[0].m_coeff)); + bool all_int = g.is_int(); + for (unsigned i = 1; all_int && !g.is_one() && i < r.m_vars.size(); ++i) { + rational const& coeff = r.m_vars[i].m_coeff; + if (coeff.is_int()) { + g = gcd(g, abs(coeff)); + } + else { + all_int = false; + } + } + if (all_int && !r.m_coeff.is_zero()) { + if (r.m_coeff.is_int()) { + g = gcd(g, abs(r.m_coeff)); + } + else { + all_int = false; + } + } + if (all_int && !g.is_one()) { + SASSERT(!g.is_zero()); + mul(row_id, rational::one()/g); + } + } + + // + // set row1 <- row1 + c*row2 + // + void model_based_opt::mul_add(bool same_sign, unsigned row_id1, rational const& c, unsigned row_id2) { + if (c.is_zero()) + return; + + + m_new_vars.reset(); + row& r1 = m_rows[row_id1]; + row const& r2 = m_rows[row_id2]; + unsigned i = 0, j = 0; + while (i < r1.m_vars.size() || j < r2.m_vars.size()) { + if (j == r2.m_vars.size()) { + m_new_vars.append(r1.m_vars.size() - i, r1.m_vars.data() + i); + break; + } + if (i == r1.m_vars.size()) { + for (; j < r2.m_vars.size(); ++j) { + m_new_vars.push_back(r2.m_vars[j]); + m_new_vars.back().m_coeff *= c; + if (row_id1 != m_objective_id) + m_var2row_ids[r2.m_vars[j].m_id].push_back(row_id1); + } + break; + } + + unsigned v1 = r1.m_vars[i].m_id; + unsigned v2 = r2.m_vars[j].m_id; + if (v1 == v2) { + m_new_vars.push_back(r1.m_vars[i]); + m_new_vars.back().m_coeff += c*r2.m_vars[j].m_coeff; + ++i; + ++j; + if (m_new_vars.back().m_coeff.is_zero()) + m_new_vars.pop_back(); + } + else if (v1 < v2) { + m_new_vars.push_back(r1.m_vars[i]); + ++i; + } + else { + m_new_vars.push_back(r2.m_vars[j]); + m_new_vars.back().m_coeff *= c; + if (row_id1 != m_objective_id) + m_var2row_ids[r2.m_vars[j].m_id].push_back(row_id1); + ++j; + } + } + r1.m_coeff += c*r2.m_coeff; + r1.m_vars.swap(m_new_vars); + r1.m_value += c*r2.m_value; + + if (!same_sign && r2.m_type == t_lt) + r1.m_type = t_lt; + else if (same_sign && r1.m_type == t_lt && r2.m_type == t_lt) + r1.m_type = t_le; + SASSERT(invariant(row_id1, r1)); + } + + void model_based_opt::display(std::ostream& out) const { + for (auto const& r : m_rows) + display(out, r); + for (unsigned i = 0; i < m_var2row_ids.size(); ++i) { + unsigned_vector const& rows = m_var2row_ids[i]; + out << i << ": "; + for (auto const& r : rows) + out << r << " "; + out << "\n"; + } + } + + void model_based_opt::display(std::ostream& out, vector const& vars, rational const& coeff) { + unsigned i = 0; + for (var const& v : vars) { + if (i > 0 && v.m_coeff.is_pos()) + out << "+ "; + ++i; + if (v.m_coeff.is_one()) + out << "v" << v.m_id << " "; + else + out << v.m_coeff << "*v" << v.m_id << " "; + } + if (coeff.is_pos()) + out << " + " << coeff << " "; + else if (coeff.is_neg()) + out << coeff << " "; + } + + std::ostream& model_based_opt::display(std::ostream& out, row const& r) { + out << (r.m_alive?"a":"d") << " "; + display(out, r.m_vars, r.m_coeff); + switch (r.m_type) { + case opt::t_divides: + out << r.m_type << " " << r.m_mod << " = 0; value: " << r.m_value << "\n"; + break; + case opt::t_mod: + out << r.m_type << " " << r.m_mod << " = v" << r.m_id << " ; mod: " << mod(r.m_value, r.m_mod) << "\n"; + break; + case opt::t_div: + out << r.m_type << " " << r.m_mod << " = v" << r.m_id << " ; div: " << div(r.m_value, r.m_mod) << "\n"; + break; + default: + out << r.m_type << " 0; value: " << r.m_value << "\n"; + break; + } + return out; + } + + std::ostream& model_based_opt::display(std::ostream& out, def const& r) { + display(out, r.m_vars, r.m_coeff); + if (!r.m_div.is_one()) { + out << " / " << r.m_div; + } + return out; + } + + unsigned model_based_opt::add_var(rational const& value, bool is_int) { + unsigned v = m_var2value.size(); + m_var2value.push_back(value); + m_var2is_int.push_back(is_int); + SASSERT(value.is_int() || !is_int); + m_var2row_ids.push_back(unsigned_vector()); + return v; + } + + rational model_based_opt::get_value(unsigned var) { + return m_var2value[var]; + } + + void model_based_opt::set_row(unsigned row_id, vector const& coeffs, rational const& c, rational const& m, ineq_type rel) { + row& r = m_rows[row_id]; + rational val(c); + SASSERT(r.m_vars.empty()); + r.m_vars.append(coeffs.size(), coeffs.data()); + bool is_int_row = !coeffs.empty(); + std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); + for (auto const& c : coeffs) { + val += m_var2value[c.m_id] * c.m_coeff; + SASSERT(!is_int(c.m_id) || c.m_coeff.is_int()); + is_int_row &= is_int(c.m_id); + } + r.m_alive = true; + r.m_coeff = c; + r.m_value = val; + r.m_type = rel; + r.m_mod = m; + if (is_int_row && rel == t_lt) { + r.m_type = t_le; + r.m_coeff += rational::one(); + r.m_value += rational::one(); + } + } + + unsigned model_based_opt::new_row() { + unsigned row_id = 0; + if (m_retired_rows.empty()) { + row_id = m_rows.size(); + m_rows.push_back(row()); + } + else { + row_id = m_retired_rows.back(); + m_retired_rows.pop_back(); + SASSERT(!m_rows[row_id].m_alive); + m_rows[row_id].reset(); + m_rows[row_id].m_alive = true; + } + return row_id; + } + + unsigned model_based_opt::copy_row(unsigned src, unsigned excl) { + unsigned dst = new_row(); + row const& r = m_rows[src]; + set_row(dst, r.m_vars, r.m_coeff, r.m_mod, r.m_type); + for (auto const& v : r.m_vars) { + if (v.m_id != excl) + m_var2row_ids[v.m_id].push_back(dst); + } + SASSERT(invariant(dst, m_rows[dst])); + return dst; + } + + // -x + lo <= 0 + void model_based_opt::add_lower_bound(unsigned x, rational const& lo) { + vector coeffs; + coeffs.push_back(var(x, rational::minus_one())); + add_constraint(coeffs, lo, t_le); + } + + // x - hi <= 0 + void model_based_opt::add_upper_bound(unsigned x, rational const& hi) { + vector coeffs; + coeffs.push_back(var(x, rational::one())); + add_constraint(coeffs, -hi, t_le); + } + + void model_based_opt::add_constraint(vector const& coeffs, rational const& c, ineq_type rel) { + add_constraint(coeffs, c, rational::zero(), rel, 0); + } + + void model_based_opt::add_divides(vector const& coeffs, rational const& c, rational const& m) { + rational g(c); + for (auto const& [v, coeff] : coeffs) + g = gcd(coeff, g); + if ((g/m).is_int()) + return; + add_constraint(coeffs, c, m, t_divides, 0); + } + + unsigned model_based_opt::add_mod(vector const& coeffs, rational const& c, rational const& m) { + rational value = c; + for (auto const& var : coeffs) + value += var.m_coeff * m_var2value[var.m_id]; + unsigned v = add_var(mod(value, m), true); + add_constraint(coeffs, c, m, t_mod, v); + return v; + } + + unsigned model_based_opt::add_div(vector const& coeffs, rational const& c, rational const& m) { + rational value = c; + for (auto const& var : coeffs) + value += var.m_coeff * m_var2value[var.m_id]; + unsigned v = add_var(div(value, m), true); + add_constraint(coeffs, c, m, t_div, v); + return v; + } + + unsigned model_based_opt::add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type rel, unsigned id) { + auto const& r = m_rows.back(); + if (r.m_vars == coeffs && r.m_coeff == c && r.m_mod == m && r.m_type == rel && r.m_id == id && r.m_alive) + return m_rows.size() - 1; + unsigned row_id = new_row(); + set_row(row_id, coeffs, c, m, rel); + m_rows[row_id].m_id = id; + for (var const& coeff : coeffs) + m_var2row_ids[coeff.m_id].push_back(row_id); + SASSERT(invariant(row_id, m_rows[row_id])); + normalize(row_id); + return row_id; + } + + void model_based_opt::set_objective(vector const& coeffs, rational const& c) { + set_row(m_objective_id, coeffs, c, rational::zero(), t_le); + } + + void model_based_opt::get_live_rows(vector& rows) { + for (row & r : m_rows) + if (r.m_alive) + rows.push_back(r.normalize()); + } + + // + // pick glb and lub representative. + // The representative is picked such that it + // represents the fewest inequalities. + // The constraints that enforce a glb or lub are not forced. + // The constraints that separate the glb from ub or the lub from lb + // are not forced. + // In other words, suppose there are + // . N inequalities of the form t <= x + // . M inequalities of the form s >= x + // . t0 is glb among N under valuation. + // . s0 is lub among M under valuation. + // If N < M + // create the inequalities: + // t <= t0 for each t other than t0 (N-1 inequalities). + // t0 <= s for each s (M inequalities). + // If N >= M the construction is symmetric. + // + model_based_opt::def model_based_opt::project(unsigned x, bool compute_def) { + unsigned_vector& lub_rows = m_lub; + unsigned_vector& glb_rows = m_glb; + unsigned_vector& divide_rows = m_divides; + unsigned_vector& mod_rows = m_mod; + unsigned_vector& div_rows = m_div; + unsigned lub_index = UINT_MAX, glb_index = UINT_MAX; + bool lub_strict = false, glb_strict = false; + rational lub_val, glb_val; + rational const& x_val = m_var2value[x]; + unsigned_vector const& row_ids = m_var2row_ids[x]; + uint_set visited; + lub_rows.reset(); + glb_rows.reset(); + divide_rows.reset(); + mod_rows.reset(); + div_rows.reset(); + bool lub_is_unit = true, glb_is_unit = true; + unsigned eq_row = UINT_MAX; + // select the lub and glb. + for (unsigned row_id : row_ids) { + if (visited.contains(row_id)) + continue; + visited.insert(row_id); + row& r = m_rows[row_id]; + if (!r.m_alive) + continue; + rational a = get_coefficient(row_id, x); + if (a.is_zero()) + continue; + if (r.m_type == t_eq) + eq_row = row_id; + else if (r.m_type == t_mod) + mod_rows.push_back(row_id); + else if (r.m_type == t_div) + div_rows.push_back(row_id); + else if (r.m_type == t_divides) + divide_rows.push_back(row_id); + else if (a.is_pos()) { + rational lub_value = x_val - (r.m_value/a); + if (lub_rows.empty() || + lub_value < lub_val || + (lub_value == lub_val && r.m_type == t_lt && !lub_strict)) { + lub_val = lub_value; + lub_index = row_id; + lub_strict = r.m_type == t_lt; + } + lub_rows.push_back(row_id); + lub_is_unit &= a.is_one(); + } + else { + SASSERT(a.is_neg()); + rational glb_value = x_val - (r.m_value/a); + if (glb_rows.empty() || + glb_value > glb_val || + (glb_value == glb_val && r.m_type == t_lt && !glb_strict)) { + glb_val = glb_value; + glb_index = row_id; + glb_strict = r.m_type == t_lt; + } + glb_rows.push_back(row_id); + glb_is_unit &= a.is_minus_one(); + } + } + + if (!divide_rows.empty()) + return solve_divides(x, divide_rows, compute_def); + + if (!div_rows.empty() || !mod_rows.empty()) + return solve_mod_div(x, mod_rows, div_rows, compute_def); + + if (eq_row != UINT_MAX) + return solve_for(eq_row, x, compute_def); + + def result; + unsigned lub_size = lub_rows.size(); + unsigned glb_size = glb_rows.size(); + unsigned row_index = (lub_size <= glb_size) ? lub_index : glb_index; + + // There are only upper or only lower bounds. + if (row_index == UINT_MAX) { + if (compute_def) { + if (lub_index != UINT_MAX) + result = solve_for(lub_index, x, true); + else if (glb_index != UINT_MAX) + result = solve_for(glb_index, x, true); + else + result = def() + m_var2value[x]; + SASSERT(eval(result) == eval(x)); + } + else { + for (unsigned row_id : lub_rows) retire_row(row_id); + for (unsigned row_id : glb_rows) retire_row(row_id); + } + return result; + } + + SASSERT(lub_index != UINT_MAX); + SASSERT(glb_index != UINT_MAX); + if (compute_def) { + if (lub_size <= glb_size) + result = def(m_rows[lub_index], x); + else + result = def(m_rows[glb_index], x); + } + + // The number of matching lower and upper bounds is small. + if ((lub_size <= 2 || glb_size <= 2) && + (lub_size <= 3 && glb_size <= 3) && + (!is_int(x) || lub_is_unit || glb_is_unit)) { + for (unsigned i = 0; i < lub_size; ++i) { + unsigned row_id1 = lub_rows[i]; + bool last = i + 1 == lub_size; + rational coeff = get_coefficient(row_id1, x); + for (unsigned row_id2 : glb_rows) { + if (last) { + resolve(row_id1, coeff, row_id2, x); + } + else { + unsigned row_id3 = copy_row(row_id2); + resolve(row_id1, coeff, row_id3, x); + } + } + } + for (unsigned row_id : lub_rows) + retire_row(row_id); + + return result; + } + + // General case. + rational coeff = get_coefficient(row_index, x); + + for (unsigned row_id : lub_rows) + if (row_id != row_index) + resolve(row_index, coeff, row_id, x); + + for (unsigned row_id : glb_rows) + if (row_id != row_index) + resolve(row_index, coeff, row_id, x); + retire_row(row_index); + return result; + } + + + // + // Given v = a*x + b mod K + // + // - remove v = a*x + b mod K + // + // case a = 1: + // - add w = b mod K + // - x |-> K*y + z, 0 <= z < K + // - if z.value + w.value < K: + // add z + w - v = 0 + // - if z.value + w.value >= K: + // add z + w - v - K = 0 + // + // case a != 1, gcd(a, K) = 1 + // - x |-> x*y + a^-1*z, 0 <= z < K + // - add w = b mod K + // if z.value + w.value < K + // add z + w - v = 0 + // if z.value + w.value >= K + // add z + w - v - K = 0 + // + // case a != 1, gcd(a,K) = g != 1 + // - x |-> x*y + a^-1*z, 0 <= z < K + // a*x + b mod K = v is now + // g*z + b mod K = v + // - add w = b mod K + // - 0 <= g*z.value + w.value < K*(g+1) + // - add g*z + w - v - k*K = 0 for suitable k from 0 .. g based on model + // + // + // + // Given v = a*x + b div K + // Replace x |-> K*y + z + // - w = b div K + // - v = ((a*K*y + a*z) + b) div K + // = a*y + (a*z + b) div K + // = a*y + b div K + (b mod K + a*z) div K + // = a*y + b div K + k + // where k := (b.value mod K + a*z.value) div K + // k is between 0 and a + // + // - k*K <= b mod K + a*z < (k+1)*K + // + // A better version using a^-1 + // - v = (a*K*y + a^-1*a*z + b) div K + // = a*y + ((K*A + g)*z + b) div K where we write a*a^-1 = K*A + g + // = a*y + A + (g*z + b) div K + // - k*K <= b Kod m + gz < (k+1)*K + // where k is between 0 and g + // when gcd(a, K) = 1, then there are only two cases. + // + model_based_opt::def model_based_opt::solve_mod_div(unsigned x, unsigned_vector const& _mod_rows, unsigned_vector const& _div_rows, bool compute_def) { + def result; + unsigned_vector div_rows(_div_rows), mod_rows(_mod_rows); + SASSERT(!div_rows.empty() || !mod_rows.empty()); + TRACE("opt", display(tout << "solve_div v" << x << "\n")); + + rational K(1); + for (unsigned ri : div_rows) + K = lcm(K, m_rows[ri].m_mod); + for (unsigned ri : mod_rows) + K = lcm(K, m_rows[ri].m_mod); + + rational x_value = m_var2value[x]; + rational z_value = mod(x_value, K); + rational y_value = div(x_value, K); + SASSERT(x_value == K * y_value + z_value); + SASSERT(0 <= z_value && z_value < K); + // add new variables + unsigned z = add_var(z_value, true); + unsigned y = add_var(y_value, true); + + uint_set visited; + unsigned j = 0; + for (unsigned ri : div_rows) { + if (visited.contains(ri)) + continue; + row& r = m_rows[ri]; + mul(ri, K / r.m_mod); + r.m_alive = false; + visited.insert(ri); + div_rows[j++] = ri; + } + div_rows.shrink(j); + + j = 0; + for (unsigned ri : mod_rows) { + if (visited.contains(ri)) + continue; + m_rows[ri].m_alive = false; + visited.insert(ri); + mod_rows[j++] = ri; + } + mod_rows.shrink(j); + + + // replace x by K*y + z in other rows. + for (unsigned ri : m_var2row_ids[x]) { + if (visited.contains(ri)) + continue; + replace_var(ri, x, K, y, rational::one(), z); + visited.insert(ri); + normalize(ri); + } + + // add bounds for z + add_lower_bound(z, rational::zero()); + add_upper_bound(z, K - 1); + + + // solve for x_value = K*y_value + z_value, 0 <= z_value < K. + + unsigned_vector vs; + + for (unsigned ri : div_rows) { + + rational a = get_coefficient(ri, x); + replace_var(ri, x, rational::zero()); + + // add w = b div m + vector coeffs = m_rows[ri].m_vars; + rational coeff = m_rows[ri].m_coeff; + unsigned w = UINT_MAX; + rational offset(0); + if (K == 1) + offset = coeff; + else if (coeffs.empty()) + offset = div(coeff, K); + else + w = add_div(coeffs, coeff, K); + + // + // w = b div K + // v = a*y + w + k + // k = (a*z_value + (b_value mod K)) div K + // k*K <= a*z + b mod K < (k+1)*K + // + /** + * It is based on the following claim (tested for select values of a, K) + * (define-const K Int 13) + * (declare-const b Int) + * (define-const a Int -11) + * (declare-const y Int) + * (declare-const z Int) + * (define-const w Int (div b K)) + * (define-const k1 Int (+ (* a z) (mod b K))) + * (define-const k Int (div k1 K)) + * (define-const x Int (+ (* K y) z)) + * (define-const u Int (+ (* a x) b)) + * (define-const v Int (+ (* a y) w k)) + * (assert (<= 0 z)) + * (assert (< z K)) + * (assert (<= (* K k) k1)) + * (assert (< k1 (* K (+ k 1)))) + * (assert (not (= (div u K) v))) + * (check-sat) + */ + unsigned v = m_rows[ri].m_id; + rational b_value = eval(coeffs) + coeff; + rational k = div(a * z_value + mod(b_value, K), K); + vector div_coeffs; + div_coeffs.push_back(var(v, rational::minus_one())); + div_coeffs.push_back(var(y, a)); + if (w != UINT_MAX) + div_coeffs.push_back(var(w, rational::one())); + else if (K == 1) + div_coeffs.append(coeffs); + add_constraint(div_coeffs, k + offset, t_eq); + + unsigned u = UINT_MAX; + offset = 0; + if (K == 1) + offset = 0; + else if (coeffs.empty()) + offset = mod(coeff, K); + else + u = add_mod(coeffs, coeff, K); + + + // add a*z + (b mod K) < (k + 1)*K + vector bound_coeffs; + bound_coeffs.push_back(var(z, a)); + if (u != UINT_MAX) + bound_coeffs.push_back(var(u, rational::one())); + add_constraint(bound_coeffs, 1 - K * (k + 1) + offset, t_le); + + // add k*K <= az + (b mod K) + for (auto& c : bound_coeffs) + c.m_coeff.neg(); + add_constraint(bound_coeffs, k * K - offset, t_le); + // allow to recycle row. + retire_row(ri); + vs.push_back(v); + } + + for (unsigned ri : mod_rows) { + rational a = get_coefficient(ri, x); + replace_var(ri, x, rational::zero()); + rational rMod = m_rows[ri].m_mod; + + // add w = b mod rMod + vector coeffs = m_rows[ri].m_vars; + rational coeff = m_rows[ri].m_coeff; + unsigned v = m_rows[ri].m_id; + rational v_value = m_var2value[v]; + + unsigned w = UINT_MAX; + rational offset(0); + if (coeffs.empty() || rMod == 1) + offset = mod(coeff, rMod); + else + w = add_mod(coeffs, coeff, rMod); + + + rational w_value = w == UINT_MAX ? offset : m_var2value[w]; + +#if 0 + // V := (a * z_value + w_value) div rMod + // V*rMod <= a*z + w < (V+1)*rMod + // v = a*z + w - V*rMod + SASSERT(a > 0); + SASSERT(z_value >= 0); + SASSERT(w_value >= 0); + SASSERT(a * z_value + w_value >= 0); + rational V = div(a * z_value + w_value, rMod); + vector mod_coeffs; + SASSERT(V >= 0); + SASSERT(a * z_value + w_value >= V*rMod); + SASSERT((V+1)*rMod > a*z_value + w_value); + // -a*z - w + V*rMod <= 0 + mod_coeffs.push_back(var(z, -a)); + if (w != UINT_MAX) mod_coeffs.push_back(var(w, -rational::one())); + add_constraint(mod_coeffs, V*rMod - offset, t_le); + mod_coeffs.reset(); + // a*z + w - (V+1)*rMod + 1 <= 0 + mod_coeffs.push_back(var(z, a)); + if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); + add_constraint(mod_coeffs, -(V+1)*rMod + offset + 1, t_le); + mod_coeffs.reset(); + // -v + a*z + w - V*rMod = 0 + mod_coeffs.push_back(var(v, rational::minus_one())); + mod_coeffs.push_back(var(z, a)); + if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); + add_constraint(mod_coeffs, offset - V*rMod, t_eq); + +#else + // add v = a*z + w - V, for V = v_value - a * z_value - w_value + // claim: (= (mod x rMod) (- x (* rMod (div x rMod)))))) is a theorem for every x, rMod != 0 + rational V = v_value - a * z_value - w_value; + vector mod_coeffs; + mod_coeffs.push_back(var(v, rational::minus_one())); + mod_coeffs.push_back(var(z, a)); + if (w != UINT_MAX) mod_coeffs.push_back(var(w, rational::one())); + add_constraint(mod_coeffs, V + offset, t_eq); + add_lower_bound(v, rational::zero()); + add_upper_bound(v, rMod - 1); +#endif + + retire_row(ri); + vs.push_back(v); + } + + + for (unsigned v : vs) { + def v_def = project(v, compute_def); + if (compute_def) + eliminate(v, v_def); + } + + // project internal variables. + def z_def = project(z, compute_def); + def y_def = project(y, compute_def); // may depend on z + + if (compute_def) { + z_def.substitute(y, y_def); + eliminate(y, y_def); + eliminate(z, z_def); + + result = (y_def * K) + z_def; + m_var2value[x] = eval(result); + TRACE("opt", tout << y << " := " << y_def << "\n"; + tout << z << " := " << z_def << "\n"; + tout << x << " := " << result << "\n"); + } + TRACE("opt", display(tout << "solve_div done v" << x << "\n")); + return result; + } + + // + // compute D and u. + // + // D = lcm(d1, d2) + // u = eval(x) mod D + // + // d1 | (a1x + t1) & d2 | (a2x + t2) + // = + // d1 | (a1(D*x' + u) + t1) & d2 | (a2(D*x' + u) + t2) + // = + // d1 | (a1*u + t1) & d2 | (a2*u + t2) + // + // x := D*x' + u + // + + model_based_opt::def model_based_opt::solve_divides(unsigned x, unsigned_vector const& divide_rows, bool compute_def) { + SASSERT(!divide_rows.empty()); + rational D(1); + for (unsigned idx : divide_rows) { + D = lcm(D, m_rows[idx].m_mod); + } + if (D.is_zero()) { + throw default_exception("modulo 0 is not defined"); + } + if (D.is_neg()) D = abs(D); + TRACE("opt1", display(tout << "lcm: " << D << " x: v" << x << " tableau\n");); + rational val_x = m_var2value[x]; + rational u = mod(val_x, D); + SASSERT(u.is_nonneg() && u < D); + for (unsigned idx : divide_rows) { + replace_var(idx, x, u); + SASSERT(invariant(idx, m_rows[idx])); + normalize(idx); + } + TRACE("opt1", display(tout << "tableau after replace x under mod\n");); + // + // update inequalities such that u is added to t and + // D is multiplied to coefficient of x. + // the interpretation of the new version of x is (x-u)/D + // + // a*x + t <= 0 + // a*(D*x' + u) + t <= 0 + // a*D*x' + a*u + t <= 0 + // + rational new_val = (val_x - u) / D; + SASSERT(new_val.is_int()); + unsigned y = add_var(new_val, true); + unsigned_vector const& row_ids = m_var2row_ids[x]; + uint_set visited; + for (unsigned row_id : row_ids) { + if (visited.contains(row_id)) + continue; + // x |-> D*y + u + replace_var(row_id, x, D, y, u); + visited.insert(row_id); + normalize(row_id); + } + TRACE("opt1", display(tout << "tableau after replace x by y := v" << y << "\n");); + def result = project(y, compute_def); + if (compute_def) { + result = (result * D) + u; + m_var2value[x] = eval(result); + } + TRACE("opt1", display(tout << "tableau after project y" << y << "\n");); + + return result; + } + + // update row with: x |-> C + void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& C) { + row& r = m_rows[row_id]; + SASSERT(!get_coefficient(row_id, x).is_zero()); + unsigned sz = r.m_vars.size(); + unsigned i = 0, j = 0; + rational coeff(0); + for (; i < sz; ++i) { + if (r.m_vars[i].m_id == x) { + coeff = r.m_vars[i].m_coeff; + } + else { + if (i != j) { + r.m_vars[j] = r.m_vars[i]; + } + ++j; + } + } + if (j != sz) { + r.m_vars.shrink(j); + } + r.m_coeff += coeff*C; + r.m_value += coeff*(C - m_var2value[x]); + } + + // update row with: x |-> A*y + B + void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& A, unsigned y, rational const& B) { + row& r = m_rows[row_id]; + rational coeff = get_coefficient(row_id, x); + if (coeff.is_zero()) return; + if (!r.m_alive) return; + replace_var(row_id, x, B); + r.m_vars.push_back(var(y, coeff*A)); + r.m_value += coeff*A*m_var2value[y]; + if (!r.m_vars.empty() && r.m_vars.back().m_id > y) + std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); + m_var2row_ids[y].push_back(row_id); + SASSERT(invariant(row_id, r)); + } + + // update row with: x |-> A*y + B*z + void model_based_opt::replace_var(unsigned row_id, unsigned x, rational const& A, unsigned y, rational const& B, unsigned z) { + row& r = m_rows[row_id]; + rational coeff = get_coefficient(row_id, x); + if (coeff.is_zero() || !r.m_alive) + return; + replace_var(row_id, x, rational::zero()); + if (A != 0) r.m_vars.push_back(var(y, coeff*A)); + if (B != 0) r.m_vars.push_back(var(z, coeff*B)); + r.m_value += coeff*A*m_var2value[y]; + r.m_value += coeff*B*m_var2value[z]; + std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare()); + if (A != 0) m_var2row_ids[y].push_back(row_id); + if (B != 0) m_var2row_ids[z].push_back(row_id); + SASSERT(invariant(row_id, r)); + } + + // 3x + t = 0 & 7 | (c*x + s) & ax <= u + // 3 | -t & 21 | (-ct + 3s) & a-t <= 3u + + model_based_opt::def model_based_opt::solve_for(unsigned row_id1, unsigned x, bool compute_def) { + TRACE("opt", tout << "v" << x << " := " << eval(x) << "\n" << m_rows[row_id1] << "\n"; + display(tout)); + rational a = get_coefficient(row_id1, x), b; + row& r1 = m_rows[row_id1]; + ineq_type ty = r1.m_type; + SASSERT(!a.is_zero()); + SASSERT(r1.m_alive); + if (a.is_neg()) { + a.neg(); + r1.neg(); + } + SASSERT(a.is_pos()); + if (ty == t_lt) { + SASSERT(compute_def); + r1.m_coeff -= r1.m_value; + r1.m_type = t_le; + r1.m_value = 0; + } + + if (m_var2is_int[x] && !a.is_one()) { + r1.m_coeff -= r1.m_value; + r1.m_value = 0; + vector coeffs; + mk_coeffs_without(coeffs, r1.m_vars, x); + rational c = mod(-eval(coeffs), a); + add_divides(coeffs, c, a); + } + unsigned_vector const& row_ids = m_var2row_ids[x]; + uint_set visited; + visited.insert(row_id1); + for (unsigned row_id2 : row_ids) { + if (visited.contains(row_id2)) + continue; + visited.insert(row_id2); + row& r = m_rows[row_id2]; + if (!r.m_alive) + continue; + b = get_coefficient(row_id2, x); + if (b.is_zero()) + continue; + row& dst = m_rows[row_id2]; + switch (dst.m_type) { + case t_eq: + case t_lt: + case t_le: + solve(row_id1, a, row_id2, x); + break; + case t_divides: + case t_mod: + case t_div: + // mod reduction already done. + UNREACHABLE(); + break; + } + } + def result; + if (compute_def) { + result = def(m_rows[row_id1], x); + m_var2value[x] = eval(result); + TRACE("opt1", tout << "updated eval " << x << " := " << eval(x) << "\n";); + } + retire_row(row_id1); + TRACE("opt", display(tout << "solved v" << x << "\n")); + return result; + } + + void model_based_opt::eliminate(unsigned v, def const& new_def) { + for (auto & d : m_result) + d.substitute(v, new_def); + } + + vector model_based_opt::project(unsigned num_vars, unsigned const* vars, bool compute_def) { + m_result.reset(); + for (unsigned i = 0; i < num_vars; ++i) { + m_result.push_back(project(vars[i], compute_def)); + eliminate(vars[i], m_result.back()); + TRACE("opt", display(tout << "After projecting: v" << vars[i] << "\n");); + } + return m_result; + } + +} + diff --git a/src/model/model_evaluator.cpp b/src/model/model_evaluator.cpp index 0c2a09e78de..6f218934b38 100644 --- a/src/model/model_evaluator.cpp +++ b/src/model/model_evaluator.cpp @@ -20,6 +20,7 @@ Revision History: #include "ast/ast_util.h" #include "ast/for_each_expr.h" #include "ast/recfun_decl_plugin.h" +#include "ast/polymorphism_util.h" #include "ast/rewriter/rewriter_types.h" #include "ast/rewriter/bool_rewriter.h" #include "ast/rewriter/arith_rewriter.h" @@ -32,6 +33,7 @@ Revision History: #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/var_subst.h" +#include "ast/rewriter/recfun_rewriter.h" #include "model/model_smt2_pp.h" #include "model/model.h" #include "model/model_evaluator_params.hpp" @@ -53,6 +55,7 @@ struct evaluator_cfg : public default_rewriter_cfg { pb_rewriter m_pb_rw; fpa_rewriter m_f_rw; seq_rewriter m_seq_rw; + recfun_rewriter m_rec_rw; array_util m_ar; arith_util m_au; fpa_util m_fpau; @@ -80,6 +83,7 @@ struct evaluator_cfg : public default_rewriter_cfg { m_pb_rw(m), m_f_rw(m), m_seq_rw(m), + m_rec_rw(m), m_ar(m), m_au(m), m_fpau(m), @@ -158,7 +162,7 @@ struct evaluator_cfg : public default_rewriter_cfg { return st; } - bool contains_as_array(expr* e) { + bool contains_redex(expr* e) { if (m_ar.is_as_array(e)) return true; if (is_var(e)) @@ -166,22 +170,26 @@ struct evaluator_cfg : public default_rewriter_cfg { if (is_app(e) && to_app(e)->get_num_args() == 0) return false; - struct has_as_array {}; - struct has_as_array_finder { - array_util& au; - has_as_array_finder(array_util& au): au(au) {} + struct has_redex {}; + struct has_redex_finder { + evaluator_cfg& ev; + has_redex_finder(evaluator_cfg& ev): ev(ev) {} void operator()(var* v) {} void operator()(quantifier* q) {} void operator()(app* a) { - if (au.is_as_array(a->get_decl())) - throw has_as_array(); + if (ev.m_ar.is_as_array(a->get_decl())) + throw has_redex(); + if (ev.m_ar.get_manager().is_eq(a)) + throw has_redex(); + if (ev.m_fpau.is_fp(a)) + throw has_redex(); } }; - has_as_array_finder ha(m_ar); + has_redex_finder ha(*this); try { for_each_expr(ha, e); } - catch (has_as_array) { + catch (has_redex) { return true; } return false; @@ -213,8 +221,8 @@ struct evaluator_cfg : public default_rewriter_cfg { expr* val = m_model.get_const_interp(f); if (val != nullptr) { result = val; - st = contains_as_array(val) ? BR_REWRITE_FULL : BR_DONE; - TRACE("model_evaluator", tout << result << "\n";); + st = contains_redex(val) ? BR_REWRITE_FULL : BR_DONE; + TRACE("model_evaluator", tout << st << " " << result << "\n";); return st; } if (!m_model_completion) @@ -283,6 +291,8 @@ struct evaluator_cfg : public default_rewriter_cfg { st = m_f_rw.mk_app_core(f, num, args, result); else if (fid == m_seq_rw.get_fid()) st = m_seq_rw.mk_app_core(f, num, args, result); + else if (fid == m_rec_rw.get_fid()) + st = m_rec_rw.mk_app_core(f, num, args, result); else if (fid == m.get_label_family_id() && num == 1) { result = args[0]; st = BR_DONE; @@ -364,7 +374,7 @@ struct evaluator_cfg : public default_rewriter_cfg { bool get_macro(func_decl * f, expr * & def, quantifier * & , proof * &) { func_interp * fi = m_model.get_func_interp(f); def = nullptr; - if (fi != nullptr) { + if (fi) { if (fi->is_partial()) { if (m_model_completion) { sort * s = f->get_range(); @@ -377,6 +387,24 @@ struct evaluator_cfg : public default_rewriter_cfg { def = fi->get_interp(); SASSERT(def != nullptr); } + else if (f->is_polymorphic() && (fi = m_model.get_func_interp(m.poly_root(f)))) { + if (fi->is_partial()) { + if (m_model_completion) { + sort * s = f->get_range(); + expr * val = m_model.get_some_value(s); + fi->set_else(val); + } + else + return false; + } + def = fi->get_interp(); + polymorphism::substitution subst(m); + polymorphism::util util(m); + util.unify(f, m.poly_root(f), subst); + def = subst(def); + SASSERT(def != nullptr); + + } else if (m_model_completion && (f->get_family_id() == null_family_id || m.get_plugin(f->get_family_id())->is_considered_uninterpreted(f))) { @@ -421,14 +449,39 @@ struct evaluator_cfg : public default_rewriter_cfg { return BR_DONE; } } + else if (!fi && m_au.is_considered_partially_interpreted(f, num, args, f_ui)) { + fi = m_model.get_func_interp(f_ui); + + if (fi) { + auto interp = fi->get_interp(); + if (interp) { + var_subst vs(m, false); + result = vs(fi->get_interp(), num, args); + result = m.mk_ite(m.mk_eq(m_au.mk_real(rational(0)), args[1]), result, m.mk_app(f, num, args)); + return BR_DONE; + } + } + } else if (!fi && m_fpau.is_considered_uninterpreted(f, num, args)) { result = m.get_some_value(f->get_range()); return BR_DONE; } - else if (m_dt.is_accessor(f) && !is_ground(args[0])) { - result = m.mk_app(f, num, args); - return BR_DONE; + else if (m_dt.is_accessor(f)) { + expr* arg = args[0]; + if (m.is_value(arg) && !fi) { + fi = alloc(func_interp, m, f->get_arity()); + expr* val = m_model.get_some_value(f->get_range()); + fi->set_else(val); + m_model.register_decl(f, fi); + result = val; + return BR_DONE; + } + if (!is_ground(arg)) { + result = m.mk_app(f, num, args); + return BR_DONE; + } } + if (fi) { if (fi->is_partial()) fi->set_else(m.get_some_value(f->get_range())); diff --git a/src/model/model_smt2_pp.cpp b/src/model/model_smt2_pp.cpp index b5ac3fbad00..89ea98075f7 100644 --- a/src/model/model_smt2_pp.cpp +++ b/src/model/model_smt2_pp.cpp @@ -68,7 +68,7 @@ static void pp_uninterp_sorts(std::ostream & out, ast_printer_context & ctx, mod buffer << " "; } buffer << "\n-----------"; - std::string buffer_str = buffer.str(); + std::string buffer_str = std::move(buffer).str(); unsigned len = static_cast(buffer_str.length()); pp_indent(out, indent); out << ";; "; @@ -206,9 +206,7 @@ static void pp_funs(std::ostream & out, ast_printer_context & ctx, model_core co if (f_i->is_partial()) { body = mk_string(m, "#unspecified"); for (unsigned j = 0; j < f->get_arity(); j++) { - std::stringstream strm; - strm << "x!" << (j+1); - var_names.push_back(symbol(strm.str())); + var_names.push_back(symbol("x!" + std::to_string(j+1))); } } else { diff --git a/src/muz/rel/dl_base.h b/src/muz/rel/dl_base.h index f6d03162423..158df4906ff 100644 --- a/src/muz/rel/dl_base.h +++ b/src/muz/rel/dl_base.h @@ -484,7 +484,7 @@ namespace datalog { virtual bool can_swap(const base_object & o) const { return false; } - virtual void swap(base_object & o) { + virtual void swap(base_object & o) noexcept { std::swap(m_kind, o.m_kind); #if DL_LEAK_HUNTING m_leak_guard = get_plugin().get_ast_manager().mk_fresh_sort(get_plugin().get_name().bare_str()); @@ -910,7 +910,7 @@ namespace datalog { public: table_signature() : m_functional_columns(0) {} - void swap(table_signature & s) { + void swap(table_signature & s) noexcept { signature_base::swap(s); std::swap(m_functional_columns, s.m_functional_columns); } diff --git a/src/muz/rel/dl_finite_product_relation.cpp b/src/muz/rel/dl_finite_product_relation.cpp index b1cdf055356..293cc72e2e0 100644 --- a/src/muz/rel/dl_finite_product_relation.cpp +++ b/src/muz/rel/dl_finite_product_relation.cpp @@ -1835,7 +1835,7 @@ namespace datalog { } } - void finite_product_relation::swap(relation_base & r0) { + void finite_product_relation::swap(relation_base & r0) noexcept { SASSERT(can_swap(r0)); finite_product_relation & r = finite_product_relation_plugin::get(r0); SASSERT(get_signature()==r.get_signature()); diff --git a/src/muz/rel/dl_finite_product_relation.h b/src/muz/rel/dl_finite_product_relation.h index f1387437b83..324564469f2 100644 --- a/src/muz/rel/dl_finite_product_relation.h +++ b/src/muz/rel/dl_finite_product_relation.h @@ -316,7 +316,7 @@ namespace datalog { Both relations must come from the same plugin and be of the same signature. */ - void swap(relation_base & r) override; + void swap(relation_base & r) noexcept override; /** \brief Create a \c finite_product_relation object. diff --git a/src/muz/rel/dl_mk_similarity_compressor.cpp b/src/muz/rel/dl_mk_similarity_compressor.cpp index ea1ef42b50b..46415a0de56 100644 --- a/src/muz/rel/dl_mk_similarity_compressor.cpp +++ b/src/muz/rel/dl_mk_similarity_compressor.cpp @@ -101,7 +101,7 @@ namespace datalog { /** \brief Return 0 if r1 and r2 could be similar. If the rough similarity - equaivelance class of r1 is greater than the one of r2, return 1; otherwise return -1. + equivalence class of r1 is greater than the one of r2, return 1; otherwise return -1. Two rules are in the same rough similarity class if they differ only in constant arguments of positive uninterpreted predicates. diff --git a/src/muz/rel/dl_vector_relation.h b/src/muz/rel/dl_vector_relation.h index ad9204b4105..e61fc5a2ae8 100644 --- a/src/muz/rel/dl_vector_relation.h +++ b/src/muz/rel/dl_vector_relation.h @@ -59,7 +59,7 @@ namespace datalog { dealloc(m_elems); } - void swap(relation_base& other) override { + void swap(relation_base& other) noexcept override { vector_relation& o = dynamic_cast(other); if (&o == this) return; std::swap(o.m_eqs, m_eqs); diff --git a/src/muz/spacer/spacer_proof_utils.cpp b/src/muz/spacer/spacer_proof_utils.cpp index 800a6329160..9fbd4e01c2e 100644 --- a/src/muz/spacer/spacer_proof_utils.cpp +++ b/src/muz/spacer/spacer_proof_utils.cpp @@ -209,10 +209,6 @@ namespace spacer { static proof_ref mk_th_lemma(ast_manager &m, ptr_buffer const &parents, unsigned num_params, parameter const *params) { - buffer v; - for (unsigned i = 1; i < num_params; ++i) - v.push_back(params[i]); - SASSERT(params[0].is_symbol()); family_id tid = m.mk_family_id(params[0].get_symbol()); SASSERT(tid != null_family_id); @@ -220,7 +216,7 @@ namespace spacer { proof_ref pf(m); pf = m.mk_th_lemma(tid, m.mk_false(), parents.size(), parents.data(), - v.size(), v.data()); + num_params - 1, params + 1); return pf; } diff --git a/src/muz/spacer/spacer_unsat_core_plugin.cpp b/src/muz/spacer/spacer_unsat_core_plugin.cpp index 5523326f481..c0a9d7e6f97 100644 --- a/src/muz/spacer/spacer_unsat_core_plugin.cpp +++ b/src/muz/spacer/spacer_unsat_core_plugin.cpp @@ -73,7 +73,7 @@ namespace spacer { // the current step needs to be interpolated: expr* fact = m.get_fact(pf); // if we trust the current step and we are able to use it - if (m_ctx.is_b_pure (pf) && (m.is_asserted(pf) || spacer::is_literal(m, fact))) { + if (m_ctx.is_b_pure (pf) && (m.is_asserted(pf) || spacer::is_literal(m, fact)) && !spacer::contains_defaults(fact, m)) { // just add it to the core m_ctx.add_lemma_to_core(fact); } diff --git a/src/muz/spacer/spacer_util.cpp b/src/muz/spacer/spacer_util.cpp index 4e1da577004..3959bd65595 100644 --- a/src/muz/spacer/spacer_util.cpp +++ b/src/muz/spacer/spacer_util.cpp @@ -34,6 +34,7 @@ Revision History: #include "ast/ast_pp.h" #include "ast/bv_decl_plugin.h" #include "ast/datatype_decl_plugin.h" +#include "ast/expr_functors.h" #include "ast/for_each_expr.h" #include "ast/occurs.h" #include "ast/rewriter/bool_rewriter.h" @@ -51,7 +52,7 @@ Revision History: #include "model/model_smt2_pp.h" #include "smt/params/smt_params.h" -#include "qe/lite/qe_lite_tactic.h" +#include "qe/lite/qel.h" #include "qe/mbp/mbp_plugin.h" #include "qe/mbp/mbp_term_graph.h" #include "qe/qe_mbp.h" @@ -69,6 +70,21 @@ Revision History: namespace spacer { +class contains_def_pred : public i_expr_pred { + array_util m_autil; + public: + contains_def_pred(ast_manager& m): m_autil(m) {} + bool operator()(expr* e) override { + return m_autil.is_default(e); + } +}; + +bool contains_defaults(expr *fml, ast_manager &m) { + contains_def_pred pred(m); + check_pred check(pred, m, false); + return check(fml); +} + bool is_clause(ast_manager &m, expr *n) { if (spacer::is_literal(m, n)) return true; if (m.is_or(n)) { @@ -173,7 +189,7 @@ void qe_project_spacer(ast_manager &m, app_ref_vector &vars, expr_ref &fml, while (true) { params_ref p; - qe_lite qe(m, p, false); + qel qe(m, p); qe(vars, fml); rw(fml); diff --git a/src/muz/spacer/spacer_util.h b/src/muz/spacer/spacer_util.h index dbc3083a212..272c6cf304c 100644 --- a/src/muz/spacer/spacer_util.h +++ b/src/muz/spacer/spacer_util.h @@ -121,6 +121,7 @@ void ground_expr(expr *e, expr_ref &out, app_ref_vector &vars); void mbqi_project(model &mdl, app_ref_vector &vars, expr_ref &fml); bool contains_selects(expr *fml, ast_manager &m); +bool contains_defaults(expr *fml, ast_manager &m); void get_select_indices(expr *fml, app_ref_vector &indices); void find_decls(expr *fml, app_ref_vector &decls, std::string &prefix); diff --git a/src/muz/transforms/dl_mk_coi_filter.cpp b/src/muz/transforms/dl_mk_coi_filter.cpp index 73541b0cd74..d59aa66dcba 100644 --- a/src/muz/transforms/dl_mk_coi_filter.cpp +++ b/src/muz/transforms/dl_mk_coi_filter.cpp @@ -102,14 +102,11 @@ namespace datalog { // set to false each unreached predicate if (res && m_context.get_model_converter()) { generic_model_converter* mc0 = alloc(generic_model_converter, m, "dl_coi"); - for (auto const& kv : engine) { - if (!kv.m_value.is_reachable()) { + for (auto const& kv : engine) + if (!kv.m_value.is_reachable()) unreachable.insert(kv.m_key); - } - } - for (func_decl* f : unreachable) { + for (func_decl* f : unreachable) mc0->add(f, m.mk_false()); - } m_context.add_model_converter(mc0); TRACE("dl", m_context.get_model_converter()->display(tout);); } diff --git a/src/muz/transforms/dl_mk_unbound_compressor.cpp b/src/muz/transforms/dl_mk_unbound_compressor.cpp index e93209f95a2..167da3ae86c 100644 --- a/src/muz/transforms/dl_mk_unbound_compressor.cpp +++ b/src/muz/transforms/dl_mk_unbound_compressor.cpp @@ -294,13 +294,11 @@ namespace datalog { replace_original_rule = true; replace_by_decompression_rule(source, rule_index, tail_index, arg_index); // NB. arg_indices becomes stale after original rule is replaced. - if (is_negated_predicate && !can_remove_orig_rule) { + if (is_negated_predicate && !can_remove_orig_rule) break; - } } - else { + else add_decompression_rule(source, r, tail_index, arg_index); - } } return replace_original_rule; } @@ -343,20 +341,19 @@ namespace datalog { } rule_set * mk_unbound_compressor::operator()(rule_set const & source) { - // TODO mc - if (!m_context.compress_unbound()) { + if (!m_context.compress_unbound() || m_context.get_model_converter()) return nullptr; - } m_modified = false; SASSERT(m_rules.empty()); rel_context_base* rel = m_context.get_rel_context(); - if (rel) { + if (rel) rel->collect_non_empty_predicates(m_non_empty_rels); - } + + unsigned init_rule_cnt = source.get_num_rules(); for (unsigned i = 0; i < init_rule_cnt; i++) { rule * r = source.get_rule(i); @@ -390,13 +387,15 @@ namespace datalog { scoped_ptr result; if (m_modified) { result = alloc(rule_set, m_context); - unsigned fin_rule_cnt = m_rules.size(); - for (unsigned i=0; iadd_rule(m_rules.get(i)); - } + for (auto* r : m_rules) + result->add_rule(r); result->inherit_predicates(source); } + if (result && m_context.get_model_converter()) { + // TODO mc + } reset(); + return result.detach(); } diff --git a/src/nlsat/nlsat_assignment.h b/src/nlsat/nlsat_assignment.h index 6729b73cd82..d96c8099e91 100644 --- a/src/nlsat/nlsat_assignment.h +++ b/src/nlsat/nlsat_assignment.h @@ -33,7 +33,7 @@ namespace nlsat { public: assignment(anum_manager & _m):m_values(_m) {} anum_manager & am() const { return m_values.m(); } - void swap(assignment & other) { + void swap(assignment & other) noexcept { m_values.swap(other.m_values); m_assigned.swap(other.m_assigned); } @@ -67,7 +67,7 @@ namespace nlsat { anum_manager & m() const override { return am(); } bool contains(var x) const override { return is_assigned(x); } anum const & operator()(var x) const override { SASSERT(is_assigned(x)); return value(x); } - void swap(var x, var y) { + void swap(var x, var y) noexcept { SASSERT(x < m_values.size() && y < m_values.size()); std::swap(m_assigned[x], m_assigned[y]); std::swap(m_values[x], m_values[y]); diff --git a/src/nlsat/nlsat_evaluator.cpp b/src/nlsat/nlsat_evaluator.cpp index 97e0e3d725e..ca6f9efaeb4 100644 --- a/src/nlsat/nlsat_evaluator.cpp +++ b/src/nlsat/nlsat_evaluator.cpp @@ -286,23 +286,24 @@ namespace nlsat { } bool check_invariant() const { - DEBUG_CODE( - SASSERT(m_sections.size() == m_sorted_sections.size()); - for (unsigned i = 0; i < m_sorted_sections.size(); i++) { - SASSERT(m_sorted_sections[i] < m_sections.size()); - SASSERT(m_sections[m_sorted_sections[i]].m_pos == i); - } - unsigned total_num_sections = 0; - unsigned total_num_signs = 0; - for (unsigned i = 0; i < m_info.size(); i++) { - SASSERT(m_info[i].m_first_section <= m_poly_sections.size()); - SASSERT(m_info[i].m_num_roots == 0 || m_info[i].m_first_section < m_poly_sections.size()); - SASSERT(m_info[i].m_first_sign < m_poly_signs.size()); - total_num_sections += m_info[i].m_num_roots; - total_num_signs += m_info[i].m_num_roots + 1; - } - SASSERT(total_num_sections == m_poly_sections.size()); - SASSERT(total_num_signs == m_poly_signs.size());); +#ifdef Z3DEBUG + SASSERT(m_sections.size() == m_sorted_sections.size()); + for (unsigned i = 0; i < m_sorted_sections.size(); i++) { + SASSERT(m_sorted_sections[i] < m_sections.size()); + SASSERT(m_sections[m_sorted_sections[i]].m_pos == i); + } + unsigned total_num_sections = 0; + unsigned total_num_signs = 0; + for (unsigned i = 0; i < m_info.size(); i++) { + SASSERT(m_info[i].m_first_section <= m_poly_sections.size()); + SASSERT(m_info[i].m_num_roots == 0 || m_info[i].m_first_section < m_poly_sections.size()); + SASSERT(m_info[i].m_first_sign < m_poly_signs.size()); + total_num_sections += m_info[i].m_num_roots; + total_num_signs += m_info[i].m_num_roots + 1; + } + SASSERT(total_num_sections == m_poly_sections.size()); + SASSERT(total_num_signs == m_poly_signs.size()); +#endif return true; } @@ -491,7 +492,7 @@ namespace nlsat { interval_set_ref infeasible_intervals(ineq_atom * a, bool neg, clause const* cls) { sign_table & table = m_sign_table_tmp; table.reset(); - TRACE("nsat_evaluator", m_solver.display(tout, *a) << "\n";); + TRACE("nlsat_evaluator", m_solver.display(tout, *a) << "\n";); unsigned num_ps = a->size(); var x = a->max_var(); for (unsigned i = 0; i < num_ps; i++) { @@ -664,7 +665,7 @@ namespace nlsat { return result; } - interval_set_ref infeasible_intervals(atom * a, bool neg, clause const* cls) { + interval_set_ref infeasible_intervals(atom * a, bool neg, clause const* cls) { return a->is_ineq_atom() ? infeasible_intervals(to_ineq_atom(a), neg, cls) : infeasible_intervals(to_root_atom(a), neg, cls); } }; @@ -685,7 +686,7 @@ namespace nlsat { return m_imp->eval(a, neg); } - interval_set_ref evaluator::infeasible_intervals(atom * a, bool neg, clause const* cls) { + interval_set_ref evaluator::infeasible_intervals(atom * a, bool neg, clause const* cls) { return m_imp->infeasible_intervals(a, neg, cls); } diff --git a/src/nlsat/nlsat_explain.cpp b/src/nlsat/nlsat_explain.cpp index 68a646f9228..87fda76ef2c 100644 --- a/src/nlsat/nlsat_explain.cpp +++ b/src/nlsat/nlsat_explain.cpp @@ -38,7 +38,7 @@ namespace nlsat { polynomial_ref_vector m_ps; polynomial_ref_vector m_ps2; polynomial_ref_vector m_psc_tmp; - polynomial_ref_vector m_factors; + polynomial_ref_vector m_factors, m_factors_save; scoped_anum_vector m_roots_tmp; bool m_simplify_cores; bool m_full_dimensional; @@ -142,6 +142,7 @@ namespace nlsat { m_ps2(m_pm), m_psc_tmp(m_pm), m_factors(m_pm), + m_factors_save(m_pm), m_roots_tmp(m_am), m_todo(u), m_core1(s), @@ -259,22 +260,42 @@ namespace nlsat { */ ptr_vector m_zero_fs; bool_vector m_is_even; + struct restore_factors { + polynomial_ref_vector& m_factors, &m_factors_save; + unsigned num_saved = 0; + restore_factors(polynomial_ref_vector&f, polynomial_ref_vector& fs): + m_factors(f), m_factors_save(fs) + { + num_saved = m_factors_save.size(); + m_factors_save.append(m_factors); + } + + ~restore_factors() { + m_factors.reset(); + m_factors.append(m_factors_save.size() - num_saved, m_factors_save.data() + num_saved); + m_factors_save.shrink(num_saved); + } + + }; void add_zero_assumption(polynomial_ref & p) { // If p is of the form p1^n1 * ... * pk^nk, // then only the factors that are zero in the current interpretation needed to be considered. // I don't want to create a nested conjunction in the clause. // Then, I assert p_i1 * ... * p_im != 0 - factor(p, m_factors); - unsigned num_factors = m_factors.size(); - m_zero_fs.reset(); - m_is_even.reset(); - polynomial_ref f(m_pm); - for (unsigned i = 0; i < num_factors; i++) { - f = m_factors.get(i); - if (is_zero(sign(f))) { - m_zero_fs.push_back(m_factors.get(i)); - m_is_even.push_back(false); - } + { + restore_factors _restore(m_factors, m_factors_save); + factor(p, m_factors); + unsigned num_factors = m_factors.size(); + m_zero_fs.reset(); + m_is_even.reset(); + polynomial_ref f(m_pm); + for (unsigned i = 0; i < num_factors; i++) { + f = m_factors.get(i); + if (is_zero(sign(f))) { + m_zero_fs.push_back(m_factors.get(i)); + m_is_even.push_back(false); + } + } } SASSERT(!m_zero_fs.empty()); // one of the factors must be zero in the current interpretation, since p is zero in it. literal l = m_solver.mk_ineq_literal(atom::EQ, m_zero_fs.size(), m_zero_fs.data(), m_is_even.data()); @@ -582,8 +603,9 @@ namespace nlsat { if (is_const(p)) return; if (m_factor) { - TRACE("nlsat_explain", display(tout << "adding factors of\n", p); tout << "\n";); + restore_factors _restore(m_factors, m_factors_save); factor(p, m_factors); + TRACE("nlsat_explain", display(tout << "adding factors of\n", p); tout << "\n" << m_factors << "\n";); polynomial_ref f(m_pm); for (unsigned i = 0; i < m_factors.size(); i++) { f = m_factors.get(i); @@ -859,6 +881,7 @@ namespace nlsat { */ void mk_linear_root(atom::kind k, var y, unsigned i, poly * p, bool mk_neg) { + TRACE("nlsat_explain", display_var(tout, y); m_pm.display(tout << ": ", p, m_solver.display_proc()); tout << "\n"); polynomial_ref p_prime(m_pm); p_prime = p; bool lsign = false; @@ -1379,7 +1402,7 @@ namespace nlsat { var max_x = max_var(m_ps); TRACE("nlsat_explain", tout << "polynomials in the conflict:\n"; display(tout, m_ps); tout << "\n";); elim_vanishing(m_ps); - TRACE("nlsat_explain", tout << "elim vanishing\n"; display(tout, m_ps); tout << "\n";); + TRACE("nlsat_explain", tout << "elim vanishing x" << max_x << "\n"; display(tout, m_ps); tout << "\n";); project(m_ps, max_x); TRACE("nlsat_explain", tout << "after projection\n"; display(tout, m_ps); tout << "\n";); } @@ -1508,12 +1531,13 @@ namespace nlsat { m_solver.display(tout, num, ls); m_solver.display(tout);); - DEBUG_CODE( - for (unsigned i = 0; i < num; ++i) { - SASSERT(m_solver.value(ls[i]) == l_true); - atom* a = m_atoms[ls[i].var()]; - SASSERT(!a || m_evaluator.eval(a, ls[i].sign())); - }); +#ifdef Z3DEBUG + for (unsigned i = 0; i < num; ++i) { + SASSERT(m_solver.value(ls[i]) == l_true); + atom* a = m_atoms[ls[i].var()]; + SASSERT(!a || m_evaluator.eval(a, ls[i].sign())); + } +#endif split_literals(x, num, ls, lits); collect_polys(lits.size(), lits.data(), m_ps); var mx_var = max_var(m_ps); @@ -1548,13 +1572,13 @@ namespace nlsat { for (unsigned i = 0; i < result.size(); ++i) { result.set(i, ~result[i]); } - DEBUG_CODE( - TRACE("nlsat", m_solver.display(tout, result.size(), result.data()) << "\n"; ); - for (literal l : result) { - CTRACE("nlsat", l_true != m_solver.value(l), m_solver.display(tout, l) << " " << m_solver.value(l) << "\n";); - SASSERT(l_true == m_solver.value(l)); - }); - +#ifdef Z3DEBUG + TRACE("nlsat", m_solver.display(tout, result.size(), result.data()) << "\n"; ); + for (literal l : result) { + CTRACE("nlsat", l_true != m_solver.value(l), m_solver.display(tout, l) << " " << m_solver.value(l) << "\n";); + SASSERT(l_true == m_solver.value(l)); + } +#endif } void split_literals(var x, unsigned n, literal const* ls, svector& lits) { diff --git a/src/nlsat/nlsat_interval_set.cpp b/src/nlsat/nlsat_interval_set.cpp index 70b2bd02c68..f928fd5b6cc 100644 --- a/src/nlsat/nlsat_interval_set.cpp +++ b/src/nlsat/nlsat_interval_set.cpp @@ -98,12 +98,13 @@ namespace nlsat { // Check if the intervals are valid, ordered, and are disjoint. bool check_interval_set(anum_manager & am, unsigned sz, interval const * ints) { - DEBUG_CODE( - for (unsigned i = 0; i < sz; i++) { - interval const & curr = ints[i]; - SASSERT(check_interval(am, curr)); - SASSERT(i >= sz - 1 || check_no_overlap(am, curr, ints[i+1])); - }); +#ifdef Z3DEBUG + for (unsigned i = 0; i < sz; i++) { + interval const & curr = ints[i]; + SASSERT(check_interval(am, curr)); + SASSERT(i >= sz - 1 || check_no_overlap(am, curr, ints[i+1])); + } +#endif return true; } @@ -663,9 +664,8 @@ namespace nlsat { continue; m_already_visited.setx(lidx, true, false); js.push_back(l); - if (s->m_intervals[i].m_clause) { + if (s->m_intervals[i].m_clause) clauses.push_back(const_cast(s->m_intervals[i].m_clause)); - } } for (unsigned i = 0; i < num; i++) { literal l = s->m_intervals[i].m_justification; @@ -695,12 +695,11 @@ namespace nlsat { scoped_mpq _w(m_am.qm()); m_am.qm().set(_w, num, den); m_am.set(w, _w); - return; } else { m_am.set(w, 0); - return; } + return; } unsigned n = 0; @@ -741,7 +740,7 @@ namespace nlsat { for (unsigned i = 1; i < num; i++) { if (s->m_intervals[i-1].m_upper_open && s->m_intervals[i].m_lower_open) { SASSERT(m_am.eq(s->m_intervals[i-1].m_upper, s->m_intervals[i].m_lower)); // otherwise we would have found it in the previous step - if (m_am.is_rational(s->m_intervals[i-1].m_upper)) { + if (m_am.is_rational(s->m_intervals[i-1].m_upper)) { m_am.set(w, s->m_intervals[i-1].m_upper); return; } diff --git a/src/nlsat/nlsat_scoped_literal_vector.h b/src/nlsat/nlsat_scoped_literal_vector.h index a2b617e559c..9b4b62fe8b0 100644 --- a/src/nlsat/nlsat_scoped_literal_vector.h +++ b/src/nlsat/nlsat_scoped_literal_vector.h @@ -67,7 +67,7 @@ namespace nlsat { void append(scoped_literal_vector const& ls) { append(ls.size(), ls.data()); } - void swap(scoped_literal_vector& other) { + void swap(scoped_literal_vector& other) noexcept { SASSERT(&m_solver == &other.m_solver); m_lits.swap(other.m_lits); } diff --git a/src/nlsat/nlsat_solver.cpp b/src/nlsat/nlsat_solver.cpp index 568a25a294c..d99603dcddc 100644 --- a/src/nlsat/nlsat_solver.cpp +++ b/src/nlsat/nlsat_solver.cpp @@ -48,7 +48,7 @@ namespace nlsat { typedef chashtable root_atom_table; // for apply_permutation procedure - void swap(clause * & c1, clause * & c2) { + void swap(clause * & c1, clause * & c2) noexcept { std::swap(c1, c2); } @@ -114,10 +114,10 @@ namespace nlsat { unsigned_vector m_levels; // bool_var -> level svector m_justifications; vector m_bwatches; // bool_var (that are not attached to atoms) -> clauses where it is maximal - bool_vector m_dead; // mark dead boolean variables + bool_vector m_dead; // mark dead boolean variables id_gen m_bid_gen; - bool_vector m_is_int; // m_is_int[x] is true if variable is integer + bool_vector m_is_int; // m_is_int[x] is true if variable is integer vector m_watches; // var -> clauses where variable is maximal interval_set_vector m_infeasible; // var -> to a set of interval where the variable cannot be assigned to. atom_vector m_var2eq; // var -> to asserted equality @@ -486,7 +486,7 @@ namespace nlsat { SASSERT(x == num_vars()); m_is_int. push_back(is_int); m_watches. push_back(clause_vector()); - m_infeasible.push_back(0); + m_infeasible.push_back(nullptr); m_var2eq. push_back(nullptr); m_perm. push_back(x); m_inv_perm. push_back(x); @@ -810,19 +810,24 @@ namespace nlsat { void check_lemma(unsigned n, literal const* cls, bool is_valid, assumption_set a) { TRACE("nlsat", display(tout << "check lemma: ", n, cls) << "\n"; display(tout);); - IF_VERBOSE(0, display(verbose_stream() << "check lemma: ", n, cls) << "\n"); + IF_VERBOSE(2, display(verbose_stream() << "check lemma " << (is_valid?"valid: ":"consequence: "), n, cls) << "\n"); for (clause* c : m_learned) IF_VERBOSE(1, display(verbose_stream() << "lemma: ", *c) << "\n"); - - solver solver2(m_ctx); + scoped_suspend_rlimit _limit(m_rlimit); + ctx c(m_rlimit, m_ctx.m_params, m_ctx.m_incremental); + solver solver2(c); imp& checker = *(solver2.m_imp); checker.m_check_lemmas = false; checker.m_log_lemmas = false; checker.m_inline_vars = false; + auto pconvert = [&](poly* p) { + return convert(m_pm, p, checker.m_pm); + }; + // need to translate Boolean variables and literals scoped_bool_vars tr(checker); for (var x = 0; x < m_is_int.size(); ++x) { - checker.register_var(x, m_is_int[x]); + checker.register_var(x, is_int(x)); } bool_var bv = 0; tr.push_back(bv); @@ -834,10 +839,10 @@ namespace nlsat { else if (a->is_ineq_atom()) { ineq_atom& ia = *to_ineq_atom(a); unsigned sz = ia.size(); - ptr_vector ps; + polynomial_ref_vector ps(checker.m_pm); bool_vector is_even; for (unsigned i = 0; i < sz; ++i) { - ps.push_back(ia.p(i)); + ps.push_back(pconvert(ia.p(i))); is_even.push_back(ia.is_even(i)); } bv = checker.mk_ineq_atom(ia.get_kind(), sz, ps.data(), is_even.data()); @@ -847,7 +852,7 @@ namespace nlsat { if (r.x() >= max_var(r.p())) { // permutation may be reverted after check completes, // but then root atoms are not used in lemmas. - bv = checker.mk_root_atom(r.get_kind(), r.x(), r.i(), r.p()); + bv = checker.mk_root_atom(r.get_kind(), r.x(), r.i(), pconvert(r.p())); } } else { @@ -872,7 +877,6 @@ namespace nlsat { literal nlit(tr[lit.var()], !lit.sign()); checker.mk_clause(1, &nlit, nullptr); } - IF_VERBOSE(0, verbose_stream() << "check\n";); lbool r = checker.check(); if (r == l_true) { for (bool_var b : tr) { @@ -902,16 +906,30 @@ namespace nlsat { TRACE("nlsat", display(tout << "violdated tautology clause: ", *c) << "\n";); } } + throw default_exception("lemma did not check"); UNREACHABLE(); } } void log_lemma(std::ostream& out, clause const& cls) { - display_smt2(out); - out << "(assert (not "; - display_smt2(out, cls) << "))\n"; - display(out << "(echo \"#" << m_lemma_count << " ", cls) << "\")\n"; + log_lemma(out, cls.size(), cls.data(), false); + } + + void log_lemma(std::ostream& out, unsigned n, literal const* cls, bool is_valid) { + ++m_lemma_count; + out << "(set-logic NRA)\n"; + if (is_valid) { + display_smt2_bool_decls(out); + display_smt2_arith_decls(out); + } + else + display_smt2(out); + for (unsigned i = 0; i < n; ++i) + display_smt2(out << "(assert ", ~cls[i]) << ")\n"; + display(out << "(echo \"#" << m_lemma_count << " ", n, cls) << "\")\n"; out << "(check-sat)\n(reset)\n"; + + TRACE("nlsat", display(tout << "(echo \"#" << m_lemma_count << " ", n, cls) << "\")\n"); } clause * mk_clause_core(unsigned num_lits, literal const * lits, bool learned, _assumption_set a) { @@ -928,14 +946,13 @@ namespace nlsat { clause * mk_clause(unsigned num_lits, literal const * lits, bool learned, _assumption_set a) { SASSERT(num_lits > 0); clause * cls = mk_clause_core(num_lits, lits, learned, a); - ++m_lemma_count; TRACE("nlsat_sort", display(tout << "mk_clause:\n", *cls) << "\n";); std::sort(cls->begin(), cls->end(), lit_lt(*this)); - TRACE("nlsat_sort", display(tout << "#" << m_lemma_count << " after sort:\n", *cls) << "\n";); + TRACE("nlsat", display(tout << " after sort:\n", *cls) << "\n";); if (learned && m_log_lemmas) { log_lemma(verbose_stream(), *cls); } - if (learned && m_check_lemmas) { + if (learned && m_check_lemmas && false) { check_lemma(cls->size(), cls->data(), false, cls->assumptions()); } if (learned) @@ -993,7 +1010,8 @@ namespace nlsat { } void undo_set_updt(interval_set * old_set) { - if (m_xk == null_var) return; + if (m_xk == null_var) + return; var x = m_xk; if (x < m_infeasible.size()) { m_ism.dec_ref(m_infeasible[x]); @@ -1134,7 +1152,7 @@ namespace nlsat { \brief Assign literal using the given justification */ void assign(literal l, justification j) { - TRACE("nlsat", + TRACE("nlsat_assign", display(tout << "assigning literal: ", l); display(tout << " <- ", j);); @@ -1254,7 +1272,9 @@ namespace nlsat { m_ism.get_justifications(s, core, clauses); if (include_l) core.push_back(~l); - assign(l, mk_lazy_jst(m_allocator, core.size(), core.data(), clauses.size(), clauses.data())); + auto j = mk_lazy_jst(m_allocator, core.size(), core.data(), clauses.size(), clauses.data()); + TRACE("nlsat_resolve", display(tout, j); display_eval(tout << "evaluated:", j)); + assign(l, j); SASSERT(value(l) == l_true); } @@ -1364,7 +1384,9 @@ namespace nlsat { tmp = m_ism.mk_union(curr_set, xk_set); if (m_ism.is_full(tmp)) { TRACE("nlsat_inf_set", tout << "infeasible set + current set = R, skip literal\n"; - display(tout, cls) << "\n";); + display(tout, cls) << "\n"; + m_ism.display(tout, tmp); tout << "\n"; + ); R_propagate(~l, tmp, false); continue; } @@ -1454,7 +1476,7 @@ namespace nlsat { void select_witness() { scoped_anum w(m_am); SASSERT(!m_ism.is_full(m_infeasible[m_xk])); - m_ism.peek_in_complement(m_infeasible[m_xk], m_is_int[m_xk], w, m_randomize); + m_ism.peek_in_complement(m_infeasible[m_xk], is_int(m_xk), w, m_randomize); TRACE("nlsat", tout << "infeasible intervals: "; m_ism.display(tout, m_infeasible[m_xk]); tout << "\n"; tout << "assigning "; m_display_var(tout, m_xk) << "(x" << m_xk << ") -> " << w << "\n";); @@ -1491,6 +1513,7 @@ namespace nlsat { m_bk = 0; m_xk = null_var; m_conflicts = 0; + m_next_conflict = 100; while (true) { CASSERT("nlsat", check_satisfied()); @@ -1527,6 +1550,7 @@ namespace nlsat { return l_false; if (m_conflicts >= m_max_conflicts) return l_undef; + log(); } if (m_xk == null_var) { @@ -1541,6 +1565,14 @@ namespace nlsat { } } + unsigned m_next_conflict = 100; + void log() { + if (m_conflicts < m_next_conflict) + return; + m_next_conflict += 100; + IF_VERBOSE(2, verbose_stream() << "(nlsat :conflicts " << m_conflicts << " :decisions " << m_decisions << " :propagations " << m_propagations << " :clauses " << m_clauses.size() << " :learned " << m_learned.size() << ")\n"); + } + lbool search_check() { lbool r = l_undef; @@ -1550,7 +1582,7 @@ namespace nlsat { vector> bounds; for (var x = 0; x < num_vars(); x++) { - if (m_is_int[x] && m_assignment.is_assigned(x) && !m_am.is_int(m_assignment.value(x))) { + if (is_int(x) && m_assignment.is_assigned(x) && !m_am.is_int(m_assignment.value(x))) { scoped_anum v(m_am), vlo(m_am); v = m_assignment.value(x); rational lo; @@ -1624,7 +1656,7 @@ namespace nlsat { restore_order(); } CTRACE("nlsat_model", r == l_true, tout << "model\n"; display_assignment(tout);); - CTRACE("nlsat", r == l_false, display(tout);); + CTRACE("nlsat", r == l_false, display(tout << "unsat\n");); SASSERT(r != l_true || check_satisfied(m_clauses)); return r; } @@ -1795,7 +1827,7 @@ namespace nlsat { } void resolve_clause(bool_var b, clause const & c) { - TRACE("nlsat_resolve", tout << "resolving clause for b: " << b << "\n"; display(tout, c) << "\n";); + TRACE("nlsat_resolve", tout << "resolving clause "; if (b != null_bool_var) tout << "for b: " << b << "\n"; display(tout, c) << "\n";); resolve_clause(b, c.size(), c.data()); m_lemma_assumptions = m_asm.mk_join(static_cast<_assumption_set>(c.assumptions()), m_lemma_assumptions); } @@ -1827,11 +1859,17 @@ namespace nlsat { tout << "new valid clause:\n"; display(tout, m_lazy_clause.size(), m_lazy_clause.data()) << "\n";); + + if (m_log_lemmas) + log_lemma(verbose_stream(), m_lazy_clause.size(), m_lazy_clause.data(), true); + if (m_check_lemmas) { + check_lemma(m_lazy_clause.size(), m_lazy_clause.data(), true, nullptr); m_valids.push_back(mk_clause_core(m_lazy_clause.size(), m_lazy_clause.data(), false, nullptr)); } - DEBUG_CODE({ +#ifdef Z3DEBUG + { unsigned sz = m_lazy_clause.size(); for (unsigned i = 0; i < sz; i++) { literal l = m_lazy_clause[i]; @@ -1844,7 +1882,8 @@ namespace nlsat { SASSERT(l.sign() || m_bvalues[b] == l_true); } } - }); + } +#endif checkpoint(); resolve_clause(b, m_lazy_clause.size(), m_lazy_clause.data()); @@ -2066,9 +2105,12 @@ namespace nlsat { TRACE("nlsat", tout << "new lemma:\n"; display(tout, m_lemma.size(), m_lemma.data()); tout << "\n"; tout << "found_decision: " << found_decision << "\n";); - if (false && m_check_lemmas) { + if (m_check_lemmas) { check_lemma(m_lemma.size(), m_lemma.data(), false, m_lemma_assumptions.get()); } + + if (m_log_lemmas) + log_lemma(verbose_stream(), m_lemma.size(), m_lemma.data(), false); // There are two possibilities: // 1) m_lemma contains only literals from previous stages, and they @@ -2151,19 +2193,20 @@ namespace nlsat { // ----------------------- bool check_watches() const { - DEBUG_CODE( - for (var x = 0; x < num_vars(); x++) { +#ifdef Z3DEBUG + for (var x = 0; x < num_vars(); x++) { clause_vector const & cs = m_watches[x]; unsigned sz = cs.size(); for (unsigned i = 0; i < sz; i++) { SASSERT(max_var(*(cs[i])) == x); } - }); + } +#endif return true; } bool check_bwatches() const { - DEBUG_CODE( +#ifdef Z3DEBUG for (bool_var b = 0; b < m_bwatches.size(); b++) { clause_vector const & cs = m_bwatches[b]; unsigned sz = cs.size(); @@ -2172,7 +2215,8 @@ namespace nlsat { SASSERT(max_var(c) == null_var); SASSERT(max_bvar(c) == b); } - }); + } +#endif return true; } @@ -2361,13 +2405,9 @@ namespace nlsat { } bool can_reorder() const { - for (clause* c : m_learned) { - if (has_root_atom(*c)) return false; - } - for (clause* c : m_clauses) { - if (has_root_atom(*c)) return false; - } - return m_patch_var.empty(); + return m_patch_var.empty() + && all_of(m_learned, [&](clause* c) { return !has_root_atom(*c); }) + && all_of(m_clauses, [&](clause* c) { return !has_root_atom(*c); }); } /** @@ -2396,11 +2436,11 @@ namespace nlsat { // undo_until_size(0) undo_until_stage(null_var); m_cache.reset(); - DEBUG_CODE({ - for (var x = 0; x < num_vars(); x++) { - SASSERT(m_watches[x].empty()); - } - }); +#ifdef Z3DEBUG + for (var x = 0; x < num_vars(); x++) { + SASSERT(m_watches[x].empty()); + } +#endif // update m_perm mapping for (unsigned ext_x = 0; ext_x < sz; ext_x++) { // p: internal -> new pos @@ -2416,12 +2456,12 @@ namespace nlsat { SASSERT(m_infeasible[x] == 0); } m_inv_perm.swap(new_inv_perm); - DEBUG_CODE({ - for (var x = 0; x < num_vars(); x++) { - SASSERT(x == m_inv_perm[m_perm[x]]); - SASSERT(m_watches[x].empty()); - } - }); +#ifdef Z3DEBUG + for (var x = 0; x < num_vars(); x++) { + SASSERT(x == m_inv_perm[m_perm[x]]); + SASSERT(m_watches[x].empty()); + } +#endif m_pm.rename(sz, p); TRACE("nlsat_bool_assignment_bug", tout << "before reinit cache\n"; display_bool_assignment(tout);); reinit_cache(); @@ -2441,12 +2481,12 @@ namespace nlsat { var_vector p; p.append(m_perm); reorder(p.size(), p.data()); - DEBUG_CODE({ - for (var x = 0; x < num_vars(); x++) { - SASSERT(m_perm[x] == x); - SASSERT(m_inv_perm[x] == x); - } - }); +#ifdef Z3DEBUG + for (var x = 0; x < num_vars(); x++) { + SASSERT(m_perm[x] == x); + SASSERT(m_inv_perm[x] == x); + } +#endif } /** @@ -2890,7 +2930,8 @@ namespace nlsat { var mx = max_var(p0); if (mx >= m_is_int.size()) return false; for (var x = 0; x <= mx; ++x) { - if (m_is_int[x]) continue; + if (is_int(x)) + continue; if (1 == m_pm.degree(p0, x)) { p = m_pm.coeff(p0, x, 1, q); if (!m_pm.is_const(p)) @@ -2934,10 +2975,10 @@ namespace nlsat { unsigned sz = m_atoms.size(); for (bool_var b = 0; b < sz; b++) { if (m_atoms[b] == nullptr && m_bvalues[b] != l_undef) { - out << "b" << b << " -> " << (m_bvalues[b] == l_true ? "true" : "false") << "\n"; + out << "b" << b << " -> " << (m_bvalues[b] == l_true ? "true" : "false") << " @" << m_levels[b] << "\n"; } else if (m_atoms[b] != nullptr && m_bvalues[b] != l_undef) { - display(out << "b" << b << " ", *m_atoms[b]) << " -> " << (m_bvalues[b] == l_true ? "true" : "false") << "\n"; + display(out << "b" << b << " ", *m_atoms[b]) << " -> " << (m_bvalues[b] == l_true ? "true" : "false") << " @" << m_levels[b] << "\n"; } } TRACE("nlsat_bool_assignment", @@ -2993,8 +3034,14 @@ namespace nlsat { } return out; } + + bool m_display_eval = false; + std::ostream& display_eval(std::ostream& out, justification j) { + flet _display(m_display_eval, true); + return display(out, j); + } - std::ostream& display(std::ostream & out, ineq_atom const & a, display_var_proc const & proc, bool use_star = false) const { + std::ostream& display_ineq(std::ostream & out, ineq_atom const & a, display_var_proc const & proc, bool use_star = false) const { unsigned sz = a.size(); for (unsigned i = 0; i < sz; i++) { if (use_star && i > 0) @@ -3002,7 +3049,7 @@ namespace nlsat { bool is_even = a.is_even(i); if (is_even || sz > 1) out << "("; - m_pm.display(out, a.p(i), proc, use_star); + display_polynomial(out, a.p(i), proc, use_star); if (is_even || sz > 1) out << ")"; if (is_even) @@ -3042,7 +3089,12 @@ namespace nlsat { return out; } - std::ostream& display_smt2(std::ostream & out, ineq_atom const & a, display_var_proc const & proc) const { + std::ostream& display_polynomial_smt2(std::ostream & out, poly const* p, display_var_proc const & proc) const { + m_pm.display_smt2(out, p, proc); + return out; + } + + std::ostream& display_ineq_smt2(std::ostream & out, ineq_atom const & a, display_var_proc const & proc) const { switch (a.get_kind()) { case atom::LT: out << "(< "; break; case atom::GT: out << "(> "; break; @@ -3056,13 +3108,13 @@ namespace nlsat { if (i > 0) out << " "; if (a.is_even(i)) { out << "(* "; - m_pm.display_smt2(out, a.p(i), proc); + display_polynomial_smt2(out, a.p(i), proc); out << " "; - m_pm.display_smt2(out, a.p(i), proc); + display_polynomial_smt2(out, a.p(i), proc); out << ")"; } else { - m_pm.display_smt2(out, a.p(i), proc); + display_polynomial_smt2(out, a.p(i), proc); } } if (sz > 1) @@ -3071,11 +3123,112 @@ namespace nlsat { return out; } - std::ostream& display_smt2(std::ostream & out, root_atom const & a, display_var_proc const & proc) const { - return display(out, a, proc); + std::ostream& display_poly_root(std::ostream& out, char const* y, root_atom const& a, display_var_proc const& proc) const { + out << "(exists (("; proc(out,a.x()); out << " Real))\n"; + out << "(and (= " << y << " "; + proc(out, a.x()); + out << ") (= 0 "; + display_polynomial_smt2(out, a.p(), proc); + out << ")))\n"; + return out; } + + std::ostream& display_binary_smt2(std::ostream& out, poly const* p1, char const* rel, poly const* p2, display_var_proc const& proc) const { + out << "(" << rel << " "; + display_polynomial_smt2(out, p1, proc); + out << " "; + display_polynomial_smt2(out, p2, proc); + out << ")"; + return out; + } + - std::ostream& display(std::ostream & out, root_atom const & a, display_var_proc const & proc) const { + std::ostream& display_linear_root_smt2(std::ostream & out, root_atom const & a, display_var_proc const & proc) const { + polynomial_ref A(m_pm), B(m_pm), Z(m_pm), Ax(m_pm); + polynomial::scoped_numeral zero(m_qm); + m_pm.m().set(zero, 0); + A = m_pm.derivative(a.p(), a.x()); + B = m_pm.neg(m_pm.substitute(a.p(), a.x(), zero)); + Z = m_pm.mk_zero(); + + Ax = m_pm.mul(m_pm.mk_polynomial(a.x()), A); + + // x < root[1](ax + b) == (a > 0 => ax + b < 0) & (a < 0 => ax + b > 0) + // x < root[1](ax + b) == (a > 0 => ax < -b) & (a < 0 => ax > -b) + + char const* rel1 = "<", *rel2 = ">"; + switch (a.get_kind()) { + case atom::ROOT_LT: rel1 = "<"; rel2 = ">"; break; + case atom::ROOT_GT: rel1 = ">"; rel2 = "<"; break; + case atom::ROOT_LE: rel1 = "<="; rel2 = ">="; break; + case atom::ROOT_GE: rel1 = ">="; rel2 = "<="; break; + case atom::ROOT_EQ: rel1 = rel2 = "="; break; + default: UNREACHABLE(); break; + } + + out << "(and "; + out << "(=> "; display_binary_smt2(out, A, ">", Z, proc); display_binary_smt2(out, Ax, rel1, B, proc); out << ") "; + out << "(=> "; display_binary_smt2(out, A, "<", Z, proc); display_binary_smt2(out, Ax, rel2, B, proc); out << ") "; + out << ")"; + + return out; + } + + + std::ostream& display_root_smt2(std::ostream& out, root_atom const& a, display_var_proc const& proc) const { + if (a.i() == 1 && m_pm.degree(a.p(), a.x()) == 1) + return display_linear_root_smt2(out, a, proc); +#if 1 + out << "(exists ("; + for (unsigned j = 0; j < a.i(); ++j) { + std::string y = std::string("y") + std::to_string(j); + out << "(" << y << " Real) "; + } + out << ")\n"; + out << "(and\n"; + for (unsigned j = 0; j < a.i(); ++j) { + std::string y = std::string("y") + std::to_string(j); + display_poly_root(out, y.c_str(), a, proc); + } + for (unsigned j = 0; j + 1 < a.i(); ++j) { + std::string y1 = std::string("y") + std::to_string(j); + std::string y2 = std::string("y") + std::to_string(j+1); + out << "(< " << y1 << " " << y2 << ")\n"; + } + + std::string yn = "y" + std::to_string(a.i() - 1); + + // TODO we need (forall z : z < yn . p(z) => z = y1 or ... z = y_{n-1}) + // to say y1, .., yn are the first n distinct roots. + // + out << "(forall ((z Real)) (=> (and (< z " << yn << ") "; display_poly_root(out, "z", a, proc) << ") "; + if (a.i() == 1) { + out << "false))\n"; + } + else { + out << "(or "; + for (unsigned j = 0; j + 1 < a.i(); ++j) { + std::string y1 = std::string("y") + std::to_string(j); + out << "(= z " << y1 << ") "; + } + out << ")))\n"; + } + switch (a.get_kind()) { + case atom::ROOT_LT: out << "(< "; proc(out, a.x()); out << " " << yn << ")"; break; + case atom::ROOT_GT: out << "(> "; proc(out, a.x()); out << " " << yn << ")"; break; + case atom::ROOT_LE: out << "(<= "; proc(out, a.x()); out << " " << yn << ")"; break; + case atom::ROOT_GE: out << "(>= "; proc(out, a.x()); out << " " << yn << ")"; break; + case atom::ROOT_EQ: out << "(= "; proc(out, a.x()); out << " " << yn << ")"; NOT_IMPLEMENTED_YET(); break; + } + out << "))"; + return out; +#endif + + + return display_root(out, a, proc); + } + + std::ostream& display_root(std::ostream & out, root_atom const & a, display_var_proc const & proc) const { proc(out, a.x()); switch (a.get_kind()) { case atom::ROOT_LT: out << " < "; break; @@ -3086,7 +3239,7 @@ namespace nlsat { default: UNREACHABLE(); break; } out << "root[" << a.i() << "]("; - m_pm.display(out, a.p(), proc); + display_polynomial(out, a.p(), proc); out << ")"; return out; } @@ -3114,21 +3267,16 @@ namespace nlsat { default: UNREACHABLE(); break; } out << "Root["; - m_pm.display(out, a.p(), mathematica_var_proc(a.x()), true); + display_polynomial(out, a.p(), mathematica_var_proc(a.x()), true); out << " &, " << a.i() << "]"; return out; } - - std::ostream& display_smt2(std::ostream & out, root_atom const & a) const { - NOT_IMPLEMENTED_YET(); - return out; - } std::ostream& display(std::ostream & out, atom const & a, display_var_proc const & proc) const { if (a.is_ineq_atom()) - return display(out, static_cast(a), proc); + return display_ineq(out, static_cast(a), proc); else - return display(out, static_cast(a), proc); + return display_root(out, static_cast(a), proc); } std::ostream& display(std::ostream & out, atom const & a) const { @@ -3144,9 +3292,9 @@ namespace nlsat { std::ostream& display_smt2(std::ostream & out, atom const & a, display_var_proc const & proc) const { if (a.is_ineq_atom()) - return display_smt2(out, static_cast(a), proc); + return display_ineq_smt2(out, static_cast(a), proc); else - return display_smt2(out, static_cast(a), proc); + return display_root_smt2(out, static_cast(a), proc); } std::ostream& display_atom(std::ostream & out, bool_var b, display_var_proc const & proc) const { @@ -3289,6 +3437,34 @@ namespace nlsat { return display(out, c, m_display_var); } + + std::ostream& display_polynomial(std::ostream& out, poly* p, display_var_proc const & proc, bool use_star = false) const { + if (m_display_eval) { + polynomial_ref q(m_pm); + q = p; + for (var x = 0; x < num_vars(); x++) + if (m_assignment.is_assigned(x)) { + auto& a = m_assignment.value(x); + if (!m_am.is_rational(a)) + continue; + mpq r; + m_am.to_rational(a, r); + q = m_pm.substitute(q, 1, &x, &r); + } + m_pm.display(out, q, proc, use_star); + } + else + m_pm.display(out, p, proc, use_star); + return out; + } + + // -- + + std::ostream& display_smt2(std::ostream & out, unsigned n, literal const* ls) const { + return display_smt2(out, n, ls, display_var_proc()); + } + + std::ostream& display_smt2(std::ostream & out, unsigned num, literal const * ls, display_var_proc const & proc) const { if (num == 0) { out << "false"; @@ -3445,7 +3621,7 @@ namespace nlsat { std::ostream& display_smt2_arith_decls(std::ostream & out) const { unsigned sz = m_is_int.size(); for (unsigned i = 0; i < sz; i++) { - if (m_is_int[i]) + if (is_int(i)) out << "(declare-fun x" << i << " () Int)\n"; else out << "(declare-fun x" << i << " () Real)\n"; diff --git a/src/nlsat/tactic/nlsat_tactic.cpp b/src/nlsat/tactic/nlsat_tactic.cpp index 3baf9da8765..9426de78ee7 100644 --- a/src/nlsat/tactic/nlsat_tactic.cpp +++ b/src/nlsat/tactic/nlsat_tactic.cpp @@ -125,7 +125,9 @@ class nlsat_tactic : public tactic { continue; // don't care md->register_decl(to_app(a)->get_decl(), val == l_true ? m.mk_true() : m.mk_false()); } - DEBUG_CODE(eval_model(*md.get(), g);); +#ifdef Z3DEBUG + eval_model(*md.get(), g); +#endif // VERIFY(eval_model(*md.get(), g)); mc = model2model_converter(md.get()); return ok; diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 11eddc2eb08..1b57a7200aa 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -213,7 +213,7 @@ namespace opt { void context::add_hard_constraint(expr* f, expr* t) { if (m_calling_on_model) - throw default_exception("adding soft constraints is not supported during callbacks"); + throw default_exception("adding hard constraints is not supported during callbacks"); m_scoped_state.m_asms.push_back(t); m_scoped_state.add(m.mk_implies(t, f)); clear_state(); @@ -905,12 +905,14 @@ namespace opt { ptr_vector deps; expr_dependency_ref core(r->dep(i), m); m.linearize(core, deps); - if (!deps.empty()) { - fmls.push_back(m.mk_implies(m.mk_and(deps.size(), deps.data()), r->form(i))); - } - else { + if (deps.empty()) + fmls.push_back(r->form(i)); + else if (deps.size() == 1 && deps[0] == r->form(i)) + continue; + else if (is_objective(r->form(i))) fmls.push_back(r->form(i)); - } + else + fmls.push_back(m.mk_implies(mk_and(m, deps.size(), deps.data()), r->form(i))); } if (r->inconsistent()) { ptr_vector core_elems; @@ -920,6 +922,10 @@ namespace opt { } } + bool context::is_objective(expr* fml) { + return is_app(fml) && m_objective_fns.contains(to_app(fml)->get_decl()); + } + bool context::is_maximize(expr* fml, app_ref& term, expr_ref& orig_term, unsigned& index) { if (is_app(fml) && m_objective_fns.find(to_app(fml)->get_decl(), index) && m_objectives[index].m_type == O_MAXIMIZE) { diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 9e61ae92cd6..4e791531e3f 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -303,6 +303,7 @@ namespace opt { void import_scoped_state(); void normalize(expr_ref_vector const& asms); void internalize(); + bool is_objective(expr* fml); bool is_maximize(expr* fml, app_ref& term, expr_ref& orig_term, unsigned& index); bool is_minimize(expr* fml, app_ref& term, expr_ref& orig_term, unsigned& index); bool is_maxsat(expr* fml, expr_ref_vector& terms, diff --git a/src/opt/opt_lns.cpp b/src/opt/opt_lns.cpp index ba5a071d303..878c4a3ea34 100644 --- a/src/opt/opt_lns.cpp +++ b/src/opt/opt_lns.cpp @@ -196,7 +196,7 @@ namespace opt { } unsigned lns::improve_linear(model_ref& mdl) { - scoped_bounding _scoped_bouding(*this); + scoped_bounding _scoped_bounding(*this); unsigned num_improved = 0; unsigned max_conflicts = m_max_conflicts; while (m.inc()) { diff --git a/src/opt/opt_params.pyg b/src/opt/opt_params.pyg index 893b4bfd68a..15f5e5d0525 100644 --- a/src/opt/opt_params.pyg +++ b/src/opt/opt_params.pyg @@ -3,7 +3,7 @@ def_module_params('opt', export=True, params=(('optsmt_engine', SYMBOL, 'basic', "select optimization engine: 'basic', 'symba'"), ('maxsat_engine', SYMBOL, 'maxres', "select engine for maxsat: 'core_maxsat', 'wmax', 'maxres', 'pd-maxres', 'maxres-bin', 'rc2'"), - ('priority', SYMBOL, 'lex', "select how to priortize objectives: 'lex' (lexicographic), 'pareto', 'box'"), + ('priority', SYMBOL, 'lex', "select how to prioritize objectives: 'lex' (lexicographic), 'pareto', 'box'"), ('dump_benchmarks', BOOL, False, 'dump benchmarks for profiling'), ('dump_models', BOOL, False, 'display intermediary models to stdout'), ('solution_prefix', SYMBOL, '', "path prefix to dump intermediary, but non-optimal, solutions"), @@ -22,7 +22,7 @@ def_module_params('opt', ('maxlex.enable', BOOL, True, 'enable maxlex heuristic for lexicographic MaxSAT problems'), ('rc2.totalizer', BOOL, True, 'use totalizer for rc2 encoding'), ('maxres.hill_climb', BOOL, True, 'give preference for large weight cores'), - ('maxres.add_upper_bound_block', BOOL, False, 'restict upper bound with constraint'), + ('maxres.add_upper_bound_block', BOOL, False, 'restrict upper bound with constraint'), ('maxres.max_num_cores', UINT, 200, 'maximal number of cores per round'), ('maxres.max_core_size', UINT, 3, 'break batch of generated cores if size reaches this number'), ('maxres.maximize_assignment', BOOL, False, 'find an MSS/MCS to improve current assignment'), diff --git a/src/opt/opt_solver.cpp b/src/opt/opt_solver.cpp index ee91b06a491..955025fd96c 100644 --- a/src/opt/opt_solver.cpp +++ b/src/opt/opt_solver.cpp @@ -52,7 +52,7 @@ namespace opt { if (m_params.m_case_split_strategy == CS_ACTIVITY_DELAY_NEW) { m_params.m_relevancy_lvl = 0; } - m_params.m_arith_auto_config_simplex = false; + m_params.m_arith_auto_config_simplex = true; m_params.m_threads = 1; // need to interact with the solver that created model so can't have threads // m_params.m_auto_config = false; } @@ -67,7 +67,7 @@ namespace opt { m_dump_benchmarks = p.dump_benchmarks(); m_params.updt_params(_p); m_context.updt_params(_p); - m_params.m_arith_auto_config_simplex = false; + m_params.m_arith_auto_config_simplex = true; } solver* opt_solver::translate(ast_manager& m, params_ref const& p) { diff --git a/src/params/CMakeLists.txt b/src/params/CMakeLists.txt index cdc21da979a..763702caf56 100644 --- a/src/params/CMakeLists.txt +++ b/src/params/CMakeLists.txt @@ -15,6 +15,7 @@ z3_add_component(params poly_rewriter_params.pyg rewriter_params.pyg seq_rewriter_params.pyg + sls_params.pyg solver_params.pyg tactic_params.pyg EXTRA_REGISTER_MODULE_HEADERS diff --git a/src/params/bool_rewriter_params.pyg b/src/params/bool_rewriter_params.pyg index c8d7ddbb74b..87578470eb4 100644 --- a/src/params/bool_rewriter_params.pyg +++ b/src/params/bool_rewriter_params.pyg @@ -4,6 +4,7 @@ def_module_params(module_name='rewriter', params=(("ite_extra_rules", BOOL, True, "extra ite simplifications, these additional simplifications may reduce size locally but increase globally"), ("flat", BOOL, True, "create nary applications for +,*,bvadd,bvmul,bvand,bvor,bvxor"), ("flat_and_or", BOOL, True, "create nary applications for and,or"), + ("sort_disjunctions", BOOL, True, "sort subterms in disjunctions"), ("elim_and", BOOL, False, "conjunctions are rewritten using negation and disjunctions"), ('elim_ite', BOOL, True, "eliminate ite in favor of and/or"), ("local_ctx", BOOL, False, "perform local (i.e., cheap) context simplifications"), diff --git a/src/params/context_params.cpp b/src/params/context_params.cpp index fbdd90b8ce4..1d5d10b399c 100644 --- a/src/params/context_params.cpp +++ b/src/params/context_params.cpp @@ -51,15 +51,18 @@ void context_params::set_uint(unsigned & opt, char const * param, char const * v } } -void context_params::set(char const * param, char const * value) { - std::string p = param; - unsigned n = static_cast(p.size()); - for (unsigned i = 0; i < n; i++) { +static void lower_case(std::string& p) { + for (size_t i = 0; i < p.size(); i++) { if (p[i] >= 'A' && p[i] <= 'Z') p[i] = p[i] - 'A' + 'a'; else if (p[i] == '-') p[i] = '_'; } +} + +void context_params::set(char const * param, char const * value) { + std::string p = param; + lower_case(p); if (p == "timeout") { set_uint(m_timeout, param, value); } @@ -195,5 +198,15 @@ void context_params::get_solver_params(params_ref & p, bool & proofs_enabled, bo p.set_bool("auto_config", false); } +bool context_params::is_shell_only_parameter(char const* _p) const { + std::string p(_p); + lower_case(p); + if (p == "dump_models" || p == "well_sorted_check" || + p == "model_validate" || p == "smtlib2_compliant" || + p == "stats") + return true; + + return false; +} diff --git a/src/params/context_params.h b/src/params/context_params.h index ad4cee31df3..1f169a5bdda 100644 --- a/src/params/context_params.h +++ b/src/params/context_params.h @@ -70,5 +70,10 @@ class context_params { */ params_ref merge_default_params(params_ref const & p); + /** + \brief Is this a parameter that can only be set for the shell. + */ + bool is_shell_only_parameter(char const* p) const; + }; diff --git a/src/params/pattern_inference_params.cpp b/src/params/pattern_inference_params.cpp index 26f606b635e..0e548c89612 100644 --- a/src/params/pattern_inference_params.cpp +++ b/src/params/pattern_inference_params.cpp @@ -21,8 +21,10 @@ Revision History: void pattern_inference_params::updt_params(params_ref const & _p) { pattern_inference_params_helper p(_p); + m_pi_enabled = p.enabled(); m_pi_max_multi_patterns = p.max_multi_patterns(); m_pi_block_loop_patterns = p.block_loop_patterns(); + m_pi_decompose_patterns = p.decompose_patterns(); m_pi_arith = static_cast(p.arith()); m_pi_use_database = p.use_database(); m_pi_arith_weight = p.arith_weight(); @@ -34,8 +36,10 @@ void pattern_inference_params::updt_params(params_ref const & _p) { #define DISPLAY_PARAM(X) out << #X"=" << X << '\n'; void pattern_inference_params::display(std::ostream & out) const { + DISPLAY_PARAM(m_pi_enabled); DISPLAY_PARAM(m_pi_max_multi_patterns); DISPLAY_PARAM(m_pi_block_loop_patterns); + DISPLAY_PARAM(m_pi_decompose_patterns); DISPLAY_PARAM(m_pi_arith); DISPLAY_PARAM(m_pi_use_database); DISPLAY_PARAM(m_pi_arith_weight); diff --git a/src/params/pattern_inference_params.h b/src/params/pattern_inference_params.h index d0510075914..e558a6a7b63 100644 --- a/src/params/pattern_inference_params.h +++ b/src/params/pattern_inference_params.h @@ -27,20 +27,20 @@ enum arith_pattern_inference_kind { }; struct pattern_inference_params { - unsigned m_pi_max_multi_patterns; + bool m_pi_enabled = true; + unsigned m_pi_max_multi_patterns = 1; bool m_pi_block_loop_patterns; + bool m_pi_decompose_patterns; arith_pattern_inference_kind m_pi_arith; bool m_pi_use_database; unsigned m_pi_arith_weight; unsigned m_pi_non_nested_arith_weight; bool m_pi_pull_quantifiers; - int m_pi_nopat_weight; - bool m_pi_avoid_skolems; + int m_pi_nopat_weight = -1; + bool m_pi_avoid_skolems = true; bool m_pi_warnings; - pattern_inference_params(params_ref const & p = params_ref()): - m_pi_nopat_weight(-1), - m_pi_avoid_skolems(true) { + pattern_inference_params(params_ref const & p = params_ref()) { updt_params(p); } diff --git a/src/params/pattern_inference_params_helper.pyg b/src/params/pattern_inference_params_helper.pyg index 52c6c653e2e..80d36e3ec1c 100644 --- a/src/params/pattern_inference_params_helper.pyg +++ b/src/params/pattern_inference_params_helper.pyg @@ -4,8 +4,10 @@ def_module_params(class_name='pattern_inference_params_helper', export=True, params=(('max_multi_patterns', UINT, 0, 'when patterns are not provided, the prover uses a heuristic to infer them, this option sets the threshold on the number of extra multi-patterns that can be created; by default, the prover creates at most one multi-pattern when there is no unary pattern'), ('block_loop_patterns', BOOL, True, 'block looping patterns during pattern inference'), + ('decompose_patterns', BOOL, True, 'allow decomposition of patterns into multipatterns'), ('arith', UINT, 1, '0 - do not infer patterns with arithmetic terms, 1 - use patterns with arithmetic terms if there is no other pattern, 2 - always use patterns with arithmetic terms'), ('use_database', BOOL, False, 'use pattern database'), + ('enabled', BOOL, True, 'enable a heuristic to infer patterns, when they are not provided'), ('arith_weight', UINT, 5, 'default weight for quantifiers where the only available pattern has nested arithmetic terms'), ('non_nested_arith_weight', UINT, 10, 'default weight for quantifiers where the only available pattern has non nested arithmetic terms'), ('pull_quantifiers', BOOL, True, 'pull nested quantifiers, if no pattern was found'), diff --git a/src/params/rewriter_params.pyg b/src/params/rewriter_params.pyg index 290f7b1dafe..20490606cd8 100644 --- a/src/params/rewriter_params.pyg +++ b/src/params/rewriter_params.pyg @@ -8,6 +8,7 @@ def_module_params('rewriter', ("pull_cheap_ite", BOOL, False, "pull if-then-else terms when cheap."), ("bv_ineq_consistency_test_max", UINT, 0, "max size of conjunctions on which to perform consistency test based on inequalities on bitvectors."), ("cache_all", BOOL, False, "cache all intermediate results."), + ("enable_der", BOOL, True, "enable destructive equality resolution to quantifiers."), ("rewrite_patterns", BOOL, False, "rewrite patterns."), ("ignore_patterns_on_ground_qbody", BOOL, True, "ignores patterns on quantifiers that don't mention their bound variables."))) diff --git a/src/tactic/sls/sls_params.pyg b/src/params/sls_params.pyg similarity index 100% rename from src/tactic/sls/sls_params.pyg rename to src/params/sls_params.pyg diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index 081a8c83892..98323816ace 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -102,6 +102,7 @@ namespace smt2 { symbol m_declare_const; symbol m_define_sort; symbol m_declare_sort; + symbol m_declare_type_var; symbol m_declare_datatypes; symbol m_declare_datatype; symbol m_par; @@ -856,6 +857,26 @@ namespace smt2 { if (ct_decls.empty()) throw parser_exception("invalid datatype declaration, datatype does not have any constructors"); } + + void parse_declare_type_var() { + SASSERT(curr_is_identifier()); + SASSERT(curr_id() == m_declare_type_var); + next(); + + check_nonreserved_identifier("invalid sort declaration, symbol expected"); + symbol id = curr_id(); + if (m_ctx.find_psort_decl(id) != nullptr) + throw parser_exception("invalid sort declaration, sort already declared/defined"); + next(); + check_rparen("invalid sort declaration, ')' expected"); + + psort_decl * decl = pm().mk_psort_type_var_decl(id); + m_ctx.insert(decl); + + m_ctx.print_success(); + next(); + + } void parse_declare_datatypes() { SASSERT(curr_is_identifier()); @@ -942,6 +963,7 @@ namespace smt2 { unsigned line = m_scanner.get_line(); unsigned pos = m_scanner.get_pos(); symbol dt_name = curr_id(); + check_identifier("unexpected token used as datatype name"); next(); m_dt_name2idx.reset(); @@ -2975,6 +2997,10 @@ namespace smt2 { parse_declare_sort(); return; } + if (s == m_declare_type_var) { + parse_declare_type_var(); + return; + } if (s == m_declare_datatypes) { parse_declare_datatypes(); return; @@ -3048,6 +3074,7 @@ namespace smt2 { m_declare_const("declare-const"), m_define_sort("define-sort"), m_declare_sort("declare-sort"), + m_declare_type_var("declare-type-var"), m_declare_datatypes("declare-datatypes"), m_declare_datatype("declare-datatype"), m_par("par"), diff --git a/src/parsers/smt2/smt2scanner.cpp b/src/parsers/smt2/smt2scanner.cpp index 2fb45db3b80..8beb782c114 100644 --- a/src/parsers/smt2/smt2scanner.cpp +++ b/src/parsers/smt2/smt2scanner.cpp @@ -106,8 +106,13 @@ namespace smt2 { TRACE("scanner", tout << "new quoted symbol: " << m_id << "\n";); return SYMBOL_TOKEN; } - escape = (c == '\\'); - m_string.push_back(c); + else if (c != '|' && c != '\\' && escape) { + m_string.push_back('\\'); + } + + escape = (c == '\\') && !escape; + if (!escape) + m_string.push_back(c); next(); } } diff --git a/src/qe/lite/CMakeLists.txt b/src/qe/lite/CMakeLists.txt index fc942d4aebc..9b9f5a45aa5 100644 --- a/src/qe/lite/CMakeLists.txt +++ b/src/qe/lite/CMakeLists.txt @@ -1,9 +1,11 @@ z3_add_component(qe_lite SOURCES qe_lite_tactic.cpp + qel.cpp COMPONENT_DEPENDENCIES tactic mbp TACTIC_HEADERS qe_lite_tactic.h + qel.h ) diff --git a/src/qe/lite/qe_lite_tactic.cpp b/src/qe/lite/qe_lite_tactic.cpp index 32d11786cea..03ebc8c4eb2 100644 --- a/src/qe/lite/qe_lite_tactic.cpp +++ b/src/qe/lite/qe_lite_tactic.cpp @@ -487,7 +487,7 @@ namespace qel { ptr_vector vs; expr_ref_vector ts(m); expr_ref t(m); - if (is_var_def(is_exists, args[i], vs, ts)) { + if (is_var_def(is_exists, args[i], vs, ts)) { // vs is the variable, ts is the definition for (unsigned j = 0; j < vs.size(); ++j) { var* v = vs[j]; t = ts.get(j); @@ -2376,7 +2376,7 @@ class qe_lite::impl { m_array_der.set_is_variable_proc(is_var); m_der(fmls); m_fm(fmls); - // AG: disalble m_array_der() since it interferes with other array handling + // AG: disable m_array_der() since it interferes with other array handling if (m_use_array_der) m_array_der(fmls); TRACE("qe_lite", for (unsigned i = 0; i < fmls.size(); ++i) tout << mk_pp(fmls[i].get(), m) << "\n";); } @@ -2392,7 +2392,7 @@ qe_lite::~qe_lite() { } void qe_lite::operator()(app_ref_vector& vars, expr_ref& fml) { - (*m_impl)(vars, fml); + (*m_impl)(vars, fml); } diff --git a/src/qe/lite/qe_lite_tactic.h b/src/qe/lite/qe_lite_tactic.h index 07ce60f35b8..e45c6f7528a 100644 --- a/src/qe/lite/qe_lite_tactic.h +++ b/src/qe/lite/qe_lite_tactic.h @@ -30,6 +30,7 @@ class tactic; class qe_lite { class impl; impl * m_impl; + public: /** use_array_der controls whether equalities over array reads are simplified diff --git a/src/qe/lite/qel.cpp b/src/qe/lite/qel.cpp new file mode 100644 index 00000000000..addd33c3ca1 --- /dev/null +++ b/src/qe/lite/qel.cpp @@ -0,0 +1,52 @@ +/*++ + + Module Name: + + qel.cpp + +Abstract: + Light weight quantifier elimination (QEL) based on term graph. + + The implementation is based on the following paper: + + Isabel Garcia-Contreras, Hari Govind V. K., Sharon Shoham, Arie Gurfinkel: + Fast Approximations of Quantifier Elimination. Computer-Aided Verification + (CAV). 2023. URL: https://arxiv.org/abs/2306.10009 + +Author: + + Hari Govind V K (hgvk94) + Isabel Garcia (igcontreras) + + +--*/ +#include "qe/lite/qel.h" +#include "qe/mbp/mbp_term_graph.h" + +class qel::impl { +private: + ast_manager &m; + +public: + impl(ast_manager &m, params_ref const &p) : m(m) {} + + void operator()(app_ref_vector &vars, expr_ref &fml) { + if (vars.empty()) return; + + mbp::term_graph tg(m); + tg.set_vars(vars); + + expr_ref_vector lits(m); + flatten_and(fml, lits); + tg.add_lits(lits); + tg.qel(vars, fml); + } +}; + +qel::qel(ast_manager &m, params_ref const &p) { m_impl = alloc(impl, m, p); } + +qel::~qel() { dealloc(m_impl); } + +void qel::operator()(app_ref_vector &vars, expr_ref &fml) { + (*m_impl)(vars, fml); +} diff --git a/src/qe/lite/qel.h b/src/qe/lite/qel.h new file mode 100644 index 00000000000..ab960bc476d --- /dev/null +++ b/src/qe/lite/qel.h @@ -0,0 +1,46 @@ +/*++ + + Module Name: + + qel.h + +Abstract: + + Light weight quantifier elimination (QEL) based on term graph. + + The implementation is based on the following paper: + + Isabel Garcia-Contreras, Hari Govind V. K., Sharon Shoham, Arie Gurfinkel: + Fast Approximations of Quantifier Elimination. Computer-Aided Verification + (CAV). 2023. URL: https://arxiv.org/abs/2306.10009 + +Author: + + Hari Govind V K (hgvk94) + Isabel Garcia (igcontreras) + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "ast/ast_util.h" +#include "util/params.h" +#include "util/uint_set.h" + +class qel { + class impl; + impl *m_impl; + +public: + qel(ast_manager &m, params_ref const &p); + + ~qel(); + + /** + \brief Applies light-weight elimination of `vars` provided as vector + of expressions to the cube `fml`. Returns the updated formula and updated + set of variables that were not eliminated. + */ + void operator()(app_ref_vector &vars, expr_ref &fml); +}; diff --git a/src/qe/mbp/CMakeLists.txt b/src/qe/mbp/CMakeLists.txt index 69d5dfd20d5..a5c7e7702a7 100644 --- a/src/qe/mbp/CMakeLists.txt +++ b/src/qe/mbp/CMakeLists.txt @@ -2,7 +2,12 @@ z3_add_component(mbp SOURCES mbp_arith.cpp mbp_arrays.cpp + mbp_arrays_tg.cpp + mbp_basic_tg.cpp mbp_datatypes.cpp + mbp_dt_tg.cpp + mbp_qel.cpp + mbp_qel_util.cpp mbp_plugin.cpp mbp_solve_plugin.cpp mbp_term_graph.cpp diff --git a/src/qe/mbp/mbp_arith.cpp b/src/qe/mbp/mbp_arith.cpp index 5d9d3c19c8c..255c4f8144e 100644 --- a/src/qe/mbp/mbp_arith.cpp +++ b/src/qe/mbp/mbp_arith.cpp @@ -277,8 +277,14 @@ namespace mbp { extract_coefficients(mbo, eval, ts0, tids, coeffs); mbo.add_divides(coeffs, c0, mul1); } - else + else if (a.is_to_real(t)) + throw default_exception("mbp to-real"); + else if (a.is_to_int(t)) + throw default_exception("mbp to-int"); + else { + TRACE("qe", tout << "insert mul " << mk_pp(t, m) << "\n"); insert_mul(t, mul, ts); + } } bool is_numeral(expr* t, rational& r) { @@ -387,8 +393,7 @@ namespace mbp { return false; }; - for (auto& kv : tids) { - expr* e = kv.m_key; + for (auto& [e, v] : tids) { if (is_arith(e) && !is_pure(e) && !var_mark.is_marked(e)) mark_rec(fmls_mark, e); } @@ -669,7 +674,7 @@ namespace mbp { id = mbo.add_var(r, a.is_int(v)); tids.insert(v, id); } - CTRACE("qe", kv.m_value.is_zero(), tout << mk_pp(v, m) << " has coefficeint 0\n";); + CTRACE("qe", kv.m_value.is_zero(), tout << mk_pp(v, m) << " has coefficient 0\n";); if (!kv.m_value.is_zero()) { coeffs.push_back(var(id, kv.m_value)); } diff --git a/src/qe/mbp/mbp_arrays.cpp b/src/qe/mbp/mbp_arrays.cpp index 0f4c805b71b..bf3ad08edc6 100644 --- a/src/qe/mbp/mbp_arrays.cpp +++ b/src/qe/mbp/mbp_arrays.cpp @@ -17,7 +17,6 @@ Revision History: --*/ - #include "util/lbool.h" #include "ast/rewriter/rewriter_def.h" #include "ast/expr_functors.h" @@ -26,134 +25,11 @@ Revision History: #include "ast/rewriter/th_rewriter.h" #include "ast/ast_util.h" #include "ast/ast_pp.h" +#include "ast/array_peq.h" #include "model/model_evaluator.h" #include "qe/mbp/mbp_arrays.h" #include "qe/mbp/mbp_term_graph.h" - -namespace { - bool is_partial_eq (app* a); - - /** - * \brief utility class for partial equalities - * - * A partial equality (a ==I b), for two arrays a,b and a finite set of indices I holds - * iff (Forall i. i \not\in I => a[i] == b[i]); in other words, it is a - * restricted form of the extensionality axiom - * - * using this class, we denote (a =I b) as f(a,b,i0,i1,...) - * where f is an uninterpreted predicate with name PARTIAL_EQ and - * I = {i0,i1,...} - */ - - // TBD: make work for arrays with multiple arguments. - class peq { - ast_manager& m; - expr_ref m_lhs; - expr_ref m_rhs; - vector m_diff_indices; - func_decl_ref m_decl; // the partial equality declaration - app_ref m_peq; // partial equality application - app_ref m_eq; // equivalent std equality using def. of partial eq - array_util m_arr_u; - - public: - static const char* PARTIAL_EQ; - - peq (app* p, ast_manager& m): - m (m), - m_lhs (p->get_arg (0), m), - m_rhs (p->get_arg (1), m), - m_decl (p->get_decl (), m), - m_peq (p, m), - m_eq (m), - m_arr_u (m) - { - VERIFY (is_partial_eq (p)); - SASSERT (m_arr_u.is_array (m_lhs) && - m_arr_u.is_array (m_rhs) && - m_lhs->get_sort() == m_rhs->get_sort()); - unsigned arity = get_array_arity(m_lhs->get_sort()); - for (unsigned i = 2; i < p->get_num_args (); i += arity) { - SASSERT(arity + i <= p->get_num_args()); - expr_ref_vector vec(m); - vec.append(arity, p->get_args() + i); - m_diff_indices.push_back (vec); - } - } - - peq (expr* lhs, expr* rhs, vector const& diff_indices, ast_manager& m): - m (m), - m_lhs (lhs, m), - m_rhs (rhs, m), - m_diff_indices (diff_indices), - m_decl (m), - m_peq (m), - m_eq (m), - m_arr_u (m) { - SASSERT (m_arr_u.is_array (lhs) && - m_arr_u.is_array (rhs) && - lhs->get_sort() == rhs->get_sort()); - ptr_vector sorts; - sorts.push_back (m_lhs->get_sort ()); - sorts.push_back (m_rhs->get_sort ()); - for (auto const& v : diff_indices) { - SASSERT(v.size() == get_array_arity(m_lhs->get_sort())); - for (expr* e : v) - sorts.push_back (e->get_sort()); - } - m_decl = m.mk_func_decl (symbol (PARTIAL_EQ), sorts.size (), sorts.data (), m.mk_bool_sort ()); - } - - expr_ref lhs () { return m_lhs; } - - expr_ref rhs () { return m_rhs; } - - void get_diff_indices (vector& result) { result.append(m_diff_indices); } - - app_ref mk_peq () { - if (!m_peq) { - ptr_vector args; - args.push_back (m_lhs); - args.push_back (m_rhs); - for (auto const& v : m_diff_indices) { - args.append (v.size(), v.data()); - } - m_peq = m.mk_app (m_decl, args.size (), args.data ()); - } - return m_peq; - } - - app_ref mk_eq (app_ref_vector& aux_consts, bool stores_on_rhs = true) { - if (!m_eq) { - expr_ref lhs (m_lhs, m), rhs (m_rhs, m); - if (!stores_on_rhs) { - std::swap (lhs, rhs); - } - // lhs = (...(store (store rhs i0 v0) i1 v1)...) - sort* val_sort = get_array_range (lhs->get_sort()); - for (expr_ref_vector const& diff : m_diff_indices) { - ptr_vector store_args; - store_args.push_back (rhs); - store_args.append (diff.size(), diff.data()); - app_ref val(m.mk_fresh_const ("diff", val_sort), m); - store_args.push_back (val); - aux_consts.push_back (val); - rhs = m_arr_u.mk_store (store_args); - } - m_eq = m.mk_eq (lhs, rhs); - } - return m_eq; - } - }; - - const char* peq::PARTIAL_EQ = "!partial_eq"; - - bool is_partial_eq (app* a) { - return a->get_decl ()->get_name () == peq::PARTIAL_EQ; - } -} - namespace mbp { @@ -366,20 +242,10 @@ namespace mbp { } } - /** - * mk (e0 ==indices e1) - * - * result has stores if either e0 or e1 or an index term has stores - */ - app_ref mk_peq (expr* e0, expr* e1, vector const& indices) { - peq p (e0, e1, indices, m); - return p.mk_peq (); - } - void find_subst_term (app* eq) { SASSERT(m.is_eq(eq)); vector empty; - app_ref p_exp = mk_peq (eq->get_arg (0), eq->get_arg (1), empty); + app_ref p_exp = mk_peq (eq->get_arg (0), eq->get_arg (1), empty, m); bool subst_eq_found = false; while (true) { TRACE ("qe", tout << "processing peq:\n" << p_exp << "\n";); @@ -434,7 +300,7 @@ namespace mbp { ); // arr0 ==I arr1 - p_exp = mk_peq (arr0, arr1, I); + p_exp = mk_peq (arr0, arr1, I, m); TRACE ("qe", tout << "new peq:\n"; @@ -445,7 +311,7 @@ namespace mbp { m_idx_lits_v.append (idx_diseq); // arr0 ==I+idx arr1 I.push_back (idxs); - p_exp = mk_peq (arr0, arr1, I); + p_exp = mk_peq (arr0, arr1, I, m); TRACE ("qe", tout << "new peq:\n" << p_exp << "\n"; ); diff --git a/src/qe/mbp/mbp_arrays_tg.cpp b/src/qe/mbp/mbp_arrays_tg.cpp new file mode 100644 index 00000000000..c3d91ae6455 --- /dev/null +++ b/src/qe/mbp/mbp_arrays_tg.cpp @@ -0,0 +1,425 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_arrays_tg.cpp + +Abstract: + + Apply rules for model based projection for arrays on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#include "qe/mbp/mbp_arrays_tg.h" +#include "ast/array_decl_plugin.h" +#include "ast/array_peq.h" +#include "qe/mbp/mbp_qel_util.h" +#include "util/obj_hashtable.h" +#include "util/obj_pair_hashtable.h" + +namespace mbp { + +struct mbp_array_tg::impl { + typedef std::pair expr_pair; + ast_manager &m; + array_util m_array_util; + mbp::term_graph &m_tg; + // TODO: cache mdl evaluation eventhough we extend m_mdl + model &m_mdl; + + // set of variables on which to apply MBP rules + obj_hashtable &m_vars_set; + + // variables created in the last iteration of MBP application + app_ref_vector m_new_vars; + + expr_sparse_mark &m_seen; + obj_pair_hashtable m_seenp; + + // apply rules that split on model + bool m_use_mdl; + + // m_has_store.is_marked(t) if t has a subterm store(v) where v is a + // variable to be eliminated + ast_mark m_has_stores; + // variables required for applying rules + vector indices; + expr_ref_vector terms, rdTerms; + + bool has_var(expr *t) { return contains_vars(t, m_vars_set, m); } + + bool has_arr_var(expr *t) { + return contains_vars(t, m_vars_set, m, m_array_util.get_family_id(), + ARRAY_SORT); + } + + bool is_var(expr *t) { return is_uninterp_const(t) && has_var(t); } + + bool is_wr_on_rhs(expr *e) { + return is_app(e) && is_partial_eq(to_app(e)) && + is_wr_on_rhs(to_app(e)->get_arg(0), to_app(e)->get_arg(1)); + } + + bool is_wr_on_rhs(expr *lhs, expr *rhs) { + return (is_arr_write(rhs) && !is_arr_write(lhs)); + } + + bool is_arr_write(expr *t) { + return m_array_util.is_store1(t) && has_var(to_app(t)); + } + + bool is_arr_write(expr *t, expr*& a, expr*& i, expr*& v) { + return m_array_util.is_store1(t, a, i, v) && has_var(to_app(t)); + } + + // Returns true if e has a subterm store(v) where v is a variable to be + // eliminated. Recurses on subexpressions of ee + bool has_stores(expr *e) { + if (m_has_stores.is_marked(e)) return true; + if (!is_app(e)) return false; + if (m_array_util.is_store(e) && is_var(to_app(e)->get_arg(0))) { + m_has_stores.mark(e, true); + return true; + } + if (any_of(*(to_app(e)), [&](expr* c) { return m_has_stores.is_marked(c); })) { + m_has_stores.mark(e, true); + return true; + } + //recurse + for(auto c : *(to_app(e))) { + if (has_stores(c)) { + m_has_stores.mark(e, true); + return true; + } + } + return false; + } + + // + // the code that uses this assumes that select takes only two arguments. + // Note that select may take more than two arguments in general. + // + bool is_rd_wr(expr *t) { + expr* a, *idx; + return m_array_util.is_select1(t, a, idx) && + m_array_util.is_store(a) && + has_stores(a); + } + + bool is_rd_wr(expr* t, expr*& wr_ind, expr*& rd_ind, expr*& b, expr*& v) { + if (!is_rd_wr(t)) + return false; + expr* a; + VERIFY(m_array_util.is_select1(t, a, rd_ind)); + VERIFY(m_array_util.is_store1(a, b, wr_ind, v)); + return true; + } + + bool is_implicit_peq(expr *e) { + expr* a, *b; + return is_implicit_peq(e, a, b); + } + + bool is_implicit_peq(expr *e, expr*& a, expr*& b) { + return m.is_eq(e, a, b) && is_implicit_peq(a, b); + } + + bool is_implicit_peq(expr *lhs, expr *rhs) { + return m_array_util.is_array(lhs) && m_array_util.is_array(rhs) && + (has_var(lhs) || has_var(rhs)); + } + + bool is_neg_peq(expr *e, expr*& a, expr*& b) { + expr* ne; + return m.is_not(e, ne) && is_implicit_peq(ne, a, b); + } + + bool is_neg_peq(expr *e) { + expr* ne; + return m.is_not(e, ne) && is_implicit_peq(ne); + } + + void mark_seen(expr *t) { m_seen.mark(t); } + bool is_seen(expr *t) { return m_seen.is_marked(t); } + void mark_seen(expr *t1, expr *t2) { m_seenp.insert(expr_pair(t1, t2)); } + bool is_seen(expr *t1, expr *t2) { + return m_seenp.contains(expr_pair(t1, t2)) || + m_seenp.contains(expr_pair(t2, t1)); + } + + impl(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen) + : m(man), m_array_util(m), m_tg(tg), m_mdl(mdl), m_vars_set(vars_set), + m_new_vars(m), m_seen(seen), m_use_mdl(false), terms(m), rdTerms(m) {} + + // create a peq where write terms are preferred on the left hand side + peq mk_wr_peq(expr *e1, expr *e2) { + vector empty; + return mk_wr_peq(e1, e2, empty); + } + + // create a peq where write terms are preferred on the left hand side + peq mk_wr_peq(expr *e1, expr *e2, vector &indices) { + expr *n_lhs = e1, *n_rhs = e2; + if (is_wr_on_rhs(e1, e2)) + std::swap(n_lhs, n_rhs); + return peq(n_lhs, n_rhs, indices, m); + } + + // rewrite store(x, j, elem) \peq_{indices} y + // into either j = i && x \peq_{indices} y (for some i in + // indices) or &&_{i \in indices} j \neq i && + // x \peq_{indices, j} y && + // select(y, j) = elem + // rewrite negation !(store(x, j, elem) \peq_{indices} y) into + // into either j = i && !(x \peq_{indices} y) (for some i in + // indices) or &&_{i \in indices} j \neq i && + // !(x \peq_{indices, j} y) && + // or &&_{i \in indices} j \neq i && + // !(select(y, j) = elem) + void elimwreq(peq p, bool is_neg) { + expr* a, *j, *elem; + VERIFY(is_arr_write(p.lhs(), a, j, elem)); + TRACE("mbp_tg", + tout << "applying elimwreq on " << expr_ref(p.mk_peq(), m) << " is neg: " << is_neg;); + vector indices; + bool in = false; + p.get_diff_indices(indices); + expr_ref eq_index(m); + expr_ref_vector deq(m); + for (expr_ref_vector &e : indices) { + for (expr *i : e) { + if (m_mdl.are_equal(j, i)) { + in = true; + // save for later + eq_index = i; + break; + } else + deq.push_back(i); + } + } + if (in) { + SASSERT(m_mdl.are_equal(j, eq_index)); + peq p_new = + mk_wr_peq(a, p.rhs(), indices); + m_tg.add_eq(j, eq_index); + expr_ref p_new_expr(m); + p_new_expr = is_neg ? m.mk_not(p_new.mk_peq()) : p_new.mk_peq(); + m_tg.add_lit(p_new_expr); + m_tg.add_eq(p_new.mk_peq(), p.mk_peq()); + return; + } + for (expr *d : deq) { m_tg.add_deq(j, d); } + expr_ref_vector setOne(m); + setOne.push_back(j); + indices.push_back(setOne); + peq p_new = mk_wr_peq(a, p.rhs(), indices); + expr_ref rd(m_array_util.mk_select(p.rhs(), j), m); + if (!is_neg) { + m_tg.add_lit(p_new.mk_peq()); + m_tg.add_eq(rd, elem); + m_tg.add_eq(p.mk_peq(), p_new.mk_peq()); + } else { + expr_ref rd_eq(m.mk_eq(rd, elem), m); + if (m_mdl.is_false(rd_eq)) { m_tg.add_deq(rd, elem); } + else { + expr_ref npeq(mk_not(p_new.mk_peq()), m); + m_tg.add_lit(npeq); + m_tg.add_eq(p.mk_peq(), p_new.mk_peq()); + } + } + } + + // add equality v = rd where v is a fresh variable + void add_rdVar(expr *rd) { + // do not assign new variable if rd is already equal to a value + if (m_tg.has_val_in_class(rd)) return; + TRACE("mbp_tg", tout << "applying add_rdVar on " << expr_ref(rd, m);); + app_ref u = new_var(to_app(rd)->get_sort(), m); + m_new_vars.push_back(u); + m_tg.add_var(u); + m_tg.add_eq(u, rd); + m_mdl.register_decl(u->get_decl(), m_mdl(rd)); + } + + // given a \peq_{indices} t, where a is a variable, merge equivalence class + // of a with store(t, indices, elems) where elems are fresh constants + void elimeq(peq p) { + TRACE("mbp_tg", + tout << "applying elimeq on " << expr_ref(p.mk_peq(), m);); + app_ref_vector aux_consts(m); + expr_ref eq(m); + expr_ref sel(m); + eq = p.mk_eq(aux_consts, true); + vector indices; + p.get_diff_indices(indices); + vector::iterator itr = indices.begin(); + unsigned i = 0; + for (app *a : aux_consts) { + m_new_vars.push_back(a); + m_tg.add_var(a); + auto const &indx = std::next(itr, i); + SASSERT(indx->size() == 1); + sel = m_array_util.mk_select(p.lhs(), indx->get(0)); + m_mdl.register_decl(a->get_decl(), m_mdl(sel)); + i++; + } + m_tg.add_lit(eq); + m_tg.add_eq(p.mk_peq(), m.mk_true()); + TRACE("mbp_tg", tout << "added lit " << eq;); + } + + // rewrite select(store(a, i, k), j) into either select(a, j) or k + void elimrdwr(expr *term) { + TRACE("mbp_tg", tout << "applying elimrdwr on " << expr_ref(term, m);); + expr* wr_ind, *rd_ind, *b, *v; + VERIFY(is_rd_wr(term, wr_ind, rd_ind, b, v)); + if (m_mdl.are_equal(wr_ind, rd_ind)) + m_tg.add_eq(wr_ind, rd_ind); + else { + m_tg.add_deq(wr_ind, rd_ind); + v = m_array_util.mk_select(b, rd_ind); + } + m_tg.add_eq(term, v); + } + + // iterate through all terms in m_tg and apply all array MBP rules once + // returns true if any rules were applied + bool apply() { + TRACE("mbp_tg", tout << "Iterating over terms of tg";); + indices.reset(); + rdTerms.reset(); + m_new_vars.reset(); + expr_ref e(m), rdEq(m), rdDeq(m); + expr *nt, *term; + bool progress = false, is_neg = false; + + // Not resetting terms because get_terms calls resize on terms + m_tg.get_terms(terms, false); + for (unsigned i = 0; i < terms.size(); i++) { + term = terms.get(i); + if (m_seen.is_marked(term)) + continue; + if (m_tg.is_cgr(term)) + continue; + TRACE("mbp_tg", tout << "processing " << expr_ref(term, m);); + expr* a, *b; + if (is_implicit_peq(term, a, b) || is_neg_peq(term, a, b)) { + // rewrite array eq as peq + mark_seen(term); + progress = true; + nt = term; + bool is_not = m.is_not(term, nt); + e = mk_wr_peq(a, b).mk_peq(); + e = is_not ? m.mk_not(e) : e.get(); + m_tg.add_lit(e); + m_tg.add_eq(term, e); + continue; + } + nt = term; + is_neg = m.is_not(term, nt); + if (is_app(nt) && is_partial_eq(to_app(nt))) { + peq p(to_app(nt), m); + if (m_use_mdl && is_arr_write(p.lhs())) { + mark_seen(nt); + mark_seen(term); + progress = true; + elimwreq(p, is_neg); + continue; + } + if (!m_array_util.is_store(p.lhs()) && has_var(p.lhs()) && !is_neg) { + // TODO: don't apply this rule if vars in p.lhs() also + // appear in p.rhs() + + mark_seen(p.lhs()); + mark_seen(nt); + mark_seen(term); + progress = true; + elimeq(p); + continue; + } + // eliminate eq when the variable is on the rhs + if (!m_array_util.is_store(p.rhs()) && has_var(p.rhs()) && !is_neg) { + mark_seen(p.rhs()); + p.get_diff_indices(indices); + peq p_new = mk_wr_peq(p.rhs(), p.lhs(), indices); + mark_seen(nt); + mark_seen(term); + progress = true; + elimeq(p_new); + continue; + } + } + if (m_use_mdl && is_rd_wr(nt)) { + mark_seen(term); + progress = true; + elimrdwr(nt); + continue; + } + } + + // iterate over term graph again to collect read terms + // irrespective of whether they have been marked or not + rdTerms.reset(); + for (unsigned i = 0; i < terms.size(); i++) { + term = terms.get(i); + if (m_array_util.is_select(term) && + has_var(to_app(term)->get_arg(0))) { + rdTerms.push_back(term); + if (is_seen(term)) continue; + add_rdVar(term); + mark_seen(term); + } + } + if (!m_use_mdl) return progress; + expr *e1, *e2, *a1, *a2, *i1, *i2; + for (unsigned i = 0; i < rdTerms.size(); i++) { + e1 = rdTerms.get(i); + a1 = to_app(e1)->get_arg(0); + i1 = to_app(e1)->get_arg(1); + for (unsigned j = i + 1; j < rdTerms.size(); j++) { + e2 = rdTerms.get(j); + a2 = to_app(e2)->get_arg(0); + i2 = to_app(e2)->get_arg(1); + if (!is_seen(e1, e2) && a1->get_id() == a2->get_id()) { + mark_seen(e1, e2); + progress = true; + if (m_mdl.are_equal(i1, i2)) { + m_tg.add_eq(i1, i2); + } else { + SASSERT(!m_mdl.are_equal(i1, i2)); + m_tg.add_deq(i1, i2); + } + continue; + } + } + } + return progress; + } +}; + +void mbp_array_tg::use_model() { m_impl->m_use_mdl = true; } +bool mbp_array_tg::apply() { return m_impl->apply(); } +void mbp_array_tg::reset() { + m_impl->m_seen.reset(); + m_impl->m_vars_set.reset(); +} +void mbp_array_tg::get_new_vars(app_ref_vector *&t) { t = &m_impl->m_new_vars; } +family_id mbp_array_tg::get_family_id() const { + return m_impl->m_array_util.get_family_id(); +} +mbp_array_tg::mbp_array_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, + expr_sparse_mark &seen) { + m_impl = alloc(mbp_array_tg::impl, man, tg, mdl, vars_set, seen); +} +mbp_array_tg::~mbp_array_tg() { dealloc(m_impl); } + +} // namespace mbp diff --git a/src/qe/mbp/mbp_arrays_tg.h b/src/qe/mbp/mbp_arrays_tg.h new file mode 100644 index 00000000000..0c634bdd0f0 --- /dev/null +++ b/src/qe/mbp/mbp_arrays_tg.h @@ -0,0 +1,47 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_arrays_tg.h + +Abstract: + + Apply rules for model based projection for arrays on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "qe/mbp/mbp_qel_util.h" +#include "qe/mbp/mbp_term_graph.h" +#include "qe/mbp/mbp_tg_plugins.h" +#include "util/memory_manager.h" +#include "util/obj_hashtable.h" +#include "util/obj_pair_hashtable.h" + +namespace mbp { +class mbp_array_tg : public mbp_tg_plugin { + struct impl; + impl *m_impl; + + public: + mbp_array_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen); + void use_model() override; + void reset(); + // iterate through all terms in m_tg and apply all array MBP rules once + // returns true if any rules were applied + bool apply() override; + ~mbp_array_tg() override; + void get_new_vars(app_ref_vector *&t) override; + family_id get_family_id() const override; +}; +} // namespace mbp diff --git a/src/qe/mbp/mbp_basic_tg.cpp b/src/qe/mbp/mbp_basic_tg.cpp new file mode 100644 index 00000000000..ce5e99eb1bd --- /dev/null +++ b/src/qe/mbp/mbp_basic_tg.cpp @@ -0,0 +1,171 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_basic_tg.cpp + +Abstract: + + Apply rules for model based projection for basic types, on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#include "qe/mbp/mbp_basic_tg.h" +#include "ast/ast.h" +#include "ast/expr_functors.h" +#include "util/debug.h" +#include "util/memory_manager.h" + +struct mbp_basic_tg::impl { + ast_manager &m; + mbp::term_graph &m_tg; + // TODO: cache mdl evaluation eventhough we extend m_mdl + model &m_mdl; + + // set of variables on which to apply MBP rules + obj_hashtable &m_vars_set; + + // variables created in the last iteration of MBP application + app_ref_vector m_new_vars; + + expr_sparse_mark &m_seen; + + expr_ref_vector terms; + bool m_use_mdl; + + void mark_seen(expr *t) { m_seen.mark(t); } + bool is_seen(expr *t) { return m_seen.is_marked(t); } + + // Split on all ite terms, irrespective of whether + // they contain variables/are c-ground + bool apply() { + std::function should_split, is_true, is_false; + if (!m_use_mdl) { + should_split = [&](expr *t) { return m_tg.has_val_in_class(t); }; + is_true = [&](expr *t) { + return m_tg.has_val_in_class(t) && m_mdl.is_true(t); + }; + is_false = [&](expr *t) { + return m_tg.has_val_in_class(t) && m_mdl.is_false(t); + }; + } else { + should_split = [](expr *t) { return true; }; + is_true = [&](expr *t) { return m_mdl.is_true(t); }; + is_false = [&](expr *t) { return m_mdl.is_false(t); }; + } + + expr *c, *th, *el; + expr_ref nterm(m); + bool progress = false; + TRACE("mbp_tg", tout << "Iterating over terms of tg";); + // Not resetting terms because get_terms calls resize on terms + m_tg.get_terms(terms, false); + for (expr *term : terms) { + if (is_seen(term)) continue; + TRACE("mbp_tg", tout << "Processing " << expr_ref(term, m) << "\n";); + if (m.is_ite(term, c, th, el) && should_split(c)) { + mark_seen(term); + progress = true; + if (m_mdl.is_true(c)) { + m_tg.add_lit(c); + m_tg.add_eq(term, th); + } else { + nterm = mk_not(m, c); + m_tg.add_lit(nterm); + m_tg.add_eq(term, el); + } + } + if (m.is_implies(term, c, th)) { + if (is_true(th) || is_false(c)) { + mark_seen(term); + progress = true; + if (is_true(th)) + m_tg.add_lit(th); + else if (is_false(c)) + m_tg.add_lit(c); + m_tg.add_eq(term, m.mk_true()); + } else if (is_true(c) && is_false(th)) { + mark_seen(term); + progress = true; + m_tg.add_eq(term, m.mk_false()); + } + } + if (m.is_or(term) || m.is_and(term)) { + bool is_or = m.is_or(term); + app *c = to_app(term); + bool t = is_or ? any_of(*c, is_true) : all_of(*c, is_true); + bool f = is_or ? all_of(*c, is_false) : all_of(*c, is_false); + if (t || f) { + mark_seen(term); + progress = true; + m_tg.add_eq(term, t ? m.mk_true() : m.mk_false()); + if (f) { + for (auto a : *c) { + if (is_false(a)) { + m_tg.add_lit(mk_not(m, a)); + if (!is_or) break; + } + } + } else { + for (auto a : *c) { + if (is_true(a)) { + m_tg.add_lit(a); + if (is_or) break; + } + } + } + } + } + if (m_use_mdl && m.is_distinct(term)) { + mark_seen(term); + progress = true; + bool eq = false; + app *c = to_app(term); + for (auto a1 : *c) { + for (auto a2 : *c) { + if (a1 == a2) continue; + expr_ref e(m.mk_eq(a1, a2), m); + if (m_mdl.is_true(e)) { + m_tg.add_eq(a1, a2); + eq = true; + break; + } else { + SASSERT(m_mdl.is_false(e)); + m_tg.add_deq(a1, a2); + } + } + } + if (eq) + m_tg.add_eq(term, m.mk_false()); + else + m_tg.add_eq(term, m.mk_true()); + } + } + return progress; + } + + impl(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen) + : m(man), m_tg(tg), m_mdl(mdl), m_vars_set(vars_set), m_new_vars(m), + m_seen(seen), terms(m), m_use_mdl(false) {} +}; + +bool mbp_basic_tg::apply() { return m_impl->apply(); } +void mbp_basic_tg::use_model() { m_impl->m_use_mdl = true; } +void mbp_basic_tg::get_new_vars(app_ref_vector *&t) { t = &m_impl->m_new_vars; } +family_id mbp_basic_tg::get_family_id() const { + return m_impl->m.get_basic_family_id(); +} +mbp_basic_tg::mbp_basic_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, + expr_sparse_mark &seen) { + m_impl = alloc(mbp_basic_tg::impl, man, tg, mdl, vars_set, seen); +} +mbp_basic_tg::~mbp_basic_tg() { dealloc(m_impl); } diff --git a/src/qe/mbp/mbp_basic_tg.h b/src/qe/mbp/mbp_basic_tg.h new file mode 100644 index 00000000000..af7c624c0d8 --- /dev/null +++ b/src/qe/mbp/mbp_basic_tg.h @@ -0,0 +1,40 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_basic_tg.h + +Abstract: + + Apply rules for model based projection for basic types, on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ +#pragma once + +#include "qe/mbp/mbp_qel_util.h" +#include "qe/mbp/mbp_term_graph.h" +#include "qe/mbp/mbp_tg_plugins.h" +#include "util/obj_hashtable.h" + +class mbp_basic_tg : public mbp_tg_plugin { + struct impl; + impl *m_impl; + + public: + mbp_basic_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen); + // iterate through all terms in m_tg and apply all basic MBP rules once + // returns true if any rules were applied + bool apply() override; + ~mbp_basic_tg() override; + void use_model() override; + void get_new_vars(app_ref_vector *&t) override; + family_id get_family_id() const override; +}; diff --git a/src/qe/mbp/mbp_dt_tg.cpp b/src/qe/mbp/mbp_dt_tg.cpp new file mode 100644 index 00000000000..626e8a0e48b --- /dev/null +++ b/src/qe/mbp/mbp_dt_tg.cpp @@ -0,0 +1,201 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_dt_tg.cpp + +Abstract: + + Apply rules for model based projection for datatypes on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ +#include "qe/mbp/mbp_dt_tg.h" +#include "qe/mbp/mbp_qel_util.h" +#include "util/memory_manager.h" + +namespace mbp { + +struct mbp_dt_tg::impl { + ast_manager &m; + datatype_util m_dt_util; + mbp::term_graph &m_tg; + // TODO: cache mdl evaluation eventhough we extend m_mdl + model &m_mdl; + + // set of variables on which to apply MBP rules + obj_hashtable &m_vars_set; + + // variables created in the last iteration of MBP application + app_ref_vector m_new_vars; + + expr_sparse_mark &m_seen; + + expr_ref_vector terms; + bool m_use_mdl; + + void mark_seen(expr *t) { m_seen.mark(t); } + bool is_seen(expr *t) { return m_seen.is_marked(t); } + + bool is_var(expr *t) { return is_uninterp_const(t) && has_var(t); } + + bool has_var(expr *t) { return contains_vars(t, m_vars_set, m); } + + bool is_constructor(expr *t) { + return is_app(t) && m_dt_util.is_constructor(to_app(t)->get_decl()) && + has_var(t); + } + + bool is_constructor_app(expr *e, expr *&cons, expr *&rhs) { + if (!m.is_eq(e, cons, rhs)) return false; + // TODO: does it matter whether vars in cons appear in rhs? + if (is_constructor(cons)) { + return true; + } else if (is_constructor(rhs)) { + cons = rhs; + rhs = to_app(e)->get_arg(0); + return true; + } + return false; + } + + impl(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen) + : m(man), m_dt_util(m), m_tg(tg), m_mdl(mdl), m_vars_set(vars_set), + m_new_vars(m), m_seen(seen), terms(m), m_use_mdl(false) {} + + // rewrite head(x) with y + // and x with list(y, z) + void rm_select(expr *term) { + SASSERT(is_app(term) && + m_dt_util.is_accessor(to_app(term)->get_decl()) && + is_var(to_app(term)->get_arg(0))); + TRACE("mbp_tg", tout << "applying rm_select on " << expr_ref(term, m);); + expr *v = to_app(term)->get_arg(0); + expr_ref sel(m); + app_ref u(m); + app_ref_vector new_vars(m); + func_decl *cons = + m_dt_util.get_accessor_constructor(to_app(term)->get_decl()); + ptr_vector const *accessors = + m_dt_util.get_constructor_accessors(cons); + for (unsigned i = 0; i < accessors->size(); i++) { + func_decl *d = accessors->get(i); + sel = m.mk_app(d, v); + u = m_tg.get_const_in_class(sel); + if (u) { + new_vars.push_back(u); + continue; + } + u = new_var(d->get_range(), m); + m_new_vars.push_back(u); + m_tg.add_var(u); + new_vars.push_back(u); + m_tg.add_eq(sel, u); + m_mdl.register_decl(u->get_decl(), m_mdl(sel)); + } + expr_ref new_cons(m.mk_app(cons, new_vars), m); + m_tg.add_eq(v, new_cons); + } + + // rewrite cons(v, u) = x with v = head(x) and u = tail(x) + // where u or v contain variables + void deconstruct_eq(expr *cons, expr *rhs) { + TRACE("mbp_tg", + tout << "applying deconstruct_eq on " << expr_ref(cons, m);); + ptr_vector const *accessors = + m_dt_util.get_constructor_accessors(to_app(cons)->get_decl()); + for (unsigned i = 0; i < accessors->size(); i++) { + expr_ref a(m.mk_app(accessors->get(i), rhs), m); + expr *newRhs = to_app(cons)->get_arg(i); + m_tg.add_eq(a, newRhs); + } + func_decl *is_cons = + m_dt_util.get_constructor_recognizer(to_app(cons)->get_decl()); + expr_ref is(m.mk_app(is_cons, rhs), m); + m_tg.add_lit(is); + } + + // rewrite cons(v, u) != x into one of !cons(x) or v != head(x) or u != + // tail(x) where u or v contain variables + void deconstruct_neq(expr *cons, expr *rhs) { + TRACE("mbp_tg", + tout << "applying deconstruct_neq on " << expr_ref(cons, m);); + ptr_vector const *accessors = + m_dt_util.get_constructor_accessors(to_app(cons)->get_decl()); + func_decl *is_cons = + m_dt_util.get_constructor_recognizer(to_app(cons)->get_decl()); + expr_ref a(m.mk_app(is_cons, rhs), m); + if (m_mdl.is_false(a)) { + expr_ref not_cons(m.mk_not(a), m); + m_tg.add_lit(not_cons); + return; + } + m_tg.add_lit(a); + + for (unsigned i = 0; i < accessors->size(); i++) { + expr_ref a(m.mk_app(accessors->get(i), rhs), m); + expr *newRhs = to_app(cons)->get_arg(i); + if (!m_mdl.are_equal(a, newRhs)) { + m_tg.add_deq(a, newRhs); + break; + } + } + } + + bool apply() { + expr *cons, *rhs, *f, *term; + bool progress = false; + m_new_vars.reset(); + TRACE("mbp_tg", tout << "Iterating over terms of tg";); + // Not resetting terms because get_terms calls resize on terms + m_tg.get_terms(terms, false); + for (unsigned i = 0; i < terms.size(); i++) { + term = terms.get(i); + if (is_seen(term)) continue; + if (m_tg.is_cgr(term)) continue; + if (is_app(term) && + m_dt_util.is_accessor(to_app(term)->get_decl()) && + is_var(to_app(term)->get_arg(0))) { + mark_seen(term); + progress = true; + rm_select(term); + continue; + } + if (is_constructor_app(term, cons, rhs)) { + mark_seen(term); + progress = true; + deconstruct_eq(cons, rhs); + continue; + } + if (m_use_mdl && m.is_not(term, f) && + is_constructor_app(f, cons, rhs)) { + mark_seen(term); + progress = true; + deconstruct_neq(cons, rhs); + continue; + } + } + return progress; + } +}; + +bool mbp_dt_tg::apply() { return m_impl->apply(); } +mbp_dt_tg::mbp_dt_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen) { + m_impl = alloc(mbp_dt_tg::impl, man, tg, mdl, vars_set, seen); +} +void mbp_dt_tg::use_model() { m_impl->m_use_mdl = true; } +void mbp_dt_tg::get_new_vars(app_ref_vector *&t) { t = &m_impl->m_new_vars; } +family_id mbp_dt_tg::get_family_id() const { + return m_impl->m_dt_util.get_family_id(); +} +mbp_dt_tg::~mbp_dt_tg() { dealloc(m_impl); } + +} // namespace mbp diff --git a/src/qe/mbp/mbp_dt_tg.h b/src/qe/mbp/mbp_dt_tg.h new file mode 100644 index 00000000000..8ea0d8e482c --- /dev/null +++ b/src/qe/mbp/mbp_dt_tg.h @@ -0,0 +1,44 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_dt_tg.h + +Abstract: + + Apply rules for model based projection for datatypes on a term graph + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#pragma once + +#include "ast/datatype_decl_plugin.h" +#include "qe/mbp/mbp_qel_util.h" +#include "qe/mbp/mbp_term_graph.h" +#include "qe/mbp/mbp_tg_plugins.h" +#include "util/obj_hashtable.h" + +namespace mbp { +class mbp_dt_tg : public mbp_tg_plugin { + struct impl; + impl *m_impl; + + public: + mbp_dt_tg(ast_manager &man, mbp::term_graph &tg, model &mdl, + obj_hashtable &vars_set, expr_sparse_mark &seen); + // iterate through all terms in m_tg and apply all datatype MBP rules once + // returns true if any rules were applied + bool apply() override; + ~mbp_dt_tg() override; + void use_model() override; + void get_new_vars(app_ref_vector *&t) override; + family_id get_family_id() const override; +}; +} // namespace mbp diff --git a/src/qe/mbp/mbp_qel.cpp b/src/qe/mbp/mbp_qel.cpp new file mode 100644 index 00000000000..5778419c196 --- /dev/null +++ b/src/qe/mbp/mbp_qel.cpp @@ -0,0 +1,233 @@ +/*++ + + Module Name: + + mbp_qel.cpp + +Abstract: + + Model Based Projection based on term graph + +Author: + + Hari Govind V K (hgvk94) 2022-07-12 + +Revision History: + + +--*/ +#include "qe/mbp/mbp_qel.h" +#include "ast/array_decl_plugin.h" +#include "ast/array_peq.h" +#include "ast/datatype_decl_plugin.h" +#include "model/model.h" +#include "qe/mbp/mbp_arrays.h" +#include "qe/mbp/mbp_arrays_tg.h" +#include "qe/mbp/mbp_basic_tg.h" +#include "qe/mbp/mbp_dt_tg.h" +#include "qe/mbp/mbp_term_graph.h" +#include "qe/mbp/mbp_tg_plugins.h" +#include "util/obj_hashtable.h" + +namespace mbp { + +class mbp_qel::impl { +private: + ast_manager &m; + array_util m_array_util; + datatype_util m_dt_util; + params_ref m_params; + mbp::term_graph m_tg; + + ptr_vector m_plugins; + + // set of non_basic variables to be projected. MBP rules are applied to + // terms containing these variables + obj_hashtable m_non_basic_vars; + + // Utilities to keep track of which terms have been processed + expr_sparse_mark m_seen; + void mark_seen(expr *t) { m_seen.mark(t); } + bool is_seen(expr *t) { return m_seen.is_marked(t); } + + bool is_non_basic(app *v) { + return m_dt_util.is_datatype(v->get_sort()) || m_array_util.is_array(v); + } + + void add_vars(mbp_tg_plugin *p, app_ref_vector &vars) { + app_ref_vector *new_vars; + p->get_new_vars(new_vars); + for (auto v : *new_vars) { + if (is_non_basic(v)) m_non_basic_vars.insert(v); + vars.push_back(v); + } + } + + // apply all plugins till saturation + void saturate(app_ref_vector &vars) { + bool progress; + do { + progress = false; + for (auto *p : m_plugins) { + if (p->apply()) { + progress = true; + add_vars(p, vars); + } + } + } + while (progress); + } + + void init(app_ref_vector &vars, expr_ref &fml, model &mdl) { + // variables to apply projection rules on + for (auto v : vars) + if (is_non_basic(v)) m_non_basic_vars.insert(v); + + // mark vars as non-ground. + m_tg.add_vars(vars); + // treat eq literals as term in the egraph + m_tg.set_explicit_eq(); + + expr_ref_vector fmls(m); + flatten_and(fml, fmls); + m_tg.add_lits(fmls); + + add_plugin(alloc(mbp_array_tg, m, m_tg, mdl, m_non_basic_vars, m_seen)); + add_plugin(alloc(mbp_dt_tg, m, m_tg, mdl, m_non_basic_vars, m_seen)); + add_plugin(alloc(mbp_basic_tg, m, m_tg, mdl, m_non_basic_vars, m_seen)); + } + + void add_plugin(mbp_tg_plugin *p) { m_plugins.push_back(p); } + + void enable_model_splitting() { + for (auto p : m_plugins) p->use_model(); + } + + mbp_tg_plugin *get_plugin(family_id fid) { + for (auto p : m_plugins) + if (p->get_family_id() == fid) + return p; + return nullptr; + } + +public: + impl(ast_manager &m, params_ref const &p) + : m(m), m_array_util(m), m_dt_util(m), m_params(p), m_tg(m) {} + + ~impl() { + std::for_each(m_plugins.begin(), m_plugins.end(), + delete_proc()); + } + + void operator()(app_ref_vector &vars, expr_ref &fml, model &mdl) { + if (vars.empty()) + return; + + init(vars, fml, mdl); + // Apply MBP rules till saturation + + // First, apply rules without splitting on model + saturate(vars); + + enable_model_splitting(); + + // Do complete mbp + saturate(vars); + + TRACE("mbp_tg", + tout << "mbp tg " << m_tg.get_lits() << " and vars " << vars;); + TRACE("mbp_tg_verbose", obj_hashtable vars_tmp; + collect_uninterp_consts(mk_and(m_tg.get_lits()), vars_tmp); + for (auto a + : vars_tmp) tout + << mk_pp(a->get_decl(), m) << "\n"; + for (auto b + : m_tg.get_lits()) tout + << expr_ref(b, m) << "\n"; + for (auto a + : vars) tout + << expr_ref(a, m) << " ";); + + // 1. Apply qe_lite to remove all c-ground variables + // 2. Collect all core variables in the output (variables used as array + // indices/values) + // 3. Re-apply qe_lite to remove non-core variables + + // Step 1. + m_tg.qel(vars, fml); + + // Step 2. + // Variables that appear as array indices or values cannot be + // eliminated if they are not c-ground. They are core variables All + // other Array/ADT variables can be eliminated, they are redundant. + obj_hashtable core_vars; + collect_selstore_vars(fml, core_vars, m); + + std::function is_red = [&](app *v) { + if (!m_dt_util.is_datatype(v->get_sort()) && + !m_array_util.is_array(v)) + return false; + return !core_vars.contains(v); + }; + expr_sparse_mark red_vars; + for (auto v : vars) + if (is_red(v)) red_vars.mark(v); + CTRACE("mbp_tg", !core_vars.empty(), tout << "vars not redundant "; + for (auto v + : core_vars) tout + << " " << app_ref(v, m); + tout << "\n";); + + std::function non_core = [&](expr *e) { + if (is_app(e) && is_partial_eq(to_app(e))) + return true; + if (m.is_ite(e) || m.is_or(e) || m.is_implies(e) || + m.is_distinct(e)) + return true; + return red_vars.is_marked(e); + }; + + // Step 3. + m_tg.qel(vars, fml, &non_core); + + CTRACE("mbp_tg", !vars.empty(), + tout << "before substitution " << fml << "\n";); + // for all remaining non-cgr bool, dt, array variables, add v = mdl(v) + expr_sparse_mark s_vars; + for (auto v : vars) { + if (m_dt_util.is_datatype(v->get_sort()) || + m_array_util.is_array(v) || m.is_bool(v)) { + CTRACE("mbp_tg", + m_array_util.is_array(v) || + m_dt_util.is_datatype(v->get_sort()), + tout << "Could not eliminate " << v->get_name() + << "\n";); + s_vars.mark(v); + m_tg.add_eq(v, mdl(v)); + } + } + + std::function substituted = [&](expr *e) { + return + (is_app(e) && is_partial_eq(to_app(e))) || + m.is_ite(e) || + red_vars.is_marked(e) || + s_vars.is_marked(e); + }; + + // remove all substituted variables + m_tg.qel(vars, fml, &substituted); + } +}; + +mbp_qel::mbp_qel(ast_manager &m, params_ref const &p) { + m_impl = alloc(impl, m, p); +} + +mbp_qel::~mbp_qel() { dealloc(m_impl); } + +void mbp_qel::operator()(app_ref_vector &vars, expr_ref &fml, model &mdl) { + (*m_impl)(vars, fml, mdl); +} + +} // namespace mbp diff --git a/src/qe/mbp/mbp_qel.h b/src/qe/mbp/mbp_qel.h new file mode 100644 index 00000000000..6246ec01f1b --- /dev/null +++ b/src/qe/mbp/mbp_qel.h @@ -0,0 +1,41 @@ +/*++ + + Module Name: + + mbp_qel.h + +Abstract: + + Model Based Projection based on term graph + +Author: + + Hari Govind V K (hgvk94) 2022-07-12 + +Revision History: + + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "model/model.h" +#include "util/params.h" + +namespace mbp { +class mbp_qel { + class impl; + impl *m_impl; + + public: + mbp_qel(ast_manager &m, params_ref const &p); + + ~mbp_qel(); + + /** + Do model based projection + */ + void operator()(app_ref_vector &vars, expr_ref &fml, model &mdl); +}; +} // namespace mbp diff --git a/src/qe/mbp/mbp_qel_util.cpp b/src/qe/mbp/mbp_qel_util.cpp new file mode 100644 index 00000000000..e6537a1d0bf --- /dev/null +++ b/src/qe/mbp/mbp_qel_util.cpp @@ -0,0 +1,110 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_qel_util.h + +Abstract: + + Utility methods for mbp_qel + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#include "qe/mbp/mbp_qel_util.h" +#include "ast/array_decl_plugin.h" +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/ast_util.h" +#include "ast/expr_functors.h" +#include "ast/for_each_expr.h" +#include "ast/pp.h" + +class check_uninterp_consts : public i_expr_pred { + obj_hashtable const &m_vars; + family_id m_fid; + decl_kind m_decl_kind; + + public: + check_uninterp_consts(obj_hashtable const &vars, ast_manager &man, + family_id fid = null_family_id, + decl_kind dk = null_decl_kind) + : m_vars(vars), m_fid(fid), m_decl_kind(dk) {} + bool operator()(expr *n) override { + return (is_app(n) && is_uninterp_const(n) && + m_vars.contains(to_app(n))) && + ((m_fid == null_family_id || m_decl_kind == null_decl_kind) || + (is_sort_of(to_app(n)->get_sort(), m_fid, m_decl_kind))); + } +}; + +// check if e contains any apps from vars +// if fid and dk are not null, check if the variable is of desired sort +bool contains_vars(expr *e, obj_hashtable const &vars, ast_manager &man, + family_id fid, decl_kind dk) { + check_uninterp_consts pred(vars, man, fid, dk); + check_pred check(pred, man, false); + return check(e); +} + +app_ref new_var(sort *s, ast_manager &m) { + return app_ref(m.mk_fresh_const("mbptg", s), m); +} + +namespace collect_uninterp_consts_ns { +struct proc { + obj_hashtable &m_out; + proc(obj_hashtable &out) : m_out(out) {} + void operator()(expr *n) const {} + void operator()(app *n) { + if (is_uninterp_const(n)) m_out.insert(n); + } +}; +} // namespace collect_uninterp_consts_ns + +// Return all uninterpreted constants of \p q +void collect_uninterp_consts(expr *e, obj_hashtable &out) { + collect_uninterp_consts_ns::proc proc(out); + for_each_expr(proc, e); +} + +namespace collect_selstore_vars_ns { +struct proc { + ast_manager &m; + obj_hashtable &m_vars; + array_util m_array_util; + datatype_util m_dt_util; + proc(obj_hashtable &vars, ast_manager &man) + : m(man), m_vars(vars), m_array_util(m), m_dt_util(m) {} + void operator()(expr *n) const {} + void operator()(app *n) { + if (m_array_util.is_select(n)) { + expr *idx = n->get_arg(1); + if (is_app(idx) && m_dt_util.is_accessor(to_app(idx)->get_decl())) + return; + collect_uninterp_consts(idx, m_vars); + } else if (m_array_util.is_store(n)) { + expr *idx = n->get_arg(1), *elem = n->get_arg(2); + if (!(is_app(idx) && + m_dt_util.is_accessor(to_app(idx)->get_decl()))) + collect_uninterp_consts(idx, m_vars); + if (!(is_app(elem) && + m_dt_util.is_accessor(to_app(elem)->get_decl()))) + collect_uninterp_consts(elem, m_vars); + } + } +}; +} // namespace collect_selstore_vars_ns + +// collect all uninterpreted consts used as array indices or values +void collect_selstore_vars(expr *fml, obj_hashtable &vars, + ast_manager &man) { + collect_selstore_vars_ns::proc proc(vars, man); + quick_for_each_expr(proc, fml); +} diff --git a/src/qe/mbp/mbp_qel_util.h b/src/qe/mbp/mbp_qel_util.h new file mode 100644 index 00000000000..dd1a7795ace --- /dev/null +++ b/src/qe/mbp/mbp_qel_util.h @@ -0,0 +1,41 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + mbp_qel_util.h + +Abstract: + + Utility methods for mbp_qel + +Author: + + Hari Govind V K (hgvk94) 2023-03-07 + +Revision History: + +--*/ + +#pragma once + +#include "ast/array_decl_plugin.h" +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/ast_util.h" +#include "ast/for_each_expr.h" +#include "ast/pp.h" + +// check if e contains any apps from vars +// if fid and dk are not null, check if the variable is of desired sort +bool contains_vars(expr *e, obj_hashtable const &vars, ast_manager &man, + family_id fid = null_family_id, + decl_kind dk = null_decl_kind); + +app_ref new_var(sort *s, ast_manager &m); + +// Return all uninterpreted constants of \p q +void collect_uninterp_consts(expr *e, obj_hashtable &out); +// collect all uninterpreted consts used as array indices or values +void collect_selstore_vars(expr *fml, obj_hashtable &vars, + ast_manager &man); diff --git a/src/qe/mbp/mbp_solve_plugin.cpp b/src/qe/mbp/mbp_solve_plugin.cpp index d80a3665d15..3820af59e12 100644 --- a/src/qe/mbp/mbp_solve_plugin.cpp +++ b/src/qe/mbp/mbp_solve_plugin.cpp @@ -248,30 +248,80 @@ namespace mbp { return false; } - expr_ref solve(expr* atom, bool is_pos) override { - expr *e1, *e2; + // returns `true` if a rewriting happened + bool try_int_mul_solve(expr *atom, bool is_pos, expr_ref &res) { - expr_ref res(atom, m); - if (m.is_eq (atom, e1, e2)) { - expr_ref v(m), t(m); - v = e1; t = e2; - // -- attempt to solve using arithmetic - solve(e1, e2, v, t); - // -- normalize equality - res = mk_eq_core(v, t); - } - else if (a.is_le(atom, e1, e2)) { - mk_le_core(e1, e2, res); - } - else if (a.is_ge(atom, e1, e2)) { - mk_ge_core(e1, e2, res); + if (!is_pos) + return false; // negation of multiplication is not a cube for + // integers + + // we want k*y == x -----> y = x div k && x mod k == 0 + expr *lhs = nullptr, *rhs = nullptr; + if (!m.is_eq(atom, lhs, rhs)) return false; + + if (!a.is_int(lhs)) { return false; } + + if (!a.is_mul(rhs)) { + if (a.is_mul(lhs)) + std::swap(lhs, rhs); + else + return false; // no muls } - // restore negation - if (!is_pos) { - res = mk_not(m, res); + // of the form v = k*expr + expr *first = nullptr, *second = nullptr; + if (!a.is_mul(rhs, first, second)) return false; + + if (!(is_app(first) && a.plugin().is_value(to_app(first)))) { + if (is_app(second) && a.plugin().is_value(to_app(second))) { + std::swap(first, second); + } else { + return false; + } } - return res; + + if (a.is_zero(first)) { + // SASSERT(a.is_int(lhs)); + res = m.mk_eq(lhs, a.mk_int(0)); + return true; + }; + + // `first` is a value, different from 0 + res = m.mk_and(m.mk_eq(second, a.mk_idiv(lhs, first)), + m.mk_eq(a.mk_int(0), a.mk_mod(lhs, first))); + + return true; + } + + expr_ref solve(expr* atom, bool is_pos) override { + + expr_ref res(atom, m); + + if (try_int_mul_solve(atom, is_pos, res)) return res; + + expr *e1, *e2; + + if (m.is_eq(atom, e1, e2)) { + expr_ref v(m), t(m); + v = e1; + t = e2; + // -- attempt to solve using arithmetic + solve(e1, e2, v, t); + // -- normalize equality + res = mk_eq_core(v, t); + } + else if (a.is_le(atom, e1, e2)) { + mk_le_core(e1, e2, res); + } + else if (a.is_ge(atom, e1, e2)) { + mk_ge_core(e1, e2, res); + } + + // restore negation + if (!is_pos) { + res = mk_not(m, res); + } + return res; } }; diff --git a/src/qe/mbp/mbp_term_graph.cpp b/src/qe/mbp/mbp_term_graph.cpp index 624febffd79..e5e5309f357 100644 --- a/src/qe/mbp/mbp_term_graph.cpp +++ b/src/qe/mbp/mbp_term_graph.cpp @@ -12,1291 +12,1859 @@ Module Name: Author: Arie Gurfinkel + Hari Govind V K (hgvk94) + Isabel Garcia (igcontreras) + +Revision History: + + Added implementation of qe_lite using term graph Notes: --*/ -#include "util/util.h" -#include "util/uint_set.h" -#include "util/obj_pair_hashtable.h" +#include "qe/mbp/mbp_term_graph.h" +#include "ast/array_peq.h" +#include "ast/ast.h" #include "ast/ast_pp.h" #include "ast/ast_util.h" #include "ast/for_each_expr.h" #include "ast/occurs.h" #include "ast/rewriter/th_rewriter.h" #include "model/model_evaluator.h" -#include "qe/mbp/mbp_term_graph.h" +#include "qe/mbp/mbp_arrays.h" +#include "util/bit_vector.h" +#include "util/obj_pair_hashtable.h" +#include "util/uint_set.h" +#include "util/util.h" namespace mbp { - static expr_ref mk_neq(ast_manager &m, expr *e1, expr *e2) { - expr *t = nullptr; - // x != !x == true - if ((m.is_not(e1, t) && t == e2) || (m.is_not(e2, t) && t == e1)) - return expr_ref(m.mk_true(), m); - else if (m.are_distinct(e1, e2)) - return expr_ref(m.mk_true(), m); - return expr_ref(m.mk_not(m.mk_eq(e1, e2)), m); - } +static expr_ref mk_neq(ast_manager &m, expr *e1, expr *e2) { + expr *t = nullptr; + // x != !x == true + if ((m.is_not(e1, t) && t == e2) || (m.is_not(e2, t) && t == e1)) + return expr_ref(m.mk_true(), m); + else if (m.are_distinct(e1, e2)) + return expr_ref(m.mk_true(), m); + return expr_ref(m.mk_not(m.mk_eq(e1, e2)), m); +} - namespace { - struct sort_lt_proc { - bool operator()(const expr* a, const expr *b) const { - return a->get_sort()->get_id() < b->get_sort()->get_id(); - } - }; +namespace { +struct sort_lt_proc { // for representatives in model_complete + bool operator()(const expr *a, const expr *b) const { + return a->get_sort()->get_id() < b->get_sort()->get_id(); } - - namespace is_pure_ns { - struct found{}; - struct proc { - is_variable_proc &m_is_var; - proc(is_variable_proc &is_var) : m_is_var(is_var) {} - void operator()(var *n) const {if (m_is_var(n)) throw found();} - void operator()(app const *n) const {if (m_is_var(n)) throw found();} - void operator()(quantifier *n) const {} - }; +}; +struct mark_all_sub_expr { + expr_sparse_mark &m_mark; + mark_all_sub_expr(expr_sparse_mark &mark) : m_mark(mark) {} + void operator()(var *n) const {} + void operator()(app *n) const { m_mark.mark(n); } + void operator()(quantifier *n) const {} +}; +} // namespace + +namespace is_pure_ns { +struct found {}; +struct proc { + is_variable_proc &m_is_var; + proc(is_variable_proc &is_var) : m_is_var(is_var) {} + void operator()(var *n) const { + if (m_is_var(n)) throw found(); } - - bool is_pure(is_variable_proc &is_var, expr *e) { - try { - is_pure_ns::proc v(is_var); - quick_for_each_expr(v, e); - } - catch (const is_pure_ns::found &) { - return false; - } - return true; + void operator()(app const *n) const { + if (m_is_var(n)) throw found(); } + void operator()(quantifier *n) const {} +}; +} // namespace is_pure_ns + +bool is_pure(is_variable_proc &is_var, expr *e) { + try { + is_pure_ns::proc v(is_var); + quick_for_each_expr(v, e); + } catch (const is_pure_ns::found &) { return false; } + return true; +} + +bool term_graph::is_ground(expr *e) { + try { + is_ground_ns::proc v(m_is_var); + quick_for_each_expr(v, e); + } catch (const is_ground_ns::found &) { return false; } + return true; +} - class term { - // -- an app represented by this term - expr_ref m_expr; // NSB: to make usable with exprs - // -- root of the equivalence class - term* m_root; - // -- next element in the equivalence class (cyclic linked list) - term* m_next; +class term { + // -- an app represented by this term + expr_ref m_expr; // NSB: to make usable with exprs + // -- root of the equivalence class + term *m_root; + // -- representative of the equivalence class + term *m_repr; + // -- next element in the equivalence class (cyclic linked list) + term *m_next; + // -- general purpose mark + unsigned m_mark : 1; + // -- general purpose second mark + unsigned m_mark2 : 1; + // -- is an interpreted constant + unsigned m_interpreted : 1; + // caches whether m_expr is an equality + unsigned m_is_eq : 1; + // caches whether m_expr is an inequality + unsigned m_is_neq : 1; + // caches whether m_expr is a distinct + unsigned m_is_distinct : 1; + // caches whether m_expr is a partial equality + unsigned m_is_peq : 1; + // caches whether m_expr is the child of not + unsigned m_is_neq_child : 1; + // caches whether m_expr is peq and the child of not + unsigned m_is_npeq_child : 1; + + // -- the term is a compound term can be rewritten to be ground or it is a + // ground constant + unsigned m_cgr : 1; + // -- the term is ground + unsigned m_gr : 1; + + // -- terms that contain this term as a child (only maintained for root + // nodes) + ptr_vector m_parents; + + // arguments of term. + ptr_vector m_children; + + struct class_props { + // TODO: parents should be here + // -- the class has a ground representative + unsigned m_gr_class : 1; // -- eq class size unsigned m_class_size; - // -- general purpose mark - unsigned m_mark:1; - // -- general purpose second mark - unsigned m_mark2:1; - // -- is an interpreted constant - unsigned m_interpreted:1; - - // -- terms that contain this term as a child - ptr_vector m_parents; - - // arguments of term. - ptr_vector m_children; - - public: - term(expr_ref const& v, u_map& app2term) : - m_expr(v), - m_root(this), - m_next(this), - m_class_size(1), - m_mark(false), - m_mark2(false), - m_interpreted(false) { - if (!is_app(m_expr)) return; - for (expr* e : *to_app(m_expr)) { - term* t = app2term[e->get_id()]; - t->get_root().m_parents.push_back(this); - m_children.push_back(t); - } + // -- disequality sets that the class belongs to + term_graph::deqs m_deqs; + + class_props() : m_gr_class(0), m_class_size(1) {} + void merge(class_props &b) { + m_class_size += b.m_class_size; + m_gr_class |= b.m_gr_class; + m_deqs |= b.m_deqs; // merge disequalities + // -- reset (useful for debugging) + b.m_class_size = 0; + b.m_gr_class = false; + b.m_deqs.reset(); + } + void transfer(class_props &b) { + // TODO replace by std::swap of the whole struct? + m_class_size = b.m_class_size; + b.m_class_size = 0; + std::swap(m_deqs, b.m_deqs); + m_gr_class = b.m_gr_class; + b.m_gr_class = false; } + }; + class_props m_class_props; + + public: + term(expr_ref const &v, u_map &app2term) + : m_expr(v), m_root(this), m_repr(nullptr), m_next(this), m_mark(false), + m_mark2(false), m_interpreted(false), + m_is_eq(m_expr.get_manager().is_eq(m_expr)), m_is_peq(false), + m_is_neq_child(false), m_is_npeq_child(false), + m_cgr(0), m_gr(0) { + m_is_neq = m_expr.get_manager().is_not(m_expr) && + m_expr.get_manager().is_eq(to_app(m_expr)->get_arg(0)); + m_is_distinct = m_expr.get_manager().is_distinct(m_expr); + m_children.reset(); + if (!is_app(m_expr)) return; + for (expr *e : *to_app(m_expr)) { + term *t = app2term[e->get_id()]; + t->get_root().m_parents.push_back(this); + m_children.push_back(t); + } + m_is_peq = is_partial_eq(to_app(m_expr)); + } - ~term() {} + ~term() {} - class parents { - term const& t; - public: - parents(term const& _t):t(_t) {} - parents(term const* _t):t(*_t) {} - ptr_vector::const_iterator begin() const { return t.m_parents.begin(); } - ptr_vector::const_iterator end() const { return t.m_parents.end(); } - }; + class parents { + term const &t; - class children { - term const& t; - public: - children(term const& _t):t(_t) {} - children(term const* _t):t(*_t) {} - ptr_vector::const_iterator begin() const { return t.m_children.begin(); } - ptr_vector::const_iterator end() const { return t.m_children.end(); } - }; - - // Congruence table hash function is based on - // roots of children and function declaration. - - unsigned get_hash() const { - unsigned a, b, c; - a = b = c = get_decl_id(); - for (term * ch : children(this)) { - a = ch->get_root().get_id(); - mix(a, b, c); - } - return c; + public: + parents(term const &_t) : t(_t) {} + parents(term const *_t) : t(*_t) {} + ptr_vector::const_iterator begin() const { + return t.m_parents.begin(); } - - static bool cg_eq(term const * t1, term const * t2) { - if (t1->get_decl_id() != t2->get_decl_id()) return false; - if (t1->m_children.size() != t2->m_children.size()) return false; - for (unsigned i = 0, sz = t1->m_children.size(); i < sz; ++ i) { - if (t1->m_children[i]->get_root().get_id() != t2->m_children[i]->get_root().get_id()) return false; - } - return true; + ptr_vector::const_iterator end() const { + return t.m_parents.end(); } + }; - unsigned get_id() const { return m_expr->get_id();} + class children { + term const &t; - unsigned get_decl_id() const { return is_app(m_expr) ? to_app(m_expr)->get_decl()->get_id() : m_expr->get_id(); } + public: + children(term const &_t) : t(_t) {} + children(term const *_t) : t(*_t) {} + ptr_vector::const_iterator begin() const { + return t.m_children.begin(); + } + ptr_vector::const_iterator end() const { + return t.m_children.end(); + } + }; - bool is_marked() const {return m_mark;} - void set_mark(bool v){m_mark = v;} - bool is_marked2() const {return m_mark2;} // NSB: where is this used? - void set_mark2(bool v){m_mark2 = v;} // NSB: where is this used? + // Congruence table hash function is based on + // roots of children and function declaration. - bool is_interpreted() const {return m_interpreted;} - bool is_theory() const { return !is_app(m_expr) || to_app(m_expr)->get_family_id() != null_family_id; } - void mark_as_interpreted() {m_interpreted=true;} - expr* get_expr() const {return m_expr;} - unsigned get_num_args() const { return is_app(m_expr) ? to_app(m_expr)->get_num_args() : 0; } + unsigned get_hash() const { + unsigned a, b, c; + a = b = c = get_decl_id(); + for (term *ch : children(this)) { + a = ch->get_root().get_id(); + mix(a, b, c); + } + return c; + } - term &get_root() const {return *m_root;} - bool is_root() const {return m_root == this;} - void set_root(term &r) {m_root = &r;} - term &get_next() const {return *m_next;} - void add_parent(term* p) { m_parents.push_back(p); } + static bool cg_eq(term const *t1, term const *t2) { + if (t1->get_decl_id() != t2->get_decl_id()) return false; + if (t1->m_children.size() != t2->m_children.size()) return false; + for (unsigned i = 0, sz = t1->m_children.size(); i < sz; ++i) { + if (t1->m_children[i]->get_root().get_id() != + t2->m_children[i]->get_root().get_id()) + return false; + } + return true; + } - unsigned get_class_size() const {return m_class_size;} + unsigned deg() const { return m_children.size(); } + unsigned get_id() const { return m_expr->get_id(); } + bool is_eq_or_neq() const { return m_is_eq || m_is_neq || m_is_distinct; } + bool is_eq_or_peq() const { return m_is_eq || m_is_peq; } + bool is_neq() const { return m_is_neq; } + void set_neq_child() { m_is_neq_child = true; } + void set_npeq_child() { m_is_npeq_child = true; } + bool is_neq_child() const { return m_is_neq_child; } + bool is_npeq_child() const { return m_is_npeq_child; } + unsigned get_decl_id() const { + return is_app(m_expr) ? to_app(m_expr)->get_decl()->get_id() + : m_expr->get_id(); + } - void merge_eq_class(term &b) { - std::swap(this->m_next, b.m_next); - m_class_size += b.get_class_size(); - // -- reset (useful for debugging) - b.m_class_size = 0; - } + bool is_marked() const { return m_mark; } + void set_mark(bool v) { m_mark = v; } + bool is_marked2() const { return m_mark2; } // NSB: where is this used? + void set_mark2(bool v) { m_mark2 = v; } // NSB: where is this used? - // -- make this term the root of its equivalence class - void mk_root() { - if (is_root()) return; - - term *curr = this; - do { - if (curr->is_root()) { - // found previous root - SASSERT(curr != this); - m_class_size = curr->get_class_size(); - curr->m_class_size = 0; - } - curr->set_root(*this); - curr = &curr->get_next(); - } - while (curr != this); - } + bool is_cgr() const { return m_cgr; } + void set_cgr(bool v) { m_cgr = v; } - std::ostream& display(std::ostream& out) const { - out << get_id() << ": " << m_expr - << (is_root() ? " R" : "") << " - "; - term const* r = &this->get_next(); - while (r != this) { - out << r->get_id() << " "; - r = &r->get_next(); - } - out << "\n"; - return out; - } - }; + bool is_gr() const { return m_gr; } + void set_gr(bool v) { m_gr = v; } - static std::ostream& operator<<(std::ostream& out, term const& t) { - return t.display(out); + bool is_class_gr_root() const { + SASSERT(is_root()); + return m_class_props.m_gr_class; + } + void set_class_gr_root(bool v) { + SASSERT(is_root()); + m_class_props.m_gr_class = v; } + bool is_class_gr() const { return m_root->is_class_gr_root(); } + void set_class_gr(bool v) { m_root->set_class_gr_root(v); } - bool term_graph::is_variable_proc::operator()(const expr * e) const { - if (!is_app(e)) return false; - const app *a = ::to_app(e); - TRACE("qe_verbose", tout << a->get_family_id() << " " << m_solved.contains(a->get_decl()) << " " << m_decls.contains(a->get_decl()) << "\n";); - return - a->get_family_id() == null_family_id && - !m_solved.contains(a->get_decl()) && - m_exclude == m_decls.contains(a->get_decl()); + static bool are_deq(const term &t1, const term &t2) { + term_graph::deqs const &ds1 = t1.get_root().get_deqs(); + term_graph::deqs const &ds2 = t2.get_root().get_deqs(); + + term_graph::deqs tmp(ds1); // copy + + tmp &= ds2; + return tmp != 0; } - bool term_graph::is_variable_proc::operator()(const term &t) const { - return (*this)(t.get_expr()); + static void set_deq(term_graph::deqs &ds, unsigned idx) { + ds.resize(idx + 1); + ds.set(idx); } - void term_graph::is_variable_proc::set_decls(const func_decl_ref_vector &decls, bool exclude) { - reset(); - m_exclude = exclude; - for (auto *d : decls) m_decls.insert(d); + bool all_children_ground() { + SASSERT(deg() != 0); + return all_of(m_children, + [&](const term *t) { return t->is_class_gr(); }); } - void term_graph::is_variable_proc::mark_solved(const expr *e) { - if ((*this)(e) && is_app(e)) - m_solved.insert(::to_app(e)->get_decl()); + + void set_mark2_terms_class(bool v) { // TODO: remove + if (is_marked2()) return; + term *curr = this; + do { + curr->set_mark2(v); + curr = &curr->get_next(); + } + while (curr != this); } + bool is_interpreted() const { return m_interpreted; } + bool is_theory() const { + return !is_app(m_expr) || + to_app(m_expr)->get_family_id() != null_family_id; + } + void mark_as_interpreted() { m_interpreted = true; } + expr *get_expr() const { return m_expr; } + unsigned get_num_args() const { + return is_app(m_expr) ? to_app(m_expr)->get_num_args() : 0; + } - unsigned term_graph::term_hash::operator()(term const* t) const { return t->get_hash(); } + term &get_root() const { return *m_root; } + bool is_root() const { return m_root == this; } + void set_root(term &r) { m_root = &r; } + term *get_repr() const { return m_repr; } + bool is_repr() const { return m_repr == this; } + void set_repr(term *t) { + SASSERT(get_root().get_id() == t->get_root().get_id()); + m_repr = t; + } + void reset_repr() { m_repr = nullptr; } + term &get_next() const { return *m_next; } + void add_parent(term *p) { m_parents.push_back(p); } - bool term_graph::term_eq::operator()(term const* a, term const* b) const { return term::cg_eq(a, b); } + unsigned get_class_size() const { return m_class_props.m_class_size; } - term_graph::term_graph(ast_manager &man) : m(man), m_lits(m), m_pinned(m), m_projector(nullptr) { - m_plugins.register_plugin(mbp::mk_basic_solve_plugin(m, m_is_var)); - m_plugins.register_plugin(mbp::mk_arith_solve_plugin(m, m_is_var)); + void merge_eq_class(term &b) { + std::swap(this->m_next, b.m_next); + m_class_props.merge(b.m_class_props); } - term_graph::~term_graph() { - dealloc(m_projector); - reset(); + // -- make this term the repr of its equivalence class + void mk_repr() { + term *curr = this; + do { + curr->set_repr(this); + curr = &curr->get_next(); + } + while (curr != this); } - bool term_graph::is_pure_def(expr *atom, expr*& v) { - expr *e = nullptr; - return m.is_eq(atom, v, e) && m_is_var(v) && is_pure(m_is_var, e); + std::ostream &display(std::ostream &out) const { + out << get_id() << ": " << m_expr << (is_repr() ? " R" : "") + << (is_gr() ? " G" : "") << (is_class_gr() ? " clsG" : "") + << (is_cgr() ? " CG" : "") << " deg:" << deg() << " - "; + term const *r = &this->get_next(); + while (r != this) { + out << r->get_id() << " " << (r->is_cgr() ? " CG" : "") << " "; + r = &r->get_next(); + } + out << "\n"; + return out; } - static family_id get_family_id(ast_manager &m, expr *lit) { - if (m.is_not(lit, lit)) - return get_family_id(m, lit); + term_graph::deqs &get_deqs() { return m_class_props.m_deqs; } +}; - expr *a = nullptr, *b = nullptr; - // deal with equality using sort of range - if (m.is_eq (lit, a, b)) { - return a->get_sort()->get_family_id(); - } - // extract family_id of top level app - else if (is_app(lit)) { - return to_app(lit)->get_decl()->get_family_id(); +static std::ostream &operator<<(std::ostream &out, term const &t) { + return t.display(out); +} + + +// t1 != t2 +void term_graph::add_deq_proc::operator()(term *t1, term *t2) { + term::set_deq(t1->get_root().get_deqs(), m_deq_cnt); + term::set_deq(t2->get_root().get_deqs(), m_deq_cnt); + inc_count(); +} + +// distinct(ts) +void term_graph::add_deq_proc::operator()(ptr_vector &ts) { + for (auto t : ts) + term::set_deq(t->get_root().get_deqs(), m_deq_cnt); + inc_count(); +} + +void term_graph::add_deq_proc::inc_count() { + m_deq_cnt++; + if (m_deq_cnt == 0) + throw default_exception("unexpected wrap-around on m_deq_cnt"); +} + +bool term_graph::is_variable_proc::operator()(const expr *e) const { + if (!is_app(e)) return false; + const app *a = ::to_app(e); + TRACE("qe_verbose", tout << a->get_family_id() << " " + << m_solved.contains(a->get_decl()) << " " + << m_decls.contains(a->get_decl()) << "\n";); + return a->get_family_id() == null_family_id && + !m_solved.contains(a->get_decl()) && + m_exclude == m_decls.contains(a->get_decl()); +} + +bool term_graph::is_variable_proc::operator()(const term &t) const { + return (*this)(t.get_expr()); +} + +void term_graph::is_variable_proc::set_decls(const func_decl_ref_vector &decls, + bool exclude) { + reset(); + m_exclude = exclude; + for (auto *d : decls) m_decls.insert(d); +} + +void term_graph::is_variable_proc::add_decls(const app_ref_vector &decls) { + for (auto *d : decls) m_decls.insert(d->get_decl()); +} + +void term_graph::is_variable_proc::add_decl(app *d) { + m_decls.insert(d->get_decl()); +} + +void term_graph::is_variable_proc::set_decls(const app_ref_vector &vars, + bool exclude) { + reset(); + m_exclude = exclude; + for (auto *v : vars) m_decls.insert(v->get_decl()); +} + +void term_graph::is_variable_proc::mark_solved(const expr *e) { + if ((*this)(e) && is_app(e)) m_solved.insert(::to_app(e)->get_decl()); +} + +unsigned term_graph::term_hash::operator()(term const *t) const { + return t->get_hash(); +} + +bool term_graph::term_eq::operator()(term const *a, term const *b) const { + return term::cg_eq(a, b); +} + +term_graph::term_graph(ast_manager &man) + : m(man), m_lits(m), m_pinned(m) { + m_is_var.reset(); + m_plugins.register_plugin(mbp::mk_basic_solve_plugin(m, m_is_var)); + m_plugins.register_plugin(mbp::mk_arith_solve_plugin(m, m_is_var)); +} + +term_graph::~term_graph() { + dealloc(m_projector); + reset(); +} + +bool term_graph::is_pure_def(expr *atom, expr *&v) { + expr *e = nullptr; + return m.is_eq(atom, v, e) && m_is_var(v) && is_pure(m_is_var, e); +} + +static family_id get_family_id(ast_manager &m, expr *lit) { + if (m.is_not(lit, lit)) + return get_family_id(m, lit); + + expr *a = nullptr, *b = nullptr; + if (m.is_eq(lit, a, b)) // deal with equality using sort of range + return a->get_sort()->get_family_id(); + else if (is_app(lit)) // extract family_id of top level app + return to_app(lit)->get_decl()->get_family_id(); + else + return null_family_id; +} + +void term_graph::add_lit(expr *l) { + expr_ref lit(m); + expr_ref_vector lits(m); + lits.push_back(l); + for (unsigned i = 0; i < lits.size(); ++i) { + l = lits.get(i); + family_id fid = get_family_id(m, l); + mbp::solve_plugin *pin = m_plugins.get_plugin(fid); + lit = pin ? (*pin)(l) : l; + if (m.is_and(lit)) { + lits.append(::to_app(lit)->get_num_args(), + ::to_app(lit)->get_args()); } else { - return null_family_id; - } - } - void term_graph::add_lit(expr *l) { - expr_ref lit(m); - expr_ref_vector lits(m); - lits.push_back(l); - for (unsigned i = 0; i < lits.size(); ++i) { - l = lits.get(i); - family_id fid = get_family_id(m, l); - mbp::solve_plugin *pin = m_plugins.get_plugin(fid); - lit = pin ? (*pin)(l) : l; - if (m.is_and(lit)) { - lits.append(::to_app(lit)->get_num_args(), ::to_app(lit)->get_args()); - } - else { - m_lits.push_back(lit); - internalize_lit(lit); - } + m_lits.push_back(lit); + internalize_lit(lit); } } +} - bool term_graph::is_internalized(expr *a) { - return m_app2term.contains(a->get_id()); +// collect expressions of all terms in the term graph +// optionally, exclude constructively ground nodes that are not equalities +// overwrites res +void term_graph::get_terms(expr_ref_vector &res, bool exclude_cground) { + std::function fil = nullptr; + if (exclude_cground) { + fil = [](term *t) { + return !t->is_neq_child() && !t->is_npeq_child() && (t->is_eq_or_peq() || !t->is_cgr()); + }; } - - term* term_graph::get_term(expr *a) { - term *res; - return m_app2term.find (a->get_id(), res) ? res : nullptr; + else { + fil = [](term *t) { return !t->is_neq_child() && !t->is_npeq_child(); }; } + auto terms = m_terms.filter_pure(fil); + res.resize(terms.size()); + unsigned i = 0; + for (term *t : terms) + res[i++] = t->get_expr(); +} + +bool term_graph::is_cgr(expr *e) { + if (!is_internalized(e)) return false; + term *t = get_term(e); + return (!t->is_eq_or_peq() && t->is_cgr()); +} - term *term_graph::mk_term(expr *a) { - expr_ref e(a, m); - term * t = alloc(term, e, m_app2term); - if (t->get_num_args() == 0 && m.is_unique_value(a)) - t->mark_as_interpreted(); +bool term_graph::is_internalized(expr *a) { + return m_app2term.contains(a->get_id()); +} + +term *term_graph::get_term(expr *a) { + term *res; + return m_app2term.find(a->get_id(), res) ? res : nullptr; +} - m_terms.push_back(t); - m_app2term.insert(a->get_id(), t); - return t; +term *term_graph::mk_term(expr *a) { + expr_ref e(a, m); + term *t = alloc(term, e, m_app2term); + if (is_ground(a)) { + t->set_gr(true); + t->set_cgr(true); + t->set_class_gr(true); } + else if (t->deg() > 0 && t->all_children_ground()) { + t->set_cgr(true); + t->set_class_gr(true); + } + if (t->get_num_args() == 0 && m.is_unique_value(a)) + t->mark_as_interpreted(); - term* term_graph::internalize_term(expr *t) { - term* res = get_term(t); - if (res) return res; - ptr_buffer todo; - todo.push_back(t); - while (!todo.empty()) { - t = todo.back(); - res = get_term(t); - if (res) { - todo.pop_back(); - continue; - } - unsigned sz = todo.size(); - if (is_app(t)) { - for (expr * arg : *::to_app(t)) { - if (!get_term(arg)) - todo.push_back(arg); - } - } - if (sz < todo.size()) continue; + m_terms.push_back(t); + m_app2term.insert(a->get_id(), t); + return t; +} + +term *term_graph::internalize_term(expr *t) { + term *res = get_term(t); + if (res) return res; + ptr_buffer todo; + todo.push_back(t); + while (!todo.empty()) { + t = todo.back(); + res = get_term(t); + if (res) { todo.pop_back(); - res = mk_term(t); - } - SASSERT(res); - return res; + continue; + } + unsigned sz = todo.size(); + if (is_app(t)) + for (expr *arg : *::to_app(t)) + if (!get_term(arg)) + todo.push_back(arg); + + if (sz < todo.size()) + continue; + todo.pop_back(); + res = mk_term(t); + + // the term was not internalized in this syntactic form, but it + // could be congruent with some other term, if that is the case, we + // need to merge them. + term *res_old = m_cg_table.insert_if_not_there(res); + if (res->is_cgr()) + res_old->set_cgr(true); + SASSERT(res_old->is_cgr() == res->is_cgr()); + if (res_old->get_root().get_id() != res->get_root().get_id()) + m_merge.push_back({res, res_old}); } - - void term_graph::internalize_eq(expr *a1, expr* a2) { - SASSERT(m_merge.empty()); - merge(*internalize_term(a1), *internalize_term(a2)); - merge_flush(); - SASSERT(m_merge.empty()); + merge_flush(); + SASSERT(res); + if (m.is_not(t) && is_app(to_app(t)->get_arg(0)) && is_partial_eq(to_app(to_app(t)->get_arg(0)))) { + term* p = get_term(to_app(t)->get_arg(0)); + SASSERT(p); + p->set_npeq_child(); } + return res; +} - void term_graph::internalize_lit(expr* lit) { - expr *e1 = nullptr, *e2 = nullptr, *v = nullptr; - if (m.is_eq (lit, e1, e2)) { - internalize_eq (e1, e2); - } - else { - internalize_term(lit); - } - if (is_pure_def(lit, v)) { - m_is_var.mark_solved(v); - } - } +void term_graph::internalize_eq(expr *a1, expr *a2) { + SASSERT(m_merge.empty()); + merge(*internalize_term(a1), *internalize_term(a2)); + merge_flush(); + SASSERT(m_merge.empty()); + if (!m_explicit_eq) + return; + expr_ref eq(m.mk_eq(a1, a2), m); + term *res = get_term(eq); + if (!res) + mk_term(eq); +} - void term_graph::merge_flush() { - while (!m_merge.empty()) { - term* t1 = m_merge.back().first; - term* t2 = m_merge.back().second; - m_merge.pop_back(); - merge(*t1, *t2); - } +void term_graph::internalize_distinct(expr *d) { + app *a = to_app(d); + ptr_vector ts(a->get_decl()->get_arity()); + auto tsit = ts.begin(); + for (auto arg : *a) { + *tsit = internalize_term(arg); + tsit++; } + m_add_deq(ts); + m_deq_distinct.push_back(ts); + if (!m_explicit_eq) return; + term *t = get_term(d); + if (!t) mk_term(d); +} - void term_graph::merge(term &t1, term &t2) { - term *a = &t1.get_root(); - term *b = &t2.get_root(); +// Assumes that a1 != a2 is satisfiable +void term_graph::internalize_deq(expr *a1, expr *a2) { + term *t1 = internalize_term(a1); + term *t2 = internalize_term(a2); + m_add_deq(t1, t2); + m_deq_pairs.push_back({t1, t2}); + if (!m_explicit_eq) + return; + expr_ref eq(m.mk_eq(a1, a2), m); + term *eq_term = mk_term(eq); + eq_term->set_neq_child(); + expr_ref deq(m.mk_not(eq), m); + term *res = get_term(deq); + if (!res) + mk_term(deq); +} - if (a == b) return; +void term_graph::internalize_lit(expr *lit) { + expr *e1 = nullptr, *e2 = nullptr, *ne = nullptr, *v = nullptr; + if (m.is_eq(lit, e1, e2)) // internalize equality + internalize_eq(e1, e2); + else if (m.is_distinct(lit)) + internalize_distinct(lit); + else if (m.is_not(lit, ne) && m.is_eq(ne, e1, e2)) + internalize_deq(e1, e2); + else + internalize_term(lit); + if (is_pure_def(lit, v)) + m_is_var.mark_solved(v); +} - // -- merge might invalidate term2app cache - m_term2app.reset(); - m_pinned.reset(); +void term_graph::merge_flush() { + while (!m_merge.empty()) { + term *t1 = m_merge.back().first; + term *t2 = m_merge.back().second; + m_merge.pop_back(); + merge(*t1, *t2); + } +} - if (a->get_class_size() > b->get_class_size()) { - std::swap(a, b); - } +void term_graph::merge(term &t1, term &t2) { + term *a = &t1.get_root(); + term *b = &t2.get_root(); - // Remove parents of b from the cg table. - for (term* p : term::parents(b)) { - if (!p->is_marked()) { - p->set_mark(true); - m_cg_table.erase(p); - } - } - // make 'a' be the root of the equivalence class of 'b' - b->set_root(*a); - for (term *it = &b->get_next(); it != b; it = &it->get_next()) { - it->set_root(*a); - } + if (a == b) + return; - // merge equivalence classes - a->merge_eq_class(*b); - - // Insert parents of b's old equilvalence class into the cg table - for (term* p : term::parents(b)) { - if (p->is_marked()) { - term* p_old = m_cg_table.insert_if_not_there(p); - p->set_mark(false); - a->add_parent(p); - // propagate new equalities. - if (p->get_root().get_id() != p_old->get_root().get_id()) { - m_merge.push_back(std::make_pair(p, p_old)); - } - } + // -- merge might invalidate term2app cache + m_term2app.reset(); + m_pinned.reset(); + m_repick_repr = true; + + if (a->get_class_size() > b->get_class_size()) + std::swap(a, b); + + // Remove parents of b from the cg table + for (term *p : term::parents(b)) { + if (!p->is_marked()) { + p->set_mark(true); + m_cg_table.erase(p); } - SASSERT(marks_are_clear()); } - expr* term_graph::mk_app_core (expr *e) { - if (is_app(e)) { - expr_ref_buffer kids(m); - app* a = ::to_app(e); - for (expr * arg : *a) { - kids.push_back (mk_app(arg)); - } - app* res = m.mk_app(a->get_decl(), a->get_num_args(), kids.data()); - m_pinned.push_back(res); - return res; - } - else { - return e; + bool prop_cgroundness = (b->is_class_gr() != a->is_class_gr()); + // make 'a' be the root of the equivalence class of 'b' + b->set_root(*a); + for (term *it = &b->get_next(); it != b; it = &it->get_next()) + it->set_root(*a); + + // merge equivalence classes + a->merge_eq_class(*b); + + // Insert parents of b's old equivalence class into the cg table + // bottom-up merge of parents + for (term *p : term::parents(b)) { + if (p->is_marked()) { + term *p_old = m_cg_table.insert_if_not_there(p); + p->set_mark(false); + a->add_parent(p); + // propagate new equalities. + if (p->get_root().get_id() != p_old->get_root().get_id()) + m_merge.push_back({p, p_old}); } } + if (prop_cgroundness) + cground_percolate_up(a); - expr_ref term_graph::mk_app(term const &r) { - SASSERT(r.is_root()); + SASSERT(marks_are_clear()); +} - if (r.get_num_args() == 0) { - return expr_ref(r.get_expr(), m); - } +expr *term_graph::mk_app_core(expr *e) { + if (!is_app(e)) + return e; + expr_ref_buffer kids(m); + app *a = ::to_app(e); + for (expr *arg : *a) + kids.push_back(mk_app(arg)); + app *res = m.mk_app(a->get_decl(), a->get_num_args(), kids.data()); + m_pinned.push_back(res); + return res; +} - expr* res = nullptr; - if (m_term2app.find(r.get_id(), res)) { - return expr_ref(res, m); - } +expr_ref term_graph::mk_app(term &r) { + SASSERT(r.is_repr()); - res = mk_app_core (r.get_expr()); - m_term2app.insert(r.get_id(), res); - return expr_ref(res, m); + if (r.get_num_args() == 0) + return expr_ref(r.get_expr(), m); - } + expr *res = nullptr; + if (m_term2app.find(r.get_id(), res)) + return expr_ref(res, m); - expr_ref term_graph::mk_app(expr *a) { - term *t = get_term(a); - if (!t) - return expr_ref(a, m); - else - return mk_app(t->get_root()); + res = mk_app_core(r.get_expr()); + m_term2app.insert(r.get_id(), res); + return expr_ref(res, m); +} - } +expr_ref term_graph::mk_app(expr *a) { + term *t = get_term(a); + SASSERT(!t || t->get_repr()); + if (!t) + return expr_ref(a, m); + else + return mk_app(*t->get_repr()); +} - void term_graph::mk_equalities(term const &t, expr_ref_vector &out) { - SASSERT(t.is_root()); - expr_ref rep(mk_app(t), m); - for (term *it = &t.get_next(); it != &t; it = &it->get_next()) { - expr* mem = mk_app_core(it->get_expr()); - out.push_back (m.mk_eq (rep, mem)); - } +void term_graph::mk_equalities(term &t, expr_ref_vector &out) { + SASSERT(t.is_repr()); + if (t.get_class_size() == 1) + return; + expr_ref rep(mk_app(t), m); + for (term *it = &t.get_next(); it != &t; it = &it->get_next()) { + expr *mem = mk_app_core(it->get_expr()); + out.push_back(m.mk_eq(rep, mem)); } +} - void term_graph::mk_all_equalities(term const &t, expr_ref_vector &out) { - mk_equalities(t, out); +void term_graph::mk_all_equalities(term &t, expr_ref_vector &out) { + if (t.get_class_size() == 1) + return; - for (term *it = &t.get_next(); it != &t; it = &it->get_next ()) { - expr* a1 = mk_app_core (it->get_expr()); - for (term *it2 = &it->get_next(); it2 != &t; it2 = &it2->get_next()) { - expr* a2 = mk_app_core(it2->get_expr()); - out.push_back (m.mk_eq (a1, a2)); - } + mk_equalities(t, out); + + for (term *it = &t.get_next(); it != &t; it = &it->get_next()) { + expr *a1 = mk_app_core(it->get_expr()); + for (term *it2 = &it->get_next(); it2 != &t; it2 = &it2->get_next()) { + expr *a2 = mk_app_core(it2->get_expr()); + out.push_back(m.mk_eq(a1, a2)); } } +} - void term_graph::reset_marks() { - for (term * t : m_terms) { - t->set_mark(false); - } +void term_graph::mk_qe_lite_equalities(term &t, expr_ref_vector &out, + check_pred &contains_nc) { + SASSERT(t.is_repr()); + if (t.get_class_size() == 1) return; + expr_ref rep(m); + rep = mk_app(t); + if (contains_nc(rep)) { + TRACE( + "qe_debug", tout << "repr not in core " << t; + for (term *it = &t.get_next(); it != &t; + it = &it->get_next()) { tout << *it << "\n"; };); + DEBUG_CODE( + for (term *it = &t.get_next(); it != &t; it = &it->get_next()) + SASSERT(!it->is_cgr() || it->is_eq_or_neq() || + contains_nc(mk_app_core(it->get_expr())));); + return; + } + for (term *it = &t.get_next(); it != &t; it = &it->get_next()) { + expr *e = it->get_expr(); + SASSERT(is_app(e)); + app *a = to_app(e); + // don't add equalities for vars to eliminate + if (m_is_var.contains(a->get_decl())) continue; + expr *mem = mk_app_core(e); + if (rep != mem && !contains_nc(mem)) + out.push_back(m.mk_eq(rep, mem)); } +} - bool term_graph::marks_are_clear() { - for (term * t : m_terms) { - if (t->is_marked()) return false; - } +void term_graph::reset_marks() { + for (term *t : m_terms) t->set_mark(false); +} + +void term_graph::reset_marks2() { + for (term *t : m_terms) t->set_mark2(false); +} + +bool term_graph::marks_are_clear() { + return all_of(m_terms, [](term* t) { return !t->is_marked(); }); +} + +/// Order of preference for roots of equivalence classes +/// XXX This should be factored out to let clients control the preference +bool term_graph::term_lt(term const &t1, term const &t2) { + // prefer constants over applications (ground) + // prefer applications over variables (for non-ground) + // prefer uninterpreted constants over values + // prefer smaller expressions over larger ones + + if (t1.get_num_args() == 0 || t2.get_num_args() == 0) { + if (t1.get_num_args() == t2.get_num_args()) { + if (m.is_value(t1.get_expr()) == m.is_value(t2.get_expr())) + return t1.get_id() < t2.get_id(); + return m.is_value(t2.get_expr()); + } + return t1.get_num_args() < t2.get_num_args(); + } + + // XXX this is the internalized size, not the size with the new + // representatives + unsigned sz1 = get_num_exprs(t1.get_expr()); + unsigned sz2 = get_num_exprs(t2.get_expr()); + return sz1 < sz2; +} + +bool all_children_picked(term *t) { + if (t->deg() == 0) return true; + for (term *c : term::children(t)) + if (!c->get_repr()) + return false; + return true; +} + +// pick representatives for all terms in todo. Then, pick representatives for +// all terms whose children have representatives +void term_graph::pick_repr_percolate_up(ptr_vector &todo) { + term *t; + while (!todo.empty()) { + t = todo.back(); + todo.pop_back(); + if (t->get_repr()) + continue; + pick_repr_class(t); + for (auto it : term::parents(t->get_root())) + if (all_children_picked(it)) + todo.push_back(it); } +} - /// Order of preference for roots of equivalence classes - /// XXX This should be factored out to let clients control the preference - bool term_graph::term_lt(term const &t1, term const &t2) { - // prefer constants over applications - // prefer uninterpreted constants over values - // prefer smaller expressions over larger ones - if (t1.get_num_args() == 0 || t2.get_num_args() == 0) { - if (t1.get_num_args() == t2.get_num_args()) { - // t1.get_num_args() == t2.get_num_args() == 0 - if (m.is_value(t1.get_expr()) == m.is_value(t2.get_expr())) - return t1.get_id() < t2.get_id(); - return m.is_value(t2.get_expr()); - } - return t1.get_num_args() < t2.get_num_args(); - } +// iterate through all terms in a class and pick a representative that: +// 1. is cgr and 2. least according to term_lt +void term_graph::pick_repr_class(term *t) { + SASSERT(all_children_picked(t)); + term *r = t; + for (term *it = &t->get_next(); it != t; it = &it->get_next()) { + if (!all_children_picked(it)) + continue; + if ((it->is_cgr() && !r->is_cgr()) || + (it->is_cgr() == r->is_cgr() && term_lt(*it, *r))) + r = it; + } + r->mk_repr(); +} - unsigned sz1 = get_num_exprs(t1.get_expr()); - unsigned sz2 = get_num_exprs(t2.get_expr()); - return sz1 < sz2; +// Choose repr for equivalence classes +// repr has the following properties: +// 1. acyclicity (mk_app terminates) +// 2. maximal wrt cgr +// 3. each class has exactly one repr +// assumes that cgroundness has been computed +void term_graph::pick_repr() { + // invalidates cache + m_term2app.reset(); + DEBUG_CODE(for (term *t + : m_terms) + SASSERT(t->deg() == 0 || !t->all_children_ground() || + t->is_cgr());); + for (term *t : m_terms) t->reset_repr(); + ptr_vector todo; + for (term *t : m_terms) + if (t->deg() == 0 && t->is_cgr()) + todo.push_back(t); + pick_repr_percolate_up(todo); + DEBUG_CODE(for (term *t : m_terms) SASSERT(!t->is_cgr() || t->get_repr());); + + for (term *t : m_terms) { + if (t->get_repr()) + continue; + if (t->deg() == 0) + todo.push_back(t); } + pick_repr_percolate_up(todo); + DEBUG_CODE(for (term *t : m_terms) SASSERT(t->get_repr());); + DEBUG_CODE(for (auto t + : m_terms) + SASSERT(!t->is_cgr() || t->get_repr()->is_cgr());); +} - void term_graph::pick_root (term &t) { - term *r = &t; - for (term *it = &t.get_next(); it != &t; it = &it->get_next()) { - it->set_mark(true); - if (term_lt(*it, *r)) { r = it; } - } +// if t is a variable, attempt to pick non-var +void term_graph::refine_repr_class(term *t) { + SASSERT(t->is_repr()); + auto is_var = [&](term *p) { + SASSERT(is_app(p->get_expr())); + return m_is_var.contains(to_app(p->get_expr())->get_decl()); + }; + if (!is_var(t)) + return; + term *r = t; + for (term *it = &t->get_next(); it != t; it = &it->get_next()) + if (!makes_cycle(it) && is_var(r) && !is_var(it)) + r = it; + r->mk_repr(); +} - // -- if found something better, make it the new root - if (r != &t) { - r->mk_root(); - } +// check if t makes a cycle if chosen as repr. This function assumes that the +// current repr doesn't have cycles. If there is a cycle in a child class the +// function doesn't terminate. +bool term_graph::makes_cycle(term *t) { + term &r = t->get_root(); + ptr_vector todo; + for (auto *it : term::children(t)) + todo.push_back(it->get_repr()); + term *it; + while (!todo.empty()) { + it = todo.back(); + todo.pop_back(); + if (it->get_root().get_id() == r.get_id()) + return true; + for (auto *ch : term::children(it)) + todo.push_back(ch->get_repr()); } + return false; +} + +void term_graph::refine_repr() { + // invalidates cache + m_term2app.reset(); + for (term *t : m_terms) + if (!t->get_repr()->is_cgr()) + refine_repr_class(t->get_repr()); +} + +// returns true if tg ==> e = v where v is a value +bool term_graph::has_val_in_class(expr *e) { + term *r = get_term(e); + if (!r) return false; + auto is_val = [&](term *t) { return m.is_value(t->get_expr()); }; + if (is_val(r)) + return true; + for (term *it = &r->get_next(); it != r; it = &it->get_next()) + if (is_val(it)) + return true; + return false; +} + +// if there exists an uninterpreted const c s.t. tg ==> e = c, return c +// else return nullptr +app *term_graph::get_const_in_class(expr *e) { + term *r = get_term(e); + if (!r) + return nullptr; + auto is_const = [](term *t) { return is_uninterp_const(t->get_expr()); }; + if (is_const(r)) + return ::to_app(r->get_expr()); + for (term *it = &r->get_next(); it != r; it = &it->get_next()) + if (is_const(it)) + return ::to_app(it->get_expr()); + return nullptr; +} + +void term_graph::display(std::ostream &out) { + for (term *t : m_terms) { out << *t; } +} + +void term_graph::to_lits(expr_ref_vector &lits, bool all_equalities, + bool repick_repr) { + if (m_repick_repr || repick_repr) pick_repr(); - /// Choose better roots for equivalence classes - void term_graph::pick_roots() { - SASSERT(marks_are_clear()); - for (term* t : m_terms) { - if (!t->is_marked() && t->is_root()) - pick_root(*t); + for (expr *a : m_lits) { + if (is_internalized(a)) { + if (m_explicit_eq && get_term(a)->is_eq_or_neq()) continue; + lits.push_back(::to_app(mk_app(a))); } - reset_marks(); } - void term_graph::display(std::ostream &out) { - for (term * t : m_terms) { - out << *t; - } + for (term *t : m_terms) { + if (t->is_eq_or_neq()) continue; + if (!t->is_repr()) + continue; + else if (all_equalities) + mk_all_equalities(*t, lits); + else + mk_equalities(*t, lits); } - void term_graph::to_lits (expr_ref_vector &lits, bool all_equalities) { - pick_roots(); + // TODO: use seen to prevent duplicate disequalities + for (auto p : m_deq_pairs) { + lits.push_back(mk_neq(m, mk_app(p.first->get_expr()), + mk_app(p.second->get_expr()))); + } - for (expr * a : m_lits) { - if (is_internalized(a)) { - lits.push_back (::to_app(mk_app(a))); - } - } + for (auto t : m_deq_distinct) { + ptr_vector args(t.size()); + for (auto c : t) args.push_back(mk_app(c->get_expr())); + lits.push_back(m.mk_distinct(args.size(), args.data())); + } +} - for (term * t : m_terms) { - if (!t->is_root()) - continue; - else if (all_equalities) - mk_all_equalities (*t, lits); - else - mk_equalities(*t, lits); - } +// assumes that representatives have already been picked +void term_graph::to_lits_qe_lite(expr_ref_vector &lits, + std::function *non_core) { + DEBUG_CODE(for (auto t : m_terms) SASSERT(t->get_repr());); + DEBUG_CODE(for (auto t + : m_terms) + SASSERT(!t->is_cgr() || t->get_repr()->is_cgr());); + is_non_core not_in_core(non_core); + check_pred contains_nc(not_in_core, m, false); + // literals other than eq, neq, distinct + for (expr *a : m_lits) { + if (!is_internalized(a)) continue; + if (m_explicit_eq && get_term(a)->is_eq_or_neq()) continue; + expr_ref r(m); + r = mk_app(a); + if (non_core == nullptr || !contains_nc(r)) lits.push_back(r); } - expr_ref term_graph::to_expr() { - expr_ref_vector lits(m); - to_lits(lits); - return mk_and(lits); + // equalities + for (term *t : m_terms) { + if (t->is_eq_or_neq()) continue; + if (!t->is_repr()) continue; + mk_qe_lite_equalities(*t, lits, contains_nc); + } + // disequalities and distinct + // TODO: use seen to prevent duplicate disequalities + expr_ref e1(m), e2(m), d(m), distinct(m); + expr_ref_vector args(m); + for (auto p : m_deq_pairs) { + e1 = mk_app(*(p.first->get_repr())); + e2 = mk_app(*(p.second->get_repr())); + if (non_core == nullptr || (!contains_nc(e1) && !contains_nc(e2))) + lits.push_back(mk_neq(m, e1, e2)); } - void term_graph::reset() { - m_term2app.reset(); - m_pinned.reset(); - m_app2term.reset(); - std::for_each(m_terms.begin(), m_terms.end(), delete_proc()); - m_terms.reset(); - m_lits.reset(); - m_cg_table.reset(); - } - - class term_graph::projector { - term_graph &m_tg; - ast_manager &m; - u_map m_term2app; - u_map m_root2rep; - th_rewriter m_rewriter; - - model_ref m_model; - expr_ref_vector m_pinned; // tracks expr in the maps - - expr* mk_pure(term const& t) { - TRACE("qe", t.display(tout);); - expr* e = nullptr; - if (find_term2app(t, e)) return e; - e = t.get_expr(); - if (!is_app(e)) return nullptr; - app* a = ::to_app(e); - expr_ref_buffer kids(m); - for (term* ch : term::children(t)) { - // prefer a node that resembles current child, - // otherwise, pick a root representative, if present. - if (find_term2app(*ch, e)) { - kids.push_back(e); - } - else if (m_root2rep.find(ch->get_root().get_id(), e)) { - kids.push_back(e); - } - else { - return nullptr; - } - TRACE("qe_verbose", tout << *ch << " -> " << mk_pp(e, m) << "\n";); - } - expr_ref pure = m_rewriter.mk_app(a->get_decl(), kids.size(), kids.data()); - m_pinned.push_back(pure); - add_term2app(t, pure); - return pure; + for (auto t : m_deq_distinct) { + args.reset(); + for (auto c : t) { + d = mk_app(*(c->get_repr())); + if (non_core == nullptr || !contains_nc(d)) args.push_back(d); } + if (args.size() < 2) continue; + if (args.size() == 2) + distinct = mk_neq(m, args.get(0), args.get(1)); + else + distinct = m.mk_distinct(args.size(), args.data()); + lits.push_back(distinct); + } +} +expr_ref term_graph::to_expr(bool repick_repr) { + expr_ref_vector lits(m); + to_lits(lits, false, repick_repr); + return mk_and(lits); +} - bool is_better_rep(expr *t1, expr *t2) { - if (!t2) return t1 != nullptr; - return m.is_unique_value(t1) && !m.is_unique_value(t2); - } +void term_graph::reset() { + m_term2app.reset(); + m_pinned.reset(); + m_app2term.reset(); + std::for_each(m_terms.begin(), m_terms.end(), delete_proc()); + m_terms.reset(); + m_lits.reset(); + m_cg_table.reset(); +} + +class term_graph::projector { + term_graph &m_tg; + ast_manager &m; + u_map m_term2app; + u_map m_root2rep; + th_rewriter m_rewriter; + + model_ref m_model; + expr_ref_vector m_pinned; // tracks expr in the maps - struct term_depth { - bool operator()(term const* t1, term const* t2) const { - return get_depth(t1->get_expr()) < get_depth(t2->get_expr()); + expr *mk_pure(term const &t) { + TRACE("qe", t.display(tout);); + expr *e = nullptr; + if (find_term2app(t, e)) return e; + e = t.get_expr(); + if (!is_app(e)) return nullptr; + app *a = ::to_app(e); + expr_ref_buffer kids(m); + for (term *ch : term::children(t)) { + // prefer a node that resembles current child, + // otherwise, pick a root representative, if present. + if (find_term2app(*ch, e)) { kids.push_back(e); } + else if (m_root2rep.find(ch->get_root().get_id(), e)) { + kids.push_back(e); } - }; + else { return nullptr; } + TRACE("qe_verbose", tout << *ch << " -> " << mk_pp(e, m) << "\n";); + } + expr_ref pure = + m_rewriter.mk_app(a->get_decl(), kids.size(), kids.data()); + m_pinned.push_back(pure); + add_term2app(t, pure); + return pure; + } + bool is_better_rep(expr *t1, expr *t2) { + if (!t2) return t1 != nullptr; + return m.is_unique_value(t1) && !m.is_unique_value(t2); + } - void solve_core() { - ptr_vector worklist; - for (term * t : m_tg.m_terms) { - // skip pure terms - if (!in_term2app(*t)) { - worklist.push_back(t); - t->set_mark(true); - } + struct term_depth { + bool operator()(term const *t1, term const *t2) const { + return get_depth(t1->get_expr()) < get_depth(t2->get_expr()); + } + }; + + void solve_core() { + ptr_vector worklist; + for (term *t : m_tg.m_terms) { + // skip pure terms + if (!in_term2app(*t) && !t->is_eq_or_neq()) { + worklist.push_back(t); + t->set_mark(true); } - term_depth td; - std::sort(worklist.begin(), worklist.end(), td); + } + term_depth td; + std::sort(worklist.begin(), worklist.end(), td); - for (unsigned i = 0; i < worklist.size(); ++i) { - term* t = worklist[i]; - t->set_mark(false); - if (in_term2app(*t)) - continue; + for (unsigned i = 0; i < worklist.size(); ++i) { + term *t = worklist[i]; + t->set_mark(false); + if (in_term2app(*t)) continue; - expr* pure = mk_pure(*t); - if (!pure) - continue; + expr *pure = mk_pure(*t); + if (!pure) continue; - add_term2app(*t, pure); - expr* rep = nullptr; - // ensure that the root has a representative - m_root2rep.find(t->get_root().get_id(), rep); - - if (!rep) { - m_root2rep.insert(t->get_root().get_id(), pure); - for (term * p : term::parents(t->get_root())) { - SASSERT(!in_term2app(*p)); - if (!p->is_marked()) { - p->set_mark(true); - worklist.push_back(p); - } + add_term2app(*t, pure); + expr *rep = nullptr; + // ensure that the root has a representative + m_root2rep.find(t->get_root().get_id(), rep); + + if (!rep) { + m_root2rep.insert(t->get_root().get_id(), pure); + for (term *p : term::parents(t->get_root())) { + SASSERT(!in_term2app(*p)); + if (!p->is_marked()) { + p->set_mark(true); + worklist.push_back(p); } } } - m_tg.reset_marks(); } + m_tg.reset_marks(); + } - bool find_app(term &t, expr *&res) { - return - find_term2app(t, res) || - m_root2rep.find(t.get_root().get_id(), res); - } + bool find_app(term &t, expr *&res) { + return find_term2app(t, res) || + m_root2rep.find(t.get_root().get_id(), res); + } - bool find_app(expr *lit, expr *&res) { - term const* t = m_tg.get_term(lit); - return - find_term2app(*t, res) || - m_root2rep.find(t->get_root().get_id(), res); - } + bool find_app(expr *lit, expr *&res) { + term const *t = m_tg.get_term(lit); + return find_term2app(*t, res) || + m_root2rep.find(t->get_root().get_id(), res); + } - void mk_lits(expr_ref_vector &res) { - expr *e = nullptr; - for (auto *lit : m_tg.m_lits) { - if (!m.is_eq(lit) && find_app(lit, e)) - res.push_back(e); - } - TRACE("qe", tout << "literals: " << res << "\n";); + void mk_lits(expr_ref_vector &res) { + expr *e = nullptr; + for (auto *lit : m_tg.m_lits) { + if (!m.is_eq(lit) && find_app(lit, e)) res.push_back(e); } + TRACE("qe", tout << "literals: " << res << "\n";); + } - void lits2pure(expr_ref_vector& res) { - expr *e1 = nullptr, *e2 = nullptr, *p1 = nullptr, *p2 = nullptr; - for (auto *lit : m_tg.m_lits) { - if (m.is_eq(lit, e1, e2)) { - if (find_app(e1, p1) && find_app(e2, p2)) { - if (p1 != p2) - res.push_back(m.mk_eq(p1, p2)); - } - else - TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); + void lits2pure(expr_ref_vector &res) { + expr *e1 = nullptr, *e2 = nullptr, *e = nullptr, *p1 = nullptr, *p2 = nullptr; + for (auto *lit : m_tg.m_lits) { + if (m.is_eq(lit, e1, e2)) { + if (find_app(e1, p1) && find_app(e2, p2)) { + if (p1 != p2) res.push_back(m.mk_eq(p1, p2)); } - else if (m.is_distinct(lit)) { - ptr_buffer diff; - for (expr* arg : *to_app(lit)) - if (find_app(arg, p1)) - diff.push_back(p1); - if (diff.size() > 1) - res.push_back(m.mk_distinct(diff.size(), diff.data())); - else - TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); + else + TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); + } + else if (m.is_not(lit, e) && m.is_eq(e, e1, e2)) { + if (find_app(e1, p1) && find_app(e2, p2)) { + res.push_back(mk_neq(m, p1, p2)); } - else if (find_app(lit, p1)) - res.push_back(p1); - else + else TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); } - remove_duplicates(res); - TRACE("qe", tout << "literals: " << res << "\n";); + else if (m.is_distinct(lit)) { + ptr_buffer diff; + for (expr *arg : *to_app(lit)) + if (find_app(arg, p1)) diff.push_back(p1); + if (diff.size() > 1) + res.push_back(m.mk_distinct(diff.size(), diff.data())); + else + TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); + } + else if (find_app(lit, p1)) + res.push_back(p1); + else + TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";); } + remove_duplicates(res); + TRACE("qe", tout << "literals: " << res << "\n";); + } - void remove_duplicates(expr_ref_vector& v) { - obj_hashtable seen; - unsigned j = 0; - for (expr* e : v) { - if (!seen.contains(e)) { - v[j++] = e; - seen.insert(e); - } + void remove_duplicates(expr_ref_vector &v) { + obj_hashtable seen; + unsigned j = 0; + for (expr *e : v) { + if (!seen.contains(e)) { + v[j++] = e; + seen.insert(e); } - v.shrink(j); } + v.shrink(j); + } - vector> m_decl2terms; // terms that use function f - ptr_vector m_decls; - - void collect_decl2terms() { - // Collect the projected function symbols. - m_decl2terms.reset(); - m_decls.reset(); - for (term *t : m_tg.m_terms) { - expr* e = t->get_expr(); - if (!is_app(e)) continue; - if (!is_projected(*t)) continue; - app* a = to_app(e); - func_decl* d = a->get_decl(); - if (d->get_arity() == 0) continue; - unsigned id = d->get_small_id(); - m_decl2terms.reserve(id+1); - if (m_decl2terms[id].empty()) m_decls.push_back(d); - m_decl2terms[id].push_back(t); - } + vector> m_decl2terms; // terms that use function f + ptr_vector m_decls; + + void collect_decl2terms() { + // Collect the projected function symbols. + m_decl2terms.reset(); + m_decls.reset(); + for (term *t : m_tg.m_terms) { + if (t->is_eq_or_neq()) continue; + expr *e = t->get_expr(); + if (!is_app(e)) continue; + if (!is_projected(*t)) continue; + app *a = to_app(e); + func_decl *d = a->get_decl(); + if (d->get_arity() == 0) continue; + unsigned id = d->get_small_id(); + m_decl2terms.reserve(id + 1); + if (m_decl2terms[id].empty()) m_decls.push_back(d); + m_decl2terms[id].push_back(t); } + } - void args_are_distinct(expr_ref_vector& res) { - // - // for each projected function that occurs - // (may occur) in multiple congruence classes, - // produce assertions that non-congruent arguments - // are distinct. - // - for (func_decl* d : m_decls) { - unsigned id = d->get_small_id(); - ptr_vector const& terms = m_decl2terms[id]; - if (terms.size() <= 1) continue; - unsigned arity = d->get_arity(); - for (unsigned i = 0; i < arity; ++i) { - obj_hashtable roots, root_vals; - expr_ref_vector pinned(m); - for (term* t : terms) { - expr* arg = to_app(t->get_expr())->get_arg(i); - term const& root = m_tg.get_term(arg)->get_root(); - expr* r = root.get_expr(); - // if a model is given, then use the equivalence class induced - // by the model. Otherwise, use the congruence class. - if (m_model) { - expr_ref tmp(m); - tmp = (*m_model)(r); - if (!root_vals.contains(tmp)) { - root_vals.insert(tmp); - roots.insert(r); - pinned.push_back(tmp); - } - } - else { + void args_are_distinct(expr_ref_vector &res) { + // + // for each projected function that occurs + // (may occur) in multiple congruence classes, + // produce assertions that non-congruent arguments + // are distinct. + // + for (func_decl *d : m_decls) { + unsigned id = d->get_small_id(); + ptr_vector const &terms = m_decl2terms[id]; + if (terms.size() <= 1) continue; + unsigned arity = d->get_arity(); + for (unsigned i = 0; i < arity; ++i) { + obj_hashtable roots, root_vals; + expr_ref_vector pinned(m); + for (term *t : terms) { + expr *arg = to_app(t->get_expr())->get_arg(i); + term const &root = m_tg.get_term(arg)->get_root(); + expr *r = root.get_expr(); + // if a model is given, then use the equivalence class + // induced by the model. Otherwise, use the congruence + // class. + if (m_model) { + expr_ref tmp(m); + tmp = (*m_model)(r); + if (!root_vals.contains(tmp)) { + root_vals.insert(tmp); roots.insert(r); + pinned.push_back(tmp); } } - if (roots.size() > 1) { - ptr_buffer args; - for (expr* r : roots) { - args.push_back(r); - } - TRACE("qe", tout << "function: " << d->get_name() << "\n";); - res.push_back(m.mk_distinct(args.size(), args.data())); - } + else { roots.insert(r); } } - } - } - - void mk_distinct(expr_ref_vector& res) { - collect_decl2terms(); - args_are_distinct(res); - TRACE("qe", tout << res << "\n";); - } - - void mk_pure_equalities(const term &t, expr_ref_vector &res) { - SASSERT(t.is_root()); - expr *rep = nullptr; - if (!m_root2rep.find(t.get_id(), rep)) return; - obj_hashtable members; - members.insert(rep); - term const * r = &t; - do { - expr* member = nullptr; - if (find_term2app(*r, member) && !members.contains(member)) { - res.push_back (m.mk_eq (rep, member)); - members.insert(member); + if (roots.size() > 1) { + ptr_buffer args; + for (expr *r : roots) { args.push_back(r); } + TRACE("qe", tout << "function: " << d->get_name() << "\n";); + res.push_back(m.mk_distinct(args.size(), args.data())); } - r = &r->get_next(); } - while (r != &t); } + } - bool is_projected(const term &t) { - return m_tg.m_is_var(t); - } + void mk_distinct(expr_ref_vector &res) { + collect_decl2terms(); + args_are_distinct(res); + TRACE("qe", tout << res << "\n";); + } - void mk_unpure_equalities(const term &t, expr_ref_vector &res) { - expr *rep = nullptr; - if (!m_root2rep.find(t.get_id(), rep)) return; - obj_hashtable members; - members.insert(rep); - term const * r = &t; - do { - expr* member = mk_pure(*r); - SASSERT(member); - if (!members.contains(member) && - (!is_projected(*r) || !is_solved_eq(rep, member))) { - res.push_back(m.mk_eq(rep, member)); - members.insert(member); - } - r = &r->get_next(); + void mk_pure_equalities(const term &t, expr_ref_vector &res) { + SASSERT(t.is_root()); + expr *rep = nullptr; + if (!m_root2rep.find(t.get_id(), rep)) return; + obj_hashtable members; + members.insert(rep); + term const *r = &t; + do { + expr *member = nullptr; + if (find_term2app(*r, member) && !members.contains(member)) { + res.push_back(m.mk_eq(rep, member)); + members.insert(member); } - while (r != &t); + r = &r->get_next(); } + while (r != &t); + } - template - void mk_equalities(expr_ref_vector &res) { - for (term *t : m_tg.m_terms) { - if (!t->is_root()) continue; - if (!m_root2rep.contains(t->get_id())) continue; - if (pure) - mk_pure_equalities(*t, res); - else - mk_unpure_equalities(*t, res); + bool is_projected(const term &t) { return m_tg.m_is_var(t); } + + void mk_unpure_equalities(const term &t, expr_ref_vector &res) { + expr *rep = nullptr; + if (!m_root2rep.find(t.get_id(), rep)) return; + obj_hashtable members; + members.insert(rep); + term const *r = &t; + do { + expr *member = mk_pure(*r); + SASSERT(member); + if (!members.contains(member) && + (!is_projected(*r) || !is_solved_eq(rep, member))) { + res.push_back(m.mk_eq(rep, member)); + members.insert(member); } - TRACE("qe", tout << "literals: " << res << "\n";); + r = &r->get_next(); } + while (r != &t); + } - void mk_pure_equalities(expr_ref_vector &res) { - mk_equalities(res); + template void mk_equalities(expr_ref_vector &res) { + for (term *t : m_tg.m_terms) { + if (t->is_eq_or_neq()) continue; + if (!t->is_root()) continue; + if (!m_root2rep.contains(t->get_id())) continue; + if (pure) + mk_pure_equalities(*t, res); + else + mk_unpure_equalities(*t, res); } + TRACE("qe", tout << "literals: " << res << "\n";); + } - void mk_unpure_equalities(expr_ref_vector &res) { - mk_equalities(res); - } + void mk_pure_equalities(expr_ref_vector &res) { mk_equalities(res); } - // TBD: generalize for also the case of a (:var n) - bool is_solved_eq(expr *lhs, expr* rhs) { - return is_uninterp_const(rhs) && !occurs(rhs, lhs); - } - - /// Add equalities and disequalities for all pure representatives - /// based on their equivalence in the model - void model_complete(expr_ref_vector &res) { - if (!m_model) return; - obj_map val2rep; - model_evaluator mev(*m_model); - for (auto &kv : m_root2rep) { - expr *rep = kv.m_value; - expr_ref val(m); - expr *u = nullptr; - if (!mev.eval(rep, val)) continue; - if (val2rep.find(val, u)) { - res.push_back(m.mk_eq(u, rep)); - } - else { - val2rep.insert(val, rep); - } - } + void mk_unpure_equalities(expr_ref_vector &res) { + mk_equalities(res); + } - // TBD: optimize further based on implied values (e.g., - // some literals are forced to be true/false) and based on - // unique_values (e.g., (x=1 & y=1) does not require - // (x!=y) to be added - ptr_buffer reps; - for (auto &kv : val2rep) { - expr *rep = kv.m_value; - if (!m.is_unique_value(rep)) - reps.push_back(kv.m_value); - } + // TBD: generalize for also the case of a (:var n) + bool is_solved_eq(expr *lhs, expr *rhs) { + return is_uninterp_const(rhs) && !occurs(rhs, lhs); + } - if (reps.size() <= 1) return; - - // -- sort representatives, call mk_distinct on any range - // -- of the same sort longer than 1 - std::sort(reps.data(), reps.data() + reps.size(), sort_lt_proc()); - unsigned i = 0; - unsigned sz = reps.size(); - while (i < sz) { - sort* last_sort = res.get(i)->get_sort(); - unsigned j = i + 1; - while (j < sz && last_sort == reps.get(j)->get_sort()) {++j;} - if (j - i == 2) { - expr_ref d(m); - d = mk_neq(m, reps.get(i), reps.get(i+1)); - if (!m.is_true(d)) res.push_back(d); - } - else if (j - i > 2) - res.push_back(m.mk_distinct(j - i, reps.data() + i)); - i = j; + /// Add equalities and disequalities for all pure representatives + /// based on their equivalence in the model + void model_complete(expr_ref_vector &res) { + if (!m_model) return; + obj_map val2rep; + model_evaluator mev(*m_model); + for (auto &kv : m_root2rep) { + expr *rep = kv.m_value; + expr_ref val(m); + expr *u = nullptr; + if (!mev.eval(rep, val)) continue; + if (val2rep.find(val, u)) { res.push_back(m.mk_eq(u, rep)); } + else { val2rep.insert(val, rep); } + } + + // TBD: optimize further based on implied values (e.g., + // some literals are forced to be true/false) and based on + // unique_values (e.g., (x=1 & y=1) does not require + // (x!=y) to be added + ptr_buffer reps; + for (auto &kv : val2rep) { + expr *rep = kv.m_value; + if (!m.is_unique_value(rep)) reps.push_back(kv.m_value); + } + + if (reps.size() <= 1) return; + + // -- sort representatives, call mk_distinct on any range + // -- of the same sort longer than 1 + std::sort(reps.data(), reps.data() + reps.size(), sort_lt_proc()); + unsigned i = 0; + unsigned sz = reps.size(); + while (i < sz) { + sort *last_sort = res.get(i)->get_sort(); + unsigned j = i + 1; + while (j < sz && last_sort == reps.get(j)->get_sort()) { ++j; } + if (j - i == 2) { + expr_ref d(m); + d = mk_neq(m, reps.get(i), reps.get(i + 1)); + if (!m.is_true(d)) res.push_back(d); } - TRACE("qe", tout << "after distinct: " << res << "\n";); + else if (j - i > 2) + res.push_back(m.mk_distinct(j - i, reps.data() + i)); + i = j; } + TRACE("qe", tout << "after distinct: " << res << "\n";); + } - std::ostream& display(std::ostream& out) const { - m_tg.display(out); - out << "term2app:\n"; - for (auto const& kv : m_term2app) { - out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n"; - } - out << "root2rep:\n"; - for (auto const& kv : m_root2rep) { - out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n"; - } - return out; + std::ostream &display(std::ostream &out) const { + m_tg.display(out); + out << "term2app:\n"; + for (auto const &kv : m_term2app) { + out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n"; } - - public: - projector(term_graph &tg) : m_tg(tg), m(m_tg.m), m_rewriter(m), m_pinned(m) {} - - void add_term2app(term const& t, expr* a) { - m_term2app.insert(t.get_id(), a); + out << "root2rep:\n"; + for (auto const &kv : m_root2rep) { + out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n"; } + return out; + } - void del_term2app(term const& t) { - m_term2app.remove(t.get_id()); - } + public: + projector(term_graph &tg) + : m_tg(tg), m(m_tg.m), m_rewriter(m), m_pinned(m) {} - bool find_term2app(term const& t, expr*& r) { - return m_term2app.find(t.get_id(), r); - } + void add_term2app(term const &t, expr *a) { + m_term2app.insert(t.get_id(), a); + } - expr* find_term2app(term const& t) { - expr* r = nullptr; - find_term2app(t, r); - return r; - } + void del_term2app(term const &t) { m_term2app.remove(t.get_id()); } - bool in_term2app(term const& t) { - return m_term2app.contains(t.get_id()); - } + bool find_term2app(term const &t, expr *&r) { + return m_term2app.find(t.get_id(), r); + } - void set_model(model &mdl) { m_model = &mdl; } + expr *find_term2app(term const &t) { + expr *r = nullptr; + find_term2app(t, r); + return r; + } - void reset() { - m_tg.reset_marks(); - m_term2app.reset(); - m_root2rep.reset(); - m_pinned.reset(); - m_model.reset(); - } + bool in_term2app(term const &t) { return m_term2app.contains(t.get_id()); } - expr_ref_vector project() { - expr_ref_vector res(m); - purify(); - lits2pure(res); - mk_distinct(res); - reset(); - return res; - } + void set_model(model &mdl) { m_model = &mdl; } - expr_ref_vector get_ackerman_disequalities() { - expr_ref_vector res(m); - purify(); - lits2pure(res); - unsigned sz = res.size(); - mk_distinct(res); - reset(); - unsigned j = 0; - for (unsigned i = sz; i < res.size(); ++i) { - res[j++] = res.get(i); - } - res.shrink(j); - return res; - } + void reset() { + m_tg.reset_marks(); + m_term2app.reset(); + m_root2rep.reset(); + m_pinned.reset(); + m_model.reset(); + } - expr_ref_vector solve() { - expr_ref_vector res(m); - purify(); - solve_core(); - mk_lits(res); - mk_unpure_equalities(res); - reset(); - return res; - } + expr_ref_vector project() { + expr_ref_vector res(m); + purify(); + lits2pure(res); + mk_distinct(res); + reset(); + return res; + } - vector get_partition(model& mdl, bool include_bool) { - vector result; - expr_ref_vector pinned(m); - obj_map pid; - auto insert_val = [&](expr* a, expr* val) { - unsigned p = 0; - // NB. works for simple domains Integers, Rationals, - // but not for algebraic numerals. - if (!pid.find(val, p)) { - p = pid.size(); - pid.insert(val, p); - pinned.push_back(val); - result.push_back(expr_ref_vector(m)); - } - result[p].push_back(a); - }; - model::scoped_model_completion _smc(mdl, true); - for (term *t : m_tg.m_terms) { - expr* a = t->get_expr(); - if (!is_app(a)) - continue; - if (m.is_bool(a) && !include_bool) - continue; - expr_ref val = mdl(a); - insert_val(a, val); - } + expr_ref_vector get_ackerman_disequalities() { + expr_ref_vector res(m); + purify(); + lits2pure(res); + unsigned sz = res.size(); + mk_distinct(res); + reset(); + unsigned j = 0; + for (unsigned i = sz; i < res.size(); ++i) { res[j++] = res.get(i); } + res.shrink(j); + return res; + } - return result; - } + expr_ref_vector solve() { + expr_ref_vector res(m); + purify(); + solve_core(); + mk_lits(res); + mk_unpure_equalities(res); + reset(); + return res; + } - expr_ref_vector shared_occurrences(family_id fid) { - expr_ref_vector result(m); - for (term *t : m_tg.m_terms) { - expr* e = t->get_expr(); - if (e->get_sort()->get_family_id() != fid) continue; - for (term * p : term::parents(t->get_root())) { - expr* pe = p->get_expr(); - if (!is_app(pe)) continue; - if (to_app(pe)->get_family_id() == fid) continue; - if (to_app(pe)->get_family_id() == m.get_basic_family_id()) continue; - result.push_back(e); - break; - } + vector get_partition(model &mdl, bool include_bool) { + vector result; + expr_ref_vector pinned(m); + obj_map pid; + auto insert_val = [&](expr *a, expr *val) { + unsigned p = 0; + // NB. works for simple domains Integers, Rationals, + // but not for algebraic numerals. + if (!pid.find(val, p)) { + p = pid.size(); + pid.insert(val, p); + pinned.push_back(val); + result.push_back(expr_ref_vector(m)); } - return result; + result[p].push_back(a); + }; + model::scoped_model_completion _smc(mdl, true); + for (term *t : m_tg.m_terms) { + if (t->is_eq_or_neq()) continue; + expr *a = t->get_expr(); + if (!is_app(a)) continue; + if (m.is_bool(a) && !include_bool) continue; + expr_ref val = mdl(a); + insert_val(a, val); } + return result; + } - void purify() { - // - propagate representatives up over parents. - // use work-list + marking to propagate. - // - produce equalities over represented classes. - // - produce other literals over represented classes - // (walk disequalities in m_lits and represent - // lhs/rhs over decls or excluding decls) - - ptr_vector worklist; - for (term * t : m_tg.m_terms) { - worklist.push_back(t); - t->set_mark(true); - } - // traverse worklist in order of depth. - term_depth td; - std::sort(worklist.begin(), worklist.end(), td); - - for (unsigned i = 0; i < worklist.size(); ++i) { - term* t = worklist[i]; - t->set_mark(false); - if (in_term2app(*t)) - continue; - if (!t->is_theory() && is_projected(*t)) + expr_ref_vector shared_occurrences(family_id fid) { + expr_ref_vector result(m); + for (term *t : m_tg.m_terms) { + if (t->is_eq_or_neq()) continue; + expr *e = t->get_expr(); + if (e->get_sort()->get_family_id() != fid) continue; + for (term *p : term::parents(t->get_root())) { + expr *pe = p->get_expr(); + if (!is_app(pe)) continue; + if (to_app(pe)->get_family_id() == fid) continue; + if (to_app(pe)->get_family_id() == m.get_basic_family_id()) continue; + result.push_back(e); + break; + } + } + return result; + } - expr* pure = mk_pure(*t); - if (!pure) continue; - - add_term2app(*t, pure); - TRACE("qe_verbose", tout << "purified " << *t << " " << mk_pp(pure, m) << "\n";); - expr* rep = nullptr; // ensure that the root has a representative - m_root2rep.find(t->get_root().get_id(), rep); - - // update rep with pure if it is better - if (pure != rep && is_better_rep(pure, rep)) { - m_root2rep.insert(t->get_root().get_id(), pure); - for (term * p : term::parents(t->get_root())) { - del_term2app(*p); - if (!p->is_marked()) { - p->set_mark(true); - worklist.push_back(p); - } + void purify() { + // - propagate representatives up over parents. + // use work-list + marking to propagate. + // - produce equalities over represented classes. + // - produce other literals over represented classes + // (walk disequalities in m_lits and represent + // lhs/rhs over decls or excluding decls) + + ptr_vector worklist; + for (term *t : m_tg.m_terms) { + if (t->is_eq_or_neq()) continue; + worklist.push_back(t); + t->set_mark(true); + } + // traverse worklist in order of depth. + term_depth td; + std::sort(worklist.begin(), worklist.end(), td); + + for (unsigned i = 0; i < worklist.size(); ++i) { + term *t = worklist[i]; + t->set_mark(false); + if (in_term2app(*t)) continue; + if (!t->is_theory() && is_projected(*t)) continue; + + expr *pure = mk_pure(*t); + if (!pure) continue; + + add_term2app(*t, pure); + TRACE("qe_verbose", + tout << "purified " << *t << " " << mk_pp(pure, m) << "\n";); + expr *rep = nullptr; // ensure that the root has a representative + m_root2rep.find(t->get_root().get_id(), rep); + + // update rep with pure if it is better + if (pure != rep && is_better_rep(pure, rep)) { + m_root2rep.insert(t->get_root().get_id(), pure); + for (term *p : term::parents(t->get_root())) { + del_term2app(*p); + if (!p->is_marked()) { + p->set_mark(true); + worklist.push_back(p); } } } - - // Here we could also walk equivalence classes that - // contain interpreted values by sort and extract - // disequalities between non-unique value - // representatives. these disequalities are implied - // and can be mined using other means, such as theory - // aware core minimization - m_tg.reset_marks(); - TRACE("qe", display(tout << "after purify\n");); } - }; - - void term_graph::set_vars(func_decl_ref_vector const& decls, bool exclude) { - m_is_var.set_decls(decls, exclude); + // Here we could also walk equivalence classes that + // contain interpreted values by sort and extract + // disequalities between non-unique value + // representatives. these disequalities are implied + // and can be mined using other means, such as theory + // aware core minimization + m_tg.reset_marks(); + TRACE("qe", display(tout << "after purify\n");); } - - expr_ref_vector term_graph::project() { - // reset solved vars so that they are not considered pure by projector - m_is_var.reset_solved(); - term_graph::projector p(*this); - return p.project(); +}; + +// produce a quantifier reduction of the formula stored in the term graph +// removes from `vars` the variables that have a ground representative +// modifies `vars` to keep the variables that could not be eliminated +void term_graph::qel(app_ref_vector &vars, expr_ref &fml, + std::function *non_core) { + unsigned i = 0; + for (auto v : vars) { + if (is_internalized(v)) { vars[i++] = v; } } - - expr_ref_vector term_graph::project(model &mdl) { - m_is_var.reset_solved(); - term_graph::projector p(*this); - p.set_model(mdl); - return p.project(); + vars.shrink(i); + pick_repr(); + refine_repr(); + + expr_ref_vector lits(m); + to_lits_qe_lite(lits, non_core); + if (lits.size() == 0) + fml = m.mk_true(); + else if (lits.size() == 1) + fml = lits[0].get(); + else + fml = m.mk_and(lits); + + // Remove all variables that are do not appear in the formula + expr_sparse_mark mark; + mark_all_sub_expr marker(mark); + quick_for_each_expr(marker, fml); + i = 0; + for (auto v : vars) { + if (mark.is_marked(v)) vars[i++] = v; } + vars.shrink(i); +} - expr_ref_vector term_graph::solve() { - // reset solved vars so that they are not considered pure by projector - m_is_var.reset_solved(); - term_graph::projector p(*this); - return p.solve(); - } +void term_graph::set_vars(func_decl_ref_vector const &decls, bool exclude) { + m_is_var.set_decls(decls, exclude); +} - expr_ref_vector term_graph::get_ackerman_disequalities() { - m_is_var.reset_solved(); - dealloc(m_projector); - m_projector = alloc(term_graph::projector, *this); - return m_projector->get_ackerman_disequalities(); - } +void term_graph::set_vars(app_ref_vector const &vars, bool exclude) { + m_is_var.set_decls(vars, exclude); +} - vector term_graph::get_partition(model& mdl) { - dealloc(m_projector); - m_projector = alloc(term_graph::projector, *this); - return m_projector->get_partition(mdl, false); - } +void term_graph::add_vars(app_ref_vector const &vars) { + m_is_var.add_decls(vars); +} - expr_ref_vector term_graph::shared_occurrences(family_id fid) { - term_graph::projector p(*this); - return p.shared_occurrences(fid); - } +void term_graph::add_var(app *var) { m_is_var.add_decl(var); } - void term_graph::add_model_based_terms(model& mdl, expr_ref_vector const& terms) { - for (expr* t : terms) { - internalize_term(t); - } - m_is_var.reset_solved(); - - SASSERT(!m_projector); - m_projector = alloc(term_graph::projector, *this); - - // retrieve partition of terms - vector equivs = m_projector->get_partition(mdl, true); - - // merge term graph on equal terms. - for (auto const& cs : equivs) { - term* t0 = get_term(cs[0]); - for (unsigned i = 1; i < cs.size(); ++i) { - merge(*t0, *get_term(cs[i])); - } - } - TRACE("qe", - for (auto & es : equivs) { - tout << "equiv: "; - for (expr* t : es) tout << expr_ref(t, m) << " "; - tout << "\n"; - } - display(tout);); - // create representatives for shared/projected variables. - m_projector->set_model(mdl); - m_projector->purify(); +expr_ref_vector term_graph::project() { + // reset solved vars so that they are not considered pure by projector + m_is_var.reset_solved(); + term_graph::projector p(*this); + return p.project(); +} - } +expr_ref_vector term_graph::project(model &mdl) { + m_is_var.reset_solved(); + term_graph::projector p(*this); + p.set_model(mdl); + return p.project(); +} + +expr_ref_vector term_graph::solve() { + // reset solved vars so that they are not considered pure by projector + m_is_var.reset_solved(); + term_graph::projector p(*this); + return p.solve(); +} + +expr_ref_vector term_graph::get_ackerman_disequalities() { + m_is_var.reset_solved(); + dealloc(m_projector); + m_projector = alloc(term_graph::projector, *this); + return m_projector->get_ackerman_disequalities(); +} + +vector term_graph::get_partition(model &mdl) { + dealloc(m_projector); + m_projector = alloc(term_graph::projector, *this); + return m_projector->get_partition(mdl, false); +} + +expr_ref_vector term_graph::shared_occurrences(family_id fid) { + term_graph::projector p(*this); + return p.shared_occurrences(fid); +} + +void term_graph::add_model_based_terms(model &mdl, + expr_ref_vector const &terms) { + for (expr *t : terms) { internalize_term(t); } + m_is_var.reset_solved(); + + SASSERT(!m_projector); + m_projector = alloc(term_graph::projector, *this); + + // retrieve partition of terms + vector equivs = m_projector->get_partition(mdl, true); - expr* term_graph::rep_of(expr* e) { - SASSERT(m_projector); - term* t = get_term(e); - SASSERT(t && "only get representatives"); - return m_projector->find_term2app(*t); + // merge term graph on equal terms. + for (auto const &cs : equivs) { + term *t0 = get_term(cs[0]); + for (unsigned i = 1; i < cs.size(); ++i) { + merge(*t0, *get_term(cs[i])); + } } - - expr_ref_vector term_graph::dcert(model& mdl, expr_ref_vector const& lits) { - TRACE("qe", tout << "dcert " << lits << "\n";); - struct pair_t { - expr* a, *b; - pair_t(): a(nullptr), b(nullptr) {} - pair_t(expr* _a, expr* _b):a(_a), b(_b) { - if (a->get_id() > b->get_id()) std::swap(a, b); + TRACE( + "qe", for (auto &es + : equivs) { + tout << "equiv: "; + for (expr *t : es) tout << expr_ref(t, m) << " "; + tout << "\n"; + } display(tout);); + // create representatives for shared/projected variables. + m_projector->set_model(mdl); + m_projector->purify(); +} + +expr *term_graph::rep_of(expr *e) { + SASSERT(m_projector); + term *t = get_term(e); + SASSERT(t && "only get representatives"); + return m_projector->find_term2app(*t); +} + +expr_ref_vector term_graph::dcert(model &mdl, expr_ref_vector const &lits) { + TRACE("qe", tout << "dcert " << lits << "\n";); + struct pair_t { + expr *a, *b; + pair_t() : a(nullptr), b(nullptr) {} + pair_t(expr *_a, expr *_b) : a(_a), b(_b) { + if (a->get_id() > b->get_id()) std::swap(a, b); + } + struct hash { + unsigned operator()(pair_t const &p) const { + return mk_mix(p.a ? p.a->hash() : 0, p.b ? p.b->hash() : 0, 1); } - struct hash { - unsigned operator()(pair_t const& p) const { return mk_mix(p.a ? p.a->hash() : 0, p.b ? p.b->hash() : 0, 1); } - }; - struct eq { - bool operator()(pair_t const& a, pair_t const& b) const { return a.a == b.a && a.b == b.b; } - }; }; - hashtable diseqs; - expr_ref_vector result(m); - add_lits(lits); - svector todo; - - for (expr* e : lits) { - expr* ne, *a, *b; - if (m.is_not(e, ne) && m.is_eq(ne, a, b) && (is_uninterp(a) || is_uninterp(b))) { - diseqs.insert(pair_t(a, b)); + struct eq { + bool operator()(pair_t const &a, pair_t const &b) const { + return a.a == b.a && a.b == b.b; } - else if (is_uninterp(e)) { - diseqs.insert(pair_t(e, m.mk_false())); - } - else if (m.is_not(e, ne) && is_uninterp(ne)) { - diseqs.insert(pair_t(ne, m.mk_true())); - } + }; + }; + hashtable diseqs; + expr_ref_vector result(m); + add_lits(lits); + svector todo; + + for (expr *e : lits) { + expr *ne, *a, *b; + if (m.is_not(e, ne) && m.is_eq(ne, a, b) && + (is_uninterp(a) || is_uninterp(b))) { + diseqs.insert(pair_t(a, b)); } - for (auto& p : diseqs) todo.push_back(p); - - auto const partitions = get_partition(mdl); - obj_map term2pid; - unsigned id = 0; - for (auto const& vec : partitions) { - for (expr* e : vec) term2pid.insert(e, id); - ++id; + else if (is_uninterp(e)) { diseqs.insert(pair_t(e, m.mk_false())); } + else if (m.is_not(e, ne) && is_uninterp(ne)) { + diseqs.insert(pair_t(ne, m.mk_true())); } - expr_ref_vector empty(m); - auto partition_of = [&](expr* e) { - unsigned pid; - if (!term2pid.find(e, pid)) - return empty; - return partitions[pid]; - }; - auto in_table = [&](expr* a, expr* b) { - return diseqs.contains(pair_t(a, b)); - }; - auto same_function = [](expr* a, expr* b) { - return is_app(a) && is_app(b) && - to_app(a)->get_decl() == to_app(b)->get_decl() && to_app(a)->get_family_id() == null_family_id; - }; + } + for (auto &p : diseqs) todo.push_back(p); + + auto const partitions = get_partition(mdl); + obj_map term2pid; + unsigned id = 0; + for (auto const &vec : partitions) { + for (expr *e : vec) term2pid.insert(e, id); + ++id; + } + expr_ref_vector empty(m); + auto partition_of = [&](expr *e) { + unsigned pid; + if (!term2pid.find(e, pid)) return empty; + return partitions[pid]; + }; + auto in_table = [&](expr *a, expr *b) { + return diseqs.contains(pair_t(a, b)); + }; + auto same_function = [](expr *a, expr *b) { + return is_app(a) && is_app(b) && + to_app(a)->get_decl() == to_app(b)->get_decl() && + to_app(a)->get_family_id() == null_family_id; + }; - // make sure that diseqs is closed under function applications - // of uninterpreted functions. - for (unsigned idx = 0; idx < todo.size(); ++idx) { - auto p = todo[idx]; - for (expr* t1 : partition_of(p.a)) { - for (expr* t2 : partition_of(p.b)) { - if (same_function(t1, t2)) { - unsigned sz = to_app(t1)->get_num_args(); - bool found = false; - pair_t q(t1, t2); - for (unsigned i = 0; i < sz; ++i) { - expr* arg1 = to_app(t1)->get_arg(i); - expr* arg2 = to_app(t2)->get_arg(i); - if (mdl(arg1) == mdl(t2)) { - continue; - } - if (in_table(arg1, arg2)) { - found = true; - break; - } - q = pair_t(arg1, arg2); - } - if (!found) { - diseqs.insert(q); - todo.push_back(q); - result.push_back(m.mk_not(m.mk_eq(q.a, q.b))); + // make sure that diseqs is closed under function applications + // of uninterpreted functions. + for (unsigned idx = 0; idx < todo.size(); ++idx) { + auto p = todo[idx]; + for (expr *t1 : partition_of(p.a)) { + for (expr *t2 : partition_of(p.b)) { + if (same_function(t1, t2)) { + unsigned sz = to_app(t1)->get_num_args(); + bool found = false; + pair_t q(t1, t2); + for (unsigned i = 0; i < sz; ++i) { + expr *arg1 = to_app(t1)->get_arg(i); + expr *arg2 = to_app(t2)->get_arg(i); + if (mdl(arg1) == mdl(t2)) { continue; } + if (in_table(arg1, arg2)) { + found = true; + break; } + q = pair_t(arg1, arg2); + } + if (!found) { + diseqs.insert(q); + todo.push_back(q); + result.push_back(m.mk_not(m.mk_eq(q.a, q.b))); } } } } - for (auto const& terms : partitions) { - expr* a = nullptr; - for (expr* b : terms) { - if (is_uninterp(b)) { - if (a) - result.push_back(m.mk_eq(a, b)); - else - a = b; - } + } + for (auto const &terms : partitions) { + expr *a = nullptr; + for (expr *b : terms) { + if (is_uninterp(b)) { + if (a) + result.push_back(m.mk_eq(a, b)); + else + a = b; } } - TRACE("qe", tout << result << "\n";); - return result; } + TRACE("qe", tout << result << "\n";); + return result; +} + +void term_graph::cground_percolate_up(term *t) { + SASSERT(t->is_class_gr()); + term *it = t; + // there is a cgr term in all ground classes + while (!it->is_cgr()) { + it = &it->get_next(); + SASSERT(it != t); + } + + ptr_vector todo; + todo.push_back(it); + cground_percolate_up(todo); +} + +void term_graph::cground_percolate_up(ptr_vector &todo) { + term *t; + + while (!todo.empty()) { + t = todo.back(); + todo.pop_back(); + t->set_cgr(true); + t->set_class_gr(true); + for (auto p : term::parents(t->get_root())) + if (!p->is_cgr() && p->all_children_ground()) todo.push_back(p); + } +} +void term_graph::compute_cground() { + for (auto t : m_terms) { + t->set_cgr(false); + t->set_class_gr(false); + } + ptr_vector todo; + for (auto t : m_terms) { + if (t->is_gr()) { todo.push_back(t); } + } + cground_percolate_up(todo); + DEBUG_CODE(for (auto t + : m_terms) { + bool isclsg = true; + for (auto c : term::children(t)) isclsg &= c->is_class_gr(); + SASSERT(t->deg() == 0 || !isclsg || t->is_cgr()); + SASSERT(t->deg() == 0 || isclsg || !t->is_cgr()); + }); } +} // namespace mbp diff --git a/src/qe/mbp/mbp_term_graph.h b/src/qe/mbp/mbp_term_graph.h index 37d1e7b8dce..58afcaf02d6 100644 --- a/src/qe/mbp/mbp_term_graph.h +++ b/src/qe/mbp/mbp_term_graph.h @@ -12,6 +12,12 @@ Module Name: Author: Arie Gurfinkel + Hari Govind V K (hgvk94) + Isabel Garcia (igcontreras) + +Revision History: + + Added implementation of qe_lite using term graph Notes: @@ -19,138 +25,237 @@ Module Name: #pragma once #include "ast/ast.h" +#include "ast/expr_functors.h" #include "ast/is_variable_test.h" -#include "util/plugin_manager.h" -#include "qe/mbp/mbp_solve_plugin.h" #include "model/model.h" +#include "qe/mbp/mbp_solve_plugin.h" +#include "util/plugin_manager.h" namespace mbp { +namespace is_ground_ns { +struct proc; +struct found; +} // namespace is_ground_ns +class term; + +class term_graph { + class projector; + friend struct is_ground_ns::proc; + friend struct is_ground_ns::found; + + class is_variable_proc : public ::is_variable_proc { + bool m_exclude; + obj_hashtable m_decls, m_solved; - class term; - - class term_graph { - class projector; - - class is_variable_proc : public ::is_variable_proc { - bool m_exclude; - obj_hashtable m_decls, m_solved; - public: - bool operator()(const expr *e) const override; - bool operator()(const term &t) const; - - void set_decls(const func_decl_ref_vector &decls, bool exclude); - void mark_solved(const expr *e); - void reset_solved() {m_solved.reset();} - void reset() {m_decls.reset(); m_solved.reset(); m_exclude = true;} - }; - - struct term_hash { unsigned operator()(term const* t) const; }; - struct term_eq { bool operator()(term const* a, term const* b) const; }; - ast_manager & m; - ptr_vector m_terms; - expr_ref_vector m_lits; // NSB: expr_ref_vector? - u_map m_app2term; - ast_ref_vector m_pinned; - projector* m_projector; - u_map m_term2app; - plugin_manager m_plugins; - ptr_hashtable m_cg_table; - vector> m_merge; - - term_graph::is_variable_proc m_is_var; - void merge(term &t1, term &t2); - void merge_flush(); - - term *mk_term(expr *t); - term *get_term(expr *t); - - term *internalize_term(expr *t); - void internalize_eq(expr *a1, expr *a2); - void internalize_lit(expr *lit); - - bool is_internalized(expr *a); - - bool term_lt(term const &t1, term const &t2); - void pick_root (term &t); - void pick_roots(); - - void reset_marks(); - bool marks_are_clear(); - - expr* mk_app_core(expr* a); - expr_ref mk_app(term const &t); - expr* mk_pure(term& t); - expr_ref mk_app(expr *a); - void mk_equalities(term const &t, expr_ref_vector &out); - void mk_all_equalities(term const &t, expr_ref_vector &out); - void display(std::ostream &out); - - bool is_pure_def(expr* atom, expr *& v); - public: - term_graph(ast_manager &m); - ~term_graph(); - - void set_vars(func_decl_ref_vector const& decls, bool exclude); - - ast_manager& get_ast_manager() const { return m;} - - void add_lit(expr *lit); - void add_lits(expr_ref_vector const &lits) { for (expr* e : lits) add_lit(e); } - void add_eq(expr* a, expr* b) { internalize_eq(a, b); } - - void reset(); - - // deprecate? - void to_lits(expr_ref_vector &lits, bool all_equalities = false); - expr_ref to_expr(); - - /** - * Return literals obtained by projecting added literals - * onto the vocabulary of decls (if exclude is false) or outside the - * vocabulary of decls (if exclude is true). - */ - expr_ref_vector project(); - expr_ref_vector solve(); - expr_ref_vector project(model &mdl); - - /** - * Return disequalities to ensure that disequalities between - * excluded functions are preserved. - * For example if f(a) = b, f(c) = d, and b and d are not - * congruent, then produce the disequality a != c. - */ - expr_ref_vector get_ackerman_disequalities(); - - /** - * Produce model-based disequality - * certificate corresponding to - * definition in BGVS 2020. - * A disequality certificate is a reduced set of - * disequalities, true under mdl, such that the literals - * can be satisfied when non-shared symbols are projected. - */ - expr_ref_vector dcert(model& mdl, expr_ref_vector const& lits); - - /** - * Produce a model-based partition. - */ - vector get_partition(model& mdl); - - /** - * Extract shared occurrences of terms whose sort are - * fid, but appear in a context that is not fid. - * for example f(x + y) produces the shared occurrence - * x + y when f is uninterpreted and x + y has sort Int or Real. - */ - expr_ref_vector shared_occurrences(family_id fid); - - /** - * Map expression that occurs in added literals into representative if it exists. - */ - void add_model_based_terms(model& mdl, expr_ref_vector const& terms); - expr* rep_of(expr* e); + bool operator()(const expr *e) const override; + bool operator()(const term &t) const; + + void set_decls(const func_decl_ref_vector &decls, bool exclude); + void set_decls(const app_ref_vector &vars, bool exclude); + void add_decls(const app_ref_vector &vars); + void add_decl(app *var); + void mark_solved(const expr *e); + void reset_solved() { m_solved.reset(); } + void reset() { + m_decls.reset(); + m_solved.reset(); + m_exclude = true; + } + bool contains(func_decl *f) { return m_decls.contains(f) == m_exclude; } + }; + + class is_non_core : public i_expr_pred { + std::function *m_non_core; + public: + is_non_core(std::function *nc) : m_non_core(nc) {} + bool operator()(expr *n) override { + if (m_non_core == nullptr) return false; + return (*m_non_core)(n); + } + }; + + struct term_hash { + unsigned operator()(term const *t) const; + }; + struct term_eq { + bool operator()(term const *a, term const *b) const; + }; + ast_manager & m; + ptr_vector m_terms; + expr_ref_vector m_lits; + u_map m_app2term; + ast_ref_vector m_pinned; + projector * m_projector = nullptr; + bool m_explicit_eq = false; + bool m_repick_repr = false; + u_map m_term2app; // any representative change invalidates this cache + plugin_manager m_plugins; + ptr_hashtable m_cg_table; + vector> m_merge; + + term_graph::is_variable_proc m_is_var; + + void merge(term &t1, term &t2); + void merge_flush(); + + term *mk_term(expr *t); + term *get_term(expr *t); + term *get_term(func_decl *f); + + term *internalize_term(expr *t); + void internalize_eq(expr *a1, expr *a2); + void internalize_lit(expr *lit); + void internalize_distinct(expr *d); + void internalize_deq(expr *a1, expr *a2); + + bool is_internalized(expr *a); + bool is_ground(expr *e); + + bool term_lt(term const &t1, term const &t2); + void pick_repr_percolate_up(ptr_vector &todo); + void pick_repr_class(term *t); + void pick_repr(); + + void reset_marks(); + void reset_marks2(); + bool marks_are_clear(); + + expr *mk_app_core(expr *a); + expr_ref mk_app(term &t); + expr *mk_pure(term &t); + expr_ref mk_app(expr *a); + void mk_equalities(term &t, expr_ref_vector &out); + void mk_all_equalities(term &t, expr_ref_vector &out); + void mk_qe_lite_equalities(term &t, expr_ref_vector &out, + check_pred ¬_in_core); + void display(std::ostream &out); + + bool is_pure_def(expr *atom, expr *&v); + void cground_percolate_up(ptr_vector &); + void cground_percolate_up(term *t); + void compute_cground(); + +public: + term_graph(ast_manager &m); + ~term_graph(); + + const expr_ref_vector &get_lits() const { return m_lits; } + void get_terms(expr_ref_vector &res, bool exclude_cground = true); + bool is_cgr(expr *e); + unsigned size() { return m_terms.size(); } + + void set_vars(func_decl_ref_vector const &decls, bool exclude = true); + void set_vars(app_ref_vector const &vars, bool exclude = true); + void add_vars(app_ref_vector const &vars); + void add_var(app *var); + + ast_manager &get_ast_manager() const { return m; } + + void add_lit(expr *lit); + void add_lits(expr_ref_vector const &lits) { + for (expr *e : lits) add_lit(e); + } + void add_eq(expr *a, expr *b) { internalize_eq(a, b); } + void add_deq(expr *a, expr *b) { internalize_deq(a, b); } + + void reset(); + + // deprecate? + void to_lits(expr_ref_vector &lits, bool all_equalities = false, + bool repick_repr = true); + void to_lits_qe_lite(expr_ref_vector &lits, + std::function *non_core = nullptr); + expr_ref to_expr(bool repick_repr = true); + + /** + * Return literals obtained by projecting added literals + * onto the vocabulary of decls (if exclude is false) or outside the + * vocabulary of decls (if exclude is true). + */ + expr_ref_vector project(); + expr_ref_vector solve(); + expr_ref_vector project(model &mdl); + + /** + * Return disequalities to ensure that disequalities between + * excluded functions are preserved. + * For example if f(a) = b, f(c) = d, and b and d are not + * congruent, then produce the disequality a != c. + */ + expr_ref_vector get_ackerman_disequalities(); + + /** + * Produce model-based disequality + * certificate corresponding to + * definition in BGVS 2020. + * A disequality certificate is a reduced set of + * disequalities, true under mdl, such that the literals + * can be satisfied when non-shared symbols are projected. + */ + expr_ref_vector dcert(model &mdl, expr_ref_vector const &lits); + + /** + * Produce a model-based partition. + */ + vector get_partition(model &mdl); + + /** + * Extract shared occurrences of terms whose sort are + * fid, but appear in a context that is not fid. + * for example f(x + y) produces the shared occurrence + * x + y when f is uninterpreted and x + y has sort Int or Real. + */ + expr_ref_vector shared_occurrences(family_id fid); + + /** + * Map expression that occurs in added literals into representative if it + * exists. + */ + void add_model_based_terms(model &mdl, expr_ref_vector const &terms); + expr *rep_of(expr *e); + + using deqs = bit_vector; + struct add_deq_proc { + unsigned m_deq_cnt = 0; + void inc_count(); + void operator()(term *t1, term *t2); + void operator()(ptr_vector &ts); }; -} + // -- disequalities added for output + vector> m_deq_pairs; + // -- maybe they are not necessary since they are in the original formula + vector> m_deq_distinct; + + expr_ref_vector non_ground_terms(); + void gr_terms_to_lits(expr_ref_vector &lits, bool all_equalities); + // produce a quantifier reduction of the formula stored in the term graph + // output of qel will not contain expression e s.t. non_core(e) == true + void qel(app_ref_vector &vars, expr_ref &fml, + std::function *non_core = nullptr); + bool has_val_in_class(expr *e); + app *get_const_in_class(expr *e); + void set_explicit_eq() { m_explicit_eq = true; } + +private: + add_deq_proc m_add_deq; + void refine_repr_class(term *t); + void refine_repr(); + bool makes_cycle(term *t); +}; + +namespace is_ground_ns { +struct found {}; +struct proc { + term_graph::is_variable_proc &m_is_var; + proc(term_graph::is_variable_proc &is_var) : m_is_var(is_var) {} + void operator()(var *n) const {} + void operator()(app const *n) const { + if (m_is_var.contains(n->get_decl())) throw found(); + } + void operator()(quantifier *n) const {} +}; +} // namespace is_ground_ns +} // namespace mbp diff --git a/src/qe/mbp/mbp_tg_plugins.h b/src/qe/mbp/mbp_tg_plugins.h new file mode 100644 index 00000000000..ab4a5433322 --- /dev/null +++ b/src/qe/mbp/mbp_tg_plugins.h @@ -0,0 +1,34 @@ +/*++ + + Module Name: + + mbp_tg_plugins.h + +Abstract: + + Model Based Projection for a theory + +Author: + + Hari Govind V K (hgvk94) 2022-07-12 + +Revision History: + + +--*/ +#pragma once +#include "ast/ast.h" +#include "qe/mbp/mbp_qel_util.h" +#include "qe/mbp/mbp_term_graph.h" +#include "util/obj_hashtable.h" + +class mbp_tg_plugin { + public: + // iterate through all terms in m_tg and apply all theory MBP rules once + // returns true if any rules were applied + virtual bool apply() { return false; }; + virtual ~mbp_tg_plugin() = default; + virtual void use_model() { }; + virtual void get_new_vars(app_ref_vector*&) { }; + virtual family_id get_family_id() const { return null_family_id; }; +}; diff --git a/src/qe/nlqsat.cpp b/src/qe/nlqsat.cpp index 3a1fa19c6a2..d81048c3d39 100644 --- a/src/qe/nlqsat.cpp +++ b/src/qe/nlqsat.cpp @@ -558,9 +558,13 @@ namespace qe { div_rewriter_cfg(nlqsat& s): m(s.m), a(s.m), m_zero(a.mk_real(0), m) {} ~div_rewriter_cfg() {} br_status reduce_app(func_decl* f, unsigned sz, expr* const* args, expr_ref& result, proof_ref& pr) { - rational r(1); + rational r1, r(1); + if (a.is_div(f) && sz == 2 && a.is_numeral(args[0], r1) && a.is_numeral(args[1], r) && !r.is_zero()) { + result = a.mk_real(r1 / r); + return BR_DONE; + } if (is_decl_of(f, a.get_family_id(), OP_DIV) && - sz == 2 && (!a.is_numeral(args[1], r) || r.is_zero()) && + sz == 2 && is_ground(args[0]) && is_ground(args[1])) { result = m.mk_fresh_const("div", a.mk_real()); m_divs.push_back(div(m, args[0], args[1], to_app(result))); @@ -609,9 +613,6 @@ namespace qe { } expr* n1, *n2; rational r; - if (a.is_div(n, n1, n2) && a.is_numeral(n2, r) && !r.is_zero()) { - return; - } if (a.is_power(n, n1, n2) && a.is_numeral(n2, r) && r.is_unsigned() && r.is_pos()) { return; } diff --git a/src/qe/qe.cpp b/src/qe/qe.cpp index 12365b2040e..325f343b467 100644 --- a/src/qe/qe.cpp +++ b/src/qe/qe.cpp @@ -157,8 +157,8 @@ namespace qe { // // Partition variables into buckets. - // The var_paritions buckets covering disjoint subsets of - // the conjuncts. The remaining variables in vars are non-partioned. + // The var_partitions buckets covering disjoint subsets of + // the conjuncts. The remaining variables in vars are non-partitioned. // bool partition_vars( unsigned num_vars, diff --git a/src/qe/qe_mbp.cpp b/src/qe/qe_mbp.cpp index 9f5d9063cd0..97f12238d41 100644 --- a/src/qe/qe_mbp.cpp +++ b/src/qe/qe_mbp.cpp @@ -18,25 +18,128 @@ Revision History: --*/ -#include "ast/rewriter/expr_safe_replace.h" +#include "qe/qe_mbp.h" #include "ast/ast_pp.h" #include "ast/ast_util.h" -#include "ast/occurs.h" -#include "ast/rewriter/th_rewriter.h" #include "ast/expr_functors.h" #include "ast/for_each_expr.h" +#include "ast/occurs.h" +#include "ast/rewriter/expr_safe_replace.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/rewriter/rewriter.h" +#include "ast/rewriter/rewriter_def.h" #include "ast/scoped_proof.h" -#include "qe/qe_mbp.h" +#include "ast/seq_decl_plugin.h" +#include "util/gparams.h" +#include "model/model_evaluator.h" +#include "model/model_pp.h" +#include "qe/lite/qe_lite_tactic.h" +#include "qe/lite/qel.h" #include "qe/mbp/mbp_arith.h" #include "qe/mbp/mbp_arrays.h" +#include "qe/mbp/mbp_qel.h" #include "qe/mbp/mbp_datatypes.h" -#include "qe/lite/qe_lite_tactic.h" -#include "model/model_pp.h" -#include "model/model_evaluator.h" - using namespace qe; +namespace { +// rewrite select(store(a, i, k), j) into k if m \models i = j and select(a, j) if m \models i != j + struct rd_over_wr_rewriter : public default_rewriter_cfg { + ast_manager &m; + array_util m_arr; + model_evaluator m_eval; + expr_ref_vector m_sc; + + rd_over_wr_rewriter(ast_manager& man, model& mdl): m(man), m_arr(m), m_eval(mdl), m_sc(m) { + m_eval.set_model_completion(false); + } + + br_status reduce_app(func_decl *f, unsigned num, expr *const *args, + expr_ref &result, proof_ref &result_pr) { + if (m_arr.is_select(f) && m_arr.is_store(args[0])) { + expr_ref ind1(m), ind2(m); + ind1 = m_eval(args[1]); + ind2 = m_eval(to_app(args[0])->get_arg(1)); + if (ind1 == ind2) { + result = to_app(args[0])->get_arg(2); + m_sc.push_back(m.mk_eq(args[1], to_app(args[0])->get_arg(1))); + return BR_DONE; + } + m_sc.push_back(m.mk_not(m.mk_eq(args[1], to_app(args[0])->get_arg(1)))); + expr_ref_vector new_args(m); + new_args.push_back(to_app(args[0])->get_arg(0)); + new_args.push_back(args[1]); + result = m_arr.mk_select(new_args); + return BR_REWRITE1; + } + return BR_FAILED; + } + }; +// rewrite all occurrences of (as const arr c) to (as const arr v) where v = m_eval(c) + struct app_const_arr_rewriter : public default_rewriter_cfg { + ast_manager &m; + array_util m_arr; + datatype_util m_dt_util; + model_evaluator m_eval; + expr_ref val; + + app_const_arr_rewriter(ast_manager& man, model& mdl): m(man), m_arr(m), m_dt_util(m), m_eval(mdl), val(m) { + m_eval.set_model_completion(false); + } + br_status reduce_app(func_decl *f, unsigned num, expr *const *args, + expr_ref &result, proof_ref &result_pr) { + if (m_arr.is_const(f) && !m.is_value(args[0])) { + val = m_eval(args[0]); + SASSERT(m.is_value(val)); + result = m_arr.mk_const_array(f->get_range(), val); + return BR_DONE; + } + if (m_dt_util.is_constructor(f)) { + // cons(head(x), tail(x)) --> x + ptr_vector const *accessors = + m_dt_util.get_constructor_accessors(f); + + SASSERT(num == accessors->size()); + // -- all accessors must have exactly one argument + if (any_of(*accessors, [&](const func_decl* acc) { return acc->get_arity() != 1; })) { + return BR_FAILED; + } + + if (num >= 1 && is_app(args[0]) && to_app(args[0])->get_decl() == accessors->get(0)) { + bool is_all = true; + expr* t = to_app(args[0])->get_arg(0); + for(unsigned i = 1; i < num && is_all; ++i) { + is_all &= (is_app(args[i]) && + to_app(args[i])->get_decl() == accessors->get(i) && + to_app(args[i])->get_arg(0) == t); + } + if (is_all) { + result = t; + return BR_DONE; + } + } + } + return BR_FAILED; + } + }; +} +void rewrite_as_const_arr(expr* in, model& mdl, expr_ref& out) { + app_const_arr_rewriter cfg(out.m(), mdl); + rewriter_tpl rw(out.m(), false, cfg); + rw(in, out); +} + +void rewrite_read_over_write(expr *in, model &mdl, expr_ref &out) { + rd_over_wr_rewriter cfg(out.m(), mdl); + rewriter_tpl rw(out.m(), false, cfg); + rw(in, out); + if (cfg.m_sc.empty()) return; + expr_ref_vector sc(out.m()); + SASSERT(out.m().is_and(out)); + flatten_and(out, sc); + sc.append(cfg.m_sc); + out = mk_and(sc); +} class mbproj::impl { ast_manager& m; @@ -47,6 +150,7 @@ class mbproj::impl { // parameters bool m_reduce_all_selects; bool m_dont_sub; + bool m_use_qel; void add_plugin(mbp::project_plugin* p) { family_id fid = p->get_family_id(); @@ -253,9 +357,35 @@ class mbproj::impl { m_params.append(p); m_reduce_all_selects = m_params.get_bool("reduce_all_selects", false); m_dont_sub = m_params.get_bool("dont_sub", false); + auto q = gparams::get_module("smt"); + m_params.append(q); + m_use_qel = m_params.get_bool("qsat_use_qel", true); } void preprocess_solve(model& model, app_ref_vector& vars, expr_ref_vector& fmls) { + if (m_use_qel) { + extract_literals(model, vars, fmls); + expr_ref e(m); + bool change = true; + while (change && !vars.empty()) { + change = false; + e = mk_and(fmls); + do_qel(vars, e); + fmls.reset(); + flatten_and(e, fmls); + for (auto* p : m_plugins) { + if (p && p->solve(model, vars, fmls)) { + change = true; + } + } + } + //rewrite as_const_arr terms + expr_ref fml(m); + fml = mk_and(fmls); + rewrite_as_const_arr(fml, model, fml); + flatten_and(fml, fmls); + } + else { extract_literals(model, vars, fmls); bool change = true; while (change && !vars.empty()) { @@ -267,6 +397,7 @@ class mbproj::impl { } } } + } bool validate_model(model& model, expr_ref_vector const& fmls) { for (expr* f : fmls) { @@ -275,7 +406,33 @@ class mbproj::impl { return true; } + bool has_unsupported_th(const expr_ref_vector fmls) { + seq_util seq(m); + expr_ref e(m); + e = mk_and(fmls); + return any_of(subterms::all(e), [&](expr* c) { return seq.is_char(c) || seq.is_seq(c); }); + } void operator()(bool force_elim, app_ref_vector& vars, model& model, expr_ref_vector& fmls) { + //don't use mbp_qel on some theories where model evaluation is + //incomplete This is not a limitation of qel. Fix this either by + //making mbp choices less dependent on the model evaluation methods + //or fix theory rewriters to make terms evaluation complete + if (m_use_qel && !has_unsupported_th(fmls)) { + bool dsub = m_dont_sub; + m_dont_sub = !force_elim; + expr_ref fml(m); + fml = mk_and(fmls); + spacer_qel(vars, model, fml); + fmls.reset(); + flatten_and(fml, fmls); + m_dont_sub = dsub; + } + else { + mbp(force_elim, vars, model, fmls); + } + } + + void mbp(bool force_elim, app_ref_vector& vars, model& model, expr_ref_vector& fmls) { SASSERT(validate_model(model, fmls)); expr_ref val(m), tmp(m); app_ref var(m); @@ -341,6 +498,17 @@ class mbproj::impl { SASSERT(!m.is_false(fml)); } + + void do_qel(app_ref_vector &vars, expr_ref &fml) { + qel qe(m, m_params); + qe(vars, fml); + m_rw(fml); + TRACE("qe", tout << "After qel:\n" + << fml << "\n" + << "Vars: " << vars << "\n";); + SASSERT(!m.is_false(fml)); + } + void do_qe_bool(model& mdl, app_ref_vector& vars, expr_ref& fml) { expr_ref_vector fmls(m); fmls.push_back(fml); @@ -348,7 +516,86 @@ class mbproj::impl { fml = mk_and(fmls); } + void qel_project(app_ref_vector &vars, model &mdl, expr_ref &fml, bool reduce_all_selects) { + flatten_and(fml); + mbp::mbp_qel mbptg(m, m_params); + mbptg(vars, fml, mdl); + if (reduce_all_selects) rewrite_read_over_write(fml, mdl, fml); + m_rw(fml); + TRACE("qe", tout << "After mbp_tg:\n" + << fml << " models " << mdl.is_true(fml) << "\n" + << "Vars: " << vars << "\n";); + } + + void spacer_qel(app_ref_vector& vars, model& mdl, expr_ref& fml) { + TRACE("qe", tout << "Before projection:\n" << fml << "\n" << "Vars: " << vars << "\n";); + + model_evaluator eval(mdl, m_params); + eval.set_model_completion(true); + app_ref_vector other_vars(m); + app_ref_vector sub_vars(m); + array_util arr_u(m); + arith_util ari_u(m); + datatype_util dt_u(m); + + do_qel(vars, fml); + qel_project(vars, mdl, fml, m_reduce_all_selects); + flatten_and(fml); + m_rw(fml); + rewrite_as_const_arr(fml, mdl, fml); + + for (app* v : vars) { + SASSERT(!arr_u.is_array(v) && !dt_u.is_datatype(v->get_sort())); + other_vars.push_back(v); + } + + // project reals, ints and other variables. + if (!other_vars.empty()) { + TRACE("qe", tout << "Other vars: " << other_vars << "\n" << mdl;); + + expr_ref_vector fmls(m); + flatten_and(fml, fmls); + + mbp(false, other_vars, mdl, fmls); + fml = mk_and(fmls); + m_rw(fml); + + TRACE("qe", + tout << "Projected other vars:\n" << fml << "\n"; + tout << "Remaining other vars:\n" << other_vars << "\n";); + SASSERT(!m.is_false(fml)); + } + + if (!other_vars.empty()) { + project_vars(mdl, other_vars, fml); + m_rw(fml); + } + + // substitute any remaining other vars + if (!m_dont_sub && !other_vars.empty()) { + subst_vars(eval, other_vars, fml); + TRACE("qe", tout << "After substituting remaining other vars:\n" << fml << "\n";); + // an extra round of simplification because subst_vars is not simplifying + m_rw(fml); + other_vars.reset(); + } + + SASSERT(!eval.is_false(fml)); + + vars.reset(); + vars.append(other_vars); + } + void spacer(app_ref_vector& vars, model& mdl, expr_ref& fml) { + if (m_use_qel) { + spacer_qel(vars, mdl, fml); + } + else { + spacer_qe_lite(vars, mdl, fml); + } + } + + void spacer_qe_lite(app_ref_vector& vars, model& mdl, expr_ref& fml) { TRACE("qe", tout << "Before projection:\n" << fml << "\n" << "Vars: " << vars << "\n";); model_evaluator eval(mdl, m_params); @@ -428,7 +675,6 @@ class mbproj::impl { vars.reset(); vars.append(other_vars); } - }; mbproj::mbproj(ast_manager& m, params_ref const& p) { @@ -447,6 +693,7 @@ void mbproj::updt_params(params_ref const& p) { void mbproj::get_param_descrs(param_descrs& r) { r.insert("reduce_all_selects", CPK_BOOL, "(default: false) reduce selects"); r.insert("dont_sub", CPK_BOOL, "(default: false) disable substitution of values for free variables"); + r.insert("use_qel", CPK_BOOL, "(default: true) use egraph based QEL"); } void mbproj::operator()(bool force_elim, app_ref_vector& vars, model& mdl, expr_ref_vector& fmls) { @@ -468,3 +715,5 @@ opt::inf_eps mbproj::maximize(expr_ref_vector const& fmls, model& mdl, app* t, e scoped_no_proof _sp(fmls.get_manager()); return m_impl->maximize(fmls, mdl, t, ge, gt); } +template class rewriter_tpl; +template class rewriter_tpl; diff --git a/src/qe/qsat.cpp b/src/qe/qsat.cpp index 7e28437f674..dc0613c59d8 100644 --- a/src/qe/qsat.cpp +++ b/src/qe/qsat.cpp @@ -39,6 +39,7 @@ Revision History: #include "qe/qe_mbp.h" #include "qe/qe.h" #include "ast/rewriter/label_rewriter.h" +#include "util/params.h" namespace qe { @@ -84,8 +85,10 @@ namespace qe { if (m_flevel.find(a->get_decl(), lvl)) { lvl0.merge(lvl); } - for (unsigned i = 0; i < a->get_num_args(); ++i) { - app* arg = to_app(a->get_arg(i)); + for (expr* f : *a) { + if (!is_app(f)) + throw tactic_exception("atom is non-ground"); + app* arg = to_app(f); if (m_elevel.find(arg, lvl)) { lvl0.merge(lvl); } @@ -164,8 +167,9 @@ namespace qe { TRACE("qe_assumptions", model_v2_pp(tout, *mdl);); expr_ref val(m); - for (unsigned j = 0; j < m_preds[level - 1].size(); ++j) { - app* p = m_preds[level - 1][j].get(); + for (unsigned i = 0; i <= level-1; ++i) { + for (unsigned j = 0; j < m_preds[i].size(); ++j) { + app* p = m_preds[i][j].get(); eval(p, val); if (!m.inc()) return; @@ -176,6 +180,7 @@ namespace qe { SASSERT(m.is_true(val)); m_asms.push_back(p); } + } } asms.append(m_asms); @@ -262,13 +267,9 @@ namespace qe { continue; } - unsigned sz = a->get_num_args(); - for (unsigned i = 0; i < sz; ++i) { - expr* f = a->get_arg(i); - if (!mark.is_marked(f)) { - todo.push_back(f); - } - } + for (expr* f : *a) + if (!mark.is_marked(f)) + todo.push_back(f); bool is_boolop = (a->get_family_id() == m.get_basic_family_id()) && @@ -321,8 +322,8 @@ namespace qe { unsigned sz = a->get_num_args(); bool diff = false; args.reset(); - for (unsigned i = 0; i < sz; ++i) { - expr* f = a->get_arg(i), *f1; + for (expr* f : *a) { + expr *f1; if (cache.find(f, f1)) { args.push_back(f1); diff |= f != f1; @@ -410,8 +411,8 @@ namespace qe { unsigned sz = a->get_num_args(); args.reset(); bool diff = false; - for (unsigned i = 0; i < sz; ++i) { - expr* f = a->get_arg(i), *f1; + for (expr* f : *a) { + expr *f1; if (cache.find(f, f1)) { args.push_back(f1); diff |= f != f1; @@ -529,11 +530,14 @@ namespace qe { ast_manager& m; params_ref m_params; ref m_solver; + + expr_ref m_last_assert; public: kernel(ast_manager& m): m(m), - m_solver(nullptr) + m_solver(nullptr), + m_last_assert(m) { m_params.set_bool("model", true); m_params.set_uint("relevancy", 0); @@ -544,7 +548,8 @@ namespace qe { solver const& s() const { return *m_solver; } void init() { - m_solver = mk_smt_solver(m, m_params, symbol::null); + m_solver = mk_smt2_solver(m, m_params, symbol::null); + m_last_assert = nullptr; } void collect_statistics(statistics & st) const { if (m_solver) @@ -561,10 +566,22 @@ namespace qe { void clear() { m_solver = nullptr; } + void assert_expr(expr* e) { + if (!m.is_true(e)) + m_solver->assert_expr(e); + } + void assert_blocking_fml(expr* e) { + if (m.is_true(e)) + return; + if (m_last_assert && e == m_last_assert && !m.is_false(e)) { + IF_VERBOSE(0, verbose_stream() << "Asserting this expression twice in a row:\n " << m_last_assert << "\n"); + UNREACHABLE(); + } + m_last_assert = e; m_solver->assert_expr(e); } - + void get_core(expr_ref_vector& core) { core.reset(); m_solver->get_unsat_core(core); @@ -615,10 +632,12 @@ namespace qe { \brief check alternating satisfiability. Even levels are existential, odd levels are universal. */ - lbool check_sat() { + lbool check_sat() { while (true) { ++m_stats.m_num_rounds; - IF_VERBOSE(3, verbose_stream() << "(check-qsat level: " << m_level << " round: " << m_stats.m_num_rounds << ")\n";); + IF_VERBOSE(1, verbose_stream() << "(check-qsat level: " << m_level << " round: " << m_stats.m_num_rounds << ")\n";); + TRACE("qe", + tout << "level: " << m_level << " round: " << m_stats.m_num_rounds << "\n"); check_cancel(); expr_ref_vector asms(m_asms); m_pred_abs.get_assumptions(m_model.get(), asms); @@ -951,7 +970,8 @@ namespace qe { } else { fml = m_pred_abs.mk_abstract(fml); - get_kernel(m_level).assert_expr(fml); + TRACE("qe_block", tout << "Blocking fml at level: " << m_level << "\n" << fml << "\n";); + get_kernel(m_level).assert_blocking_fml(fml); } SASSERT(!m_model.get()); return true; @@ -994,14 +1014,13 @@ namespace qe { case AST_APP: { app* a = to_app(e); expr_ref_vector args(m); - unsigned num_args = a->get_num_args(); bool all_visited = true; - for (unsigned i = 0; i < num_args; ++i) { - if (visited.find(a->get_arg(i), r)) { + for (expr* arg : *a) { + if (visited.find(arg, r)) { args.push_back(r); } else { - todo.push_back(a->get_arg(i)); + todo.push_back(arg); all_visited = false; } } @@ -1235,8 +1254,11 @@ namespace qe { m_value(nullptr), m_was_sat(false), m_gt(m) - { - } + { + params_ref q = params_ref(); + q.set_bool("use_qel", false); + m_mbp.updt_params(q); + } ~qsat() override { clear(); diff --git a/src/sat/sat_aig_finder.cpp b/src/sat/sat_aig_finder.cpp index eba4366ae43..a1013108f87 100644 --- a/src/sat/sat_aig_finder.cpp +++ b/src/sat/sat_aig_finder.cpp @@ -94,7 +94,7 @@ namespace sat { // from clause x, y, z // then ~x, ~y -> z // look for ~y, z -> ~x - contains ternary(y, ~z, ~x) - // look for ~x, y -> u - u is used in a ternary claues (~y, x) + // look for ~x, y -> u - u is used in a ternary clause (~y, x) // look for y, u -> ~x - contains ternary(~u, ~x, ~y) // then ~x = if ~y then z else u diff --git a/src/sat/sat_anf_simplifier.cpp b/src/sat/sat_anf_simplifier.cpp index 6d28d9f98b0..5ed45171c41 100644 --- a/src/sat/sat_anf_simplifier.cpp +++ b/src/sat/sat_anf_simplifier.cpp @@ -42,7 +42,8 @@ namespace sat { void anf_simplifier::operator()() { dd::pdd_manager m(20, dd::pdd_manager::semantics::mod2_e); - pdd_solver solver(s.rlimit(), m); + u_dependency_manager dm; + pdd_solver solver(s.rlimit(), dm, m); report _report(*this); configure_solver(solver); clauses2anf(solver); diff --git a/src/sat/sat_binspr.cpp b/src/sat/sat_binspr.cpp index 460b6a4c0f9..acfd1298023 100644 --- a/src/sat/sat_binspr.cpp +++ b/src/sat/sat_binspr.cpp @@ -47,7 +47,7 @@ Marijn's version: if inconsistent(): learn C (subsumes C or p) else: - candidates' := C union ~(consequencs of propagate(~C)) + candidates' := C union ~(consequences of propagate(~C)) candidates := candidates' intersect candidates pop(1) for q in candidates: @@ -77,7 +77,7 @@ Marijn's version: if inconsistent(): learn C (subsumes C or p) else: - candidates := candicates union C union ~(consequencs of propagate(~C)) + candidates := candidates union C union ~(consequences of propagate(~C)) pop(1) for q in candidates: push(1) diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index aae9bf1abbe..2d31bcf1473 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -273,7 +273,7 @@ namespace sat { std::string cut::table2string(unsigned num_input, uint64_t table) { std::ostringstream strm; display_table(strm, num_input, table); - return strm.str(); + return std::move(strm).str(); } diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index cb97c41d47c..f8451d4123c 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -183,7 +183,7 @@ namespace sat { void reset(on_update_t& on_del) { shrink(on_del, 0); } cut const & operator[](unsigned idx) const { return m_cuts[idx]; } void shrink(on_update_t& on_del, unsigned j); - void swap(cut_set& other) { + void swap(cut_set& other) noexcept { std::swap(m_var, other.m_var); std::swap(m_size, other.m_size); std::swap(m_max_size, other.m_max_size); diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index ca274be51ca..bd7b0d26cbf 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -317,7 +317,7 @@ namespace sat { // cls becomes false: flip any variable in clause to receive reward w switch (ci.m_num_trues) { case 0: { - m_unsat.insert(cls_idx); + m_unsat.insert_fresh(cls_idx); clause const& c = get_clause(cls_idx); for (literal l : c) { inc_reward(l, w); @@ -406,7 +406,7 @@ namespace sat { inc_reward(lit, ci.m_weight); inc_make(lit); } - m_unsat.insert(i); + m_unsat.insert_fresh(i); break; case 1: dec_reward(to_literal(ci.m_trues), ci.m_weight); @@ -494,7 +494,7 @@ namespace sat { unsigned h = value_hash(); unsigned occs = 0; bool contains = m_models.find(h, occs); - if (!m_models.contains(h)) { + if (!contains) { for (unsigned v = 0; v < num_vars(); ++v) bias(v) += value(v) ? 1 : -1; if (m_models.size() > m_config.m_max_num_models) diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index 9883652859c..ff86e9b8c17 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -178,7 +178,7 @@ namespace sat { inline void inc_make(literal lit) { bool_var v = lit.var(); - if (make_count(v)++ == 0) m_unsat_vars.insert(v); + if (make_count(v)++ == 0) m_unsat_vars.insert_fresh(v); } inline void dec_make(literal lit) { diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index ae99cae12be..b42cc33e21a 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -87,6 +87,7 @@ namespace sat { virtual void init_search() {} virtual bool propagated(sat::literal l, sat::ext_constraint_idx idx) { UNREACHABLE(); return false; } virtual bool unit_propagate() = 0; + virtual bool can_propagate() { return false; } virtual bool is_external(bool_var v) { return false; } virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; } virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0; diff --git a/src/sat/sat_integrity_checker.cpp b/src/sat/sat_integrity_checker.cpp index 031ce92029c..a8de05e0f9e 100644 --- a/src/sat/sat_integrity_checker.cpp +++ b/src/sat/sat_integrity_checker.cpp @@ -28,7 +28,7 @@ namespace sat { } // for nary clauses - static bool contains_watched(watch_list const & wlist, clause const & c, clause_offset cls_off) { + bool integrity_checker::contains_watched(watch_list const & wlist, clause const & c, clause_offset cls_off) const { for (watched const& w : wlist) { if (w.is_clause()) { if (w.get_clause_offset() == cls_off) { @@ -38,6 +38,8 @@ namespace sat { } } } + TRACE("sat", tout << "clause " << c << " not found in watch-list\n"); + TRACE("sat", s.display_watches(tout)); UNREACHABLE(); return false; } diff --git a/src/sat/sat_integrity_checker.h b/src/sat/sat_integrity_checker.h index 22c8d0f5eab..bc7fecf68d3 100644 --- a/src/sat/sat_integrity_checker.h +++ b/src/sat/sat_integrity_checker.h @@ -25,6 +25,7 @@ Revision History: namespace sat { class integrity_checker { solver const & s; + bool contains_watched(watch_list const & wlist, clause const & c, clause_offset cls_off) const; public: integrity_checker(solver const & s); diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 2fa7ed040dd..2f3fc91b289 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -990,7 +990,7 @@ namespace sat { m_rating.push_back(0); m_vprefix.push_back(prefix()); if (!m_s.was_eliminated(v)) - m_freevars.insert(v); + m_freevars.insert_fresh(v); } void lookahead::init(bool learned) { @@ -1096,7 +1096,7 @@ namespace sat { literal l = m_trail[i]; set_undef(l); TRACE("sat", tout << "inserting free var v" << l.var() << "\n";); - m_freevars.insert(l.var()); + m_freevars.insert_fresh(l.var()); } m_num_tc1 = m_num_tc1_lim.back(); diff --git a/src/sat/sat_model_converter.cpp b/src/sat/sat_model_converter.cpp index 27cc6823a36..ddb277e9b23 100644 --- a/src/sat/sat_model_converter.cpp +++ b/src/sat/sat_model_converter.cpp @@ -369,7 +369,7 @@ namespace sat { return result; } - void model_converter::swap(bool_var v, unsigned sz, literal_vector& clause) { + void model_converter::swap(bool_var v, unsigned sz, literal_vector& clause) noexcept { for (unsigned j = 0; j < sz; ++j) { if (v == clause[j].var()) { std::swap(clause[0], clause[j]); diff --git a/src/sat/sat_model_converter.h b/src/sat/sat_model_converter.h index 8a3891797d3..a331bef8e57 100644 --- a/src/sat/sat_model_converter.h +++ b/src/sat/sat_model_converter.h @@ -91,7 +91,7 @@ namespace sat { bool legal_to_flip(bool_var v) const; - void swap(bool_var v, unsigned sz, literal_vector& clause); + void swap(bool_var v, unsigned sz, literal_vector& clause) noexcept; void add_elim_stack(entry & e); diff --git a/src/sat/sat_parallel.cpp b/src/sat/sat_parallel.cpp index cdb13706fae..f7b54b99812 100644 --- a/src/sat/sat_parallel.cpp +++ b/src/sat/sat_parallel.cpp @@ -5,7 +5,7 @@ Module Name: sat_parallel.cpp - Abstract: +Abstract: Utilities for parallel SAT solving. @@ -25,12 +25,7 @@ namespace sat { void parallel::vector_pool::next(unsigned& index) { SASSERT(index < m_size); unsigned n = index + 2 + get_length(index); - if (n >= m_size) { - index = 0; - } - else { - index = n; - } + index = (n >= m_size) ? 0 : n; } void parallel::vector_pool::reserve(unsigned num_threads, unsigned sz) { @@ -64,9 +59,8 @@ namespace sat { } void parallel::vector_pool::end_add_vector() { - if (m_tail >= m_size) { + if (m_tail >= m_size) m_tail = 0; - } } @@ -93,9 +87,9 @@ namespace sat { parallel::parallel(solver& s): m_num_clauses(0), m_consumer_ready(false), m_scoped_rlimit(s.rlimit()) {} parallel::~parallel() { - for (unsigned i = 0; i < m_solvers.size(); ++i) { - dealloc(m_solvers[i]); - } + m_limits.reset(); + for (auto* s : m_solvers) + dealloc(s); } void parallel::init_solvers(solver& s, unsigned num_extra_solvers) { @@ -106,9 +100,8 @@ namespace sat { for (unsigned i = 0; i < num_extra_solvers; ++i) { s.m_params.set_uint("random_seed", s.m_rand()); - if (i == 1 + num_threads/2) { + if (i == 1 + num_threads/2) s.m_params.set_sym("phase", symbol("random")); - } m_solvers[i] = alloc(sat::solver, s.m_params, m_limits[i]); m_solvers[i]->copy(s, true); m_solvers[i]->set_par(this, i); @@ -164,9 +157,8 @@ namespace sat { IF_VERBOSE(3, verbose_stream() << owner << ": share " << c << "\n";); lock_guard lock(m_mux); m_pool.begin_add_vector(owner, n); - for (unsigned i = 0; i < n; ++i) { + for (unsigned i = 0; i < n; ++i) m_pool.add_vector_elem(c[i].index()); - } m_pool.end_add_vector(); } @@ -220,9 +212,8 @@ namespace sat { if (m_priorities.empty()) return; - for (bool_var v = 0; v < m_priorities.size(); ++v) { + for (bool_var v = 0; v < m_priorities.size(); ++v) s.update_activity(v, m_priorities[v]); - } s.m_activity_inc = 128; #endif } @@ -269,15 +260,13 @@ namespace sat { bool parallel::copy_solver(solver& s) { bool copied = false; - { - lock_guard lock(m_mux); - m_consumer_ready = true; - if (m_solver_copy && s.m_clauses.size() > m_solver_copy->m_clauses.size()) { - s.copy(*m_solver_copy, true); - copied = true; - m_num_clauses = s.m_clauses.size(); - } - } + lock_guard lock(m_mux); + m_consumer_ready = true; + if (m_solver_copy && s.m_clauses.size() > m_solver_copy->m_clauses.size()) { + s.copy(*m_solver_copy, true); + copied = true; + m_num_clauses = s.m_clauses.size(); + } return copied; } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index d40d606d1e1..370bf1dea1c 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -64,7 +64,7 @@ def_module_params('sat', ('ddfw_search', BOOL, False, 'use ddfw local search instead of CDCL'), ('ddfw.init_clause_weight', UINT, 8, 'initial clause weight for DDFW local search'), ('ddfw.use_reward_pct', UINT, 15, 'percentage to pick highest reward variable when it has reward 0'), - ('ddfw.restart_base', UINT, 100000, 'number of flips used a starting point for hessitant restart backoff'), + ('ddfw.restart_base', UINT, 100000, 'number of flips used a starting point for hesitant restart backoff'), ('ddfw.reinit_base', UINT, 10000, 'increment basis for geometric backoff scheme of re-initialization of weights'), ('ddfw.threads', UINT, 0, 'number of ddfw threads to run in parallel with sat solver'), ('prob_search', BOOL, False, 'use probsat local search instead of CDCL'), @@ -105,7 +105,7 @@ def_module_params('sat', ('lookahead.preselect', BOOL, False, 'use pre-selection of subset of variables for branching'), ('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'), ('lookahead_scores', BOOL, False, 'extract lookahead scores. A utility that can only be used from the DIMACS front-end'), - ('lookahead.double', BOOL, True, 'enable doubld lookahead'), + ('lookahead.double', BOOL, True, 'enable double lookahead'), ('lookahead.use_learned', BOOL, False, 'use learned clauses when selecting lookahead literal'), ('lookahead_simplify.bca', BOOL, True, 'add learned binary clauses as part of lookahead simplification'), ('lookahead.global_autarky', BOOL, False, 'prefer to branch on variables that occur in clauses that are reduced'), diff --git a/src/sat/sat_prob.cpp b/src/sat/sat_prob.cpp index 46a098ba896..824ec4fc6cf 100644 --- a/src/sat/sat_prob.cpp +++ b/src/sat/sat_prob.cpp @@ -80,7 +80,7 @@ namespace sat { ci.del(lit); switch (ci.m_num_trues) { case 0: - m_unsat.insert(cls_idx); + m_unsat.insert_fresh(cls_idx); dec_break(lit); break; case 1: @@ -184,7 +184,7 @@ namespace sat { } switch (ci.m_num_trues) { case 0: - m_unsat.insert(i); + m_unsat.insert_fresh(i); break; case 1: inc_break(to_literal(ci.m_trues)); diff --git a/src/sat/sat_proof_trim.cpp b/src/sat/sat_proof_trim.cpp index df55aecd72c..7531da99fb7 100644 --- a/src/sat/sat_proof_trim.cpp +++ b/src/sat/sat_proof_trim.cpp @@ -7,15 +7,13 @@ Abstract: - proof replay and trim + The proof is trimmed by re-running the proof steps and collecting justified literals + at level 0. The proof is obtained by back-tracing the justificiations attached to literals. Author: Nikolaj Bjorner 2023-10-04 - Notes: - - --*/ #include "sat/sat_proof_trim.h" @@ -29,53 +27,50 @@ namespace sat { Output: reduced trail - result */ - unsigned_vector proof_trim::trim() { - unsigned_vector result; - m_core_literals.reset(); - m_core_literals.insert(literal_vector()); + vector> proof_trim::trim() { + m_result.reset(); m_propagated.resize(num_vars(), false); - for (unsigned i = m_trail.size(); i-- > 0; ) { + + + IF_VERBOSE(10, s.display(verbose_stream() << "trim\n")); + + auto const& [id, cl, clp, is_add, is_initial] = m_trail.back(); + SASSERT(cl.empty()); + m_result.push_back({id, unsigned_vector()}); + conflict_analysis_core(m_conflict, m_conflict_clause); + m_trail.pop_back(); + + for (unsigned i = m_trail.size(); i-- > 0; ) { auto const& [id, cl, clp, is_add, is_initial] = m_trail[i]; if (!is_add) { revive(cl, clp); continue; - } + } IF_VERBOSE(10, s.display(verbose_stream())); prune_trail(cl, clp); - IF_VERBOSE(10, verbose_stream() << cl << " " << in_core(cl, clp) << ": "; for (auto const& c : m_core_literals) verbose_stream() << "{" << c << "} "); IF_VERBOSE(10, s.display(verbose_stream() << "\n")); del(cl, clp); - if (!in_core(cl, clp)) + if (!in_core(cl)) continue; - result.push_back(id); + IF_VERBOSE(4, verbose_stream() << cl << " in-core " << in_core(cl) << ": "; for (auto const& [k,v] : m_clauses) verbose_stream() << "{" << v.m_clauses << "} "; verbose_stream() << "\n"); + + m_result.push_back({id, unsigned_vector()}); + m_in_deps.reset(); if (is_initial) continue; conflict_analysis_core(cl, clp); } - result.reverse(); - return result; + m_result.reverse(); + return m_result; } void proof_trim::del(literal_vector const& cl, clause* cp) { + CTRACE("sat", cp, tout << "del " << *cp << "\n"); if (cp) s.detach_clause(*cp); else del(cl); } - - bool proof_trim::match_clause(literal_vector const& cl, literal l1, literal l2) const { - return cl.size() == 2 && ((l1 == cl[0] && l2 == cl[1]) || (l1 == cl[1] && l2 == cl[0])); - } - - bool proof_trim::match_clause(literal_vector const& cl, literal l1, literal l2, literal l3) const { - return cl.size() == 3 && - ((l1 == cl[0] && l2 == cl[1] && l3 == cl[2]) || - (l1 == cl[0] && l2 == cl[2] && l3 == cl[1]) || - (l1 == cl[1] && l2 == cl[0] && l3 == cl[2]) || - (l1 == cl[1] && l2 == cl[2] && l3 == cl[0]) || - (l1 == cl[2] && l2 == cl[1] && l3 == cl[0]) || - (l1 == cl[2] && l2 == cl[0] && l3 == cl[1])); - } /** * cl is on the trail if there is some literal l that is implied by cl @@ -90,6 +85,8 @@ namespace sat { void proof_trim::prune_trail(literal_vector const& cl, clause* cp) { m_in_clause.reset(); m_in_coi.reset(); + + // verbose_stream() << "prune trail " << cl << "\n"; if (cl.empty()) return; @@ -121,12 +118,21 @@ namespace sat { auto js = s.get_justification(l); bool in_coi = false; if (js.is_clause()) - for (literal lit : s.get_clause(j)) + for (literal lit : s.get_clause(js)) in_coi |= m_in_coi.contains(lit.index()); else if (js.is_binary_clause()) in_coi = m_in_coi.contains(js.get_literal().index()); - else + else if (js.is_none()) { + verbose_stream() << "none " << js << "\n"; + } + else if (js.is_ext_justification()) { + verbose_stream() << js << "\n"; + UNREACHABLE(); // approach does not work for external justifications + } + else { + verbose_stream() << js << "\n"; UNREACHABLE(); // approach does not work for external justifications + } if (in_coi) unassign_literal(l); @@ -134,6 +140,7 @@ namespace sat { s.m_trail[j++] = s.m_trail[i]; } s.m_trail.shrink(j); + // verbose_stream() << "trail after " << s.m_trail << "\n"; s.m_inconsistent = false; s.m_qhead = s.m_trail.size(); s.propagate(false); @@ -170,7 +177,8 @@ namespace sat { IF_VERBOSE(3, verbose_stream() << "core " << cl << "\n"); unsigned trail_size0 = s.m_trail.size(); - if (!cl.empty()) { + bool probe = !cl.empty() && !s.inconsistent(); + if (probe) { SASSERT(!s.inconsistent()); s.push(); unsigned lvl = s.scope_lvl(); @@ -188,11 +196,14 @@ namespace sat { m_propagated[s.m_trail[i].var()] = true; } SASSERT(s.inconsistent()); - IF_VERBOSE(3, verbose_stream() << s.m_not_l << " " << s.m_conflict << "\n"); + IF_VERBOSE(3, s.display_justification(verbose_stream() << "conflict " << s.m_not_l << " ", s.m_conflict) << "\n"); + IF_VERBOSE(3, s.display(verbose_stream())); + sat::literal l = sat::null_literal; if (s.m_not_l != null_literal) { - add_core(~s.m_not_l, s.m_conflict); add_dependency(s.m_not_l); + l = ~s.m_not_l; } + add_core(l, s.m_conflict); add_dependency(s.m_conflict); for (unsigned i = s.m_trail.size(); i-- > trail_size0; ) { @@ -201,17 +212,20 @@ namespace sat { if (!s.is_marked(v)) continue; add_core(v); - s.reset_mark(v); + s.reset_mark(v); add_dependency(s.get_justification(v)); } - if (!cl.empty()) + if (probe) s.pop(1); } void proof_trim::add_dependency(literal lit) { + IF_VERBOSE(3, verbose_stream() << "add dependency " << lit << "\n"); bool_var v = lit.var(); - if (m_propagated[v]) // literal was propagated after assuming ~C - s.mark(v); + if (m_propagated[v]) { // literal was propagated after assuming ~C + if (!s.is_marked(v)) + s.mark(v); + } else if (s.lvl(v) == 0) // literal depends on level 0, it is not assumed by ~C // inefficient for repeated insertions ? add_core(v); @@ -241,40 +255,54 @@ namespace sat { add_core(lit, j); } + void proof_trim::insert_dep(unsigned dep) { + if (m_in_deps.contains(dep)) + return; + m_in_deps.insert(dep); + m_result.back().second.push_back(dep); + } void proof_trim::add_core(literal l, justification j) { m_clause.reset(); switch (j.get_kind()) { case justification::NONE: - m_clause.push_back(l); + if (l != null_literal) + m_clause.push_back(l); break; case justification::BINARY: m_clause.push_back(l); m_clause.push_back(j.get_literal()); break; case justification::CLAUSE: - s.get_clause(j).mark_used(); - IF_VERBOSE(3, verbose_stream() << "add core " << s.get_clause(j) << "\n"); - return; + for (auto lit : s.get_clause(j)) + m_clause.push_back(lit); + break; default: + verbose_stream() << j << "\n"; UNREACHABLE(); break; } std::sort(m_clause.begin(), m_clause.end()); - IF_VERBOSE(3, verbose_stream() << "add core " << m_clause << "\n"); - m_core_literals.insert(m_clause); - if (s.lvl(l) == 0) { - m_clause.reset(); - m_clause.push_back(l); - m_core_literals.insert(m_clause); + IF_VERBOSE(3, verbose_stream() << "add core {" << m_clause << "}\n"); + auto& [clauses, id, in_core] = m_clauses.find(m_clause); + in_core = true; + insert_dep(id); + if (m_clause.size() > 1 && l != null_literal && s.lvl(l) == 0) { + for (auto lit : m_clause) { + if (s.lvl(lit) != 0) + continue; + m_clause2.reset(); + m_clause2.push_back(s.value(lit) == l_false ? ~lit : lit); + auto& [clauses, id, in_core] = m_clauses.insert_if_not_there(m_clause2, {{}, UINT_MAX, true }); + in_core = true; + if (id != UINT_MAX) + insert_dep(id); + } } } - bool proof_trim::in_core(literal_vector const& cl, clause* cp) const { - if (cp) - return cp->was_used(); - else - return m_core_literals.contains(cl); + bool proof_trim::in_core(literal_vector const& cl) const { + return m_clauses.find(cl).m_in_core; } void proof_trim::revive(literal_vector const& cl, clause* cp) { @@ -286,7 +314,7 @@ namespace sat { clause* proof_trim::del(literal_vector const& cl) { clause* cp = nullptr; - IF_VERBOSE(3, verbose_stream() << "del: " << cl << "\n"); + TRACE("sat", tout << "del: " << cl << "\n"); if (cl.size() == 2) { s.detach_bin_clause(cl[0], cl[1], true); return cp; @@ -294,23 +322,15 @@ namespace sat { auto* e = m_clauses.find_core(cl); if (!e) return cp; - auto& v = e->get_data().m_value; - if (!v.empty()) { - cp = v.back(); - IF_VERBOSE(3, verbose_stream() << "del: " << *cp << "\n"); + auto& [clauses, id, in_core] = e->get_data().m_value; + if (!clauses.empty()) { + cp = clauses.back(); + TRACE("sat", tout << "del: " << *cp << "\n"); s.detach_clause(*cp); - v.pop_back(); + clauses.pop_back(); } return cp; - } - - void proof_trim::save(literal_vector const& lits, clause* cl) { - if (!cl) - return; - IF_VERBOSE(3, verbose_stream() << "add: " << *cl << "\n"); - auto& v = m_clauses.insert_if_not_there(lits, clause_vector()); - v.push_back(cl); - } + } proof_trim::proof_trim(params_ref const& p, reslimit& lim): s(p, lim) { @@ -318,14 +338,67 @@ namespace sat { } void proof_trim::assume(unsigned id, bool is_initial) { - std::sort(m_clause.begin(), m_clause.end()); + std::sort(m_clause.begin(), m_clause.end()); + unsigned j = 0; + sat::literal prev = null_literal; + for (unsigned i = 0; i < m_clause.size(); ++i) + if (m_clause[i] != prev) + prev = m_clause[j++] = m_clause[i]; + m_clause.shrink(j); if (unit_or_binary_occurs()) - return; + return; + if (!m_conflict.empty() && m_clause.empty()) { + m_clauses.insert(m_clause, { {}, id, m_clause.empty() }); + m_trail.push_back({ id , m_clause, nullptr, true, is_initial }); + } + if (!m_conflict.empty()) + return; + IF_VERBOSE(3, verbose_stream() << (is_initial?"assume ":"rup ") << m_clause << "\n"); auto* cl = s.mk_clause(m_clause, status::redundant()); + auto& [clauses, id2, in_core] = m_clauses.insert_if_not_there(m_clause, { {}, id, m_clause.empty() }); + if (cl) + clauses.push_back(cl); m_trail.push_back({ id, m_clause, cl, true, is_initial }); + + auto is_unit2 = [&]() { + if (s.value(m_clause[0]) == l_false) + std::swap(m_clause[0], m_clause[1]); + return s.value(m_clause[1]) == l_false; + }; + + auto is_unit = [&]() { + unsigned undef_idx = m_clause.size(); + for (unsigned i = 0; i < m_clause.size(); ++i) { + sat::literal lit = (*cl)[i]; + if (s.value(lit) != l_undef) + continue; + if (undef_idx < m_clause.size()) + return false; + undef_idx = i; + } + if (undef_idx < m_clause.size()) { + std::swap((*cl)[undef_idx], (*cl)[0]); + return true; + } + return false; + }; + + if (all_of(m_clause, [&](sat::literal lit) { return s.value(lit) == l_false; })) { + IF_VERBOSE(3, verbose_stream() << "false clause " << m_clause << "\n"); + set_conflict(m_clause, cl); + return; + } + + if (m_clause.size() == 2 && is_unit2()) + s.propagate_bin_clause(m_clause[0], m_clause[1]); + else if (m_clause.size() > 2 && is_unit()) + s.propagate_clause(*cl, true, 0, s.cls_allocator().get_offset(cl)); s.propagate(false); - save(m_clause, cl); + if (s.inconsistent()) { + IF_VERBOSE(3, verbose_stream() << "conflict " << m_clause << "\n"); + set_conflict(m_clause, cl); + } } /** @@ -352,6 +425,4 @@ namespace sat { void proof_trim::infer(unsigned id) { assume(id, false); } - - } diff --git a/src/sat/sat_proof_trim.h b/src/sat/sat_proof_trim.h index 6d996ad6e1b..11e460eb6d8 100644 --- a/src/sat/sat_proof_trim.h +++ b/src/sat/sat_proof_trim.h @@ -30,11 +30,13 @@ namespace sat { class proof_trim { solver s; - literal_vector m_clause; + literal_vector m_clause, m_clause2, m_conflict; + uint_set m_in_deps; uint_set m_in_clause; uint_set m_in_coi; + clause* m_conflict_clause = nullptr; vector> m_trail; - + vector> m_result; struct hash { unsigned operator()(literal_vector const& v) const { @@ -46,16 +48,20 @@ namespace sat { return a == b; } }; - map m_clauses; - hashtable m_core_literals; + + struct clause_info { + clause_vector m_clauses; + unsigned m_id = 0; + bool m_in_core = false; + }; + + + map m_clauses; bool_vector m_propagated; void del(literal_vector const& cl, clause* cp); - bool match_clause(literal_vector const& cl, literal l1, literal l2) const; - bool match_clause(literal_vector const& cl, literal l1, literal l2, literal l3) const; - void prune_trail(literal_vector const& cl, clause* cp); void conflict_analysis_core(literal_vector const& cl, clause* cp); @@ -63,13 +69,15 @@ namespace sat { void add_dependency(justification j); void add_core(bool_var v); void add_core(literal l, justification j); - bool in_core(literal_vector const& cl, clause* cp) const; + bool in_core(literal_vector const& cl) const; void revive(literal_vector const& cl, clause* cp); clause* del(literal_vector const& cl); - void save(literal_vector const& lits, clause* cl); + + void insert_dep(unsigned dep); uint_set m_units; bool unit_or_binary_occurs(); + void set_conflict(literal_vector const& c, clause* cp) { m_conflict.reset(); m_conflict.append(c); m_conflict_clause = cp;} public: @@ -85,7 +93,7 @@ namespace sat { void infer(unsigned id); void updt_params(params_ref const& p) { s.updt_params(p); } - unsigned_vector trim(); + vector> trim(); }; } diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 5c1ed6daeab..2d296294085 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -90,7 +90,7 @@ namespace sat { solver::~solver() { m_ext = nullptr; - SASSERT(m_config.m_num_threads > 1 || check_invariant()); + SASSERT(m_config.m_num_threads > 1 || m_trim || check_invariant()); CTRACE("sat", !m_clauses.empty(), tout << "Delete clauses\n";); del_clauses(m_clauses); CTRACE("sat", !m_learned.empty(), tout << "Delete learned\n";); @@ -238,7 +238,8 @@ namespace sat { } m_user_scope_literals.reset(); - m_user_scope_literals.append(src.m_user_scope_literals); + for (auto lit : src.m_user_scope_literals) + assign_unit(~lit); m_mc = src.m_mc; m_stats.m_units = init_trail_size(); @@ -960,6 +961,8 @@ namespace sat { // ----------------------- bool solver::propagate_core(bool update) { + if (m_ext && (!is_probing() || at_base_lvl())) + m_ext->unit_propagate(); while (m_qhead < m_trail.size() && !m_inconsistent) { do { checkpoint(); @@ -1716,6 +1719,9 @@ namespace sat { if (next == null_bool_var) return false; } + else { + SASSERT(value(next) == l_undef); + } push(); m_stats.m_decision++; @@ -1725,11 +1731,14 @@ namespace sat { phase = guess(next) ? l_true: l_false; literal next_lit(next, false); + SASSERT(value(next_lit) == l_undef); if (m_ext && m_ext->decide(next, phase)) { + if (used_queue) m_case_split_queue.unassign_var_eh(next); next_lit = literal(next, false); + SASSERT(value(next_lit) == l_undef); } if (phase == l_undef) @@ -1781,7 +1790,7 @@ namespace sat { } bool solver::should_propagate() const { - return !inconsistent() && m_qhead < m_trail.size(); + return !inconsistent() && (m_qhead < m_trail.size() || (m_ext && m_ext->can_propagate())); } lbool solver::final_check() { @@ -2271,7 +2280,7 @@ namespace sat { << std::setw(4) << m_stats.m_restart << mk_stat(*this) << " " << std::setw(6) << std::setprecision(2) << m_stopwatch.get_current_seconds() << ")\n"; - std::string str(strm.str()); + std::string str = std::move(strm).str(); svector nums; for (size_t i = 0; i < str.size(); ++i) { while (i < str.size() && str[i] != ' ') ++i; @@ -2425,9 +2434,8 @@ namespace sat { m_conflicts_since_restart++; m_conflicts_since_gc++; m_stats.m_conflict++; - if (m_step_size > m_config.m_step_size_min) { - m_step_size -= m_config.m_step_size_dec; - } + if (m_step_size > m_config.m_step_size_min) + m_step_size -= m_config.m_step_size_dec; bool unique_max; m_conflict_lvl = get_max_lvl(m_not_l, m_conflict, unique_max); @@ -2531,17 +2539,8 @@ namespace sat { case justification::EXT_JUSTIFICATION: { fill_ext_antecedents(consequent, js, false); TRACE("sat", tout << "ext antecedents: " << m_ext_antecedents << "\n";); - for (literal l : m_ext_antecedents) - process_antecedent(l, num_marks); - -#if 0 - if (m_ext_antecedents.size() <= 1) { - for (literal& l : m_ext_antecedents) - l.neg(); - m_ext_antecedents.push_back(consequent); - mk_clause(m_ext_antecedents.size(), m_ext_antecedents.c_ptr(), sat::status::redundant()); - } -#endif + for (literal l : m_ext_antecedents) + process_antecedent(l, num_marks); break; } default: @@ -2559,7 +2558,8 @@ namespace sat { } SASSERT(lvl(c_var) < m_conflict_lvl); } - CTRACE("sat", idx == 0, + CTRACE("sat", idx == 0, + tout << "conflict level " << m_conflict_lvl << "\n"; for (literal lit : m_trail) if (is_marked(lit.var())) tout << "missed " << lit << "@" << lvl(lit) << "\n";); @@ -2814,31 +2814,34 @@ namespace sat { unsigned level = 0; if (not_l != null_literal) { - level = lvl(not_l); + level = lvl(not_l); } + TRACE("sat", tout << "level " << not_l << " is " << level << " " << js << "\n"); switch (js.get_kind()) { case justification::NONE: level = std::max(level, js.level()); - return level; + break; case justification::BINARY: level = update_max_level(js.get_literal(), level, unique_max); - return level; + break; case justification::CLAUSE: for (literal l : get_clause(js)) level = update_max_level(l, level, unique_max); - return level; + break; case justification::EXT_JUSTIFICATION: if (not_l != null_literal) not_l.neg(); fill_ext_antecedents(not_l, js, true); for (literal l : m_ext_antecedents) level = update_max_level(l, level, unique_max); - return level; + break; default: UNREACHABLE(); - return 0; + break; } + TRACE("sat", tout << "max-level " << level << " " << unique_max << "\n"); + return level; } /** @@ -3461,7 +3464,7 @@ namespace sat { } } - // can't eliminat FUIP + // can't eliminate FUIP SASSERT(is_marked_lit(m_lemma[0])); unsigned j = 0; @@ -3488,6 +3491,7 @@ namespace sat { // // ----------------------- void solver::push() { + SASSERT(!m_ext || !m_ext->can_propagate()); SASSERT(!inconsistent()); TRACE("sat_verbose", tout << "q:" << m_qhead << " trail: " << m_trail.size() << "\n";); SASSERT(m_qhead == m_trail.size()); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 3a437855e9d..0361fc157b9 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -164,7 +164,7 @@ namespace sat { unsigned m_rephase_inc; backoff m_rephase; backoff m_reorder; - var_queue m_case_split_queue; + var_queue m_case_split_queue; unsigned m_qhead; unsigned m_scope_lvl; unsigned m_search_lvl; diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 8bf665ebf02..4574d3da384 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -115,9 +115,6 @@ class inc_sat_solver : public solver { } solver* translate(ast_manager& dst_m, params_ref const& p) override { - if (m_num_scopes > 0) { - throw default_exception("Cannot translate sat solver at non-base level"); - } ast_translation tr(m, dst_m); m_solver.pop_to_base_level(); inc_sat_solver* result = alloc(inc_sat_solver, dst_m, p, is_incremental()); diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index a5dd9b415f1..ab0e71cc311 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -47,7 +47,7 @@ class sat_smt_solver : public solver { ast_manager& m; trail_stack& m_trail; expr_ref_vector m_refs; - obj_map m_dep2orig; // map original dependency to uninterpeted literal + obj_map m_dep2orig; // map original dependency to uninterpreted literal u_map m_lit2dep; // map from literal assumption to original expression obj_map m_dep2lit; // map uninterpreted literal to sat literal @@ -489,7 +489,6 @@ class sat_smt_solver : public solver { model_converter_ref get_model_converter() const override { const_cast(this)->convert_internalized(); - verbose_stream() << "get model converter " << (m_cached_mc.get() != nullptr) << "\n"; if (m_cached_mc) return m_cached_mc; if (is_internalized() && m_internalized_converted) { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 4a899ca9d8e..7747b65cb90 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(sat_smt arith_internalize.cpp arith_sls.cpp arith_solver.cpp + arith_value.cpp array_axioms.cpp array_diagnostics.cpp array_internalize.cpp @@ -27,6 +28,7 @@ z3_add_component(sat_smt euf_proof_checker.cpp euf_relevancy.cpp euf_solver.cpp + intblast_solver.cpp fpa_solver.cpp pb_card.cpp pb_constraint.cpp @@ -44,6 +46,8 @@ z3_add_component(sat_smt q_solver.cpp recfun_solver.cpp sat_th.cpp + sls_solver.cpp + specrel_solver.cpp tseitin_theory_checker.cpp user_solver.cpp COMPONENT_DEPENDENCIES diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 93917042e50..b2dd5e96908 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -205,6 +205,169 @@ namespace arith { add_clause(dgez, neg); } + bool solver::check_bv_term(app* n) { + unsigned sz; + expr* _x, * _y; + if (!ctx.is_relevant(expr2enode(n))) + return true; + expr_ref vx(m), vy(m),vn(m); + rational valn, valx, valy; + bool is_int; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); + if (!get_value(expr2enode(_x), vx) || !get_value(expr2enode(_y), vy) || !get_value(expr2enode(n), vn)) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; + } + if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; + } + rational N = rational::power_of_two(sz); + valx = mod(valx, N); + valy = mod(valy, N); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); + SASSERT(0 <= valn && valn < N); + + // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. + auto bitof = [&](expr* x, unsigned i) { + expr_ref r(m); + r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i))); + return mk_literal(r); + }; + + if (a.is_band(n)) { + IF_VERBOSE(2, verbose_stream() << "band: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"); + for (unsigned i = 0; i < sz; ++i) { + bool xb = valx.get_bit(i); + bool yb = valy.get_bit(i); + bool nb = valn.get_bit(i); + if (xb && yb && !nb) + add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); + else if (nb && !xb) + add_clause(~bitof(n, i), bitof(x, i)); + else if (nb && !yb) + add_clause(~bitof(n, i), bitof(y, i)); + else + continue; + return false; + } + } + if (a.is_shl(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = eq_internalize(n, a.mk_mod(a.mk_mul(_x, a.mk_int(rational::power_of_two(k))), a.mk_int(N))); + if (s().value(eq) == l_true) + return true; + add_clause(~eq_internalize(y, a.mk_int(k)), eq); + IF_VERBOSE(2, verbose_stream() << "shl: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " << " << valy << "\n"); + return false; + } + if (a.is_lshr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = eq_internalize(n, a.mk_idiv(x, a.mk_int(rational::power_of_two(k)))); + if (s().value(eq) == l_true) + return true; + add_clause(~eq_internalize(y, a.mk_int(k)), eq); + IF_VERBOSE(2, verbose_stream() << "lshr: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " >>l " << valy << "\n"); + return false; + } + if (a.is_ashr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + sat::literal eq; + expr* xdiv2k; + switch (s().value(signx)) { + case l_true: + // x < 0 & y = k -> n = (x div 2^k - 2^{N-k}) mod 2^N + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = eq_internalize(n, a.mk_mod(a.mk_add(xdiv2k, a.mk_int(-rational::power_of_two(sz - k))), a.mk_int(N))); + if (s().value(eq) == l_true) + return true; + break; + case l_false: + // x >= 0 & y = k -> n = x div 2^k + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = eq_internalize(n, xdiv2k); + if (s().value(eq) == l_true) + return true; + break; + case l_undef: + ctx.mark_relevant(signx); + return false; + } + add_clause(~eq_internalize(y, a.mk_int(k)), ~signx, eq); + return false; + } + return true; + } + + bool solver::check_bv_terms() { + for (app* n : m_bv_terms) { + if (!check_bv_term(n)) { + ++m_stats.m_bv_axioms; + return false; + } + } + return true; + } + + + void solver::mk_bv_axiom(app* n) { + unsigned sz; + expr* _x, * _y; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); + rational N = rational::power_of_two(sz); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); + + if (a.is_band(n)) { + + // 0 <= x&y < 2^sz + // x&y <= x + // x&y <= y + // TODO? x = y => x&y = x + + add_clause(mk_literal(a.mk_ge(n, a.mk_int(0)))); + add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1)))); + add_clause(mk_literal(a.mk_le(n, x))); + add_clause(mk_literal(a.mk_le(n, y))); + } + else if (a.is_shl(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_lshr(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_ashr(n)) { + // y >= sz & x < 2^{sz-1} => n = 0 + // y >= sz & x >= 2^{sz-1} => n = -1 + // y = 0 => n = x + auto signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), signx, mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), ~signx, mk_literal(m.mk_eq(n, a.mk_int(N-1)))); + add_clause(~mk_literal(a.mk_eq(a.mk_mod(y, a.mk_int(N)), a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else + UNREACHABLE(); + } + void solver::mk_bound_axioms(api_bound& b) { theory_var v = b.get_var(); lp_api::bound_kind kind1 = b.get_bound_kind(); @@ -386,12 +549,12 @@ namespace arith { ctx.push(push_back_vector>>(m_delayed_eqs)); } - void solver::mk_diseq_axiom(euf::th_eq const& e) { - if (is_bool(e.v1())) + void solver::mk_diseq_axiom(theory_var v1, theory_var v2) { + if (is_bool(v1)) return; force_push(); - expr* e1 = var2expr(e.v1()); - expr* e2 = var2expr(e.v2()); + expr* e1 = var2expr(v1); + expr* e2 = var2expr(v2); if (e1->get_id() > e2->get_id()) std::swap(e1, e2); if (m.are_distinct(e1, e2)) @@ -524,7 +687,7 @@ namespace arith { return all_divs_valid; } - void solver::fixed_var_eh(theory_var v, lp::constraint_index ci1, lp::constraint_index ci2, rational const& bound) { + void solver::fixed_var_eh(theory_var v, u_dependency* dep, rational const& bound) { theory_var w = euf::null_theory_var; enode* x = var2enode(v); if (bound.is_zero()) @@ -539,10 +702,12 @@ namespace arith { if (x->get_root() == y->get_root()) return; reset_evidence(); - set_evidence(ci1); - set_evidence(ci2); + m_explanation.clear(); + for (auto ci : lp().flatten(dep)) + consume(rational::one(), ci); ++m_stats.m_fixed_eqs; - auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y); + auto* hint = explain_implied_eq(m_explanation, x, y); + auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y, hint); ctx.propagate(x, y, jst->to_index()); } diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index a3e48256ded..c408fbf96bc 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -15,6 +15,8 @@ Module Name: --*/ +#include "util/cancel_eh.h" +#include "util/scoped_timer.h" #include "ast/ast_util.h" #include "ast/scoped_proof.h" #include "sat/smt/euf_solver.h" @@ -32,7 +34,7 @@ namespace arith { } arith_proof_hint* arith_proof_hint_builder::mk(euf::solver& s) { - return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); } std::ostream& solver::display(std::ostream& out) const { @@ -43,8 +45,7 @@ namespace arith { } unsigned nv = get_num_vars(); for (unsigned v = 0; v < nv; ++v) { - auto t = get_tv(v); - auto vi = lp().external_to_column_index(v); + auto vi = lp().external_to_local(v); out << "v" << v << " "; if (is_bool(v)) { euf::enode* n = var2enode(v); @@ -55,10 +56,10 @@ namespace arith { } } else { - if (t.is_null()) + if (vi == lp::null_lpvar) out << "null"; else - out << (t.is_term() ? "t" : "j") << vi; + out << (lp().column_has_term(vi) ? "t" : "j") << vi; if (m_nla && m_nla->use_nra_model() && is_registered_var(v)) { scoped_anum an(m_nla->am()); m_nla->am().display(out << " = ", nl_value(v, an)); @@ -90,7 +91,6 @@ namespace arith { void solver::collect_statistics(statistics& st) const { m_stats.collect_statistics(st); lp().settings().stats().collect_statistics(st); - if (m_nla) m_nla->collect_statistics(st); } void solver::explain_assumptions(lp::explanation const& e) { @@ -133,12 +133,25 @@ namespace arith { return m_arith_hint.mk(ctx); } - arith_proof_hint const* solver::explain_conflict(sat::literal_vector const& core, euf::enode_pair_vector const& eqs) { + arith_proof_hint const* solver::explain_conflict(hint_type ty, sat::literal_vector const& core, euf::enode_pair_vector const& eqs) { arith_proof_hint* hint = nullptr; if (ctx.use_drat()) { - m_arith_hint.set_type(ctx, hint_type::farkas_h); - for (auto lit : core) - m_arith_hint.add_lit(rational::one(), lit); + m_coeffs.reset(); + for (auto const& e : m_explanation) { + if (inequality_source == m_constraint_sources[e.ci()]) + m_coeffs.push_back(e.coeff()); + } + + m_arith_hint.set_type(ctx, ty); + if (m_coeffs.size() == core.size()) { + unsigned i = 0; + for (auto lit : core) + m_arith_hint.add_lit(m_coeffs[i], lit), ++i; + } + else { + for (auto lit : core) + m_arith_hint.add_lit(rational::one(), lit); + } for (auto const& [a,b] : eqs) m_arith_hint.add_eq(a, b); hint = m_arith_hint.mk(ctx); @@ -151,7 +164,6 @@ namespace arith { return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); explain_assumptions(e); - m_arith_hint.set_num_le(1); // TODO m_arith_hint.add_diseq(a, b); return m_arith_hint.mk(ctx); } @@ -160,13 +172,19 @@ namespace arith { if (!ctx.use_drat()) return nullptr; m_arith_hint.set_type(ctx, hint_type::implied_eq_h); - m_arith_hint.set_num_le(1); m_arith_hint.add_lit(rational(1), le); m_arith_hint.add_lit(rational(1), ge); m_arith_hint.add_lit(rational(1), ~eq); return m_arith_hint.mk(ctx); } + /** + * The expected format is: + * 1. all equalities + * 2. all inequalities + * 3. optional disequalities (used for the steps that propagate equalities) + */ + expr* arith_proof_hint::get_hint(euf::solver& s) const { ast_manager& m = s.get_manager(); family_id fid = m.get_family_id("arith"); @@ -179,30 +197,67 @@ namespace arith { case hint_type::farkas_h: name = "farkas"; break; + case hint_type::cut_h: + name = "cut"; + break; case hint_type::bound_h: name = "bound"; break; case hint_type::implied_eq_h: name = "implied-eq"; - args.push_back(arith.mk_int(m_num_le)); + break; + case hint_type::nla_h: + name = "nla"; + break; + default: + name = "unknown-arithmetic"; break; } + + auto push_eq = [&](bool is_eq, enode* x, enode* y) { + if (x->get_id() > y->get_id()) + std::swap(x, y); + expr_ref eq(m.mk_eq(x->get_expr(), y->get_expr()), m); + if (!is_eq) eq = m.mk_not(eq); + args.push_back(arith.mk_int(1)); + args.push_back(eq); + }; rational lc(1); for (unsigned i = m_lit_head; i < m_lit_tail; ++i) lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); - + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { + auto [x, y, is_eq] = a.m_arith_hint.eq(i); + if (is_eq) + push_eq(is_eq, x, y); + } for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { auto const& [coeff, lit] = a.m_arith_hint.lit(i); args.push_back(arith.mk_int(abs(coeff*lc))); args.push_back(s.literal2expr(lit)); } for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { - auto const& [x, y, is_eq] = a.m_arith_hint.eq(i); - expr_ref eq(m.mk_eq(x->get_expr(), y->get_expr()), m); - if (!is_eq) eq = m.mk_not(eq); - args.push_back(arith.mk_int(1)); - args.push_back(eq); + auto [x, y, is_eq] = a.m_arith_hint.eq(i); + if (!is_eq) + push_eq(is_eq, x, y); } + return m.mk_app(symbol(name), args.size(), args.data(), m.mk_proof_sort()); } + + bool solver::validate_conflict() { + scoped_ptr<::solver> vs = mk_smt2_solver(m, ctx.s().params(), symbol::null); + for (auto lit : m_core) + vs->assert_expr(ctx.literal2expr(lit)); + + for (auto [a, b] : m_eqs) + vs->assert_expr(m.mk_eq(a->get_expr(), b->get_expr())); + + cancel_eh eh(m.limit()); + scoped_timer timer(1000, &eh); + bool result = l_true != vs->check_sat(); + CTRACE("arith", !result, vs->display(tout)); + CTRACE("arith", !result, s().display(tout)); + SASSERT(result); + return result; + } } diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 60ca9651a19..a389d13b889 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -61,29 +61,11 @@ namespace arith { void solver::ensure_nla() { if (!m_nla) { - m_nla = alloc(nla::solver, *m_solver.get(), m.limit()); + m_nla = alloc(nla::solver, *m_solver.get(), s().params(), m.limit()); for (auto const& _s : m_scopes) { (void)_s; m_nla->push(); } - smt_params_helper prms(s().params()); - m_nla->settings().run_order = prms.arith_nl_order(); - m_nla->settings().run_tangents = prms.arith_nl_tangents(); - m_nla->settings().run_horner = prms.arith_nl_horner(); - m_nla->settings().horner_subs_fixed = prms.arith_nl_horner_subs_fixed(); - m_nla->settings().horner_frequency = prms.arith_nl_horner_frequency(); - m_nla->settings().horner_row_length_limit = prms.arith_nl_horner_row_length_limit(); - m_nla->settings().run_grobner = prms.arith_nl_grobner(); - m_nla->settings().run_nra = prms.arith_nl_nra(); - m_nla->settings().grobner_subs_fixed = prms.arith_nl_grobner_subs_fixed(); - m_nla->settings().grobner_eqs_growth = prms.arith_nl_grobner_eqs_growth(); - m_nla->settings().grobner_expr_size_growth = prms.arith_nl_grobner_expr_size_growth(); - m_nla->settings().grobner_expr_degree_growth = prms.arith_nl_grobner_expr_degree_growth(); - m_nla->settings().grobner_max_simplified = prms.arith_nl_grobner_max_simplified(); - m_nla->settings().grobner_number_of_conflicts_to_report = prms.arith_nl_grobner_cnfl_to_report(); - m_nla->settings().grobner_quota = prms.arith_nl_gr_q(); - m_nla->settings().grobner_frequency = prms.arith_nl_grobner_frequency(); - m_nla->settings().expensive_patching = false; } } @@ -160,7 +142,6 @@ namespace arith { expr_ref_vector& terms = st.terms(); svector& vars = st.vars(); vector& coeffs = st.coeffs(); - rational& offset = st.offset(); rational r; expr* n1, * n2; unsigned index = 0; @@ -204,7 +185,9 @@ namespace arith { ++index; } else if (a.is_numeral(n, r)) { - offset += coeffs[index] * r; + theory_var v = internalize_numeral(to_app(n), r); + coeffs[vars.size()] = coeffs[index]; + vars.push_back(v); ++index; } else if (a.is_uminus(n, n1)) { @@ -269,6 +252,12 @@ namespace arith { st.to_ensure_var().push_back(n1); st.to_ensure_var().push_back(n2); } + else if (a.is_band(n) || a.is_shl(n) || a.is_ashr(n) || a.is_lshr(n)) { + m_bv_terms.push_back(to_app(n)); + ctx.push(push_back_vector(m_bv_terms)); + mk_bv_axiom(to_app(n)); + ensure_arg_vars(to_app(n)); + } else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) { found_unsupported(n); ensure_arg_vars(to_app(n)); @@ -308,6 +297,13 @@ namespace arith { internalize_term(n->get_arg(1)->get_expr()); } + expr* solver::mk_sub(expr* x, expr* y) { + rational r; + if (a.is_numeral(y, r) && r == 0) + return x; + return a.mk_sub(x, y); + } + bool solver::internalize_atom(expr* atom) { TRACE("arith", tout << mk_pp(atom, m) << "\n";); expr* n1, *n2; @@ -336,26 +332,26 @@ namespace arith { k = lp_api::upper_t; } else if (a.is_le(atom, n1, n2)) { - expr_ref n3(a.mk_sub(n1, n2), m); + expr_ref n3(mk_sub(n1, n2), m); v = internalize_def(n3); k = lp_api::upper_t; r = 0; } else if (a.is_ge(atom, n1, n2)) { - expr_ref n3(a.mk_sub(n1, n2), m); + expr_ref n3(mk_sub(n1, n2), m); v = internalize_def(n3); k = lp_api::lower_t; r = 0; } else if (a.is_lt(atom, n1, n2)) { - expr_ref n3(a.mk_sub(n1, n2), m); + expr_ref n3(mk_sub(n1, n2), m); v = internalize_def(n3); k = lp_api::lower_t; r = 0; lit.neg(); } - else if (a.is_gt(atom, n1, n2)) { - expr_ref n3(a.mk_sub(n1, n2), m); + else if (a.is_gt(atom, n1, n2)) { + expr_ref n3(mk_sub(n1, n2), m); v = internalize_def(n3); k = lp_api::upper_t; r = 0; @@ -457,13 +453,26 @@ namespace arith { return v; } + theory_var solver::internalize_numeral(app* n, rational const& val) { + theory_var v = mk_evar(n); + lpvar vi = get_lpvar(v); + if (vi == UINT_MAX) { + vi = lp().add_var(v, a.is_int(n)); + add_def_constraint_and_equality(vi, lp::GE, val); + add_def_constraint_and_equality(vi, lp::LE, val); + register_fixed_var(v, val); + } + return v; + } + + theory_var solver::internalize_mul(app* t) { SASSERT(a.is_mul(t)); internalize_args(t, true); bool _has_var = has_var(t); mk_enode(t); theory_var v = mk_evar(t); - + if (!_has_var) { svector vars; for (expr* n : *t) { @@ -484,57 +493,32 @@ namespace arith { theory_var v = mk_evar(term); TRACE("arith", tout << mk_bounded_pp(term, m) << " v" << v << "\n";); - if (is_unit_var(st) && v == st.vars()[0]) { + if (is_unit_var(st) && v == st.vars()[0]) return st.vars()[0]; - } - else if (is_one(st) && a.is_numeral(term)) { - return lp().local_to_external(get_one(a.is_int(term))); - } - else if (is_zero(st) && a.is_numeral(term)) { - return lp().local_to_external(get_zero(a.is_int(term))); - } - else { - init_left_side(st); - lpvar vi = get_lpvar(v); - if (vi == UINT_MAX) { - if (m_left_side.empty()) { - vi = lp().add_var(v, a.is_int(term)); - add_def_constraint_and_equality(vi, lp::GE, st.offset()); - add_def_constraint_and_equality(vi, lp::LE, st.offset()); - register_fixed_var(v, st.offset()); - return v; - } - if (!st.offset().is_zero()) { - m_left_side.push_back(std::make_pair(st.offset(), get_one(a.is_int(term)))); - } - if (m_left_side.empty()) { - vi = lp().add_var(v, a.is_int(term)); - add_def_constraint_and_equality(vi, lp::GE, rational(0)); - add_def_constraint_and_equality(vi, lp::LE, rational(0)); - } - else { - vi = lp().add_term(m_left_side, v); - SASSERT(lp::tv::is_term(vi)); - TRACE("arith_verbose", - tout << "v" << v << " := " << mk_pp(term, m) - << " slack: " << vi << " scopes: " << m_scopes.size() << "\n"; - lp().print_term(lp().get_term(lp::tv::raw(vi)), tout) << "\n";); - } + + init_left_side(st); + lpvar vi = get_lpvar(v); + + if (vi == UINT_MAX) { + if (m_left_side.empty()) { + vi = lp().add_var(v, a.is_int(term)); + add_def_constraint_and_equality(vi, lp::GE, rational(0)); + add_def_constraint_and_equality(vi, lp::LE, rational(0)); + } + else { + vi = lp().add_term(m_left_side, v); + SASSERT(lp().column_has_term(vi)); + TRACE("arith_verbose", + tout << "v" << v << " := " << mk_pp(term, m) + << " slack: " << vi << " scopes: " << m_scopes.size() << "\n"; + lp().print_term(lp().get_term(vi), tout) << "\n";); } - return v; } + return v; } bool solver::is_unit_var(scoped_internalize_state& st) { - return st.offset().is_zero() && st.vars().size() == 1 && st.coeffs()[0].is_one(); - } - - bool solver::is_one(scoped_internalize_state& st) { - return st.offset().is_one() && st.vars().empty(); - } - - bool solver::is_zero(scoped_internalize_state& st) { - return st.offset().is_zero() && st.vars().empty(); + return st.vars().size() == 1 && st.coeffs()[0].is_one(); } void solver::init_left_side(scoped_internalize_state& st) { @@ -557,8 +541,6 @@ namespace arith { rational const& r = m_columns[var]; if (!r.is_zero()) { auto vi = register_theory_var_in_lar_solver(var); - if (lp::tv::is_term(vi)) - vi = lp().map_term_index_to_column_index(vi); m_left_side.push_back(std::make_pair(r, vi)); m_columns[var].reset(); } @@ -641,9 +623,6 @@ namespace arith { return lp().external_to_local(v); } - lp::tv solver::get_tv(theory_var v) const { - return lp::tv::raw(get_lpvar(v)); - } /** \brief We must redefine this method, because theory of arithmetic contains diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp index 4fe153289de..2168299803b 100644 --- a/src/sat/smt/arith_sls.cpp +++ b/src/sat/smt/arith_sls.cpp @@ -59,18 +59,10 @@ namespace arith { int64_t val = 0; lp::lar_term const& term = s.lp().get_term(t); for (lp::lar_term::ival const& arg : term) { - auto t2 = s.lp().column2tv(arg.column()); - auto w = s.lp().local_to_external(t2.id()); + auto t2 = arg.j(); + auto w = s.lp().local_to_external(t2); val += to_numeral(arg.coeff()) * m_vars[w].m_best_value; } - if (v == 52) { - verbose_stream() << "update v" << v << " := " << val << "\n"; - for (lp::lar_term::ival const& arg : term) { - auto t2 = s.lp().column2tv(arg.column()); - auto w = s.lp().local_to_external(t2.id()); - verbose_stream() << "v" << w << " := " << m_vars[w].m_best_value << " * " << to_numeral(arg.coeff()) << "\n"; - } - } m_vars[v].m_best_value = val; } @@ -81,12 +73,12 @@ namespace arith { continue; int64_t new_value = m_vars[v].m_best_value; s.ensure_column(v); - lp::column_index vj = s.lp().to_column_index(v); - SASSERT(!vj.is_null()); - if (!s.lp().is_base(vj.index())) { + lp::lpvar vj = s.lp().external_to_local(v); + SASSERT(vj != lp::null_lpvar); + if (!s.lp().is_base(vj)) { rational new_value_(new_value, rational::i64()); lp::impq val(new_value_, rational::zero()); - s.lp().set_value_for_nbasic_column(vj.index(), val); + s.lp().set_value_for_nbasic_column(vj, val); } } @@ -460,18 +452,18 @@ namespace arith { return 0; } - void sls::add_args(sat::bool_var bv, ineq& ineq, lp::tv t, theory_var v, int64_t sign) { - if (t.is_term()) { + void sls::add_args(sat::bool_var bv, ineq& ineq, lp::lpvar t, theory_var v, int64_t sign) { + if (s.lp().column_has_term(t)) { lp::lar_term const& term = s.lp().get_term(t); m_terms.push_back({t,v}); for (lp::lar_term::ival arg : term) { - auto t2 = s.lp().column2tv(arg.column()); - auto w = s.lp().local_to_external(t2.id()); + auto t2 = arg.j(); + auto w = s.lp().local_to_external(t2); add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w); } } else - add_arg(bv, ineq, sign, s.lp().local_to_external(t.id())); + add_arg(bv, ineq, sign, s.lp().local_to_external(t)); } void sls::init_bool_var(sat::bool_var bv) { @@ -480,7 +472,7 @@ namespace arith { api_bound* b = nullptr; s.m_bool_var2bound.find(bv, b); if (b) { - auto t = b->tv(); + auto t = b->column_index(); rational bound = b->get_value(); bool should_minus = false; sls::ineq_kind op; @@ -503,8 +495,8 @@ namespace arith { if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { theory_var u = s.get_th_var(l); theory_var v = s.get_th_var(r); - lp::tv tu = s.get_tv(u); - lp::tv tv = s.get_tv(v); + lp::lpvar tu = s.get_column(u); + lp::lpvar tv = s.get_column(v); auto& ineq = new_ineq(sls::ineq_kind::EQ, 0); add_args(bv, ineq, tu, u, 1); add_args(bv, ineq, tv, v, -1); diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h index 09a56c84efd..55d39b25205 100644 --- a/src/sat/smt/arith_sls.h +++ b/src/sat/smt/arith_sls.h @@ -105,7 +105,7 @@ namespace arith { config m_config; scoped_ptr_vector m_bool_vars; vector m_vars; - svector> m_terms; + svector> m_terms; bool m_dscore_mode = false; @@ -140,7 +140,7 @@ namespace arith { void add_vars(); sls::ineq& new_ineq(ineq_kind op, int64_t const& bound); void add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v); - void add_args(sat::bool_var bv, ineq& ineq, lp::tv t, euf::theory_var v, int64_t sign); + void add_args(sat::bool_var bv, ineq& ineq, lp::lpvar j, euf::theory_var v, int64_t sign); void init_bool_var(sat::bool_var v); void init_bool_var_assignment(sat::bool_var v); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index bd5dd315fe7..3086d75f43f 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -26,7 +26,7 @@ namespace arith { m_model_eqs(DEFAULT_HASHTABLE_INITIAL_CAPACITY, var_value_hash(*this), var_value_eq(*this)), m_local_search(*this), m_resource_limit(*this), - m_bp(*this), + m_bp(*this, m_implied_bounds), a(m), m_bound_terms(m), m_bound_predicate(m) @@ -53,21 +53,25 @@ namespace arith { euf::th_solver* solver::clone(euf::solver& dst_ctx) { arith::solver* result = alloc(arith::solver, dst_ctx, get_id()); + unsigned_vector var2var; + for (unsigned i = 0; i < result->get_num_vars(); ++i) + var2var.push_back(i); + for (unsigned i = result->get_num_vars(); i < get_num_vars(); ++i) - result->mk_evar(ctx.copy(dst_ctx, var2enode(i))->get_expr()); - - unsigned v = 0; - result->m_bounds.resize(m_bounds.size()); - for (auto const& bounds : m_bounds) { - for (auto* b : bounds) { - auto* b2 = result->mk_var_bound(b->get_lit(), v, b->get_bound_kind(), b->get_value()); - result->m_bounds[v].push_back(b2); - result->m_bounds_trail.push_back(v); - result->updt_unassigned_bounds(v, +1); + var2var.push_back(result->mk_evar(ctx.copy(dst_ctx, var2enode(i))->get_expr())); + + result->m_bounds.resize(get_num_vars()); + unsigned nv = std::min(m_bounds.size(), get_num_vars()); + for (unsigned v = 0; v < nv; ++v) { + auto w = var2var[v]; + for (auto* b : m_bounds[v]) { + auto* b2 = result->mk_var_bound(b->get_lit(), w, b->get_bound_kind(), b->get_value()); + result->m_bounds[w].push_back(b2); + result->m_bounds_trail.push_back(w); + result->updt_unassigned_bounds(w, +1); result->m_bool_var2bound.insert(b->get_lit().var(), b2); result->m_new_bounds.push_back(b2); } - ++v; } // clone rows into m_solver, m_nla, m_lia @@ -83,6 +87,7 @@ namespace arith { m_new_eq = false; flush_bound_axioms(); + propagate_nla(); unsigned qhead = m_asserted_qhead; while (m_asserted_qhead < m_asserted.size() && !s().inconsistent() && m.inc()) { @@ -301,11 +306,31 @@ namespace arith { m_explanation.add_pair(j, v); } + void solver::add_equality(lpvar j, rational const& k, lp::explanation const& exp) { + TRACE("arith", tout << "equality " << j << " " << k << "\n"); + theory_var v; + if (k == 1) + v = m_one_var; + else if (k == 0) + v = m_zero_var; + else if (!m_value2var.find(k, v)) + return; + theory_var w = lp().local_to_external(j); + if (w < 0) + return; + lpvar i = register_theory_var_in_lar_solver(v); + add_eq(i, j, exp, true); + } + bool solver::add_eq(lpvar u, lpvar v, lp::explanation const& e, bool is_fixed) { if (s().inconsistent()) return false; theory_var uv = lp().local_to_external(u); // variables that are returned should have external representations theory_var vv = lp().local_to_external(v); // so maybe better to have them already transformed to external form + if (uv == euf::null_theory_var) + return false; + if (vv == euf::null_theory_var) + return false; if (is_equal(uv, vv)) return false; enode* n1 = var2enode(uv); @@ -345,7 +370,7 @@ namespace arith { void solver::refine_bound(theory_var v, const lp::implied_bound& be) { lpvar vi = be.m_j; - if (lp::tv::is_term(vi)) + if (lp().column_has_term(vi)) return; expr_ref w(var2expr(v), m); if (a.is_add(w) || a.is_numeral(w) || m.is_ite(w)) @@ -393,37 +418,37 @@ namespace arith { ++m_stats.m_assert_upper; inf_rational value = b.get_value(is_true); if (propagate_eqs() && value.is_rational()) - propagate_eqs(b.tv(), ci, k, b, value.get_rational()); + propagate_eqs(b.column_index(), ci, k, b, value.get_rational()); #if 0 if (propagation_mode() != BP_NONE) - lp().mark_rows_for_bound_prop(b.tv().id()); + lp().add_column_rows_to_touched_rows(b.tv().id()); #endif } - void solver::propagate_eqs(lp::tv t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { - lp::constraint_index ci2; - if (k == lp::GE && set_lower_bound(t, ci1, value) && has_upper_bound(t.index(), ci2, value)) { - fixed_var_eh(b.get_var(), ci1, ci2, value); + void solver::propagate_eqs(lp::lpvar t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { + u_dependency* dep; + auto& dm = lp().dep_manager(); + if (k == lp::GE && set_lower_bound(t, ci1, value) && has_upper_bound(t, dep, value)) { + fixed_var_eh(b.get_var(), dm.mk_join(dm.mk_leaf(ci1), dep), value); } - else if (k == lp::LE && set_upper_bound(t, ci1, value) && has_lower_bound(t.index(), ci2, value)) { - fixed_var_eh(b.get_var(), ci1, ci2, value); + else if (k == lp::LE && set_upper_bound(t, ci1, value) && has_lower_bound(t, dep, value)) { + fixed_var_eh(b.get_var(), dm.mk_join(dm.mk_leaf(ci1), dep), value); } } - bool solver::set_bound(lp::tv tv, lp::constraint_index ci, rational const& v, bool is_lower) { - if (tv.is_term()) { - lpvar ti = tv.id(); + bool solver::set_bound(lp::lpvar tv, lp::constraint_index ci, rational const& v, bool is_lower) { + if (lp().column_has_term(tv)) { auto& vec = is_lower ? m_lower_terms : m_upper_terms; - if (vec.size() <= ti) { - vec.resize(ti + 1, constraint_bound(UINT_MAX, rational())); + if (vec.size() <= tv) { + vec.resize(tv + 1, constraint_bound(UINT_MAX, rational())); } - constraint_bound& b = vec[ti]; + constraint_bound& b = vec[tv]; if (b.first == UINT_MAX || (is_lower ? b.second < v : b.second > v)) { - TRACE("arith", tout << "tighter bound " << tv.to_string() << "\n";); - m_history.push_back(vec[ti]); - ctx.push(history_trail(vec, ti, m_history)); + TRACE("arith", tout << "tighter bound " << tv << "\n";); + m_history.push_back(vec[tv]); + ctx.push(history_trail(vec, tv, m_history)); b.first = ci; b.second = v; } @@ -433,11 +458,12 @@ namespace arith { // m_solver already tracks bounds on proper variables, but not on terms. bool is_strict = false; rational b; + u_dependency* dep = nullptr; if (is_lower) { - return lp().has_lower_bound(tv.id(), ci, b, is_strict) && !is_strict && b == v; + return lp().has_lower_bound(tv, dep, b, is_strict) && !is_strict && b == v; } else { - return lp().has_upper_bound(tv.id(), ci, b, is_strict) && !is_strict && b == v; + return lp().has_upper_bound(tv, dep, b, is_strict) && !is_strict && b == v; } } } @@ -592,42 +618,53 @@ namespace arith { } } - void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + bool solver::get_value(euf::enode* n, expr_ref& value) { theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); - expr_ref value(m); + if (m.is_value(n->get_root()->get_expr())) { value = n->get_root()->get_expr(); } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { - anum const& an = nl_value(v, *m_a1); + anum const& an = nl_value(v, m_nla->tmp1()); if (a.is_int(o) && !m_nla->am().is_int(an)) - value = a.mk_numeral(rational::zero(), a.is_int(o)); + value = a.mk_numeral(rational::zero(), a.is_int(o)); else - value = a.mk_numeral(m_nla->am(), nl_value(v, *m_a1), a.is_int(o)); + value = a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o)); } else if (v != euf::null_theory_var) { rational r = get_value(v); TRACE("arith", tout << mk_pp(o, m) << " v" << v << " := " << r << "\n";); - SASSERT("integer variables should have integer values: " && (ctx.get_config().m_arith_ignore_int || !a.is_int(o) || r.is_int() || m.limit().is_canceled())); - if (a.is_int(o) && !r.is_int()) + SASSERT("integer variables should have integer values: " && (ctx.get_config().m_arith_ignore_int || !a.is_int(o) || r.is_int() || m_not_handled != nullptr || m.limit().is_canceled())); + if (a.is_int(o) && !r.is_int()) r = floor(r); value = a.mk_numeral(r, o->get_sort()); } + else + return false; + + return true; + } + + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + expr* o = n->get_expr(); + if (get_value(n, value)) + ; else if (a.is_arith_expr(o) && reflect(o)) { expr_ref_vector args(m); for (auto* arg : *to_app(o)) { if (m.is_value(arg)) args.push_back(arg); - else + else args.push_back(values.get(ctx.get_enode(arg)->get_root_id())); } value = m.mk_app(to_app(o)->get_decl(), args.size(), args.data()); ctx.get_rewriter()(value); } - else { - value = mdl.get_fresh_value(o->get_sort()); - } + else + value = mdl.get_fresh_value(n->get_sort()); mdl.register_value(value); values.set(n->get_root_id(), value); } @@ -692,7 +729,7 @@ namespace arith { void solver::report_equality_of_fixed_vars(unsigned vi1, unsigned vi2) { rational bound; - lp::constraint_index ci1, ci2, ci3, ci4; + u_dependency* ci1 = nullptr, *ci2 = nullptr, *ci3 = nullptr, *ci4 = nullptr; theory_var v1 = lp().local_to_external(vi1); theory_var v2 = lp().local_to_external(vi2); TRACE("arith", tout << "fixed: " << mk_pp(var2expr(v1), m) << " " << mk_pp(var2expr(v2), m) << "\n";); @@ -713,10 +750,11 @@ namespace arith { ++m_stats.m_fixed_eqs; reset_evidence(); - set_evidence(ci1); - set_evidence(ci2); - set_evidence(ci3); - set_evidence(ci4); + m_explanation.clear(); + auto& dm = lp().dep_manager(); + auto* d = dm.mk_join(dm.mk_join(ci1, ci2), dm.mk_join(ci3, ci4)); + for (auto ci : lp().flatten(d)) + consume(rational::one(), ci); enode* x = var2enode(v1); enode* y = var2enode(v2); auto* ex = explain_implied_eq(m_explanation, x, y); @@ -728,26 +766,27 @@ namespace arith { return x == y || var2enode(x)->get_root() == var2enode(y)->get_root(); } - bool solver::has_upper_bound(lpvar vi, lp::constraint_index& ci, rational const& bound) { return has_bound(vi, ci, bound, false); } + bool solver::has_upper_bound(lpvar vi, u_dependency*& ci, rational const& bound) { return has_bound(vi, ci, bound, false); } - bool solver::has_lower_bound(lpvar vi, lp::constraint_index& ci, rational const& bound) { return has_bound(vi, ci, bound, true); } + bool solver::has_lower_bound(lpvar vi, u_dependency*& ci, rational const& bound) { return has_bound(vi, ci, bound, true); } - bool solver::has_bound(lpvar vi, lp::constraint_index& ci, rational const& bound, bool is_lower) { - if (lp::tv::is_term(vi)) { + bool solver::has_bound(lpvar vi, u_dependency*& dep, rational const& bound, bool is_lower) { + if (lp().column_has_term(vi)) { theory_var v = lp().local_to_external(vi); rational val; TRACE("arith", tout << lp().get_variable_name(vi) << " " << v << "\n";); if (v != euf::null_theory_var && a.is_numeral(var2expr(v), val) && bound == val) { - ci = UINT_MAX; + dep = nullptr; return bound == val; } auto& vec = is_lower ? m_lower_terms : m_upper_terms; - lpvar ti = lp::tv::unmask_term(vi); - if (vec.size() > ti) { - constraint_bound& b = vec[ti]; - ci = b.first; - return ci != UINT_MAX && bound == b.second; + if (vec.size() > vi) { + auto& [ci, coeff] = vec[vi]; + if (ci == UINT_MAX) + return false; + dep = lp().dep_manager().mk_leaf(ci); + return bound == coeff; } else { return false; @@ -757,10 +796,10 @@ namespace arith { bool is_strict = false; rational b; if (is_lower) { - return lp().has_lower_bound(vi, ci, b, is_strict) && b == bound && !is_strict; + return lp().has_lower_bound(vi, dep, b, is_strict) && b == bound && !is_strict; } else { - return lp().has_upper_bound(vi, ci, b, is_strict) && b == bound && !is_strict; + return lp().has_upper_bound(vi, dep, b, is_strict) && b == bound && !is_strict; } } } @@ -835,19 +874,23 @@ namespace arith { lp::impq solver::get_ivalue(theory_var v) const { SASSERT(is_registered_var(v)); - return m_solver->get_tv_ivalue(get_tv(v)); + return m_solver->get_column_value(get_column(v)); } + lp::lpvar solver::get_column(theory_var v) const { + SASSERT(is_registered_var(v)); + return m_solver->external_to_local(v); + } + rational solver::get_value(theory_var v) const { - return is_registered_var(v) ? m_solver->get_tv_value(get_tv(v)) : rational::zero(); + return is_registered_var(v) ? m_solver->get_value(get_column(v)) : rational::zero(); } void solver::random_update() { if (m_nla) return; TRACE("arith", tout << s().scope_lvl() << "\n"; tout.flush();); - m_tmp_var_set.clear(); - m_tmp_var_set.resize(get_num_vars()); + m_tmp_var_set.reset(); m_model_eqs.reset(); svector vars; theory_var sz = static_cast(get_num_vars()); @@ -855,18 +898,18 @@ namespace arith { if (is_bool(v)) continue; ensure_column(v); - lp::column_index vj = lp().to_column_index(v); - SASSERT(!vj.is_null()); + lp::lpvar vj = lp().external_to_local(v); + SASSERT(vj != lp::null_lpvar); theory_var other = m_model_eqs.insert_if_not_there(v); if (is_equal(v, other)) continue; - if (!lp().is_fixed(vj)) - vars.push_back(vj.index()); + if (!lp().column_is_fixed(vj)) + vars.push_back(vj); else if (!m_tmp_var_set.contains(other)) { - lp::column_index other_j = lp().to_column_index(other); - if (!lp().is_fixed(other_j)) { + lp::lpvar other_j = lp().external_to_local(other); + if (!lp().column_is_fixed(other_j)) { m_tmp_var_set.insert(other); - vars.push_back(other_j.index()); + vars.push_back(other_j); } } } @@ -874,7 +917,15 @@ namespace arith { lp().random_update(vars.size(), vars.data()); } + bool solver::include_func_interp(enode* n) const { + func_decl* d = n->get_decl(); + return d && include_func_interp(d); + } + bool solver::assume_eqs() { + if (delayed_assume_eqs()) + return true; + TRACE("arith", display(tout);); random_update(); m_model_eqs.reset(); @@ -923,26 +974,26 @@ namespace arith { continue; literal eq = eq_internalize(n1, n2); ctx.mark_relevant(eq); - if (s().value(eq) != l_true) + switch (s().value(eq)) { + case l_true: + break; + case l_undef: + return true; + case l_false: + mk_diseq_axiom(v1, v2); return true; + } } return false; } bool solver::use_nra_model() { - if (m_nla && m_nla->use_nra_model()) { - if (!m_a1) { - m_a1 = alloc(scoped_anum, m_nla->am()); - m_a2 = alloc(scoped_anum, m_nla->am()); - } - return true; - } - return false; + return m_nla && m_nla->use_nra_model(); } bool solver::is_eq(theory_var v1, theory_var v2) { if (use_nra_model()) { - return m_nla->am().eq(nl_value(v1, *m_a1), nl_value(v2, *m_a2)); + return m_nla->am().eq(nl_value(v1, m_nla->tmp1()), nl_value(v2, m_nla->tmp2())); } else { return get_ivalue(v1) == get_ivalue(v2); @@ -975,9 +1026,6 @@ namespace arith { TRACE("arith", ctx.display(tout);); - if (!check_delayed_eqs()) - return sat::check_result::CR_CONTINUE; - switch (check_lia()) { case l_true: break; @@ -1000,17 +1048,17 @@ namespace arith { st = sat::check_result::CR_GIVEUP; break; } - - if (delayed_assume_eqs()) { - ++m_stats.m_assume_eqs; - return sat::check_result::CR_CONTINUE; - } + if (assume_eqs()) { ++m_stats.m_assume_eqs; return sat::check_result::CR_CONTINUE; } if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; + + if (!int_undef && !check_bv_terms()) + return sat::check_result::CR_CONTINUE; + if (ctx.get_config().m_arith_ignore_int && int_undef) return sat::check_result::CR_GIVEUP; if (m_not_handled != nullptr) { @@ -1023,11 +1071,14 @@ namespace arith { nlsat::anum const& solver::nl_value(theory_var v, scoped_anum& r) const { SASSERT(m_nla); SASSERT(m_nla->use_nra_model()); - auto t = get_tv(v); - if (t.is_term()) { + auto t = get_column(v); + if (!lp().column_has_term(t)) { + m_nla->am().set(r, m_nla->am_value(t)); + } + else { m_todo_terms.push_back(std::make_pair(t, rational::one())); - TRACE("nl_value", tout << "v" << v << " " << t.to_string() << "\n";); - TRACE("nl_value", tout << "v" << v << " := w" << t.to_string() << "\n"; + TRACE("nl_value", tout << "v" << v << " " << t << "\n";); + TRACE("nl_value", tout << "v" << v << " := w" << t << "\n"; lp().print_term(lp().get_term(t), tout) << "\n";); m_nla->am().set(r, 0); @@ -1042,23 +1093,20 @@ namespace arith { m_nla->am().set(r1, c1.to_mpq()); m_nla->am().add(r, r1, r); for (lp::lar_term::ival arg : term) { - auto wi = lp().column2tv(arg.column()); + auto wi = arg.j(); c1 = arg.coeff() * wcoeff; - if (wi.is_term()) { + if (lp().column_has_term(wi)) { m_todo_terms.push_back(std::make_pair(wi, c1)); } else { m_nla->am().set(r1, c1.to_mpq()); - m_nla->am().mul(m_nla->am_value(wi.id()), r1, r1); + m_nla->am().mul(m_nla->am_value(wi), r1, r1); m_nla->am().add(r1, r, r); } } } - return r; - } - else { - return m_nla->am_value(t.id()); } + return r; } lbool solver::make_feasible() { @@ -1092,7 +1140,7 @@ namespace arith { if (p.second) new_eq_eh(e); else if (is_eq(e.v1(), e.v2())) { - mk_diseq_axiom(e); + mk_diseq_axiom(e.v1(), e.v2()); found_diseq = true; break; } @@ -1147,21 +1195,22 @@ namespace arith { app_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper()); IF_VERBOSE(4, verbose_stream() << "cut " << b << "\n"); literal lit = expr2literal(b); - assign(lit, m_core, m_eqs, explain(hint_type::bound_h, lit)); + assign(lit, m_core, m_eqs, explain(hint_type::cut_h, lit)); lia_check = l_false; break; } case lp::lia_move::conflict: TRACE("arith", tout << "conflict\n";); // ex contains unsat core - set_conflict(); + set_conflict(hint_type::cut_h); return l_false; case lp::lia_move::undef: TRACE("arith", tout << "lia undef\n";); lia_check = l_undef; break; case lp::lia_move::continue_with_check: - lia_check = l_undef; + TRACE("arith", tout << "continue-with-check\n"); + lia_check = l_false; break; default: UNREACHABLE(); @@ -1186,15 +1235,15 @@ namespace arith { void solver::get_infeasibility_explanation_and_set_conflict() { m_explanation.clear(); lp().get_infeasibility_explanation(m_explanation); - set_conflict(); + set_conflict(hint_type::farkas_h); } - void solver::set_conflict() { + void solver::set_conflict(hint_type ty) { literal_vector core; - set_conflict_or_lemma(core, true); + set_conflict_or_lemma(ty, core, true); } - void solver::set_conflict_or_lemma(literal_vector const& core, bool is_conflict) { + void solver::set_conflict_or_lemma(hint_type ty, literal_vector const& core, bool is_conflict) { reset_evidence(); m_core.append(core); for (auto ev : m_explanation) @@ -1205,13 +1254,16 @@ namespace arith { for (literal c : m_core) tout << c << ": " << literal2expr(c) << "\n"; for (auto p : m_eqs) tout << ctx.bpp(p.first) << " == " << ctx.bpp(p.second) << "\n";); + if (ctx.get_config().m_arith_validate) + VERIFY(validate_conflict()); + if (is_conflict) { DEBUG_CODE( for (literal c : m_core) VERIFY(s().value(c) == l_true); for (auto p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root())); ++m_num_conflicts; ++m_stats.m_conflicts; - auto* hint = explain_conflict(m_core, m_eqs); + auto* hint = explain_conflict(ty, m_core, m_eqs); ctx.set_conflict(euf::th_explain::conflict(*this, m_core, m_eqs, hint)); } else { @@ -1220,7 +1272,7 @@ namespace arith { for (literal& c : m_core) c.neg(); - add_redundant(m_core, explain(hint_type::farkas_h)); + add_redundant(m_core, explain(ty)); } } @@ -1344,17 +1396,17 @@ namespace arith { TRACE("arith", lp().print_term(term, tout) << "\n";); for (lp::lar_term::ival ti : term) { theory_var w; - auto tv = lp().column2tv(ti.column()); - if (tv.is_term()) { + auto tv = ti.j(); + if (lp().column_has_term(tv)) { lp::lar_term const& term1 = lp().get_term(tv); rational coeff2 = coeff * ti.coeff(); term2coeffs(term1, coeffs, coeff2); continue; } else { - w = lp().local_to_external(tv.id()); + w = lp().local_to_external(tv); SASSERT(w >= 0); - TRACE("arith", tout << (tv.id()) << ": " << w << "\n";); + TRACE("arith", tout << tv << ": " << w << "\n";); } rational c0(0); coeffs.find(w, c0); @@ -1406,30 +1458,35 @@ namespace arith { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; - for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - sat::literal lit; - if (is_eq) - lit = mk_eq(ineq.term(), ineq.rs()); - else - lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); - core.push_back(pos ? lit : ~lit); - } - set_conflict_or_lemma(core, false); + for (auto const& ineq : m_lemma.ineqs()) + core.push_back(~mk_ineq_literal(ineq)); + set_conflict_or_lemma(hint_type::nla_h, core, false); + } + + sat::literal solver::mk_ineq_literal(nla::ineq const& ineq) { + bool is_lower = true, sign = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; sign = false; break; + case lp::LT: is_lower = true; sign = true; break; + case lp::GE: is_lower = true; sign = false; break; + case lp::GT: is_lower = false; sign = true; break; + case lp::EQ: is_eq = true; sign = false; break; + case lp::NE: is_eq = true; sign = true; break; + default: UNREACHABLE(); + } + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + sat::literal lit; + if (is_eq) + lit = mk_eq(ineq.term(), ineq.rs()); + else + lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); + + TRACE("arith", tout << "is_lower: " << is_lower << " sign " << sign << " " << ctx.literal2expr(lit) << "\n";); + return sign ? ~lit : lit; } + lbool solver::check_nla() { if (!m.inc()) { TRACE("arith", tout << "canceled\n";); @@ -1441,16 +1498,12 @@ namespace arith { if (!m_nla->need_check()) return l_true; - m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); + lbool r = m_nla->check(); switch (r) { - case l_false: - for (const nla::lemma& l : m_nla_lemma_vector) - false_case_of_check_nla(l); + case l_false: + add_lemmas(); break; case l_true: - if (assume_eqs()) - return l_false; break; case l_undef: break; @@ -1458,9 +1511,48 @@ namespace arith { return r; } + void solver::add_lemmas() { + if (m_nla->should_check_feasible()) { + auto is_sat = make_feasible(); + if (l_false == is_sat) { + get_infeasibility_explanation_and_set_conflict(); + return; + } + } + for (auto const& ineq : m_nla->literals()) { + auto lit = mk_ineq_literal(ineq); + ctx.mark_relevant(lit); + s().set_phase(lit); + // force trichotomy axiom for equality literals + if (ineq.cmp() == lp::EQ) { + nla::lemma l; + l.push_back(ineq); + l.push_back(nla::ineq(lp::LT, ineq.term(), ineq.rs())); + l.push_back(nla::ineq(lp::GT, ineq.term(), ineq.rs())); + false_case_of_check_nla(l); + } + } + for (const nla::lemma& l : m_nla->lemmas()) + false_case_of_check_nla(l); + if (!propagate_eqs()) + return; + for (auto const& [v,k,e] : m_nla->fixed_equalities()) + add_equality(v, k, e); + for (auto const& [i,j,e] : m_nla->equalities()) + add_eq(i,j,e,false); + } + + void solver::propagate_nla() { + if (m_nla) { + m_nla->propagate(); + add_lemmas(); + lp().collect_more_rows_for_lp_propagation(); + } + } + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { auto& jst = euf::th_explain::from_index(idx); - ctx.get_antecedents(l, jst, r, probing); + ctx.get_th_antecedents(l, jst, r, probing); } bool solver::include_func_interp(func_decl* f) const { diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 68d5f802592..755611474bd 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -38,7 +38,7 @@ namespace euf { namespace arith { typedef ptr_vector> lp_bounds; - typedef lp::var_index lpvar; + typedef lp::lpvar lpvar; typedef euf::theory_var theory_var; typedef euf::theory_id theory_id; typedef euf::enode enode; @@ -51,15 +51,16 @@ namespace arith { enum class hint_type { farkas_h, bound_h, - implied_eq_h + cut_h, + implied_eq_h, + nla_h }; struct arith_proof_hint : public euf::th_proof_hint { hint_type m_ty; - unsigned m_num_le; unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; - arith_proof_hint(hint_type t, unsigned num_le, unsigned lh, unsigned lt, unsigned eh, unsigned et): - m_ty(t), m_num_le(num_le), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} + arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et): + m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} expr* get_hint(euf::solver& s) const override; }; @@ -67,7 +68,6 @@ namespace arith { vector> m_literals; svector> m_eqs; hint_type m_ty; - unsigned m_num_le = 0; unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } void add(euf::enode* a, euf::enode* b, bool is_eq) { @@ -79,7 +79,6 @@ namespace arith { } public: void set_type(euf::solver& ctx, hint_type ty); - void set_num_le(unsigned n) { m_num_le = n; } void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } void add_lit(rational const& coeff, literal lit) { @@ -145,13 +144,11 @@ namespace arith { expr_ref_vector m_terms; vector m_coeffs; svector m_vars; - rational m_offset; ptr_vector m_to_ensure_enode, m_to_ensure_var; internalize_state(ast_manager& m) : m_terms(m) {} void reset() { m_terms.reset(); m_coeffs.reset(); - m_offset.reset(); m_vars.reset(); m_to_ensure_enode.reset(); m_to_ensure_var.reset(); @@ -178,7 +175,6 @@ namespace arith { expr_ref_vector& terms() { return m_st.m_terms; } vector& coeffs() { return m_st.m_coeffs; } svector& vars() { return m_st.m_vars; } - rational& offset() { return m_st.m_offset; } ptr_vector& to_ensure_enode() { return m_st.m_to_ensure_enode; } ptr_vector& to_ensure_var() { return m_st.m_to_ensure_var; } void push(expr* e, rational c) { m_st.m_terms.push_back(e); m_st.m_coeffs.push_back(c); } @@ -218,6 +214,7 @@ namespace arith { expr* m_not_handled = nullptr; ptr_vector m_underspecified; ptr_vector m_idiv_terms; + ptr_vector m_bv_terms; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -229,7 +226,7 @@ namespace arith { svector > m_assume_eq_candidates; unsigned m_assume_eq_head = 0; - lp::u_set m_tmp_var_set; + indexed_uint_set m_tmp_var_set; unsigned m_num_conflicts = 0; lp_api::stats m_stats; @@ -237,7 +234,6 @@ namespace arith { // non-linear arithmetic scoped_ptr m_nla; - scoped_ptr m_a1, m_a2; // integer arithmetic scoped_ptr m_lia; @@ -247,13 +243,14 @@ namespace arith { resource_limit m_resource_limit; lp_bounds m_new_bounds; symbol m_farkas; + std_vector m_implied_bounds; lp::lp_bound_propagator m_bp; - mutable vector> m_todo_terms; + mutable vector> m_todo_terms; // lemmas lp::explanation m_explanation; - vector m_nla_lemma_vector; literal_vector m_core, m_core2; + vector m_coeffs; svector m_eqs; vector m_params; nla::lemma m_lemma; @@ -290,15 +287,17 @@ namespace arith { void ensure_arg_vars(app* t); theory_var internalize_power(app* t, app* n, unsigned p); theory_var internalize_mul(app* t); + theory_var internalize_numeral(app* t, rational const& v); theory_var internalize_def(expr* term); theory_var internalize_def(expr* term, scoped_internalize_state& st); theory_var internalize_linearized_def(expr* term, scoped_internalize_state& st); void init_left_side(scoped_internalize_state& st); bool internalize_term(expr* term); - bool internalize_atom(expr* atom); + bool internalize_atom(expr* atom); bool is_unit_var(scoped_internalize_state& st); bool is_one(scoped_internalize_state& st); bool is_zero(scoped_internalize_state& st); + expr* mk_sub(expr* a, expr* b); enode* mk_enode(expr* e); lpvar register_theory_var_in_lar_solver(theory_var v); @@ -307,7 +306,7 @@ namespace arith { bool reflect(expr* n) const; lpvar get_lpvar(theory_var v) const; - lp::tv get_tv(theory_var v) const; + lp::lpvar get_column(theory_var v) const; // axioms void mk_div_axiom(expr* p, expr* q); @@ -319,6 +318,7 @@ namespace arith { void mk_bound_axioms(api_bound& b); void mk_bound_axiom(api_bound& b1, api_bound& b2); void mk_power0_axioms(app* t, app* n); + void mk_bv_axiom(app* n); void flush_bound_axioms(); void add_farkas_clause(sat::literal l1, sat::literal l2); @@ -348,7 +348,7 @@ namespace arith { iterator end, bool& found_compatible); - void propagate_eqs(lp::tv t, lp::constraint_index ci, lp::lconstraint_kind k, api_bound& b, rational const& value); + void propagate_eqs(lp::lpvar t, lp::constraint_index ci, lp::lconstraint_kind k, api_bound& b, rational const& value); void propagate_basic_bounds(unsigned qhead); void propagate_bounds_with_lp_solver(); void propagate_bound(literal lit, api_bound& b); @@ -357,14 +357,14 @@ namespace arith { literal is_bound_implied(lp::lconstraint_kind k, rational const& value, api_bound const& b) const; void assert_bound(bool is_true, api_bound& b); void mk_eq_axiom(bool is_eq, euf::th_eq const& eq); - void mk_diseq_axiom(euf::th_eq const& eq); + void mk_diseq_axiom(theory_var v1, theory_var v2); void assert_idiv_mod_axioms(theory_var u, theory_var v, theory_var w, rational const& r); api_bound* mk_var_bound(sat::literal lit, theory_var v, lp_api::bound_kind bk, rational const& bound); lp::lconstraint_kind bound2constraint_kind(bool is_int, lp_api::bound_kind bk, bool is_true); - void fixed_var_eh(theory_var v1, lp::constraint_index ci1, lp::constraint_index ci2, rational const& bound); - bool set_upper_bound(lp::tv t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, false); } - bool set_lower_bound(lp::tv t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, true); } - bool set_bound(lp::tv tv, lp::constraint_index ci, rational const& v, bool is_lower); + void fixed_var_eh(theory_var v1, u_dependency* dep, rational const& bound); + bool set_upper_bound(lp::lpvar t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, false); } + bool set_lower_bound(lp::lpvar t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, true); } + bool set_bound(lp::lpvar tv, lp::constraint_index ci, rational const& v, bool is_lower); typedef std::pair constraint_bound; vector m_lower_terms; @@ -404,19 +404,25 @@ namespace arith { bool delayed_assume_eqs(); bool is_eq(theory_var v1, theory_var v2); bool use_nra_model(); + bool include_func_interp(enode* n) const; lbool make_feasible(); bool check_delayed_eqs(); lbool check_lia(); lbool check_nla(); + bool check_bv_terms(); + bool check_bv_term(app* n); + void add_lemmas(); + void propagate_nla(); + void add_equality(lpvar v, rational const& k, lp::explanation const& exp); bool is_infeasible() const; nlsat::anum const& nl_value(theory_var v, scoped_anum& r) const; - bool has_bound(lpvar vi, lp::constraint_index& ci, rational const& bound, bool is_lower); - bool has_lower_bound(lpvar vi, lp::constraint_index& ci, rational const& bound); - bool has_upper_bound(lpvar vi, lp::constraint_index& ci, rational const& bound); + bool has_bound(lpvar vi, u_dependency*& ci, rational const& bound, bool is_lower); + bool has_lower_bound(lpvar vi, u_dependency*& ci, rational const& bound); + bool has_upper_bound(lpvar vi, u_dependency*& ci, rational const& bound); /* * Facility to put a small box around integer variables used in branch and bounds. @@ -460,11 +466,12 @@ namespace arith { void term2coeffs(lp::lar_term const& term, u_map& coeffs); void get_infeasibility_explanation_and_set_conflict(); - void set_conflict(); - void set_conflict_or_lemma(literal_vector const& core, bool is_conflict); + void set_conflict(hint_type ty); + void set_conflict_or_lemma(hint_type ty, literal_vector const& core, bool is_conflict); void set_evidence(lp::constraint_index idx); void assign(literal lit, literal_vector const& core, svector const& eqs, euf::th_proof_hint const* pma); + sat::literal mk_ineq_literal(nla::ineq const& ineq); void false_case_of_check_nla(const nla::lemma& l); void dbg_finalize_model(model& mdl); @@ -473,9 +480,10 @@ namespace arith { arith_proof_hint const* explain(hint_type ty, sat::literal lit = sat::null_literal); arith_proof_hint const* explain_implied_eq(lp::explanation const& e, euf::enode* a, euf::enode* b); arith_proof_hint const* explain_trichotomy(sat::literal le, sat::literal ge, sat::literal eq); - arith_proof_hint const* explain_conflict(sat::literal_vector const& core, euf::enode_pair_vector const& eqs); + arith_proof_hint const* explain_conflict(hint_type ty, sat::literal_vector const& core, euf::enode_pair_vector const& eqs); void explain_assumptions(lp::explanation const& e); + bool validate_conflict(); public: solver(euf::solver& ctx, theory_id id); @@ -519,6 +527,8 @@ namespace arith { bool add_eq(lpvar u, lpvar v, lp::explanation const& e, bool is_fixed); void consume(rational const& v, lp::constraint_index j); bool bound_is_interesting(unsigned vi, lp::lconstraint_kind kind, const rational& bval) const; + + bool get_value(euf::enode* n, expr_ref& val); }; diff --git a/src/sat/smt/arith_theory_checker.h b/src/sat/smt/arith_theory_checker.h index 65c64765857..87868d940c6 100644 --- a/src/sat/smt/arith_theory_checker.h +++ b/src/sat/smt/arith_theory_checker.h @@ -35,6 +35,15 @@ The module assumes a limited repertoire of arithmetic proof rules. namespace arith { class theory_checker : public euf::theory_checker_plugin { + + enum rule_type_t { + cut_t, + farkas_t, + implied_eq_t, + bound_t, + none_t + }; + struct row { obj_map m_coeffs; rational m_coeff; @@ -42,6 +51,9 @@ namespace arith { m_coeffs.reset(); m_coeff = 0; } + bool is_zero() const { + return m_coeffs.empty() && m_coeff == 0; + } }; ast_manager& m; @@ -50,10 +62,24 @@ namespace arith { bool m_strict = false; row m_ineq; row m_conseq; - vector m_eqs; - symbol m_farkas; - symbol m_implied_eq; - symbol m_bound; + vector m_eqs, m_ineqs; + symbol m_farkas = symbol("farkas"); + symbol m_implied_eq = symbol("implied-eq"); + symbol m_bound = symbol("bound"); + symbol m_cut = symbol("cut"); + + rule_type_t rule_type(app* jst) const { + if (jst->get_name() == m_cut) + return cut_t; + if (jst->get_name() == m_bound) + return bound_t; + if (jst->get_name() == m_implied_eq) + return implied_eq_t; + if (jst->get_name() == m_farkas) + return farkas_t; + return none_t; + } + void add(row& r, expr* v, rational const& coeff) { rational coeff1; @@ -90,10 +116,10 @@ namespace arith { // X = lcm(a,b)/b, Y = -lcm(a,b)/a if v is integer // X = 1/b, Y = -1/a if v is real // - void resolve(expr* v, row& dst, rational const& A, row const& src) { + bool resolve(expr* v, row& dst, rational const& A, row const& src) { rational B, x, y; if (!dst.m_coeffs.find(v, B)) - return; + return false; if (a.is_int(v)) { rational lc = lcm(abs(A), abs(B)); x = lc / abs(B); @@ -109,6 +135,7 @@ namespace arith { y.neg(); mul(dst, x); add(dst, src, y); + return true; } void cut(row& r) { @@ -197,6 +224,8 @@ namespace arith { resolve(v, m_eqs[j], coeff, r); resolve(v, m_ineq, coeff, r); resolve(v, m_conseq, coeff, r); + for (auto& ineq : m_ineqs) + resolve(v, ineq, coeff, r); } return true; } @@ -269,6 +298,81 @@ namespace arith { return false; } + /** + Check implied equality lemma: + + inequalities & equalities => equality + + + We may assume the set of inequality assumptions we are given are all tight, non-strict and imply equalities. + In other words, given a set of inequalities a1x + b1 <= 0, ..., anx + bn <= 0 + the equalities a1x + b1 = 0, ..., anx + bn = 0 are all consequences. + + We use a weaker property: We derive implied equalities by applying exhaustive Fourier-Motzkin + elimination and then collect the tight 0 <= 0 inequalities that are derived. + + Claim: the set of inequalities used to derive 0 <= 0 are all tight equalities. + */ + + svector> m_deps; + unsigned_vector m_tight_inequalities; + uint_set m_ineqs_that_are_eqs; + + bool check_implied_eq() { + if (!reduce_eq()) + return true; + if (m_conseq.is_zero()) + return true; + + m_eqs.reset(); + m_deps.reset(); + unsigned orig_size = m_ineqs.size(); + m_deps.reserve(orig_size); + for (unsigned i = 0; i < m_ineqs.size(); ++i) { + row& r = m_ineqs[i]; + if (r.is_zero()) { + m_tight_inequalities.push_back(i); + continue; + } + auto const& [v, coeff] = *r.m_coeffs.begin(); + unsigned sz = m_ineqs.size(); + + for (unsigned j = i + 1; j < sz; ++j) { + rational B; + row& r2 = m_ineqs[j]; + if (!r2.m_coeffs.find(v, B) || (coeff > 0 && B > 0) || (coeff < 0 && B < 0)) + continue; + row& r3 = fresh(m_ineqs); + add(r3, m_ineqs[j], rational::one()); + resolve(v, r3, coeff, m_ineqs[i]); + m_deps.push_back({i, j}); + } + SASSERT(m_deps.size() == m_ineqs.size()); + } + + m_ineqs_that_are_eqs.reset(); + while (!m_tight_inequalities.empty()) { + unsigned j = m_tight_inequalities.back(); + m_tight_inequalities.pop_back(); + if (m_ineqs_that_are_eqs.contains(j)) + continue; + m_ineqs_that_are_eqs.insert(j); + if (j < orig_size) { + m_eqs.push_back(m_ineqs[j]); + } + else { + auto [a, b] = m_deps[j]; + m_tight_inequalities.push_back(a); + m_tight_inequalities.push_back(b); + } + } + m_ineqs.reset(); + + VERIFY (reduce_eq()); + + return m_conseq.is_zero(); + } + std::ostream& display_row(std::ostream& out, row const& r) { bool first = true; for (auto const& [v, coeff] : r.m_coeffs) { @@ -306,22 +410,21 @@ namespace arith { public: theory_checker(ast_manager& m): m(m), - a(m), - m_farkas("farkas"), - m_implied_eq("implied-eq"), - m_bound("bound") {} + a(m) {} void reset() { m_ineq.reset(); m_conseq.reset(); m_eqs.reset(); + m_ineqs.reset(); m_strict = false; } - bool add_ineq(rational const& coeff, expr* e, bool sign) { - return add_literal(m_ineq, abs(coeff), e, sign); + bool add_ineq(rule_type_t rt, rational const& coeff, expr* e, bool sign) { + row& r = rt == implied_eq_t ? fresh(m_ineqs) : m_ineq; + return add_literal(r, abs(coeff), e, sign); } - + bool add_conseq(rational const& coeff, expr* e, bool sign) { return add_literal(m_conseq, abs(coeff), e, sign); } @@ -332,11 +435,17 @@ namespace arith { linearize(r, rational(-1), b); } - bool check() { - if (m_conseq.m_coeffs.empty()) + bool check(rule_type_t rt) { + switch (rt) { + case farkas_t: return check_farkas(); - else + case bound_t: + return check_bound(); + case implied_eq_t: + return check_implied_eq(); + default: return check_bound(); + } } std::ostream& display(std::ostream& out) { @@ -359,7 +468,7 @@ namespace arith { /** Add implied equality as an inequality */ - bool add_implied_ineq(bool sign, app* jst) { + bool add_implied_diseq(bool sign, app* jst) { unsigned n = jst->get_num_args(); if (n < 2) return false; @@ -374,90 +483,57 @@ namespace arith { return false; if (!sign) coeff.neg(); - auto& r = m_ineq; + auto& r = m_conseq; linearize(r, coeff, arg1); linearize(r, -coeff, arg2); - m_strict = true; return true; } bool check(app* jst) override { reset(); - bool is_bound = jst->get_name() == m_bound; - bool is_implied_eq = jst->get_name() == m_implied_eq; - bool is_farkas = jst->get_name() == m_farkas; - if (!is_farkas && !is_bound && !is_implied_eq) { + + auto rt = rule_type(jst); + switch (rt) { + case cut_t: + return false; + case none_t: IF_VERBOSE(0, verbose_stream() << "unhandled inference " << mk_pp(jst, m) << "\n"); return false; + default: + break; } bool even = true; rational coeff; expr* x, * y; - unsigned j = 0, num_le = 0; - + unsigned j = 0; for (expr* arg : *jst) { + if (even) { if (!a.is_numeral(arg, coeff)) { IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n"); return false; } - if (is_implied_eq) { - is_implied_eq = false; - if (!coeff.is_unsigned()) { - IF_VERBOSE(0, verbose_stream() << "not unsigned " << mk_pp(jst, m) << "\n"); - return false; - } - num_le = coeff.get_unsigned(); - if (!add_implied_ineq(false, jst)) { - IF_VERBOSE(0, display(verbose_stream() << "did not add implied eq")); - return false; - } - ++j; - continue; - } } else { bool sign = m.is_not(arg, arg); if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) { - if (is_bound && j + 1 == jst->get_num_args()) + if (rt == bound_t && j + 1 == jst->get_num_args()) add_conseq(coeff, arg, sign); - else if (num_le > 0) { - add_ineq(coeff, arg, sign); - --num_le; - if (num_le == 0) { - // we processed all the first inequalities, - // check that they imply one half of the implied equality. - if (!check()) { - // we might have added the wrong direction of the implied equality. - // so try the opposite inequality. - add_implied_ineq(true, jst); - add_implied_ineq(true, jst); - if (check()) { - reset(); - add_implied_ineq(false, jst); - } - else { - IF_VERBOSE(0, display(verbose_stream() << "failed to check implied eq ")); - return false; - } - } - else { - reset(); - VERIFY(add_implied_ineq(true, jst)); - } - } - } else - add_ineq(coeff, arg, sign); + add_ineq(rt, coeff, arg, sign); } else if (m.is_eq(arg, x, y)) { - if (is_bound && j + 1 == jst->get_num_args()) + if (rt == bound_t && j + 1 == jst->get_num_args()) add_conseq(coeff, arg, sign); - else if (sign) - return check(); // it should be an implied equality - else + else if (rt == implied_eq_t && j + 1 == jst->get_num_args()) + return add_implied_diseq(sign, jst) && check(rt); + else if (!sign) add_eq(x, y); + else { + IF_VERBOSE(0, verbose_stream() << "unexpected disequality in justification " << mk_pp(arg, m) << "\n"); + return false; + } } else { IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n"); @@ -467,13 +543,14 @@ namespace arith { even = !even; ++j; } - return check(); + return check(rt); } void register_plugins(euf::theory_checker& pc) override { pc.register_plugin(m_farkas, this); pc.register_plugin(m_bound, this); pc.register_plugin(m_implied_eq, this); + pc.register_plugin(m_cut, this); } }; diff --git a/src/sat/smt/arith_value.cpp b/src/sat/smt/arith_value.cpp new file mode 100644 index 00000000000..bb301808e40 --- /dev/null +++ b/src/sat/smt/arith_value.cpp @@ -0,0 +1,145 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + smt_arith_value.cpp + +Abstract: + + Utility to extract arithmetic values from context. + +Author: + + Nikolaj Bjorner (nbjorner) 2018-12-08. + +Revision History: + +--*/ + +#include "ast/ast_pp.h" +#include "sat/smt/arith_value.h" +#include "sat/smt/euf_solver.h" +#include "sat/smt/arith_solver.h" + +namespace arith { + + arith_value::arith_value(euf::solver& s) : + s(s), m(s.get_manager()), a(m) {} + + void arith_value::init() { + if (!as) + as = dynamic_cast(s.fid2solver(a.get_family_id())); + } + + bool arith_value::get_value(expr* e, rational& val) { + auto n = s.get_enode(e); + expr_ref _val(m); + init(); + return n && as->get_value(n, _val) && a.is_numeral(_val, val); + } + +#if 0 + bool arith_value::get_lo_equiv(expr* e, rational& lo, bool& is_strict) { + if (!s.get_enode(e)) + return false; + init(); + is_strict = false; + bool found = false; + bool is_strict1; + rational lo1; + for (auto sib : euf::enode_class(s.get_enode(e))) { + if (!as->get_lower(sib, lo1, is_strict1)) + continue; + if (!found || lo1 > lo || lo == lo1 && is_strict1) + lo = lo1, is_strict = is_strict1; + found = true; + } + CTRACE("arith_value", !found, tout << "value not found for " << mk_pp(e, m) << "\n";); + return found; + } + + bool arith_value::get_up_equiv(expr* e, rational& hi, bool& is_strict) { + if (!s.get_enode(e)) + return false; + init(); + is_strict = false; + bool found = false; + bool is_strict1; + rational hi1; + for (auto sib : euf::enode_class(s.get_enode(e))) { + if (!as->get_upper(sib, hi1, is_strict1)) + continue; + if (!found || hi1 < hi || hi == hi1 && is_strict1) + hi = hi1, is_strict = is_strict1; + found = true; + } + CTRACE("arith_value", !found, tout << "value not found for " << mk_pp(e, m) << "\n";); + return found; + } + + bool arith_value::get_up(expr* e, rational& up, bool& is_strict) const { + init(); + enode* n = s.get_enode(e); + is_strict = false; + return n && as->get_upper(n, up, is_strict); + } + + bool arith_value::get_lo(expr* e, rational& lo, bool& is_strict) const { + init(); + enode* n = s.get_enode(e); + is_strict = false; + return n && as->get_lower(n, lo, is_strict); + } + +#endif + + +#if 0 + + + bool arith_value::get_value_equiv(expr* e, rational& val) const { + if (!m_ctx->e_internalized(e)) return false; + expr_ref _val(m); + enode* next = m_ctx->get_enode(e), * n = next; + do { + e = next->get_expr(); + if (m_tha && m_tha->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thi && m_thi->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thr && m_thr->get_value(next, val)) return true; + next = next->get_next(); + } while (next != n); + TRACE("arith_value", tout << "value not found for " << mk_pp(e, m_ctx->get_manager()) << "\n";); + return false; + } + + expr_ref arith_value::get_lo(expr* e) const { + rational lo; + bool s = false; + if ((a.is_int_real(e) || b.is_bv(e)) && get_lo(e, lo, s) && !s) { + return expr_ref(a.mk_numeral(lo, e->get_sort()), m); + } + return expr_ref(e, m); + } + + expr_ref arith_value::get_up(expr* e) const { + rational up; + bool s = false; + if ((a.is_int_real(e) || b.is_bv(e)) && get_up(e, up, s) && !s) { + return expr_ref(a.mk_numeral(up, e->get_sort()), m); + } + return expr_ref(e, m); + } + + expr_ref arith_value::get_fixed(expr* e) const { + rational lo, up; + bool s = false; + if (a.is_int_real(e) && get_lo(e, lo, s) && !s && get_up(e, up, s) && !s && lo == up) { + return expr_ref(a.mk_numeral(lo, e->get_sort()), m); + } + return expr_ref(e, m); + } + +#endif + +}; diff --git a/src/sat/smt/arith_value.h b/src/sat/smt/arith_value.h new file mode 100644 index 00000000000..b858ff8965a --- /dev/null +++ b/src/sat/smt/arith_value.h @@ -0,0 +1,52 @@ + +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + arith_value.h + +Abstract: + + Utility to extract arithmetic values from context. + +Author: + + Nikolaj Bjorner (nbjorner) 2018-12-08. + +Revision History: + +--*/ +#pragma once + +#include "ast/arith_decl_plugin.h" + +namespace euf { + class solver; +} +namespace arith { + + class solver; + + class arith_value { + euf::solver& s; + ast_manager& m; + arith_util a; + solver* as = nullptr; + void init(); + public: + arith_value(euf::solver& s); + bool get_value(expr* e, rational& value); + +#if 0 + bool get_lo_equiv(expr* e, rational& lo, bool& strict); + bool get_up_equiv(expr* e, rational& up, bool& strict); + bool get_lo(expr* e, rational& lo, bool& strict); + bool get_up(expr* e, rational& up, bool& strict); + bool get_value_equiv(expr* e, rational& value); + expr_ref get_lo(expr* e); + expr_ref get_up(expr* e); + expr_ref get_fixed(expr* e); +#endif + }; +}; diff --git a/src/sat/smt/array_axioms.cpp b/src/sat/smt/array_axioms.cpp index ff60bf67575..6a08cf76e88 100644 --- a/src/sat/smt/array_axioms.cpp +++ b/src/sat/smt/array_axioms.cpp @@ -409,6 +409,11 @@ namespace array { def1 = a.mk_default(store); def2 = a.mk_default(store->get_arg(0)); + prop |= !ctx.get_enode(def1) || !ctx.get_enode(def2); + + euf::enode* ndef1 = e_internalize(def1); + euf::enode* ndef2 = e_internalize(def2); + if (has_unitary_domain(store)) { def2 = store->get_arg(num_args - 1); } @@ -417,8 +422,8 @@ namespace array { // let A = store(B, i, v) // // Add: - // default(A) = ite(epsilon1 = i, v, default(B)) - // A[diag(i)] = B[diag(i)] + // default(A) = A[epsilon] + // default(B) = B[epsilon] // expr_ref_vector eqs(m); expr_ref_vector args1(m), args2(m); @@ -428,22 +433,21 @@ namespace array { for (unsigned i = 1; i + 1 < num_args; ++i) { expr* arg = store->get_arg(i); sort* srt = arg->get_sort(); - auto ep = mk_epsilon(srt); - eqs.push_back(m.mk_eq(ep.first, arg)); - args1.push_back(m.mk_app(ep.second, arg)); - args2.push_back(m.mk_app(ep.second, arg)); + auto [ep, d] = mk_epsilon(srt); + eqs.push_back(m.mk_eq(ep, arg)); + args1.push_back(ep); + args2.push_back(ep); } - expr_ref eq(m.mk_and(eqs), m); - def2 = m.mk_ite(eq, store->get_arg(num_args - 1), def2); app_ref sel1(m), sel2(m); sel1 = a.mk_select(args1); sel2 = a.mk_select(args2); - prop |= !ctx.get_enode(sel1) || !ctx.get_enode(sel2); - if (ctx.propagate(e_internalize(sel1), e_internalize(sel2), array_axiom())) - prop = true; + return + ctx.propagate(e_internalize(sel1), ndef1, array_axiom()) || + ctx.propagate(e_internalize(sel2), ndef2, array_axiom()) || + prop; } - prop |= !ctx.get_enode(def1) || !ctx.get_enode(def2); - if (ctx.propagate(e_internalize(def1), e_internalize(def2), array_axiom())) + // default(A) == default(B) + if (ctx.propagate(ndef1, ndef2, array_axiom())) prop = true; return prop; } diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index c63eedaca4f..8dc6e4e8446 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -32,7 +32,7 @@ namespace array { typedef sat::literal literal; typedef sat::bool_var bool_var; typedef sat::literal_vector literal_vector; - typedef union_find array_union_find; + typedef union_find array_union_find; struct stats { diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp index 99d2a34ae11..f1e9e837443 100644 --- a/src/sat/smt/bv_internalize.cpp +++ b/src/sat/smt/bv_internalize.cpp @@ -191,8 +191,8 @@ namespace bv { case OP_BAND: internalize_ac(mk_and); break; case OP_BOR: internalize_ac(mk_or); break; case OP_BXOR: internalize_ac(mk_xor); break; - case OP_BNAND: internalize_bin(mk_nand); break; - case OP_BNOR: internalize_bin(mk_nor); break; + case OP_BNAND: if_unary(mk_not); internalize_bin(mk_nand); break; + case OP_BNOR: if_unary(mk_not); internalize_bin(mk_nor); break; case OP_BXNOR: if_unary(mk_not); internalize_bin(mk_xnor); break; case OP_BCOMP: internalize_bin(mk_comp); break; case OP_SIGN_EXT: internalize_pun(mk_sign_extend); break; diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index a0bcea43b96..53a00414074 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -96,8 +96,6 @@ namespace bv { } void solver::add_fixed_eq(theory_var v1, theory_var v2) { - if (!get_config().m_bv_eq_axioms) - return; m_ackerman.used_eq_eh(v1, v2); } @@ -128,7 +126,7 @@ namespace bv { /** \brief Find an unassigned bit for m_wpos[v], if such bit cannot be found invoke fixed_var_eh */ - void solver::find_wpos(theory_var v) { + bool solver::find_wpos(theory_var v) { literal_vector const& bits = m_bits[v]; unsigned sz = bits.size(); unsigned& wpos = m_wpos[v]; @@ -137,19 +135,18 @@ namespace bv { if (s().value(bits[idx]) == l_undef) { wpos = idx; TRACE("bv", tout << "moved wpos of v" << v << " to " << wpos << "\n";); - return; + return false; } } TRACE("bv", tout << "v" << v << " is a fixed variable.\n";); fixed_var_eh(v); + return true; } /** *\brief v[idx] = ~v'[idx], then v /= v' is a theory axiom. */ void solver::find_new_diseq_axioms(atom& a, theory_var v, unsigned idx) { - if (!get_config().m_bv_eq_axioms) - return; literal l = m_bits[v][idx]; l.neg(); for (auto vp : a) { @@ -270,7 +267,7 @@ namespace bv { ++num_undef; undef_idx = -static_cast(i + 1); } - if (num_undef > 1 && get_config().m_bv_eq_axioms) + if (num_undef > 1) return; } if (num_undef == 0) @@ -292,8 +289,6 @@ namespace bv { ++m_stats.m_num_ne2bit; s().assign(consequent, mk_ne2bit_justification(undef_idx, v1, v2, consequent, antecedent)); } - else if (!get_config().m_bv_eq_axioms) - ; else if (s().at_search_lvl()) { force_push(); assert_ackerman(v1, v2); @@ -313,7 +308,7 @@ namespace bv { case bv_justification::kind_t::eq2bit: SASSERT(s().value(c.m_antecedent) == l_true); r.push_back(c.m_antecedent); - ctx.add_antecedent(probing, var2enode(c.m_v1), var2enode(c.m_v2)); + ctx.add_eq_antecedent(probing, var2enode(c.m_v1), var2enode(c.m_v2)); break; case bv_justification::kind_t::ne2bit: { r.push_back(c.m_antecedent); @@ -381,8 +376,8 @@ namespace bv { break; } case bv_justification::kind_t::bv2int: { - ctx.add_antecedent(probing, c.a, c.b); - ctx.add_antecedent(probing, c.a, c.c); + ctx.add_eq_antecedent(probing, c.a, c.b); + ctx.add_eq_antecedent(probing, c.a, c.c); break; } } @@ -853,7 +848,17 @@ namespace bv { values[n->get_root_id()] = bv.mk_numeral(val, m_bits[v].size()); } - trail_stack& solver::get_trail_stack() { + sat::bool_var solver::get_bit(unsigned bit, euf::enode *n) const { + theory_var v = n->get_th_var(get_id()); + if (v == euf::null_theory_var) + return sat::null_bool_var; + auto &bits = m_bits[v]; + if (bit >= bits.size()) + return sat::null_bool_var; + return bits[bit].var(); + } + + trail_stack &solver::get_trail_stack() { return ctx.get_trail_stack(); } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index dc9cd1456b8..df4e5c9c2e9 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -49,7 +49,7 @@ namespace bv { typedef std::pair value_sort_pair; typedef pair_hash, unsigned_hash> value_sort_pair_hash; typedef map > value2var; - typedef union_find bv_union_find; + typedef union_find bv_union_find; typedef std::pair var_pos; friend class ackerman; @@ -321,7 +321,7 @@ namespace bv { // solving theory_var find(theory_var v) const { return m_find.find(v); } - void find_wpos(theory_var v); + bool find_wpos(theory_var v); void find_new_diseq_axioms(atom& a, theory_var v, unsigned idx); void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx); bool get_fixed_value(theory_var v, numeral& result) const; @@ -334,7 +334,6 @@ namespace bv { numeral const& power2(unsigned i) const; sat::literal mk_true(); - // invariants bool check_zero_one_bits(theory_var v); void check_missing_propagation() const; @@ -391,6 +390,7 @@ namespace bv { euf::theory_var mk_var(euf::enode* n) override; void apply_sort_cnstr(euf::enode * n, sort * s) override; + bool_var get_bit(unsigned bit, euf::enode* n) const; void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) { SASSERT(check_zero_one_bits(r1)); } diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index 56a224d362f..0b3ca2a9911 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -7,7 +7,7 @@ Module Name: Abstract: - Theory plugin for altegraic datatypes + Theory plugin for algebraic datatypes Author: @@ -400,7 +400,7 @@ namespace dt { return; } SASSERT(val == l_undef || (val == l_false && !d->m_constructor)); - ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); d->m_recognizers[c_idx] = recognizer; if (val == l_false) propagate_recognizer(v, recognizer); @@ -554,12 +554,12 @@ namespace dt { } // Assuming `app` is equal to a constructor term, return the constructor enode - inline euf::enode* solver::oc_get_cstor(enode* app) { + inline euf::enode* solver::oc_get_cstor(enode* app) const { theory_var v = app->get_root()->get_th_var(get_id()); - SASSERT(v != euf::null_theory_var); + if (v == euf::null_theory_var) + return nullptr; v = m_find.find(v); var_data* d = m_var_data[v]; - SASSERT(d->m_constructor); return d->m_constructor; } @@ -757,7 +757,7 @@ namespace dt { void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { auto& jst = euf::th_explain::from_index(idx); - ctx.get_antecedents(l, jst, r, probing); + ctx.get_th_antecedents(l, jst, r, probing); } void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { @@ -783,7 +783,7 @@ namespace dt { if (v == euf::null_theory_var) return false; euf::enode* con = m_var_data[m_find.find(v)]->m_constructor; - CTRACE("dt", !con, display(tout) << ctx.bpp(n) << "\n";); + TRACE("dt", display(tout) << ctx.bpp(n) << " con: " << ctx.bpp(con) << "\n";); if (con->num_args() == 0) dep.insert(n, nullptr); for (enode* arg : euf::enode_args(con)) @@ -794,16 +794,15 @@ namespace dt { bool solver::include_func_interp(func_decl* f) const { if (!dt.is_accessor(f)) return false; - func_decl* con = dt.get_accessor_constructor(f); - for (enode* app : ctx.get_egraph().enodes_of(f)) { - enode* arg = app->get_arg(0)->get_root(); - if (is_constructor(arg) && arg->get_decl() != con) + func_decl* con_decl = dt.get_accessor_constructor(f); + for (enode* app : ctx.get_egraph().enodes_of(f)) { + enode* con = oc_get_cstor(app->get_arg(0)); + if (con && is_constructor(con) && con->get_decl() != con_decl) return true; } return false; } - sat::literal solver::internalize(expr* e, bool sign, bool root) { if (!visit_rec(m, e, sign, root)) return sat::null_literal; diff --git a/src/sat/smt/dt_solver.h b/src/sat/smt/dt_solver.h index 51a7679fdd0..02f1300b80a 100644 --- a/src/sat/smt/dt_solver.h +++ b/src/sat/smt/dt_solver.h @@ -7,7 +7,7 @@ Module Name: Abstract: - Theory plugin for altegraic datatypes + Theory plugin for algebraic datatypes Author: @@ -36,7 +36,7 @@ namespace dt { typedef sat::bool_var bool_var; typedef sat::literal literal; typedef sat::literal_vector literal_vector; - typedef union_find dt_union_find; + typedef union_find dt_union_find; struct var_data { ptr_vector m_recognizers; //!< recognizers of this equivalence class that are being watched. @@ -116,7 +116,7 @@ namespace dt { void pop_core(unsigned n) override; - enode * oc_get_cstor(enode * n); + enode * oc_get_cstor(enode * n) const; bool occurs_check(enode * n); bool occurs_check_enter(enode * n); void occurs_check_explain(enode * top, enode * root); diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index a1d383e4578..f750f186d56 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -106,7 +106,6 @@ namespace euf { attach_node(mk_enode(e, 0, nullptr)); return true; } - bool solver::post_visit(expr* e, bool sign, bool root) { unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0; m_args.reset(); @@ -355,8 +354,16 @@ namespace euf { bool solver::is_shared(enode* n) const { n = n->get_root(); - if (m.is_ite(n->get_expr())) + switch (n->is_shared()) { + case l_true: return true; + case l_false: return false; + default: break; + } + + if (m.is_ite(n->get_expr())) { + n->set_is_shared(l_true); return true; + } // the variable is shared if the equivalence class of n // contains a parent application. @@ -366,21 +373,27 @@ namespace euf { family_id id = p.get_id(); if (m.get_basic_family_id() != id) { - if (th_id != m.get_basic_family_id()) + if (th_id != m.get_basic_family_id()) { + n->set_is_shared(l_true); return true; + } th_id = id; } } - if (m.is_bool(n->get_expr()) && th_id != m.get_basic_family_id()) + if (m.is_bool(n->get_expr()) && th_id != m.get_basic_family_id()) { + n->set_is_shared(l_true); return true; + } for (enode* parent : euf::enode_parents(n)) { app* p = to_app(parent->get_expr()); family_id fid = p->get_family_id(); if (is_beta_redex(parent, n)) continue; - if (fid != th_id && fid != m.get_basic_family_id()) + if (fid != th_id && fid != m.get_basic_family_id()) { + n->set_is_shared(l_true); return true; + } } // Some theories implement families of theories. Examples: @@ -411,15 +424,18 @@ namespace euf { // not marked as shared. for (auto const& p : euf::enode_th_vars(n)) - if (fid2solver(p.get_id())->is_shared(p.get_var())) + if (fid2solver(p.get_id()) && fid2solver(p.get_id())->is_shared(p.get_var())) { + n->set_is_shared(l_true); return true; + } + n->set_is_shared(l_false); return false; } bool solver::is_beta_redex(enode* p, enode* n) const { for (auto const& th : enode_th_vars(p)) - if (fid2solver(th.get_id())->is_beta_redex(p, n)) + if (fid2solver(th.get_id()) && fid2solver(th.get_id())->is_beta_redex(p, n)) return true; return false; } diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index b117ac1e32e..2035e16b643 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -282,7 +282,7 @@ namespace euf { } void solver::display_validation_failure(std::ostream& out, model& mdl, enode* n) { - out << "Failed to validate " << n->bool_var() << " " << bpp(n) << " " << mdl(n->get_expr()) << "\n"; + out << "Failed to validate b" << n->bool_var() << " " << bpp(n) << " " << mdl(n->get_expr()) << "\n"; s().display(out); euf::enode_vector nodes; nodes.push_back(n); @@ -302,7 +302,7 @@ namespace euf { if (mval != sval) { if (r->bool_var() != sat::null_bool_var) out << "b" << r->bool_var() << " "; - out << bpp(r) << " :=\neval: " << sval << "\nmval: " << mval << "\n"; + out << bpp(r) << " :=\nvalue obtained from model: " << sval << "\nvalue of the root expression: " << mval << "\n"; continue; } if (!m.is_bool(val)) @@ -310,7 +310,7 @@ namespace euf { auto bval = s().value(r->bool_var()); bool tt = l_true == bval; if (tt != m.is_true(sval)) - out << bpp(r) << " :=\neval: " << sval << "\nmval: " << bval << "\n"; + out << bpp(r) << " :=\nvalue according to model: " << sval << "\nvalue of Boolean literal: " << bval << "\n"; } for (euf::enode* r : nodes) if (r) @@ -357,6 +357,7 @@ namespace euf { if (!tt && !mdl.is_true(e)) continue; CTRACE("euf", first, display_validation_failure(tout, mdl, n);); + CTRACE("euf", first, display(tout)); IF_VERBOSE(0, display_validation_failure(verbose_stream(), mdl, n);); (void)first; first = false; diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index ac9e8131145..39c9879a608 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -44,17 +44,73 @@ namespace euf { } /** - * \brief logs antecedents to a proof trail. - * - * NB with theories, this is not a pure EUF justification, - * It is true modulo EUF and previously logged certificates - * so it isn't necessarily an axiom over EUF, - * We will here leave it to the EUF checker to perform resolution steps. - */ + * Log justifications. + * is_euf - true if l is justified by congruence closure. In this case create a congruence closure proof. + * explain_size - the relevant portion of premises for the congruence closure proof. + * The EUF solver manages equality propagation. Each propagated equality is justified by a congruence closure. + */ + void solver::log_justifications(literal l, unsigned explain_size, bool is_euf) { + + unsigned nv = s().num_vars(); + expr_ref_vector eqs(m); + + auto add_hint_literals = [&](unsigned sz) { + eqs.reset(); + m_hint_lits.reset(); + nv = s().num_vars(); + for (unsigned i = 0; i < sz; ++i) { + size_t* e = m_explain[i]; + if (is_literal(e)) + m_hint_lits.push_back(get_literal(e)); + else { + auto [x, y] = th_explain::from_index(get_justification(e)).eq_consequent(); + eqs.push_back(m.mk_eq(x->get_expr(), y->get_expr())); + set_tmp_bool_var(nv, eqs.back()); + m_hint_lits.push_back(literal(nv, false)); + ++nv; + } + } + }; + + auto clear_hint_literals = [&]() { + for (unsigned v = s().num_vars(); v < nv; ++v) + set_tmp_bool_var(v, nullptr); + }; + + // log EUF justifications + if (is_euf) { + add_hint_literals(explain_size); + auto* hint = mk_hint(m_euf, l); + log_antecedents(l, m_hint_lits, hint); + clear_hint_literals(); + } + + // explain equalities + for (auto const& [a, b] : m_hint_eqs) { + m_egraph.begin_explain(); + m_explain.reset(); + m_egraph.explain_eq(m_explain, &m_explain_cc, a, b); + m_egraph.end_explain(); + // Detect shortcut if equality is explained directly by a theory + if (m_explain.size() == 1 && !is_literal(m_explain[0])) { + auto const& [x, y] = th_explain::from_index(get_justification(m_explain[0])).eq_consequent(); + if (x == a && y == b) + continue; + } + add_hint_literals(m_explain.size()); + eqs.push_back(m.mk_eq(a->get_expr(), b->get_expr())); + set_tmp_bool_var(nv, eqs.back()); + sat::literal eql = literal(nv, false); + ++nv; + auto* hint = mk_hint(m_euf, eql); + log_antecedents(eql, m_hint_lits, hint); + clear_hint_literals(); + } + } + void solver::log_antecedents(literal l, literal_vector const& r, th_proof_hint* hint) { - TRACE("euf", log_antecedents(tout, l, r);); - if (!use_drat()) - return; + SASSERT(hint && use_drat()); + TRACE("euf", log_antecedents(tout, l, r); tout << mk_pp(hint->get_hint(*this), m) << "\n"); literal_vector lits; for (literal lit : r) lits.push_back(~lit); @@ -63,6 +119,15 @@ namespace euf { get_drat().add(lits, sat::status::th(true, get_id(), hint)); } + void solver::log_rup(literal l, literal_vector const& r) { + literal_vector lits; + for (literal lit : r) + lits.push_back(~lit); + if (l != sat::null_literal) + lits.push_back(l); + get_drat().add(lits, sat::status::redundant()); + } + void solver::log_antecedents(std::ostream& out, literal l, literal_vector const& r) { for (sat::literal l : r) { expr* n = m_bool_var2expr[l.var()]; @@ -79,7 +144,7 @@ namespace euf { } } - eq_proof_hint* solver::mk_hint(symbol const& th, literal conseq, literal_vector const& r) { + eq_proof_hint* solver::mk_hint(symbol const& th, literal conseq) { if (!use_drat()) return nullptr; push(value_trail(m_lit_tail)); @@ -87,7 +152,7 @@ namespace euf { push(restore_vector(m_proof_literals)); if (conseq != sat::null_literal) m_proof_literals.push_back(~conseq); - m_proof_literals.append(r); + m_proof_literals.append(m_hint_lits); m_lit_head = m_lit_tail; m_cc_head = m_cc_tail; m_lit_tail = m_proof_literals.size(); @@ -159,6 +224,7 @@ namespace euf { }; for (unsigned i = m_lit_head; i < m_lit_tail; ++i) args.push_back(s.literal2expr(s.m_proof_literals[i])); + std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts); for (unsigned i = m_cc_head; i < m_cc_tail; ++i) { auto const& [a, b, ts, comm] = s.m_explain_cc[i]; @@ -197,7 +263,6 @@ namespace euf { if (!literal2expr(lits[i])) IF_VERBOSE(0, verbose_stream() << lits[i] << "\n"; display(verbose_stream())); - SASSERT(literal2expr(lits[i])); m_proof_literals.push_back(lits[i]); } @@ -260,10 +325,7 @@ namespace euf { auto const& [a, b] = s.m_proof_deqs[i]; args.push_back(m.mk_not(m.mk_eq(a, b))); } - for (auto * arg : args) - sorts.push_back(arg->get_sort()); - func_decl* f = m.mk_func_decl(m_name, sorts.size(), sorts.data(), proof); - return m.mk_app(f, args); + return m.mk_app(m_name, args.size(), args.data(), proof); } void solver::set_tmp_bool_var(bool_var b, expr* e) { @@ -298,7 +360,7 @@ namespace euf { } void solver::on_clause(unsigned n, literal const* lits, sat::status st) { - TRACE("euf", tout << "on-clause " << n << "\n"); + TRACE("euf_verbose", tout << "on-clause " << n << "\n"); on_lemma(n, lits, st); on_proof(n, lits, st); on_check(n, lits, st); @@ -312,7 +374,7 @@ namespace euf { for (unsigned i = 0; i < n; ++i) m_clause.push_back(literal2expr(lits[i])); auto hint = status2proof_hint(st); - m_on_clause(m_on_clause_ctx, hint, m_clause.size(), m_clause.data()); + m_on_clause(m_on_clause_ctx, hint, 0, nullptr, m_clause.size(), m_clause.data()); } void solver::on_proof(unsigned n, literal const* lits, sat::status st) { @@ -417,7 +479,7 @@ namespace euf { if (proof_hint) return display_expr(out << " ", proof_hint); else - return out; + return out; } app_ref solver::status2proof_hint(sat::status st) { diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index a538b2a804f..42cda4bfb3c 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -501,8 +501,9 @@ namespace euf { for (expr* arg : clause) std::cout << "\n " << mk_bounded_pp(arg, m); std::cout << ")\n"; + std::cout.flush(); - if (is_rup(proof_hint)) + if (false && is_rup(proof_hint)) diagnose_rup_failure(clause); add_clause(clause); @@ -527,9 +528,6 @@ namespace euf { for (expr* f : core) std::cout << mk_pp(f, m) << "\n"; } - SASSERT(false); - - exit(0); } void smt_proof_checker::collect_statistics(statistics& st) const { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 0ae56beb3fa..b108430d8a5 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -21,6 +21,7 @@ Module Name: #include "sat/smt/sat_smt.h" #include "sat/smt/pb_solver.h" #include "sat/smt/bv_solver.h" +#include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/array_solver.h" #include "sat/smt/arith_solver.h" @@ -28,6 +29,7 @@ Module Name: #include "sat/smt/fpa_solver.h" #include "sat/smt/dt_solver.h" #include "sat/smt/recfun_solver.h" +#include "sat/smt/specrel_solver.h" namespace euf { @@ -130,10 +132,19 @@ namespace euf { arith_util arith(m); datatype_util dt(m); recfun::util rf(m); + special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); - else if (bvu.get_family_id() == fid) - ext = alloc(bv::solver, *this, fid); + else if (bvu.get_family_id() == fid) { + if (get_config().m_bv_solver == 0) + ext = alloc(bv::solver, *this, fid); + else if (get_config().m_bv_solver == 1) + throw default_exception("polysat solver is not integrated"); + else if (get_config().m_bv_solver == 2) + ext = alloc(intblast::solver, *this); + else + throw default_exception("unknown bit-vector solver. Accepted values 0 (bit blast), 1 (polysat), 2 (int blast)"); + } else if (au.get_family_id() == fid) ext = alloc(array::solver, *this, fid); else if (fpa.get_family_id() == fid) @@ -144,6 +155,8 @@ namespace euf { ext = alloc(dt::solver, *this, fid); else if (rf.get_family_id() == fid) ext = alloc(recfun::solver, *this); + else if (sp.get_family_id() == fid) + ext = alloc(specrel::solver, *this, fid); if (ext) add_solver(ext); @@ -205,6 +218,15 @@ namespace euf { s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), idx)); } + lbool solver::resolve_conflict() { + for (auto* s : m_solvers) { + lbool r = s->resolve_conflict(); + if (r != l_undef) + return r; + } + return l_undef; + } + /** Retrieve set of literals r that imply r. Since the set of literals are retrieved modulo multiple theories in a single implication @@ -227,64 +249,93 @@ namespace euf { */ void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) { - m_egraph.begin_explain(); - m_explain.reset(); - if (use_drat() && !probing) { + bool create_hint = use_drat() && !probing; + if (create_hint) { push(restore_vector(m_explain_cc)); + m_hint_eqs.reset(); } auto* ext = sat::constraint_base::to_extension(idx); - th_proof_hint* hint = nullptr; - bool has_theory = false; - if (ext == this) - get_antecedents(l, constraint::from_idx(idx), r, probing); - else { + bool is_euf = ext == this; + bool multiple_theories = false; + + m_egraph.begin_explain(); + m_explain.reset(); + if (is_euf) + get_euf_antecedents(l, constraint::from_idx(idx), r, probing); + else ext->get_antecedents(l, idx, r, probing); - has_theory = true; - } + + unsigned ez = m_explain.size(); + for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { size_t* e = m_explain[qhead]; if (is_literal(e)) r.push_back(get_literal(e)); else { + multiple_theories = true; size_t idx = get_justification(e); auto* ext = sat::constraint_base::to_extension(idx); SASSERT(ext != this); sat::literal lit = sat::null_literal; ext->get_antecedents(lit, idx, r, probing); - has_theory = true; } } - m_egraph.end_explain(); - if (use_drat() && !probing) - hint = mk_hint(has_theory ? m_smt : m_euf, l, r); - + m_egraph.end_explain(); + + CTRACE("euf", probing, tout << "explain " << l << " <- " << r << "\n"); unsigned j = 0; - for (sat::literal lit : r) - if (s().lvl(lit) > 0) r[j++] = lit; + for (auto lit : r) + if (s().lvl(lit) > 0) + r[j++] = lit; + bool reduced = j < r.size(); r.shrink(j); - CTRACE("euf", probing, tout << "explain " << l << " <- " << r << "\n"); + DEBUG_CODE(for (auto lit : r) SASSERT(s().value(lit) == l_true);); - if (!probing) - log_antecedents(l, r, hint); + if (create_hint) { + log_justifications(l, ez, is_euf); + if (l != sat::null_literal && (reduced || multiple_theories)) + log_rup(l, r); + } } - void solver::get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) { + void solver::get_eq_antecedents(enode* a, enode* b, literal_vector& r) { + m_egraph.begin_explain(); + m_explain.reset(); + m_egraph.explain_eq(m_explain, nullptr, a, b); + for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { + size_t* e = m_explain[qhead]; + if (is_literal(e)) + r.push_back(get_literal(e)); + else { + size_t idx = get_justification(e); + auto* ext = sat::constraint_base::to_extension(idx); + SASSERT(ext != this); + sat::literal lit = sat::null_literal; + ext->get_antecedents(lit, idx, r, true); + } + } + m_egraph.end_explain(); + } + + + void solver::get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) { for (auto lit : euf::th_explain::lits(jst)) r.push_back(lit); for (auto eq : euf::th_explain::eqs(jst)) - add_antecedent(probing, eq.first, eq.second); + add_eq_antecedent(probing, eq.first, eq.second); if (!probing && use_drat()) log_justification(l, jst); } - void solver::add_antecedent(bool probing, enode* a, enode* b) { - cc_justification* cc = (!probing && use_drat()) ? &m_explain_cc : nullptr; - m_egraph.explain_eq(m_explain, cc, a, b); + void solver::add_eq_antecedent(bool probing, enode* a, enode* b) { + if (!probing && use_drat()) + m_hint_eqs.push_back({a, b}); + m_egraph.explain_eq(m_explain, nullptr, a, b); } - void solver::add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b) { + void solver::explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b) { sat::bool_var v = get_egraph().explain_diseq(ex, cc, a, b); SASSERT(v == sat::null_bool_var || s().value(v) == l_false); if (v != sat::null_bool_var) @@ -298,7 +349,7 @@ namespace euf { return true; } - void solver::get_antecedents(literal l, constraint& j, literal_vector& r, bool probing) { + void solver::get_euf_antecedents(literal l, constraint& j, literal_vector& r, bool probing) { expr* e = nullptr; euf::enode* n = nullptr; cc_justification* cc = nullptr; @@ -307,7 +358,7 @@ namespace euf { init_ackerman(); if (!probing && use_drat()) cc = &m_explain_cc; - + switch (j.kind()) { case constraint::kind_t::conflict: SASSERT(m_egraph.inconsistent()); @@ -333,8 +384,9 @@ namespace euf { bool_var v = ante->bool_var(); lbool val = ante->value(); SASSERT(val != l_undef); - literal ante(v, val == l_false); - m_explain.push_back(to_ptr(ante)); + literal ante_lit(v, val == l_false); + TRACE("euf", tout << "explain " << bpp(n) << " by " << bpp(ante) << "\n"); + m_explain.push_back(to_ptr(ante_lit)); } break; } @@ -421,6 +473,9 @@ namespace euf { } + bool solver::can_propagate() { + return m_egraph.can_propagate(); + } bool solver::unit_propagate() { bool propagated = false; @@ -463,6 +518,7 @@ namespace euf { SASSERT(m.is_bool(e)); size_t cnstr; literal lit; + if (!ante) { VERIFY(m.is_eq(e, a, b)); cnstr = eq_constraint().to_index(); @@ -480,7 +536,7 @@ namespace euf { if (val == l_undef) { SASSERT(m.is_value(ante->get_expr())); val = m.is_true(ante->get_expr()) ? l_true : l_false; - } + } auto& c = lit_constraint(ante); cnstr = c.to_index(); lit = literal(v, val == l_false); @@ -510,7 +566,7 @@ namespace euf { bool solver::is_self_propagated(th_eq const& e) { if (!e.is_eq()) return false; - + m_egraph.begin_explain(); m_explain.reset(); m_egraph.explain_eq(m_explain, nullptr, e.child(), e.root()); @@ -998,8 +1054,10 @@ namespace euf { return out << "euf conflict"; case constraint::kind_t::eq: return out << "euf equality propagation"; - case constraint::kind_t::lit: - return out << "euf literal propagation " << m_egraph.bpp(c.node()) ; + case constraint::kind_t::lit: { + euf::enode* n = c.node(); + return out << "euf literal propagation " << (sat::literal(n->bool_var(), n->value() == l_false)) << " " << m_egraph.bpp(n); + } default: UNREACHABLE(); return out; @@ -1047,14 +1105,14 @@ namespace euf { SASSERT(true_lit != sat::null_literal); return (void*)(r->to_ptr(true_lit)); }; - r->m_egraph.copy_from(m_egraph, copy_justification); r->set_solver(s); + r->m_egraph.copy_from(m_egraph, copy_justification); for (euf::enode* n : r->m_egraph.nodes()) { auto b = n->bool_var(); if (b != sat::null_bool_var) { r->m_bool_var2expr.setx(b, n->get_expr(), nullptr); SASSERT(r->m.is_bool(n->get_sort())); - IF_VERBOSE(11, verbose_stream() << "set bool_var " << b << " " << r->bpp(n) << " " << mk_bounded_pp(n->get_expr(), m) << "\n"); + IF_VERBOSE(20, verbose_stream() << "set bool_var " << b << " " << r->bpp(n) << " " << mk_bounded_pp(n->get_expr(), m) << "\n"); } } for (auto* s_orig : m_id2solver) { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 72776b7ffa7..7d2d01473ea 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -147,6 +147,8 @@ namespace euf { ptr_vector m_bool_var2expr; ptr_vector m_explain; euf::cc_justification m_explain_cc; + enode_pair_vector m_hint_eqs; + sat::literal_vector m_hint_lits; unsigned m_num_scopes = 0; unsigned_vector m_var_trail; svector m_scopes; @@ -218,7 +220,7 @@ namespace euf { void propagate_literal(enode* n, enode* ante); void propagate_th_eqs(); bool is_self_propagated(th_eq const& e); - void get_antecedents(literal l, constraint& j, literal_vector& r, bool probing); + void get_euf_antecedents(literal l, constraint& j, literal_vector& r, bool probing); void new_diseq(enode* a, enode* b, literal lit); bool merge_shared_bools(); @@ -226,9 +228,11 @@ namespace euf { void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r, th_proof_hint* hint); void log_justification(literal l, th_explain const& jst); + void log_justifications(literal l, unsigned explain_size, bool is_euf); + void log_rup(literal l, literal_vector const& r); - eq_proof_hint* mk_hint(symbol const& th, literal lit, literal_vector const& r); + eq_proof_hint* mk_hint(symbol const& th, literal lit); @@ -344,6 +348,7 @@ namespace euf { bool is_external(bool_var v) override; bool propagated(literal l, ext_constraint_idx idx) override; bool unit_propagate() override; + bool can_propagate() override; bool should_research(sat::literal_vector const& core) override; void add_assumptions(sat::literal_set& assumptions) override; bool tracking_assumptions() override; @@ -358,14 +363,16 @@ namespace euf { bool propagate(enode* a, enode* b, th_explain* p) { return propagate(a, b, p->to_index()); } size_t* to_justification(sat::literal l) { return to_ptr(l); } void set_conflict(th_explain* p) { set_conflict(p->to_index()); } + bool inconsistent() const { return s().inconsistent() || m_egraph.inconsistent(); } bool set_root(literal l, literal r) override; void flush_roots() override; void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; - void get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); - void add_antecedent(bool probing, enode* a, enode* b); - void add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); + void get_eq_antecedents(enode* a, enode* b, literal_vector& r); + void get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); + void add_eq_antecedent(bool probing, enode* a, enode* b); + void explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); void add_explain(size_t* p) { m_explain.push_back(p); } void reset_explain() { m_explain.reset(); } void set_eliminated(bool_var v) override; @@ -373,6 +380,7 @@ namespace euf { bool get_case_split(bool_var& var, lbool& phase) override; void asserted(literal l) override; sat::check_result check() override; + lbool resolve_conflict() override; void push() override; void pop(unsigned n) override; void user_push() override; diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp new file mode 100644 index 00000000000..2c373f6b935 --- /dev/null +++ b/src/sat/smt/intblast_solver.cpp @@ -0,0 +1,1141 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + intblast_solver.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2023-12-10 + +--*/ + +#include "ast/ast_util.h" +#include "ast/for_each_expr.h" +#include "ast/rewriter/bv_rewriter.h" +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/intblast_solver.h" +#include "sat/smt/euf_solver.h" +#include "sat/smt/arith_value.h" + + +namespace intblast { + + solver::solver(euf::solver& ctx) : + th_euf_solver(ctx, symbol("intblast"), ctx.get_manager().get_family_id("bv")), + ctx(ctx), + s(ctx.s()), + m(ctx.get_manager()), + bv(m), + a(m), + m_translate(m), + m_args(m), + m_pinned(m) + {} + + euf::theory_var solver::mk_var(euf::enode* n) { + auto r = euf::th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + TRACE("bv", tout << "mk-var: v" << r << " " << ctx.bpp(n) << "\n";); + return r; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + force_push(); + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + force_push(); + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + app* a = to_app(e); + if (visited(e)) + return true; + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + SASSERT(n->is_attached_to(get_id())); + internalize_bv(a); + return true; + } + + void solver::eq_internalized(euf::enode* n) { + expr* e = n->get_expr(); + expr* x, * y; + VERIFY(m.is_eq(n->get_expr(), x, y)); + SASSERT(bv.is_bv(x)); + if (!is_translated(e)) { + ensure_translated(x); + ensure_translated(y); + m_args.reset(); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + set_translated(e, m.mk_eq(umod(x, 0), a.mk_int(0))); + } + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); + } + + void solver::set_translated(expr* e, expr* r) { + SASSERT(r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); + ctx.push(set_vector_idx_trail(m_translate, e->get_id())); + } + + void solver::internalize_bv(app* e) { + ensure_translated(e); + if (m.is_bool(e)) { + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); + } + } + + bool solver::add_bound_axioms() { + if (m_vars_qhead == m_vars.size()) + return false; + ctx.push(value_trail(m_vars_qhead)); + for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { + auto v = m_vars[m_vars_qhead]; + auto w = translated(v); + auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); + auto lo = ctx.mk_literal(a.mk_ge(w, a.mk_int(0))); + auto hi = ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1))); + ctx.mark_relevant(lo); + ctx.mark_relevant(hi); + add_unit(lo); + add_unit(hi); + } + return true; + } + + bool solver::add_predicate_axioms() { + if (m_preds_qhead == m_preds.size()) + return false; + ctx.push(value_trail(m_preds_qhead)); + for (; m_preds_qhead < m_preds.size(); ++m_preds_qhead) { + expr* e = m_preds[m_preds_qhead]; + expr_ref r(translated(e), m); + ctx.get_rewriter()(r); + auto a = expr2literal(e); + auto b = mk_literal(r); + ctx.mark_relevant(b); +// verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; + add_equiv(a, b); + } + return true; + } + + bool solver::unit_propagate() { + return add_bound_axioms() || add_predicate_axioms(); + } + void solver::ensure_translated(expr* e) { + if (m_translate.get(e->get_id(), nullptr)) + return; + ptr_vector todo; + ast_fast_mark1 visited; + todo.push_back(e); + visited.mark(e); + for (unsigned i = 0; i < todo.size(); ++i) { + expr* e = todo[i]; + if (!is_app(e)) + continue; + app* a = to_app(e); + if (m.is_bool(e) && a->get_family_id() != bv.get_family_id()) + continue; + for (auto arg : *a) + if (!visited.is_marked(arg) && !m_translate.get(arg->get_id(), nullptr)) { + visited.mark(arg); + todo.push_back(arg); + } + } + std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : todo) + translate_expr(e); + } + + lbool solver::check_axiom(sat::literal_vector const& lits) { + sat::literal_vector core; + for (auto lit : lits) + core.push_back(~lit); + return check_core(core, {}); + } + lbool solver::check_propagation(sat::literal lit, sat::literal_vector const& lits, euf::enode_pair_vector const& eqs) { + sat::literal_vector core; + core.append(lits); + core.push_back(~lit); + return check_core(core, eqs); + } + + lbool solver::check_core(sat::literal_vector const& lits, euf::enode_pair_vector const& eqs) { + m_core.reset(); + m_vars.reset(); + m_is_plugin = false; + m_solver = mk_smt2_solver(m, s.params(), symbol::null); + + for (unsigned i = 0; i < m_translate.size(); ++i) + m_translate[i] = nullptr; + + expr_ref_vector es(m), original_es(m); + for (auto lit : lits) + es.push_back(ctx.literal2expr(lit)); + for (auto [a, b] : eqs) + es.push_back(m.mk_eq(a->get_expr(), b->get_expr())); + + original_es.append(es); + + lbool r; + if (false) { + r = m_solver->check_sat(es); + } + else { + + translate(es); + + for (auto e : m_vars) { + auto v = translated(e); + auto b = rational::power_of_two(bv.get_bv_size(e)); + m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); + m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); + } + + for (unsigned i = 0; i < es.size(); ++i) { + expr_ref tmp(es.get(i), m); + ctx.get_rewriter()(tmp); + es[i] = tmp; + } + + IF_VERBOSE(2, verbose_stream() << "check\n" << original_es << "\n"); + + IF_VERBOSE(2, + { + m_solver->push(); + m_solver->assert_expr(es); + m_solver->display(verbose_stream()) << "(check-sat)\n"; + m_solver->pop(1); + }); + + + r = m_solver->check_sat(es); + } + + m_solver->collect_statistics(m_stats); + + IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); + if (r == l_true) { + IF_VERBOSE(0, + model_ref mdl; + m_solver->get_model(mdl); + verbose_stream() << original_es << "\n"; + verbose_stream() << *mdl << "\n"; + verbose_stream() << es << "\n"; + m_solver->display(verbose_stream());); + SASSERT(false); + } + + m_solver = nullptr; + + return r; + } + + + + lbool solver::check_solver_state() { + sat::literal_vector literals; + uint_set selected; + for (auto const& clause : s.clauses()) { + if (any_of(*clause, [&](auto lit) { return selected.contains(lit.index()); })) + continue; + if (any_of(*clause, [&](auto lit) { return s.value(lit) == l_true && !is_bv(lit); })) + continue; + // TBD: if we associate "status" with clauses, we can also remove theory axioms from polysat + sat::literal selected_lit = sat::null_literal; + for (auto lit : *clause) { + if (s.value(lit) != l_true) + continue; + SASSERT(is_bv(lit)); + if (selected_lit == sat::null_literal || s.lvl(selected_lit) > s.lvl(lit)) + selected_lit = lit; + } + if (selected_lit == sat::null_literal) { + UNREACHABLE(); + return l_undef; + } + selected.insert(selected_lit.index()); + literals.push_back(selected_lit); + } + unsigned trail_sz = s.init_trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + auto lit = s.trail_literal(i); + if (selected.contains(lit.index()) || !is_bv(lit)) + continue; + selected.insert(lit.index()); + literals.push_back(lit); + } + svector> bin; + s.collect_bin_clauses(bin, false, false); + for (auto [a, b] : bin) { + if (selected.contains(a.index())) + continue; + if (selected.contains(b.index())) + continue; + if (s.value(a) == l_true && !is_bv(a)) + continue; + if (s.value(b) == l_true && !is_bv(b)) + continue; + if (s.value(a) == l_false) + std::swap(a, b); + if (s.value(b) == l_true && s.value(a) == l_true && s.lvl(b) < s.lvl(a)) + std::swap(a, b); + selected.insert(a.index()); + literals.push_back(a); + } + + m_core.reset(); + m_is_plugin = false; + m_solver = mk_smt2_solver(m, s.params(), symbol::null); + + expr_ref_vector es(m); + for (auto lit : literals) + es.push_back(ctx.literal2expr(lit)); + + translate(es); + + for (auto e : m_vars) { + auto v = translated(e); + auto b = rational::power_of_two(bv.get_bv_size(e)); + m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); + m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); + } + + IF_VERBOSE(10, verbose_stream() << "check\n"; + m_solver->display(verbose_stream()); + verbose_stream() << es << "\n"); + + lbool r = m_solver->check_sat(es); + + m_solver->collect_statistics(m_stats); + + IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); + + if (r == l_false) { + expr_ref_vector core(m); + m_solver->get_unsat_core(core); + obj_map e2index; + for (unsigned i = 0; i < es.size(); ++i) + e2index.insert(es.get(i), i); + for (auto e : core) { + unsigned idx = e2index[e]; + if (idx < literals.size()) + m_core.push_back(literals[idx]); + else + m_core.push_back(ctx.mk_literal(e)); + } + } + return r; + }; + + bool solver::is_bv(sat::literal lit) { + expr* e = ctx.bool_var2expr(lit.var()); + if (!e) + return false; + if (m.is_and(e) || m.is_or(e) || m.is_not(e) || m.is_implies(e) || m.is_iff(e)) + return false; + return any_of(subterms::all(expr_ref(e, m)), [&](auto* p) { return bv.is_bv_sort(p->get_sort()); }); + } + + void solver::sorted_subterms(expr_ref_vector& es, ptr_vector& sorted) { + expr_fast_mark1 visited; + for (expr* e : es) { + if (is_translated(e)) + continue; + if (visited.is_marked(e)) + continue; + sorted.push_back(e); + visited.mark(e); + } + for (unsigned i = 0; i < sorted.size(); ++i) { + expr* e = sorted[i]; + if (is_app(e)) { + app* a = to_app(e); + for (expr* arg : *a) { + if (!visited.is_marked(arg) && !is_translated(arg)) { + visited.mark(arg); + sorted.push_back(arg); + } + } + + } + else if (is_quantifier(e)) { + quantifier* q = to_quantifier(e); + expr* b = q->get_expr(); + if (!visited.is_marked(b) && !is_translated(b)) { + visited.mark(b); + sorted.push_back(b); + } + } + } + std::stable_sort(sorted.begin(), sorted.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + } + + void solver::translate(expr_ref_vector& es) { + ptr_vector todo; + + sorted_subterms(es, todo); + + for (expr* e : todo) + translate_expr(e); + + TRACE("bv", + for (expr* e : es) + tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated(e), m) << "\n"; + ); + + for (unsigned i = 0; i < es.size(); ++i) + es[i] = translated(es.get(i)); + } + + sat::check_result solver::check() { + // ensure that bv2int is injective + for (auto e : m_bv2int) { + euf::enode* n = expr2enode(e); + euf::enode* r1 = n->get_arg(0)->get_root(); + for (auto sib : euf::enode_class(n)) { + if (sib == n) + continue; + if (!bv.is_bv2int(sib->get_expr())) + continue; + if (sib->get_arg(0)->get_root() == r1) + continue; + auto a = eq_internalize(n, sib); + auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); + ctx.mark_relevant(a); + ctx.mark_relevant(b); + add_clause(~a, b, nullptr); + return sat::check_result::CR_CONTINUE; + } + } + // ensure that int2bv respects values + // bv2int(int2bv(x)) = x mod N + for (auto e : m_int2bv) { + auto n = expr2enode(e); + auto x = n->get_arg(0)->get_expr(); + auto bv2int = bv.mk_bv2int(e); + ctx.internalize(bv2int); + auto N = rational::power_of_two(bv.get_bv_size(e)); + auto xModN = a.mk_mod(x, a.mk_int(N)); + ctx.internalize(xModN); + auto nBv2int = ctx.get_enode(bv2int); + auto nxModN = ctx.get_enode(xModN); + if (nBv2int->get_root() != nxModN->get_root()) { + auto a = eq_internalize(nBv2int, nxModN); + ctx.mark_relevant(a); + add_unit(a); + return sat::check_result::CR_CONTINUE; + } + } + return sat::check_result::CR_DONE; + } + + bool solver::is_bounded(expr* x, rational const& N) { + return any_of(m_vars, [&](expr* v) { + return is_translated(v) && translated(v) == x && bv.get_bv_size(v) <= N; + }); + } + + bool solver::is_non_negative(expr* bv_expr, expr* e) { + auto N = rational::power_of_two(bv.get_bv_size(bv_expr)); + rational r; + if (a.is_numeral(e, r)) + return r >= 0; + if (is_bounded(e, N)) + return true; + expr* x, * y; + if (a.is_mul(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + if (a.is_add(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + return false; + } + + expr* solver::umod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + rational N = bv_size(bv_expr); + return amod(bv_expr, x, N); + } + + expr* solver::smod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + auto N = bv_size(bv_expr); + auto shift = N / 2; + rational r; + if (a.is_numeral(x, r)) + return a.mk_int(mod(r + shift, N)); + return amod(bv_expr, add(x, a.mk_int(shift)), N); + } + + expr_ref solver::mul(expr* x, expr* y) { + expr_ref _x(x, m), _y(y, m); + if (a.is_zero(x)) + return _x; + if (a.is_zero(y)) + return _y; + if (a.is_one(x)) + return _y; + if (a.is_one(y)) + return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 * v2), m); + _x = a.mk_mul(x, y); + return _x; + } + + expr_ref solver::add(expr* x, expr* y) { + expr_ref _x(x, m), _y(y, m); + if (a.is_zero(x)) + return _y; + if (a.is_zero(y)) + return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 + v2), m); + _x = a.mk_add(x, y); + return _x; + } + + /* + * Perform simplifications that are claimed sound when the bit-vector interpretations of + * mod/div always guard the mod and dividend to be non-zero. + * Potentially shady area is for arithmetic expressions created by int2bv. + * They will be guarded by a modulus which dose not disappear. + */ + expr* solver::amod(expr* bv_expr, expr* x, rational const& N) { + rational v; + expr* r, *c, * t, * e; + if (m.is_ite(x, c, t, e)) + r = m.mk_ite(c, amod(bv_expr, t, N), amod(bv_expr, e, N)); + else if (a.is_idiv(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N && is_non_negative(bv_expr, e)) + r = x; + else if (a.is_mod(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N) + r = x; + else if (a.is_numeral(x, v)) + r = a.mk_int(mod(v, N)); + else if (is_bounded(x, N)) + r = x; + else + r = a.mk_mod(x, a.mk_int(N)); + return r; + } + + rational solver::bv_size(expr* bv_expr) { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); + } + + void solver::translate_expr(expr* e) { + if (is_quantifier(e)) + translate_quantifier(to_quantifier(e)); + else if (is_var(e)) + translate_var(to_var(e)); + else { + app* ap = to_app(e); + if (m_is_plugin && ap->get_family_id() == basic_family_id && m.is_bool(ap)) { + set_translated(e, e); + return; + } + m_args.reset(); + for (auto arg : *ap) + m_args.push_back(translated(arg)); + + if (ap->get_family_id() == basic_family_id) + translate_basic(ap); + else if (ap->get_family_id() == bv.get_family_id()) + translate_bv(ap); + else + translate_app(ap); + } + } + + void solver::translate_quantifier(quantifier* q) { + if (is_lambda(q)) + throw default_exception("lambdas are not supported in intblaster"); + if (m_is_plugin) { + set_translated(q, q); + return; + } + expr* b = q->get_expr(); + unsigned nd = q->get_num_decls(); + ptr_vector sorts; + for (unsigned i = 0; i < nd; ++i) { + auto s = q->get_decl_sort(i); + if (bv.is_bv_sort(s)) { + NOT_IMPLEMENTED_YET(); + sorts.push_back(a.mk_int()); + } + else + + sorts.push_back(s); + } + b = translated(b); + // TODO if sorts contain integer, then created bounds variables. + set_translated(q, m.update_quantifier(q, b)); + } + + void solver::translate_var(var* v) { + if (bv.is_bv_sort(v->get_sort())) + set_translated(v, m.mk_var(v->get_idx(), a.mk_int())); + else + set_translated(v, v); + } + + // Translate functions that are not built-in or bit-vectors. + // Base method uses fresh functions. + // Other method could use bv2int, int2bv axioms and coercions. + // f(args) = bv2int(f(int2bv(args')) + // + + void solver::translate_app(app* e) { + + if (m_is_plugin && m.is_bool(e)) { + set_translated(e, e); + return; + } + + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + + for (unsigned i = 0; i < m_args.size(); ++i) + if (bv.is_bv(e->get_arg(i))) + m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); + + if (has_bv_sort) + m_vars.push_back(e); + if (m_is_plugin) { + expr* r = m.mk_app(f, m_args); + if (has_bv_sort) { + ctx.push(push_back_vector(m_vars)); + r = bv.mk_bv2int(r); + } + set_translated(e, r); + return; + } + else if (has_bv_sort) { + if (f->get_family_id() != null_family_id) + throw default_exception("conversion for interpreted functions is not supported by intblast solver"); + func_decl* g = nullptr; + if (!m_new_funs.find(f, g)) { + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); + m_new_funs.insert(f, g); + } + f = g; + m_pinned.push_back(f); + } + set_translated(e, m.mk_app(f, m_args)); + } + + void solver::translate_bv(app* e) { + + auto bnot = [&](expr* e) { + return a.mk_sub(a.mk_int(-1), e); + }; + + auto band = [&](expr_ref_vector const& args) { + expr* r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_band(bv.get_bv_size(e), r, arg(i)); + return r; + }; + + auto rotate_left = [&](unsigned n) { + auto sz = bv.get_bv_size(e); + n = n % sz; + expr* r = arg(0); + if (n != 0 && sz != 1) { + // r[sz - n - 1 : 0] ++ r[sz - 1 : sz - n] + // r * 2^(sz - n) + (r div 2^n) mod 2^(sz - n)??? + // r * A + (r div B) mod A + auto N = bv_size(e); + auto A = rational::power_of_two(sz - n); + auto B = rational::power_of_two(n); + auto hi = mul(r, a.mk_int(A)); + auto lo = amod(e, a.mk_idiv(umod(e, 0), a.mk_int(B)), A); + r = add(hi, lo); + } + return r; + }; + + expr* bv_expr = e; + expr_ref r(m); + auto const& args = m_args; + switch (e->get_decl_kind()) { + case OP_BADD: + r = a.mk_add(args); + break; + case OP_BSUB: + r = a.mk_sub(args.size(), args.data()); + break; + case OP_BMUL: + r = a.mk_mul(args); + break; + case OP_ULEQ: + bv_expr = e->get_arg(0); + r = a.mk_le(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_UGEQ: + bv_expr = e->get_arg(0); + r = a.mk_ge(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_ULT: + bv_expr = e->get_arg(0); + r = a.mk_lt(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_UGT: + bv_expr = e->get_arg(0); + r = a.mk_gt(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_SLEQ: + bv_expr = e->get_arg(0); + r = a.mk_le(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SGEQ: + r = a.mk_ge(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SLT: + bv_expr = e->get_arg(0); + r = a.mk_lt(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SGT: + bv_expr = e->get_arg(0); + r = a.mk_gt(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_BNEG: + r = a.mk_uminus(arg(0)); + break; + case OP_CONCAT: { + unsigned sz = 0; + expr_ref new_arg(m); + for (unsigned i = args.size(); i-- > 0;) { + expr* old_arg = e->get_arg(i); + new_arg = umod(old_arg, i); + if (sz > 0) { + new_arg = mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = add(r, new_arg); + } + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); + } + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + r = arg(0); + if (lo > 0) + r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + r = a.mk_int(val); + break; + } + case OP_BUREM: + case OP_BUREM_I: { + expr* x = umod(e, 0), * y = umod(e, 1); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, a.mk_mod(x, y)); + break; + } + case OP_BUDIV: + case OP_BUDIV_I: { + expr* x = arg(0), * y = umod(e, 1); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(-1), a.mk_idiv(x, y)); + break; + } + case OP_BUMUL_NO_OVFL: { + bv_expr = e->get_arg(0); + r = a.mk_lt(mul(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(bv_size(bv_expr))); + break; + } + case OP_BSHL: { + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_shl(bv.get_bv_size(e), arg(0),arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + IF_VERBOSE(2, verbose_stream() << "shl " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), mul(x, a.mk_int(rational::power_of_two(i))), r); + } + break; + } + case OP_BNOT: + r = bnot(arg(0)); + break; + case OP_BLSHR: + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_lshr(bv.get_bv_size(e), arg(0), arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + IF_VERBOSE(2, verbose_stream() << "lshr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + } + break; + case OP_BASHR: + if (!a.is_numeral(arg(1))) + r = a.mk_ashr(bv.get_bv_size(e), arg(0), arg(1)); + else { + + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> (x / 2^k) - 2^{N-k} + // + unsigned sz = bv.get_bv_size(e); + rational N = bv_size(e); + expr* x = umod(e, 0), *y = umod(e, 1); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); + IF_VERBOSE(1, verbose_stream() << "ashr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < sz; ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), + m.mk_ite(signx, add(d, a.mk_int(- rational::power_of_two(sz-i))), d), + r); + } + } + break; + case OP_BOR: { + // p | q := (p + q) - band(p, q) + IF_VERBOSE(2, verbose_stream() << "bor " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_sub(add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); + break; + } + case OP_BNAND: + r = bnot(band(args)); + break; + case OP_BAND: + IF_VERBOSE(2, verbose_stream() << "band " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + r = band(args); + break; + case OP_BXNOR: + case OP_BXOR: { + // p ^ q := (p + q) - 2*band(p, q); + unsigned sz = bv.get_bv_size(e); + IF_VERBOSE(2, verbose_stream() << "bxor " << bv.get_bv_size(e) << "\n"); + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) { + expr* q = arg(i); + r = a.mk_sub(add(r, q), mul(a.mk_int(2), a.mk_band(sz, r, q))); + } + if (e->get_decl_kind() == OP_BXNOR) + r = bnot(r); + break; + } + case OP_ZERO_EXT: + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + break; + case OP_SIGN_EXT: { + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + unsigned arg_sz = bv.get_bv_size(bv_expr); + unsigned sz = bv.get_bv_size(e); + rational N = rational::power_of_two(sz); + rational M = rational::power_of_two(arg_sz); + expr* signbit = a.mk_ge(r, a.mk_int(M / 2)); + r = m.mk_ite(signbit, a.mk_uminus(r), r); + break; + } + case OP_INT2BV: + m_int2bv.push_back(e); + ctx.push(push_back_vector(m_int2bv)); + r = arg(0); + break; + case OP_BV2INT: + m_bv2int.push_back(e); + ctx.push(push_back_vector(m_bv2int)); + r = umod(e->get_arg(0), 0); + break; + case OP_BCOMP: + bv_expr = e->get_arg(0); + r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); + break; + case OP_BSMOD_I: + case OP_BSMOD: { + expr* x = umod(e, 0), *y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N/2)); + expr* signy = a.mk_ge(y, a.mk_int(N/2)); + expr* u = a.mk_mod(x, y); + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + r = a.mk_uminus(u); + r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), add(u, y), r); + r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); + r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); + r = m.mk_ite(m.mk_eq(u, a.mk_int(0)), a.mk_int(0), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); + break; + } + case OP_BSDIV_I: + case OP_BSDIV: { + // d = udiv(abs(x), abs(y)) + // y = 0, x > 0 -> 1 + // y = 0, x <= 0 -> -1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + x = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + y = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); + expr* d = a.mk_idiv(x, y); + r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), m.mk_ite(signx, a.mk_int(1), a.mk_int(-1)), r); + break; + } + case OP_BSREM_I: + case OP_BSREM: { + // y = 0 -> x + // else x - sdiv(x, y) * y + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* absx = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + expr* absy = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); + expr* d = a.mk_idiv(absx, absy); + d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = a.mk_sub(x, mul(d, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); + break; + } + case OP_ROTATE_LEFT: { + auto n = e->get_parameter(0).get_int(); + r = rotate_left(n); + break; + } + case OP_ROTATE_RIGHT: { + unsigned sz = bv.get_bv_size(e); + auto n = e->get_parameter(0).get_int(); + r = rotate_left(sz - n); + break; + } + case OP_EXT_ROTATE_LEFT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(i), r); + break; + } + case OP_EXT_ROTATE_RIGHT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(sz - i), r); + break; + } + case OP_REPEAT: { + unsigned n = e->get_parameter(0).get_int(); + expr* x = umod(e->get_arg(0), 0); + r = x; + rational N = bv_size(e->get_arg(0)); + rational N0 = N; + for (unsigned i = 1; i < n; ++i) + r = add(mul(a.mk_int(N), x), r), N *= N0; + break; + } + case OP_BREDOR: { + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(0))); + break; + } + case OP_BREDAND: { + rational N = bv_size(e->get_arg(0)); + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(N - 1))); + break; + } + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } + set_translated(e, r); + } + + void solver::translate_basic(app* e) { + if (m.is_eq(e)) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + if (has_bv_arg) { + expr* bv_expr = e->get_arg(0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + if (a.is_numeral(arg(0)) || a.is_numeral(arg(1)) || + is_bounded(arg(0), N) || is_bounded(arg(1), N)) { + set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); + } + else { + m_args[0] = a.mk_sub(arg(0), arg(1)); + set_translated(e, m.mk_eq(umod(bv_expr, 0), a.mk_int(0))); + } + } + else + set_translated(e, m.mk_eq(arg(0), arg(1))); + } + else if (m.is_ite(e)) + set_translated(e, m.mk_ite(arg(0), arg(1), arg(2))); + else if (m_is_plugin) + set_translated(e, e); + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); + } + + rational solver::get_value(expr* e) const { + SASSERT(bv.is_bv(e)); + model_ref mdl; + m_solver->get_model(mdl); + expr_ref r(m); + r = translated(e); + rational val; + if (!mdl->eval_expr(r, r, true)) + return rational::zero(); + if (!a.is_numeral(r, val)) + return rational::zero(); + return val; + } + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + if (m_is_plugin) + add_value_plugin(n, mdl, values); + else + add_value_solver(n, mdl, values); + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + if (!is_app(n->get_expr())) + return false; + app* e = to_app(n->get_expr()); + if (n->num_args() == 0) { + dep.insert(n, nullptr); + return true; + } + if (e->get_family_id() != bv.get_family_id()) + return false; + for (euf::enode* arg : euf::enode_args(n)) + dep.add(n, arg->get_root()); + return true; + } + + // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated + // bit-vectors directly. + void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr* e = n->get_expr(); + SASSERT(bv.is_bv(e)); + if (bv.is_numeral(e)) { + values.setx(n->get_root_id(), e); + return; + } + + rational r, N = rational::power_of_two(bv.get_bv_size(e)); + expr* te = translated(e); + model_ref mdlr; + m_solver->get_model(mdlr); + expr_ref value(m); + if (mdlr->eval_expr(te, value, true) && a.is_numeral(value, r)) { + values.setx(n->get_root_id(), bv.mk_numeral(mod(r, N), bv.get_bv_size(e))); + return; + } + ctx.s().display(verbose_stream()); + verbose_stream() << "failed to evaluate " << mk_pp(te, m) << " " << value << "\n"; + UNREACHABLE(); + } + + void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + if (n->interpreted()) + value = n->get_expr(); + else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { + bv_rewriter rw(m); + expr_ref_vector args(m); + for (auto arg : euf::enode_args(n)) + args.push_back(values.get(arg->get_root_id())); + rw.mk_app(n->get_decl(), args.size(), args.data(), value); + } + else { + expr_ref bv2int(bv.mk_bv2int(n->get_expr()), m); + euf::enode* b2i = ctx.get_enode(bv2int); + if (!b2i) verbose_stream() << bv2int << "\n"; + SASSERT(b2i); + VERIFY(b2i); + arith::arith_value av(ctx); + rational r; + VERIFY(av.get_value(b2i->get_expr(), r)); + value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); + verbose_stream() << ctx.bpp(n) << " := " << value << "\n"; + } + values.set(n->get_root_id(), value); + TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + } + + sat::literal_vector const& solver::unsat_core() { + return m_core; + } + + std::ostream& solver::display(std::ostream& out) const { + if (m_solver) + m_solver->display(out); + return out; + } + + void solver::collect_statistics(statistics& st) const { + st.copy(m_stats); + } + +} diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h new file mode 100644 index 00000000000..0aceb8b2bb6 --- /dev/null +++ b/src/sat/smt/intblast_solver.h @@ -0,0 +1,152 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + intblast_solver.h + +Abstract: + + Int-blast solver. + + check_solver_state assumes a full assignment to literals in + irredundant clauses. + It picks a satisfying Boolean assignment and + checks if it is feasible for bit-vectors using + an arithmetic solver. + + The solver plugin is self-contained. + + Internalize: + - internalize bit-vector terms bottom-up by updating m_translate. + - add axioms of the form: + - ule(b,a) <=> translate(ule(b, a)) + - let arithmetic solver handle bit-vector constraints. + - For shared b + - Ensure: int2bv(translate(b)) = b + - but avoid bit-blasting by ensuring int2bv is injective (mod N) during final check + +Author: + + Nikolaj Bjorner (nbjorner) 2023-12-10 + +--*/ +#pragma once + +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "solver/solver.h" +#include "sat/smt/sat_th.h" +#include "util/statistics.h" + +namespace euf { + class solver; +} + +namespace intblast { + + class solver : public euf::th_euf_solver { + euf::solver& ctx; + sat::solver& s; + ast_manager& m; + bv_util bv; + arith_util a; + scoped_ptr<::solver> m_solver; + obj_map m_new_funs; + expr_ref_vector m_translate, m_args; + ast_ref_vector m_pinned; + sat::literal_vector m_core; + ptr_vector m_bv2int, m_int2bv; + statistics m_stats; + bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. + + bool is_bv(sat::literal lit); + void translate(expr_ref_vector& es); + void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); + + + + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } + expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + void set_translated(expr* e, expr* r); + expr* arg(unsigned i) { return m_args.get(i); } + + expr* umod(expr* bv_expr, unsigned i); + expr* smod(expr* bv_expr, unsigned i); + bool is_bounded(expr* v, rational const& N); + bool is_non_negative(expr* bv_expr, expr* e); + expr_ref mul(expr* x, expr* y); + expr_ref add(expr* x, expr* y); + expr* amod(expr* bv_expr, expr* x, rational const& N); + rational bv_size(expr* bv_expr); + + void translate_expr(expr* e); + void translate_bv(app* e); + void translate_basic(app* e); + void translate_app(app* e); + void translate_quantifier(quantifier* q); + void translate_var(var* v); + + void ensure_translated(expr* e); + void internalize_bv(app* e); + + unsigned m_vars_qhead = 0, m_preds_qhead = 0; + ptr_vector m_vars, m_preds; + bool add_bound_axioms(); + bool add_predicate_axioms(); + + euf::theory_var mk_var(euf::enode* n) override; + + void add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values); + void add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values); + + public: + solver(euf::solver& ctx); + + ~solver() override {} + + lbool check_axiom(sat::literal_vector const& lits); + lbool check_core(sat::literal_vector const& lits, euf::enode_pair_vector const& eqs); + lbool check_propagation(sat::literal lit, sat::literal_vector const& lits, euf::enode_pair_vector const& eqs); + + lbool check_solver_state(); + + sat::literal_vector const& unsat_core(); + + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + + bool add_dep(euf::enode* n, top_sort& dep) override; + + std::ostream& display(std::ostream& out) const override; + + void collect_statistics(statistics& st) const override; + + bool unit_propagate() override; + + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} + + sat::check_result check() override; + + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return out; } + + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return out; } + + euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + void internalize(expr* e) override; + + bool visited(expr* e) override; + + bool post_visit(expr* e, bool sign, bool root) override; + + bool visit(expr* e) override; + + sat::literal internalize(expr* e, bool, bool) override; + + void eq_internalized(euf::enode* n) override; + + rational get_value(expr* e) const; + + }; + +} diff --git a/src/sat/smt/pb_card.cpp b/src/sat/smt/pb_card.cpp index 7356c3e8402..30120cb7314 100644 --- a/src/sat/smt/pb_card.cpp +++ b/src/sat/smt/pb_card.cpp @@ -176,17 +176,17 @@ namespace pb { return false; } else if (j == bound) { - for (unsigned i = 0; i < bound; ++i) { - s.assign(c, c[i]); - } + for (unsigned i = 0; i < bound; ++i) + s.assign(c, c[i]); return false; } else { - if (c.is_watched()) return true; + if (c.is_watched()) + return true; clear_watch(s); - for (unsigned i = 0; i <= bound; ++i) { - c.watch_literal(s, c[i]); - } + for (unsigned i = 0; i <= bound; ++i) + if (!c.is_watched(s, c[i])) + c.watch_literal(s, c[i]); c.set_watch(); return true; } diff --git a/src/sat/smt/pb_card.h b/src/sat/smt/pb_card.h index 6df50d81c2b..26c6d3bb1a4 100644 --- a/src/sat/smt/pb_card.h +++ b/src/sat/smt/pb_card.h @@ -33,7 +33,7 @@ namespace pb { literal const* begin() const { return m_lits; } literal const* end() const { return static_cast(m_lits) + m_size; } void negate() override; - void swap(unsigned i, unsigned j) override { std::swap(m_lits[i], m_lits[j]); } + void swap(unsigned i, unsigned j) noexcept override { std::swap(m_lits[i], m_lits[j]); } literal_vector literals() const override { return literal_vector(m_size, m_lits); } bool is_watching(literal l) const override; literal get_lit(unsigned i) const override { return m_lits[i]; } diff --git a/src/sat/smt/pb_constraint.h b/src/sat/smt/pb_constraint.h index d8cec6de9d7..26c5f8b0886 100644 --- a/src/sat/smt/pb_constraint.h +++ b/src/sat/smt/pb_constraint.h @@ -102,7 +102,7 @@ namespace pb { virtual bool is_watching(literal l) const { UNREACHABLE(); return false; }; virtual literal_vector literals() const { UNREACHABLE(); return literal_vector(); } - virtual void swap(unsigned i, unsigned j) { UNREACHABLE(); } + virtual void swap(unsigned i, unsigned j) noexcept { UNREACHABLE(); } virtual literal get_lit(unsigned i) const { UNREACHABLE(); return sat::null_literal; } virtual void set_lit(unsigned i, literal l) { UNREACHABLE(); } virtual void negate() { UNREACHABLE(); } diff --git a/src/sat/smt/pb_internalize.cpp b/src/sat/smt/pb_internalize.cpp index abbd79c44d2..391b643f569 100644 --- a/src/sat/smt/pb_internalize.cpp +++ b/src/sat/smt/pb_internalize.cpp @@ -30,7 +30,7 @@ namespace pb { if (m_pb.is_pb(e)) { sat::literal lit = internalize_pb(e, sign, root); if (m_ctx && !root && lit != sat::null_literal) - m_ctx->attach_lit(lit, e); + m_ctx->attach_lit(literal(lit.var(), false), e); return lit; } UNREACHABLE(); @@ -171,6 +171,15 @@ namespace pb { wl.second.neg(); k += rational(wl.first); } + if (k < 0) { + bool_var v = s().add_var(false); + literal l(v, false); + s().assign_unit(~l); + si.cache(t, l); + if (sign) l.neg(); + return l; + } + check_unsigned(k); add_pb_ge(v2, false, wlits, k.get_unsigned()); if (base_assert) { diff --git a/src/sat/smt/pb_pb.h b/src/sat/smt/pb_pb.h index 5db177d0797..169fe247919 100644 --- a/src/sat/smt/pb_pb.h +++ b/src/sat/smt/pb_pb.h @@ -46,7 +46,7 @@ namespace pb { bool is_cardinality() const; void negate() override; void set_k(unsigned k) override { m_k = k; VERIFY(k < 4000000000); update_max_sum(); } - void swap(unsigned i, unsigned j) override { std::swap(m_wlits[i], m_wlits[j]); } + void swap(unsigned i, unsigned j) noexcept override { std::swap(m_wlits[i], m_wlits[j]); } literal_vector literals() const override { literal_vector lits; for (auto wl : *this) lits.push_back(wl.second); return lits; } bool is_watching(literal l) const override; literal get_lit(unsigned i) const override { return m_wlits[i].second; } diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index df832a675f7..76d234d8d5a 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -125,7 +125,7 @@ namespace q { if (a->get_root() == b->get_root()) ctx.get_egraph().explain_eq(m_explain, cc, a, b); else - ctx.add_diseq_antecedent(m_explain, cc, a, b); + ctx.explain_diseq(m_explain, cc, a, b); } ctx.get_egraph().end_explain(); @@ -390,7 +390,7 @@ namespace q { m_qs.log_instantiation(lits, &j); euf::th_proof_hint* ph = nullptr; if (ctx.use_drat()) - ph = q_proof_hint::mk(ctx, j.m_generation, lits, j.m_clause.num_decls(), j.m_binding); + ph = q_proof_hint::mk(ctx, m_ematch, j.m_generation, lits, j.m_clause.num_decls(), j.m_binding); m_qs.add_clause(lits, ph); } @@ -486,7 +486,7 @@ namespace q { * basic clausifier, assumes q has been normalized. */ clause* ematch::clausify(quantifier* _q) { - clause* cl = alloc(clause, m, m_clauses.size()); + scoped_ptr cl = alloc(clause, m, m_clauses.size()); cl->m_literal = ctx.mk_literal(_q); quantifier_ref q(_q, m); q = m_qs.flatten(q); @@ -514,7 +514,7 @@ namespace q { unsigned generation = nq ? nq->generation() : ctx.generation(); cl->m_stat = m_qstat_gen(_q, generation); SASSERT(ctx.s().value(cl->m_literal) == l_true); - return cl; + return cl.detach(); } lit ematch::clausify_literal(expr* arg) { @@ -576,14 +576,12 @@ namespace q { void ematch::add(quantifier* _q) { TRACE("q", tout << "add " << mk_pp(_q, m) << "\n"); - clause* c = clausify(_q); + scoped_ptr c = clausify(_q); quantifier* q = c->q(); - if (m_q2clauses.contains(q)) { - dealloc(c); + if (m_q2clauses.contains(q)) return; - } ensure_ground_enodes(*c); - m_clauses.push_back(c); + m_clauses.push_back(c.get()); m_q2clauses.insert(q, c->index()); ctx.push(pop_clause(*this)); init_watch(*c); @@ -614,6 +612,7 @@ namespace q { if (!unary) j++; } + c.detach(); } diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index cbeb34679bf..f7de55fb813 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -90,6 +90,7 @@ namespace q { unsigned_vector m_clause_queue; euf::enode_pair_vector m_evidence; bool m_enable_propagate = true; + symbol m_ematch = symbol("ematch"); euf::enode* const* copy_nodes(clause& c, euf::enode* const* _binding); binding* tmp_binding(clause& c, app* pat, euf::enode* const* _binding); diff --git a/src/sat/smt/q_mbi.cpp b/src/sat/smt/q_mbi.cpp index c66f1b3a22c..07d4880c966 100644 --- a/src/sat/smt/q_mbi.cpp +++ b/src/sat/smt/q_mbi.cpp @@ -71,7 +71,7 @@ namespace q { for (auto const& [qlit, fml, inst, generation] : m_instantiations) { euf::solver::scoped_generation sg(ctx, generation + 1); sat::literal lit = ~ctx.mk_literal(fml); - auto* ph = ctx.use_drat()? q_proof_hint::mk(ctx, generation, ~qlit, lit, inst.size(), inst.data()) : nullptr; + auto* ph = ctx.use_drat()? q_proof_hint::mk(ctx, m_mbqi, generation, ~qlit, lit, inst.size(), inst.data()) : nullptr; m_qs.add_clause(~qlit, lit, ph); m_qs.log_instantiation(~qlit, lit); } @@ -498,8 +498,8 @@ namespace q { if (m_model->is_false(eq)) { IF_VERBOSE(0, verbose_stream() << mk_pp(s, m) << " := " << (*m_model)(s) << "\n"; - verbose_stream() << mk_pp(term, m) << " := " << (*m_model)(term) << "\n"; - verbose_stream() << value << " -> " << (*m_model)(ctx.values2root()[value]->get_expr()) << "\n"; + verbose_stream() << term << " := " << (*m_model)(term) << "\n"; + verbose_stream() << value << " -> " << (*m_model)(ctx.values2root()[(*m_model)(term)]->get_expr()) << "\n"; verbose_stream() << (*m_model)(s) << " -> " << (*m_model)(ctx.values2root()[(*m_model)(s)]->get_expr()) << "\n"; verbose_stream() << *m_model << "\n";); } diff --git a/src/sat/smt/q_mbi.h b/src/sat/smt/q_mbi.h index 96e3ba56f97..71a15be7473 100644 --- a/src/sat/smt/q_mbi.h +++ b/src/sat/smt/q_mbi.h @@ -72,6 +72,7 @@ namespace q { unsigned m_max_choose_candidates = 10; unsigned m_generation_bound = UINT_MAX; unsigned m_generation_max = UINT_MAX; + symbol m_mbqi = symbol("mbqi"); typedef std::tuple instantiation_t; vector m_instantiations; vector m_defs; diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index fff11898c7d..ed6cb643b3a 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -120,7 +120,6 @@ namespace q { } sat::literal solver::instantiate(quantifier* _q, bool negate, std::function& mk_var) { - sat::literal sk; expr_ref tmp(m); quantifier_ref q(_q, m); expr_ref_vector vars(m); @@ -364,10 +363,10 @@ namespace q { } } - q_proof_hint* q_proof_hint::mk(euf::solver& s, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings) { + q_proof_hint* q_proof_hint::mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings) { SASSERT(n > 0); auto* mem = s.get_region().allocate(q_proof_hint::get_obj_size(n, lits.size())); - q_proof_hint* ph = new (mem) q_proof_hint(generation, n, lits.size()); + q_proof_hint* ph = new (mem) q_proof_hint(method, generation, n, lits.size()); for (unsigned i = 0; i < n; ++i) ph->m_bindings[i] = bindings[i]->get_expr(); for (unsigned i = 0; i < lits.size(); ++i) @@ -375,10 +374,10 @@ namespace q { return ph; } - q_proof_hint* q_proof_hint::mk(euf::solver& s, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings) { + q_proof_hint* q_proof_hint::mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings) { SASSERT(n > 0); auto* mem = s.get_region().allocate(q_proof_hint::get_obj_size(n, 2)); - q_proof_hint* ph = new (mem) q_proof_hint(generation, n, 2); + q_proof_hint* ph = new (mem) q_proof_hint(method, generation, n, 2); for (unsigned i = 0; i < n; ++i) ph->m_bindings[i] = bindings[i]; ph->m_literals[0] = l1; @@ -402,6 +401,7 @@ namespace q { args.push_back(s.literal2expr(~m_literals[i])); args.push_back(binding); args.push_back(m.mk_app(symbol("gen"), 1, gens, range)); + args.push_back(m.mk_const(m_method, range)); return m.mk_app(symbol("inst"), args.size(), args.data(), range); } diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index d0581f85203..a7220e68b9d 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -30,21 +30,23 @@ namespace euf { namespace q { struct q_proof_hint : public euf::th_proof_hint { + symbol m_method; unsigned m_generation; unsigned m_num_bindings; unsigned m_num_literals; sat::literal* m_literals; expr* m_bindings[0]; - q_proof_hint(unsigned g, unsigned b, unsigned l) { + q_proof_hint(symbol const& method, unsigned g, unsigned b, unsigned l) { + m_method = method; m_generation = g; m_num_bindings = b; m_num_literals = l; m_literals = reinterpret_cast(m_bindings + m_num_bindings); } static size_t get_obj_size(unsigned num_bindings, unsigned num_lits) { return sizeof(q_proof_hint) + num_bindings*sizeof(expr*) + num_lits*sizeof(sat::literal); } - static q_proof_hint* mk(euf::solver& s, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings); - static q_proof_hint* mk(euf::solver& s, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings); + static q_proof_hint* mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal_vector const& lits, unsigned n, euf::enode* const* bindings); + static q_proof_hint* mk(euf::solver& s, symbol const& method, unsigned generation, sat::literal l1, sat::literal l2, unsigned n, expr* const* bindings); expr* get_hint(euf::solver& s) const override; }; diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 17d167829ce..21e3883e8a0 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -228,6 +228,8 @@ namespace euf { th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, th_proof_hint const* pma) { m_consequent = c; m_eq = p; + if (m_eq.first && m_eq.first->get_id() > m_eq.second->get_id()) + std::swap(m_eq.first, m_eq.second); m_proof_hint = pma; m_num_literals = n_lits; m_num_eqs = n_eqs; @@ -238,8 +240,11 @@ namespace euf { m_literals[i] = lits[i]; base_ptr += sizeof(literal) * n_lits; m_eqs = reinterpret_cast(base_ptr); - for (i = 0; i < n_eqs; ++i) + for (i = 0; i < n_eqs; ++i) { m_eqs[i] = eqs[i]; + if (m_eqs[i].first->get_id() > m_eqs[i].second->get_id()) + std::swap(m_eqs[i].first, m_eqs[i].second); + } } th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, th_proof_hint const* pma) { diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp new file mode 100644 index 00000000000..8feb9f83eac --- /dev/null +++ b/src/sat/smt/sls_solver.cpp @@ -0,0 +1,130 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_solver + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + +--*/ + +#include "sat/smt/sls_solver.h" +#include "sat/smt/euf_solver.h" + + + +namespace sls { + + solver::solver(euf::solver& ctx): + th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) {} + + solver::~solver() { + if (m_bvsls) { + m_bvsls->cancel(); + m_thread.join(); + } + } + + void solver::push_core() { + if (s().scope_lvl() == s().search_lvl() + 1) + init_local_search(); + } + + void solver::pop_core(unsigned n) { + if (s().scope_lvl() - n <= s().search_lvl()) + sample_local_search(); + } + + void solver::simplify() { + + } + + + void solver::init_local_search() { + if (m_bvsls) { + m_bvsls->cancel(); + m_thread.join(); + if (m_result == l_true) { + verbose_stream() << "Found model using local search - INIT\n"; + exit(1); + } + } + // set up state for local search solver here + + m_m = alloc(ast_manager, m); + ast_translation tr(m, *m_m); + + m_completed = false; + m_result = l_undef; + m_bvsls = alloc(bv::sls, *m_m); + // walk clauses, add them + // walk trail stack until search level, add units + // encapsulate bvsls within the arguments of run-local-search. + // ensure bvsls does not touch ast-manager. + + unsigned trail_sz = s().trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + auto lit = s().trail_literal(i); + if (s().lvl(lit) > s().search_lvl()) + break; + expr_ref fml = literal2expr(lit); + m_bvsls->assert_expr(tr(fml.get())); + } + unsigned num_vars = s().num_vars(); + for (unsigned i = 0; i < 2*num_vars; ++i) { + auto l1 = ~sat::to_literal(i); + auto const& wlist = s().get_wlist(l1); + for (sat::watched const& w : wlist) { + if (!w.is_binary_non_learned_clause()) + continue; + sat::literal l2 = w.get_literal(); + if (l1.index() > l2.index()) + continue; + expr_ref fml(m.mk_or(literal2expr(l1), literal2expr(l2)), m); + m_bvsls->assert_expr(tr(fml.get())); + } + } + for (auto clause : s().clauses()) { + expr_ref_vector cls(m); + for (auto lit : *clause) + cls.push_back(literal2expr(lit)); + expr_ref fml(m.mk_or(cls), m); + m_bvsls->assert_expr(tr(fml.get())); + } + + // use phase assignment from literals? + std::function eval = [&](expr* e, unsigned r) { + return false; + }; + + m_bvsls->init(); + m_bvsls->init_eval(eval); + m_bvsls->updt_params(s().params()); + + m_thread = std::thread([this]() { run_local_search(); }); + } + + void solver::sample_local_search() { + if (m_completed) { + m_thread.join(); + if (m_result == l_true) { + verbose_stream() << "Found model using local search\n"; + exit(1); + } + } + } + + void solver::run_local_search() { + lbool r = (*m_bvsls)(); + m_result = r; + m_completed = true; + } + +} diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h new file mode 100644 index 00000000000..c473264acca --- /dev/null +++ b/src/sat/smt/sls_solver.h @@ -0,0 +1,63 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_solver + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + +--*/ +#pragma once + +#include +#include "util/rlimit.h" +#include "ast/sls/bv_sls.h" +#include "sat/smt/sat_th.h" + + +namespace euf { + class solver; +} + +namespace sls { + + class solver : public euf::th_euf_solver { + std::atomic m_result; + std::atomic m_completed; + std::thread m_thread; + scoped_ptr m_m; + scoped_ptr m_bvsls; + + void run_local_search(); + void init_local_search(); + void sample_local_search(); + public: + solver(euf::solver& ctx); + ~solver(); + + void push_core() override; + void pop_core(unsigned n) override; + void simplify() override; + + sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } + void internalize(expr* e) override { UNREACHABLE(); } + th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + + bool unit_propagate() override { return false; } + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } + sat::check_result check() override { return sat::check_result::CR_DONE; } + std::ostream & display(std::ostream & out) const override { return out; } + std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } + std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + + }; + +} diff --git a/src/sat/smt/specrel_solver.cpp b/src/sat/smt/specrel_solver.cpp new file mode 100644 index 00000000000..d59029e6bc3 --- /dev/null +++ b/src/sat/smt/specrel_solver.cpp @@ -0,0 +1,120 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ + +#include "sat/smt/specrel_solver.h" +#include "sat/smt/euf_solver.h" +#include "ast/euf/euf_specrel_plugin.h" + +namespace euf { + class solver; +} + +namespace specrel { + + solver::solver(euf::solver& ctx, theory_id id) : + th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id), + sp(m) + { + ctx.get_egraph().add_plugin(alloc(euf::specrel_plugin, ctx.get_egraph())); + } + + solver::~solver() { + } + + void solver::asserted(sat::literal l) { + + } + + sat::check_result solver::check() { + return sat::check_result::CR_DONE; + } + + std::ostream& solver::display(std::ostream& out) const { + return out; + } + + void solver::collect_statistics(statistics& st) const { + } + + euf::th_solver* solver::clone(euf::solver& ctx) { + return alloc(solver, ctx, get_id()); + } + + void solver::new_eq_eh(euf::th_eq const& eq) { + TRACE("specrel", tout << "new-eq\n"); + if (eq.is_eq()) { + auto* p = ctx.get_egraph().get_plugin(sp.get_family_id()); + p->merge_eh(var2enode(eq.v1()), var2enode(eq.v2())); + TRACE("specrel", tout << eq.v1() << " " << eq.v2() << "\n"); + } + } + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + return false; + } + + bool solver::include_func_interp(func_decl* f) const { + return false; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + auto lit = ctx.expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (visited(e)) + return true; + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* term, bool sign, bool root) { + euf::enode* n = expr2enode(term); + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(term); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + TRACE("specrel", tout << ctx.bpp(n) << "\n"); + return true; + } + + euf::theory_var solver::mk_var(euf::enode* n) { + if (is_attached_to_var(n)) + return n->get_th_var(get_id()); + euf::theory_var r = th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + return r; + } +} diff --git a/src/sat/smt/specrel_solver.h b/src/sat/smt/specrel_solver.h new file mode 100644 index 00000000000..9ebb769160c --- /dev/null +++ b/src/sat/smt/specrel_solver.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + specrel_solver.h + +Abstract: + + Theory plugin for special relations + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "sat/smt/sat_th.h" +#include "ast/special_relations_decl_plugin.h" + +namespace euf { + class solver; +} + +namespace specrel { + + class solver : public euf::th_euf_solver { + typedef euf::theory_var theory_var; + typedef euf::theory_id theory_id; + typedef euf::enode enode; + typedef euf::enode_pair enode_pair; + typedef euf::enode_pair_vector enode_pair_vector; + typedef sat::bool_var bool_var; + typedef sat::literal literal; + typedef sat::literal_vector literal_vector; + + special_relations_util sp; + + public: + solver(euf::solver& ctx, theory_id id); + ~solver() override; + + bool is_external(bool_var v) override { return false; } + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {} + void asserted(literal l) override; + sat::check_result check() override; + + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_explain::from_index(idx).display(out); } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); } + void collect_statistics(statistics& st) const override; + euf::th_solver* clone(euf::solver& ctx) override; + void new_eq_eh(euf::th_eq const& eq) override; + bool unit_propagate() override { return false; } + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; + bool include_func_interp(func_decl* f) const override; + sat::literal internalize(expr* e, bool sign, bool root) override; + void internalize(expr* e) override; + bool visit(expr* e) override; + bool visited(expr* e) override; + bool post_visit(expr* e, bool sign, bool root) override; + + euf::theory_var mk_var(euf::enode* n) override; + void apply_sort_cnstr(euf::enode* n, sort* s) override {} + bool is_shared(theory_var v) const override { return false; } + lbool get_phase(bool_var v) override { return l_true; } + bool enable_self_propagate() const override { return true; } + + void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); + void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {} + void unmerge_eh(theory_var v1, theory_var v2) {} + }; +} diff --git a/src/sat/smt/tseitin_theory_checker.cpp b/src/sat/smt/tseitin_theory_checker.cpp index 74f4e55b000..ffddc90631c 100644 --- a/src/sat/smt/tseitin_theory_checker.cpp +++ b/src/sat/smt/tseitin_theory_checker.cpp @@ -72,7 +72,7 @@ namespace tseitin { complement_mark(arg); for (expr* arg : *to_app(main_expr)) - if (!is_complement(arg)) + if (!is_complement(arg) && !m.is_true(arg)) return false; return true; @@ -178,7 +178,7 @@ namespace tseitin { for (expr* arg : *jst) mark(arg); for (expr* arg : *to_app(a)) - if (!is_marked(arg)) + if (!is_marked(arg) && !m.is_false(arg)) return false; return true; } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 34f2b10b470..0204573dcb1 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -15,8 +15,9 @@ Module Name: --*/ -#include "sat/smt/user_solver.h" +#include "sat/smt/bv_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/user_solver.h" namespace user_solver { @@ -39,39 +40,47 @@ namespace user_solver { expr_ref r(m); sat::literal_vector explain; if (ctx.is_fixed(n, r, explain)) - m_prop.push_back(prop_info(explain, v, r)); + m_prop.push_back(prop_info(explain, v, r)); } - void solver::propagate_cb( - unsigned num_fixed, expr* const* fixed_ids, - unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, - expr* conseq) { + bool solver::propagate_cb( + unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, + expr* conseq) { + auto* n = ctx.get_enode(conseq); + if (n && s().value(ctx.enode2literal(n)) == l_true) + return false; m_fixed_ids.reset(); for (unsigned i = 0; i < num_fixed; ++i) m_fixed_ids.push_back(get_th_var(fixed_ids[i])); m_prop.push_back(prop_info(num_fixed, m_fixed_ids.data(), num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); DEBUG_CODE(validate_propagation();); + return true; } void solver::register_cb(expr* e) { add_expr(e); } - - void solver::next_split_cb(expr* e, unsigned idx, lbool phase) { + + bool solver::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { - m_next_split_expr = nullptr; - return; + m_next_split_var = sat::null_bool_var; + return true; } force_push(); ctx.internalize(e); - m_next_split_expr = e; - m_next_split_idx = idx; + sat::bool_var var = enode_to_bool(ctx.get_enode(e), idx); + m_next_split_phase = phase; + if (var == sat::null_bool_var || s().value(var) != l_undef) + return false; + m_next_split_var = var; m_next_split_phase = phase; + return true; } sat::check_result solver::check() { if (!(bool)m_final_eh) - return sat::check_result::CR_DONE; + return sat::check_result::CR_DONE; unsigned sz = m_prop.size(); m_final_eh(m_user_context, this); return sz == m_prop.size() ? sat::check_result::CR_DONE : sat::check_result::CR_CONTINUE; @@ -84,39 +93,41 @@ namespace user_solver { m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), value); } - + bool solver::decide(sat::bool_var& var, lbool& phase) { - + if (!m_decide_eh) return false; - + euf::enode* original_enode = bool_var2enode(var); - + if (!original_enode || !is_attached_to_var(original_enode)) return false; - + unsigned new_bit = 0; // ignored; currently no bv-support expr* e = original_enode->get_expr(); - - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - - euf::enode* new_enode = ctx.get_enode(e); - - if (original_enode == new_enode || new_enode->bool_var() == sat::null_bool_var) + + m_decide_eh(m_user_context, this, e, new_bit, phase); + sat::bool_var new_var; + if (!get_case_split(new_var, phase) || new_var == var) + // The user did not interfere return false; - - var = new_enode->bool_var(); + var = new_var; + + // check if the new variable is unassigned + if (s().value(var) != l_undef) + throw default_exception("expression in \"decide\" is already assigned"); return true; } - - bool solver::get_case_split(sat::bool_var& var, lbool& phase){ - if (!m_next_split_expr) + + bool solver::get_case_split(sat::bool_var& var, lbool& phase) { + if (m_next_split_var == sat::null_bool_var) return false; - - euf::enode* n = ctx.get_enode(m_next_split_expr); - var = n->bool_var(); + + var = m_next_split_var; phase = m_next_split_phase; - m_next_split_expr = nullptr; + m_next_split_var = sat::null_bool_var; + m_next_split_phase = l_undef; return true; } @@ -134,14 +145,14 @@ namespace user_solver { m_id2justification.setx(v, lits, sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true()); } - + void solver::new_eq_eh(euf::th_eq const& eq) { if (!m_eq_eh) return; force_push(); m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2())); } - + void solver::new_diseq_eh(euf::th_eq const& de) { if (!m_diseq_eh) return; @@ -167,8 +178,10 @@ namespace user_solver { void solver::propagate_consequence(prop_info const& prop) { sat::literal lit = ctx.internalize(prop.m_conseq, false, false); if (s().value(lit) != l_true) { - s().assign(lit, mk_justification(m_qhead)); + auto j = mk_justification(m_qhead); + s().assign(lit, j); ++m_stats.m_num_propagations; + persist_clause(lit, j); } } @@ -177,9 +190,17 @@ namespace user_solver { } bool solver::unit_propagate() { - if (m_qhead == m_prop.size()) + if (m_qhead == m_prop.size() && m_replay_qhead == m_clauses_to_replay.size()) return false; force_push(); + + bool replayed = false; + if (m_replay_qhead < m_clauses_to_replay.size()) { + replayed = true; + ctx.push(value_trail(m_replay_qhead)); + for (; m_replay_qhead < m_clauses_to_replay.size(); ++m_replay_qhead) + replay_clause(m_clauses_to_replay.get(m_replay_qhead)); + } ctx.push(value_trail(m_qhead)); unsigned np = m_stats.m_num_propagations; for (; m_qhead < m_prop.size() && !s().inconsistent(); ++m_qhead) { @@ -188,8 +209,38 @@ namespace user_solver { propagate_consequence(prop); else propagate_new_fixed(prop); - } - return np < m_stats.m_num_propagations; + } + return np < m_stats.m_num_propagations || replayed; + } + + void solver::replay_clause(expr_ref_vector const& clause) { + sat::literal_vector lits; + for (expr* e : clause) + lits.push_back(ctx.mk_literal(e)); + add_clause(lits); + } + + void solver::persist_clause(sat::literal lit, sat::justification const& sj) { + if (!ctx.get_config().m_up_persist_clauses) + return; + + expr_ref_vector clause(m); + auto idx = sj.get_ext_justification_idx(); + auto& j = justification::from_index(idx); + auto const& prop = m_prop[j.m_propagation_index]; + sat::literal_vector r; + for (unsigned id : prop.m_ids) + r.append(m_id2justification[id]); + for (auto lit : r) + clause.push_back(ctx.literal2expr(~lit)); + for (auto const& [a,b] : prop.m_eqs) + clause.push_back(m.mk_not(m.mk_eq(a, b))); + clause.push_back(ctx.literal2expr(lit)); + + m_clauses_to_replay.push_back(clause); + if (m_replay_qhead + 1 < m_clauses_to_replay.size()) + std::swap(m_clauses_to_replay[m_replay_qhead], m_clauses_to_replay[m_clauses_to_replay.size()-1]); + ++m_replay_qhead; } void solver::collect_statistics(::statistics& st) const { @@ -208,9 +259,9 @@ namespace user_solver { auto& j = justification::from_index(idx); auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) - r.append(m_id2justification[id]); + r.append(m_id2justification[id]); for (auto const& p : prop.m_eqs) - ctx.add_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); + ctx.add_eq_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); } /* @@ -243,7 +294,7 @@ namespace user_solver { } std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { - return display_justification(out, idx); + return display_justification(out, idx); } euf::th_solver* solver::clone(euf::solver& dst_ctx) { @@ -278,26 +329,35 @@ namespace user_solver { return true; } m_stack.push_back(sat::eframe(e)); - return false; + return false; } - + bool solver::visited(expr* e) { euf::enode* n = expr2enode(e); - return n && n->is_attached_to(get_id()); + return n && n->is_attached_to(get_id()); } - + bool solver::post_visit(expr* e, bool sign, bool root) { euf::enode* n = expr2enode(e); SASSERT(!n || !n->is_attached_to(get_id())); - if (!n) - n = mk_enode(e, false); + if (!n) + n = mk_enode(e, false); add_expr(e); if (m_created_eh) m_created_eh(m_user_context, this, e); return true; } - + sat::bool_var solver::enode_to_bool(euf::enode* n, unsigned idx) { + if (n->bool_var() != sat::null_bool_var) { + // expression is a boolean + return n->bool_var(); + } + // expression is a bit-vector + bv_util bv(m); + th_solver* th = ctx.fid2solver(bv.get_fid()); + return ((bv::solver*) th)->get_bit(idx, n); + } } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index cb1c6fe94c8..373b046b80e 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -75,9 +75,10 @@ namespace user_solver { euf::enode_pair_vector m_eqs; unsigned_vector m_fixed_ids; stats m_stats; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + sat::bool_var m_next_split_var = sat::null_bool_var; + lbool m_next_split_phase = l_undef; + vector m_clauses_to_replay; + unsigned m_replay_qhead = 0; struct justification { unsigned m_propagation_index { 0 }; @@ -104,6 +105,11 @@ namespace user_solver { bool visited(expr* e) override; bool post_visit(expr* e, bool sign, bool root) override; + sat::bool_var enode_to_bool(euf::enode* n, unsigned idx); + + void replay_clause(expr_ref_vector const& clause); + void persist_clause(sat::literal lit, sat::justification const& j); + public: solver(euf::solver& ctx); @@ -134,9 +140,9 @@ namespace user_solver { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); diff --git a/src/smt/arith_eq_adapter.cpp b/src/smt/arith_eq_adapter.cpp index b77a389270e..f251d01ae01 100644 --- a/src/smt/arith_eq_adapter.cpp +++ b/src/smt/arith_eq_adapter.cpp @@ -67,6 +67,8 @@ namespace smt { m_ge(ge) { } + ~arith_eq_relevancy_eh() override {} + void operator()(relevancy_propagator & rp) override { if (!rp.is_relevant(m_n1)) return; diff --git a/src/smt/fingerprints.h b/src/smt/fingerprints.h index b1308e9b0d9..b1904a4ee3b 100644 --- a/src/smt/fingerprints.h +++ b/src/smt/fingerprints.h @@ -25,11 +25,11 @@ namespace smt { class fingerprint { protected: - void* m_data{ nullptr }; - unsigned m_data_hash{ 0 }; - expr* m_def{ nullptr }; - unsigned m_num_args{ 0 }; - enode** m_args{ nullptr }; + void* m_data = nullptr; + unsigned m_data_hash = 0; + expr* m_def = nullptr; + unsigned m_num_args = 0; + enode** m_args = nullptr; friend class fingerprint_set; fingerprint() {} diff --git a/src/smt/mam.cpp b/src/smt/mam.cpp index 3804b72284e..b9ac45039f2 100644 --- a/src/smt/mam.cpp +++ b/src/smt/mam.cpp @@ -2413,7 +2413,10 @@ namespace { m_n2 = static_cast(m_pc)->m_enode; SASSERT(m_n1 != 0); SASSERT(m_n2 != 0); - if (m_n1->get_root() != m_n2->get_root()) + + // hack to handle dynamically generated patterns: + // if the pattern is ground and an if-expression, ignore equality check. + if (m_n1->get_root() != m_n2->get_root() && !m.is_ite(m_n2->get_expr())) goto backtrack; // we used the equality m_n1 = m_n2 for the match and need to make sure it ends up in the log diff --git a/src/smt/old_interval.cpp b/src/smt/old_interval.cpp index e719f7e2bd8..c4d49d00a59 100644 --- a/src/smt/old_interval.cpp +++ b/src/smt/old_interval.cpp @@ -238,7 +238,7 @@ interval & interval::operator=(interval const & other) { return *this; } -interval & interval::operator=(interval && other) { +interval & interval::operator=(interval && other) noexcept { SASSERT(&m_manager == &other.m_manager); m_lower = std::move(other.m_lower); m_upper = std::move(other.m_upper); diff --git a/src/smt/old_interval.h b/src/smt/old_interval.h index 7a6f41e269f..bf68a0d2665 100644 --- a/src/smt/old_interval.h +++ b/src/smt/old_interval.h @@ -92,7 +92,7 @@ class old_interval { rational const & get_lower_value() const { SASSERT(!minus_infinity()); return m_lower.to_rational(); } rational const & get_upper_value() const { SASSERT(!plus_infinity()); return m_upper.to_rational(); } old_interval & operator=(old_interval const & other); - old_interval & operator=(old_interval && other); + old_interval & operator=(old_interval && other) noexcept; old_interval & operator+=(old_interval const & other); old_interval & operator-=(old_interval const & other); old_interval & operator*=(old_interval const & other); diff --git a/src/smt/params/smt_params.cpp b/src/smt/params/smt_params.cpp index c2841761316..0b9aa674d28 100644 --- a/src/smt/params/smt_params.cpp +++ b/src/smt/params/smt_params.cpp @@ -51,6 +51,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_core_validate = p.core_validate(); m_logic = _p.get_sym("logic", m_logic); m_string_solver = p.string_solver(); + m_up_persist_clauses = p.up_persist_clauses(); validate_string_solver(m_string_solver); if (_p.get_bool("arith.greatest_error_pivot", false)) m_arith_pivot_strategy = arith_pivot_strategy::ARITH_PIVOT_GREATEST_ERROR; @@ -145,6 +146,7 @@ void smt_params::display(std::ostream & out) const { DISPLAY_PARAM(m_agility_factor); DISPLAY_PARAM(m_restart_agility_threshold); + DISPLAY_PARAM(m_up_persist_clauses); DISPLAY_PARAM(m_lemma_gc_strategy); DISPLAY_PARAM(m_lemma_gc_half); DISPLAY_PARAM(m_recent_lemmas_size); diff --git a/src/smt/params/smt_params.h b/src/smt/params/smt_params.h index fda89e8e29d..b71b71fa1cb 100644 --- a/src/smt/params/smt_params.h +++ b/src/smt/params/smt_params.h @@ -172,6 +172,14 @@ struct smt_params : public preprocessor_params, unsigned m_old_clause_relevancy = 6; //!< Max. number of unassigned literals to be considered relevant. double m_inv_clause_decay = 1; //!< clause activity decay + // ----------------------------------- + // + // User propagator configuration + // + // ----------------------------------- + + bool m_up_persist_clauses = false; + // ----------------------------------- // // SMT-LIB (debug) pretty printer diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 2ee1c38bbfa..76c96ee005d 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -53,13 +53,14 @@ def_module_params(module_name='smt', ('bv.enable_int2bv', BOOL, True, 'enable support for int2bv and bv2int operators'), ('bv.watch_diseq', BOOL, False, 'use watch lists instead of eager axioms for bit-vectors'), ('bv.delay', BOOL, False, 'delay internalize expensive bit-vector operations'), - ('bv.eq_axioms', BOOL, True, 'enable redundant equality axioms for bit-vectors'), ('bv.size_reduce', BOOL, False, 'pre-processing; turn assertions that set the upper bits of a bit-vector to constants into a substitution that replaces the bit-vector with constant bits. Useful for minimizing circuits as many input bits to circuits are constant'), + ('bv.solver', UINT, 0, 'bit-vector solver engine: 0 - bit-blasting, 1 - polysat, 2 - intblast, requires sat.smt=true'), ('arith.random_initial_value', BOOL, False, 'use random initial values in the simplex-based procedure for linear arithmetic'), ('arith.solver', UINT, 6, 'arithmetic solver: 0 - no solver, 1 - bellman-ford based solver (diff. logic only), 2 - simplex based solver, 3 - floyd-warshall based solver (diff. logic only) and no theory combination 4 - utvpi, 5 - infinitary lra, 6 - lra solver'), ('arith.nl', BOOL, True, '(incomplete) nonlinear arithmetic support based on Groebner basis and interval propagation, relevant only if smt.arith.solver=2'), ('arith.nl.nra', BOOL, True, 'call nra_solver when incremental linearization does not produce a lemma, this option is ignored when arith.nl=false, relevant only if smt.arith.solver=6'), - ('arith.nl.branching', BOOL, True, 'branching on integer variables in non linear clusters, relevant only if smt.arith.solver=2'), + ('arith.nl.branching', BOOL, True, 'branching on integer variables in non linear clusters'), + ('arith.nl.expensive_patching', BOOL, False, 'use the expensive of monomials'), ('arith.nl.rounds', UINT, 1024, 'threshold for number of (nested) final checks for non linear arithmetic, relevant only if smt.arith.solver=2'), ('arith.nl.order', BOOL, True, 'run order lemmas'), ('arith.nl.expp', BOOL, False, 'expensive patching'), @@ -68,6 +69,7 @@ def_module_params(module_name='smt', ('arith.nl.horner_subs_fixed', UINT, 2, '0 - no subs, 1 - substitute, 2 - substitute fixed zeros only'), ('arith.nl.horner_frequency', UINT, 4, 'horner\'s call frequency'), ('arith.nl.horner_row_length_limit', UINT, 10, 'row is disregarded by the heuristic if its length is longer than the value'), + ('arith.nl.grobner_row_length_limit', UINT, 10, 'row is disregarded by the heuristic if its length is longer than the value'), ('arith.nl.grobner_frequency', UINT, 4, 'grobner\'s call frequency'), ('arith.nl.grobner', BOOL, True, 'run grobner\'s basis heuristic'), ('arith.nl.grobner_eqs_growth', UINT, 10, 'grobner\'s number of equalities growth '), @@ -77,7 +79,10 @@ def_module_params(module_name='smt', ('arith.nl.grobner_cnfl_to_report', UINT, 1, 'grobner\'s maximum number of conflicts to report'), ('arith.nl.gr_q', UINT, 10, 'grobner\'s quota'), ('arith.nl.grobner_subs_fixed', UINT, 1, '0 - no subs, 1 - substitute, 2 - substitute fixed zeros only'), - ('arith.nl.delay', UINT, 500, 'number of calls to final check before invoking bounded nlsat check'), + ('arith.nl.delay', UINT, 10, 'number of calls to final check before invoking bounded nlsat check'), + ('arith.nl.propagate_linear_monomials', BOOL, True, 'propagate linear monomials'), + ('arith.nl.optimize_bounds', BOOL, True, 'enable bounds optimization'), + ('arith.nl.cross_nested', BOOL, True, 'enable cross-nested consistency checking'), ('arith.propagate_eqs', BOOL, True, 'propagate (cheap) equalities'), ('arith.propagation_mode', UINT, 1, '0 - no propagation, 1 - propagate existing literals, 2 - refine finite bounds'), ('arith.branch_cut_ratio', UINT, 2, 'branch/cut ratio for linear integer arithmetic'), @@ -90,17 +95,19 @@ def_module_params(module_name='smt', ('arith.rep_freq', UINT, 0, 'the report frequency, in how many iterations print the cost and other info'), ('arith.min', BOOL, False, 'minimize cost'), ('arith.print_stats', BOOL, False, 'print statistic'), + ('arith.validate', BOOL, False, 'validate lemmas generated by arithmetic solver'), ('arith.simplex_strategy', UINT, 0, 'simplex strategy for the solver'), ('arith.enable_hnf', BOOL, True, 'enable hnf (Hermite Normal Form) cuts'), ('arith.bprop_on_pivoted_rows', BOOL, True, 'propagate bounds on rows changed by the pivot operation'), ('arith.print_ext_var_names', BOOL, False, 'print external variable names'), ('pb.conflict_frequency', UINT, 1000, 'conflict frequency for Pseudo-Boolean theory'), ('pb.learn_complements', BOOL, True, 'learn complement literals for Pseudo-Boolean theory'), + ('up.persist_clauses', BOOL, True, 'replay propagated clauses below the levels they are asserted'), ('array.weak', BOOL, False, 'weak array theory'), ('array.extensional', BOOL, True, 'extensional array theory'), ('clause_proof', BOOL, False, 'record a clausal proof'), ('dack', UINT, 1, '0 - disable dynamic ackermannization, 1 - expand Leibniz\'s axiom if a congruence is the root of a conflict, 2 - expand Leibniz\'s axiom if a congruence is used during conflict resolution'), - ('dack.eq', BOOL, False, 'enable dynamic ackermannization for transtivity of equalities'), + ('dack.eq', BOOL, False, 'enable dynamic ackermannization for transitivity of equalities'), ('dack.factor', DOUBLE, 0.1, 'number of instance per conflict'), ('dack.gc', UINT, 2000, 'Dynamic ackermannization garbage collection frequency (per conflict)'), ('dack.gc_inv_decay', DOUBLE, 0.8, 'Dynamic ackermannization garbage collection decay'), @@ -138,6 +145,7 @@ def_module_params(module_name='smt', ('core.extend_patterns.max_distance', UINT, UINT_MAX, 'limits the distance of a pattern-extended unsat core'), ('core.extend_nonlocal_patterns', BOOL, False, 'extend unsat cores with literals that have quantifiers with patterns that contain symbols which are not in the quantifier\'s body'), ('lemma_gc_strategy', UINT, 0, 'lemma garbage collection strategy: 0 - fixed, 1 - geometric, 2 - at restart, 3 - none'), - ('dt_lazy_splits', UINT, 1, 'How lazy datatype splits are performed: 0- eager, 1- lazy for infinite types, 2- lazy') + ('dt_lazy_splits', UINT, 1, 'How lazy datatype splits are performed: 0- eager, 1- lazy for infinite types, 2- lazy'), + ('qsat_use_qel', BOOL, True, 'Use QEL for lite quantifier elimination and model-based projection in QSAT') )) diff --git a/src/smt/params/theory_arith_params.cpp b/src/smt/params/theory_arith_params.cpp index 7f3f1ca23ce..fef7508f4f1 100644 --- a/src/smt/params/theory_arith_params.cpp +++ b/src/smt/params/theory_arith_params.cpp @@ -36,6 +36,10 @@ void theory_arith_params::updt_params(params_ref const & _p) { m_arith_bound_prop = static_cast(p.arith_propagation_mode()); m_arith_eager_eq_axioms = p.arith_eager_eq_axioms(); m_arith_auto_config_simplex = p.arith_auto_config_simplex(); + m_arith_validate = p.arith_validate(); + m_nl_arith_propagate_linear_monomials = p.arith_nl_propagate_linear_monomials(); + m_nl_arith_optimize_bounds = p.arith_nl_optimize_bounds(); + m_nl_arith_cross_nested = p.arith_nl_cross_nested(); arith_rewriter_params ap(_p); m_arith_eq2ineq = ap.eq2ineq(); @@ -89,4 +93,8 @@ void theory_arith_params::display(std::ostream & out) const { DISPLAY_PARAM(m_nl_arith_max_degree); DISPLAY_PARAM(m_nl_arith_branching); DISPLAY_PARAM(m_nl_arith_rounds); + DISPLAY_PARAM(m_nl_arith_propagate_linear_monomials); + DISPLAY_PARAM(m_nl_arith_optimize_bounds); + DISPLAY_PARAM(m_nl_arith_cross_nested); + DISPLAY_PARAM(m_arith_validate); } diff --git a/src/smt/params/theory_arith_params.h b/src/smt/params/theory_arith_params.h index 526cb6f0907..0a6b9edca45 100644 --- a/src/smt/params/theory_arith_params.h +++ b/src/smt/params/theory_arith_params.h @@ -82,6 +82,7 @@ struct theory_arith_params { bool m_arith_adaptive_gcd = false; unsigned m_arith_propagation_threshold = UINT_MAX; + bool m_arith_validate = false; arith_pivot_strategy m_arith_pivot_strategy = arith_pivot_strategy::ARITH_PIVOT_SMALLEST; // used in diff-logic @@ -105,7 +106,9 @@ struct theory_arith_params { unsigned m_nl_arith_max_degree = 6; bool m_nl_arith_branching = true; unsigned m_nl_arith_rounds = 1024; - + bool m_nl_arith_propagate_linear_monomials = true; + bool m_nl_arith_optimize_bounds = true; + bool m_nl_arith_cross_nested = true; theory_arith_params(params_ref const & p = params_ref()) { updt_params(p); diff --git a/src/smt/params/theory_bv_params.cpp b/src/smt/params/theory_bv_params.cpp index 09fa4513fe2..8a3ddcf3729 100644 --- a/src/smt/params/theory_bv_params.cpp +++ b/src/smt/params/theory_bv_params.cpp @@ -27,8 +27,8 @@ void theory_bv_params::updt_params(params_ref const & _p) { m_bv_reflect = p.bv_reflect(); m_bv_enable_int2bv2int = p.bv_enable_int2bv(); m_bv_delay = p.bv_delay(); - m_bv_eq_axioms = p.bv_eq_axioms(); m_bv_size_reduce = p.bv_size_reduce(); + m_bv_solver = p.bv_solver(); } #define DISPLAY_PARAM(X) out << #X"=" << X << '\n'; @@ -38,10 +38,10 @@ void theory_bv_params::display(std::ostream & out) const { DISPLAY_PARAM(m_hi_div0); DISPLAY_PARAM(m_bv_reflect); DISPLAY_PARAM(m_bv_lazy_le); - DISPLAY_PARAM(m_bv_eq_axioms); DISPLAY_PARAM(m_bv_cc); DISPLAY_PARAM(m_bv_blast_max_size); DISPLAY_PARAM(m_bv_enable_int2bv2int); DISPLAY_PARAM(m_bv_delay); DISPLAY_PARAM(m_bv_size_reduce); + DISPLAY_PARAM(m_bv_solver); } diff --git a/src/smt/params/theory_bv_params.h b/src/smt/params/theory_bv_params.h index e83b0b5db4d..97428c8ba28 100644 --- a/src/smt/params/theory_bv_params.h +++ b/src/smt/params/theory_bv_params.h @@ -31,12 +31,12 @@ struct theory_bv_params { bool m_bv_reflect = true; bool m_bv_lazy_le = false; bool m_bv_cc = false; - bool m_bv_eq_axioms = true; unsigned m_bv_blast_max_size = INT_MAX; bool m_bv_enable_int2bv2int = true; bool m_bv_watch_diseq = false; bool m_bv_delay = true; bool m_bv_size_reduce = false; + unsigned m_bv_solver = 0; theory_bv_params(params_ref const & p = params_ref()) { updt_params(p); } diff --git a/src/smt/proto_model/proto_model.h b/src/smt/proto_model/proto_model.h index f069e0a60a5..d12f56dee97 100644 --- a/src/smt/proto_model/proto_model.h +++ b/src/smt/proto_model/proto_model.h @@ -63,6 +63,9 @@ class proto_model : public model_core { void register_factory(value_factory * f) { m_factories.register_plugin(f); } bool eval(expr * e, expr_ref & result, bool model_completion = false); + bool are_equal(expr* a, expr* b) { return m_eval.are_equal(a, b); } + bool is_false(expr* e) { return m_eval.are_equal(e, m.mk_false()); } + expr_ref operator()(expr* e) { expr_ref result(e, m); eval(e, result, false); return result; } value_factory * get_factory(family_id fid); diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index 582bcc66452..52399abb6e1 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -131,6 +131,8 @@ namespace smt { // max_top_generation and min_top_generation are not available for computing inc_gen set_values(q, nullptr, generation, 0, 0, cost); float r = m_evaluator(m_new_gen_function, m_vals.size(), m_vals.data()); + if (q->get_weight() > 0 || r > 0) + return static_cast(r); return std::max(generation + 1, static_cast(r)); } @@ -396,6 +398,7 @@ namespace smt { bool qi_queue::final_check_eh() { TRACE("qi_queue", display_delayed_instances_stats(tout); tout << "lazy threshold: " << m_params.m_qi_lazy_threshold << ", scope_level: " << m_context.get_scope_level() << "\n";); + if (m_params.m_qi_conservative_final_check) { bool init = false; float min_cost = 0.0; diff --git a/src/smt/smt_clause_proof.cpp b/src/smt/smt_clause_proof.cpp index 3bb2a1fdfb3..5cb3729ab65 100644 --- a/src/smt/smt_clause_proof.cpp +++ b/src/smt/smt_clause_proof.cpp @@ -90,14 +90,14 @@ namespace smt { return proof_ref(m); } - void clause_proof::add(clause& c) { + void clause_proof::add(clause& c, literal_buffer const* simp_lits) { if (!is_enabled()) return; justification* j = c.get_justification(); auto st = kind2st(c.get_kind()); auto pr = justification2proof(st, j); CTRACE("mk_clause", pr.get(), tout << mk_bounded_pp(pr, m, 4) << "\n";); - update(c, st, pr); + update(c, st, pr, simp_lits); } void clause_proof::add(unsigned n, literal const* lits, clause_kind k, justification* j) { @@ -137,12 +137,15 @@ namespace smt { update(st, m_lits, pr); } - void clause_proof::add(literal lit1, literal lit2, clause_kind k, justification* j) { + void clause_proof::add(literal lit1, literal lit2, clause_kind k, justification* j, literal_buffer const* simp_lits) { if (!is_enabled()) return; m_lits.reset(); m_lits.push_back(ctx.literal2expr(lit1)); m_lits.push_back(ctx.literal2expr(lit2)); + if (simp_lits) + for (auto lit : *simp_lits) + m_lits.push_back(ctx.literal2expr(~lit)); auto st = kind2st(k); auto pr = justification2proof(st, j); update(st, m_lits, pr); @@ -160,7 +163,7 @@ namespace smt { } void clause_proof::del(clause& c) { - update(c, status::deleted, justification2proof(status::deleted, nullptr)); + update(c, status::deleted, justification2proof(status::deleted, nullptr), nullptr); } std::ostream& clause_proof::display_literals(std::ostream& out, expr_ref_vector const& v) { @@ -190,7 +193,8 @@ namespace smt { if (ctx.get_fparams().m_clause_proof) m_trail.push_back(info(st, v, p)); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, p, v.size(), v.data()); + m_on_clause_eh(m_on_clause_ctx, p, 0, nullptr, v.size(), v.data()); + if (m_has_log) { init_pp_out(); auto& out = *m_pp_out; @@ -220,12 +224,15 @@ namespace smt { } } - void clause_proof::update(clause& c, status st, proof* p) { + void clause_proof::update(clause& c, status st, proof* p, literal_buffer const* simp_lits) { if (!is_enabled()) return; m_lits.reset(); for (literal lit : c) - m_lits.push_back(ctx.literal2expr(lit)); + m_lits.push_back(ctx.literal2expr(lit)); + if (simp_lits) + for (auto lit : *simp_lits) + m_lits.push_back(ctx.literal2expr(~lit)); update(st, m_lits, p); } diff --git a/src/smt/smt_clause_proof.h b/src/smt/smt_clause_proof.h index 1c593113669..d7cc421cfb2 100644 --- a/src/smt/smt_clause_proof.h +++ b/src/smt/smt_clause_proof.h @@ -68,7 +68,7 @@ namespace smt { void init_pp_out(); void update(status st, expr_ref_vector& v, proof* p); - void update(clause& c, status st, proof* p); + void update(clause& c, status st, proof* p, literal_buffer const* simp_lits); status kind2st(clause_kind k); proof_ref justification2proof(status st, justification* j); void log(status st, proof* p); @@ -79,8 +79,8 @@ namespace smt { clause_proof(context& ctx); void shrink(clause& c, unsigned new_size); void add(literal lit, clause_kind k, justification* j); - void add(literal lit1, literal lit2, clause_kind k, justification* j); - void add(clause& c); + void add(literal lit1, literal lit2, clause_kind k, justification* j, literal_buffer const* simp_lits = nullptr); + void add(clause& c, literal_buffer const* simp_lits = nullptr); void add(unsigned n, literal const* lits, clause_kind k, justification* j); void propagate(literal lit, justification const& j, literal_vector const& ante); void del(clause& c); diff --git a/src/smt/smt_conflict_resolution.cpp b/src/smt/smt_conflict_resolution.cpp index d075c06522b..7932434738f 100644 --- a/src/smt/smt_conflict_resolution.cpp +++ b/src/smt/smt_conflict_resolution.cpp @@ -601,6 +601,7 @@ namespace smt { finalize_resolve(conflict, not_l); + return true; } @@ -752,7 +753,7 @@ namespace smt { } else { if (j != i) { - m_lemma[j] = m_lemma[i]; + m_lemma[j] = l; m_lemma_atoms.set(j, m_lemma_atoms.get(i)); } j++; diff --git a/src/smt/smt_consequences.cpp b/src/smt/smt_consequences.cpp index 657e222dadf..fe2bd71494c 100644 --- a/src/smt/smt_consequences.cpp +++ b/src/smt/smt_consequences.cpp @@ -243,7 +243,7 @@ namespace smt { lit.neg(); literal lit = mk_diseq(k, v); - literals.push_back(lit); + literals.push_back(~lit); mk_clause(literals.size(), literals.data(), nullptr); TRACE("context", display_literals_verbose(tout, literals.size(), literals.data());); } @@ -274,6 +274,12 @@ namespace smt { expr_ref_vector& conseq, expr_ref_vector& unfixed) { + for (expr* a : assumptions0) + if (!m.is_bool(a)) { + std::string msg = std::string("assumption ") + mk_pp(a, m) + std::string(" is not Boolean"); + warning_msg(msg.c_str()); + throw default_exception(msg.c_str()); + } m_antecedents.reset(); m_antecedents.insert(true_literal.var(), index_set()); pop_to_base_lvl(); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index d84535a9c86..62c7c8ddd10 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -36,6 +36,7 @@ Revision History: #include "smt/smt_quick_checker.h" #include "smt/uses_theory.h" #include "smt/theory_special_relations.h" +#include "smt/theory_polymorphism.h" #include "smt/smt_for_each_relevant_expr.h" #include "smt/smt_model_generator.h" #include "smt/smt_model_checker.h" @@ -153,9 +154,8 @@ namespace smt { src_af.get_macro_manager().copy_to(dst_af.get_macro_manager()); - if (!src_ctx.m_setup.already_configured()) { + if (!src_ctx.m_setup.already_configured()) return; - } for (unsigned i = 0; !src_m.proofs_enabled() && i < src_ctx.m_assigned_literals.size(); ++i) { literal lit = src_ctx.m_assigned_literals[i]; @@ -560,6 +560,7 @@ namespace smt { // Update "equivalence" class size r2->m_class_size += r1->m_class_size; + r2->m_is_shared = 2; CASSERT("add_eq", check_invariant()); } @@ -920,6 +921,7 @@ namespace smt { // restore r2 class size r2->m_class_size -= r1->m_class_size; + r2->m_is_shared = 2; // unmerge "equivalence" classes std::swap(r1->m_next, r2->m_next); @@ -1670,12 +1672,7 @@ namespace smt { } bool context::can_theories_propagate() const { - for (theory* t : m_theory_set) { - if (t->can_propagate()) { - return true; - } - } - return false; + return any_of(m_theory_set, [&](theory* t) { return t->can_propagate(); }); } bool context::can_propagate() const { @@ -2918,7 +2915,9 @@ namespace smt { bool context::has_split_candidate(bool_var& var, bool& is_pos) { if (!m_user_propagator) return false; - return m_user_propagator->get_case_split(var, is_pos); + if (!m_user_propagator->get_case_split(var, is_pos)) + return false; + return get_assignment(var) == l_undef; } bool context::decide_user_interference(bool_var& var, bool& is_pos) { @@ -3024,7 +3023,8 @@ namespace smt { SASSERT(is_well_sorted(m, e)); TRACE("begin_assert_expr", tout << mk_pp(e, m) << " " << mk_pp(pr, m) << "\n";); TRACE("begin_assert_expr_ll", tout << mk_ll_pp(e, m) << "\n";); - pop_to_base_lvl(); + if (!m_searching) + pop_to_base_lvl(); if (pr == nullptr) m_asserted_formulas.assert_expr(e); else @@ -3229,7 +3229,7 @@ namespace smt { } expr * f = m_asserted_formulas.get_formula(qhead); proof * pr = m_asserted_formulas.get_formula_proof(qhead); - SASSERT(!pr || f == m.get_fact(pr)); + SASSERT(!pr || f == m.get_fact(pr)); internalize_assertion(f, pr, 0); ++qhead; } @@ -3334,6 +3334,7 @@ namespace smt { reset_assumptions(); m_literal2assumption.reset(); m_unsat_core.reset(); + if (!asms.empty()) { // We must give a chance to the theories to propagate before we create a new scope... propagate(); @@ -3343,6 +3344,7 @@ namespace smt { return; if (get_cancel_flag()) return; + del_inactive_lemmas(); push_scope(); vector> asm2proxy; internalize_proxies(asms, asm2proxy); @@ -3465,13 +3467,9 @@ namespace smt { } if (r == l_true && gparams::get_value("model_validate") == "true") { recfun::util u(m); - model_ref mdl; - get_model(mdl); - if (u.get_rec_funs().empty()) { - if (mdl.get()) { - for (theory* t : m_theory_set) { - t->validate_model(*mdl); - } + if (u.get_rec_funs().empty() && m_proto_model) { + for (theory* t : m_theory_set) { + t->validate_model(*m_proto_model); } } #if 0 @@ -3564,7 +3562,6 @@ namespace smt { try { internalize_assertions(); } catch (cancel_exception&) { - VERIFY(resource_limits_exceeded()); return l_undef; } expr_ref_vector theory_assumptions(m); @@ -3636,7 +3633,6 @@ namespace smt { TRACE("unsat_core_bug", tout << asms << '\n';); init_assumptions(asms); } catch (cancel_exception&) { - VERIFY(resource_limits_exceeded()); return l_undef; } TRACE("before_search", display(tout);); @@ -3663,7 +3659,6 @@ namespace smt { for (auto const& clause : clauses) if (!validate_assumptions(clause)) return l_undef; init_assumptions(asms); } catch (cancel_exception&) { - VERIFY(resource_limits_exceeded()); return l_undef; } for (auto const& clause : clauses) init_clause(clause); @@ -3698,6 +3693,8 @@ namespace smt { m_phase_default = false; m_case_split_queue ->init_search_eh(); m_next_progress_sample = 0; + if (m.has_type_vars() && !m_theories.get_plugin(poly_family_id)) + register_plugin(alloc(theory_polymorphism, *this)); TRACE("literal_occ", display_literal_num_occs(tout);); } @@ -4120,7 +4117,6 @@ namespace smt { // Moreover, I backtrack only one level. bool delay_forced_restart = m_fparams.m_delay_units && - internalized_quantifiers() && num_lits == 1 && conflict_lvl > m_search_lvl + 1 && !m.proofs_enabled() && @@ -4269,9 +4265,11 @@ namespace smt { SASSERT(num_lits == 1); expr * unit = bool_var2expr(lits[0].var()); bool unit_sign = lits[0].sign(); + while (m.is_not(unit, unit)) + unit_sign = !unit_sign; m_units_to_reassert.push_back(unit); m_units_to_reassert_sign.push_back(unit_sign); - TRACE("reassert_units", tout << "asserting #" << unit->get_id() << " " << unit_sign << " @ " << m_scope_lvl << "\n";); + TRACE("reassert_units", tout << "asserting " << mk_pp(unit, m) << " #" << unit->get_id() << " " << unit_sign << " @ " << m_scope_lvl << "\n";); } m_conflict_resolution->release_lemma_atoms(); @@ -4504,8 +4502,15 @@ namespace smt { bool context::is_shared(enode * n) const { n = n->get_root(); + switch (n->is_shared()) { + case l_true: return true; + case l_false: return false; + default: break; + } + unsigned num_th_vars = n->get_num_th_vars(); if (m.is_ite(n->get_expr())) { + n->set_is_shared(l_true); return true; } switch (num_th_vars) { @@ -4531,6 +4536,7 @@ namespace smt { TRACE("is_shared", tout << enode_pp(n, *this) << "\nis shared because of:\n" << enode_pp(parent, *this) << "\n";); + n->set_is_shared(l_true); return true; } } @@ -4561,7 +4567,9 @@ namespace smt { // the theories of (array int int) and (array (array int int) int). // Remark: The inconsistency is not going to be detected if they are // not marked as shared. - return get_theory(th_id)->is_shared(l->get_var()); + bool r = get_theory(th_id)->is_shared(l->get_var()); + n->set_is_shared(to_lbool(r)); + return r; } default: return true; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 7a267fdeca1..4feeae1b50c 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -61,9 +61,16 @@ Revision History: namespace smt { class model_generator; + class context; struct cancel_exception {}; + struct enode_pp { + context const& ctx; + enode* n; + enode_pp(enode* n, context const& ctx): ctx(ctx), n(n) {} + }; + class context { friend class model_generator; friend class lookahead; @@ -134,7 +141,6 @@ namespace smt { enode * m_lhs; enode * m_rhs; eq_justification m_justification; - new_eq() {} new_eq(enode * lhs, enode * rhs, eq_justification const & js): m_lhs(lhs), m_rhs(rhs), m_justification(js) {} }; @@ -143,7 +149,6 @@ namespace smt { theory_id m_th_id; theory_var m_lhs; theory_var m_rhs; - new_th_eq():m_th_id(null_theory_id), m_lhs(null_theory_var), m_rhs(null_theory_var) {} new_th_eq(theory_id id, theory_var l, theory_var r):m_th_id(id), m_lhs(l), m_rhs(r) {} }; svector m_th_eq_propagation_queue; @@ -215,7 +220,7 @@ namespace smt { // ----------------------------------- proto_model_ref m_proto_model; model_ref m_model; - std::string m_unknown; + const char * m_unknown; void mk_proto_model(); void reset_model() { m_model = nullptr; m_proto_model = nullptr; } @@ -1370,6 +1375,8 @@ namespace smt { void display_asserted_formulas(std::ostream & out) const; + enode_pp pp(enode* n) { return enode_pp(n, *this); } + std::ostream& display_literal(std::ostream & out, literal l) const; std::ostream& display_detailed_literal(std::ostream & out, literal l) const { return smt::display(out, l, m, m_bool_var2expr.data()); } @@ -1846,11 +1853,6 @@ namespace smt { std::ostream& operator<<(std::ostream& out, enode_eq_pp const& p); - struct enode_pp { - context const& ctx; - enode* n; - enode_pp(enode* n, context const& ctx): ctx(ctx), n(n) {} - }; std::ostream& operator<<(std::ostream& out, enode_pp const& p); diff --git a/src/smt/smt_context_pp.cpp b/src/smt/smt_context_pp.cpp index a6088fdf7df..57875a8bdfa 100644 --- a/src/smt/smt_context_pp.cpp +++ b/src/smt/smt_context_pp.cpp @@ -66,6 +66,7 @@ namespace smt { std::string context::last_failure_as_string() const { std::string r; switch(m_last_search_failure) { + case UNKNOWN: case OK: r = m_unknown; break; case MEMOUT: r = "memout"; break; case CANCELED: r = "canceled"; break; @@ -82,7 +83,6 @@ namespace smt { case RESOURCE_LIMIT: r = "(resource limits reached)"; break; case QUANTIFIERS: r = "(incomplete quantifiers)"; break; case LAMBDAS: r = "(incomplete lambdas)"; break; - case UNKNOWN: r = m_unknown; break; } return r; } @@ -132,7 +132,7 @@ namespace smt { void context::display_literal_info(std::ostream & out, literal l) const { smt::display_compact(out, l, m_bool_var2expr.data()); - display_literal_smt2(out, l); + display_literal_smt2(out << " " << l << ": ", l); out << "relevant: " << is_relevant(bool_var2expr(l.var())) << ", val: " << get_assignment(l) << "\n"; } @@ -510,7 +510,7 @@ namespace smt { #else strm << "lemma_" << (++m_lemma_id) << ".smt2"; #endif - return strm.str(); + return std::move(strm).str(); } @@ -635,7 +635,7 @@ namespace smt { literal_vector lits; const_cast(*m_conflict_resolution).justification2literals(j.get_justification(), lits); out << "justification " << j.get_justification()->get_from_theory() << ": "; - // display_literals_smt2(out, lits); + display_literals_smt2(out, lits); break; } default: @@ -722,7 +722,7 @@ namespace smt { << std::setw(4) << m_stats.m_num_del_clauses << " " << std::setw(7) << mem_stat() << ")\n"; - std::string str(strm.str()); + std::string str = std::move(strm).str(); svector offsets; for (size_t i = 0; i < str.size(); ++i) { while (i < str.size() && str[i] != ' ') ++i; diff --git a/src/smt/smt_enode.cpp b/src/smt/smt_enode.cpp index 49f05b01931..86b83af4c9b 100644 --- a/src/smt/smt_enode.cpp +++ b/src/smt/smt_enode.cpp @@ -49,10 +49,12 @@ namespace smt { n->m_iscope_lvl = iscope_lvl; n->m_lbl_hash = -1; n->m_proof_is_logged = false; + n->m_is_shared = 2; unsigned num_args = n->get_num_args(); for (unsigned i = 0; i < num_args; i++) { enode * arg = app2enode[owner->get_arg(i)->get_id()]; n->m_args[i] = arg; + arg->get_root()->m_is_shared = 2; SASSERT(n->get_arg(i) == arg); if (update_children_parent) arg->get_root()->m_parents.push_back(n); diff --git a/src/smt/smt_enode.h b/src/smt/smt_enode.h index b3a3bbf69a4..92902ea0b1b 100644 --- a/src/smt/smt_enode.h +++ b/src/smt/smt_enode.h @@ -77,6 +77,7 @@ namespace smt { unsigned m_bool:1; //!< True if it is a boolean enode unsigned m_merge_tf:1; //!< True if the enode should be merged with true/false when the associated boolean variable is assigned. unsigned m_cgc_enabled:1; //!< True if congruence closure is enabled for this enode. + unsigned m_is_shared:2; //!< 0 - not shared, 1 - shared, 2 - invalid state unsigned m_iscope_lvl; //!< When the enode was internalized bool m_proof_is_logged; //!< Indicates that the proof for the enode being equal to its root is in the log. signed char m_lbl_hash; //!< It is different from -1, if enode is used in a pattern @@ -179,6 +180,21 @@ namespace smt { return m_owner->hash(); } + lbool is_shared() const { + switch (m_is_shared) { + case 0: return l_false; + case 1: return l_true; + default: return l_undef; + } + } + + void set_is_shared(lbool s) { + switch (s) { + case l_true: m_is_shared = 1; break; + case l_false: m_is_shared = 0; break; + default: m_is_shared = 2; break; + } + } enode * get_root() const { return m_root; diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 68879b8aced..2b18d9a3f33 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -297,9 +297,10 @@ namespace smt { void context::assert_default(expr * n, proof * pr) { internalize(n, true); literal l = get_literal(n); - if (l == false_literal) { + if (l == false_literal) set_conflict(mk_justification(justification_proof_wrapper(*this, pr))); - } + else if (l == true_literal) + return; else { justification* j = mk_justification(justification_proof_wrapper(*this, pr)); m_clause_proof.add(l, CLS_AUX, j); @@ -1377,12 +1378,12 @@ namespace smt { clause * context::mk_clause(unsigned num_lits, literal * lits, justification * j, clause_kind k, clause_del_eh * del_eh) { TRACE("mk_clause", display_literals_verbose(tout << "creating clause: " << literal_vector(num_lits, lits) << "\n", num_lits, lits) << "\n";); m_clause_proof.add(num_lits, lits, k, j); + literal_buffer simp_lits; switch (k) { case CLS_TH_AXIOM: dump_axiom(num_lits, lits); Z3_fallthrough; case CLS_AUX: { - literal_buffer simp_lits; if (m_searching) dump_lemma(num_lits, lits); if (!simplify_aux_clause_literals(num_lits, lits, simp_lits)) { @@ -1450,7 +1451,7 @@ namespace smt { else if (get_assignment(l2) == l_false) { assign(l1, b_justification(~l2)); } - m_clause_proof.add(l1, l2, k, j); + m_clause_proof.add(l1, l2, k, j, &simp_lits); m_stats.m_num_mk_bin_clause++; return nullptr; } @@ -1463,7 +1464,7 @@ namespace smt { bool reinit = save_atoms; SASSERT(!lemma || j == 0 || !j->in_region()); clause * cls = clause::mk(m, num_lits, lits, k, j, del_eh, save_atoms, m_bool_var2expr.data()); - m_clause_proof.add(*cls); + m_clause_proof.add(*cls, &simp_lits); if (lemma) { cls->set_activity(activity); if (k == CLS_LEARNED) { @@ -1829,7 +1830,7 @@ namespace smt { // Case) there is a variable old_v in the var-list of n. // // Remark: This variable was moved to the var-list of n due to a add_eq. - SASSERT(th->get_enode(old_v) != n); // this varialbe is not owned by n + SASSERT(th->get_enode(old_v) != n); // this variable is not owned by n SASSERT(n->get_root()->get_th_var(th_id) != null_theory_var); // the root has also a variable in its var-list. n->replace_th_var(v, th_id); push_trail(replace_th_var_trail( n, th_id, old_v)); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index c4ecf6787b1..74f0bded65d 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -63,8 +63,8 @@ namespace smt { return m_imp->m_kernel.get_manager(); } - void kernel::copy(kernel& src, kernel& dst) { - context::copy(src.m_imp->m_kernel, dst.m_imp->m_kernel); + void kernel::copy(kernel& src, kernel& dst, bool override_base) { + context::copy(src.m_imp->m_kernel, dst.m_imp->m_kernel, override_base); } bool kernel::set_logic(symbol logic) { diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index ccea5caf8f8..dacbb525e09 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -50,7 +50,7 @@ namespace smt { ~kernel(); - static void copy(kernel& src, kernel& dst); + static void copy(kernel& src, kernel& dst, bool override_base); ast_manager & m() const; diff --git a/src/smt/smt_model_finder.cpp b/src/smt/smt_model_finder.cpp index 94a0992f3ba..4139b310971 100644 --- a/src/smt/smt_model_finder.cpp +++ b/src/smt/smt_model_finder.cpp @@ -574,9 +574,9 @@ namespace smt { to_delete.push_back(n); } } - for (expr* e : to_delete) { + for (expr* e : to_delete) s->remove(e); - } + reset_eval_cache(); } } } diff --git a/src/smt/smt_relevancy.cpp b/src/smt/smt_relevancy.cpp index 649d75737d7..ce6de3ec3eb 100644 --- a/src/smt/smt_relevancy.cpp +++ b/src/smt/smt_relevancy.cpp @@ -123,7 +123,7 @@ namespace smt { } struct relevancy_propagator_imp : public relevancy_propagator { - unsigned m_qhead; + unsigned m_qhead = 0; expr_ref_vector m_relevant_exprs; uint_set m_is_relevant; typedef list relevancy_ehs; @@ -144,14 +144,18 @@ namespace smt { unsigned m_trail_lim; }; svector m_scopes; - bool m_propagating; + bool m_propagating = false; relevancy_propagator_imp(context & ctx): - relevancy_propagator(ctx), m_qhead(0), m_relevant_exprs(ctx.get_manager()), - m_propagating(false) {} + relevancy_propagator(ctx), m_relevant_exprs(ctx.get_manager()) {} ~relevancy_propagator_imp() override { - undo_trail(0); + ast_manager & m = get_manager(); + unsigned i = m_trail.size(); + while (i != 0) { + --i; + m.dec_ref(m_trail[i].get_node()); + } } relevancy_ehs * get_handlers(expr * n) { diff --git a/src/smt/smt_relevancy.h b/src/smt/smt_relevancy.h index f64b7d05973..f6e3c46592f 100644 --- a/src/smt/smt_relevancy.h +++ b/src/smt/smt_relevancy.h @@ -41,13 +41,14 @@ namespace smt { /** \brief Fallback for the two previous methods. */ - virtual void operator()(relevancy_propagator & rp) {} + virtual void operator()(relevancy_propagator & rp) = 0; }; class simple_relevancy_eh : public relevancy_eh { expr * m_target; public: simple_relevancy_eh(expr * t):m_target(t) {} + ~simple_relevancy_eh() override {} void operator()(relevancy_propagator & rp) override; }; @@ -60,6 +61,7 @@ namespace smt { expr * m_target; public: pair_relevancy_eh(expr * s1, expr * s2, expr * t):m_source1(s1), m_source2(s2), m_target(t) {} + ~pair_relevancy_eh() override {} void operator()(relevancy_propagator & rp) override; }; diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index 4b02d172123..c0a4158f3ff 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -38,6 +38,7 @@ Revision History: #include "smt/theory_pb.h" #include "smt/theory_fpa.h" #include "smt/theory_str.h" +#include "smt/theory_polymorphism.h" #include "smt/theory_str_noodler/theory_str_noodler.h" namespace smt { @@ -797,6 +798,11 @@ namespace smt { m_context.register_plugin(alloc(smt::theory_special_relations, m_context, m_manager)); } + void setup::setup_polymorphism() { + if (m_manager.has_type_vars()) + m_context.register_plugin(alloc(theory_polymorphism, m_context)); + } + void setup::setup_unknown() { static_features st(m_manager); ptr_vector fmls; @@ -812,6 +818,7 @@ namespace smt { setup_seq_str(st); setup_fpa(); setup_special_relations(); + setup_polymorphism(); } void setup::setup_unknown(static_features & st) { @@ -828,6 +835,7 @@ namespace smt { setup_fpa(); setup_recfuns(); setup_special_relations(); + setup_polymorphism(); return; } diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index dee1fb9a382..d5690b1df3e 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -82,6 +82,7 @@ namespace smt { void setup_LRA(); void setup_CSP(); void setup_special_relations(); + void setup_polymorphism(); void setup_AUFLIA(bool simple_array = true); void setup_AUFLIA(static_features const & st); void setup_AUFLIRA(bool simple_array = true); diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 4be78b20a75..f91a3111184 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -20,7 +20,7 @@ Module Name: #include "util/dec_ref_util.h" #include "ast/reg_decl_plugins.h" #include "ast/for_each_expr.h" -#include "ast/ast_smt2_pp.h" +#include "ast/ast_pp.h" #include "ast/func_decl_dependencies.h" #include "smt/smt_kernel.h" #include "smt/params/smt_params.h" @@ -88,14 +88,14 @@ namespace { ast_translation translator(get_manager(), m); smt_solver * result = alloc(smt_solver, m, p, m_logic); - smt::kernel::copy(m_context, result->m_context); + smt::kernel::copy(m_context, result->m_context, true); if (mc0()) result->set_model_converter(mc0()->translate(translator)); - for (auto & kv : m_name2assertion) { - expr* val = translator(kv.m_value); - expr* key = translator(kv.m_key); + for (auto & [k, v] : m_name2assertion) { + expr* val = translator(k); + expr* key = translator(v); result->assert_expr(val, key); } @@ -104,9 +104,9 @@ namespace { ~smt_solver() override { dealloc(m_cuber); - for (auto& kv : m_name2assertion) { - get_manager().dec_ref(kv.m_key); - get_manager().dec_ref(kv.m_value); + for (auto& [k,v] : m_name2assertion) { + get_manager().dec_ref(k); + get_manager().dec_ref(v); } } @@ -267,6 +267,7 @@ namespace { }; void get_unsat_core(expr_ref_vector & r) override { + unsigned sz = m_context.get_unsat_core_size(); for (unsigned i = 0; i < sz; i++) { r.push_back(m_context.get_unsat_core_expr(i)); diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index c715f215f1e..d0e73cc9296 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -368,7 +368,8 @@ namespace smt { // // ---------------------------------------------------- - virtual void validate_model(model& mdl) {} + virtual void validate_model(proto_model& mdl) {} + // ---------------------------------------------------- // diff --git a/src/smt/theory_arith.h b/src/smt/theory_arith.h index 92aa4baec66..86c05aec64e 100644 --- a/src/smt/theory_arith.h +++ b/src/smt/theory_arith.h @@ -301,7 +301,7 @@ namespace smt { inf_numeral const & get_value() const { return m_value; } virtual bool has_justification() const { return false; } virtual void push_justification(antecedents& antecedents, numeral const& coeff, bool proofs_enabled) {} - virtual void display(theory_arith const& th, std::ostream& out) const; + virtual std::ostream& display(theory_arith const& th, std::ostream& out) const; }; @@ -327,7 +327,7 @@ namespace smt { void push_justification(antecedents& a, numeral const& coeff, bool proofs_enabled) override { a.push_lit(literal(get_bool_var(), !m_is_true), coeff, proofs_enabled); } - void display(theory_arith const& th, std::ostream& out) const override; + std::ostream& display(theory_arith const& th, std::ostream& out) const override; }; class eq_bound : public bound { @@ -345,7 +345,7 @@ namespace smt { SASSERT(m_lhs->get_root() == m_rhs->get_root()); a.push_eq(enode_pair(m_lhs, m_rhs), coeff, proofs_enabled); } - void display(theory_arith const& th, std::ostream& out) const override; + std::ostream& display(theory_arith const& th, std::ostream& out) const override; }; class derived_bound : public bound { @@ -361,7 +361,7 @@ namespace smt { void push_justification(antecedents& a, numeral const& coeff, bool proofs_enabled) override; virtual void push_lit(literal l, numeral const&) { m_lits.push_back(l); } virtual void push_eq(enode_pair const& p, numeral const&) { m_eqs.push_back(p); } - void display(theory_arith const& th, std::ostream& out) const override; + std::ostream& display(theory_arith const& th, std::ostream& out) const override; }; @@ -788,7 +788,7 @@ namespace smt { // // ----------------------------------- void mark_row_for_bound_prop(unsigned r1); - void mark_rows_for_bound_prop(theory_var v); + void add_column_rows_to_touched_rows(theory_var v); void is_row_useful_for_bound_prop(row const & r, int & lower_idx, int & upper_idx) const; unsigned imply_bound_for_monomial(row const & r, int idx, bool lower); unsigned imply_bound_for_all_monomials(row const & r, bool lower); @@ -824,7 +824,7 @@ namespace smt { unsigned m_assume_eq_head = 0; bool random_update(theory_var v); void mutate_assignment(); - bool assume_eqs_core(); + bool assume_eqs(); bool delayed_assume_eqs(); // ----------------------------------- diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index 470ea5f7b4a..8141377c455 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -370,8 +370,8 @@ namespace smt { template - void theory_arith::bound::display(theory_arith const& th, std::ostream& out) const { - out << "v" << get_var() << " " << get_bound_kind() << " " << get_value(); + std::ostream& theory_arith::bound::display(theory_arith const& th, std::ostream& out) const { + return out << "v" << get_var() << " " << get_bound_kind() << " " << get_value(); } @@ -414,11 +414,10 @@ namespace smt { } template - void theory_arith::atom::display(theory_arith const& th, std::ostream& out) const { + std::ostream& theory_arith::atom::display(theory_arith const& th, std::ostream& out) const { literal l(get_bool_var(), !m_is_true); - // out << "v" << bound::get_var() << " " << bound::get_bound_kind() << " " << get_k() << " "; - // out << l << ":"; th.ctx.display_detailed_literal(out, l); + return out; } // ----------------------------------- @@ -428,10 +427,10 @@ namespace smt { // ----------------------------------- template - void theory_arith::eq_bound::display(theory_arith const& th, std::ostream& out) const { + std::ostream& theory_arith::eq_bound::display(theory_arith const& th, std::ostream& out) const { ast_manager& m = th.get_manager(); - out << "#" << m_lhs->get_owner_id() << " " << mk_pp(m_lhs->get_expr(), m) << " = " - << "#" << m_rhs->get_owner_id() << " " << mk_pp(m_rhs->get_expr(), m); + return out << "#" << m_lhs->get_owner_id() << " " << mk_pp(m_lhs->get_expr(), m) << " = " + << "#" << m_rhs->get_owner_id() << " " << mk_pp(m_rhs->get_expr(), m); } // ----------------------------------- @@ -752,7 +751,7 @@ namespace smt { } template - void theory_arith::derived_bound::display(theory_arith const& th, std::ostream& out) const { + std::ostream& theory_arith::derived_bound::display(theory_arith const& th, std::ostream& out) const { ast_manager& m = th.get_manager(); out << "v" << bound::get_var() << " " << bound::get_bound_kind() << " " << bound::get_value() << "\n"; out << "expr: " << mk_pp(th.var2expr(bound::get_var()), m) << "\n"; @@ -765,8 +764,9 @@ namespace smt { << "#" << b->get_owner_id() << " " << mk_pp(b->get_expr(), m) << "\n"; } for (literal l : m_lits) { - out << l << ":"; th.ctx.display_detailed_literal(out, l) << "\n"; + out << l << ":"; th.ctx.display_detailed_literal(out, l) << "\n"; } + return out; } @@ -2195,33 +2195,27 @@ namespace smt { } template - bool theory_arith::assume_eqs_core() { + bool theory_arith::assume_eqs() { // See comment in m_liberal_final_check declaration if (m_liberal_final_check) mutate_assignment(); TRACE("assume_eq_int", display(tout);); unsigned old_sz = m_assume_eq_candidates.size(); - TRACE("func_interp_bug", display(tout);); m_var_value_table.reset(); bool result = false; int num = get_num_vars(); for (theory_var v = 0; v < num; v++) { enode * n = get_enode(v); - TRACE("func_interp_bug", tout << mk_pp(n->get_expr(), get_manager()) << " -> " << m_value[v] << " root #" << n->get_root()->get_owner_id() << " " << is_relevant_and_shared(n) << "\n";); - if (!is_relevant_and_shared(n)) { + if (!is_relevant_and_shared(n)) continue; - } theory_var other = null_theory_var; other = m_var_value_table.insert_if_not_there(v); - if (other == v) { + if (other == v) continue; - } enode * n2 = get_enode(other); - if (n->get_root() == n2->get_root()) { + if (n->get_root() == n2->get_root()) continue; - } - TRACE("func_interp_bug", tout << "adding to assume_eq queue #" << n->get_owner_id() << " #" << n2->get_owner_id() << "\n";); m_assume_eq_candidates.push_back({ other , v }); result = true; } @@ -2242,10 +2236,9 @@ namespace smt { enode* n1 = get_enode(v1); enode* n2 = get_enode(v2); m_assume_eq_head++; - CTRACE("func_interp_bug", - get_value(v1) == get_value(v2) && - n1->get_root() != n2->get_root(), - tout << "assuming eq: #" << n1->get_owner_id() << " = #" << n2->get_owner_id() << "\n";); + CTRACE("arith", + get_value(v1) == get_value(v2) && n1->get_root() != n2->get_root(), + tout << "assuming eq: " << ctx.pp(n1) << " = #" << ctx.pp(n2) << "\n";); if (get_value(v1) == get_value(v2) && n1->get_root() != n2->get_root() && assume_eq(n1, n2)) { diff --git a/src/smt/theory_arith_core.h b/src/smt/theory_arith_core.h index 159758c80a1..4e4464a5226 100644 --- a/src/smt/theory_arith_core.h +++ b/src/smt/theory_arith_core.h @@ -17,7 +17,7 @@ Revision History: --*/ #pragma once - +// clang-format off #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/ast_smt2_pp.h" @@ -1167,10 +1167,7 @@ namespace smt { --i; } } - CTRACE("arith", atoms.size() > 1, - for (unsigned i = 0; i < atoms.size(); ++i) { - atoms[i]->display(*this, tout); tout << "\n"; - }); + CTRACE("arith", atoms.size() > 1, for (auto* a : atoms) a->display(*this, tout) << "\n";); ptr_vector occs(m_var_occs[v]); std::sort(atoms.begin(), atoms.end(), compare_atoms()); @@ -1277,7 +1274,7 @@ namespace smt { template bool theory_arith::internalize_atom(app * n, bool gate_ctx) { - TRACE("arith_internalize", tout << "internalizing atom:\n" << mk_pp(n, m) << "\n";); + TRACE("arith_internalize", tout << "internalizing atom:\n" << mk_bounded_pp(n, m) << "\n";); SASSERT(m_util.is_le(n) || m_util.is_ge(n) || m_util.is_is_int(n)); SASSERT(!ctx.b_internalized(n)); atom_kind kind; @@ -1302,7 +1299,8 @@ namespace smt { app * lhs = to_app(n->get_arg(0)); app * rhs = to_app(n->get_arg(1)); expr * rhs2; - if (m_util.is_to_real(rhs, rhs2) && is_app(rhs2)) { rhs = to_app(rhs2); } + if (m_util.is_to_real(rhs, rhs2) && is_app(rhs2)) + rhs = to_app(rhs2); if (!m_util.is_numeral(rhs)) { throw default_exception("malformed atomic constraint"); } @@ -1335,16 +1333,14 @@ namespace smt { occs.push_back(a); m_atoms.push_back(a); insert_bv2a(bv, a); - TRACE("arith_internalize", tout << "succeeded... v" << v << " " << kind << " " << k << "\n"; - for (unsigned i = 0; i + 1 < occs.size(); ++i) tout << occs[i] << "\n";); + TRACE("arith_internalize", tout << "succeeded... v" << v << " " << kind << " " << k << "\n"); return true; } template bool theory_arith::internalize_term(app * term) { - TRACE("arith_internalize", tout << "internalising term:\n" << mk_pp(term, m) << "\n";); theory_var v = internalize_term_core(term); - TRACE("arith_internalize", tout << "theory_var: " << v << "\n";); + TRACE("arith_internalize", tout << "internalising term: v" << v << " " << mk_bounded_pp(term, m) << "\n";); return v != null_theory_var; } @@ -1375,9 +1371,9 @@ namespace smt { template void theory_arith::assign_eh(bool_var v, bool is_true) { - TRACE("arith_verbose", tout << "p" << v << " := " << (is_true?"true":"false") << "\n";); atom * a = get_bv2a(v); if (!a) return; + TRACE("arith", tout << "assign p" << literal(v,!is_true) << " : " << mk_bounded_pp(ctx.bool_var2expr(v), m) << "\n";); SASSERT(ctx.get_assignment(a->get_bool_var()) != l_undef); SASSERT((ctx.get_assignment(a->get_bool_var()) == l_true) == is_true); a->assign_eh(is_true, get_epsilon(a->get_var())); @@ -1401,9 +1397,7 @@ namespace smt { template void theory_arith::new_eq_eh(theory_var v1, theory_var v2) { - TRACE("arith_new_eq_eh", tout << "#" << get_enode(v1)->get_owner_id() << " = #" << get_enode(v2)->get_owner_id() << "\n";); - TRACE("arith_new_eq_eh_detail", tout << mk_pp(get_enode(v1)->get_expr(), m) << "\n" << - mk_pp(get_enode(v2)->get_expr(), m) << "\n";); + TRACE("arith_new_eq_eh", tout << ctx.pp(get_enode(v1)) << "\n" << ctx.pp(get_enode(v2)) << "\n";); enode * n1 = get_enode(v1); @@ -1503,7 +1497,7 @@ namespace smt { TRACE("arith", tout << "check_int_feasibility(), ok: " << ok << "\n";); break; case 1: - if (assume_eqs_core()) + if (assume_eqs()) ok = FC_CONTINUE; else ok = FC_DONE; @@ -1571,7 +1565,6 @@ namespace smt { template void theory_arith::propagate() { - TRACE("arith_propagate", tout << "propagate\n"; display(tout);); if (!process_atoms()) return; propagate_core(); @@ -1597,9 +1590,9 @@ namespace smt { failed(); return false; } - if (ctx.get_cancel_flag()) { + if (ctx.get_cancel_flag()) return true; - } + CASSERT("arith", satisfy_bounds()); discard_update_trail(); @@ -2486,7 +2479,7 @@ namespace smt { set_bound(b, false); if (propagation_mode() != bound_prop_mode::BP_NONE) - mark_rows_for_bound_prop(v); + add_column_rows_to_touched_rows(v); return true; } @@ -2534,7 +2527,7 @@ namespace smt { set_bound(b, true); if (propagation_mode() != bound_prop_mode::BP_NONE) - mark_rows_for_bound_prop(v); + add_column_rows_to_touched_rows(v); return true; } @@ -2603,7 +2596,7 @@ namespace smt { \brief Mark all rows that contain v for bound propagation. */ template - void theory_arith::mark_rows_for_bound_prop(theory_var v) { + void theory_arith::add_column_rows_to_touched_rows(theory_var v) { for (col_entry const& ce : m_columns[v]) { if (!ce.is_dead()) mark_row_for_bound_prop(ce.m_row_id); @@ -2812,7 +2805,8 @@ namespace smt { template void theory_arith::explain_bound(row const & r, int idx, bool is_lower, inf_numeral & delta, antecedents& ante) { SASSERT(delta >= inf_numeral::zero()); - TRACE("arith_conflict", tout << "relax: " << relax_bounds() << " lits: " << ante.lits().size() << " eqs: " << ante.eqs().size() << " idx: " << idx << "\n";); + TRACE("arith_conflict", tout << "delta: " << delta << " relax: " << relax_bounds() << " lits: " << ante.lits().size() << " eqs: " << ante.eqs().size() << " idx: " << idx << "\n";); + if (!relax_bounds() && (!ante.lits().empty() || !ante.eqs().empty())) { return; } @@ -3002,12 +2996,10 @@ namespace smt { TRACE("propagate_bounds", ante.display(tout) << " --> "; - ctx.display_detailed_literal(tout, l); - tout << "\n";); - - + ctx.display_detailed_literal(tout, l) << "\n"); - TRACE("arith", tout << ctx.get_scope_level() << "\n"; + TRACE("arith", tout << "@" << ctx.get_scope_level() << ": "; + ante.display(tout) << " --> "; ctx.display_detailed_literal(tout, l) << "\n"); if (ante.lits().size() < small_lemma_size() && ante.eqs().empty()) { @@ -3078,7 +3070,6 @@ namespace smt { } } - TRACE("arith_eq", tout << "done\n";); m_to_check.reset(); m_in_to_check.reset(); } @@ -3108,7 +3099,7 @@ namespace smt { TRACE("arith_conflict", if (proof_rule) tout << proof_rule << "\n"; - tout << "scope: " << ctx.get_scope_level() << "\n"; + tout << "@" << ctx.get_scope_level() << "\n"; for (unsigned i = 0; i < num_literals; i++) { ctx.display_detailed_literal(tout, lits[i]); tout << " "; @@ -3392,7 +3383,7 @@ namespace smt { } template - void theory_arith::pop_scope_eh(unsigned num_scopes) { + void theory_arith::pop_scope_eh(unsigned num_scopes) { CASSERT("arith", wf_rows()); CASSERT("arith", wf_columns()); CASSERT("arith", valid_row_assignment()); diff --git a/src/smt/theory_arith_eq.h b/src/smt/theory_arith_eq.h index cffac245988..ebdd6e73eec 100644 --- a/src/smt/theory_arith_eq.h +++ b/src/smt/theory_arith_eq.h @@ -43,7 +43,6 @@ namespace smt { return; numeral const & val = lower_bound(v).get_rational(); value_sort_pair key(val, is_int_src(v)); - TRACE("arith_eq", tout << mk_pp(get_enode(v)->get_expr(), get_manager()) << " = " << val << "\n";); theory_var v2; if (m_fixed_var_table.find(key, v2)) { if (v2 < static_cast(get_num_vars()) && is_fixed(v2) && lower_bound(v2).get_rational() == val) { @@ -310,26 +309,22 @@ namespace smt { } // add new entry m_var_offset2row_id.insert(key, rid); - } - + } } template void theory_arith::propagate_eq_to_core(theory_var x, theory_var y, antecedents& antecedents) { // Ignore equality if variables are already known to be equal. - ast_manager& m = get_manager(); - (void)m; if (is_equal(x, y)) return; - // I doesn't make sense to propagate an equality (to the core) of variables of different sort. - if (var2expr(x)->get_sort() != var2expr(y)->get_sort()) { - TRACE("arith", tout << mk_pp(var2expr(x), m) << " = " << mk_pp(var2expr(y), m) << "\n";); - return; - } - context & ctx = get_context(); enode * _x = get_enode(x); enode * _y = get_enode(y); + // I doesn't make sense to propagate an equality (to the core) of variables of different sort. + CTRACE("arith", _x->get_sort() != _y->get_sort(), tout << enode_pp(_x, ctx) << " = " << enode_pp(_y, ctx) << "\n"); + if (_x->get_sort() != _y->get_sort()) + return; + eq_vector const& eqs = antecedents.eqs(); literal_vector const& lits = antecedents.lits(); justification * js = @@ -346,9 +341,9 @@ namespace smt { for (literal lit : lits) ctx.display_detailed_literal(tout, lit) << "\n"; for (auto const& p : eqs) - tout << pp(p.first, m) << " = " << pp(p.second, m) << "\n"; + tout << enode_pp(p.first, ctx) << " = " << enode_pp(p.second, ctx) << "\n"; tout << " ==> "; - tout << pp(_x, m) << " = " << pp(_y, m) << "\n";); + tout << enode_pp(_x, ctx) << " = " << enode_pp(_y, ctx) << "\n";); ctx.assign_eq(_x, _y, eq_justification(js)); } }; diff --git a/src/smt/theory_arith_int.h b/src/smt/theory_arith_int.h index c9bc9f31abb..75c8785a7a6 100644 --- a/src/smt/theory_arith_int.h +++ b/src/smt/theory_arith_int.h @@ -16,6 +16,7 @@ Module Name: Revision History: --*/ +// clang-format off #pragma once #include "util/numeral_buffer.h" @@ -98,7 +99,7 @@ namespace smt { theory_var result = null_theory_var; numeral range; numeral new_range; - numeral small_range_thresold(1024); + numeral small_range_threshold(1024); unsigned n = 0; for (row const& row : m_rows) { theory_var v = row.get_base_var(); @@ -116,7 +117,7 @@ namespace smt { numeral const & u = upper_bound(v).get_rational(); new_range = u; new_range -= l; - if (new_range > small_range_thresold) { + if (new_range > small_range_threshold) { // } else if (result == null_theory_var || new_range < range) { @@ -145,13 +146,10 @@ namespace smt { */ template theory_var theory_arith::find_infeasible_int_base_var() { - theory_var v = find_bounded_infeasible_int_base_var(); - if (v != null_theory_var) { - TRACE("find_infeasible_int_base_var", display_var(tout, v);); - return v; - } + theory_var r = find_bounded_infeasible_int_base_var(); + CTRACE("find_infeasible_int_base_var", r != null_theory_var, display_var(tout << "bounded infeasible", r);); + unsigned n = 0; - theory_var r = null_theory_var; #define SELECT_VAR(VAR) if (r == null_theory_var) { n = 1; r = VAR; } else { n++; SASSERT(n >= 2); if (m_random() % n == 0) r = VAR; } @@ -171,6 +169,7 @@ namespace smt { } } } + CTRACE("find_infeasible_int_base_var", r != null_theory_var, tout << "found small value v" << r << "\n"); } if (r == null_theory_var) { @@ -180,6 +179,8 @@ namespace smt { SELECT_VAR(v); } } + CTRACE("find_infeasible_int_base_var", r != null_theory_var, tout << "found base v" << r << "\n"); + } if (r == null_theory_var) { @@ -190,6 +191,7 @@ namespace smt { SELECT_VAR(v); } } + CTRACE("find_infeasible_int_base_var", r != null_theory_var, tout << "found quasi base v" << r << "\n"); } CASSERT("arith", wf_rows()); CASSERT("arith", wf_columns()); @@ -435,19 +437,11 @@ namespace smt { */ template bool theory_arith::is_gomory_cut_target(row const & r) { - TRACE("gomory_cut", r.display(tout);); theory_var b = r.get_base_var(); - typename vector::const_iterator it = r.begin_entries(); - typename vector::const_iterator end = r.end_entries(); - for (; it != end; ++it) { + for (auto& e : r) { // All non base variables must be at their bounds and assigned to rationals (that is, infinitesimals are not allowed). - if (!it->is_dead() && it->m_var != b && (!at_bound(it->m_var) || !get_value(it->m_var).is_rational())) { - TRACE("gomory_cut", tout << "row is not gomory cut target:\n"; - display_var(tout, it->m_var); - tout << "at_bound: " << at_bound(it->m_var) << "\n"; - tout << "infinitesimal: " << !get_value(it->m_var).is_rational() << "\n";); + if (!e.is_dead() && e.m_var != b && (!at_bound(e.m_var) || !get_value(e.m_var).is_rational())) return false; - } } return true; } @@ -509,7 +503,7 @@ namespace smt { theory_var x_i = r.get_base_var(); SASSERT(is_int(x_i)); - // The following assertion is wrong. It may be violated in mixed-real-interger problems. + // The following assertion is wrong. It may be violated in mixed-real-integer problems. // The check is_gomory_cut_target will discard rows where any variable contains infinitesimals. // SASSERT(m_value[x_i].is_rational()); // infinitesimals are not used for integer variables SASSERT(!m_value[x_i].is_int()); // the base variable is not assigned to an integer value. @@ -540,12 +534,10 @@ namespace smt { numeral lcm_den(1); unsigned num_ints = 0; - typename vector::const_iterator it = r.begin_entries(); - typename vector::const_iterator end = r.end_entries(); - for (; it != end; ++it) { - if (!it->is_dead() && it->m_var != x_i) { - theory_var x_j = it->m_var; - numeral a_ij = it->m_coeff; + for (row_entry const& e : r) { + if (!e.is_dead() && e.m_var != x_i) { + theory_var x_j = e.m_var; + numeral a_ij = e.m_coeff; a_ij.neg(); // make the used format compatible with the format used in: Integrating Simplex with DPLL(T) if (is_real(x_j)) { numeral new_a_ij; @@ -708,38 +700,36 @@ namespace smt { numeral gcds(0); numeral least_coeff(0); bool least_coeff_is_bounded = false; - typename vector::const_iterator it = r.begin_entries(); - typename vector::const_iterator end = r.end_entries(); - for (; it != end; ++it) { - if (!it->is_dead()) { - if (is_fixed(it->m_var)) { - // WARNING: it is not safe to use get_value(it->m_var) here, since - // get_value(it->m_var) may not satisfy it->m_var bounds at this point. - numeral aux = lcm_den * it->m_coeff; - consts += aux * lower_bound(it->m_var).get_rational(); + for (row_entry const& e : r) { + if (!e.is_dead()) { + if (is_fixed(e.m_var)) { + // WARNING: it is not safe to use get_value(e.m_var) here, since + // get_value(e.m_var) may not satisfy e.m_var bounds at this point. + numeral aux = lcm_den * e.m_coeff; + consts += aux * lower_bound(e.m_var).get_rational(); } - else if (is_real(it->m_var)) { + else if (is_real(e.m_var)) { return true; } else if (gcds.is_zero()) { - gcds = abs(lcm_den * it->m_coeff); + gcds = abs(lcm_den * e.m_coeff); least_coeff = gcds; - least_coeff_is_bounded = is_bounded(it->m_var); + least_coeff_is_bounded = is_bounded(e.m_var); } else { - numeral aux = abs(lcm_den * it->m_coeff); + numeral aux = abs(lcm_den * e.m_coeff); gcds = gcd(gcds, aux); if (aux < least_coeff) { least_coeff = aux; - least_coeff_is_bounded = is_bounded(it->m_var); + least_coeff_is_bounded = is_bounded(e.m_var); } else if (least_coeff_is_bounded && aux == least_coeff) { - least_coeff_is_bounded = is_bounded(it->m_var); + least_coeff_is_bounded = is_bounded(e.m_var); } } SASSERT(gcds.is_int()); SASSERT(least_coeff.is_int()); - TRACE("gcd_test_bug", tout << "coeff: " << it->m_coeff << ", gcds: " << gcds + TRACE("gcd_test_bug", tout << "coeff: " << e.m_coeff << ", gcds: " << gcds << " least_coeff: " << least_coeff << " consts: " << consts << "\n";); } } @@ -789,14 +779,11 @@ namespace smt { antecedents ante(*this); - - typename vector::const_iterator it = r.begin_entries(); - typename vector::const_iterator end = r.end_entries(); - for (; it != end; ++it) { - if (!it->is_dead() && !is_fixed(it->m_var)) { - theory_var v = it->m_var; + for (auto const& e : r) { + if (!e.is_dead() && !is_fixed(e.m_var)) { + theory_var v = e.m_var; SASSERT(!is_real(v)); - numeral ncoeff = lcm_den * it->m_coeff; + numeral ncoeff = lcm_den * e.m_coeff; SASSERT(ncoeff.is_int()); numeral abs_ncoeff = abs(ncoeff); if (abs_ncoeff == least_coeff) { @@ -813,8 +800,8 @@ namespace smt { // u += ncoeff * lower_bound(v).get_rational(); u.addmul(ncoeff, lower_bound(v).get_rational()); } - lower(v)->push_justification(ante, it->m_coeff, coeffs_enabled()); - upper(v)->push_justification(ante, it->m_coeff, coeffs_enabled()); + lower(v)->push_justification(ante, e.m_coeff, coeffs_enabled()); + upper(v)->push_justification(ante, e.m_coeff, coeffs_enabled()); } else if (gcds.is_zero()) { gcds = abs_ncoeff; diff --git a/src/smt/theory_arith_nl.h b/src/smt/theory_arith_nl.h index f44516cad0f..0a2b6e938f3 100644 --- a/src/smt/theory_arith_nl.h +++ b/src/smt/theory_arith_nl.h @@ -765,10 +765,8 @@ typename theory_arith::numeral theory_arith::get_monomial_fixed_var_pr template expr * theory_arith::get_monomial_non_fixed_var(expr * m) const { SASSERT(is_pure_monomial(m)); - for (unsigned i = 0; i < to_app(m)->get_num_args(); i++) { - expr * arg = to_app(m)->get_arg(i); - theory_var _var = expr2var(arg); - if (!is_fixed(_var)) + for (expr* arg : *to_app(m)) { + if (!is_fixed(expr2var(arg))) return arg; } return nullptr; @@ -780,7 +778,7 @@ expr * theory_arith::get_monomial_non_fixed_var(expr * m) const { */ template bool theory_arith::propagate_linear_monomial(theory_var v) { - TRACE("non_linear", tout << "checking whether v" << v << " became linear...\n";); + TRACE("non_linear_verbose", tout << "checking whether v" << v << " became linear...\n";); if (m_data[v].m_nl_propagated) return false; // already propagated this monomial. expr * m = var2expr(v); @@ -819,6 +817,11 @@ bool theory_arith::propagate_linear_monomial(theory_var v) { ctx.mark_as_relevant(rhs); } TRACE("non_linear_bug", tout << "enode: " << ctx.get_enode(rhs) << " enode_id: " << ctx.get_enode(rhs)->get_owner_id() << "\n";); + IF_VERBOSE(3, + for (auto* arg : *to_app(m)) + if (is_fixed(expr2var(arg))) + verbose_stream() << mk_pp(arg, get_manager()) << " = " << -k << "\n"); + theory_var new_v = expr2var(rhs); SASSERT(new_v != null_theory_var); new_lower = alloc(derived_bound, new_v, inf_numeral(0), B_LOWER); @@ -902,9 +905,11 @@ bool theory_arith::propagate_linear_monomial(theory_var v) { */ template bool theory_arith::propagate_linear_monomials() { + if (!m_params.m_nl_arith_propagate_linear_monomials) + return false; if (!reflection_enabled()) return false; - TRACE("non_linear", tout << "propagating linear monomials...\n";); + TRACE("non_linear_verbose", tout << "propagating linear monomials...\n";); bool p = false; // CMW: m_nl_monomials can grow during this loop, so // don't use iterators. @@ -2278,6 +2283,8 @@ typename theory_arith::gb_result theory_arith::compute_grobner(svector */ template bool theory_arith::max_min_nl_vars() { + if (!m_params.m_nl_arith_optimize_bounds) + return true; var_set already_found; svector vars; for (theory_var v : m_nl_monomials) { @@ -2360,7 +2367,7 @@ final_check_status theory_arith::process_non_linear() { } break; case 1: - if (!is_cross_nested_consistent(vars)) + if (m_params.m_nl_arith_cross_nested && !is_cross_nested_consistent(vars)) progress = true; break; case 2: diff --git a/src/smt/theory_arith_pp.h b/src/smt/theory_arith_pp.h index b0d43bc00ac..edc640d86aa 100644 --- a/src/smt/theory_arith_pp.h +++ b/src/smt/theory_arith_pp.h @@ -83,9 +83,11 @@ namespace smt { template void theory_arith::display_row(std::ostream & out, row const & r, bool compact) const { - + if (static_cast(r.get_base_var()) >= m_columns.size()) + return; column const & c = m_columns[r.get_base_var()]; - out << "(v" << r.get_base_var() << " r" << c[0].m_row_id << ") : "; + if (c.size() > 0) + out << "(v" << r.get_base_var() << " r" << c[0].m_row_id << ") : "; bool first = true; for (auto const& e : r) { if (!e.is_dead()) { diff --git a/src/smt/theory_array_base.cpp b/src/smt/theory_array_base.cpp index 6c2f4038f31..b766451dfc5 100644 --- a/src/smt/theory_array_base.cpp +++ b/src/smt/theory_array_base.cpp @@ -969,7 +969,6 @@ namespace smt { } model_value_proc * theory_array_base::mk_value(enode * n, model_generator & mg) { - SASSERT(ctx.is_relevant(n)); theory_var v = n->get_th_var(get_id()); SASSERT(v != null_theory_var); sort * s = n->get_expr()->get_sort(); diff --git a/src/smt/theory_array_full.cpp b/src/smt/theory_array_full.cpp index 079c2f62e0b..345663b6bf6 100644 --- a/src/smt/theory_array_full.cpp +++ b/src/smt/theory_array_full.cpp @@ -546,7 +546,7 @@ namespace smt { expr_ref def2(m.mk_app(f, args2.size(), args2.data()), m); ctx.get_rewriter()(def2); - expr* def1 = mk_default(map); + expr_ref def1(mk_default(map), m); ctx.internalize(def1, false); ctx.internalize(def2, false); return try_assign_eq(def1, def2); @@ -561,7 +561,7 @@ namespace smt { SASSERT(is_const(cnst)); TRACE("array", tout << mk_bounded_pp(cnst->get_expr(), m) << "\n";); expr* val = cnst->get_arg(0)->get_expr(); - expr* def = mk_default(cnst->get_expr()); + expr_ref def(mk_default(cnst->get_expr()), m); ctx.internalize(def, false); return try_assign_eq(val, def); } @@ -598,7 +598,7 @@ namespace smt { return false; m_stats.m_num_default_lambda_axiom++; expr* e = arr->get_expr(); - expr* def = mk_default(e); + expr_ref def(mk_default(e), m); quantifier* lam = m.is_lambda_def(arr->get_decl()); TRACE("array", tout << mk_pp(lam, m) << "\n" << mk_pp(e, m) << "\n"); expr_ref_vector args(m); diff --git a/src/smt/theory_bv.cpp b/src/smt/theory_bv.cpp index 7adab35f409..55d3a1d6245 100644 --- a/src/smt/theory_bv.cpp +++ b/src/smt/theory_bv.cpp @@ -429,12 +429,9 @@ namespace smt { }; void theory_bv::add_fixed_eq(theory_var v1, theory_var v2) { - if (!params().m_bv_eq_axioms) - return; - if (v1 > v2) { + if (v1 > v2) std::swap(v1, v2); - } unsigned act = m_eq_activity[hash_u_u(v1, v2) & 0xFF]++; if ((act & 0xFF) != 0xFF) { @@ -1163,8 +1160,6 @@ namespace smt { } void theory_bv::expand_diseq(theory_var v1, theory_var v2) { - if (!params().m_bv_eq_axioms) - return; SASSERT(get_bv_size(v1) == get_bv_size(v2)); if (v1 > v2) { @@ -1331,29 +1326,27 @@ namespace smt { } else { ctx.assign(consequent, mk_bit_eq_justification(v1, v2, consequent, antecedent)); - if (params().m_bv_eq_axioms) { - - literal_vector lits; - lits.push_back(~consequent); - lits.push_back(antecedent); - literal eq = mk_eq(get_expr(v1), get_expr(v2), false); - lits.push_back(~eq); - // - // Issue #3035: - // merge_eh invokes assign_bit, which updates the propagation queue and includes the - // theory axiom for the propagated equality. When relevancy is non-zero, propagation may get - // lost on backtracking because the propagation queue is reset on conflicts. - // An alternative approach is to ensure the propagation queue is chronological with - // backtracking scopes (ie., it doesn't get reset, but shrunk to a previous level, and similar - // with a qhead indicator. - // - ctx.mark_as_relevant(lits[0]); - ctx.mark_as_relevant(lits[1]); - ctx.mark_as_relevant(lits[2]); - { - scoped_trace_stream _sts(*this, lits); - ctx.mk_th_axiom(get_id(), lits.size(), lits.data()); - } + + literal_vector lits; + lits.push_back(~consequent); + lits.push_back(antecedent); + literal eq = mk_eq(get_expr(v1), get_expr(v2), false); + lits.push_back(~eq); + // + // Issue #3035: + // merge_eh invokes assign_bit, which updates the propagation queue and includes the + // theory axiom for the propagated equality. When relevancy is non-zero, propagation may get + // lost on backtracking because the propagation queue is reset on conflicts. + // An alternative approach is to ensure the propagation queue is chronological with + // backtracking scopes (ie., it doesn't get reset, but shrunk to a previous level, and similar + // with a qhead indicator. + // + ctx.mark_as_relevant(lits[0]); + ctx.mark_as_relevant(lits[1]); + ctx.mark_as_relevant(lits[2]); + { + scoped_trace_stream _sts(*this, lits); + ctx.mk_th_axiom(get_id(), lits.size(), lits.data()); } if (m_wpos[v2] == idx) @@ -1382,7 +1375,7 @@ namespace smt { } } } - + void theory_bv::relevant_eh(app * n) { TRACE("arith", tout << "relevant: #" << n->get_id() << " " << ctx.e_internalized(n) << ": " << mk_bounded_pp(n, m) << "\n";); TRACE("bv", tout << "relevant: #" << n->get_id() << " " << ctx.e_internalized(n) << ": " << mk_pp(n, m) << "\n";); @@ -1889,21 +1882,14 @@ namespace smt { return var_enode_pos(nullptr, UINT32_MAX); } - bool_var theory_bv::get_first_unassigned(unsigned start_bit, enode* n) const { + bool_var theory_bv::get_bit(unsigned bit, enode* n) const { theory_var v = n->get_th_var(get_family_id()); + if (v == null_theory_var) + return null_bool_var; auto& bits = m_bits[v]; - unsigned sz = bits.size(); - - for (unsigned i = start_bit; i < sz; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - for (unsigned i = 0; i < start_bit; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - - return null_bool_var; + if (bit >= bits.size()) + return null_bool_var; + return bits[bit].var(); } bool theory_bv::check_assignment(theory_var v) { diff --git a/src/smt/theory_bv.h b/src/smt/theory_bv.h index 73d659c688c..10cf005e343 100644 --- a/src/smt/theory_bv.h +++ b/src/smt/theory_bv.h @@ -291,7 +291,7 @@ namespace smt { bool is_fixed_propagated(theory_var v, expr_ref& val, literal_vector& explain) override; var_enode_pos get_bv_with_theory(bool_var v, theory_id id) const; - bool_var get_first_unassigned(unsigned start_bit, enode* n) const; + bool_var get_bit(unsigned bit, enode* n) const; bool check_assignment(theory_var v); bool check_invariant(); diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index d1216c1718b..b794a44b5c7 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -509,10 +509,10 @@ namespace smt { // Assuming `app` is equal to a constructor term, return the constructor enode inline enode * theory_datatype::oc_get_cstor(enode * app) { theory_var v = app->get_root()->get_th_var(get_id()); - SASSERT(v != null_theory_var); + if (v == null_theory_var) + return nullptr; v = m_find.find(v); var_data * d = m_var_data[v]; - SASSERT(d->m_constructor); return d->m_constructor; } @@ -802,8 +802,9 @@ namespace smt { return false; func_decl* con = m_util.get_accessor_constructor(f); for (enode* app : ctx.enodes_of(f)) { - enode* arg = app->get_arg(0)->get_root(); - if (is_constructor(arg) && arg->get_decl() != con) + enode* arg = app->get_arg(0); + enode* arg_con = oc_get_cstor(arg); + if (arg_con && is_constructor(arg_con) && arg_con->get_decl() != con) return true; } return false; @@ -914,7 +915,7 @@ namespace smt { } SASSERT(val == l_undef || (val == l_false && d->m_constructor == nullptr)); d->m_recognizers[c_idx] = recognizer; - m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); if (val == l_false) { propagate_recognizer(v, recognizer); } diff --git a/src/smt/theory_fpa.cpp b/src/smt/theory_fpa.cpp index 2ecc17c45e3..fb31a61ed1f 100644 --- a/src/smt/theory_fpa.cpp +++ b/src/smt/theory_fpa.cpp @@ -220,7 +220,7 @@ namespace smt { TRACE("t_fpa_detail", tout << "asserting " << mk_ismt2_pp(e, m) << "\n";); if (m.has_trace_stream()) log_axiom_instantiation(e); ctx.internalize(e, false); - if (m.has_trace_stream()) m.trace_stream() << "[end-of-instance]\n"; + if (m.has_trace_stream()) m.trace_stream() << "[end-of-instance]\n"; literal lit(ctx.get_literal(e)); ctx.mark_as_relevant(lit); ctx.mk_th_axiom(get_id(), 1, &lit); @@ -239,11 +239,11 @@ namespace smt { if (ctx.b_internalized(atom)) return true; - ctx.internalize(atom->get_args(), atom->get_num_args(), false); - literal l(ctx.mk_bool_var(atom)); ctx.set_var_theory(l.var(), get_id()); + ctx.internalize(atom->get_args(), atom->get_num_args(), false); + expr_ref bv_atom(m_rw.convert_atom(m_th_rw, atom)); expr_ref bv_atom_w_side_c(m), atom_eq(m); bv_atom_w_side_c = m.mk_and(bv_atom, mk_side_conditions()); diff --git a/src/smt/theory_fpa.h b/src/smt/theory_fpa.h index 9aa70d9bfbf..262a239dda2 100644 --- a/src/smt/theory_fpa.h +++ b/src/smt/theory_fpa.h @@ -121,7 +121,6 @@ namespace smt { void attach_new_th_var(enode * n); void assert_cnstr(expr * e); - enode* ensure_enode(expr* e); enode* get_root(expr* a) { return ensure_enode(a)->get_root(); } app* get_ite_value(expr* e); diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index c6bd12f03a1..f0a96ddd1ba 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -47,7 +47,7 @@ #include "util/scoped_timer.h" #include "util/distribution.h" -typedef lp::var_index lpvar; +typedef lp::lpvar lpvar; namespace smt { @@ -77,7 +77,6 @@ class theory_lra::imp { bool get_cancel_flag() override { return !m_imp.m.inc(); } }; - theory_lra& th; ast_manager& m; arith_util a; @@ -90,13 +89,11 @@ class theory_lra::imp { expr_ref_vector m_terms; vector m_coeffs; svector m_vars; - rational m_offset; ptr_vector m_to_ensure_enode, m_to_ensure_var; internalize_state(ast_manager& m): m_terms(m) {} void reset() { m_terms.reset(); m_coeffs.reset(); - m_offset.reset(); m_vars.reset(); m_to_ensure_enode.reset(); m_to_ensure_var.reset(); @@ -123,7 +120,6 @@ class theory_lra::imp { expr_ref_vector& terms() { return m_st.m_terms; } vector& coeffs() { return m_st.m_coeffs; } svector& vars() { return m_st.m_vars; } - rational& offset() { return m_st.m_offset; } ptr_vector& to_ensure_enode() { return m_st.m_to_ensure_enode; } ptr_vector& to_ensure_var() { return m_st.m_to_ensure_var; } void push(expr* e, rational c) { m_st.m_terms.push_back(e); m_st.m_coeffs.push_back(c); } @@ -171,13 +167,12 @@ class theory_lra::imp { svector > m_assume_eq_candidates; unsigned m_assume_eq_head; - lp::u_set m_tmp_var_set; + indexed_uint_set m_tmp_var_set; unsigned m_num_conflicts; // non-linear arithmetic scoped_ptr m_nla; - mutable scoped_ptr m_a1, m_a2; // integer arithmetic scoped_ptr m_lia; @@ -195,26 +190,17 @@ class theory_lra::imp { }; bool use_nra_model() const { - if (m_nla && m_nla->use_nra_model()) { - if (!m_a1) { - m_a1 = alloc(scoped_anum, m_nla->am()); - m_a2 = alloc(scoped_anum, m_nla->am()); - } - return true; - } - return false; + return m_nla && m_nla->use_nra_model(); } struct var_value_hash { imp & m_th; var_value_hash(imp & th):m_th(th) {} unsigned operator()(theory_var v) const { - if (m_th.use_nra_model()) { + if (m_th.use_nra_model()) return m_th.is_int(v); - } - else { + else return (unsigned)std::hash()(m_th.get_ivalue(v)); - } } }; int_hashtable m_model_eqs; @@ -228,6 +214,7 @@ class theory_lra::imp { lp_bounds m_new_bounds; symbol m_farkas; vector m_bound_params; + std_vector m_implied_bounds; lp::lp_bound_propagator m_bp; context& ctx() const { return th.get_context(); } @@ -239,12 +226,14 @@ class theory_lra::imp { bool is_real(enode* n) const { return a.is_real(n->get_expr()); } enode* get_enode(theory_var v) const { return th.get_enode(v); } enode* get_enode(expr* e) const { return ctx().get_enode(e); } - expr* get_owner(theory_var v) const { return get_enode(v)->get_expr(); } + expr* get_owner(theory_var v) const { return get_enode(v)->get_expr(); } + enode_pp pp(enode* n) const { return enode_pp(n, ctx()); } + enode_pp pp(theory_var v) const { return pp(get_enode(v)); } + mk_bounded_pp bpp(expr* e) { return mk_bounded_pp(e, m); } lpvar add_const(int c, lpvar& var, bool is_int) { - if (var != UINT_MAX) { + if (var != UINT_MAX) return var; - } app_ref cnst(a.mk_numeral(rational(c), is_int), m); mk_enode(cnst); theory_var v = mk_var(cnst); @@ -266,7 +255,7 @@ class theory_lra::imp { void ensure_nla() { if (!m_nla) { - m_nla = alloc(nla::solver, *m_solver.get(), m.limit()); + m_nla = alloc(nla::solver, *m_solver.get(), ctx().get_params(), m.limit()); for (auto const& _s : m_scopes) { (void)_s; m_nla->push(); @@ -276,24 +265,6 @@ class theory_lra::imp { return ctx().is_relevant(th.get_enode(u)); }; m_nla->set_relevant(is_relevant); - smt_params_helper prms(ctx().get_params()); - m_nla->settings().run_order = prms.arith_nl_order(); - m_nla->settings().run_tangents = prms.arith_nl_tangents(); - m_nla->settings().run_horner = prms.arith_nl_horner(); - m_nla->settings().horner_subs_fixed = prms.arith_nl_horner_subs_fixed(); - m_nla->settings().horner_frequency = prms.arith_nl_horner_frequency(); - m_nla->settings().horner_row_length_limit = prms.arith_nl_horner_row_length_limit(); - m_nla->settings().run_grobner = prms.arith_nl_grobner(); - m_nla->settings().run_nra = prms.arith_nl_nra(); - m_nla->settings().grobner_subs_fixed = prms.arith_nl_grobner_subs_fixed(); - m_nla->settings().grobner_eqs_growth = prms.arith_nl_grobner_eqs_growth(); - m_nla->settings().grobner_expr_size_growth = prms.arith_nl_grobner_expr_size_growth(); - m_nla->settings().grobner_expr_degree_growth = prms.arith_nl_grobner_expr_degree_growth(); - m_nla->settings().grobner_max_simplified = prms.arith_nl_grobner_max_simplified(); - m_nla->settings().grobner_number_of_conflicts_to_report = prms.arith_nl_grobner_cnfl_to_report(); - m_nla->settings().grobner_quota = prms.arith_nl_gr_q(); - m_nla->settings().grobner_frequency = prms.arith_nl_grobner_frequency(); - m_nla->settings().expensive_patching = false; } } @@ -345,18 +316,33 @@ class theory_lra::imp { st.push(rhs, rational::minus_one()); linearize(st); } + + theory_var internalize_numeral(app* n, rational const& val) { + + if (!ctx().e_internalized(n)) + mk_enode(n); + theory_var v = mk_var(n); + lpvar vi = get_lpvar(v); + if (vi == UINT_MAX) { + vi = lp().add_var(v, a.is_int(n)); + add_def_constraint_and_equality(vi, lp::GE, val); + add_def_constraint_and_equality(vi, lp::LE, val); + register_fixed_var(v, val); + } + return v; + } + void linearize(scoped_internalize_state& st) { expr_ref_vector & terms = st.terms(); svector& vars = st.vars(); vector& coeffs = st.coeffs(); - rational& offset = st.offset(); rational r; expr* n1, *n2; unsigned index = 0; while (index < terms.size()) { SASSERT(index >= vars.size()); - expr* n = terms[index].get(); + expr* n = terms.get(index); st.to_ensure_enode().push_back(n); if (a.is_add(n)) { for (expr* arg : *to_app(n)) { @@ -394,7 +380,9 @@ class theory_lra::imp { ++index; } else if (a.is_numeral(n, r)) { - offset += coeffs[index]*r; + theory_var v = internalize_numeral(to_app(n), r); + coeffs[vars.size()] = coeffs[index]; + vars.push_back(v); ++index; } else if (a.is_uminus(n, n1)) { @@ -562,7 +550,7 @@ class theory_lra::imp { theory_var v = mk_var(n); vars.push_back(register_theory_var_in_lar_solver(v)); } - TRACE("arith", tout << "v" << v << " := " << mk_pp(t, m) << "\n" << vars << "\n";); + TRACE("arith", tout << "v" << v << " := " << bpp(t) << "\n" << vars << "\n";); m_solver->register_existing_terms(); ensure_nla(); m_nla->add_monic(register_theory_var_in_lar_solver(v), vars.size(), vars.data()); @@ -571,7 +559,7 @@ class theory_lra::imp { } enode * mk_enode(app * n) { - TRACE("arith", tout << mk_bounded_pp(n, m) << " internalized: " << ctx().e_internalized(n) << "\n";); + TRACE("arith_verbose", tout << bpp(n) << " internalized: " << ctx().e_internalized(n) << "\n";); if (reflect(n)) for (expr* arg : *n) if (!ctx().e_internalized(arg)) @@ -617,20 +605,18 @@ class theory_lra::imp { } theory_var mk_var(expr* n) { - if (!ctx().e_internalized(n)) { + if (!ctx().e_internalized(n)) ctx().internalize(n, false); - } enode* e = get_enode(n); theory_var v; - if (!th.is_attached_to_var(e)) { + if (th.is_attached_to_var(e)) + v = e->get_th_var(get_id()); + else { v = th.mk_var(e); SASSERT(m_bounds.size() <= static_cast(v) || m_bounds[v].empty()); reserve_bounds(v); ctx().attach_th_var(e, &th, v); } - else { - v = e->get_th_var(get_id()); - } SASSERT(null_theory_var != v); return v; } @@ -651,12 +637,10 @@ class theory_lra::imp { for (unsigned i = 0; i < vars.size(); ++i) { theory_var var = vars[i]; rational const& coeff = coeffs[i]; - if (m_columns.size() <= static_cast(var)) { + if (m_columns.size() <= static_cast(var)) m_columns.setx(var, coeff, rational::zero()); - } - else { + else m_columns[var] += coeff; - } } m_left_side.clear(); // reset the coefficients after they have been used. @@ -664,7 +648,7 @@ class theory_lra::imp { theory_var var = vars[i]; rational const& r = m_columns[var]; if (!r.is_zero()) { - m_left_side.push_back(std::make_pair(r, register_theory_var_in_lar_solver(var))); + m_left_side.push_back({r, register_theory_var_in_lar_solver(var)}); m_columns[var].reset(); } } @@ -672,12 +656,7 @@ class theory_lra::imp { } bool all_zeros(vector const& v) const { - for (rational const& r : v) { - if (!r.is_zero()) { - return false; - } - } - return true; + return all_of(v, [](rational const& r) { return r.is_zero(); }); } void add_eq_constraint(lp::constraint_index index, enode* n1, enode* n2) { @@ -700,7 +679,6 @@ class theory_lra::imp { m_definitions.setx(index, v, null_theory_var); } - bool is_infeasible() const { return lp().get_status() == lp::lp_status::INFEASIBLE; } @@ -729,9 +707,8 @@ class theory_lra::imp { lpvar vi_equal; lp::constraint_index ci = lp().add_var_bound_check_on_equal(vi, kind, bound, vi_equal); add_def_constraint(ci); - if (vi_equal != lp::null_lpvar) { + if (vi_equal != lp::null_lpvar) report_equality_of_fixed_vars(vi, vi_equal); - } m_new_def = true; } @@ -745,22 +722,9 @@ class theory_lra::imp { theory_var z = internalize_linearized_def(term, st); lpvar vi = register_theory_var_in_lar_solver(z); add_def_constraint_and_equality(vi, lp::LE, rational::zero()); - if (is_infeasible()) { - IF_VERBOSE(0, verbose_stream() << "infeasible\n";); - // process_conflict(); // exit here? - } add_def_constraint_and_equality(vi, lp::GE, rational::zero()); - if (is_infeasible()) { - IF_VERBOSE(0, verbose_stream() << "infeasible\n";); - // process_conflict(); // exit here? - } TRACE("arith", - { - expr* o1 = get_enode(v1)->get_expr(); - expr* o2 = get_enode(v2)->get_expr(); - tout << "v" << v1 << " = " << "v" << v2 << ": " - << mk_pp(o1, m) << " = " << mk_pp(o2, m) << "\n"; - }); + tout << "v" << v1 << " = " << "v" << v2 << ": " << pp(v1) << " = " << pp(v2) << "\n"); } void del_bounds(unsigned old_size) { @@ -775,22 +739,15 @@ class theory_lra::imp { } void updt_unassigned_bounds(theory_var v, int inc) { - TRACE("arith", tout << "v" << v << " " << m_unassigned_bounds[v] << " += " << inc << "\n";); + TRACE("arith_verbose", tout << "v" << v << " " << m_unassigned_bounds[v] << " += " << inc << "\n";); ctx().push_trail(vector_value_trail(m_unassigned_bounds, v)); m_unassigned_bounds[v] += inc; } bool is_unit_var(scoped_internalize_state& st) { - return st.offset().is_zero() && st.vars().size() == 1 && st.coeffs()[0].is_one(); + return st.vars().size() == 1 && st.coeffs()[0].is_one(); } - bool is_one(scoped_internalize_state& st) { - return st.offset().is_one() && st.vars().empty(); - } - - bool is_zero(scoped_internalize_state& st) { - return st.offset().is_zero() && st.vars().empty(); - } theory_var internalize_def(app* term, scoped_internalize_state& st) { TRACE("arith", tout << expr_ref(term, m) << "\n";); @@ -833,54 +790,33 @@ class theory_lra::imp { return v == null_theory_var ? lp::null_lpvar : lp().external_to_local(v); } - lp::tv get_tv(theory_var v) const { - return lp::tv::raw(get_lpvar(v)); - } - theory_var internalize_linearized_def(app* term, scoped_internalize_state& st) { theory_var v = mk_var(term); - TRACE("arith", tout << mk_bounded_pp(term, m) << " v" << v << "\n";); + TRACE("arith_internalize", tout << "v" << v << " " << bpp(term) << "\n";); - if (is_unit_var(st) && v == st.vars()[0]) { + if (is_unit_var(st) && v == st.vars()[0]) return st.vars()[0]; - } - else if (is_one(st) && a.is_numeral(term)) { - return lp().local_to_external(get_one(a.is_int(term))); - } - else if (is_zero(st) && a.is_numeral(term)) { - return lp().local_to_external(get_zero(a.is_int(term))); - } - else { - init_left_side(st); - lpvar vi = get_lpvar(v); - if (vi == UINT_MAX) { - if (m_left_side.empty()) { - vi = lp().add_var(v, a.is_int(term)); - add_def_constraint_and_equality(vi, lp::GE, st.offset()); - add_def_constraint_and_equality(vi, lp::LE, st.offset()); - register_fixed_var(v, st.offset()); - return v; - } - if (!st.offset().is_zero()) { - m_left_side.push_back(std::make_pair(st.offset(), get_one(a.is_int(term)))); - } - if (m_left_side.empty()) { - vi = lp().add_var(v, a.is_int(term)); - add_def_constraint_and_equality(vi, lp::GE, rational(0)); - add_def_constraint_and_equality(vi, lp::LE, rational(0)); - } - else { - vi = lp().add_term(m_left_side, v); - SASSERT(lp::tv::is_term(vi)); - TRACE("arith_verbose", - tout << "v" << v << " := " << mk_pp(term, m) - << " slack: " << vi << " scopes: " << m_scopes.size() << "\n"; - lp().print_term(lp().get_term(lp::tv::raw(vi)), tout) << "\n";); - } - } - return v; + init_left_side(st); + lpvar vi = get_lpvar(v); + + if (vi == UINT_MAX) { + if (m_left_side.empty()) { + vi = lp().add_var(v, a.is_int(term)); + add_def_constraint_and_equality(vi, lp::GE, rational(0)); + add_def_constraint_and_equality(vi, lp::LE, rational(0)); + } + else { + vi = lp().add_term(m_left_side, v); + SASSERT(lp().column_has_term(vi)); + TRACE("arith_verbose", + tout << "v" << v << " := " << mk_pp(term, m) + << " slack: " << vi << " scopes: " << m_scopes.size() << "\n"; + lp().print_term(lp().get_term(vi), tout) << "\n";); + } } + + return v; } @@ -901,7 +837,7 @@ class theory_lra::imp { m_solver(nullptr), m_resource_limit(*this), m_farkas("farkas"), - m_bp(*this), + m_bp(*this, m_implied_bounds), m_bounded_range_idx(0), m_bounded_range_lit(null_literal), m_bound_terms(m), @@ -953,7 +889,7 @@ class theory_lra::imp { } bool internalize_atom(app * atom, bool gate_ctx) { - TRACE("arith", tout << mk_pp(atom, m) << "\n";); + TRACE("arith_internalize", tout << bpp(atom) << "\n";); SASSERT(!ctx().b_internalized(atom)); expr* n1, *n2; rational r; @@ -988,17 +924,16 @@ class theory_lra::imp { return true; } - if (is_int(v) && !r.is_int()) { + if (is_int(v) && !r.is_int()) r = (k == lp_api::upper_t) ? floor(r) : ceil(r); - } + api_bound* b = mk_var_bound(bv, v, k, r); m_bounds[v].push_back(b); updt_unassigned_bounds(v, +1); m_bounds_trail.push_back(v); m_bool_var2bound.insert(bv, b); - TRACE("arith_verbose", tout << "Internalized " << bv << ": " << mk_pp(atom, m) << "\n";); mk_bound_axioms(*b); - //add_use_lists(b); + TRACE("arith_internalize", tout << "Internalized " << bv << ": " << bpp(atom) << "\n";); return true; } @@ -1017,17 +952,19 @@ class theory_lra::imp { } void internalize_eq_eh(app * atom, bool_var) { + if (!ctx().get_fparams().m_arith_eager_eq_axioms) + return; expr* lhs = nullptr, *rhs = nullptr; VERIFY(m.is_eq(atom, lhs, rhs)); enode * n1 = get_enode(lhs); enode * n2 = get_enode(rhs); - TRACE("arith_verbose", tout << mk_pp(atom, m) << " " << is_arith(n1) << " " << is_arith(n2) << "\n";); + if (is_arith(n1) && is_arith(n2) && n1 != n2) m_arith_eq_adapter.mk_axioms(n1, n2); } void assign_eh(bool_var v, bool is_true) { - TRACE("arith", tout << mk_bounded_pp(ctx().bool_var2expr(v), m) << " " << (literal(v, !is_true)) << "\n";); + TRACE("arith", tout << "assign p" << literal(v, !is_true) << ": " << bpp(ctx().bool_var2expr(v)) << "\n";); m_asserted_atoms.push_back(delayed_atom(v, is_true)); } @@ -1072,16 +1009,14 @@ class theory_lra::imp { } void apply_sort_cnstr(enode* n, sort*) { - TRACE("arith", tout << "sort constraint: " << enode_pp(n, ctx()) << "\n";); + TRACE("arith", tout << "sort constraint: " << pp(n) << "\n";); #if 0 - if (!th.is_attached_to_var(n)) { + if (!th.is_attached_to_var(n)) mk_var(n->get_owner()); - } #endif } void push_scope_eh() { - TRACE("arith", tout << "push\n";); m_scopes.push_back(scope()); scope& sc = m_scopes.back(); sc.m_bounds_lim = m_bounds_trail.size(); @@ -1090,14 +1025,11 @@ class theory_lra::imp { lp().push(); if (m_nla) m_nla->push(); - } void pop_scope_eh(unsigned num_scopes) { - TRACE("arith", tout << "pop " << num_scopes << "\n";); - if (num_scopes == 0) { + if (num_scopes == 0) return; - } unsigned old_size = m_scopes.size() - num_scopes; del_bounds(m_scopes[old_size].m_bounds_lim); m_asserted_atoms.shrink(m_scopes[old_size].m_asserted_atoms_lim); @@ -1114,6 +1046,28 @@ class theory_lra::imp { void restart_eh() { m_arith_eq_adapter.restart_eh(); +#if 1 + // experiment + if (m_lia) { + std::function is_root = [&](unsigned j) { + theory_var v = lp().local_to_external(j); + if (v < 0) + return false; + auto* n = get_enode(v); + if (!th.is_relevant_and_shared(n)) + return false; + if (n->is_root()) + return true; + theory_var w = n->get_root()->get_th_var(get_id()); + return w == v; + }; + m_lia->simplify(is_root); + for (auto const& [i, j, e] : m_lia->equalities()) + add_eq(i, j, e, false); + } +#endif + if (m_nla) + m_nla->simplify(); } void relevant_eh(app* n) { @@ -1474,15 +1428,15 @@ class theory_lra::imp { register_theory_var_in_lar_solver(v); } - mutable vector> m_todo_terms; - + mutable vector> m_todo_terms; + lp::impq get_ivalue(theory_var v) const { SASSERT(is_registered_var(v)); - return lp().get_tv_ivalue(get_tv(v)); + return lp().get_column_value(get_lpvar(v)); } rational get_value(theory_var v) const { - return is_registered_var(v) ? lp().get_tv_value(get_tv(v)) : rational::zero(); + return is_registered_var(v) ? lp().get_value(get_lpvar(v)) : rational::zero(); } bool m_model_is_initialized{ false }; @@ -1499,8 +1453,7 @@ class theory_lra::imp { void random_update() { if (m_nla && m_nla->need_check()) return; - m_tmp_var_set.clear(); - m_tmp_var_set.resize(th.get_num_vars()); + m_tmp_var_set.reset(); m_model_eqs.reset(); svector vars; theory_var sz = static_cast(th.get_num_vars()); @@ -1510,8 +1463,8 @@ class theory_lra::imp { continue; } ensure_column(v); - lp::column_index vj = lp().to_column_index(v); - SASSERT(!vj.is_null()); + lp::lpvar vj = lp().external_to_local(v); + SASSERT(vj != lp::null_lpvar); theory_var other = m_model_eqs.insert_if_not_there(v); if (other == v) { continue; @@ -1519,14 +1472,14 @@ class theory_lra::imp { enode * n2 = get_enode(other); if (n1->get_root() == n2->get_root()) continue; - if (!lp().is_fixed(vj)) { - vars.push_back(vj.index()); + if (!lp().column_is_fixed(vj)) { + vars.push_back(vj); } else if (!m_tmp_var_set.contains(other) ) { - lp::column_index other_j = lp().to_column_index(other); - if (!lp().is_fixed(other_j)) { + lp::lpvar other_j = lp().external_to_local(other); + if (!lp().column_is_fixed(other_j)) { m_tmp_var_set.insert(other); - vars.push_back(other_j.index()); + vars.push_back(other_j); } } } @@ -1541,13 +1494,14 @@ class theory_lra::imp { } bool assume_eqs() { + if (delayed_assume_eqs()) return true; - - TRACE("arith", display(tout);); + + TRACE("arith_verbose", display(tout);); random_update(); m_model_eqs.reset(); - + theory_var sz = static_cast(th.get_num_vars()); unsigned old_sz = m_assume_eq_candidates.size(); unsigned num_candidates = 0; @@ -1555,30 +1509,23 @@ class theory_lra::imp { for (theory_var i = 0; i < sz; ++i) { theory_var v = (i + start) % sz; enode* n1 = get_enode(v); - if (!th.is_relevant_and_shared(n1)) { + if (!th.is_relevant_and_shared(n1)) continue; - } ensure_column(v); if (!is_registered_var(v)) - continue; + continue; theory_var other = m_model_eqs.insert_if_not_there(v); - TRACE("arith", tout << "insert: v" << v << " := " << get_value(v) << " found: v" << other << "\n";); - if (other == v) { + if (other == v) continue; - } enode* n2 = get_enode(other); - if (n1->get_root() != n2->get_root()) { - TRACE("arith", tout << pp(n1, m) << " = " << pp(n2, m) << "\n"; - tout << pp(n1, m) << " = " << pp(n2, m) << "\n"; - tout << "v" << v << " = " << "v" << other << "\n";); - m_assume_eq_candidates.push_back(std::make_pair(v, other)); - num_candidates++; - } + if (n1->get_root() == n2->get_root()) + continue; + m_assume_eq_candidates.push_back({v, other}); + num_candidates++; } - if (num_candidates > 0) { + if (num_candidates > 0) ctx().push_trail(restore_vector(m_assume_eq_candidates, old_sz)); - } return delayed_assume_eqs(); } @@ -1606,7 +1553,7 @@ class theory_lra::imp { bool is_eq(theory_var v1, theory_var v2) { if (use_nra_model()) - return m_nla->am().eq(nl_value(v1, *m_a1), nl_value(v2, *m_a2)); + return m_nla->am().eq(nl_value(v1, m_nla->tmp1()), nl_value(v2, m_nla->tmp2())); else return get_ivalue(v1) == get_ivalue(v2); } @@ -1623,12 +1570,11 @@ class theory_lra::imp { return FC_DONE; if (!m_nla) return FC_GIVEUP; - switch (m_nla->check_power(get_lpvar(e), get_lpvar(x), get_lpvar(y), m_nla_lemma_vector)) { + switch (m_nla->check_power(get_lpvar(e), get_lpvar(x), get_lpvar(y))) { case l_true: return FC_DONE; case l_false: - for (const nla::lemma & l : m_nla_lemma_vector) - false_case_of_check_nla(l); + add_lemmas(); return FC_CONTINUE; case l_undef: return FC_GIVEUP; @@ -1645,9 +1591,6 @@ class theory_lra::imp { return FC_DONE; return FC_GIVEUP; } - - unsigned m_final_check_idx = 0; - distribution m_dist { 0 }; final_check_status final_check_eh() { if (propagate_core()) @@ -1656,89 +1599,41 @@ class theory_lra::imp { IF_VERBOSE(12, verbose_stream() << "final-check " << lp().get_status() << "\n"); lbool is_sat = l_true; SASSERT(lp().ax_is_correct()); - if (!lp().is_feasible() || lp().has_changed_columns()) { + if (!lp().is_feasible() || lp().has_changed_columns()) is_sat = make_feasible(); - } - bool giveup = false; final_check_status st = FC_DONE; - m_final_check_idx = 0; // remove to experiment. - unsigned old_idx = m_final_check_idx; switch (is_sat) { case l_true: - TRACE("arith", display(tout)); - - // if (lp().has_fixed_at_bound()) // explain and propagate. - -#if 0 - m_dist.reset(); - m_dist.push(0, 1); - m_dist.push(1, 1); - m_dist.push(2, 1); - - for (auto idx : m_dist) { - if (!m.inc()) - return FC_GIVEUP; - - switch (idx) { - case 0: - if (assume_eqs()) - st = FC_CONTINUE; - break; - case 1: - st = check_nla(); - break; - case 2: - st = check_lia(); - break; - default: - UNREACHABLE(); - break; - } - switch (st) { - case FC_DONE: - break; - case FC_CONTINUE: - return st; - case FC_GIVEUP: - giveup = true; - break; - } - } - -#else - - do { - if (!m.inc()) - return FC_GIVEUP; + TRACE("arith", display(tout)); - switch (m_final_check_idx) { - case 0: - if (assume_eqs()) - st = FC_CONTINUE; - break; - case 1: - st = check_lia(); - break; - case 2: - st = check_nla(); - break; - } - m_final_check_idx = (m_final_check_idx + 1) % 3; - switch (st) { - case FC_DONE: - break; - case FC_CONTINUE: - return st; - case FC_GIVEUP: - giveup = true; - break; - } + switch (check_lia()) { + case FC_DONE: + break; + case FC_CONTINUE: + return FC_CONTINUE; + case FC_GIVEUP: + TRACE("arith", tout << "check-lia giveup\n";); + if (ctx().get_fparams().m_arith_ignore_int) + st = FC_CONTINUE; + break; } - while (old_idx != m_final_check_idx); -#endif - if (giveup) - return FC_GIVEUP; + switch (check_nla()) { + case FC_DONE: + break; + case FC_CONTINUE: + return FC_CONTINUE; + case FC_GIVEUP: + TRACE("arith", tout << "check-nra giveup\n";); + st = FC_GIVEUP; + break; + } + + if (assume_eqs()) { + ++m_stats.m_assume_eqs; + return FC_CONTINUE; + } + for (expr* e : m_not_handled) { if (!ctx().is_relevant(e)) continue; @@ -1791,13 +1686,13 @@ class theory_lra::imp { } } // create a bound atom representing term >= k is lower_bound is true, and term <= k if it is false - app_ref mk_bound(lp::lar_term const& term, rational const& k, bool lower_bound) { + expr_ref mk_bound(lp::lar_term const& term, rational const& k, bool lower_bound) { rational offset; expr_ref t(m); return mk_bound(term, k, lower_bound, offset, t); } - app_ref mk_bound(lp::lar_term const& term, rational const& k, bool lower_bound, rational& offset, expr_ref& t) { + expr_ref mk_bound(lp::lar_term const& term, rational const& k, bool lower_bound, rational& offset, expr_ref& t) { offset = k; u_map coeffs; term2coeffs(term, coeffs); @@ -1843,15 +1738,17 @@ class theory_lra::imp { // lp().print_term(term, tout << "term: ") << "\n"; // tout << "offset: " << offset << " gcd: " << g << "\n";); - app_ref atom(m); + expr_ref atom(m); t = coeffs2app(coeffs, rational::zero(), is_int); - if (lower_bound) { + if (lower_bound) atom = a.mk_ge(t, a.mk_numeral(offset, is_int)); - } - else { - atom = a.mk_le(t, a.mk_numeral(offset, is_int)); - } + else + atom = a.mk_le(t, a.mk_numeral(offset, is_int)); + // ctx().get_rewriter()(atom); + // Note: it is not safe to rewrite atom because the rewriter can + // destroy structure, such as (div x 24) >= 0 becomes x >= 0 and the internal variable + // corresponding to (div x 24) is not constrained. TRACE("arith", tout << t << ": " << atom << "\n"; lp().print_term(term, tout << "bound atom: ") << (lower_bound?" >= ":" <= ") << k << "\n";); ctx().internalize(atom, true); @@ -1873,11 +1770,9 @@ class theory_lra::imp { bool check_idiv_bounds() { if (!m_nla) return true; - m_nla_lemma_vector.reset(); - m_nla->check_bounded_divisions(m_nla_lemma_vector); - for (auto & lemma : m_nla_lemma_vector) - false_case_of_check_nla(lemma); - return m_nla_lemma_vector.empty(); + m_nla->check_bounded_divisions(); + add_lemmas(); + return m_nla->lemmas().empty(); } expr_ref var2expr(lpvar v) { @@ -1895,12 +1790,12 @@ class theory_lra::imp { expr_ref t(m); expr_ref_vector ts(m); for (lp::lar_term::ival p : term) { - auto ti = lp().column2tv(p.column()); - if (ti.is_term()) { + auto ti = p.j(); + if (lp().column_has_term(ti)) { ts.push_back(multerm(p.coeff(), term2expr(lp().get_term(ti)))); } else { - ts.push_back(multerm(p.coeff(), var2expr(ti.id()))); + ts.push_back(multerm(p.coeff(), var2expr(ti))); } } if (ts.size() == 1) { @@ -1937,13 +1832,13 @@ class theory_lra::imp { lp().print_term(term, out << "bound: "); out << (upper?" <= ":" >= ") << k << "\n"; for (lp::lar_term::ival p : term) { - auto ti = lp().column2tv(p.column()); + auto ti = p.j(); out << p.coeff() << " * "; - if (ti.is_term()) { + if (lp().column_has_term(ti)) { lp().print_term(lp().get_term(ti), out) << "\n"; } else { - out << "v" << lp().local_to_external(ti.id()) << "\n"; + out << "v" << lp().local_to_external(ti) << "\n"; } } for (auto ev : ex) { @@ -1983,12 +1878,11 @@ class theory_lra::imp { case lp::lia_move::branch: { TRACE("arith", tout << "branch\n";); - app_ref b(m); bool u = m_lia->is_upper(); auto const & k = m_lia->get_offset(); rational offset; expr_ref t(m); - b = mk_bound(m_lia->get_term(), k, !u, offset, t); + expr_ref b = mk_bound(m_lia->get_term(), k, !u, offset, t); if (m.has_trace_stream()) { app_ref body(m); body = m.mk_or(b, m.mk_not(b)); @@ -2016,7 +1910,7 @@ class theory_lra::imp { } // The call mk_bound() can set the m_infeasible_column in lar_solver // so the explanation is safer to take before this call. - app_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper()); + expr_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper()); if (m.has_trace_stream()) { th.log_axiom_instantiation(b); m.trace_stream() << "[end-of-instance]\n"; @@ -2045,57 +1939,63 @@ class theory_lra::imp { } if (!check_idiv_bounds()) return FC_CONTINUE; - - if (assume_eqs()) - return FC_CONTINUE; return FC_DONE; } nla::lemma m_lemma; - + + literal mk_literal(nla::ineq const& ineq) { + bool is_lower = true, pos = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; pos = false; break; + case lp::LT: is_lower = true; pos = true; break; + case lp::GE: is_lower = true; pos = false; break; + case lp::GT: is_lower = false; pos = true; break; + case lp::EQ: is_eq = true; pos = false; break; + case lp::NE: is_eq = true; pos = true; break; + default: UNREACHABLE(); + } + TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); + expr_ref atom(m); + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + if (is_eq) + atom = mk_eq(ineq.term(), ineq.rs()); + else + // create term >= 0 (or term <= 0) + atom = mk_bound(ineq.term(), ineq.rs(), is_lower); + return literal(ctx().get_bool_var(atom), pos); + } + void false_case_of_check_nla(const nla::lemma & l) { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - app_ref atom(m); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - if (is_eq) { - atom = mk_eq(ineq.term(), ineq.rs()); - } - else { - // create term >= 0 (or term <= 0) - atom = mk_bound(ineq.term(), ineq.rs(), is_lower); - } - literal lit(ctx().get_bool_var(atom), pos); + auto lit = mk_literal(ineq); core.push_back(~lit); } set_conflict_or_lemma(core, false); } + + void assume_literal(nla::ineq const& i) { + auto lit = mk_literal(i); + ctx().mark_as_relevant(lit); + ctx().set_true_first_flag(lit.var()); + } final_check_status check_nla_continue() { - m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); +#if Z3DEBUG + flet f(lp().validate_blocker(), true); +#endif + lbool r = m_nla->check(); switch (r) { - case l_false: - for (const nla::lemma & l : m_nla_lemma_vector) - false_case_of_check_nla(l); + case l_false: + add_lemmas(); return FC_CONTINUE; case l_true: - return assume_eqs()? FC_CONTINUE: FC_DONE; + return FC_DONE; default: return FC_GIVEUP; } @@ -2133,7 +2033,7 @@ class theory_lra::imp { enode * n = get_enode(v); enode * r = n->get_root(); unsigned usz = m_underspecified.size(); - TRACE("shared", tout << ctx().get_scope_level() << " " << v << " " << r->get_num_parents() << "\n";); + TRACE("shared", tout << ctx().get_scope_level() << " " << enode_pp(n, ctx()) << " " << v << " underspecified " << usz << " parents " << r->get_num_parents() << "\n";); if (r->get_num_parents() > 2*usz) { for (unsigned i = 0; i < usz; ++i) { app* u = m_underspecified[i]; @@ -2181,29 +2081,31 @@ class theory_lra::imp { bool propagate_core() { m_model_is_initialized = false; flush_bound_axioms(); - if (!can_propagate_core()) + propagate_nla(); + if (ctx().inconsistent()) + return true; + if (!can_propagate_core()) return false; + m_new_def = false; while (m_asserted_qhead < m_asserted_atoms.size() && !ctx().inconsistent() && m.inc()) { auto [bv, is_true] = m_asserted_atoms[m_asserted_qhead]; - - // m_bv_to_propagate.push_back(bv); - + api_bound* b = nullptr; TRACE("arith", tout << "propagate: " << literal(bv, !is_true) << "\n"; if (!m_bool_var2bound.contains(bv)) tout << "not found\n"); - if (m_bool_var2bound.find(bv, b)) - assert_bound(bv, is_true, *b); + if (m_bool_var2bound.find(bv, b) && !assert_bound(bv, is_true, *b)) { + get_infeasibility_explanation_and_set_conflict(); + return true; + } ++m_asserted_qhead; } - if (ctx().inconsistent()) { - m_bv_to_propagate.reset(); + if (ctx().inconsistent()) return true; - } lbool lbl = make_feasible(); if (!m.inc()) - return false; + return true; switch(lbl) { case l_false: @@ -2211,15 +2113,59 @@ class theory_lra::imp { get_infeasibility_explanation_and_set_conflict(); break; case l_true: - propagate_basic_bounds(); propagate_bounds_with_lp_solver(); break; case l_undef: + UNREACHABLE(); break; } return true; } + void propagate_nla() { + if (m_nla) { + m_nla->propagate(); + add_lemmas(); + lp().collect_more_rows_for_lp_propagation(); + } + } + + void add_equality(lpvar j, rational const& k, lp::explanation const& exp) { + TRACE("arith", tout << "equality " << j << " " << k << "\n"); + theory_var v; + if (k == 1) + v = m_one_var; + else if (k == 0) + v = m_zero_var; + else if (!m_value2var.find(k, v)) + return; + theory_var w = lp().local_to_external(j); + if (w < 0) + return; + lpvar i = register_theory_var_in_lar_solver(v); + add_eq(i, j, exp, true); + } + + void add_lemmas() { + if (m_nla->should_check_feasible()) { + auto is_sat = make_feasible(); + if (l_false == is_sat) { + get_infeasibility_explanation_and_set_conflict(); + return; + } + } + for (const nla::ineq& i : m_nla->literals()) + assume_literal(i); + for (const nla::lemma & l : m_nla->lemmas()) + false_case_of_check_nla(l); + if (!propagate_eqs()) + return; + for (auto const& [v, k, e] : m_nla->fixed_equalities()) + add_equality(v, k, e); + for (auto const& [i, j, e] : m_nla->equalities()) + add_eq(i, j, e, false); + } + bool should_propagate() const { return bound_prop_mode::BP_NONE != propagation_mode(); } @@ -2242,19 +2188,15 @@ class theory_lra::imp { if (!m.inc()) return; - if (is_infeasible()) { get_infeasibility_explanation_and_set_conflict(); - // verbose_stream() << "unsat\n"; } else { - unsigned count = 0, prop = 0; for (auto& ib : m_bp.ibounds()) { m.inc(); if (ctx().inconsistent()) break; - ++prop; - count += propagate_lp_solver_bound(ib); + propagate_lp_solver_bound(ib); } } } @@ -2378,7 +2320,7 @@ class theory_lra::imp { void refine_bound(theory_var v, const lp::implied_bound& be) { lpvar vi = be.m_j; - if (lp::tv::is_term(vi)) + if (lp().column_has_term(vi)) return; expr_ref w(get_enode(v)->get_expr(), m); if (a.is_add(w) || a.is_numeral(w) || m.is_ite(w)) @@ -2418,10 +2360,14 @@ class theory_lra::imp { return false; theory_var uv = lp().local_to_external(u); // variables that are returned should have external representations theory_var vv = lp().local_to_external(v); // so maybe better to have them already transformed to external form + if (uv == null_theory_var) + return false; + if (vv == null_theory_var) + return false; enode* n1 = get_enode(uv); enode* n2 = get_enode(vv); - TRACE("arith", tout << "add-eq " << mk_pp(n1->get_expr(), m) << " == " << mk_pp(n2->get_expr(), m) << " " << n1->get_expr_id() << " == " << n2->get_expr_id() << "\n";); + TRACE("arith", tout << "add-eq " << pp(n1) << " == " << pp(n2) << "\n";); if (n1->get_root() == n2->get_root()) return false; expr* e1 = n1->get_expr(); @@ -2439,7 +2385,9 @@ class theory_lra::imp { literal_vector m_core2; - void assign(literal lit, literal_vector const& core, svector const& eqs, vector const& params) { + void assign(literal lit, literal_vector const& core, svector const& eqs, vector const& ps) { + if (params().m_arith_validate) + VERIFY(validate_assign(lit, core, eqs)); if (core.size() < small_lemma_size() && eqs.empty()) { m_core2.reset(); for (auto const& c : core) { @@ -2449,7 +2397,7 @@ class theory_lra::imp { justification * js = nullptr; if (proofs_enabled()) { js = alloc(theory_lemma_justification, get_id(), ctx(), m_core2.size(), m_core2.data(), - params.size(), params.data()); + ps.size(), ps.data()); } ctx().mk_clause(m_core2.size(), m_core2.data(), js, CLS_TH_LEMMA, nullptr); } @@ -2458,7 +2406,7 @@ class theory_lra::imp { lit, ctx().mk_justification( ext_theory_propagation_justification( get_id(), ctx(), core.size(), core.data(), - eqs.size(), eqs.data(), lit, params.size(), params.data()))); + eqs.size(), eqs.data(), lit, ps.size(), ps.data()))); } } @@ -2725,18 +2673,6 @@ class theory_lra::imp { } return end; } - - void propagate_basic_bounds() { - for (auto const& bv : m_bv_to_propagate) { - api_bound* b = nullptr; - if (m_bool_var2bound.find(bv, b)) { - propagate_bound(bv, ctx().get_assignment(bv) == l_true, *b); - if (ctx().inconsistent()) - break; - } - } - m_bv_to_propagate.reset(); - } // for glb lo': lo' < lo: // lo <= x -> lo' <= x @@ -2807,27 +2743,27 @@ class theory_lra::imp { ++m_stats.m_bounds_propagations; } - svector m_todo_vars; + svector m_todo_vars; void add_use_lists(api_bound* b) { theory_var v = b->get_var(); lpvar vi = register_theory_var_in_lar_solver(v); - if (!lp::tv::is_term(vi)) { + if (!lp().column_has_term(vi)) { return; } - m_todo_vars.push_back(lp::tv::raw(vi)); + m_todo_vars.push_back(vi); while (!m_todo_vars.empty()) { auto ti = m_todo_vars.back(); - SASSERT(ti.is_term()); + SASSERT(lp().column_has_term(ti)); m_todo_vars.pop_back(); lp::lar_term const& term = lp().get_term(ti); for (auto p : term) { - lp::tv wi = lp().column2tv(p.column()); - if (wi.is_term()) { + lp::lpvar wi = p.j(); + if (lp().column_has_term(wi)) { m_todo_vars.push_back(wi); } else { - unsigned w = lp().local_to_external(wi.id()); + unsigned w = lp().local_to_external(wi); m_use_list.reserve(w + 1, ptr_vector()); m_use_list[w].push_back(b); } @@ -2838,22 +2774,22 @@ class theory_lra::imp { void del_use_lists(api_bound* b) { theory_var v = b->get_var(); lpvar vi = get_lpvar(v); - if (!lp::tv::is_term(vi)) { + if (!lp().column_has_term(vi)) { return; } - m_todo_vars.push_back(lp::tv::raw(vi)); + m_todo_vars.push_back(vi); while (!m_todo_vars.empty()) { auto ti = m_todo_vars.back(); - SASSERT(ti.is_term()); + SASSERT(lp().column_has_term(ti)); m_todo_vars.pop_back(); lp::lar_term const& term = lp().get_term(ti); for (auto coeff : term) { - auto wi = lp().column2tv(coeff.column()); - if (wi.is_term()) { + auto wi = coeff.j(); + if (lp().column_has_term(wi)) { m_todo_vars.push_back(wi); } else { - unsigned w = lp().local_to_external(wi.id()); + unsigned w = lp().local_to_external(wi); SASSERT(m_use_list[w].back() == b); m_use_list[w].pop_back(); } @@ -2869,7 +2805,7 @@ class theory_lra::imp { // void propagate_bound_compound(bool_var bv, bool is_true, api_bound& b) { theory_var v = b.get_var(); - TRACE("arith", tout << mk_pp(get_owner(v), m) << "\n";); + TRACE("arith", tout << pp(v) << "\n";); if (static_cast(v) >= m_use_list.size()) { return; } @@ -2932,20 +2868,20 @@ class theory_lra::imp { reset_evidence(); r.reset(); theory_var v = b.get_var(); - auto ti = get_tv(v); - SASSERT(ti.is_term()); + lp::lpvar ti = get_lpvar(v); + SASSERT(lp().column_has_term(ti)); lp::lar_term const& term = lp().get_term(ti); for (auto const mono : term) { - auto wi = lp().column2tv(mono.column()); - lp::constraint_index ci; + auto wi = mono.j(); + u_dependency* ci = nullptr; rational value; bool is_strict; - if (wi.is_term()) { + if (lp().column_has_term(wi)) { return false; } if (mono.coeff().is_neg() == is_lub) { // -3*x ... <= lub based on lower bound for x. - if (!lp().has_lower_bound(wi.id(), ci, value, is_strict)) { + if (!lp().has_lower_bound(wi, ci, value, is_strict)) { return false; } if (is_strict) { @@ -2953,7 +2889,7 @@ class theory_lra::imp { } } else { - if (!lp().has_upper_bound(wi.id(), ci, value, is_strict)) { + if (!lp().has_upper_bound(wi, ci, value, is_strict)) { return false; } if (is_strict) { @@ -2978,13 +2914,12 @@ class theory_lra::imp { return lp::EQ; } - void assert_bound(bool_var bv, bool is_true, api_bound& b) { + bool assert_bound(bool_var bv, bool is_true, api_bound& b) { TRACE("arith", tout << b << "\n";); lp::constraint_index ci = b.get_constraint(is_true); lp().activate(ci); - if (is_infeasible()) { - return; - } + if (is_infeasible()) + return false; lp::lconstraint_kind k = bound2constraint_kind(b.is_int(), b.get_bound_kind(), is_true); if (k == lp::LT || k == lp::LE) { ++m_stats.m_assert_lower; @@ -2993,12 +2928,12 @@ class theory_lra::imp { ++m_stats.m_assert_upper; } inf_rational value = b.get_value(is_true); - if (propagate_eqs() && value.is_rational()) { - propagate_eqs(b.tv(), ci, k, b, value.get_rational()); - } + if (propagate_eqs() && value.is_rational()) + propagate_eqs(b.column_index(), ci, k, b, value.get_rational()); + return true; #if 0 if (should_propagate()) - lp().mark_rows_for_bound_prop(b.tv().id()); + lp().add_column_rows_to_touched_rows(b.tv().id()); #endif } @@ -3039,13 +2974,14 @@ class theory_lra::imp { vector m_lower_terms; vector m_upper_terms; - void propagate_eqs(lp::tv t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { - lp::constraint_index ci2; - if (k == lp::GE && set_lower_bound(t, ci1, value) && has_upper_bound(t.index(), ci2, value)) { - fixed_var_eh(b.get_var(), t, ci1, ci2, value); + void propagate_eqs(lp::lpvar t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { + u_dependency* ci2 = nullptr; + auto pair = [&]() { return lp().dep_manager().mk_join(lp().dep_manager().mk_leaf(ci1), ci2); }; + if (k == lp::GE && set_lower_bound(t, ci1, value) && has_upper_bound(t, ci2, value)) { + fixed_var_eh(b.get_var(), t, pair(), value); } - else if (k == lp::LE && set_upper_bound(t, ci1, value) && has_lower_bound(t.index(), ci2, value)) { - fixed_var_eh(b.get_var(), t, ci1, ci2, value); + else if (k == lp::LE && set_upper_bound(t, ci1, value) && has_lower_bound(t, ci2, value)) { + fixed_var_eh(b.get_var(), t, pair(), value); } } @@ -3058,39 +2994,38 @@ class theory_lra::imp { bool proofs_enabled() const { return m.proofs_enabled(); } - bool set_upper_bound(lp::tv t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, false); } + bool set_upper_bound(lp::lpvar t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, false); } - bool set_lower_bound(lp::tv t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, true); } + bool set_lower_bound(lp::lpvar t, lp::constraint_index ci, rational const& v) { return set_bound(t, ci, v, true); } vector m_history; - bool set_bound(lp::tv tv, lp::constraint_index ci, rational const& v, bool is_lower) { - if (tv.is_term()) { - lpvar ti = tv.id(); + bool set_bound(lp::lpvar tv, lp::constraint_index ci, rational const& v, bool is_lower) { + if (lp().column_has_term(tv)) { auto& vec = is_lower ? m_lower_terms : m_upper_terms; - if (vec.size() <= ti) { - vec.resize(ti + 1, constraint_bound(UINT_MAX, rational())); + if (vec.size() <= tv) { + vec.resize(tv + 1, constraint_bound(UINT_MAX, rational())); } - constraint_bound& b = vec[ti]; + constraint_bound& b = vec[tv]; if (b.first == UINT_MAX || (is_lower? b.second < v : b.second > v)) { - TRACE("arith", tout << "tighter bound " << tv.to_string() << "\n";); - m_history.push_back(vec[ti]); - ctx().push_trail(history_trail(vec, ti, m_history)); + TRACE("arith", tout << "tighter bound " << tv << "\n";); + m_history.push_back(vec[tv]); + ctx().push_trail(history_trail(vec, tv, m_history)); b.first = ci; b.second = v; } return true; } else { - TRACE("arith", tout << "not a term " << tv.to_string() << "\n";); // m_solver already tracks bounds on proper variables, but not on terms. bool is_strict = false; rational b; + u_dependency* dep = nullptr; if (is_lower) { - return lp().has_lower_bound(tv.id(), ci, b, is_strict) && !is_strict && b == v; + return lp().has_lower_bound(tv, dep, b, is_strict) && !is_strict && b == v; } else { - return lp().has_upper_bound(tv.id(), ci, b, is_strict) && !is_strict && b == v; + return lp().has_upper_bound(tv, dep, b, is_strict) && !is_strict && b == v; } } } @@ -3098,35 +3033,36 @@ class theory_lra::imp { bool var_has_bound(lpvar vi, bool is_lower) { bool is_strict = false; rational b; - lp::constraint_index ci; + u_dependency* dep; if (is_lower) { - return lp().has_lower_bound(vi, ci, b, is_strict); + return lp().has_lower_bound(vi, dep, b, is_strict); } else { - return lp().has_upper_bound(vi, ci, b, is_strict); + return lp().has_upper_bound(vi, dep, b, is_strict); } } - bool has_upper_bound(lpvar vi, lp::constraint_index& ci, rational const& bound) { return has_bound(vi, ci, bound, false); } + bool has_upper_bound(lpvar vi, u_dependency*& ci, rational const& bound) { return has_bound(vi, ci, bound, false); } - bool has_lower_bound(lpvar vi, lp::constraint_index& ci, rational const& bound) { return has_bound(vi, ci, bound, true); } + bool has_lower_bound(lpvar vi, u_dependency*& ci, rational const& bound) { return has_bound(vi, ci, bound, true); } - bool has_bound(lpvar vi, lp::constraint_index& ci, rational const& bound, bool is_lower) { - if (lp::tv::is_term(vi)) { + bool has_bound(lpvar vi, u_dependency*& dep, rational const& bound, bool is_lower) { + if (lp().column_has_term(vi)) { theory_var v = lp().local_to_external(vi); rational val; TRACE("arith", tout << lp().get_variable_name(vi) << " " << v << "\n";); if (v != null_theory_var && a.is_numeral(get_owner(v), val) && bound == val) { - ci = UINT_MAX; + dep = nullptr; return bound == val; } auto& vec = is_lower ? m_lower_terms : m_upper_terms; - lpvar ti = lp::tv::unmask_term(vi); - if (vec.size() > ti) { - constraint_bound& b = vec[ti]; - ci = b.first; - return ci != UINT_MAX && bound == b.second; + if (vec.size() > vi) { + auto const& [ci, coeff] = vec[vi]; + if (ci == UINT_MAX) + return false; + dep = lp().dep_manager().mk_leaf(ci); + return bound == coeff; } else { return false; @@ -3136,10 +3072,10 @@ class theory_lra::imp { bool is_strict = false; rational b; if (is_lower) { - return lp().has_lower_bound(vi, ci, b, is_strict) && b == bound && !is_strict; + return lp().has_lower_bound(vi, dep, b, is_strict) && b == bound && !is_strict; } else { - return lp().has_upper_bound(vi, ci, b, is_strict) && b == bound && !is_strict; + return lp().has_upper_bound(vi, dep, b, is_strict) && b == bound && !is_strict; } } } @@ -3152,10 +3088,10 @@ class theory_lra::imp { void report_equality_of_fixed_vars(unsigned vi1, unsigned vi2) { rational bound(0); - lp::constraint_index ci1, ci2, ci3, ci4; + u_dependency* ci1 = nullptr, *ci2 = nullptr, *ci3 = nullptr, *ci4 = nullptr; theory_var v1 = lp().local_to_external(vi1); theory_var v2 = lp().local_to_external(vi2); - TRACE("arith", tout << "fixed: " << mk_pp(get_owner(v1), m) << " " << mk_pp(get_owner(v2), m) << "\n";); + TRACE("arith", tout << "fixed: " << pp(v1) << " " << pp(v2) << "\n";); // we expect lp() to ensure that none of these returns happen. if (is_equal(v1, v2)) return; @@ -3191,42 +3127,44 @@ class theory_lra::imp { for (auto c : m_core) ctx().display_detailed_literal(tout << ctx().get_assign_level(c.var()) << " " << c << " ", c) << "\n"; for (auto e : m_eqs) - tout << pp(e.first, m) << " = " << pp(e.second, m) << "\n"; - tout << " ==> "; - tout << pp(x, m) << " = " << pp(y, m) << "\n"; + tout << pp(e.first) << " = " << pp(e.second) << "\n"; + tout << " ==> " << pp(x) << " = " << pp(y) << "\n"; ); std::function fn = [&]() { return m.mk_eq(x->get_expr(), y->get_expr()); }; scoped_trace_stream _sts(th, fn); - - // SASSERT(validate_eq(x, y)); + if (params().m_arith_validate) + VERIFY(validate_eq(x, y)); ctx().assign_eq(x, y, eq_justification(js)); } - void fixed_var_eh(theory_var v, lp::tv t, lp::constraint_index ci1, lp::constraint_index ci2, rational const& bound) { + void fixed_var_eh(theory_var v, lp::lpvar t, u_dependency* dep, rational const& bound) { theory_var w = null_theory_var; enode* x = get_enode(v); - if (bound.is_zero()) + if (m_value2var.find(bound, w)) + ; + else if (bound.is_zero()) w = lp().local_to_external(get_zero(a.is_int(x->get_expr()))); else if (bound.is_one()) w = lp().local_to_external(get_one(a.is_int(x->get_expr()))); - else if (!m_value2var.find(bound, w)) + else return; enode* y = get_enode(w); + TRACE("arith", tout << pp(x) << " == " << pp(y) << "\n"); if (x->get_sort() != y->get_sort()) return; if (x->get_root() == y->get_root()) return; reset_evidence(); - set_evidence(ci1, m_core, m_eqs); - set_evidence(ci2, m_core, m_eqs); + set_evidence(dep, m_core, m_eqs); ++m_stats.m_fixed_eqs; assign_eq(v, w); } lbool make_feasible() { TRACE("pcs", tout << lp().constraints();); + TRACE("arith_verbose", tout << "before calling lp().find_feasible_solution()\n"; display(tout);); auto status = lp().find_feasible_solution(); TRACE("arith_verbose", display(tout);); if (lp().is_feasible()) @@ -3240,7 +3178,6 @@ class theory_lra::imp { } lp::explanation m_explanation; - vector m_nla_lemma_vector; literal_vector m_core; svector m_eqs; vector m_params; @@ -3253,6 +3190,11 @@ class theory_lra::imp { // lp::constraint_index const null_constraint_index = UINT_MAX; // not sure what a correct fix is + void set_evidence(u_dependency* dep, literal_vector& core, svector& eqs) { + for (auto ci : lp().flatten(dep)) + set_evidence(ci, core, eqs); + } + void set_evidence(lp::constraint_index idx, literal_vector& core, svector& eqs) { if (idx == UINT_MAX) { return; @@ -3299,16 +3241,17 @@ class theory_lra::imp { // lp().shrink_explanation_to_minimum(m_explanation); // todo, enable when perf is fixed ++m_num_conflicts; ++m_stats.m_conflicts; - TRACE("arith", - tout << "lemma scope: " << ctx().get_scope_level(); + TRACE("arith_conflict", + tout << "@" << ctx().get_scope_level() << (is_conflict ? " conflict":" lemma"); for (auto const& p : m_params) tout << " " << p; tout << "\n"; - display_evidence(tout, m_explanation); - display(tout << "is-conflict: " << is_conflict << "\n");); + display_evidence(tout, m_explanation);); for (auto ev : m_explanation) set_evidence(ev.ci(), m_core, m_eqs); - - // SASSERT(validate_conflict(m_core, m_eqs)); + + + if (params().m_arith_validate) + VERIFY(validate_conflict(m_core, m_eqs)); if (is_conflict) { ctx().set_conflict( ctx().mk_justification( @@ -3326,16 +3269,6 @@ class theory_lra::imp { ctx().mark_as_relevant(c); } TRACE("arith", ctx().display_literals_verbose(tout, m_core) << "\n";); - // DEBUG_CODE( - // for (literal const& c : m_core) { - // if (ctx().get_assignment(c) == l_true) { - // TRACE("arith", ctx().display_literal_verbose(tout, c) << " is true\n";); - // SASSERT(false); - // } - // }); // TODO: this check seems to be too strict. - // The lemmas can come in batches - // and the same literal can appear in several lemmas in a batch: it becomes l_true - // in earlier processing, but it was not so when the lemma was produced ctx().mk_th_axiom(get_id(), m_core.size(), m_core.data()); } } @@ -3355,7 +3288,6 @@ class theory_lra::imp { m_assume_eq_head = 0; m_scopes.reset(); m_stats.reset(); - m_bv_to_propagate.reset(); m_model_is_initialized = false; } @@ -3376,12 +3308,14 @@ class theory_lra::imp { nlsat::anum const& nl_value(theory_var v, scoped_anum& r) const { SASSERT(use_nra_model()); - auto t = get_tv(v); - if (t.is_term()) { + auto t = get_lpvar(v); + if (!lp().column_has_term(t)) + m_nla->am().set(r, m_nla->am_value(t)); + else { - m_todo_terms.push_back(std::make_pair(t, rational::one())); - TRACE("nl_value", tout << "v" << v << " " << t.to_string() << "\n";); - TRACE("nl_value", tout << "v" << v << " := w" << t.to_string() << "\n"; + m_todo_terms.push_back({t, rational::one()}); + TRACE("nl_value", tout << "v" << v << " " << t << "\n";); + TRACE("nl_value", tout << "v" << v << " := w" << t << "\n"; lp().print_term(lp().get_term(t), tout) << "\n";); m_nla->am().set(r, 0); @@ -3396,34 +3330,31 @@ class theory_lra::imp { m_nla->am().set(r1, c1.to_mpq()); m_nla->am().add(r, r1, r); for (lp::lar_term::ival arg : term) { - auto wi = lp().column2tv(arg.column()); + auto wi = arg.j(); c1 = arg.coeff() * wcoeff; - if (wi.is_term()) { - m_todo_terms.push_back(std::make_pair(wi, c1)); + if (lp().column_has_term(wi)) { + m_todo_terms.push_back({wi, c1}); } else { m_nla->am().set(r1, c1.to_mpq()); - m_nla->am().mul(m_nla->am_value(wi.id()), r1, r1); + m_nla->am().mul(m_nla->am_value(wi), r1, r1); m_nla->am().add(r1, r, r); } } } - return r; - } - else { - return m_nla->am_value(t.id()); } + return r; } model_value_proc * mk_value(enode * n, model_generator & mg) { theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { - anum const& an = nl_value(v, *m_a1); + anum const& an = nl_value(v, m_nla->tmp1()); if (a.is_int(o) && !m_nla->am().is_int(an)) { return alloc(expr_wrapper_proc, a.mk_numeral(rational::zero(), a.is_int(o))); } - return alloc(expr_wrapper_proc, a.mk_numeral(m_nla->am(), nl_value(v, *m_a1), a.is_int(o))); + return alloc(expr_wrapper_proc, a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o))); } else { rational r = get_value(v); @@ -3468,12 +3399,16 @@ class theory_lra::imp { a.is_mod0(f); } + bool include_func_interp(enode* n) { + return include_func_interp(n->get_decl()); + } + bool get_lower(enode* n, rational& val, bool& is_strict) { theory_var v = n->get_th_var(get_id()); if (!is_registered_var(v)) return false; lpvar vi = get_lpvar(v); - lp::constraint_index ci; + u_dependency* ci; return lp().has_lower_bound(vi, ci, val, is_strict); } @@ -3492,8 +3427,8 @@ class theory_lra::imp { if (!is_registered_var(v)) return false; lpvar vi = get_lpvar(v); - lp::constraint_index ci; - return lp().has_upper_bound(vi, ci, val, is_strict); + u_dependency* dep = nullptr; + return lp().has_upper_bound(vi, dep, val, is_strict); } @@ -3522,13 +3457,15 @@ class theory_lra::imp { bool validate_conflict(literal_vector const& core, svector const& eqs) { if (params().m_arith_mode != arith_solver_id::AS_NEW_ARITH) return true; + + VERIFY(!m_core.empty() || !m_eqs.empty()); scoped_arith_mode _sa(ctx().get_fparams()); context nctx(m, ctx().get_fparams(), ctx().get_params()); add_background(nctx); cancel_eh eh(m.limit()); scoped_timer timer(1000, &eh); bool result = l_true != nctx.check(); - CTRACE("arith", !result, ctx().display_lemma_as_smt_problem(tout, core.size(), core.data(), eqs.size(), eqs.data(), false_literal);); + CTRACE("arith", !result, ctx().display_lemma_as_smt_problem(tout, core.size(), core.data(), eqs.size(), eqs.data(), false_literal);); return result; } @@ -3556,12 +3493,14 @@ class theory_lra::imp { flet _svalid(s_validating, true); context nctx(m, ctx().get_fparams(), ctx().get_params()); add_background(nctx); - nctx.assert_expr(m.mk_not(m.mk_eq(x->get_expr(), y->get_expr()))); + expr_ref neq(m.mk_not(m.mk_eq(x->get_expr(), y->get_expr())), m); + nctx.assert_expr(neq); cancel_eh eh(m.limit()); scoped_timer timer(1000, &eh); lbool r = nctx.check(); if (r == l_true) { nctx.display_asserted_formulas(std::cout); + std::cout.flush(); } return l_true != r; } @@ -3674,17 +3613,17 @@ class theory_lra::imp { TRACE("arith", lp().print_term(term, tout) << "\n";); for (lp::lar_term::ival ti : term) { theory_var w; - auto tv = lp().column2tv(ti.column()); - if (tv.is_term()) { + auto tv = ti.j(); + if (lp().column_has_term(tv)) { lp::lar_term const& term1 = lp().get_term(tv); rational coeff2 = coeff * ti.coeff(); term2coeffs(term1, coeffs, coeff2); continue; } else { - w = lp().local_to_external(tv.id()); + w = lp().local_to_external(tv); SASSERT(w >= 0); - TRACE("arith", tout << (tv.id()) << ": " << w << "\n";); + TRACE("arith", tout << tv << ": " << w << "\n";); } rational c0(0); coeffs.find(w, c0); @@ -3694,17 +3633,16 @@ class theory_lra::imp { app_ref coeffs2app(u_map const& coeffs, rational const& offset, bool is_int) { expr_ref_vector args(m); - for (auto const& kv : coeffs) { - theory_var w = kv.m_key; + for (auto const& [w, coeff] : coeffs) { expr* o = get_enode(w)->get_expr(); - if (kv.m_value.is_zero()) { + if (coeff.is_zero()) { // continue } - else if (kv.m_value.is_one()) { + else if (coeff.is_one()) { args.push_back(o); } else { - args.push_back(a.mk_mul(a.mk_numeral(kv.m_value, is_int), o)); + args.push_back(a.mk_mul(a.mk_numeral(coeff, is_int), o)); } } if (!offset.is_zero()) { @@ -3742,9 +3680,9 @@ class theory_lra::imp { } app_ref mk_obj(theory_var v) { - auto t = get_tv(v); + auto t = get_lpvar(v); bool is_int = a.is_int(get_enode(v)->get_expr()); - if (t.is_term()) { + if (lp().column_has_term(t)) { return mk_term(lp().get_term(t), is_int); } else { @@ -3779,6 +3717,7 @@ class theory_lra::imp { m_bounds[v].push_back(a); m_bounds_trail.push_back(v); m_bool_var2bound.insert(bv, a); + TRACE("arith", tout << "internalized " << bv << ": " << mk_pp(b, m) << "\n";); } if (is_strict) { @@ -3799,16 +3738,15 @@ class theory_lra::imp { } unsigned nv = th.get_num_vars(); for (unsigned v = 0; v < nv; ++v) { - auto t = get_tv(v); - auto vi = lp().external_to_column_index(v); + auto vi = get_lpvar(v); if (!ctx().is_relevant(get_enode(v))) out << "irr: "; out << "v" << v << " "; - if (t.is_null()) out << "null"; else out << (t.is_term() ? "t":"j") << vi; - if (use_nra_model() && is_registered_var(v)) m_nla->am().display(out << " = ", nl_value(v, *m_a1)); + if (vi == lp::null_lpvar) out << "null"; else out << (lp().column_has_term(vi) ? "t":"j") << vi; + if (use_nra_model() && is_registered_var(v)) m_nla->am().display(out << " = ", nl_value(v, m_nla->tmp1())); else if (can_get_value(v)) out << " = " << get_value(v); if (is_int(v)) out << ", int"; if (ctx().is_shared(get_enode(v))) out << ", shared"; - out << " := " << enode_pp(get_enode(v), ctx()) << "\n"; + out << " := " << pp(v) << "\n"; } } @@ -3824,17 +3762,17 @@ class theory_lra::imp { case inequality_source: { literal lit = m_inequalities[idx]; ctx().literal2expr(lit, e); - out << e << " " << ctx().get_assignment(lit) << "\n"; + out << bpp(e) << " " << ctx().get_assignment(lit) << "\n"; break; } case equality_source: - out << pp(m_equalities[idx].first, m) << " = " - << pp(m_equalities[idx].second, m) << "\n"; + out << pp(m_equalities[idx].first) << " = " + << pp(m_equalities[idx].second) << "\n"; break; case definition_source: { theory_var v = m_definitions[idx]; if (v != null_theory_var) - out << "def: v" << v << " := " << pp(th.get_enode(v), m) << "\n"; + out << "def: v" << v << " := " << pp(th.get_enode(v)) << "\n"; break; } case null_source: @@ -3845,16 +3783,14 @@ class theory_lra::imp { break; } } - for (lp::explanation::cimpq ev : evidence) { + for (lp::explanation::cimpq ev : evidence) lp().constraints().display(out << ev.coeff() << ": ", ev.ci()); - } } void collect_statistics(::statistics & st) const { m_arith_eq_adapter.collect_statistics(st); m_stats.collect_statistics(st); lp().settings().stats().collect_statistics(st); - if (m_nla) m_nla->collect_statistics(st); } /* @@ -3888,6 +3824,30 @@ class theory_lra::imp { } + void validate_model(proto_model& mdl) { + + rational r1, r2; + expr_ref res(m); + if (!m_model_is_initialized) + return; + for (unsigned v = 0; v < th.get_num_vars(); ++v) { + if (!is_registered_var(v)) + continue; + enode* n = get_enode(v); + if (!n) + continue; + if (!th.is_relevant_and_shared(n)) + continue; + rational r1 = get_value(v); + if (!mdl.eval(n->get_expr(), res, false)) + continue; + if (!a.is_numeral(res, r2)) + continue; + if (r1 != r2) + IF_VERBOSE(1, verbose_stream() << enode_pp(n, ctx()) << " evaluates to " << r2 << " but arith solver has " << r1 << "\n"); + } + } + }; theory_lra::theory_lra(context& ctx): @@ -4015,6 +3975,10 @@ void theory_lra::setup() { m_imp->setup(); } +void theory_lra::validate_model(proto_model& mdl) { + m_imp->validate_model(mdl); +} + } template class lp::lp_bound_propagator; template void lp::lar_solver::propagate_bounds_for_touched_rows(lp::lp_bound_propagator&); diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index b7d27107925..4c2351c85d5 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -82,6 +82,7 @@ namespace smt { void init_model(model_generator & m) override; model_value_proc * mk_value(enode * n, model_generator & mg) override; + void validate_model(proto_model& mdl) override; bool get_value(enode* n, expr_ref& r) override; bool include_func_interp(func_decl* f) override; diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index 718d5c65ab8..25c8e7195ee 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -1807,7 +1807,7 @@ namespace smt { bool theory_pb::resolve_conflict(card& c, literal_vector const& confl) { - TRACE("pb", display(tout, c, true); ); + TRACE("pb", display(tout << "resolve conflict\n", c, true); ); bool_var v; m_conflict_lvl = 0; @@ -1839,8 +1839,19 @@ namespace smt { literal conseq = ~confl[2]; int bound = 1; + auto clear_marks = [&]() { + while (m_num_marks > 0 && idx > 0) { + v = lits[idx].var(); + if (ctx.is_marked(v)) { + ctx.unset_mark(v); + } + --idx; + } + }; + while (m_num_marks > 0) { + TRACE("pb", tout << "conseq: " << conseq << "\n"); v = conseq.var(); int offset = get_abs_coeff(v); @@ -1850,13 +1861,7 @@ namespace smt { } SASSERT(validate_lemma()); if (offset > 1000) { - while (m_num_marks > 0 && idx > 0) { - v = lits[idx].var(); - if (ctx.is_marked(v)) { - ctx.unset_mark(v); - } - --idx; - } + clear_marks(); return false; } @@ -1884,8 +1889,11 @@ namespace smt { clause& cls = *js.get_clause(); justification* cjs = cls.get_justification(); unsigned num_lits = cls.get_num_literals(); - if (cjs && typeid(smt::unit_resolution_justification) == typeid(*cjs)) - ; + CTRACE("pb", cjs, tout << (typeid(smt::unit_resolution_justification) == typeid(*cjs)) << "\n"); + if (cjs && typeid(smt::unit_resolution_justification) == typeid(*cjs)) { + clear_marks(); + return false; + } else if (cjs && !is_proof_justification(*cjs)) { TRACE("pb", tout << "not processing justification over: " << conseq << " " << typeid(*cjs).name() << "\n";); break; @@ -1954,7 +1962,8 @@ namespace smt { while (true) { conseq = lits[idx]; v = conseq.var(); - if (ctx.is_marked(v)) break; + if (ctx.is_marked(v)) + break; SASSERT(idx > 0); --idx; } diff --git a/src/smt/theory_polymorphism.h b/src/smt/theory_polymorphism.h new file mode 100644 index 00000000000..4c64a0a9c7d --- /dev/null +++ b/src/smt/theory_polymorphism.h @@ -0,0 +1,105 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + theory_polymorphism.h + +Abstract: + + Plugin for handling polymorphism + The plugin instantiates polymorphic axioms based on occurrences of polymorphic functions in other axioms. + It uses blocking literals to restart search when there are new axioms that can be instantiated. + +Author: + + Nikolaj Bjorner (nbjorner) 2013-07-11 + +--*/ +#pragma once + +#include "ast/polymorphism_inst.h" +#include "smt/smt_theory.h" + +namespace smt { + + class theory_polymorphism : public theory { + trail_stack m_trail; + polymorphism::inst m_inst; + expr_ref m_assumption; + unsigned m_qhead = 0; + bool m_pending = true; + + bool internalize_atom(app*, bool) override { return false; } + bool internalize_term(app*) override { return false; } + void new_eq_eh(theory_var, theory_var) override { } + void new_diseq_eh(theory_var, theory_var) override {} + theory* mk_fresh(context* new_ctx) override { return alloc(theory_polymorphism, *new_ctx); } + char const * get_name() const override { return "polymorphism"; } + void display(std::ostream& out) const override {} + + void push_scope_eh() override { + m_trail.push_scope(); + } + + void pop_scope_eh(unsigned n) override { + m_trail.pop_scope(n); + } + + bool can_propagate() override { + return m_pending; + } + + /** + * Assert instances of polymorphic axioms + */ + void propagate() override { + if (!m_pending) + return; + m_pending = false; + vector instances; + m_inst.instantiate(instances); + if (instances.empty()) + return; + for (auto const& [orig, inst, sub] : instances) + ctx.add_asserted(inst); + ctx.internalize_assertions(); + } + + final_check_status final_check_eh() override { + if (m_inst.pending()) + ctx.assign(~mk_literal(m_assumption), nullptr); + return FC_DONE; + } + + void add_theory_assumptions(expr_ref_vector & assumptions) override { + if (m_qhead == ctx.get_num_asserted_formulas()) + return; + m_assumption = m.mk_fresh_const("poly", m.mk_bool_sort()); + assumptions.push_back(m_assumption); + ctx.push_trail(value_trail(m_qhead)); + for (; m_qhead < ctx.get_num_asserted_formulas(); ++m_qhead) + m_inst.add(ctx.get_asserted_formula(m_qhead)); + m_pending = true; + } + + bool should_research(expr_ref_vector & assumptions) override { + for (auto * a : assumptions) + if (a == m_assumption) + return true; + return false; + } + + + public: + theory_polymorphism(context& ctx): + theory(ctx, poly_family_id), + m_inst(ctx.get_manager(), m_trail), + m_assumption(ctx.get_manager()) {} + + void init_model(model_generator & mg) override { } + }; + +}; + + diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index 7b0955518f5..0d16120b19b 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -1244,7 +1244,7 @@ bool theory_seq::get_length(expr* e, expr_ref& len, literal_vector& lits) { /** * solve for fold/map (recursive function that depends on a sequence) - * Assumption: the Seq argument of fold/map expands into a concatentation of units + * Assumption: the Seq argument of fold/map expands into a concatenation of units * The assumption is enforced by tracking the length of the seq argument. * This is ensured in relevant_eh. * Under the assumption, evern occurrence of fold/map gets simplified by expanding @@ -1270,7 +1270,6 @@ bool theory_seq::solve_nc(unsigned idx) { expr_ref c(m); expr* a = nullptr, *b = nullptr; VERIFY(m_util.str.is_contains(n.contains(), a, b)); - literal pre, cnt, ctail, emp; lbool is_gt = ctx.get_assignment(len_gt); TRACE("seq", ctx.display_literal_smt2(tout << len_gt << " := " << is_gt << "\n", len_gt) << "\n";); @@ -2160,7 +2159,7 @@ app* theory_seq::mk_value(app* e) { } -void theory_seq::validate_model(model& mdl) { +void theory_seq::validate_model(proto_model& mdl) { return; for (auto const& eq : m_eqs) { sort* srt = eq.ls[0]->get_sort(); diff --git a/src/smt/theory_seq.h b/src/smt/theory_seq.h index 49213dbd49f..6ed4fc41f27 100644 --- a/src/smt/theory_seq.h +++ b/src/smt/theory_seq.h @@ -407,7 +407,7 @@ namespace smt { void init_model(model_generator & mg) override; void finalize_model(model_generator & mg) override; void init_search_eh() override; - void validate_model(model& mdl) override; + void validate_model(proto_model& mdl) override; bool is_beta_redex(enode* p, enode* n) const override; void init_model(expr_ref_vector const& es); diff --git a/src/smt/theory_special_relations.cpp b/src/smt/theory_special_relations.cpp index ddddfbc000a..30eac685d06 100644 --- a/src/smt/theory_special_relations.cpp +++ b/src/smt/theory_special_relations.cpp @@ -130,7 +130,11 @@ namespace smt { } bool theory_special_relations::internalize_term(app * term) { - verbose_stream() << mk_pp(term, m) << "\n"; + m_terms.push_back(term); + ctx.push_trail(push_back_vector(m_terms)); + std::stringstream strm; + strm << "term not not handled by special relations procedure. Use sat.smt=true " << mk_pp(term, m); + warning_msg(strm.str().c_str()); return false; } @@ -207,9 +211,10 @@ namespace smt { if (new_equality) { return FC_CONTINUE; } - else { + else if (!m_terms.empty()) + return FC_GIVEUP; + else return FC_DONE; - } } lbool theory_special_relations::final_check_lo(relation& r) { @@ -888,9 +893,20 @@ namespace smt { func_decl* memf, *nextf, *connectedf; + std::string member, next, connected_sym, id; + auto const& pa = r.decl()->get_parameter(0); + if (pa.is_int()) + id = std::to_string(pa.get_int()); + else if (pa.is_ast() && is_func_decl(pa.get_ast())) + id = to_func_decl(pa.get_ast())->get_name().str(); + else + throw default_exception("expected an integer or function declaration"); + member = "member" + id; + next = "next" + id; + connected_sym = "connected" + id; { sort* dom[2] = { s, listS }; - recfun::promise_def mem = p.ensure_def(symbol("member"), 2, dom, m.mk_bool_sort(), true); + recfun::promise_def mem = p.ensure_def(symbol(member), 2, dom, m.mk_bool_sort(), true); memf = mem.get_def()->get_decl(); var_ref xV(m.mk_var(1, s), m); @@ -913,7 +929,7 @@ namespace smt { { sort* dom[5] = { s, s, listS, listS, tup }; - recfun::promise_def nxt = p.ensure_def(symbol("next"), 5, dom, tup, true); + recfun::promise_def nxt = p.ensure_def(symbol(next), 5, dom, tup, true); nextf = nxt.get_def()->get_decl(); expr_ref next_body(m); @@ -934,7 +950,7 @@ namespace smt { { sort* dom[3] = { listS, s, listS }; - recfun::promise_def connected = p.ensure_def(symbol("connected"), 3, dom, m.mk_bool_sort(), true); + recfun::promise_def connected = p.ensure_def(symbol(connected_sym), 3, dom, m.mk_bool_sort(), true); connectedf = connected.get_def()->get_decl(); var_ref AV(m.mk_var(2, listS), m); var_ref dstV(m.mk_var(1, s), m); diff --git a/src/smt/theory_special_relations.h b/src/smt/theory_special_relations.h index 73e889a5dd9..65ce179070d 100644 --- a/src/smt/theory_special_relations.h +++ b/src/smt/theory_special_relations.h @@ -131,6 +131,7 @@ namespace smt { special_relations_util m_util; atoms m_atoms; unsigned_vector m_atoms_lim; + ptr_vector m_terms; obj_map m_relations; bool_var2atom m_bool_var2atom; bool m_can_propagate; diff --git a/src/smt/theory_str_noodler/theory_str_noodler.cpp b/src/smt/theory_str_noodler/theory_str_noodler.cpp index a7515e15dc5..b83a9b1ec2f 100644 --- a/src/smt/theory_str_noodler/theory_str_noodler.cpp +++ b/src/smt/theory_str_noodler/theory_str_noodler.cpp @@ -1536,7 +1536,7 @@ namespace smt::noodler { if(m_util_s.str.is_string(a, str_a) && str_a.length() == 1) { // s = emp -> v = t.a // NOTE: if we use ~s_emp, this diseqation does not become relevant - add_axiom({mk_literal(m.mk_not(m.mk_eq(s, eps))), mk_eq(v, mk_concat(t, a),false)}); + add_axiom({mk_literal(m.mk_not(m.mk_eq(s, eps))), mk_literal(m.mk_eq(r, mk_concat(t, a)))}); // s = a -> v = t // NOTE: if we use ~mk_eq(s, a), this diseqation does not become relevant add_axiom({mk_literal(m.mk_not(m.mk_eq(s, a))), mk_eq(v, t,false)}); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 8eeaf4382da..93faf9e4e6a 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -63,12 +63,11 @@ void theory_user_propagator::add_expr(expr* term, bool ensure_enode) { if (is_attached_to_var(n)) return; - theory_var v = mk_var(n); m_var2expr.reserve(v + 1); m_var2expr[v] = term; m_expr2var.setx(term->get_id(), v, null_theory_var); - + if (m.is_bool(e) && !ctx.b_internalized(e)) { bool_var bv = ctx.mk_bool_var(e); ctx.set_var_theory(bv, get_id()); @@ -80,10 +79,9 @@ void theory_user_propagator::add_expr(expr* term, bool ensure_enode) { literal_vector explain; if (ctx.is_fixed(n, r, explain)) m_prop.push_back(prop_info(explain, v, r)); - } -void theory_user_propagator::propagate_cb( +bool theory_user_propagator::propagate_cb( unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) { @@ -95,9 +93,10 @@ void theory_user_propagator::propagate_cb( if (!ctx.get_manager().is_true(_conseq) && !ctx.get_manager().is_false(_conseq)) ctx.mark_as_relevant((expr*)_conseq); - if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) - return; - m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) + return false; + m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + return true; } void theory_user_propagator::register_cb(expr* e) { @@ -107,15 +106,23 @@ void theory_user_propagator::register_cb(expr* e) { add_expr(e, true); } -void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { +bool theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { // clear - m_next_split_expr = nullptr; - return; + m_next_split_var = nullptr; + return true; + } + if (!ctx.e_internalized(e)) { + // We may not eagerly internalize it (might crash when done in pop) => delay + m_next_split_var = e; + return true; } - ensure_enode(e); - m_next_split_expr = e; + bool_var b = enode_to_bool(ctx.get_enode(e), idx); + if (b == null_bool_var || ctx.get_assignment(b) != l_undef) + return false; + m_next_split_var = e; m_next_split_idx = idx; m_next_split_phase = phase; + return true; } theory * theory_user_propagator::mk_fresh(context * new_ctx) { @@ -142,18 +149,18 @@ final_check_status theory_user_propagator::final_check_eh() { return FC_DONE; force_push(); unsigned sz1 = m_prop.size(); - unsigned sz2 = m_expr2var.size(); + unsigned sz2 = get_num_vars(); try { m_final_eh(m_user_context, this); } catch (...) { - throw default_exception("Exception thrown in \"final\"-callback"); + throw default_exception("Exception thrown in \"final\"-callback"); } CTRACE("user_propagate", can_propagate(), tout << "can propagate\n"); propagate(); CTRACE("user_propagate", ctx.inconsistent(), tout << "inconsistent\n"); // check if it became inconsistent or something new was propagated/registered - bool done = (sz1 == m_prop.size()) && (sz2 == m_expr2var.size()) && !ctx.inconsistent(); + bool done = (sz1 == m_prop.size()) && (sz2 == get_num_vars()) && !ctx.inconsistent(); return done ? FC_DONE : FC_CONTINUE; } @@ -167,42 +174,39 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu ctx.push_trail(insert_map(m_fixed, v)); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); try { - m_fixed_eh(m_user_context, this, var2expr(v), value); - } - catch (...) { + m_fixed_eh(m_user_context, this, var2expr(v), value); + } + catch (...) { throw default_exception("Exception thrown in \"fixed\"-callback"); - } + } } -bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) { +bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned idx) { if (n->is_bool()) { // expression is a boolean - bool_var new_var = ctx.enode2bool_var(n); - if (ctx.get_assignment(new_var) == l_undef) - return new_var; - return null_bool_var; + return ctx.enode2bool_var(n); } // expression is a bit-vector bv_util bv(m); auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); - return th_bv->get_first_unassigned(bit, n); + return th_bv->get_bit(idx, n); } void theory_user_propagator::decide(bool_var& var, bool& is_pos) { if (!m_decide_eh) return; - + const bool_var_data& d = ctx.get_bdata(var); - - if (!d.is_enode() && !d.is_theory_atom()) + + if (!d.is_enode() && !d.is_theory_atom()) return; - - enode* original_enode = nullptr; + + enode* original_enode = nullptr; unsigned original_bit = 0; bv_util bv(m); theory* th = nullptr; theory_var v = null_theory_var; - + // get the associated theory if (!d.is_enode()) { // it might be a value that does not have an enode @@ -216,7 +220,7 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { th = ctx.get_theory(d.get_theory()); } } - + if (v == null_theory_var && !th) return; @@ -225,7 +229,7 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { if (v == null_theory_var) { // it is not a registered boolean value but it is a bitvector - auto registered_bv = ((theory_bv*)th)->get_bv_with_theory(var, get_family_id()); + auto registered_bv = ((theory_bv*) th)->get_bv_with_theory(var, get_family_id()); if (!registered_bv.first) // there is no registered bv associated with the bit return; @@ -236,47 +240,38 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { // call the registered callback unsigned new_bit = original_bit; - lbool phase = is_pos ? l_true : l_false; - - expr* e = var2expr(v); - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - enode* new_enode = ctx.get_enode(e); - - // check if the callback changed something - if (original_enode == new_enode && (new_enode->is_bool() || original_bit == new_bit)) { - if (phase != l_undef) - // it only affected the truth value - is_pos = phase == l_true; + + force_push(); + expr *e = var2expr(v); + m_decide_eh(m_user_context, this, e, new_bit, is_pos); + + bool_var new_var; + if (!get_case_split(new_var, is_pos) || new_var == var) + // The user did not interfere return; - } + var = new_var; - // get unassigned variable from enode - var = enode_to_bool(new_enode, new_bit); - - if (var == null_bool_var) - // selected variable is already assigned + // check if the new variable is unassigned + if (ctx.get_assignment(var) != l_undef) throw default_exception("expression in \"decide\" is already assigned"); - - // in case the callback did not decide on a truth value -> let Z3 decide - is_pos = ctx.guess(var, phase); } -bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){ - if (!m_next_split_expr) +bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos) { + if (m_next_split_var == nullptr) return false; - enode* n = ctx.get_enode(m_next_split_expr); - - var = enode_to_bool(n, m_next_split_idx); - - if (var == null_bool_var) + ensure_enode(m_next_split_var); + bool_var b = enode_to_bool(ctx.get_enode(m_next_split_var), m_next_split_idx); + if (b == null_bool_var || ctx.get_assignment(b) != l_undef) return false; - + var = b; is_pos = ctx.guess(var, m_next_split_phase); - m_next_split_expr = nullptr; + m_next_split_var = nullptr; + m_next_split_idx = 0; + m_next_split_phase = l_undef; return true; } -void theory_user_propagator::push_scope_eh() { +void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } @@ -298,16 +293,16 @@ void theory_user_propagator::pop_scope_eh(unsigned num_scopes) { } bool theory_user_propagator::can_propagate() { - return m_qhead < m_prop.size() || m_to_add_qhead < m_to_add.size(); + return m_qhead < m_prop.size() || m_to_add_qhead < m_to_add.size() || m_replay_qhead < m_clauses_to_replay.size(); } void theory_user_propagator::propagate_consequence(prop_info const& prop) { justification* js; - m_lits.reset(); + m_lits.reset(); m_eqs.reset(); - for (expr* id : prop.m_ids) + for (expr* id: prop.m_ids) m_lits.append(m_id2justification[expr2var(id)]); - for (auto const& [a,b] : prop.m_eqs) + for (auto const& [a, b]: prop.m_eqs) if (a != b) m_eqs.push_back(enode_pair(get_enode(expr2var(a)), get_enode(expr2var(b)))); DEBUG_CODE(for (auto const& [a, b] : m_eqs) VERIFY(a->get_root() == b->get_root());); @@ -326,12 +321,10 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) { ctx.set_conflict(js); } else { -#if 1 for (auto& lit : m_lits) lit.neg(); for (auto const& [a,b] : m_eqs) m_lits.push_back(~mk_eq(a->get_expr(), b->get_expr(), false)); -#endif literal lit; if (has_quantifiers(prop.m_conseq)) { @@ -342,22 +335,23 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) { lit = mk_literal(fn); } else - lit = mk_literal(prop.m_conseq); + lit = mk_literal(prop.m_conseq); ctx.mark_as_relevant(lit); -#if 0 - justification* js = - ctx.mk_justification( - ext_theory_propagation_justification( - get_id(), ctx, m_lits.size(), m_lits.data(), m_eqs.size(), m_eqs.data(), lit)); - - ctx.assign(lit, js); -#endif - -#if 1 m_lits.push_back(lit); - ctx.mk_th_lemma(get_id(), m_lits); -#endif + if (ctx.get_fparams().m_up_persist_clauses) { + ctx.mk_th_axiom(get_id(), m_lits); + expr_ref_vector clause(m); + for (auto lit : m_lits) + clause.push_back(ctx.literal2expr(lit)); + m_clauses_to_replay.push_back(clause); + if (m_replay_qhead + 1 < m_clauses_to_replay.size()) + std::swap(m_clauses_to_replay[m_replay_qhead], m_clauses_to_replay[m_clauses_to_replay.size()-1]); + ++m_replay_qhead; + } + else { + ctx.mk_th_lemma(get_id(), m_lits); + } TRACE("user_propagate", ctx.display(tout);); } } @@ -368,12 +362,20 @@ void theory_user_propagator::propagate_new_fixed(prop_info const& prop) { void theory_user_propagator::propagate() { - if (m_qhead == m_prop.size() && m_to_add_qhead == m_to_add.size()) + if (m_qhead == m_prop.size() && m_to_add_qhead == m_to_add.size() && m_replay_qhead == m_clauses_to_replay.size()) return; TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n"); force_push(); - - unsigned qhead = m_to_add_qhead; + + unsigned qhead = m_replay_qhead; + if (qhead < m_clauses_to_replay.size()) { + for (; qhead < m_clauses_to_replay.size() && !ctx.inconsistent(); ++qhead) + replay_clause(m_clauses_to_replay.get(qhead)); + ctx.push_trail(value_trail(m_replay_qhead)); + m_replay_qhead = qhead; + } + + qhead = m_to_add_qhead; if (qhead < m_to_add.size()) { for (; qhead < m_to_add.size(); ++qhead) add_expr(m_to_add.get(qhead), true); @@ -396,11 +398,18 @@ void theory_user_propagator::propagate() { } +void theory_user_propagator::replay_clause(expr_ref_vector const& clause) { + m_lits.reset(); + for (expr* e : clause) + m_lits.push_back(mk_literal(e)); + ctx.mk_th_axiom(get_id(), m_lits); +} + bool theory_user_propagator::internalize_atom(app* atom, bool gate_ctx) { return internalize_term(atom); } -bool theory_user_propagator::internalize_term(app* term) { +bool theory_user_propagator::internalize_term(app* term) { for (auto arg : *term) ensure_enode(arg); if (term->get_family_id() == get_id() && !ctx.e_internalized(term)) @@ -421,9 +430,9 @@ bool theory_user_propagator::internalize_term(app* term) { return true; } -void theory_user_propagator::collect_statistics(::statistics & st) const { +void theory_user_propagator::collect_statistics(::statistics& st) const { st.update("user-propagations", m_stats.m_num_propagations); - st.update("user-watched", get_num_vars()); + st.update("user-watched", get_num_vars()); } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 2ed1acbdff6..5dbae59ed5d 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -83,9 +83,11 @@ namespace smt { expr_ref_vector m_to_add; unsigned_vector m_to_add_lim; unsigned m_to_add_qhead = 0; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + expr* m_next_split_var = nullptr; + unsigned m_next_split_idx = 0; + lbool m_next_split_phase = l_undef; + vector m_clauses_to_replay; + unsigned m_replay_qhead = 0; expr* var2expr(theory_var v) { return m_var2expr.get(v); } theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; } @@ -101,6 +103,8 @@ namespace smt { bool_var enode_to_bool(enode* n, unsigned bit); + void replay_clause(expr_ref_vector const& clause); + public: theory_user_propagator(context& ctx); @@ -131,9 +135,9 @@ namespace smt { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); void decide(bool_var& var, bool& is_pos); diff --git a/src/smt/theory_utvpi.cpp b/src/smt/theory_utvpi.cpp index da2883938fa..2072474b017 100644 --- a/src/smt/theory_utvpi.cpp +++ b/src/smt/theory_utvpi.cpp @@ -11,7 +11,7 @@ Module Name: Revision History: - The implementaton is derived from theory_diff_logic. + The implementation is derived from theory_diff_logic. --*/ #include "smt/theory_utvpi.h" diff --git a/src/solver/assertions/asserted_formulas.cpp b/src/solver/assertions/asserted_formulas.cpp index 4e64ee39fd0..ee94ac3555b 100644 --- a/src/solver/assertions/asserted_formulas.cpp +++ b/src/solver/assertions/asserted_formulas.cpp @@ -161,6 +161,9 @@ void asserted_formulas::assert_expr(expr * e, proof * _in_pr) { if (inconsistent()) return; + if (m.is_true(e)) + return; + if (m_smt_params.m_preprocess) { TRACE("assert_expr_bug", tout << r << "\n";); set_eliminate_and(false); // do not eliminate and before nnf. @@ -507,7 +510,8 @@ void asserted_formulas::simplify_fmls::operator()() { else { af.push_assertion(result, result_pr, new_fmls); } - if (af.canceled()) return; + if (af.canceled()) + return; } af.swap_asserted_formulas(new_fmls); TRACE("asserted_formulas", af.display(tout);); diff --git a/src/solver/check_logic.cpp b/src/solver/check_logic.cpp index 231c21a807e..36c08c4d590 100644 --- a/src/solver/check_logic.cpp +++ b/src/solver/check_logic.cpp @@ -38,7 +38,7 @@ struct check_logic::imp { datatype_util m_dt_util; pb_util m_pb_util; bool m_uf; // true if the logic supports uninterpreted functions - bool m_dt; // true if the lgoic supports dattypes + bool m_dt; // true if the logic supports dattypes bool m_arrays; // true if the logic supports arbitrary arrays bool m_bv_arrays; // true if the logic supports only bv arrays bool m_reals; // true if the logic supports reals diff --git a/src/solver/check_sat_result.cpp b/src/solver/check_sat_result.cpp index c0f3979aa3a..fd7939d9503 100644 --- a/src/solver/check_sat_result.cpp +++ b/src/solver/check_sat_result.cpp @@ -22,7 +22,7 @@ void check_sat_result::set_reason_unknown(event_handler& eh) { switch (eh.caller_id()) { case UNSET_EH_CALLER: if (reason_unknown() == "") - set_reason_unknown("unclassifed exception"); + set_reason_unknown("unclassified exception"); break; case CTRL_C_EH_CALLER: set_reason_unknown("interrupted from keyboard"); diff --git a/src/solver/parallel_params.pyg b/src/solver/parallel_params.pyg index 628d4242af6..60a77d49aac 100644 --- a/src/solver/parallel_params.pyg +++ b/src/solver/parallel_params.pyg @@ -10,7 +10,7 @@ def_module_params('parallel', ('conquer.delay', UINT, 10, 'delay of cubes until applying conquer'), ('conquer.backtrack_frequency', UINT, 10, 'frequency to apply core minimization during conquer'), ('simplify.exp', DOUBLE, 1, 'restart and inprocess max is multiplied by simplify.exp ^ depth'), - ('simplify.max_conflicts', UINT, UINT_MAX, 'maximal number of conflicts during simplifcation phase'), + ('simplify.max_conflicts', UINT, UINT_MAX, 'maximal number of conflicts during simplification phase'), ('simplify.restart.max', UINT, 5000, 'maximal number of restarts during simplification phase'), ('simplify.inprocess.max', UINT, 2, 'maximal number of inprocessing steps during simplification'), )) diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index d70d232e47a..ed645dae0b2 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -24,6 +24,7 @@ Module Name: #include "ast/rewriter/expr_safe_replace.h" #include "ast/simplifiers/dependent_expr_state.h" #include "ast/simplifiers/then_simplifier.h" +#include "ast/rewriter/th_rewriter.h" #include "solver/solver.h" #include "solver/simplifier_solver.h" #include "solver/solver_preprocess.h" @@ -35,17 +36,21 @@ class simplifier_solver : public solver { struct dep_expr_state : public dependent_expr_state { simplifier_solver& s; model_reconstruction_trail m_reconstruction_trail; + bool m_updated = false; dep_expr_state(simplifier_solver& s) :dependent_expr_state(s.m), s(s), m_reconstruction_trail(s.m, m_trail) {} ~dep_expr_state() override {} - virtual unsigned qtail() const override { return s.m_fmls.size(); } + unsigned qtail() const override { return s.m_fmls.size(); } dependent_expr const& operator[](unsigned i) override { return s.m_fmls[i]; } void update(unsigned i, dependent_expr const& j) override { SASSERT(j.fml()); check_false(j.fml()); - s.m_fmls[i] = j; + s.m_fmls[i] = j; + m_updated = true; } - void add(dependent_expr const& j) override { check_false(j.fml()); s.m_fmls.push_back(j); } + void add(dependent_expr const& j) override { m_updated = true; check_false(j.fml()); s.m_fmls.push_back(j); } bool inconsistent() override { return s.m_inconsistent; } + bool updated() override { return m_updated; } + void reset_updated() override { m_updated = false; } model_reconstruction_trail& model_trail() override { return m_reconstruction_trail; } std::ostream& display(std::ostream& out) const override { unsigned i = 0; @@ -62,12 +67,23 @@ class simplifier_solver : public solver { if (s.m.is_false(f)) s.set_inconsistent(); } - void replay(unsigned qhead, expr_ref_vector& assumptions) { m_reconstruction_trail.replay(qhead, assumptions, *this); } + void replay(unsigned qhead, expr_ref_vector& assumptions) { + m_reconstruction_trail.replay(qhead, assumptions, *this); + th_rewriter rw(s.m); + expr_ref tmp(s.m); + for (unsigned i = 0; i < assumptions.size(); ++i) { + tmp = assumptions.get(i); + rw(tmp); + assumptions[i] = tmp; + } + } void flatten_suffix() override { expr_mark seen; unsigned j = qhead(); + expr_ref_vector pinned(s.m); for (unsigned i = qhead(); i < qtail(); ++i) { - expr* f = s.m_fmls[i].fml(); + expr* f = s.m_fmls[i].fml(), *g = nullptr; + pinned.push_back(f); if (seen.is_marked(f)) continue; seen.mark(f, true); @@ -79,6 +95,12 @@ class simplifier_solver : public solver { add(dependent_expr(s.m, arg, nullptr, d)); continue; } + if (s.m.is_not(f, g) && s.m.is_or(g)) { + auto* d = s.m_fmls[i].dep(); + for (expr* arg : *to_app(g)) + add(dependent_expr(s.m, mk_not(s.m, arg), nullptr, d)); + continue; + } if (i != j) s.m_fmls[j] = s.m_fmls[i]; ++j; @@ -109,14 +131,16 @@ class simplifier_solver : public solver { unsigned qhead = m_preprocess_state.qhead(); expr_ref_vector orig_assumptions(assumptions); m_core_replace.reset(); - if (qhead < m_fmls.size() || !assumptions.empty()) { - TRACE("solver", tout << "qhead " << qhead << "\n"); - m_preprocess_state.replay(qhead, assumptions); - m_preprocess_state.freeze(assumptions); + if (qhead < m_fmls.size()) { m_preprocess.reduce(); if (!m.inc()) return; + TRACE("solver", tout << "qhead " << qhead << "\n"; + m_preprocess_state.display(tout)); m_preprocess_state.advance_qhead(); + } + if (!assumptions.empty()) { + m_preprocess_state.replay(m_preprocess_state.qhead(), assumptions); for (unsigned i = 0; i < assumptions.size(); ++i) m_core_replace.insert(assumptions.get(i), orig_assumptions.get(i)); } @@ -202,6 +226,7 @@ class simplifier_solver : public solver { lbool check_sat_core(unsigned num_assumptions, expr* const* assumptions) override { expr_ref_vector _assumptions(m, num_assumptions, assumptions); flush(_assumptions); + TRACE("simplifier", tout << _assumptions); return s->check_sat_core(num_assumptions, _assumptions.data()); } diff --git a/src/solver/solver_preprocess.cpp b/src/solver/solver_preprocess.cpp index 9cac4b835bd..1ac7bb8b44e 100644 --- a/src/solver/solver_preprocess.cpp +++ b/src/solver/solver_preprocess.cpp @@ -47,6 +47,17 @@ Module Name: void init_preprocess(ast_manager& m, params_ref const& p, then_simplifier& s, dependent_expr_state& st) { + auto mk_bound_simplifier = [&]() { + auto* s1 = alloc(bound_simplifier, m, p, st); + auto* s2 = alloc(then_simplifier, m, p, st); + s2->add_simplifier(alloc(rewriter_simplifier, m, p, st)); + s2->add_simplifier(alloc(propagate_values, m, p, st)); + s2->add_simplifier(alloc(euf::solve_eqs, m, st)); + auto* r = alloc(if_change_simplifier, m, p, st); + r->add_simplifier(s1); + r->add_simplifier(s2); + return r; + }; smt_params smtp(p); s.add_simplifier(alloc(rewriter_simplifier, m, p, st)); if (smtp.m_propagate_values) s.add_simplifier(alloc(propagate_values, m, p, st)); @@ -60,7 +71,7 @@ void init_preprocess(ast_manager& m, params_ref const& p, then_simplifier& s, de if (smtp.m_refine_inj_axiom) s.add_simplifier(alloc(refine_inj_axiom_simplifier, m, p, st)); if (smtp.m_bv_size_reduce) s.add_simplifier(alloc(bv::slice, m, st)); if (smtp.m_distribute_forall) s.add_simplifier(alloc(distribute_forall_simplifier, m, p, st)); - if (smtp.m_bound_simplifier) s.add_simplifier(alloc(bound_simplifier, m, p, st)); + if (smtp.m_bound_simplifier) s.add_simplifier(mk_bound_simplifier()); if (smtp.m_eliminate_bounds) s.add_simplifier(alloc(elim_bounds_simplifier, m, p, st)); if (smtp.m_simplify_bit2int) s.add_simplifier(alloc(bit2int_simplifier, m, p, st)); if (smtp.m_bb_quantifiers) s.add_simplifier(alloc(bv::elim_simplifier, m, p, st)); diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index cc3ac9336db..861a8318545 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -25,6 +25,9 @@ Module Name: #include "solver/tactic2solver.h" #include "solver/solver_na2as.h" #include "solver/mus.h" +#include "smt/params/smt_params.h" +#include "smt/params/smt_params_helper.hpp" + /** \brief Simulates the incremental solver interface using a tactic. @@ -48,6 +51,7 @@ class tactic2solver : public solver_na2as { bool m_produce_proofs; bool m_produce_unsat_cores; statistics m_stats; + bool m_minimizing = false; public: tactic2solver(ast_manager & m, tactic * t, params_ref const & p, bool produce_proofs, bool produce_models, bool produce_unsat_cores, symbol const & logic); @@ -173,6 +177,7 @@ tactic2solver::~tactic2solver() { void tactic2solver::updt_params(params_ref const & p) { solver::updt_params(p); + m_produce_unsat_cores |= p.get_bool("unsat_core", false); } void tactic2solver::collect_param_descrs(param_descrs & r) { @@ -289,9 +294,6 @@ solver* tactic2solver::translate(ast_manager& m, params_ref const& p) { tactic* t = m_tactic->translate(m); tactic2solver* r = alloc(tactic2solver, m, t, p, m_produce_proofs, m_produce_models, m_produce_unsat_cores, m_logic); r->m_result = nullptr; - if (!m_scopes.empty()) { - throw default_exception("translation of contexts is only supported at base level"); - } ast_translation tr(m_assertions.get_manager(), m, false); for (unsigned i = 0; i < get_num_assertions(); ++i) { @@ -309,6 +311,16 @@ void tactic2solver::collect_statistics(statistics & st) const { void tactic2solver::get_unsat_core(expr_ref_vector & r) { if (m_result.get()) { m_result->get_unsat_core(r); + if (!m_minimizing && smt_params_helper(get_params()).core_minimize()) { + flet minimizing(m_minimizing, true); + mus mus(*this); + mus.add_soft(r.size(), r.data()); + expr_ref_vector r2(m); + if (l_true == mus.get_mus(r2)) { + r.reset(); + r.append(r2); + } + } } } diff --git a/src/tactic/arith/purify_arith_tactic.cpp b/src/tactic/arith/purify_arith_tactic.cpp index db19863987a..40ee967ba8c 100644 --- a/src/tactic/arith/purify_arith_tactic.cpp +++ b/src/tactic/arith/purify_arith_tactic.cpp @@ -444,7 +444,7 @@ struct purify_arith_proc { expr * x = args[0]; bool is_int = u().is_int(x); - expr * k = mk_fresh_var(is_int); + expr * k = mk_fresh_var(false); result = k; mk_def_proof(k, t, result_pr); cache_result(t, result, result_pr); @@ -454,7 +454,7 @@ struct purify_arith_proc { if (y.is_zero()) { expr* p0; if (is_int) { - if (!m_ipower0) m_ipower0 = mk_fresh_var(true); + if (!m_ipower0) m_ipower0 = mk_fresh_var(false); p0 = m_ipower0; } else { diff --git a/src/tactic/core/simplify_tactic.cpp b/src/tactic/core/simplify_tactic.cpp index 8d9ff759f3d..f05b4c4fc62 100644 --- a/src/tactic/core/simplify_tactic.cpp +++ b/src/tactic/core/simplify_tactic.cpp @@ -42,6 +42,10 @@ struct simplify_tactic::imp { m_num_steps = 0; } + void collect_statistics(statistics& st) { + st.update("rewriter.steps", m_num_steps); + } + void operator()(goal & g) { tactic_report report("simplifier", g); m_num_steps = 0; @@ -108,6 +112,11 @@ void simplify_tactic::cleanup() { new (m_imp) imp(m, p); } +void simplify_tactic::collect_statistics(statistics& st) const { + if (m_imp) + m_imp->collect_statistics(st); +} + unsigned simplify_tactic::get_num_steps() const { return m_imp->get_num_steps(); } diff --git a/src/tactic/core/simplify_tactic.h b/src/tactic/core/simplify_tactic.h index 1594b3d3795..7baabb8d6e8 100644 --- a/src/tactic/core/simplify_tactic.h +++ b/src/tactic/core/simplify_tactic.h @@ -81,6 +81,8 @@ class simplify_tactic : public tactic { static void get_param_descrs(param_descrs & r); void collect_param_descrs(param_descrs & r) override { get_param_descrs(r); } + + void collect_statistics(statistics& st) const override; void operator()(goal_ref const & in, goal_ref_buffer & result) override; diff --git a/src/tactic/core/tseitin_cnf_tactic.cpp b/src/tactic/core/tseitin_cnf_tactic.cpp index 411b8aa6edb..dce1fe459fe 100644 --- a/src/tactic/core/tseitin_cnf_tactic.cpp +++ b/src/tactic/core/tseitin_cnf_tactic.cpp @@ -905,10 +905,10 @@ class tseitin_cnf_tactic : public tactic { void collect_param_descrs(param_descrs & r) override { insert_max_memory(r); - r.insert("common_patterns", CPK_BOOL, "minimize the number of auxiliary variables during CNF encoding by identifing commonly used patterns", "true"); + r.insert("common_patterns", CPK_BOOL, "minimize the number of auxiliary variables during CNF encoding by identifying commonly used patterns", "true"); r.insert("distributivity", CPK_BOOL, "minimize the number of auxiliary variables during CNF encoding by applying distributivity over unshared subformulas", "true"); r.insert("distributivity_blowup", CPK_UINT, "maximum overhead for applying distributivity during CNF encoding", "32"); - r.insert("ite_chaing", CPK_BOOL, "minimize the number of auxiliary variables during CNF encoding by identifing if-then-else chains", "true"); + r.insert("ite_chaing", CPK_BOOL, "minimize the number of auxiliary variables during CNF encoding by identifying if-then-else chains", "true"); r.insert("ite_extra", CPK_BOOL, "add redundant clauses (that improve unit propagation) when encoding if-then-else formulas", "true"); } diff --git a/src/tactic/dependent_expr_state_tactic.h b/src/tactic/dependent_expr_state_tactic.h index 347e147fb33..79c1993b211 100644 --- a/src/tactic/dependent_expr_state_tactic.h +++ b/src/tactic/dependent_expr_state_tactic.h @@ -33,6 +33,7 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { expr_ref_vector m_frozen; scoped_ptr m_simp; scoped_ptr m_model_trail; + bool m_updated = false; void init() { if (!m_simp) { @@ -61,9 +62,9 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { if (m_simp) pop(1); } - + /** - * size(), [](), update() and inconsisent() implement the abstract interface of dependent_expr_state + * size(), [](), update() and inconsistent() implement the abstract interface of dependent_expr_state */ unsigned qtail() const override { return m_goal->size(); } @@ -75,6 +76,7 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { void update(unsigned i, dependent_expr const& j) override { if (inconsistent()) return; + m_updated = true; auto [f, p, d] = j(); m_goal->update(i, f, p, d); } @@ -82,6 +84,7 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { void add(dependent_expr const& j) override { if (inconsistent()) return; + m_updated = true; auto [f, p, d] = j(); m_goal->assert_expr(f, p, d); } @@ -96,6 +99,10 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { char const* name() const override { return m_simp ? m_simp->name() : "null"; } + bool updated() override { return m_updated; } + + void reset_updated() override { m_updated = false; } + void updt_params(params_ref const& p) override { m_params.append(p); init(); @@ -133,6 +140,12 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { cleanup(); } + void collect_statistics(statistics& st) const override { + if (m_simp) + m_simp->collect_statistics(st); + st.copy(m_st); + } + void cleanup() override { if (m_simp) { m_simp->collect_statistics(m_st); @@ -144,13 +157,6 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { m_dep = dependent_expr(m, m.mk_true(), nullptr, nullptr); } - void collect_statistics(statistics& st) const override { - if (m_simp) - m_simp->collect_statistics(st); - else - st.copy(m_st); - } - void reset_statistics() override { if (m_simp) m_simp->reset_statistics(); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 01370812f75..4d0912fdca7 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -433,7 +433,7 @@ namespace smtfd { void populate_model(model_ref& mdl, expr_ref_vector const& terms); /** - * \brief check consistency properties that can only be achived using a global analysis of terms + * \brief check consistency properties that can only be achieved using a global analysis of terms */ void global_check(expr_ref_vector const& core); diff --git a/src/tactic/goal_proof_converter.h b/src/tactic/goal_proof_converter.h index a17ff0ea1ea..cfe0d970998 100644 --- a/src/tactic/goal_proof_converter.h +++ b/src/tactic/goal_proof_converter.h @@ -37,7 +37,7 @@ class subgoal_proof_converter : public proof_converter { } proof_ref operator()(ast_manager & m, unsigned num_source, proof * const * source) override { - // ignore the proofs from the arguments, instead obtain the proofs fromt he subgoals. + // ignore the proofs from the arguments, instead obtain the proofs from the subgoals. SASSERT(num_source == 0); proof_converter_ref_buffer pc_buffer; for (goal_ref g : m_goals) { diff --git a/src/tactic/sls/CMakeLists.txt b/src/tactic/sls/CMakeLists.txt index 436b1742fc0..83599b82792 100644 --- a/src/tactic/sls/CMakeLists.txt +++ b/src/tactic/sls/CMakeLists.txt @@ -1,15 +1,12 @@ z3_add_component(sls_tactic SOURCES - bvsls_opt_engine.cpp - sls_engine.cpp sls_tactic.cpp COMPONENT_DEPENDENCIES bv_tactics core_tactics normal_forms tactic - PYG_FILES - sls_params.pyg + ast_sls TACTIC_HEADERS sls_tactic.h ) diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index e631c23e986..198204d90b0 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -27,8 +27,9 @@ Module Name: #include "tactic/core/nnf_tactic.h" #include "util/stopwatch.h" #include "tactic/sls/sls_tactic.h" -#include "tactic/sls/sls_params.hpp" -#include "tactic/sls/sls_engine.h" +#include "params/sls_params.hpp" +#include "ast/sls/sls_engine.h" +#include "ast/sls/bv_sls.h" class sls_tactic : public tactic { ast_manager & m; @@ -60,6 +61,38 @@ class sls_tactic : public tactic { void collect_param_descrs(param_descrs & r) override { sls_params::collect_param_descrs(r); } + + void run(goal_ref const& g, model_converter_ref& mc) { + if (g->inconsistent()) { + mc = nullptr; + return; + } + + for (unsigned i = 0; i < g->size(); i++) + m_engine->assert_expr(g->form(i)); + + lbool res = m_engine->operator()(); + auto const& stats = m_engine->get_stats(); + if (res == l_true) { + report_tactic_progress("Number of flips:", stats.m_moves); + + for (unsigned i = 0; i < g->size(); i++) + if (!m_engine->get_mpz_manager().is_one(m_engine->get_value(g->form(i)))) { + verbose_stream() << "Terminated before all assertions were SAT!" << std::endl; + NOT_IMPLEMENTED_YET(); + } + + if (g->models_enabled()) { + model_ref mdl = m_engine->get_model(); + mc = model2model_converter(mdl.get()); + TRACE("sls_model", mc->display(tout);); + } + g->reset(); + } + else + mc = nullptr; + + } void operator()(goal_ref const & g, goal_ref_buffer & result) override { @@ -69,7 +102,7 @@ class sls_tactic : public tactic { tactic_report report("sls", *g); model_converter_ref mc; - m_engine->operator()(g, mc); + run(g, mc); g->add(mc.get()); g->inc_depth(); result.push_back(g.get()); @@ -91,11 +124,115 @@ class sls_tactic : public tactic { }; +class bv_sls_tactic : public tactic { + ast_manager& m; + params_ref m_params; + bv::sls* m_sls; + statistics m_st; + +public: + bv_sls_tactic(ast_manager& _m, params_ref const& p) : + m(_m), + m_params(p) { + m_sls = alloc(bv::sls, m); + } + + tactic* translate(ast_manager& m) override { + return alloc(bv_sls_tactic, m, m_params); + } + + ~bv_sls_tactic() override { + dealloc(m_sls); + } + + char const* name() const override { return "bv-sls"; } + + void updt_params(params_ref const& p) override { + m_params.append(p); + m_sls->updt_params(m_params); + } + + void collect_param_descrs(param_descrs& r) override { + sls_params::collect_param_descrs(r); + } + + void run(goal_ref const& g, model_converter_ref& mc) { + if (g->inconsistent()) { + mc = nullptr; + return; + } + + for (unsigned i = 0; i < g->size(); i++) + m_sls->assert_expr(g->form(i)); + + m_sls->init(); + std::function false_eval = [&](expr* e, unsigned idx) { + return false; + }; + m_sls->init_eval(false_eval); + + lbool res = m_sls->operator()(); + auto const& stats = m_sls->get_stats(); + report_tactic_progress("Number of flips:", stats.m_moves); + IF_VERBOSE(20, verbose_stream() << res << "\n"); + IF_VERBOSE(20, m_sls->display(verbose_stream())); + m_st.reset(); + m_sls->collect_statistics(m_st); + if (res == l_true) { + if (g->models_enabled()) { + model_ref mdl = m_sls->get_model(); + mc = model2model_converter(mdl.get()); + TRACE("sls_model", mc->display(tout);); + } + g->reset(); + } + else + mc = nullptr; + + } + + void operator()(goal_ref const& g, + goal_ref_buffer& result) override { + result.reset(); + + TRACE("sls", g->display(tout);); + tactic_report report("sls", *g); + + model_converter_ref mc; + run(g, mc); + g->add(mc.get()); + g->inc_depth(); + result.push_back(g.get()); + } + + void cleanup() override { + + auto* d = alloc(bv::sls, m); + std::swap(d, m_sls); + dealloc(d); + } + + void collect_statistics(statistics& st) const override { + st.copy(m_st); + } + + void reset_statistics() override { + m_sls->reset_statistics(); + m_st.reset(); + } + +}; + static tactic * mk_sls_tactic(ast_manager & m, params_ref const & p) { return and_then(fail_if_not(mk_is_qfbv_probe()), // Currently only QF_BV is supported. clean(alloc(sls_tactic, m, p))); } +tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p) { + return and_then(fail_if_not(mk_is_qfbv_probe()), // Currently only QF_BV is supported. + clean(alloc(bv_sls_tactic, m, p))); +} + static tactic * mk_preamble(ast_manager & m, params_ref const & p) { params_ref main_p; @@ -139,3 +276,9 @@ tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p) { t->updt_params(p); return t; } + +tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p) { + tactic* t = and_then(mk_preamble(m, p), mk_bv_sls_tactic(m, p)); + t->updt_params(p); + return t; +} diff --git a/src/tactic/sls/sls_tactic.h b/src/tactic/sls/sls_tactic.h index 3c0612e6e19..d58d310e3c1 100644 --- a/src/tactic/sls/sls_tactic.h +++ b/src/tactic/sls/sls_tactic.h @@ -24,7 +24,16 @@ class tactic; tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p = params_ref()); +tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); + +tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); + /* ADD_TACTIC("qfbv-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_sls_tactic(m, p)") + + ADD_TACTIC("qfbv-new-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_new_sls_tactic(m, p)") + + ADD_TACTIC("qfbv-new-sls-core", "(try to) solve using stochastic local search for QF_BV.", "mk_bv_sls_tactic(m, p)") + */ diff --git a/src/tactic/smtlogics/qfnia_tactic.cpp b/src/tactic/smtlogics/qfnia_tactic.cpp index 3dd66606df6..0dad56a196e 100644 --- a/src/tactic/smtlogics/qfnia_tactic.cpp +++ b/src/tactic/smtlogics/qfnia_tactic.cpp @@ -117,10 +117,8 @@ tactic * mk_qfnia_tactic(ast_manager & m, params_ref const & p) { mk_report_verbose_tactic("(qfnia-tactic)", 10), mk_qfnia_preamble(m, p), or_else(mk_qfnia_sat_solver(m, p), - try_for(mk_qfnia_smt_solver(m, p), 2000), - mk_qfnia_nlsat_solver(m, p), - mk_qfnia_smt_solver(m, p)) - ) - ; + try_for(mk_qfnia_smt_solver(m, p), 2000), + mk_qfnia_nlsat_solver(m, p), + mk_qfnia_smt_solver(m, p))); } diff --git a/src/tactic/tactic.h b/src/tactic/tactic.h index ddd18733716..652bf8130bd 100644 --- a/src/tactic/tactic.h +++ b/src/tactic/tactic.h @@ -62,7 +62,7 @@ class tactic : public user_propagator::core { */ virtual void operator()(goal_ref const & in, goal_ref_buffer& result) = 0; - virtual void collect_statistics(statistics & st) const { } + virtual void collect_statistics(statistics& st) const { } virtual void reset_statistics() {} virtual void cleanup() = 0; virtual void reset() { cleanup(); } @@ -130,6 +130,7 @@ class skip_tactic : public tactic { void cleanup() override {} tactic * translate(ast_manager & m) override { return this; } char const* name() const override { return "skip"; } + void collect_statistics(statistics& st) const override {} }; tactic * mk_skip_tactic(); diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index 5b1ea95875f..0b8189e8df4 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -1190,6 +1190,9 @@ class fail_if_tactic : public tactic { tactic * translate(ast_manager & m) override { return this; } + + void collect_statistics(statistics& st) const override { + } }; tactic * fail_if(probe * p) { @@ -1216,6 +1219,7 @@ class if_no_proofs_tactical : public unary_tactical { } tactic * translate(ast_manager & m) override { return translate_core(m); } + }; class if_no_unsat_cores_tactical : public unary_tactical { diff --git a/src/tactic/ufbv/quasi_macros_tactic.h b/src/tactic/ufbv/quasi_macros_tactic.h index faa939954fb..ec2a455276e 100644 --- a/src/tactic/ufbv/quasi_macros_tactic.h +++ b/src/tactic/ufbv/quasi_macros_tactic.h @@ -18,7 +18,7 @@ Tactic Documentation ## Tactic quasi-macro-finder ### Short Description -dentifies and applies quasi-macros. +Identifies and applies quasi-macros. ### Long Description diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 68e55be7522..58904a12d16 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -9,9 +9,9 @@ namespace user_propagator { class callback { public: virtual ~callback() = default; - virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; + virtual bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; - virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0; + virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0; }; class context_obj { @@ -26,8 +26,8 @@ namespace user_propagator { typedef std::function push_eh_t; typedef std::function pop_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; - typedef std::function on_clause_eh_t; + typedef std::function decide_eh_t; + typedef std::function on_clause_eh_t; class plugin : public decl_plugin { public: diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 2a2977681c8..d2dcb92ff80 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -41,6 +41,8 @@ add_executable(test-z3 doc.cpp egraph.cpp escaped.cpp + euf_bv_plugin.cpp + euf_arith_plugin.cpp ex.cpp expr_rand.cpp expr_substitution.cpp @@ -108,6 +110,7 @@ add_executable(test-z3 simple_parser.cpp simplex.cpp simplifier.cpp + sls_test.cpp small_object_allocator.cpp smt2print_parse.cpp smt_context.cpp diff --git a/src/test/euf_arith_plugin.cpp b/src/test/euf_arith_plugin.cpp new file mode 100644 index 00000000000..41d629ad573 --- /dev/null +++ b/src/test/euf_arith_plugin.cpp @@ -0,0 +1,106 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_arith_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +unsigned s_var = 0; + +static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, a, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, a.get_family_id()); + return n; +} + +// +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x))); + auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); + TRACE("plugin", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); + g.propagate(); + std::cout << g << "\n"; +} + +static void test2() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nxy = get_node(g, a, a.mk_add(x, y)); + auto* nyx = get_node(g, a, a.mk_add(y, x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); + + TRACE("plugin", tout << "before merge\n" << g << "\n"); + g.merge(nxy, nx, nullptr); + g.merge(nyx, ny, nullptr); + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + SASSERT(nx->get_root() == ny->get_root()); + g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr); + g.propagate(); + std::cout << g << "\n"; +} + +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::arith_plugin, g)); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y)); + auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x)); + auto* nx = get_node(g, a, x); + auto* ny = get_node(g, a, y); + g.merge(nxyy, nx, nullptr); + g.merge(nyxx, ny, nullptr); + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + +void tst_euf_arith_plugin() { + enable_trace("plugin"); + test1(); + test2(); + test3(); +} diff --git a/src/test/euf_bv_plugin.cpp b/src/test/euf_bv_plugin.cpp new file mode 100644 index 00000000000..501bd7b14eb --- /dev/null +++ b/src/test/euf_bv_plugin.cpp @@ -0,0 +1,183 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_bv_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +static unsigned s_var = 0; +static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, b, arg)); + n = g.mk(e, 0, args.size(), args.data()); + g.add_th_var(n, s_var++, b.get_family_id()); + return n; +} + +// align slices, and propagate extensionality +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref y3(bv.mk_extract(31, 24, y), m); + expr_ref y2(bv.mk_extract(23, 8, y), m); + expr_ref y1(bv.mk_extract(7, 0, y), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m); + auto* nx = get_node(g, bv, xx); + auto* ny = get_node(g, bv, yy); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; + SASSERT(nx->get_root() == ny->get_root()); +} + +// propagate values down +static void test2() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m); + g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr); + g.propagate(); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x3)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x)->get_root()->interpreted()); +} + + +// propagate values up +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x3(bv.mk_extract(31, 16, x), m); + expr_ref x2(bv.mk_extract(15, 8, x), m); + expr_ref x1(bv.mk_extract(7, 0, x), m); + expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m); + expr_ref y(m.mk_const("y", u32), m); + g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr); + g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr); + g.propagate(); + SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted()); + SASSERT(get_node(g, bv, x1)->get_root()->interpreted()); + SASSERT(get_node(g, bv, x2)->get_root()->interpreted()); +} + +// propagate extract up +static void test4() { + // concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J]) + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + sort_ref u8(bv.mk_sort(8), m); + sort_ref u16(bv.mk_sort(16), m); + expr_ref a(m.mk_const("a", u8), m); + expr_ref x(m.mk_const("x", u32), m); + expr_ref y(m.mk_const("y", u16), m); + expr_ref x1(bv.mk_extract(15, 8, x), m); + expr_ref x2(bv.mk_extract(23, 16, x), m); + g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr); + g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr); + g.propagate(); + TRACE("bv", tout << g << "\n"); + SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root()); +} + +// iterative slicing +static void test5() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 4, x), m); + expr_ref x2(bv.mk_extract(27, 0, x), m); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + +// iterative slicing +static void test6() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugin(alloc(euf::bv_plugin, g)); + bv_util bv(m); + sort_ref u32(bv.mk_sort(32), m); + + expr_ref x(m.mk_const("x", u32), m); + expr_ref x1(bv.mk_extract(31, 3, x), m); + expr_ref x2(bv.mk_extract(28, 0, x), m); + auto* nx = get_node(g, bv, x1); + auto* ny = get_node(g, bv, x2); + TRACE("bv", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + TRACE("bv", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("bv", tout << "after propagate\n" << g << "\n"); + std::cout << g << "\n"; +} + + +void tst_euf_bv_plugin() { + enable_trace("bv"); + enable_trace("plugin"); + test6(); + return; + test1(); + test2(); + test3(); + test4(); + test5(); + test6(); +} diff --git a/src/test/fuzzing/expr_rand.cpp b/src/test/fuzzing/expr_rand.cpp index 65d5f70c9d2..c0ffcc767aa 100644 --- a/src/test/fuzzing/expr_rand.cpp +++ b/src/test/fuzzing/expr_rand.cpp @@ -241,8 +241,7 @@ void expr_rand::initialize_bv(unsigned num_vars) { void expr_rand::initialize_array(unsigned num_vars, sort* dom, sort* rng) { family_id afid = m_manager.mk_family_id("array"); - parameter p1(dom), p2(rng); - parameter ps[2] = { p1, p2 }; + parameter ps[2] = { parameter(dom), parameter(rng) }; sort* a = m_manager.mk_sort(afid, ARRAY_SORT, 2, ps); sort* ss[3] = { a, dom, rng }; diff --git a/src/test/hwf.cpp b/src/test/hwf.cpp index 8a019ec02b6..b81a9cef308 100644 --- a/src/test/hwf.cpp +++ b/src/test/hwf.cpp @@ -103,7 +103,9 @@ static void bug_to_rational() { static void bug_is_int() { unsigned raw_val[2] = { 2147483648u, 1077720461u }; - double val = *(double*)(raw_val); + double val; + static_assert(sizeof(raw_val) == sizeof(val)); + memcpy(&val, raw_val, sizeof(val)); std::cout << val << "\n"; hwf_manager m; hwf a; diff --git a/src/test/lp/.clang-format b/src/test/lp/.clang-format new file mode 100644 index 00000000000..d70ca6b63d8 --- /dev/null +++ b/src/test/lp/.clang-format @@ -0,0 +1,3 @@ +BasedOnStyle: Google +IndentWidth: 4 +ColumnLimit: 0 diff --git a/src/test/lp/gomory_test.h b/src/test/lp/gomory_test.h index c64c0103653..9ac675d1a35 100644 --- a/src/test/lp/gomory_test.h +++ b/src/test/lp/gomory_test.h @@ -189,7 +189,7 @@ struct gomory_test { void print_term(lar_term & t, std::ostream & out) { vector> row; for (auto p : t) - row.push_back(std::make_pair(p.coeff(), p.column().index())); + row.push_back(std::make_pair(p.coeff(), p.j())); print_row(out, row); } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 9120d64cfdf..84077b98e20 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -19,81 +19,81 @@ --*/ #include + +#include "util/rational.h" #ifndef _WINDOWS #include #endif -#include -#include -#include -#include -#include +#include #include +#include + +#include #include #include -#include +#include +#include +#include #include -#include "math/lp/lp_utils.h" -#include "test/lp/smt_reader.h" -#include "test/lp/argument_parser.h" -#include "test/lp/test_file_reader.h" + +#include "math/lp/cross_nested.h" +#include "math/lp/emonics.h" +#include "math/lp/general_matrix.h" +#include "math/lp/hnf.h" +#include "math/lp/horner.h" #include "math/lp/indexed_value.h" +#include "math/lp/int_cube.h" #include "math/lp/lar_solver.h" -#include "math/lp/numeric_pair.h" -#include "util/stacked_value.h" -#include "math/lp/u_set.h" -#include "util/stopwatch.h" -#include -#include "test/lp/gomory_test.h" -#include "math/lp/matrix.h" -#include "math/lp/hnf.h" -#include "math/lp/general_matrix.h" #include "math/lp/lp_bound_propagator.h" +#include "math/lp/lp_utils.h" +#include "math/lp/matrix.h" #include "math/lp/nla_solver.h" -#include "math/lp/horner.h" -#include "math/lp/cross_nested.h" -#include "math/lp/int_cube.h" -#include "math/lp/emonics.h" +#include "math/lp/numeric_pair.h" #include "math/lp/static_matrix.h" - -bool my_white_space(const char & a) { - return a == ' ' || a == '\t'; -} -size_t number_of_whites(const std::string & s) { +#include "util/uint_set.h" +#include "test/lp/argument_parser.h" +#include "test/lp/gomory_test.h" +#include "test/lp/smt_reader.h" +#include "test/lp/test_file_reader.h" +#include "util/stacked_value.h" +#include "util/stopwatch.h" +void test_patching(); +bool my_white_space(const char &a) { return a == ' ' || a == '\t'; } +size_t number_of_whites(const std::string &s) { size_t i = 0; - for(;i < s.size(); i++) - if (!my_white_space(s[i])) return i; + for (; i < s.size(); i++) + if (!my_white_space(s[i])) + return i; return i; } -size_t number_of_whites_from_end(const std::string & s) { +size_t number_of_whites_from_end(const std::string &s) { size_t ret = 0; - for(int i = static_cast(s.size()) - 1;i >= 0; i--) - if (my_white_space(s[i])) ret++;else break; - + for (int i = static_cast(s.size()) - 1; i >= 0; i--) + if (my_white_space(s[i])) + ret++; + else + break; + return ret; } - std::string <rim(std::string &s) { s.erase(0, number_of_whites(s)); return s; } - - - - // trim from end +// trim from end inline std::string &rtrim(std::string &s) { - // s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun(std::isspace))).base(), s.end()); + // s.erase(std::find_if(s.rbegin(), s.rend(), + // std::not1(std::ptr_fun(std::isspace))).base(), s.end()); s.erase(s.end() - number_of_whites_from_end(s), s.end()); return s; } - // trim from both ends -inline std::string &trim(std::string &s) { - return ltrim(rtrim(s)); -} - +// trim from both ends +inline std::string &trim(std::string &s) { return ltrim(rtrim(s)); } -vector string_split(const std::string &source, const char *delimiter, bool keep_empty) { +vector string_split(const std::string &source, + const char *delimiter, bool keep_empty) { vector results; size_t prev = 0; size_t next = 0; @@ -118,7 +118,6 @@ vector split_and_trim(const std::string &line) { return ret; } - namespace nla { void test_horner(); void test_monics(); @@ -131,7 +130,7 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial(); void test_basic_lemma_for_mon_neutral_from_monomial_to_factors(); void test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); -void test_cn_on_expr(nex_sum *t, cross_nested& cn) { +void test_cn_on_expr(nex_sum *t, cross_nested &cn) { t = to_sum(cn.get_nex_creator().simplify(t)); TRACE("nla_test", tout << "t=" << *t << '\n';); cn.run(t); @@ -147,35 +146,34 @@ void test_nex_order() { r.set_number_of_vars(3); for (unsigned j = 0; j < r.get_number_of_vars(); j++) r.set_var_weight(j, 10 - j); - nex_var* a = r.mk_var(0); - nex_var* b = r.mk_var(1); - nex_var* c = r.mk_var(2); + nex_var *a = r.mk_var(0); + nex_var *b = r.mk_var(1); + nex_var *c = r.mk_var(2); ENSURE(r.gt(a, b)); ENSURE(r.gt(b, c)); ENSURE(r.gt(a, c)); - - - nex* ab = r.mk_mul(a, b); - nex* ba = r.mk_mul(b, a); - nex* ac = r.mk_mul(a, c); + nex *ab = r.mk_mul(a, b); + nex *ba = r.mk_mul(b, a); + nex *ac = r.mk_mul(a, c); ENSURE(r.gt(ab, ac)); ENSURE(!r.gt(ac, ab)); - nex* _3ac = r.mk_mul(rational(3), a, c); - nex* _2ab = r.mk_mul(rational(2), a, b); + nex *_3ac = r.mk_mul(rational(3), a, c); + nex *_2ab = r.mk_mul(rational(2), a, b); ENSURE(r.gt(ab, _3ac)); ENSURE(!r.gt(_3ac, ab)); ENSURE(!r.gt(a, ab)); ENSURE(r.gt(ab, a)); ENSURE(r.gt(_2ab, _3ac)); ENSURE(!r.gt(_3ac, _2ab)); - nex* _2a = r.mk_mul(rational(2), a); + nex *_2a = r.mk_mul(rational(2), a); ENSURE(!r.gt(_2a, _2ab)); ENSURE(r.gt(_2ab, _2a)); ENSURE(nex_creator::equal(ab, ba)); - nex_sum * five_a_pl_one = r.mk_sum(r.mk_mul(rational(5), a), r.mk_scalar(rational(1))); - nex_mul * poly = r.mk_mul(five_a_pl_one, b); - nex * p = r.simplify(poly); + nex_sum *five_a_pl_one = + r.mk_sum(r.mk_mul(rational(5), a), r.mk_scalar(rational(1))); + nex_mul *poly = r.mk_mul(five_a_pl_one, b); + nex *p = r.simplify(poly); std::cout << "poly = " << *poly << " , p = " << *p << "\n"; #endif } @@ -184,216 +182,208 @@ void test_simplify() { #ifdef Z3DEBUG nex_creator r; cross_nested cn( - [](const nex* n) { - TRACE("nla_cn_test", tout << *n << "\n";); - return false; - } , - [](unsigned) { return false; }, - []() { return 1; }, // for random - r); + [](const nex *n) { + TRACE("nla_cn_test", tout << *n << "\n";); + return false; + }, + [](unsigned) { return false; }, []() { return 1; }, // for random + r); enable_trace("nla_cn"); enable_trace("nla_cn_details"); // enable_trace("nla_cn_details_"); enable_trace("nla_test"); - + r.set_number_of_vars(3); for (unsigned j = 0; j < r.get_number_of_vars(); j++) r.set_var_weight(j, j); - nex_var* a = r.mk_var(0); - nex_var* b = r.mk_var(1); - nex_var* c = r.mk_var(2); + nex_var *a = r.mk_var(0); + nex_var *b = r.mk_var(1); + nex_var *c = r.mk_var(2); auto bc = r.mk_mul(b, c); auto a_plus_bc = r.mk_sum(a, bc); auto two_a_plus_bc = r.mk_mul(r.mk_scalar(rational(2)), a_plus_bc); auto simp_two_a_plus_bc = r.simplify(two_a_plus_bc); - TRACE("nla_test", tout << * simp_two_a_plus_bc << "\n";); + TRACE("nla_test", tout << *simp_two_a_plus_bc << "\n";); ENSURE(nex_creator::equal(simp_two_a_plus_bc, two_a_plus_bc)); auto simp_a_plus_bc = r.simplify(a_plus_bc); ENSURE(to_sum(simp_a_plus_bc)->size() > 1); auto three_ab = r.mk_mul(r.mk_scalar(rational(3)), a, b); auto three_ab_square = r.mk_mul(three_ab, three_ab, three_ab); - + TRACE("nla_test", tout << "before simplify " << *three_ab_square << "\n";); three_ab_square = to_mul(r.simplify(three_ab_square)); TRACE("nla_test", tout << *three_ab_square << "\n";); - const rational& s = three_ab_square->coeff(); + const rational &s = three_ab_square->coeff(); ENSURE(s == rational(27)); auto m = r.mk_mul(a, a); TRACE("nla_test_", tout << "m = " << *m << "\n";); /* - auto n = r.mk_mul(b, b, b, b, b, b, b); - n->add_child_in_power(b, 7); - n->add_child(r.mk_scalar(rational(3))); - n->add_child_in_power(r.mk_scalar(rational(2)), 2); - n->add_child(r.mk_scalar(rational(1))); - TRACE("nla_test_", tout << "n = " << *n << "\n";); - m->add_child_in_power(n, 3); - n->add_child_in_power(r.mk_scalar(rational(1, 3)), 2); - TRACE("nla_test_", tout << "m = " << *m << "\n";); - - nex_sum * e = r.mk_sum(a, r.mk_sum(b, m)); - TRACE("nla_test", tout << "before simplify e = " << *e << "\n";); - e = to_sum(r.simplify(e)); - TRACE("nla_test", tout << "simplified e = " << *e << "\n";); - ENSURE(e->children().size() > 2); - nex_sum * e_m = r.mk_sum(); - for (const nex* ex: to_sum(e)->children()) { - nex* ce = r.mk_mul(r.clone(ex), r.mk_scalar(rational(3))); - TRACE("nla_test", tout << "before simpl ce = " << *ce << "\n";); - ce = r.simplify(ce); - TRACE("nla_test", tout << "simplified ce = " << *ce << "\n";); - e_m->add_child(ce); - } - e->add_child(e_m); - TRACE("nla_test", tout << "before simplify sum e = " << *e << "\n";); - e = to_sum(r.simplify(e)); - TRACE("nla_test", tout << "simplified sum e = " << *e << "\n";); - - nex * pr = r.mk_mul(a, b, b); - TRACE("nla_test", tout << "before simplify pr = " << *pr << "\n";); - r.simplify(pr); - TRACE("nla_test", tout << "simplified sum e = " << *pr << "\n";); - */ + auto n = r.mk_mul(b, b, b, b, b, b, b); + n->add_child_in_power(b, 7); + n->add_child(r.mk_scalar(rational(3))); + n->add_child_in_power(r.mk_scalar(rational(2)), 2); + n->add_child(r.mk_scalar(rational(1))); + TRACE("nla_test_", tout << "n = " << *n << "\n";); + m->add_child_in_power(n, 3); + n->add_child_in_power(r.mk_scalar(rational(1, 3)), 2); + TRACE("nla_test_", tout << "m = " << *m << "\n";); + + nex_sum * e = r.mk_sum(a, r.mk_sum(b, m)); + TRACE("nla_test", tout << "before simplify e = " << *e << "\n";); + e = to_sum(r.simplify(e)); + TRACE("nla_test", tout << "simplified e = " << *e << "\n";); + ENSURE(e->children().size() > 2); + nex_sum * e_m = r.mk_sum(); + for (const nex* ex: to_sum(e)->children()) { + nex* ce = r.mk_mul(r.clone(ex), r.mk_scalar(rational(3))); + TRACE("nla_test", tout << "before simpl ce = " << *ce << "\n";); + ce = r.simplify(ce); + TRACE("nla_test", tout << "simplified ce = " << *ce << "\n";); + e_m->add_child(ce); + } + e->add_child(e_m); + TRACE("nla_test", tout << "before simplify sum e = " << *e << "\n";); + e = to_sum(r.simplify(e)); + TRACE("nla_test", tout << "simplified sum e = " << *e << "\n";); + + nex * pr = r.mk_mul(a, b, b); + TRACE("nla_test", tout << "before simplify pr = " << *pr << "\n";); + r.simplify(pr); + TRACE("nla_test", tout << "simplified sum e = " << *pr << "\n";); + */ #endif } void test_cn_shorter() { -// nex_sum *clone; -// nex_creator cr; -// cross_nested cn( -// [](const nex* n) { -// TRACE("nla_test", tout <<"cn form = " << *n << "\n"; - -// ); -// return false; -// } , -// [](unsigned) { return false; }, -// []{ return 1; }, cr); -// enable_trace("nla_test"); -// enable_trace("nla_cn"); -// enable_trace("nla_cn_test"); -// enable_trace("nla_cn_details"); -// // enable_trace("nla_cn_details_"); -// enable_trace("nla_test_details"); -// cr.set_number_of_vars(20); -// for (unsigned j = 0; j < cr.get_number_of_vars(); j++) -// cr.set_var_weight(j,j); - -// nex_var* a = cr.mk_var(0); -// nex_var* b = cr.mk_var(1); -// nex_var* c = cr.mk_var(2); -// nex_var* d = cr.mk_var(3); -// nex_var* e = cr.mk_var(4); -// nex_var* g = cr.mk_var(6); - -// nex* min_1 = cr.mk_scalar(rational(-1)); -// // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); -// nex_mul* bcg = cr.mk_mul(b, c, g); -// /* -// bcg->add_child(min_1); -// nex* abcd = cr.mk_mul(a, b, c, d); -// nex* eae = cr.mk_mul(e, a, e); -// nex* three_eac = cr.mk_mul(e, a, c); to_mul(three_eac)->coeff() = rational(3); -// nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, d); -// clone = to_sum(cr.clone(cr.mk_sum(_6aad, abcd, eae, three_eac))); -// clone = to_sum(cr.simplify(clone)); -// TRACE("nla_test", tout << "clone = " << *clone << "\n";); -// // test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); -// test_cn_on_expr(clone, cn); -// */ + // nex_sum *clone; + // nex_creator cr; + // cross_nested cn( + // [](const nex* n) { + // TRACE("nla_test", tout <<"cn form = " << *n << "\n"; + + // ); + // return false; + // } , + // [](unsigned) { return false; }, + // []{ return 1; }, cr); + // enable_trace("nla_test"); + // enable_trace("nla_cn"); + // enable_trace("nla_cn_test"); + // enable_trace("nla_cn_details"); + // // enable_trace("nla_cn_details_"); + // enable_trace("nla_test_details"); + // cr.set_number_of_vars(20); + // for (unsigned j = 0; j < cr.get_number_of_vars(); j++) + // cr.set_var_weight(j,j); + + // nex_var* a = cr.mk_var(0); + // nex_var* b = cr.mk_var(1); + // nex_var* c = cr.mk_var(2); + // nex_var* d = cr.mk_var(3); + // nex_var* e = cr.mk_var(4); + // nex_var* g = cr.mk_var(6); + + // nex* min_1 = cr.mk_scalar(rational(-1)); + // // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); + // nex_mul* bcg = cr.mk_mul(b, c, g); + // /* + // bcg->add_child(min_1); + // nex* abcd = cr.mk_mul(a, b, c, d); + // nex* eae = cr.mk_mul(e, a, e); + // nex* three_eac = cr.mk_mul(e, a, c); to_mul(three_eac)->coeff() = + // rational(3); nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, + // d); clone = to_sum(cr.clone(cr.mk_sum(_6aad, abcd, eae, three_eac))); + // clone = to_sum(cr.simplify(clone)); + // TRACE("nla_test", tout << "clone = " << *clone << "\n";); + // // test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), + // cn); test_cn_on_expr(clone, cn); + // */ } void test_cn() { -// #ifdef Z3DEBUG -// test_cn_shorter(); -// nex_creator cr; -// cross_nested cn( -// [](const nex* n) { -// TRACE("nla_test", tout <<"cn form = " << *n << "\n";); -// return false; -// } , -// [](unsigned) { return false; }, -// []{ return 1; }, cr); -// enable_trace("nla_test"); -// enable_trace("nla_cn_test"); -// // enable_trace("nla_cn"); -// // enable_trace("nla_test_details"); -// cr.set_number_of_vars(20); -// for (unsigned j = 0; j < cr.get_number_of_vars(); j++) -// cr.set_var_weight(j, j); - -// nex_var* a = cr.mk_var(0); -// nex_var* b = cr.mk_var(1); -// nex_var* c = cr.mk_var(2); -// nex_var* d = cr.mk_var(3); -// nex_var* e = cr.mk_var(4); -// nex_var* g = cr.mk_var(6); -// nex_sum * a_p_ae_sq = cr.mk_sum(a, cr.mk_mul(a, e, e)); -// a_p_ae_sq = to_sum(cr.simplify(a_p_ae_sq)); -// test_cn_on_expr(a_p_ae_sq, cn); - -// nex* min_1 = cr.mk_scalar(rational(-1)); -// // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); -// nex* bcd = cr.mk_mul(b, c, d); -// nex_mul* bcg = cr.mk_mul(b, c, g); -// /* -// bcg->add_child(min_1); -// nex_sum* t = cr.mk_sum(bcd, bcg); -// test_cn_on_expr(t, cn); -// nex* abd = cr.mk_mul(a, b, d); -// nex* abc = cr.mk_mul(a, b, c); -// nex* abcd = cr.mk_mul(a, b, c, d); -// nex* aaccd = cr.mk_mul(a, a, c, c, d); -// nex* add = cr.mk_mul(a, d, d); -// nex* eae = cr.mk_mul(e, a, e); -// nex* eac = cr.mk_mul(e, a, c); -// nex* ed = cr.mk_mul(e, d); -// nex* cbd = cr.mk_mul(c, b, d); -// nex* acd = cr.mk_mul(a, c, d); - -// nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, d); -// nex * clone = cr.clone(cr.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); -// clone = cr.simplify(clone); -// ENSURE(cr.is_simplified(clone)); -// TRACE("nla_test", tout << "clone = " << *clone << "\n";); -// // test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); -// test_cn_on_expr(to_sum(clone), cn); -// TRACE("nla_test", tout << "done\n";); -// test_cn_on_expr(cr.mk_sum(abd, abc, cbd, acd), cn); -// TRACE("nla_test", tout << "done\n";);*/ -// #endif -// // test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); -// // TRACE("nla_test", tout << "done\n";); -// // test_cn_on_expr(a*b*d + a*b*c + c*b*d); -} - -} // end of namespace nla + // #ifdef Z3DEBUG + // test_cn_shorter(); + // nex_creator cr; + // cross_nested cn( + // [](const nex* n) { + // TRACE("nla_test", tout <<"cn form = " << *n << "\n";); + // return false; + // } , + // [](unsigned) { return false; }, + // []{ return 1; }, cr); + // enable_trace("nla_test"); + // enable_trace("nla_cn_test"); + // // enable_trace("nla_cn"); + // // enable_trace("nla_test_details"); + // cr.set_number_of_vars(20); + // for (unsigned j = 0; j < cr.get_number_of_vars(); j++) + // cr.set_var_weight(j, j); + + // nex_var* a = cr.mk_var(0); + // nex_var* b = cr.mk_var(1); + // nex_var* c = cr.mk_var(2); + // nex_var* d = cr.mk_var(3); + // nex_var* e = cr.mk_var(4); + // nex_var* g = cr.mk_var(6); + // nex_sum * a_p_ae_sq = cr.mk_sum(a, cr.mk_mul(a, e, e)); + // a_p_ae_sq = to_sum(cr.simplify(a_p_ae_sq)); + // test_cn_on_expr(a_p_ae_sq, cn); + + // nex* min_1 = cr.mk_scalar(rational(-1)); + // // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); + // nex* bcd = cr.mk_mul(b, c, d); + // nex_mul* bcg = cr.mk_mul(b, c, g); + // /* + // bcg->add_child(min_1); + // nex_sum* t = cr.mk_sum(bcd, bcg); + // test_cn_on_expr(t, cn); + // nex* abd = cr.mk_mul(a, b, d); + // nex* abc = cr.mk_mul(a, b, c); + // nex* abcd = cr.mk_mul(a, b, c, d); + // nex* aaccd = cr.mk_mul(a, a, c, c, d); + // nex* add = cr.mk_mul(a, d, d); + // nex* eae = cr.mk_mul(e, a, e); + // nex* eac = cr.mk_mul(e, a, c); + // nex* ed = cr.mk_mul(e, d); + // nex* cbd = cr.mk_mul(c, b, d); + // nex* acd = cr.mk_mul(a, c, d); + + // nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, d); + // nex * clone = cr.clone(cr.mk_sum(_6aad, abcd, aaccd, add, eae, eac, + // ed)); clone = cr.simplify(clone); ENSURE(cr.is_simplified(clone)); + // TRACE("nla_test", tout << "clone = " << *clone << "\n";); + // // test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), + // cn); test_cn_on_expr(to_sum(clone), cn); TRACE("nla_test", tout << + // "done\n";); test_cn_on_expr(cr.mk_sum(abd, abc, cbd, acd), cn); + // TRACE("nla_test", tout << "done\n";);*/ + // #endif + // // test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); + // // TRACE("nla_test", tout << "done\n";); + // // test_cn_on_expr(a*b*d + a*b*c + c*b*d); +} + +} // end of namespace nla namespace lp { unsigned seed = 1; - random_gen g_rand; -static unsigned my_random() { - return g_rand(); -} -struct simple_column_namer:public column_namer -{ +static unsigned my_random() { return g_rand(); } +struct simple_column_namer : public column_namer { std::string get_variable_name(unsigned j) const override { - return std::string("x") + T_to_string(j); + return std::string("x") + T_to_string(j); } }; - - -vector allocate_basis_heading(unsigned count) { // the rest of initialization will be handled by lu_QR +vector allocate_basis_heading( + unsigned count) { // the rest of initialization will be handled by lu_QR vector basis_heading(count, -1); return basis_heading; } - -void init_basic_part_of_basis_heading(vector & basis, vector & basis_heading) { +void init_basic_part_of_basis_heading(vector &basis, + vector &basis_heading) { lp_assert(basis_heading.size() >= basis.size()); unsigned m = basis.size(); for (unsigned i = 0; i < m; i++) { @@ -402,26 +392,28 @@ void init_basic_part_of_basis_heading(vector & basis, vector & ba } } -void init_non_basic_part_of_basis_heading(vector & basis_heading, vector & non_basic_columns) { +void init_non_basic_part_of_basis_heading(vector &basis_heading, + vector &non_basic_columns) { non_basic_columns.clear(); - for (int j = basis_heading.size(); j--;){ + for (int j = basis_heading.size(); j--;) { if (basis_heading[j] < 0) { non_basic_columns.push_back(j); // the index of column j in m_nbasis is (- basis_heading[j] - 1) - basis_heading[j] = - static_cast(non_basic_columns.size()); + basis_heading[j] = -static_cast(non_basic_columns.size()); } } } -void init_basis_heading_and_non_basic_columns_vector(vector & basis, - vector & basis_heading, - vector & non_basic_columns) { +void init_basis_heading_and_non_basic_columns_vector( + vector &basis, vector &basis_heading, + vector &non_basic_columns) { init_basic_part_of_basis_heading(basis, basis_heading); init_non_basic_part_of_basis_heading(basis_heading, non_basic_columns); } -void change_basis(unsigned entering, unsigned leaving, vector& basis, vector& nbasis, vector & basis_heading) { - int place_in_basis = basis_heading[leaving]; - int place_in_non_basis = - basis_heading[entering] - 1; +void change_basis(unsigned entering, unsigned leaving, vector &basis, + vector &nbasis, vector &basis_heading) { + int place_in_basis = basis_heading[leaving]; + int place_in_non_basis = -basis_heading[entering] - 1; basis_heading[entering] = place_in_basis; basis_heading[leaving] = -place_in_non_basis - 1; basis[place_in_basis] = entering; @@ -430,7 +422,8 @@ void change_basis(unsigned entering, unsigned leaving, vector& basis, int perm_id = 0; -bool get_int_from_args_parser(const char * option, argument_parser & args_parser, unsigned & n) { +bool get_int_from_args_parser(const char *option, argument_parser &args_parser, + unsigned &n) { std::string s = args_parser.get_option_value(option); if (!s.empty()) { n = atoi(s.c_str()); @@ -439,7 +432,8 @@ bool get_int_from_args_parser(const char * option, argument_parser & args_parser return false; } -bool get_double_from_args_parser(const char * option, argument_parser & args_parser, double & n) { +bool get_double_from_args_parser(const char *option, + argument_parser &args_parser, double &n) { std::string s = args_parser.get_option_value(option); if (!s.empty()) { n = atof(s.c_str()); @@ -448,196 +442,18 @@ bool get_double_from_args_parser(const char * option, argument_parser & args_par return false; } +void get_time_limit_and_max_iters_from_parser( + argument_parser &args_parser, unsigned &time_limit); // forward definition +int get_random_rows() { return 5 + my_random() % 2; } - - -void get_time_limit_and_max_iters_from_parser(argument_parser & args_parser, unsigned & time_limit); // forward definition - - - - - - -int get_random_rows() { - return 5 + my_random() % 2; -} - -int get_random_columns() { - return 5 + my_random() % 3; -} +int get_random_columns() { return 5 + my_random() % 3; } int get_random_int() { - return -1 + my_random() % 2; // (1.0 + RAND_MAX); -} - -#ifndef _WINDOWS -void fill_file_names(vector &file_names, std::set & minimums) { - char *home_dir = getenv("HOME"); - if (home_dir == nullptr) { - std::cout << "cannot find home directory, don't know how to find the files"; - return; - } - std::string home_dir_str(home_dir); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l0redund.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l1.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l2.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l3.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l4.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l4fix.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/plan.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/samp2.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/murtagh.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/l0.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/AFIRO.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SC50B.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SC50A.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/KB2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SC105.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STOCFOR1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/ADLITTLE.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BLEND.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCAGR7.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SC205.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHARE2B.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/RECIPELP.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/LOTFI.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/VTP-BASE.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHARE1B.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BOEING2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BORE3D.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCORPION.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/CAPRI.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BRANDY.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCAGR25.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCTAP1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/ISRAEL.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCFXM1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BANDM.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/E226.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/AGG.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GROW7.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/ETAMACRO.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FINNIS.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCSD1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STANDATA.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STANDGUB.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BEACONFD.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STAIR.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STANDMPS.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GFRD-PNC.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCRS8.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BOEING1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/MODSZK1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/DEGEN2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FORPLAN.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/AGG2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/AGG3.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCFXM2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHELL.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOT4.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCSD6.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP04S.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SEBA.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GROW15.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FFFFF800.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BNL1.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PEROLD.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/QAP8.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCFXM3.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP04L.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GANGES.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCTAP2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GROW22.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP08S.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOT-WE.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/MAROS.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STOCFOR2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/25FV47.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP12S.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCSD8.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FIT1P.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SCTAP3.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SIERRA.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOTNOV.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/CZPROB.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FIT1D.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOT-JA.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP08L.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/BNL2.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/NESM.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/CYCLE.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/acc-tight5.mps"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/SHIP12L.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/DEGEN3.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GREENBEA.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/GREENBEB.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/80BAU3B.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/TRUSS.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/D2Q06C.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/WOODW.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/QAP12.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/D6CUBE.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOT.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/DFL001.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/WOOD1P.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FIT2P.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/PILOT87.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/STOCFOR3.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/QAP15.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/FIT2D.SIF"); - file_names.push_back(home_dir_str + "/projects/lp/src/tests/math/lp/test_files/netlib/MAROS-R7.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/FIT2P.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/DFL001.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/D2Q06C.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/80BAU3B.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/GREENBEB.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/GREENBEA.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/BNL2.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/SHIP08L.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/FIT1D.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/SCTAP3.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/SCSD8.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/SCSD6.SIF"); - minimums.insert("/projects/lp/src/tests/math/lp/test_files/netlib/MAROS-R7.SIF"); -} - -void test_out_dir(std::string out_dir) { - auto *out_dir_p = opendir(out_dir.c_str()); - if (out_dir_p == nullptr) { - std::cout << "creating directory " << out_dir << std::endl; -#ifdef LEAN_WINDOWS - int res = mkdir(out_dir.c_str()); -#else - int res = mkdir(out_dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); -#endif - if (res) { - std::cout << "Cannot open output directory \"" << out_dir << "\"" << std::endl; - } - return; - } - closedir(out_dir_p); + return -1 + my_random() % 2; // (1.0 + RAND_MAX); } -void find_dir_and_file_name(std::string a, std::string & dir, std::string& fn) { - // todo: make it system independent - size_t last_slash_pos = a.find_last_of('/'); - if (last_slash_pos >= a.size()) { - std::cout << "cannot find file name in " << a << std::endl; - throw; - } - dir = a.substr(0, last_slash_pos); - // std::cout << "dir = " << dir << std::endl; - fn = a.substr(last_slash_pos + 1); - // std::cout << "fn = " << fn << std::endl; -} - - -#endif - - - -std::string read_line(bool & end, std::ifstream & file) { +std::string read_line(bool &end, std::ifstream &file) { std::string s; if (!getline(file, s)) { end = true; @@ -647,67 +463,108 @@ std::string read_line(bool & end, std::ifstream & file) { return s; } -bool contains(std::string const & s, char const * pattern) { +bool contains(std::string const &s, char const *pattern) { return s.find(pattern) != std::string::npos; } - - -void setup_args_parser(argument_parser & parser) { +void setup_args_parser(argument_parser &parser) { parser.add_option_with_help_string("-monics", "test emonics"); parser.add_option_with_help_string("-nex_order", "test nex order"); parser.add_option_with_help_string("-nla_cn", "test cross nornmal form"); parser.add_option_with_help_string("-nla_sim", "test nex simplify"); - parser.add_option_with_help_string("-nla_blfmz_mf", "test_basic_lemma_for_mon_zero_from_factor_to_monomial"); - parser.add_option_with_help_string("-nla_blfmz_fm", "test_basic_lemma_for_mon_zero_from_monomials_to_factor"); - parser.add_option_with_help_string("-nla_order", "test nla_solver order lemma"); - parser.add_option_with_help_string("-nla_monot", "test nla_solver order lemma"); + parser.add_option_with_help_string( + "-nla_blfmz_mf", "test_basic_lemma_for_mon_zero_from_factor_to_monomial"); + parser.add_option_with_help_string( + "-nla_blfmz_fm", + "test_basic_lemma_for_mon_zero_from_monomials_to_factor"); + parser.add_option_with_help_string("-nla_order", + "test nla_solver order lemma"); + parser.add_option_with_help_string("-nla_monot", + "test nla_solver order lemma"); parser.add_option_with_help_string("-nla_tan", "test_tangent_lemma"); parser.add_option_with_help_string("-nla_bsl", "test_basic_sign_lemma"); parser.add_option_with_help_string("-horner", "test horner's heuristic"); - parser.add_option_with_help_string("-nla_blnt_mf", "test_basic_lemma_for_mon_neutral_from_monomial_to_factors"); - parser.add_option_with_help_string("-nla_blnt_fm", "test_basic_lemma_for_mon_neutral_from_factors_to_monomial"); + parser.add_option_with_help_string( + "-nla_blnt_mf", + "test_basic_lemma_for_mon_neutral_from_monomial_to_factors"); + parser.add_option_with_help_string( + "-nla_blnt_fm", + "test_basic_lemma_for_mon_neutral_from_factors_to_monomial"); parser.add_option_with_help_string("-hnf", "test hermite normal form"); parser.add_option_with_help_string("-gomory", "gomory"); parser.add_option_with_help_string("-intd", "test integer_domain"); - parser.add_option_with_help_string("-xyz_sample", "run a small interactive scenario"); - parser.add_option_with_after_string_with_help("--percent_for_enter", "which percent of columns check for entering column"); - parser.add_option_with_help_string("--totalinf", "minimizes the total infeasibility instead of diminishing infeasibility of the rows"); - parser.add_option_with_after_string_with_help("--rep_frq", "the report frequency, in how many iterations print the cost and other info "); + parser.add_option_with_help_string("-xyz_sample", + "run a small interactive scenario"); + parser.add_option_with_after_string_with_help( + "--percent_for_enter", + "which percent of columns check for entering column"); + parser.add_option_with_help_string( + "--totalinf", + "minimizes the total infeasibility instead of diminishing " + "infeasibility of the rows"); + parser.add_option_with_after_string_with_help( + "--rep_frq", + "the report frequency, in how many iterations print the " + "cost and other info "); parser.add_option_with_help_string("--smt", "smt file format"); - parser.add_option_with_after_string_with_help("--filelist", "the file containing the list of files"); - parser.add_option_with_after_string_with_help("--file", "the input file name"); + parser.add_option_with_after_string_with_help( + "--filelist", "the file containing the list of files"); + parser.add_option_with_after_string_with_help("--file", + "the input file name"); parser.add_option_with_after_string_with_help("--random_seed", "random seed"); parser.add_option_with_help_string("--bp", "bound propagation"); - parser.add_option_with_help_string("--min", "will look for the minimum for the given file if --file is used; the default is looking for the max"); - parser.add_option_with_help_string("--max", "will look for the maximum for the given file if --file is used; it is the default behavior"); - parser.add_option_with_after_string_with_help("--max_iters", "maximum total iterations in a core solver stage"); - parser.add_option_with_after_string_with_help("--time_limit", "time limit in seconds"); + parser.add_option_with_help_string( + "--min", + "will look for the minimum for the given file if --file is " + "used; the default is looking for the max"); + parser.add_option_with_help_string( + "--max", + "will look for the maximum for the given file if --file is " + "used; it is the default behavior"); + parser.add_option_with_after_string_with_help( + "--max_iters", "maximum total iterations in a core solver stage"); + parser.add_option_with_after_string_with_help("--time_limit", + "time limit in seconds"); parser.add_option_with_help_string("--mpq", "solve for rational numbers"); - parser.add_option_with_after_string_with_help("--simplex_strategy", "sets simplex strategy for rational number"); + parser.add_option_with_after_string_with_help( + "--simplex_strategy", "sets simplex strategy for rational number"); parser.add_option_with_help_string("--test_lp_0", "solve a small lp"); - parser.add_option_with_help_string("--solve_some_mps", "solves a list of mps problems"); - parser.add_option_with_after_string_with_help("--test_file_directory", "loads files from the directory for testing"); - parser.add_option_with_after_string_with_help("--out_dir", "setting the output directory for tests, if not set /tmp is used"); + parser.add_option_with_help_string("--solve_some_mps", + "solves a list of mps problems"); + parser.add_option_with_after_string_with_help( + "--test_file_directory", "loads files from the directory for testing"); + parser.add_option_with_after_string_with_help( + "--out_dir", + "setting the output directory for tests, if not set /tmp is used"); parser.add_option_with_help_string("--dual", "using the dual simplex solver"); - parser.add_option_with_help_string("--compare_with_primal", "using the primal simplex solver for comparison"); + parser.add_option_with_help_string( + "--compare_with_primal", + "using the primal simplex solver for comparison"); parser.add_option_with_help_string("--lar", "test lar_solver"); - parser.add_option_with_after_string_with_help("--maxng", "max iterations without progress"); - parser.add_option_with_help_string("--randomize_lar", "test randomize functionality"); + parser.add_option_with_after_string_with_help( + "--maxng", "max iterations without progress"); + parser.add_option_with_help_string("--randomize_lar", + "test randomize functionality"); parser.add_option_with_help_string("--smap", "test stacked_map"); parser.add_option_with_help_string("--term", "simple term test"); - parser.add_option_with_help_string("--eti"," run a small evidence test for total infeasibility scenario"); - parser.add_option_with_help_string("--row_inf", "forces row infeasibility search"); + parser.add_option_with_help_string( + "--eti", " run a small evidence test for total infeasibility scenario"); + parser.add_option_with_help_string("--row_inf", + "forces row infeasibility search"); parser.add_option_with_help_string("-pd", "presolve with double solver"); parser.add_option_with_help_string("--test_int_set", "test int_set"); parser.add_option_with_help_string("--test_mpq", "test rationals"); parser.add_option_with_help_string("--test_mpq_np", "test rationals"); - parser.add_option_with_help_string("--test_mpq_np_plus", "test rationals using plus instead of +="); + parser.add_option_with_help_string("--test_mpq_np_plus", + "test rationals using plus instead of +="); parser.add_option_with_help_string("--maximize_term", "test maximize_term()"); + parser.add_option_with_help_string("--patching", "test patching"); } -struct fff { int a; int b;}; - +struct fff { + int a; + int b; +}; void test_stacked_unsigned() { std::cout << "test stacked unsigned" << std::endl; @@ -719,24 +576,21 @@ void test_stacked_unsigned() { v = 4; v.pop(); lp_assert(v == 2); - v ++; + v++; v++; std::cout << "before push v=" << v << std::endl; v.push(); v++; v.push(); - v+=1; + v += 1; std::cout << "v = " << v << std::endl; v.pop(2); lp_assert(v == 4); - const unsigned & rr = v; - std::cout << rr << std:: endl; - + const unsigned &rr = v; + std::cout << rr << std::endl; } -void test_stacked_value() { - test_stacked_unsigned(); -} +void test_stacked_value() { test_stacked_unsigned(); } void test_stacked_vector() { std::cout << "test_stacked_vector" << std::endl; @@ -751,31 +605,29 @@ void test_stacked_vector() { v.push_back(3); v.push_back(34); v.push(); - v[1]=3; + v[1] = 3; v[2] = 3; v.push(); - v[0]= 7; + v[0] = 7; v[1] = 9; v.pop(2); if (v.size()) - v[v.size() -1 ] = 7; + v[v.size() - 1] = 7; v.push(); v.push_back(33); v[0] = 13; v.pop(); - } - void test_stacked() { test_stacked_value(); test_stacked_vector(); } -char * find_home_dir() { +char *find_home_dir() { #ifdef _WINDOWS #else - char * home_dir = getenv("HOME"); + char *home_dir = getenv("HOME"); if (home_dir == nullptr) { std::cout << "cannot find home directory" << std::endl; return nullptr; @@ -784,9 +636,8 @@ char * find_home_dir() { return nullptr; } - template -void print_chunk(T * arr, unsigned len) { +void print_chunk(T *arr, unsigned len) { for (unsigned i = 0; i < len; i++) { std::cout << arr[i] << ", "; } @@ -794,7 +645,7 @@ void print_chunk(T * arr, unsigned len) { } struct mem_cpy_place_holder { - static void mem_copy_hook(int * destination, unsigned num) { + static void mem_copy_hook(int *destination, unsigned num) { if (destination == nullptr || num == 0) { throw "bad parameters"; } @@ -803,13 +654,14 @@ struct mem_cpy_place_holder { void finalize(unsigned ret) { /* - finalize_util_module(); - finalize_numerics_module(); - */ + finalize_util_module(); + finalize_numerics_module(); + */ // return ret; } -void get_time_limit_and_max_iters_from_parser(argument_parser & args_parser, unsigned & time_limit) { +void get_time_limit_and_max_iters_from_parser(argument_parser &args_parser, + unsigned &time_limit) { std::string time_limit_string = args_parser.get_option_value("--time_limit"); if (!time_limit_string.empty()) { time_limit = atoi(time_limit_string.c_str()); @@ -818,21 +670,28 @@ void get_time_limit_and_max_iters_from_parser(argument_parser & args_parser, uns } } - -std::string create_output_file_name(bool minimize, std::string file_name, bool use_mpq) { - std::string ret = file_name + "_lp_tst_" + (minimize?"min":"max"); - if (use_mpq) return ret + "_mpq.out"; +std::string create_output_file_name(bool minimize, std::string file_name, + bool use_mpq) { + std::string ret = file_name + "_lp_tst_" + (minimize ? "min" : "max"); + if (use_mpq) + return ret + "_mpq.out"; return ret + ".out"; } -std::string create_output_file_name_for_glpsol(bool minimize, std::string file_name){ - return file_name + (minimize?"_min":"_max") + "_glpk_out"; +std::string create_output_file_name_for_glpsol(bool minimize, + std::string file_name) { + return file_name + (minimize ? "_min" : "_max") + "_glpk_out"; } -int run_glpk(std::string file_name, std::string glpk_out_file_name, bool minimize, unsigned time_limit) { - std::string minmax(minimize?"--min":"--max"); - std::string tmlim = time_limit > 0 ? std::string(" --tmlim ") + std::to_string(time_limit)+ " ":std::string(); - std::string command_line = std::string("glpsol --nointopt --nomip ") + minmax + tmlim + + " -o " + glpk_out_file_name +" " + file_name + " > /dev/null"; +int run_glpk(std::string file_name, std::string glpk_out_file_name, + bool minimize, unsigned time_limit) { + std::string minmax(minimize ? "--min" : "--max"); + std::string tmlim = time_limit > 0 ? std::string(" --tmlim ") + + std::to_string(time_limit) + " " + : std::string(); + std::string command_line = std::string("glpsol --nointopt --nomip ") + + minmax + tmlim + +" -o " + glpk_out_file_name + + " " + file_name + " > /dev/null"; return system(command_line.c_str()); } @@ -857,22 +716,13 @@ std::string get_status(std::string file_name) { throw 0; } - - - struct sort_pred { - bool operator()(const std::pair &left, const std::pair &right) { + bool operator()(const std::pair &left, + const std::pair &right) { return left.second < right.second; } }; - - - - - - - vector get_file_names_from_file_list(std::string filelist) { std::ifstream file(filelist); if (!file.is_open()) { @@ -892,17 +742,16 @@ vector get_file_names_from_file_list(std::string filelist) { return ret; } - void test_numeric_pair() { numeric_pair a; numeric_pair b(2, lp::mpq(6, 2)); a = b; numeric_pair c(0.1, 0.5); - a += 2*c; + a += 2 * c; a -= c; - lp_assert (a == b + c); + lp_assert(a == b + c); numeric_pair d = a * 2; - std::cout << a << std::endl; + std::cout << a << std::endl; lp_assert(b == b); lp_assert(b < a); lp_assert(b <= a); @@ -913,12 +762,12 @@ void test_numeric_pair() { lp_assert(a < 2 * b); lp_assert(b + b > a); lp_assert(lp::mpq(2.1) * b + b > a); - lp_assert(-b * lp::mpq(2.1) - b < lp::mpq(0.99) * a); - std::cout << - b * lp::mpq(2.1) - b << std::endl; - lp_assert(-b *(lp::mpq(2.1) + 1) == - b * lp::mpq(2.1) - b); + lp_assert(-b * lp::mpq(2.1) - b < lp::mpq(0.99) * a); + std::cout << -b * lp::mpq(2.1) - b << std::endl; + lp_assert(-b * (lp::mpq(2.1) + 1) == -b * lp::mpq(2.1) - b); } -void get_matrix_dimensions(std::ifstream & f, unsigned & m, unsigned & n) { +void get_matrix_dimensions(std::ifstream &f, unsigned &m, unsigned &n) { std::string line; getline(f, line); getline(f, line); @@ -929,36 +778,33 @@ void get_matrix_dimensions(std::ifstream & f, unsigned & m, unsigned & n) { n = atoi(r[1].c_str()); } - void print_st(lp_status status) { std::cout << lp_status_to_string(status) << std::endl; } - - void test_term() { lar_solver solver; unsigned _x = 0; unsigned _y = 1; - var_index x = solver.add_named_var(_x, true, "x"); - var_index y = solver.add_named_var(_y, true, "y"); + lpvar x = solver.add_named_var(_x, true, "x"); + lpvar y = solver.add_named_var(_y, true, "y"); enable_trace("lar_solver"); enable_trace("cube"); - vector> pairs; - pairs.push_back(std::pair(mpq(2), x)); - pairs.push_back(std::pair(mpq(1), y)); + vector> pairs; + pairs.push_back(std::pair(mpq(2), x)); + pairs.push_back(std::pair(mpq(1), y)); int ti = 0; unsigned x_plus_y = solver.add_term(pairs, ti++); solver.add_var_bound(x_plus_y, lconstraint_kind::GE, mpq(5, 3)); solver.add_var_bound(x_plus_y, lconstraint_kind::LE, mpq(14, 3)); pairs.pop_back(); - pairs.push_back(std::pair(mpq(-1), y)); - unsigned x_minus_y = solver.add_term(pairs, ti++); + pairs.push_back(std::pair(mpq(-1), y)); + unsigned x_minus_y = solver.add_term(pairs, ti++); solver.add_var_bound(x_minus_y, lconstraint_kind::GE, mpq(5, 3)); solver.add_var_bound(x_minus_y, lconstraint_kind::LE, mpq(14, 3)); - auto status = solver.solve(); + auto status = solver.solve(); std::cout << lp_status_to_string(status) << std::endl; - std::unordered_map model; + std::unordered_map model; if (status != lp_status::OPTIMAL) { std::cout << "non optimal" << std::endl; return; @@ -971,85 +817,84 @@ void test_term() { solver.set_int_solver(&i_s); int_cube cuber(i_s); lia_move m = cuber(); - - std::cout <<"\n" << lia_move_to_string(m) << std::endl; + + std::cout << "\n" + << lia_move_to_string(m) << std::endl; model.clear(); solver.get_model(model); - for (auto & t : model) { - std::cout << solver.get_variable_name(t.first) << " = " << t.second.get_double() << ","; + for (auto &t : model) { + std::cout << solver.get_variable_name(t.first) << " = " + << t.second.get_double() << ","; } std::cout << "\ntableu after cube\n"; solver.pp(std::cout).print(); std::cout << "Ax_is_correct = " << solver.ax_is_correct() << "\n"; - } -void test_evidence_for_total_inf_simple(argument_parser & args_parser) { +void test_evidence_for_total_inf_simple(argument_parser &args_parser) { lar_solver solver; - var_index x = solver.add_var(0, false); - var_index y = solver.add_var(1, false); + lpvar x = solver.add_var(0, false); + lpvar y = solver.add_var(1, false); solver.add_var_bound(x, LE, mpq(-1)); solver.add_var_bound(y, GE, mpq(0)); - vector> ls; - - ls.push_back(std::pair(mpq(1), x)); - ls.push_back(std::pair(mpq(1), y)); + vector> ls; + + ls.push_back(std::pair(mpq(1), x)); + ls.push_back(std::pair(mpq(1), y)); unsigned j = solver.add_term(ls, 1); solver.add_var_bound(j, GE, mpq(1)); ls.pop_back(); - ls.push_back(std::pair(- mpq(1), y)); + ls.push_back(std::pair(-mpq(1), y)); j = solver.add_term(ls, 2); solver.add_var_bound(j, GE, mpq(0)); auto status = solver.solve(); std::cout << lp_status_to_string(status) << std::endl; - std::unordered_map model; + std::unordered_map model; lp_assert(solver.get_status() == lp_status::INFEASIBLE); } void test_bound_propagation_one_small_sample1() { /* - (<= (+ a (* (- 1.0) b)) 0.0) - (<= (+ b (* (- 1.0) x_13)) 0.0) - --> (<= (+ a (* (- 1.0) c)) 0.0) - - the inequality on (<= a c) is obtained from a triangle inequality (<= a b) (<= b c). - If b becomes basic variable, then it is likely the old solver ends up with a row that implies (<= a c). - a - b <= 0.0 - b - c <= 0.0 - - got to get a <= c - */ - std::function bound_is_relevant = - [&](unsigned j, bool is_lower_bound, bool strict, const rational& bound_val) { - return true; - }; + (<= (+ a (* (- 1.0) b)) 0.0) + (<= (+ b (* (- 1.0) x_13)) 0.0) + --> (<= (+ a (* (- 1.0) c)) 0.0) + + the inequality on (<= a c) is obtained from a triangle inequality (<= a b) + (<= b c). If b becomes basic variable, then it is likely the old solver ends + up with a row that implies (<= a c). a - b <= 0.0 b - c <= 0.0 + + got to get a <= c + */ + std::function bound_is_relevant = + [&](unsigned j, bool is_lower_bound, bool strict, + const rational &bound_val) { return true; }; lar_solver ls; unsigned a = ls.add_var(0, false); unsigned b = ls.add_var(1, false); unsigned c = ls.add_var(2, false); - vector> coeffs; - coeffs.push_back(std::pair(mpq(1), a)); - coeffs.push_back(std::pair(mpq(-1), c)); + vector> coeffs; + coeffs.push_back(std::pair(mpq(1), a)); + coeffs.push_back(std::pair(mpq(-1), c)); ls.add_term(coeffs, -1); coeffs.pop_back(); - coeffs.push_back(std::pair(mpq(-1), b)); + coeffs.push_back(std::pair(mpq(-1), b)); ls.add_term(coeffs, -1); coeffs.clear(); - coeffs.push_back(std::pair(mpq(1), a)); - coeffs.push_back(std::pair(mpq(-1), b)); + coeffs.push_back(std::pair(mpq(1), a)); + coeffs.push_back(std::pair(mpq(-1), b)); // ls.add_constraint(coeffs, LE, zero_of_type()); // coeffs.clear(); - // coeffs.push_back(std::pair(mpq(1), b)); - // coeffs.push_back(std::pair(mpq(-1), c)); + // coeffs.push_back(std::pair(mpq(1), b)); + // coeffs.push_back(std::pair(mpq(-1), c)); // ls.add_constraint(coeffs, LE, zero_of_type()); // vector ev; // ls.add_var_bound(a, LE, mpq(1)); // ls.solve(); // my_bound_propagator bp(ls); // ls.propagate_bounds_for_touched_rows(bp); - // std::cout << " bound ev from test_bound_propagation_one_small_sample1" << std::endl; - // for (auto & be : bp.m_ibounds) { + // std::cout << " bound ev from test_bound_propagation_one_small_sample1" << + // std::endl; for (auto & be : bp.m_ibounds) { // std::cout << "bound\n"; // ls.print_implied_bound(be, std::cout); // } // todo: restore test @@ -1058,43 +903,48 @@ void test_bound_propagation_one_small_sample1() { void test_bound_propagation_one_small_samples() { test_bound_propagation_one_small_sample1(); /* - (>= x_46 0.0) - (<= x_29 0.0) - (not (<= x_68 0.0)) - (<= (+ (* (/ 1001.0 1998.0) x_10) (* (- 1.0) x_151) x_68) (- (/ 1001.0 999.0))) - (<= (+ (* (/ 1001.0 999.0) x_9) - (* (- 1.0) x_152) - (* (/ 1001.0 999.0) x_151) - (* (/ 1001.0 999.0) x_68)) - (- (/ 1502501.0 999000.0))) - (not (<= (+ (* (/ 999.0 2.0) x_10) (* (- 1.0) x_152) (* (- (/ 999.0 2.0)) x_151)) - (/ 1001.0 2.0))) - (not (<= x_153 0.0))z - (>= (+ x_9 (* (- (/ 1001.0 999.0)) x_10) (* (- 1.0) x_153) (* (- 1.0) x_68)) - (/ 5003.0 1998.0)) - --> (not (<= (+ x_10 x_46 (* (- 1.0) x_29)) 0.0)) - - and - - (<= (+ a (* (- 1.0) b)) 0.0) - (<= (+ b (* (- 1.0) x_13)) 0.0) - --> (<= (+ a (* (- 1.0) x_13)) 0.0) - - In the first case, there typically are no atomic formulas for bounding x_10. So there is never some - basic lemma of the form (>= x46 0), (<= x29 0), (>= x10 0) -> (not (<= (+ x10 x46 (- x29)) 0)). - Instead the bound on x_10 falls out from a bigger blob of constraints. - - In the second case, the inequality on (<= x19 x13) is obtained from a triangle inequality (<= x19 x9) (<= x9 x13). - If x9 becomes basic variable, then it is likely the old solver ends up with a row that implies (<= x19 x13). - */ + (>= x_46 0.0) + (<= x_29 0.0) + (not (<= x_68 0.0)) + (<= (+ (* (/ 1001.0 1998.0) x_10) (* (- 1.0) x_151) x_68) (- (/ 1001.0 + 999.0))) + (<= (+ (* (/ 1001.0 999.0) x_9) + (* (- 1.0) x_152) + (* (/ 1001.0 999.0) x_151) + (* (/ 1001.0 999.0) x_68)) + (- (/ 1502501.0 999000.0))) + (not (<= (+ (* (/ 999.0 2.0) x_10) (* (- 1.0) x_152) (* (- (/ 999.0 2.0)) + x_151)) + (/ 1001.0 2.0))) + (not (<= x_153 0.0))z + (>= (+ x_9 (* (- (/ 1001.0 999.0)) x_10) (* (- 1.0) x_153) (* (- 1.0) x_68)) + (/ 5003.0 1998.0)) + --> (not (<= (+ x_10 x_46 (* (- 1.0) x_29)) 0.0)) + + and + + (<= (+ a (* (- 1.0) b)) 0.0) + (<= (+ b (* (- 1.0) x_13)) 0.0) + --> (<= (+ a (* (- 1.0) x_13)) 0.0) + + In the first case, there typically are no atomic formulas for bounding x_10. + So there is never some basic lemma of the form (>= x46 0), (<= x29 0), (>= + x10 0) -> (not (<= (+ x10 x46 (- x29)) 0)). Instead the bound on x_10 falls + out from a bigger blob of constraints. + + In the second case, the inequality on (<= x19 x13) is obtained from a + triangle inequality (<= x19 x9) (<= x9 x13). If x9 becomes basic variable, + then it is likely the old solver ends up with a row that implies (<= x19 + x13). + */ } void test_bound_propagation_one_row() { lar_solver ls; unsigned x0 = ls.add_var(0, false); unsigned x1 = ls.add_var(1, false); - vector> c; - c.push_back(std::pair(mpq(1), x0)); - c.push_back(std::pair(mpq(-1), x1)); + vector> c; + c.push_back(std::pair(mpq(1), x0)); + c.push_back(std::pair(mpq(-1), x1)); // todo : restore test // ls.add_constraint(c, EQ, one_of_type()); // vector ev; @@ -1102,14 +952,14 @@ void test_bound_propagation_one_row() { // ls.solve(); // my_bound_propagator bp(ls); // ls.propagate_bounds_for_touched_rows(bp); -} +} void test_bound_propagation_one_row_with_bounded_vars() { lar_solver ls; unsigned x0 = ls.add_var(0, false); unsigned x1 = ls.add_var(1, false); - vector> c; - c.push_back(std::pair(mpq(1), x0)); - c.push_back(std::pair(mpq(-1), x1)); + vector> c; + c.push_back(std::pair(mpq(1), x0)); + c.push_back(std::pair(mpq(-1), x1)); // todo: restore test // ls.add_constraint(c, EQ, one_of_type()); // vector ev; @@ -1124,9 +974,9 @@ void test_bound_propagation_one_row_mixed() { lar_solver ls; unsigned x0 = ls.add_var(0, false); unsigned x1 = ls.add_var(1, false); - vector> c; - c.push_back(std::pair(mpq(1), x0)); - c.push_back(std::pair(mpq(-1), x1)); + vector> c; + c.push_back(std::pair(mpq(1), x0)); + c.push_back(std::pair(mpq(-1), x1)); // todo: restore test // ls.add_constraint(c, EQ, one_of_type()); // vector ev; @@ -1134,23 +984,23 @@ void test_bound_propagation_one_row_mixed() { // ls.solve(); // my_bound_propagator bp(ls); // ls.propagate_bounds_for_touched_rows(bp); -} +} void test_bound_propagation_two_rows() { lar_solver ls; unsigned x = ls.add_var(0, false); unsigned y = ls.add_var(1, false); unsigned z = ls.add_var(2, false); - vector> c; - c.push_back(std::pair(mpq(1), x)); - c.push_back(std::pair(mpq(2), y)); - c.push_back(std::pair(mpq(3), z)); + vector> c; + c.push_back(std::pair(mpq(1), x)); + c.push_back(std::pair(mpq(2), y)); + c.push_back(std::pair(mpq(3), z)); // todo: restore test // ls.add_constraint(c, GE, one_of_type()); // c.clear(); - // c.push_back(std::pair(mpq(3), x)); - // c.push_back(std::pair(mpq(2), y)); - // c.push_back(std::pair(mpq(y), z)); + // c.push_back(std::pair(mpq(3), x)); + // c.push_back(std::pair(mpq(2), y)); + // c.push_back(std::pair(mpq(y), z)); // ls.add_constraint(c, GE, one_of_type()); // ls.add_var_bound(x, LE, mpq(2)); // vector ev; @@ -1158,7 +1008,7 @@ void test_bound_propagation_two_rows() { // ls.solve(); // my_bound_propagator bp(ls); // ls.propagate_bounds_for_touched_rows(bp); -} +} void test_total_case_u() { std::cout << "test_total_case_u\n"; @@ -1166,10 +1016,10 @@ void test_total_case_u() { unsigned x = ls.add_var(0, false); unsigned y = ls.add_var(1, false); unsigned z = ls.add_var(2, false); - vector> c; - c.push_back(std::pair(mpq(1), x)); - c.push_back(std::pair(mpq(2), y)); - c.push_back(std::pair(mpq(3), z)); + vector> c; + c.push_back(std::pair(mpq(1), x)); + c.push_back(std::pair(mpq(2), y)); + c.push_back(std::pair(mpq(3), z)); // todo: restore test // ls.add_constraint(c, LE, one_of_type()); // ls.add_var_bound(x, GE, zero_of_type()); @@ -1180,23 +1030,24 @@ void test_total_case_u() { // my_bound_propagator bp(ls); // ls.propagate_bounds_for_touched_rows(bp); } -bool contains_j_kind(unsigned j, lconstraint_kind kind, const mpq & rs, const vector & ev) { - for (auto & e : ev) { +bool contains_j_kind(unsigned j, lconstraint_kind kind, const mpq &rs, + const vector &ev) { + for (auto &e : ev) { if (e.m_j == j && e.m_bound == rs && e.kind() == kind) return true; } return false; } -void test_total_case_l(){ +void test_total_case_l() { std::cout << "test_total_case_l\n"; lar_solver ls; unsigned x = ls.add_var(0, false); unsigned y = ls.add_var(1, false); unsigned z = ls.add_var(2, false); - vector> c; - c.push_back(std::pair(mpq(1), x)); - c.push_back(std::pair(mpq(2), y)); - c.push_back(std::pair(mpq(3), z)); + vector> c; + c.push_back(std::pair(mpq(1), x)); + c.push_back(std::pair(mpq(2), y)); + c.push_back(std::pair(mpq(3), z)); // todo: restore test // ls.add_constraint(c, GE, one_of_type()); // ls.add_var_bound(x, LE, one_of_type()); @@ -1218,54 +1069,50 @@ void test_bound_propagation() { test_bound_propagation_two_rows(); test_bound_propagation_one_row_mixed(); test_total_case_l(); - } void test_int_set() { - u_set s(4); - s.insert(2); + indexed_uint_set s; s.insert(1); s.insert(2); lp_assert(s.contains(2)); lp_assert(s.size() == 2); - s.erase(2); - lp_assert(s.size() == 1); - s.erase(2); + s.remove(2); lp_assert(s.size() == 1); s.insert(3); s.insert(2); - s.clear(); + s.reset(); lp_assert(s.size() == 0); - - + std::cout << "done test_int_set\n"; } void test_rationals_no_numeric_pairs() { stopwatch sw; vector c; - for (unsigned j = 0; j < 10; j ++) - c.push_back(mpq(my_random()%100, 1 + my_random()%100 )); - + for (unsigned j = 0; j < 10; j++) + c.push_back(mpq(my_random() % 100, 1 + my_random() % 100)); + vector x; - for (unsigned j = 0; j < 10; j ++) - x.push_back(mpq(my_random()%100, 1 + my_random()%100 )); + for (unsigned j = 0; j < 10; j++) + x.push_back(mpq(my_random() % 100, 1 + my_random() % 100)); unsigned k = 500000; - mpq r=zero_of_type(); + mpq r = zero_of_type(); sw.start(); - - for (unsigned j = 0; j < k; j++){ + + for (unsigned j = 0; j < k; j++) { mpq val = zero_of_type(); - for (unsigned j=0;j< c.size(); j++){ - val += c[j]*x[j]; + for (unsigned j = 0; j < c.size(); j++) { + val += c[j] * x[j]; } - + r += val; } - + sw.stop(); - std::cout << "operation with rationals no pairs " << sw.get_seconds() << std::endl; + std::cout << "operation with rationals no pairs " << sw.get_seconds() + << std::endl; std::cout << T_to_string(r) << std::endl; } @@ -1273,64 +1120,62 @@ void test_rationals_no_numeric_pairs_plus() { stopwatch sw; vector c; - for (unsigned j = 0; j < 10; j ++) - c.push_back(mpq(my_random()%100, 1 + my_random()%100 )); - + for (unsigned j = 0; j < 10; j++) + c.push_back(mpq(my_random() % 100, 1 + my_random() % 100)); + vector x; - for (unsigned j = 0; j < 10; j ++) - x.push_back(mpq(my_random()%100, 1 + my_random()%100 )); + for (unsigned j = 0; j < 10; j++) + x.push_back(mpq(my_random() % 100, 1 + my_random() % 100)); unsigned k = 500000; - mpq r=zero_of_type(); + mpq r = zero_of_type(); sw.start(); - - for (unsigned j = 0; j < k; j++){ + + for (unsigned j = 0; j < k; j++) { mpq val = zero_of_type(); - for (unsigned j=0;j< c.size(); j++){ - val = val + c[j]*x[j]; + for (unsigned j = 0; j < c.size(); j++) { + val = val + c[j] * x[j]; } - + r = r + val; } - + sw.stop(); - std::cout << "operation with rationals no pairs " << sw.get_seconds() << std::endl; + std::cout << "operation with rationals no pairs " << sw.get_seconds() + << std::endl; std::cout << T_to_string(r) << std::endl; } - - void test_rationals() { stopwatch sw; vector c; - for (unsigned j = 0; j < 10; j ++) - c.push_back(rational(my_random()%100, 1 + my_random()%100)); + for (unsigned j = 0; j < 10; j++) + c.push_back(rational(my_random() % 100, 1 + my_random() % 100)); - - vector> x; - for (unsigned j = 0; j < 10; j ++) - x.push_back(numeric_pair(rational(my_random()%100, 1 + my_random()%100 ))); + for (unsigned j = 0; j < 10; j++) + x.push_back(numeric_pair( + rational(my_random() % 100, 1 + my_random() % 100))); std::cout << "x = "; print_vector(x, std::cout); - + unsigned k = 1000000; - numeric_pair r=zero_of_type>(); + numeric_pair r = zero_of_type>(); sw.start(); - + for (unsigned j = 0; j < k; j++) { for (unsigned i = 0; i < c.size(); i++) { - r+= c[i] * x[i]; + r += c[i] * x[i]; } - } + } sw.stop(); std::cout << "operation with rationals " << sw.get_seconds() << std::endl; std::cout << T_to_string(r) << std::endl; } -void get_random_interval(bool& neg_inf, bool& pos_inf, int& x, int &y) { +void get_random_interval(bool &neg_inf, bool &pos_inf, int &x, int &y) { int i = my_random() % 10; if (i == 0) { neg_inf = true; @@ -1346,29 +1191,29 @@ void get_random_interval(bool& neg_inf, bool& pos_inf, int& x, int &y) { if (!neg_inf) { y = x + my_random() % (101 - x); lp_assert(y >= x); - } - else { + } else { y = my_random() % 100; } } - lp_assert((neg_inf || (0 <= x && x <= 100)) && (pos_inf || (0 <= y && y <= 100))); + lp_assert((neg_inf || (0 <= x && x <= 100)) && + (pos_inf || (0 <= y && y <= 100))); } - void test_gomory_cut_0() { gomory_test g( - [](unsigned j) { return "v" + T_to_string(j);} // name_function_p + [](unsigned j) { return "v" + T_to_string(j); } // name_function_p , - [](unsigned j) { //get_value_p + [](unsigned j) { // get_value_p if (j == 1) return mpq(2730, 1727); if (j == 2) return zero_of_type(); - if (j == 3) return mpq(3); + if (j == 3) + return mpq(3); UNREACHABLE(); return zero_of_type(); }, - [](unsigned j) { // at_low_p + [](unsigned j) { // at_low_p if (j == 1) return false; if (j == 2) @@ -1378,7 +1223,7 @@ void test_gomory_cut_0() { UNREACHABLE(); return false; }, - [](unsigned j) { // at_upper + [](unsigned j) { // at_upper if (j == 1) return false; if (j == 2) @@ -1388,9 +1233,9 @@ void test_gomory_cut_0() { UNREACHABLE(); return false; }, - [](unsigned j) { // lower_bound + [](unsigned j) { // lower_bound if (j == 1) { - UNREACHABLE(); //unlimited from below + UNREACHABLE(); // unlimited from below return impq(0); } if (j == 2) @@ -1400,9 +1245,9 @@ void test_gomory_cut_0() { UNREACHABLE(); return impq(0); }, - [](unsigned j) { // upper + [](unsigned j) { // upper if (j == 1) { - UNREACHABLE(); //unlimited from above + UNREACHABLE(); // unlimited from above return impq(0); } if (j == 2) @@ -1412,9 +1257,7 @@ void test_gomory_cut_0() { UNREACHABLE(); return impq(0); }, - [] (unsigned) { return 0; }, - [] (unsigned) { return 0; } - ); + [](unsigned) { return 0; }, [](unsigned) { return 0; }); lar_term t; mpq k; explanation expl; @@ -1423,14 +1266,14 @@ void test_gomory_cut_0() { row.push_back(std::make_pair(mpq(1), 1)); row.push_back(std::make_pair(mpq(2731, 1727), 2)); row.push_back(std::make_pair(mpq(-910, 1727), 3)); - g.mk_gomory_cut(t, k, expl, inf_col, row); + g.mk_gomory_cut(t, k, expl, inf_col, row); } void test_gomory_cut_1() { gomory_test g( - [](unsigned j) { return "v" + T_to_string(j);} // name_function_p + [](unsigned j) { return "v" + T_to_string(j); } // name_function_p , - [](unsigned j) { //get_value_p + [](unsigned j) { // get_value_p if (j == 1) return mpq(-2); if (j == 2) @@ -1440,7 +1283,7 @@ void test_gomory_cut_1() { UNREACHABLE(); return zero_of_type(); }, - [](unsigned j) { // at_low_p + [](unsigned j) { // at_low_p if (j == 1) return false; if (j == 2) @@ -1450,7 +1293,7 @@ void test_gomory_cut_1() { UNREACHABLE(); return false; }, - [](unsigned j) { // at_upper + [](unsigned j) { // at_upper if (j == 1) return true; if (j == 2) @@ -1460,9 +1303,9 @@ void test_gomory_cut_1() { UNREACHABLE(); return false; }, - [](unsigned j) { // lower_bound + [](unsigned j) { // lower_bound if (j == 1) { - UNREACHABLE(); //unlimited from below + UNREACHABLE(); // unlimited from below return impq(0); } if (j == 2) @@ -1472,7 +1315,7 @@ void test_gomory_cut_1() { UNREACHABLE(); return impq(0); }, - [](unsigned j) { // upper + [](unsigned j) { // upper if (j == 1) { return impq(-2); } @@ -1483,9 +1326,7 @@ void test_gomory_cut_1() { UNREACHABLE(); return impq(0); }, - [] (unsigned) { return 0; }, - [] (unsigned) { return 0; } - ); + [](unsigned) { return 0; }, [](unsigned) { return 0; }); lar_term t; mpq k; explanation expl; @@ -1494,10 +1335,10 @@ void test_gomory_cut_1() { row.push_back(std::make_pair(mpq(1726667, 2730001), 1)); row.push_back(std::make_pair(mpq(-910000, 2730001), 3)); row.push_back(std::make_pair(mpq(1), 2)); - g.mk_gomory_cut(t, k, expl, inf_col, row); + g.mk_gomory_cut(t, k, expl, inf_col, row); } -void call_hnf(general_matrix & A); +void call_hnf(general_matrix &A); void test_hnf_m_less_than_n() { #ifdef Z3DEBUG @@ -1547,7 +1388,6 @@ void test_hnf_m_greater_than_n() { #endif } - void cutting_the_mix_example_1() { mpq sev(7); mpq nine(9); @@ -1560,9 +1400,9 @@ void cutting_the_mix_example_1() { hnf_calc::extended_gcd_minimal_uv(-nine, -nine, d, u, vv); std::cout << "d = " << d << ", u = " << u << ", vv = " << vv << std::endl; - hnf_calc::extended_gcd_minimal_uv(-sev*2, sev, d, u, vv); + hnf_calc::extended_gcd_minimal_uv(-sev * 2, sev, d, u, vv); std::cout << "d = " << d << ", u = " << u << ", vv = " << vv << std::endl; - + hnf_calc::extended_gcd_minimal_uv(mpq(24), mpq(-7), d, u, vv); std::cout << "d = " << d << ", u = " << u << ", vv = " << vv << std::endl; hnf_calc::extended_gcd_minimal_uv(-mpq(24), mpq(7), d, u, vv); @@ -1578,7 +1418,7 @@ void cutting_the_mix_example_1() { #ifdef Z3DEBUG -void fill_general_matrix(general_matrix & M) { +void fill_general_matrix(general_matrix &M) { unsigned m = M.row_count(); unsigned n = M.column_count(); for (unsigned i = 0; i < m; i++) @@ -1586,14 +1426,14 @@ void fill_general_matrix(general_matrix & M) { M[i][j] = mpq(static_cast(my_random() % 13) - 6); } -void call_hnf(general_matrix& A) { +void call_hnf(general_matrix &A) { svector r; - mpq d = hnf_calc::determinant_of_rectangular_matrix(A, r, mpq((int)1000000000)); + mpq d = + hnf_calc::determinant_of_rectangular_matrix(A, r, mpq((int)1000000000)); A.shrink_to_rank(r); hnf h(A, d); } - void test_hnf_for_dim(int m) { general_matrix M(m, m + my_random() % m); fill_general_matrix(M); @@ -1643,7 +1483,7 @@ void test_hnf_3_3() { v.push_back(mpq(-4)); v.push_back(mpq(-3)); A.push_row(v); - + call_hnf(A); std::cout << "test_hnf_3_3 passed" << std::endl; } @@ -1768,29 +1608,29 @@ void test_larger_generated_hnf() { void test_maximize_term() { std::cout << "test_maximize_term\n"; lar_solver solver; - int_solver i_solver(solver); // have to create it too + int_solver i_solver(solver); // have to create it too unsigned _x = 0; unsigned _y = 1; - var_index x = solver.add_var(_x, false); - var_index y = solver.add_var(_y, true); - vector> term_ls; - term_ls.push_back(std::pair(mpq(1), x)); - term_ls.push_back(std::pair(mpq(-1), y)); + lpvar x = solver.add_var(_x, false); + lpvar y = solver.add_var(_y, true); + vector> term_ls; + term_ls.push_back(std::pair(mpq(1), x)); + term_ls.push_back(std::pair(mpq(-1), y)); unsigned term_x_min_y = solver.add_term(term_ls, -1); term_ls.clear(); - term_ls.push_back(std::pair(mpq(2), x)); - term_ls.push_back(std::pair(mpq(2), y)); - + term_ls.push_back(std::pair(mpq(2), x)); + term_ls.push_back(std::pair(mpq(2), y)); + unsigned term_2x_pl_2y = solver.add_term(term_ls, -1); - solver.add_var_bound(term_x_min_y, LE, zero_of_type()); + solver.add_var_bound(term_x_min_y, LE, zero_of_type()); solver.add_var_bound(term_2x_pl_2y, LE, mpq(5)); solver.find_feasible_solution(); lp_assert(solver.get_status() == lp_status::OPTIMAL); std::cout << solver.constraints(); - std::unordered_map model; + std::unordered_map model; solver.get_model(model); for (auto p : model) { - std::cout<< "v[" << p.first << "] = " << p.second << std::endl; + std::cout << "v[" << p.first << "] = " << p.second << std::endl; } std::cout << "calling int_solver\n"; explanation ex; @@ -1798,14 +1638,13 @@ void test_maximize_term() { VERIFY(lm == lia_move::sat); impq term_max; lp_status st = solver.maximize_term(term_2x_pl_2y, term_max); - + std::cout << "status = " << lp_status_to_string(st) << std::endl; std::cout << "term_max = " << term_max << std::endl; solver.get_model(model); for (auto p : model) { - std::cout<< "v[" << p.first << "] = " << p.second << std::endl; + std::cout << "v[" << p.first << "] = " << p.second << std::endl; } - } #ifdef Z3DEBUG void test_hnf() { @@ -1816,7 +1655,7 @@ void test_hnf() { test_hnf_4_4(); test_hnf_5_5(); test_hnf_2_2(); - for (unsigned k=1000; k>0; k--) + for (unsigned k = 1000; k > 0; k--) for (int i = 1; i < 8; i++) test_hnf_for_dim(i); cutting_the_mix_example_1(); @@ -1829,13 +1668,9 @@ void test_gomory_cut() { test_gomory_cut_1(); } -void test_nla_order_lemma() { - nla::test_order_lemma(); -} - +void test_nla_order_lemma() { nla::test_order_lemma(); } -void test_lp_local(int argn, char**argv) { - +void test_lp_local(int argn, char **argv) { // initialize_util_module(); // initialize_numerics_module(); int ret; @@ -1854,7 +1689,10 @@ void test_lp_local(int argn, char**argv) { return finalize(0); } - + if (args_parser.option_is_used("--patching")) { + test_patching(); + return finalize(0); + } if (args_parser.option_is_used("-nla_cn")) { #ifdef Z3DEBUG nla::test_cn(); @@ -1874,7 +1712,6 @@ void test_lp_local(int argn, char**argv) { return finalize(0); } - if (args_parser.option_is_used("-nla_order")) { #ifdef Z3DEBUG test_nla_order_lemma(); @@ -1882,7 +1719,6 @@ void test_lp_local(int argn, char**argv) { return finalize(0); } - if (args_parser.option_is_used("-nla_monot")) { #ifdef Z3DEBUG nla::test_monotone_lemma(); @@ -1890,68 +1726,67 @@ void test_lp_local(int argn, char**argv) { return finalize(0); } - if (args_parser.option_is_used("-nla_bsl")) { + if (args_parser.option_is_used("-nla_bsl")) { #ifdef Z3DEBUG nla::test_basic_sign_lemma(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_horner")) { + if (args_parser.option_is_used("-nla_horner")) { #ifdef Z3DEBUG nla::test_horner(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_tan")) { + if (args_parser.option_is_used("-nla_tan")) { #ifdef Z3DEBUG nla::test_tangent_lemma(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_blfmz_mf")) { + if (args_parser.option_is_used("-nla_blfmz_mf")) { #ifdef Z3DEBUG nla::test_basic_lemma_for_mon_zero_from_monomial_to_factors(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_blfmz_fm")) { + if (args_parser.option_is_used("-nla_blfmz_fm")) { #ifdef Z3DEBUG nla::test_basic_lemma_for_mon_zero_from_factors_to_monomial(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_blnt_mf")) { + if (args_parser.option_is_used("-nla_blnt_mf")) { #ifdef Z3DEBUG nla::test_basic_lemma_for_mon_neutral_from_monomial_to_factors(); #endif return finalize(0); } - if (args_parser.option_is_used("-nla_blnt_fm")) { + if (args_parser.option_is_used("-nla_blnt_fm")) { #ifdef Z3DEBUG nla::test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); #endif return finalize(0); } - + if (args_parser.option_is_used("-hnf")) { #ifdef Z3DEBUG test_hnf(); #endif return finalize(0); } - + if (args_parser.option_is_used("-gomory")) { test_gomory_cut(); return finalize(0); } - if (args_parser.option_is_used("--test_int_set")) { test_int_set(); return finalize(0); @@ -1960,10 +1795,127 @@ void test_lp_local(int argn, char**argv) { test_bound_propagation(); return finalize(0); } - - return finalize(0); // has_violations() ? 1 : 0); -} + + return finalize(0); // has_violations() ? 1 : 0); } -void tst_lp(char ** argv, int argc, int& i) { +} // namespace lp +void tst_lp(char **argv, int argc, int &i) { lp::test_lp_local(argc - 2, argv + 2); } +// clang-format on +bool coprime(int a, int b) { + return gcd(rational(a), rational(b)).is_one(); +} +bool coprime(rational &a, rational &b) { + return gcd(a, b).is_one(); +} +void asserts_on_patching(const rational &x, const rational &alpha) { + auto a1 = numerator(alpha); + auto a2 = denominator(alpha); + auto x1 = numerator(x); + auto x2 = denominator(x); + lp_assert(a1.is_pos()); + lp_assert(abs(a1) < abs(a2)); + lp_assert(coprime(a1, a2)); + lp_assert(x1.is_pos()); + lp_assert(x1 < x2); + lp_assert(coprime(x1, x2)); + lp_assert((a2 / x2).is_int()); +} +void get_patching_deltas(const rational &x, const rational &alpha, rational &delta_0, rational &delta_1) { + std::cout << "get_patching_deltas(" << x << ", " << alpha << ")" << std::endl; + auto a1 = numerator(alpha); + auto a2 = denominator(alpha); + auto x1 = numerator(x); + auto x2 = denominator(x); + lp_assert(divides(x2, a2)); + // delta has to be integral. + // We need to find delta such that x1/x2 + (a1/a2)*delta is integral. + // Then a2*x1/x2 + a1*delta is integral, that means that t = a2/x2 is integral. + // We established that a2 = x2*t + // Then x1 + a1*delta*(x2/a2) = x1 + a1*(delta/t) is integral. Taking into account + // that t and a1 are coprime we have delta = t*k, where k is an integer. + rational t = a2 / x2; + std::cout << "t = " << t << std::endl; + // Now we have x1/x2 + (a1/x2)*k is integral, or (x1 + a1*k)/x2 is integral. + // It is equivalent to x1 + a1*k = x2*m, where m is an integer + // We know that a2 and a1 are coprime, and x2 divides a2, so x2 and a1 are coprime. + rational u, v; + auto g = gcd(a1, x2, u, v); + lp_assert(g.is_one() && u.is_int() && v.is_int() && g == u * a1 + v * x2); + std::cout << "u = " << u << ", v = " << v << std::endl; + std::cout << "x= " << (x1 / x2) << std::endl; + std::cout << "x + (a1 / a2) * (-u * t) * x1 = " << x + (a1 / a2) * (-u * t) * x1 << std::endl; + lp_assert((x + (a1 / a2) * (-u * t) * x1).is_int()); + // 1 = (u- l*x2 ) * a1 + (v + l*a1)*x2, for every integer l. + rational d = u * t * x1; + delta_0 = mod(d, a2); + lp_assert(delta_0 > 0); + delta_1 = delta_0 - a2; + lp_assert(delta_1 < 0); + std::cout << "delta_0 = " << delta_0 << std::endl; + std::cout << "delta_1 = " << delta_1 << std::endl; +} + +void try_find_smaller_delta(const rational &x, const rational &alpha, rational &delta_0, rational &delta_1) { + auto a1 = numerator(alpha); + auto a2 = denominator(alpha); + auto x1 = numerator(x); + auto x2 = denominator(x); + rational delta_minus, delta_plus; + auto del_min = delta_0 < delta_1 ? delta_0 : delta_1; + auto del_plus = delta_0 < delta_1 ? delta_1 : delta_0; + for (auto i = del_min + rational(1); i < del_plus; i += 1) { + if ((x - alpha * i).is_int()) { + std::cout << "found smaller delta = " << i << std::endl; + std::cout << "i - del_min = " << i - del_min << std::endl; + std::cout << "x - alpha*i = " << x - alpha * i << std::endl; + } + } +} + +void test_patching_alpha(const rational &x, const rational &alpha) { + std::cout << "\nstart patching x = " << x << ", alpha = " << alpha << "\n"; + asserts_on_patching(x, alpha); + rational delta_0, delta_1; + get_patching_deltas(x, alpha, delta_0, delta_1); + + lp_assert(delta_0 * delta_1 < 0); + + lp_assert((x - alpha * delta_0).is_int()); + lp_assert((x - alpha * delta_1).is_int()); + try_find_smaller_delta(x, alpha, delta_0, delta_1); + // std::cout << "delta_minus = " << delta_minus << ", delta_1 = " << delta_1 << "\n"; + // std::cout << "x + alpha*delta_minus = " << x + alpha * delta_minus << "\n"; + // std::cout << "x + alpha*delta_1 = " << x + alpha * delta_1 << "\n"; +} + +void find_a1_x1_x2_and_fix_a2(int &x1, int &x2, int &a1, int &a2) { + x2 = (rand() % a2) + (int)(a2 / 3); + auto g = gcd(rational(a2), rational(x2)); + a2 *= (x2 / numerator(g).get_int32()); + lp_assert(rational(a2, x2).is_int()); + do { + x1 = rand() % (unsigned)x2 + 1; + } while (!coprime(x1, x2)); + + do { + a1 = rand() % (unsigned)a2 + 1; + } while (!coprime(a1, a2)); +} + +void test_patching() { + srand(1); + // repeat the test 100 times + + int range = 40; + for (int i = 0; i < 100; i++) { + int a1; + int a2 = std::max((int)rand() % range, (int)range / 3); + + int x1, x2; + find_a1_x1_x2_and_fix_a2(x1, x2, a1, a2); + + test_patching_alpha(rational(x1, x2), rational(a1, a2)); + } +} diff --git a/src/test/lp/nla_solver_test.cpp b/src/test/lp/nla_solver_test.cpp index 6e2e0336a9b..fa5a29e99f9 100644 --- a/src/test/lp/nla_solver_test.cpp +++ b/src/test/lp/nla_solver_test.cpp @@ -169,7 +169,7 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_0() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); svector v; v.push_back(lp_b);v.push_back(lp_d);v.push_back(lp_e); nla.add_monic(lp_bde, v.size(), v.begin()); v.clear(); @@ -179,7 +179,6 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_0() { v.push_back(lp_a);v.push_back(lp_c); nla.add_monic(lp_ac, v.size(), v.begin()); - vector lv; // set abcde = ac * bde // ac = 1 then abcde = bde, but we have abcde < bde @@ -193,9 +192,10 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_0() { s.set_column_value_test(lp_bde, lp::impq(rational(16))); - VERIFY(nla.get_core().test_check(lv) == l_false); - - nla.get_core().print_lemma(lv.back(), std::cout); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemmas = nla.get_core().lemmas(); + nla.get_core().print_lemma(lemmas.back(), std::cout); + ineq i0(lp_ac, llc::NE, 1); lp::lar_term t1, t2; @@ -208,7 +208,7 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_0() { bool found0 = false; bool found1 = false; bool found2 = false; - for (const auto& k : lv[0].ineqs()){ + for (const auto& k : lemmas[0].ineqs()){ if (k == i0) { found0 = true; } else if (k == i1) { @@ -246,11 +246,10 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_1() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); svector v; v.push_back(lp_b);v.push_back(lp_d);v.push_back(lp_e); nla.add_monic(lp_bde, v.size(), v.begin()); - vector lemma; s_set_column_value_test(s, lp_a, rational(1)); s_set_column_value_test(s, lp_b, rational(1)); @@ -259,7 +258,8 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial_1() { s_set_column_value_test(s, lp_e, rational(1)); s_set_column_value_test(s, lp_bde, rational(3)); - VERIFY(nla.get_core().test_check(lemma) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); SASSERT(lemma[0].size() == 4); nla.get_core().print_lemma(lemma.back(), std::cout); @@ -317,7 +317,7 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); create_abcde(nla, lp_a, @@ -330,7 +330,6 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial() { lp_bde, lp_acd, lp_be); - vector lemma; // set vars s_set_column_value_test(s, lp_a, rational(1)); @@ -344,7 +343,8 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial() { s_set_column_value_test(s, lp_acd, rational(1)); s_set_column_value_test(s, lp_be, rational(1)); - VERIFY(nla.get_core().test_check(lemma) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); nla.get_core().print_lemma(lemma.back(), std::cout); SASSERT(lemma.size() == 1 && lemma[0].size() == 2); lp::lar_term t0, t1; @@ -379,7 +379,7 @@ void test_basic_lemma_for_mon_zero_from_monomial_to_factors() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); // create monomial acd unsigned_vector vec; @@ -389,14 +389,14 @@ void test_basic_lemma_for_mon_zero_from_monomial_to_factors() { vec.push_back(lp_d); nla.add_monic(lp_acd, vec.size(), vec.begin()); - vector lemma; s_set_column_value_test(s, lp_a, rational(1)); s_set_column_value_test(s, lp_c, rational(1)); s_set_column_value_test(s, lp_d, rational(1)); s_set_column_value_test(s, lp_acd, rational(0)); - VERIFY(nla.get_core().test_check(lemma) == l_false); - + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); + nla.get_core().print_lemma(lemma.back(), std::cout); ineq i0(lp_a, llc::EQ, 0); @@ -439,7 +439,7 @@ void test_basic_lemma_for_mon_neutral_from_monomial_to_factors() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); create_abcde(nla, lp_a, @@ -452,7 +452,6 @@ void test_basic_lemma_for_mon_neutral_from_monomial_to_factors() { lp_bde, lp_acd, lp_be); - vector lemma; // set all vars to 1 s_set_column_value_test(s, lp_a, rational(1)); @@ -471,8 +470,8 @@ void test_basic_lemma_for_mon_neutral_from_monomial_to_factors() { s_set_column_value_test(s, lp_b, - rational(2)); // we have bde = -b, therefore d = +-1 and e = +-1 s_set_column_value_test(s, lp_d, rational(3)); - VERIFY(nla.get_core().test_check(lemma) == l_false); - + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); nla.get_core().print_lemma(lemma.back(), std::cout); ineq i0(lp_d, llc::EQ, 1); @@ -514,7 +513,8 @@ void test_horner() { reslimit l; params_ref p; - solver nla(s, l); + std_vector ib; + solver nla(s, p, l, ib); vector v; v.push_back(a); v.push_back(b); nla.add_monic(lp_ab, v.size(), v.begin()); @@ -551,7 +551,7 @@ void test_basic_sign_lemma() { reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); // create monomial bde vector vec; @@ -584,8 +584,8 @@ void test_basic_sign_lemma() { s_set_column_value_test(s, lp_bde, rational(5)); s_set_column_value_test(s, lp_acd, rational(3)); - vector lemmas; - VERIFY(nla.get_core().test_check(lemmas) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemmas = nla.get_core().lemmas(); lp::lar_term t; t.add_var(lp_bde); @@ -626,7 +626,8 @@ void test_order_lemma_params(bool var_equiv, int sign) { reslimit l; params_ref p; - solver nla(s,l); + std_vector ib; + solver nla(s,p,l,ib); // create monomial ab vector vec; vec.push_back(lp_a); @@ -707,9 +708,9 @@ void test_order_lemma_params(bool var_equiv, int sign) { s_set_column_value_test(s, lp_abef, nla.get_core().mon_value_by_vars(mon_cdij) + rational(1)); } - vector lemma; - VERIFY(nla.get_core().test_check(lemma) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); // lp::lar_term t; // t.add_monomial(lp_bde); // t.add_monomial(lp_acd); @@ -757,7 +758,8 @@ void test_monotone_lemma() { reslimit l; params_ref p; - solver nla(s, l); + std_vector ib; + solver nla(s, p, l, ib); // create monomial ab vector vec; vec.push_back(lp_a); @@ -792,8 +794,8 @@ void test_monotone_lemma() { // set ef = ij while it has to be ef > ij s_set_column_value_test(s, lp_ef, s.get_column_value(lp_ij)); - vector lemma; - VERIFY(nla.get_core().test_check(lemma) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); nla.get_core().print_lemma(lemma.back(), std::cout); */ } @@ -814,15 +816,15 @@ void test_tangent_lemma_rat() { s_set_column_value_test(s, lp_ab, v); reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); // create monomial ab vector vec; vec.push_back(lp_a); vec.push_back(lp_b); nla.add_monic(lp_ab, vec.size(), vec.begin()); - vector lemma; - VERIFY(nla.get_core().test_check(lemma) == l_false); + VERIFY(nla.get_core().test_check() == l_false); + auto const& lemma = nla.get_core().lemmas(); nla.get_core().print_lemma(lemma.back(), std::cout); } @@ -841,16 +843,15 @@ void test_tangent_lemma_reg() { s_set_column_value_test(s, lp_ab, rational(11)); reslimit l; params_ref p; - solver nla(s, l); + solver nla(s, p, l); // create monomial ab vector vec; vec.push_back(lp_a); vec.push_back(lp_b); nla.add_monic(lp_ab, vec.size(), vec.begin()); - vector lemma; - VERIFY(nla.get_core().test_check(lemma) == l_false); - nla.get_core().print_lemma(lemma.back(), std::cout); + VERIFY(nla.get_core().test_check() == l_false); + nla.get_core().print_lemma(nla.get_core().lemmas().back(), std::cout); } void test_tangent_lemma_equiv() { @@ -885,7 +886,8 @@ void test_tangent_lemma_equiv() { s_set_column_value_test(s, lp_a, - s.get_column_value(lp_k)); reslimit l; params_ref p; - solver nla(s, l); + std_vector ib; + solver nla(s, p, l, ib); // create monomial ab vector vec; vec.push_back(lp_a); @@ -893,10 +895,9 @@ void test_tangent_lemma_equiv() { int mon_ab = nla.add_monic(lp_ab, vec.size(), vec.begin()); s_set_column_value_test(s, lp_ab, nla.get_core().mon_value_by_vars(mon_ab) + rational(10)); // greater by ten than the correct value - vector lemma; - VERIFY(nla.get_core().test_check(lemma) == l_false); - nla.get_core().print_lemma(lemma.back(), std::cout); + VERIFY(nla.get_core().test_check() == l_false); + nla.get_core().print_lemma(nla.get_core().lemmas().back(), std::cout); */ } diff --git a/src/test/lp/smt_reader.h b/src/test/lp/smt_reader.h index 7843d5714d2..0e638399fb8 100644 --- a/src/test/lp/smt_reader.h +++ b/src/test/lp/smt_reader.h @@ -28,7 +28,7 @@ Revision History: #include #include #include -#include "math/lp/ul_pair.h" +#include "math/lp/column.h" #include "math/lp/lar_constraints.h" #include #include @@ -343,7 +343,7 @@ namespace lp { solver->add_constraint(&c); } - void create_equality_contraint_for_var(column* col, bound * b, lar_solver *solver) { + void create_equality_constraint_for_var(column* col, bound * b, lar_solver *solver) { lar_constraint c(EQ, b->m_fixed_value); var_index i = solver->add_var(col->m_name); c.add_variable_to_constraint(i, numeric_traits::one()); @@ -366,7 +366,7 @@ namespace lp { create_upper_constraint_for_var(col, b, solver); } if (b->m_value_is_fixed) { - create_equality_contraint_for_var(col, b, solver); + create_equality_constraint_for_var(col, b, solver); } } } @@ -383,7 +383,7 @@ namespace lp { } void add_constraint_to_solver(lar_solver * solver, formula_constraint & fc, unsigned i) { - vector> ls; + vector> ls; for (auto & it : fc.m_coeffs) { ls.push_back(std::make_pair(it.first, solver->add_var(register_name(it.second), false))); } diff --git a/src/test/main.cpp b/src/test/main.cpp index 7cd4b6cf9d6..0c3d0e01af7 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -265,4 +265,7 @@ int main(int argc, char ** argv) { TST(finder); TST(totalizer); TST(distribution); + TST(euf_bv_plugin); + TST(euf_arith_plugin); + TST(sls_test); } diff --git a/src/test/mpz.cpp b/src/test/mpz.cpp index 694f13e9650..2c76b3a6636 100644 --- a/src/test/mpz.cpp +++ b/src/test/mpz.cpp @@ -322,7 +322,7 @@ void tst_scoped() { #define NUM_PRIMES 168 unsigned g_primes[NUM_PRIMES] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997 }; -// Return a big number by multipling powers of the first NUM_PRIMES. +// Return a big number by multiplying powers of the first NUM_PRIMES. // - ratio: rand() % ratio == 0 is used to decide whether a specific prime will be included or not. // - max_pw: if condition above is satisfied, then we use (rand() % max_pw) + 1 as the power. void mk_big_num(unsynch_mpz_manager & m, unsigned ratio, unsigned max_pw, mpz & r) { diff --git a/src/test/pdd.cpp b/src/test/pdd.cpp index 0c9b0f85c2f..740ae4f2b7d 100644 --- a/src/test/pdd.cpp +++ b/src/test/pdd.cpp @@ -153,12 +153,49 @@ class test { pdd b = m.mk_var(1); pdd c = m.mk_var(2); pdd d = m.mk_var(3); - pdd p = (a + b)*(c + 3*d) + 2; + + auto const check = [](unsigned const expected_num_monomials, pdd const& p) { + unsigned count = 0; + std::cout << p << "\n"; + for (auto const& m : p) { + std::cout << " " << m << "\n"; + ++count; + } + VERIFY_EQ(expected_num_monomials, count); + }; + + check(9, (a + b + 2)*(c + 3*d + 5) + 2); + check(5, (a + b)*(c + 3*d) + 2); + check(1, a); + check(2, a + 5); + check(1, m.mk_val(5)); + check(0, m.mk_val(0)); + } + + static void linear_iterator() { + std::cout << "test linear iterator\n"; + pdd_manager m(4); + pdd a = m.mk_var(0); + pdd b = m.mk_var(1); + pdd c = m.mk_var(2); + pdd d = m.mk_var(3); + pdd p = (a + b + 2)*(c + 3*d + 5) + 2; std::cout << p << "\n"; - for (auto const& m : p) { - std::cout << m << "\n"; - } + for (auto const& m : p.linear_monomials()) + std::cout << " " << m << "\n"; + std::cout << a << "\n"; + for (auto const& m : a.linear_monomials()) + std::cout << " " << m << "\n"; + pdd one = m.mk_val(5); + std::cout << one << "\n"; + for (auto const& m : one.linear_monomials()) + std::cout << " " << m << "\n"; + pdd zero = m.mk_val(0); + std::cout << zero << "\n"; + for (auto const& m : zero.linear_monomials()) + std::cout << " " << m << "\n"; } + static void order() { std::cout << "order\n"; pdd_manager m(4); @@ -693,6 +730,7 @@ void tst_pdd() { dd::test::canonize(); dd::test::reset(); dd::test::iterator(); + dd::test::linear_iterator(); dd::test::order(); dd::test::order_lm(); dd::test::mod4_operations(); diff --git a/src/test/pdd_solver.cpp b/src/test/pdd_solver.cpp index a8650426697..33b0abf779b 100644 --- a/src/test/pdd_solver.cpp +++ b/src/test/pdd_solver.cpp @@ -20,13 +20,15 @@ namespace dd { } void test1() { pdd_manager m(4); + u_dependency_manager dm; reslimit lim; pdd v0 = m.mk_var(0); pdd v1 = m.mk_var(1); pdd v2 = m.mk_var(2); pdd v3 = m.mk_var(3); - solver gb(lim, m); + + solver gb(lim, dm, m); gb.add(v1*v2 + v1*v3); gb.add(v1 - 1); gb.display(std::cout); @@ -198,10 +200,11 @@ namespace dd { void test_simplify(expr_ref_vector& fmls, bool use_mod2) { ast_manager& m = fmls.get_manager(); unsigned_vector id2var; + u_dependency_manager dm; collect_id2var(id2var, fmls); pdd_manager p(id2var.size(), use_mod2 ? pdd_manager::mod2_e : pdd_manager::zero_one_vars_e); - solver g(m.limit(), p); + solver g(m.limit(), dm, p); for (expr* e : subterms::ground(fmls)) { add_def(id2var, to_app(e), m, p, g); diff --git a/src/test/rational.cpp b/src/test/rational.cpp index 711958de56a..8fe565c6201 100644 --- a/src/test/rational.cpp +++ b/src/test/rational.cpp @@ -466,6 +466,28 @@ static void tst12() { std::cout << i << ": " << r.get_bit(i) << "\n"; } +static void tst13() { + std::cout << "test13\n"; + rational const step = rational(1) / rational(3); + for (rational r; r < 5000; r += step) { + { + unsigned k = r.prev_power_of_two(); + if (r >= 1) { + VERIFY(rational::power_of_two(k) <= r); + VERIFY(r < rational::power_of_two(k + 1)); + } + else { + VERIFY_EQ(k, 0); + } + } + { + unsigned k = r.next_power_of_two(); + VERIFY(r <= rational::power_of_two(k)); + VERIFY(k == 0 || rational::power_of_two(k - 1) < r); + } + } +} + void tst_rational() { TRACE("rational", tout << "starting rational test...\n";); @@ -492,4 +514,5 @@ void tst_rational() { tst10(true); tst10(false); tst12(); + tst13(); } diff --git a/src/test/sls_test.cpp b/src/test/sls_test.cpp new file mode 100644 index 00000000000..d99035398e3 --- /dev/null +++ b/src/test/sls_test.cpp @@ -0,0 +1,243 @@ + +#include "ast/sls/bv_sls_eval.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" + +namespace bv { + class sls_test { + ast_manager& m; + bv_util bv; + + public: + sls_test(ast_manager& m): + m(m), + bv(m) + {} + + void check_eval(expr* a, expr* b, unsigned j) { + auto es = create_exprs(a, b, j); + for (expr* e : es) + check_eval(e); + } + + void check_eval(expr* e) { + std::function value = [](expr*, unsigned) { + return false; + }; + expr_ref_vector es(m); + bv_util bv(m); + es.push_back(e); + sls_eval ev(m); + ev.init_eval(es, value); + ev.tighten_range(es); + th_rewriter rw(m); + expr_ref r(e, m); + rw(r); + + if (bv.is_bv(e)) { + auto const& val = ev.wval(e); + rational n1, n2; + + n1 = val.get_value(); + + VERIFY(bv.is_numeral(r, n2)); + if (n1 != n2) { + verbose_stream() << mk_pp(e, m) << " computed value " << val << "\n"; + verbose_stream() << "should be " << n2 << "\n"; + } + SASSERT(n1 == n2); + VERIFY(n1 == n2); + } + else if (m.is_bool(e)) { + auto val1 = ev.bval0(e); + auto val2 = m.is_true(r); + if (val1 != val2) { + verbose_stream() << mk_pp(e, m) << " computed value " << val1 << " at odds with definition\n"; + } + SASSERT(val1 == val2); + VERIFY(val1 == val2); + } + } + + expr_ref_vector create_exprs(expr* a, expr* b, unsigned j) { + expr_ref_vector result(m); + result.push_back(bv.mk_bv_add(a, b)) + .push_back(bv.mk_bv_mul(a, b)) + .push_back(bv.mk_bv_sub(a, b)) + .push_back(bv.mk_bv_udiv(a, b)) + .push_back(bv.mk_bv_sdiv(a, b)) + .push_back(bv.mk_bv_srem(a, b)) + .push_back(bv.mk_bv_urem(a, b)) + .push_back(bv.mk_bv_smod(a, b)) + .push_back(bv.mk_bv_shl(a, b)) + .push_back(bv.mk_bv_ashr(a, b)) + .push_back(bv.mk_bv_lshr(a, b)) + .push_back(bv.mk_bv_and(a, b)) + .push_back(bv.mk_bv_or(a, b)) + .push_back(bv.mk_bv_xor(a, b)) + .push_back(bv.mk_bv_neg(a)) + .push_back(bv.mk_bv_not(a)) + .push_back(bv.mk_bvumul_ovfl(a, b)) + .push_back(bv.mk_bvumul_no_ovfl(a, b)) + .push_back(bv.mk_zero_extend(3, a)) + .push_back(bv.mk_sign_extend(3, a)) + .push_back(bv.mk_ule(a, b)) + .push_back(bv.mk_sle(a, b)) + .push_back(bv.mk_concat(a, b)) + .push_back(bv.mk_extract(4, 2, a)) + .push_back(bv.mk_bvuadd_ovfl(a, b)) + .push_back(bv.mk_bv_rotate_left(a, j)) + .push_back(bv.mk_bv_rotate_right(a, j)) + .push_back(bv.mk_bv_rotate_left(a, b)) + .push_back(bv.mk_bv_rotate_right(a, b)) + // .push_back(bv.mk_bvsadd_ovfl(a, b)) + // .push_back(bv.mk_bvneg_ovfl(a)) + // .push_back(bv.mk_bvsmul_no_ovfl(a, b)) + // .push_back(bv.mk_bvsmul_no_udfl(a, b)) + // .push_back(bv.mk_bvsmul_ovfl(a, b)) + // .push_back(bv.mk_bvsdiv_ovfl(a, b)) + ; + return result; + } + + + // e = op(a, b), + // update value of a to "random" + // repair a based on computed values. + void check_repair(expr* a, expr* b, unsigned j) { + expr_ref x(m.mk_const("x", bv.mk_sort(bv.get_bv_size(a))), m); + expr_ref y(m.mk_const("y", bv.mk_sort(bv.get_bv_size(b))), m); + auto es1 = create_exprs(a, b, j); + auto es2 = create_exprs(x, b, j); + auto es3 = create_exprs(a, y, j); + for (unsigned i = 0; i < es1.size(); ++i) { + auto e1 = es1.get(i); + auto e2 = es2.get(i); + auto e3 = es3.get(i); + if (bv.is_bv_sdiv(e1)) + continue; + if (bv.is_bv_srem(e1)) + continue; + if (bv.is_bv_smod(e1)) + continue; + if (is_app_of(e1, bv.get_fid(), OP_BUADD_OVFL)) + continue; + check_repair_idx(e1, e2, 0, x); + if (is_app(e1) && to_app(e1)->get_num_args() == 2) + check_repair_idx(e1, e3, 1, y); + } + } + + random_gen rand; + + void check_repair_idx(expr* e1, expr* e2, unsigned idx, expr* x) { + std::function value = [&](expr*, unsigned) { + return rand() % 2 == 0; + }; + expr_ref_vector es(m); + bv_util bv(m); + th_rewriter rw(m); + expr_ref r(e1, m); + rw(r); + es.push_back(m.is_false(r) ? m.mk_not(e1) : e1); + es.push_back(m.is_false(r) ? m.mk_not(e2) : e2); + sls_eval ev(m); + ev.init_eval(es, value); + ev.tighten_range(es); + + if (m.is_bool(e1)) { + SASSERT(m.is_true(r) || m.is_false(r)); + auto val = m.is_true(r); + auto val2 = ev.bval0(e2); + if (val != val2) { + ev.set(e2, val); + auto rep1 = ev.try_repair(to_app(e2), idx); + if (!rep1) { + verbose_stream() << "Not repaired " << mk_pp(e1, m) << " " << mk_pp(e2, m) << " r: " << r << "\n"; + } + auto val3 = ev.bval0(e2); + if (val3 != val) { + verbose_stream() << "Repaired but not corrected " << mk_pp(e2, m) << "\n"; + ev.display(std::cout, es); + exit(0); + } + //SASSERT(rep1); + } + } + if (bv.is_bv(e1)) { + auto& val1 = ev.wval(e1); + auto& val2 = ev.wval(e2); + if (!val1.eq(val2)) { + val2.set(val1.bits()); + auto rep2 = ev.try_repair(to_app(e2), idx); + if (!rep2) { + verbose_stream() << "Not repaired " << mk_pp(e2, m) << "\n"; + } + auto val3 = ev.wval(e2); + val3.commit_eval(); + if (!val3.eq(val1)) { + verbose_stream() << "Repaired but not corrected " << mk_pp(e2, m) << "\n"; + } + //SASSERT(rep2); + } + } + } + + // todo: + void test_fixed() { + + } + }; +} + + +static void test_eval1() { + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + + expr_ref e(m); + + bv::sls_test validator(m); + + unsigned k = 0; + unsigned bw = 6; + for (unsigned i = 0; i < 1ul << bw; ++i) { + expr_ref a(bv.mk_numeral(rational(i), bw), m); + for (unsigned j = 0; j < 1ul << bw; ++j) { + expr_ref b(bv.mk_numeral(rational(j), bw), m); + ++k; + if (k % 1000 == 0) + verbose_stream() << "tests " << k << "\n"; + validator.check_eval(a, b, j); + } + } +} + +static void test_repair1() { + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + expr_ref e(m); + bv::sls_test validator(m); + + unsigned k = 0; + unsigned bw = 6; + for (unsigned i = 0; i < 1ul << bw; ++i) { + expr_ref a(bv.mk_numeral(rational(i), bw), m); + for (unsigned j = 0; j < 1ul << bw; ++j) { + expr_ref b(bv.mk_numeral(rational(j), bw), m); + ++k; + if (k % 1000 == 0) + verbose_stream() << "tests " << k << "\n"; + validator.check_repair(a, b, j); + } + } +} + +void tst_sls_test() { + test_eval1(); + test_repair1(); + +} diff --git a/src/test/smt2print_parse.cpp b/src/test/smt2print_parse.cpp index 765a78060bf..76b169a4ae8 100644 --- a/src/test/smt2print_parse.cpp +++ b/src/test/smt2print_parse.cpp @@ -8,6 +8,7 @@ Copyright (c) 2015 Microsoft Corporation // for SMT-LIB2. #include "api/z3.h" +#include "util/debug.h" #include void test_print(Z3_context ctx, Z3_ast_vector av) { @@ -64,26 +65,22 @@ void test_parseprint(char const* spec) { Z3_del_context(ctx); } +static bool is_error = false; +void setError(Z3_context c, Z3_error_code e) { + is_error = true; +} + void test_eval(Z3_context ctx, Z3_string spec, bool shouldFail) { std::cout << "spec:\n" << spec << "\n"; std::string resp; - bool failed = false; - try { - resp = Z3_eval_smtlib2_string(ctx, spec); - } - catch (std::runtime_error& e) { - resp = e.what(); - failed = true; - } - catch (...) { - resp = "unknown exception"; - failed = true; - } + is_error = false; + resp = Z3_eval_smtlib2_string(ctx, spec); - std::cout << "response:\n" << resp << "\n"; + if (!is_error) + std::cout << "response:\n" << resp << "\n"; - if (shouldFail != failed) { + if (shouldFail != is_error) { if (shouldFail) throw std::runtime_error("should have failed"); else @@ -91,9 +88,6 @@ void test_eval(Z3_context ctx, Z3_string spec, bool shouldFail) { } } -void throwError(Z3_context c, Z3_error_code e) { - throw std::runtime_error(Z3_get_error_msg(c, e)); -} void test_repeated_eval() { // Z3_eval_smtlib2_string reuses the parser and the scanner @@ -142,7 +136,7 @@ void test_repeated_eval() { "(pop)\n"; Z3_context ctx = Z3_mk_context(nullptr); - Z3_set_error_handler(ctx, throwError); + Z3_set_error_handler(ctx, setError); std::cout << "testing Z3_eval_smtlib2_string\n"; try { @@ -166,6 +160,70 @@ void test_repeated_eval() { Z3_del_context(ctx); } +void test_name(Z3_string spec, Z3_string expected_name) { + Z3_context ctx = Z3_mk_context(nullptr); + Z3_set_error_handler(ctx, setError); + std::cout << "spec:\n" << spec << "\n"; + is_error = false; + + Z3_ast_vector a = + Z3_parse_smtlib2_string(ctx, + spec, + 0, + nullptr, + nullptr, + 0, + nullptr, + nullptr); + + std::cout << "done parsing\n"; + ENSURE(is_error == (expected_name == nullptr)); + if (is_error) { + Z3_del_context(ctx); + return; + } + Z3_ast_vector_inc_ref(ctx, a); + + ENSURE(Z3_ast_vector_size(ctx, a) == 1) + Z3_ast c = Z3_ast_vector_get(ctx, a, 0); + Z3_inc_ref(ctx, c); + Z3_app app = Z3_to_app(ctx, c); + Z3_func_decl decl = Z3_get_app_decl(ctx, app); + Z3_symbol symbol = Z3_get_decl_name(ctx, decl); + Z3_string name = Z3_get_symbol_string(ctx, symbol); + bool success = std::string(name) == std::string(expected_name); + Z3_dec_ref(ctx, c); + Z3_ast_vector_dec_ref(ctx, a); + Z3_del_context(ctx); + ENSURE(success); +} + +void test_symbol_escape() { + +#define SYMBOL_ASSERTION(N) \ + "(declare-const " N " Bool)\n" \ + "(assert " N ")\n" \ + "(check-sat)\n" + + std::cout << "testing Z3_eval_smtlib2_string\n"; + + try { + test_name(SYMBOL_ASSERTION("|a|"), "a"); + test_name(SYMBOL_ASSERTION("|a\\|"), nullptr); + test_name(SYMBOL_ASSERTION("|a\\||"), "a|"); + test_name(SYMBOL_ASSERTION("|a\\\\|"), "a\\"); + test_name(SYMBOL_ASSERTION("|a\\\\||"), nullptr); + test_name(SYMBOL_ASSERTION("|a\\a|"), "a\\a"); + test_name(SYMBOL_ASSERTION("|a\\a"), nullptr); + } + catch(...) { + std::cout << "Error: uncaught exception\n"; + throw; + } + + std::cout << "done evaluating\n"; +} + void tst_smt2print_parse() { // test basic datatypes @@ -232,4 +290,6 @@ void tst_smt2print_parse() { test_repeated_eval(); + test_symbol_escape(); + } diff --git a/src/util/approx_set.h b/src/util/approx_set.h index aa6f8d38300..a1835be6f2d 100644 --- a/src/util/approx_set.h +++ b/src/util/approx_set.h @@ -41,7 +41,7 @@ static_assert(sizeof(unsigned) == 4, "unsigned are 4 bytes"); template class approx_set_tpl : private T2U_Proc { protected: - R m_set; + R m_set = approx_set_traits::zero; unsigned e2u(T const & e) const { return T2U_Proc::operator()(e); } @@ -52,24 +52,17 @@ class approx_set_tpl : private T2U_Proc { static approx_set_tpl r2s(R const & s) { approx_set_tpl r; r.m_set = s; return r; } public: - approx_set_tpl(): - m_set(approx_set_traits::zero) { - } + approx_set_tpl() = default; explicit approx_set_tpl(T const & e): m_set(e2s(e)) { } - approx_set_tpl(unsigned sz, T const * es): - m_set(approx_set_traits::zero) { + approx_set_tpl(unsigned sz, T const * es) { for (unsigned i = 0; i < sz; i++) insert(es[i]); } - approx_set_tpl(approx_set_tpl const & s): - m_set(s.m_set) { - } - void set(R s) { m_set = s; } R get() const { return m_set; } diff --git a/src/util/array.h b/src/util/array.h index a4c2aa2c7b0..2954035f641 100644 --- a/src/util/array.h +++ b/src/util/array.h @@ -184,7 +184,7 @@ class array { T const * data() const { return m_data; } T * data() { return m_data; } - void swap(array & other) { + void swap(array & other) noexcept { std::swap(m_data, other.m_data); } diff --git a/src/util/basic_interval.h b/src/util/basic_interval.h index ca5868b8817..ea35259a036 100644 --- a/src/util/basic_interval.h +++ b/src/util/basic_interval.h @@ -52,8 +52,8 @@ class basic_interval_manager { interval const & get() const { return m_interval; } interval & get() { return m_interval; } void reset() { m().reset(m_interval); } - void swap(scoped_interval & a) { m().swap(m_interval, a.m_interval); } - void swap(interval & a) { m().swap(m_interval, a); } + void swap(scoped_interval & a) noexcept { m().swap(m_interval, a.m_interval); } + void swap(interval & a) noexcept { m().swap(m_interval, a); } bound const & lower() const { return m_interval.lower(); } bound const & upper() const { return m_interval.upper(); } bound & lower() { return m_interval.lower(); } @@ -146,7 +146,7 @@ class basic_interval_manager { m().set(a.m_upper, n); } - void swap(interval & a, interval & b) { + void swap(interval & a, interval & b) noexcept { m().swap(a.m_lower, b.m_lower); m().swap(a.m_upper, b.m_upper); } diff --git a/src/util/bit_vector.h b/src/util/bit_vector.h index cb29bdd9cc6..12f86dd0084 100644 --- a/src/util/bit_vector.h +++ b/src/util/bit_vector.h @@ -98,7 +98,7 @@ class bit_vector { m_num_bits = 0; } - void swap(bit_vector & other) { + void swap(bit_vector & other) noexcept { std::swap(m_data, other.m_data); std::swap(m_num_bits, other.m_num_bits); std::swap(m_capacity, other.m_capacity); diff --git a/src/util/chashtable.h b/src/util/chashtable.h index 3a2d2685b71..b15d6017f77 100644 --- a/src/util/chashtable.h +++ b/src/util/chashtable.h @@ -161,8 +161,12 @@ class chashtable : private HashProc, private EqProc { unsigned curr_cellar = (m_capacity - m_slots); unsigned new_slots = m_slots * 2; unsigned new_cellar = curr_cellar * 2; + if (new_slots < m_slots || new_cellar < curr_cellar) + throw default_exception("table overflow"); while (true) { unsigned new_capacity = new_slots + new_cellar; + if (new_capacity < new_slots) + throw default_exception("table overflow"); cell * new_table = alloc_table(new_capacity); cell * next_cell = copy_table(m_table, m_slots, m_capacity, new_table, new_slots, new_capacity, @@ -179,6 +183,8 @@ class chashtable : private HashProc, private EqProc { return; } dealloc_vect(new_table, new_capacity); + if (2*new_cellar < new_cellar) + throw default_exception("table overflow"); new_cellar *= 2; } } @@ -553,7 +559,7 @@ class chashtable : private HashProc, private EqProc { iterator begin() const { return iterator(m_table, m_table + m_slots); } iterator end() const { return iterator(); } - void swap(chashtable & other) { + void swap(chashtable & other) noexcept { std::swap(m_table, other.m_table); std::swap(m_capacity, other.m_capacity); std::swap(m_init_slots, other.m_init_slots); diff --git a/src/util/dependency.h b/src/util/dependency.h index 7ccba716a53..6094cc555fc 100644 --- a/src/util/dependency.h +++ b/src/util/dependency.h @@ -44,8 +44,39 @@ class dependency_manager { public: unsigned get_ref_count() const { return m_ref_count; } bool is_leaf() const { return m_leaf == 1; } + value const& leaf_value() const { SASSERT(is_leaf()); return static_cast(this)->m_value; } }; + static void linearize_todo(ptr_vector& todo, vector& vs) { + unsigned qhead = 0; + while (qhead < todo.size()) { + dependency* d = todo[qhead]; + qhead++; + if (d->is_leaf()) { + vs.push_back(to_leaf(d)->m_value); + } + else { + for (unsigned i = 0; i < 2; i++) { + dependency* child = to_join(d)->m_children[i]; + if (!child->is_marked()) { + todo.push_back(child); + child->mark(); + } + } + } + } + for (auto* d : todo) + d->unmark(); + } + + static void s_linearize(dependency* d, vector& vs) { + if (!d) + return; + ptr_vector todo; + todo.push_back(d); + linearize_todo(todo, vs); + } + private: struct join : public dependency { dependency * m_children[2]; @@ -69,7 +100,7 @@ class dependency_manager { value_manager & m_vmanager; allocator & m_allocator; - ptr_vector m_todo; + mutable ptr_vector m_todo; void inc_ref(value const & v) { if (C::ref_count) @@ -83,6 +114,7 @@ class dependency_manager { void del(dependency * d) { SASSERT(d); + SASSERT(m_todo.empty()); m_todo.push_back(d); while (!m_todo.empty()) { d = m_todo.back(); @@ -106,12 +138,9 @@ class dependency_manager { } } - void unmark_todo() { - typename ptr_vector::iterator it = m_todo.begin(); - typename ptr_vector::iterator end = m_todo.end(); - for (; it != end; ++it) { - (*it)->unmark(); - } + void unmark_todo() const { + for (auto* d : m_todo) + d->unmark(); m_todo.reset(); } @@ -193,30 +222,30 @@ class dependency_manager { return false; } - void linearize(dependency * d, vector & vs) { - if (d) { - m_todo.reset(); - d->mark(); - m_todo.push_back(d); - unsigned qhead = 0; - while (qhead < m_todo.size()) { - d = m_todo[qhead]; - qhead++; - if (d->is_leaf()) { - vs.push_back(to_leaf(d)->m_value); - } - else { - for (unsigned i = 0; i < 2; i++) { - dependency * child = to_join(d)->m_children[i]; - if (!child->is_marked()) { - m_todo.push_back(child); - child->mark(); - } - } - } + + + void linearize(dependency * d, vector & vs) const { + if (!d) + return; + SASSERT(m_todo.empty()); + d->mark(); + m_todo.push_back(d); + linearize_todo(m_todo, vs); + m_todo.reset(); + } + + void linearize(ptr_vector& deps, vector & vs) const { + if (deps.empty()) + return; + SASSERT(m_todo.empty()); + for (auto* d : deps) { + if (d && !d->is_marked()) { + d->mark(); + m_todo.push_back(d); } - unmark_todo(); } + linearize_todo(m_todo, vs); + m_todo.reset(); } }; @@ -300,7 +329,16 @@ class scoped_dependency_manager { return m_dep_manager.contains(d, v); } - void linearize(dependency * d, vector & vs) { + void linearize(dependency * d, vector & vs) const { + return m_dep_manager.linearize(d, vs); + } + + static vector const& s_linearize(dependency* d, vector& vs) { + dep_manager::s_linearize(d, vs); + return vs; + } + + void linearize(ptr_vector& d, vector & vs) const { return m_dep_manager.linearize(d, vs); } @@ -323,4 +361,83 @@ typedef scoped_dependency_manager::dependency v_dependency; typedef scoped_dependency_manager u_dependency_manager; typedef scoped_dependency_manager::dependency u_dependency; +/** + \brief Version of the scoped-depenendcy-manager where region scopes are handled externally. +*/ +template +class stacked_dependency_manager { + + class config { + public: + static const bool ref_count = true; + typedef Value value; + + class value_manager { + public: + void inc_ref(value const& v) { + } + + void dec_ref(value const& v) { + } + }; + + class allocator { + region& m_region; + public: + allocator(region& r) : m_region(r) {} + + void* allocate(size_t sz) { + return m_region.allocate(sz); + } + + void deallocate(size_t sz, void* mem) { + } + }; + }; + + typedef dependency_manager dep_manager; +public: + typedef typename dep_manager::dependency dependency; + typedef Value value; + +private: + typename config::value_manager m_vmanager; + typename config::allocator m_allocator; + dep_manager m_dep_manager; + +public: + stacked_dependency_manager(region& r) : + m_allocator(r), + m_dep_manager(m_vmanager, m_allocator) { + } + + dependency* mk_empty() { + return m_dep_manager.mk_empty(); + } + + dependency* mk_leaf(value const& v) { + return m_dep_manager.mk_leaf(v); + } + + dependency* mk_join(dependency* d1, dependency* d2) { + return m_dep_manager.mk_join(d1, d2); + } + + bool contains(dependency* d, value const& v) { + return m_dep_manager.contains(d, v); + } + + void linearize(dependency* d, vector& vs) { + return m_dep_manager.linearize(d, vs); + } + + static vector const& s_linearize(dependency* d, vector& vs) { + dep_manager::s_linearize(d, vs); + return vs; + } + + void linearize(ptr_vector& d, vector& vs) { + return m_dep_manager.linearize(d, vs); + } +}; diff --git a/src/util/dlist.h b/src/util/dlist.h index e5c95b8cf38..4c0e51e5898 100644 --- a/src/util/dlist.h +++ b/src/util/dlist.h @@ -188,14 +188,14 @@ class dll_iterator { dll_iterator(T const* elem, bool first): m_elem(elem), m_first(first) { } public: - static dll_iterator mk_begin(T const* elem) { - // Setting first==(bool)elem makes this also work for elem==nullptr; + static dll_iterator mk_begin(T const* list) { + // Setting first==(bool)list makes this also work for list==nullptr; // but we can't implement top-level begin/end for pointers because it clashes with the definition for arrays. - return {elem, (bool)elem}; + return {list, (bool)list}; } - static dll_iterator mk_end(T const* elem) { - return {elem, false}; + static dll_iterator mk_end(T const* list) { + return {list, false}; } using value_type = T; @@ -223,17 +223,26 @@ class dll_iterator { } }; +template +class dll_elements { + T const* m_list; +public: + dll_elements(T const* list) : m_list(list) {} + dll_iterator begin() const { return dll_iterator::mk_begin(m_list); } + dll_iterator end() const { return dll_iterator::mk_end(m_list); } +}; + template < typename T , typename U = std::enable_if_t, T>> // should only match if T actually inherits from dll_base > -dll_iterator begin(T const& elem) { - return dll_iterator::mk_begin(&elem); +dll_iterator begin(T const& list) { + return dll_iterator::mk_begin(&list); } template < typename T , typename U = std::enable_if_t, T>> // should only match if T actually inherits from dll_base > -dll_iterator end(T const& elem) +dll_iterator end(T const& list) { - return dll_iterator::mk_end(&elem); + return dll_iterator::mk_end(&list); } diff --git a/src/util/double_manager.h b/src/util/double_manager.h index 6bd691d2420..c2189458220 100644 --- a/src/util/double_manager.h +++ b/src/util/double_manager.h @@ -76,7 +76,7 @@ class double_manager { static void set(double & a, unsigned val) { a = static_cast(val); } static void set(double & a, int64_t val) { a = static_cast(val); } static void set(double & a, uint64_t val) { a = static_cast(val); } - static void swap(double & a, double & b) { std::swap(a, b); } + static void swap(double & a, double & b) noexcept { std::swap(a, b); } bool is_pos(double a) const { return a > m_zero_tolerance; } bool is_neg(double a) const { return a < m_zero_tolerance; } bool is_zero(double a) const { return -m_zero_tolerance <= a && a <= m_zero_tolerance; } diff --git a/src/util/f2n.h b/src/util/f2n.h index 2b6c9799a31..e94c21ce26d 100644 --- a/src/util/f2n.h +++ b/src/util/f2n.h @@ -45,7 +45,7 @@ class f2n { m_manager.set(m_one, ebits, sbits, 1); } - f2n(f2n && other) : m_manager(other.m_manager), m_mode(other.m_mode), m_ebits(other.m_ebits), m_sbits(other.m_sbits), + f2n(f2n && other) noexcept : m_manager(other.m_manager), m_mode(other.m_mode), m_ebits(other.m_ebits), m_sbits(other.m_sbits), m_tmp1(std::move(other.m_tmp1)), m_one(std::move(other.m_one)) {} ~f2n() { @@ -86,7 +86,7 @@ class f2n { void set(numeral & o, numeral const & x) { m().set(o, x); check(o); } void set(numeral & o, mpq const & x) { m().set(o, m_ebits, m_sbits, m_mode, x); check(o); } void reset(numeral & o) { m().reset(o, m_ebits, m_sbits); } - static void swap(numeral & x, numeral & y) { x.swap(y); } + static void swap(numeral & x, numeral & y) noexcept { x.swap(y); } void add(numeral const & x, numeral const & y, numeral & o) { m().add(m_mode, x, y, o); check(o); } void sub(numeral const & x, numeral const & y, numeral & o) { m().sub(m_mode, x, y, o); check(o); } diff --git a/src/util/hashtable.h b/src/util/hashtable.h index cfe4d20cabf..b2830326bc4 100644 --- a/src/util/hashtable.h +++ b/src/util/hashtable.h @@ -283,7 +283,7 @@ class core_hashtable : private HashProc, private EqProc { delete_table(); } - void swap(core_hashtable & source) { + void swap(core_hashtable & source) noexcept { std::swap(m_table, source.m_table); std::swap(m_capacity, source.m_capacity); std::swap(m_size, source.m_size); diff --git a/src/util/heap.h b/src/util/heap.h index df67b31be48..c080c6ebde9 100644 --- a/src/util/heap.h +++ b/src/util/heap.h @@ -159,7 +159,7 @@ class heap : private LT { } unsigned size() const { - return m_value2indices.size(); + return m_values.size() - 1; } void reserve(int s) { @@ -259,7 +259,7 @@ class heap : private LT { return m_values.end(); } - void swap(heap & other) { + void swap(heap & other) noexcept { if (this != &other) { CASSERT("heap", other.check_invariant()); CASSERT("heap", check_invariant()); diff --git a/src/util/hwf.h b/src/util/hwf.h index 4e1081ebf0d..209a8fe77a9 100644 --- a/src/util/hwf.h +++ b/src/util/hwf.h @@ -35,7 +35,7 @@ class hwf { } public: - void swap(hwf & other) { std::swap(value, other.value); } + void swap(hwf & other) noexcept { std::swap(value, other.value); } }; diff --git a/src/util/inf_eps_rational.h b/src/util/inf_eps_rational.h index b28a6d794cd..ea7ec52fcc6 100644 --- a/src/util/inf_eps_rational.h +++ b/src/util/inf_eps_rational.h @@ -38,7 +38,7 @@ class inf_eps_rational { struct eq_proc { bool operator()(inf_eps_rational const& r1, inf_eps_rational const& r2) const { return r1 == r2; } }; - void swap(inf_eps_rational & n) { + void swap(inf_eps_rational & n) noexcept { m_infty.swap(n.m_infty); m_r.swap(n.m_r); } diff --git a/src/util/inf_int_rational.h b/src/util/inf_int_rational.h index 458843cb037..2025af319f7 100644 --- a/src/util/inf_int_rational.h +++ b/src/util/inf_int_rational.h @@ -43,7 +43,7 @@ class inf_int_rational { struct eq_proc { bool operator()(inf_int_rational const& r1, inf_int_rational const& r2) const { return r1 == r2; } }; - void swap(inf_int_rational & n) { + void swap(inf_int_rational & n) noexcept { m_first.swap(n.m_first); std::swap(m_second, n.m_second); } diff --git a/src/util/inf_rational.h b/src/util/inf_rational.h index fb8a71e5f9a..2d7963ff080 100644 --- a/src/util/inf_rational.h +++ b/src/util/inf_rational.h @@ -44,7 +44,7 @@ class inf_rational { struct eq_proc { bool operator()(inf_rational const& r1, inf_rational const& r2) const { return r1 == r2; } }; - void swap(inf_rational & n) { + void swap(inf_rational & n) noexcept { m_first.swap(n.m_first); m_second.swap(n.m_second); } diff --git a/src/util/inf_s_integer.h b/src/util/inf_s_integer.h index b99167b3261..a02ba70fda9 100644 --- a/src/util/inf_s_integer.h +++ b/src/util/inf_s_integer.h @@ -37,7 +37,7 @@ class inf_s_integer { struct eq_proc { bool operator()(inf_s_integer const& r1, inf_s_integer const& r2) const { return r1 == r2; } }; - void swap(inf_s_integer & n) { + void swap(inf_s_integer & n) noexcept { std::swap(m_first, n.m_first); std::swap(m_second, n.m_second); } diff --git a/src/util/map.h b/src/util/map.h index e9880e0a074..0068be31b1b 100644 --- a/src/util/map.h +++ b/src/util/map.h @@ -184,7 +184,7 @@ class table2map { unsigned long long get_num_collision() const { return m_table.get_num_collision(); } - void swap(table2map & other) { + void swap(table2map & other) noexcept { m_table.swap(other.m_table); } diff --git a/src/util/memory_manager.h b/src/util/memory_manager.h index 7dab520df7b..af56c4507f4 100644 --- a/src/util/memory_manager.h +++ b/src/util/memory_manager.h @@ -128,6 +128,29 @@ void dealloc_svect(T * ptr) { memory::deallocate(ptr); } +template +struct std_allocator { + using value_type = T; + // the constructors must be provided according to cpp docs + std_allocator() = default; + template constexpr std_allocator(const std_allocator&) noexcept {} + + + T* allocate(std::size_t n) { + return static_cast(memory::allocate(n * sizeof(T))); + } + + void deallocate(T* p, std::size_t n) { + memory::deallocate(p); + } +}; + +// the comparison operators must be provided according to cpp docs +template +bool operator==(const std_allocator&, const std_allocator&) { return true; } +template +bool operator!=(const std_allocator&, const std_allocator&) { return false; } + struct mem_stat { }; diff --git a/src/util/mpbq.h b/src/util/mpbq.h index 13cd6ae0c5f..edd864dc3e2 100644 --- a/src/util/mpbq.h +++ b/src/util/mpbq.h @@ -40,10 +40,10 @@ class mpbq { mpbq(int v, unsigned k):m_num(v), m_k(k) {} mpz const & numerator() const { return m_num; } unsigned k() const { return m_k; } - void swap(mpbq & other) { m_num.swap(other.m_num); std::swap(m_k, other.m_k); } + void swap(mpbq & other) noexcept { m_num.swap(other.m_num); std::swap(m_k, other.m_k); } }; -inline void swap(mpbq & m1, mpbq & m2) { m1.swap(m2); } +inline void swap(mpbq & m1, mpbq & m2) noexcept { m1.swap(m2); } typedef svector mpbq_vector; @@ -72,7 +72,7 @@ class mpbq_manager { mpbq_manager(unsynch_mpz_manager & m); ~mpbq_manager(); - static void swap(mpbq & a, mpbq & b) { a.swap(b); } + static void swap(mpbq & a, mpbq & b) noexcept { a.swap(b); } void del(mpbq & a) { m_manager.del(a.m_num); } void reset(mpbq & a) { m_manager.reset(a.m_num); a.m_k = 0; } diff --git a/src/util/mpf.cpp b/src/util/mpf.cpp index d2f30e708e8..da0e10571a0 100644 --- a/src/util/mpf.cpp +++ b/src/util/mpf.cpp @@ -41,7 +41,7 @@ mpf::mpf(unsigned _ebits, unsigned _sbits): set(ebits, sbits); } -void mpf::swap(mpf & other) { +void mpf::swap(mpf & other) noexcept { unsigned tmp = ebits; ebits = other.ebits; other.ebits = tmp; @@ -913,34 +913,38 @@ void mpf_manager::fma(mpf_rounding_mode rm, mpf const & x, mpf const & y, mpf co TRACE("mpf_dbg", tout << "R*= " << to_string_binary(res, 2, 0) << " (renormalized, delta=" << renorm_delta << ")" << std::endl;); - if (exp(res) <= mk_max_exp(x.ebits)) - { - set(o, x.ebits, x.sbits, res.sign(), res.exponent(), mpz(0)); + set(o, x.ebits, x.sbits, res.sign(), res.exponent(), mpz(0)); - if (x.sbits >= 4) { - m_mpz_manager.machine_div_rem(res.significand(), m_powers2(x.sbits - 4 + 3), o.significand, sticky_rem); - renorm_sticky |= !m_mpz_manager.is_zero(sticky_rem); - } - else { - m_mpz_manager.mul2k(res.significand(), 4 - x.sbits + 3, o.significand); - } + if (x.sbits >= 4) { + m_mpz_manager.machine_div_rem(res.significand(), m_powers2(x.sbits - 4 + 3), o.significand, sticky_rem); + renorm_sticky |= !m_mpz_manager.is_zero(sticky_rem); + } + else { + m_mpz_manager.mul2k(res.significand(), 4 - x.sbits + 3, o.significand); + o.exponent -= 4 - x.sbits + 3; + } - if (renorm_sticky && m_mpz_manager.is_even(o.significand)) - m_mpz_manager.inc(o.significand); + if (renorm_sticky && m_mpz_manager.is_even(o.significand)) + m_mpz_manager.inc(o.significand); - TRACE("mpf_dbg", tout << "sum[-1:sbits+2] = " << m_mpz_manager.to_string(o.significand) << std::endl; - tout << "R = " << to_string_binary(o, 1, 3) << std::endl;); + TRACE("mpf_dbg", tout << "sum[-1:sbits+2] = " << m_mpz_manager.to_string(o.significand) << std::endl; + tout << "R = " << to_string_binary(o, 1, 3) << std::endl;); - if (m_mpz_manager.is_zero(o.significand)) - mk_zero(x.ebits, x.sbits, rm == MPF_ROUND_TOWARD_NEGATIVE, o); - else - round(rm, o); - } - else { - mk_inf(x.ebits, x.sbits, res.sign(), o); + unsigned max_size = o.sbits+4; + unsigned sig_size = m_mpz_manager.bitsize(o.significand); + if (sig_size > max_size) { + unsigned d = sig_size - max_size; + m_mpz_manager.machine_div2k(o.significand, d); + o.exponent += d; } + + if (m_mpz_manager.is_zero(o.significand)) + mk_zero(x.ebits, x.sbits, rm == MPF_ROUND_TOWARD_NEGATIVE, o); + else + round(rm, o); } + TRACE("mpf_dbg", tout << "FMA = " << to_string(o) << std::endl;); } void my_mpz_sqrt(unsynch_mpz_manager & m, unsigned sbits, bool odd_exp, mpz & in, mpz & o) { diff --git a/src/util/mpf.h b/src/util/mpf.h index 2c3e528d35a..b979e78c1e9 100644 --- a/src/util/mpf.h +++ b/src/util/mpf.h @@ -52,7 +52,7 @@ class mpf { mpf & operator=(mpf const & other) = delete; unsigned get_ebits() const { return ebits; } unsigned get_sbits() const { return sbits; } - void swap(mpf & other); + void swap(mpf & other) noexcept; }; class mpf_manager { @@ -87,7 +87,7 @@ class mpf_manager { void neg(mpf & o); void neg(mpf const & x, mpf & o); - void swap(mpf& a, mpf& b) { a.swap(b); } + void swap(mpf& a, mpf& b) noexcept { a.swap(b); } bool is_zero(mpf const & x); bool is_neg(mpf const & x); diff --git a/src/util/mpff.h b/src/util/mpff.h index c9012a2d053..254c52572c5 100644 --- a/src/util/mpff.h +++ b/src/util/mpff.h @@ -44,14 +44,14 @@ class mpff { m_exponent(0) { } - void swap(mpff & other) { + void swap(mpff & other) noexcept { unsigned sign = m_sign; m_sign = other.m_sign; other.m_sign = sign; unsigned sig_idx = m_sig_idx; m_sig_idx = other.m_sig_idx; other.m_sig_idx = sig_idx; std::swap(m_exponent, other.m_exponent); } }; -inline void swap(mpff & m1, mpff & m2) { m1.swap(m2); } +inline void swap(mpff & m1, mpff & m2) noexcept { m1.swap(m2); } class mpz; class mpq; @@ -316,7 +316,7 @@ class mpff_manager { */ static void abs(mpff & a) { a.m_sign = 0; } - static void swap(mpff & a, mpff & b) { a.swap(b); } + static void swap(mpff & a, mpff & b) noexcept { a.swap(b); } /** \brief c <- a + b diff --git a/src/util/mpfx.h b/src/util/mpfx.h index 7f34e7dc1b5..c563c287232 100644 --- a/src/util/mpfx.h +++ b/src/util/mpfx.h @@ -38,13 +38,13 @@ class mpfx { m_sig_idx(0) { } - void swap(mpfx & other) { + void swap(mpfx & other) noexcept { unsigned sign = m_sign; m_sign = other.m_sign; other.m_sign = sign; unsigned sig_idx = m_sig_idx; m_sig_idx = other.m_sig_idx; other.m_sig_idx = sig_idx; } }; -inline void swap(mpfx & m1, mpfx & m2) { m1.swap(m2); } +inline void swap(mpfx & m1, mpfx & m2) noexcept { m1.swap(m2); } class mpz; class mpq; @@ -228,7 +228,7 @@ class mpfx_manager { */ static void abs(mpfx & a) { a.m_sign = 0; } - static void swap(mpfx & a, mpfx & b) { a.swap(b); } + static void swap(mpfx & a, mpfx & b) noexcept { a.swap(b); } /** \brief c <- a + b diff --git a/src/util/mpn.cpp b/src/util/mpn.cpp index bc9017726e2..e84c15d5e3c 100644 --- a/src/util/mpn.cpp +++ b/src/util/mpn.cpp @@ -21,24 +21,20 @@ Revision History: #include "util/buffer.h" #include "util/mpn.h" -#define max(a,b) (((a) > (b)) ? (a) : (b)) - typedef uint64_t mpn_double_digit; static_assert(sizeof(mpn_double_digit) == 2 * sizeof(mpn_digit), "size alignment"); -const mpn_digit mpn_manager::zero = 0; - int mpn_manager::compare(mpn_digit const * a, unsigned lnga, mpn_digit const * b, unsigned lngb) const { int res = 0; trace(a, lnga); - unsigned j = max(lnga, lngb); + unsigned j = std::max(lnga, lngb); for (; j-- > 0 && res == 0;) { - mpn_digit const & u_j = (j < lnga) ? a[j] : zero; - mpn_digit const & v_j = (j < lngb) ? b[j] : zero; - if (u_j > v_j) + mpn_digit u_j = (j < lnga) ? a[j] : 0; + mpn_digit v_j = (j < lngb) ? b[j] : 0; + if (u_j > v_j) res = 1; else if (u_j < v_j) res = -1; @@ -56,14 +52,14 @@ bool mpn_manager::add(mpn_digit const * a, unsigned lnga, unsigned * plngc) const { trace(a, lnga, b, lngb, "+"); // Essentially Knuth's Algorithm A - unsigned len = max(lnga, lngb); + unsigned len = std::max(lnga, lngb); SASSERT(lngc_alloc == len+1 && len > 0); mpn_digit k = 0; mpn_digit r; bool c1, c2; for (unsigned j = 0; j < len; j++) { - mpn_digit const & u_j = (j < lnga) ? a[j] : zero; - mpn_digit const & v_j = (j < lngb) ? b[j] : zero; + mpn_digit u_j = (j < lnga) ? a[j] : 0; + mpn_digit v_j = (j < lngb) ? b[j] : 0; r = u_j + v_j; c1 = r < u_j; c[j] = r + k; c2 = c[j] < r; k = c1 | c2; @@ -81,13 +77,13 @@ bool mpn_manager::sub(mpn_digit const * a, unsigned lnga, mpn_digit * c, mpn_digit * pborrow) const { trace(a, lnga, b, lngb, "-"); // Essentially Knuth's Algorithm S - unsigned len = max(lnga, lngb); + unsigned len = std::max(lnga, lngb); mpn_digit & k = *pborrow; k = 0; mpn_digit r; bool c1, c2; for (unsigned j = 0; j < len; j++) { - mpn_digit const & u_j = (j < lnga) ? a[j] : zero; - mpn_digit const & v_j = (j < lngb) ? b[j] : zero; + mpn_digit u_j = (j < lnga) ? a[j] : 0; + mpn_digit v_j = (j < lngb) ? b[j] : 0; r = u_j - v_j; c1 = r > u_j; c[j] = r - k; c2 = c[j] > r; k = c1 | c2; @@ -112,14 +108,14 @@ bool mpn_manager::mul(mpn_digit const * a, unsigned lnga, c[i] = 0; for (unsigned j = 0; j < lngb; j++) { - mpn_digit const & v_j = b[j]; + mpn_digit v_j = b[j]; if (v_j == 0) { // This branch may be omitted according to Knuth. c[j+lnga] = 0; } else { k = 0; for (i = 0; i < lnga; i++) { - mpn_digit const & u_i = a[i]; + mpn_digit u_i = a[i]; mpn_double_digit t; t = ((mpn_double_digit)u_i * (mpn_double_digit)v_j) + (mpn_double_digit) c[i+j] + @@ -156,15 +152,6 @@ bool mpn_manager::div(mpn_digit const * numer, unsigned lnum, return false; } - bool all_zero = true; - for (unsigned i = 0; i < lden && all_zero; i++) - if (denom[i] != zero) all_zero = false; - - if (all_zero) { - UNREACHABLE(); - return res; - } - SASSERT(denom[lden-1] != 0); if (lnum == 1 && lden == 1) { diff --git a/src/util/mpn.h b/src/util/mpn.h index 7cf3eafb6c9..285edd09106 100644 --- a/src/util/mpn.h +++ b/src/util/mpn.h @@ -53,7 +53,6 @@ class mpn_manager { private: using mpn_sbuffer = sbuffer; - static const mpn_digit zero; void display_raw(std::ostream & out, mpn_digit const * a, unsigned lng) const; unsigned div_normalize(mpn_digit const * numer, unsigned lnum, diff --git a/src/util/mpq.cpp b/src/util/mpq.cpp index 324750cfa9e..f90843e36f7 100644 --- a/src/util/mpq.cpp +++ b/src/util/mpq.cpp @@ -316,6 +316,12 @@ unsigned mpq_manager::prev_power_of_two(mpq const & a) { return prev_power_of_two(_tmp); } +template +unsigned mpq_manager::next_power_of_two(mpq const & a) { + _scoped_numeral > _tmp(*this); + ceil(a, _tmp); + return next_power_of_two(_tmp); +} template template diff --git a/src/util/mpq.h b/src/util/mpq.h index e254ade697b..286c2758d86 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -47,7 +47,7 @@ class mpq_manager : public mpz_manager { void reset_denominator(mpq & a) { del(a.m_den); - a.m_den.m_val = 1; + a.m_den.set(1); } void normalize(mpq & a) { @@ -761,9 +761,9 @@ class mpq_manager : public mpz_manager { return temp; } - void swap(mpz & a, mpz & b) { mpz_manager::swap(a, b); } + void swap(mpz & a, mpz & b) noexcept { mpz_manager::swap(a, b); } - void swap(mpq & a, mpq & b) { + void swap(mpq & a, mpq & b) noexcept { swap(a.m_num, b.m_num); swap(a.m_den, b.m_den); } @@ -848,6 +848,14 @@ class mpq_manager : public mpz_manager { unsigned prev_power_of_two(mpz const & a) { return mpz_manager::prev_power_of_two(a); } unsigned prev_power_of_two(mpq const & a); + /** + \brief Return the smallest k s.t. a <= 2^k. + + \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two(mpz const & a) { return mpz_manager::next_power_of_two(a); } + unsigned next_power_of_two(mpq const & a); + bool is_int_perfect_square(mpq const & a, mpq & r) { SASSERT(is_int(a)); reset_denominator(r); diff --git a/src/util/mpq_inf.h b/src/util/mpq_inf.h index e2bcbcba61c..d8e9148e1ee 100644 --- a/src/util/mpq_inf.h +++ b/src/util/mpq_inf.h @@ -52,7 +52,7 @@ class mpq_inf_manager { m.del(a.second); } - void swap(mpq_inf & a, mpq_inf & b) { + void swap(mpq_inf & a, mpq_inf & b) noexcept { m.swap(a.first, b.first); m.swap(a.second, b.second); } diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index c3ba3016160..296b4426e76 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -2288,6 +2288,19 @@ unsigned mpz_manager::bitsize(mpz const & a) { return mlog2(a) + 1; } +template +unsigned mpz_manager::next_power_of_two(mpz const & a) { + if (is_nonpos(a)) + return 0; + if (is_one(a)) + return 0; + unsigned shift; + if (is_power_of_two(a, shift)) + return shift; + else + return log2(a) + 1; +} + template bool mpz_manager::is_perfect_square(mpz const & a, mpz & root) { if (is_neg(a)) diff --git a/src/util/mpz.h b/src/util/mpz.h index a1bb1939598..6d1b3449df1 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -84,6 +84,7 @@ class mpz { #else typedef mpz_t mpz_type; #endif +protected: int m_val; unsigned m_kind:1; unsigned m_owner:1; @@ -110,12 +111,23 @@ class mpz { return *this; } - void swap(mpz & other) { + void swap(mpz & other) noexcept { std::swap(m_val, other.m_val); std::swap(m_ptr, other.m_ptr); unsigned o = m_owner; m_owner = other.m_owner; other.m_owner = o; unsigned k = m_kind; m_kind = other.m_kind; other.m_kind = k; } + + void set(int v) { + m_val = v; + m_kind = mpz_small; + } + + inline bool is_small() const { return m_kind == mpz_small; } + + inline int value() const { SASSERT(is_small()); return m_val; } + + inline int sign() const { SASSERT(!is_small()); return m_val; } }; #ifndef _MP_GMP @@ -131,7 +143,7 @@ class mpz_stack : public mpz { class mpz_stack : public mpz {}; #endif -inline void swap(mpz & m1, mpz & m2) { m1.swap(m2); } +inline void swap(mpz & m1, mpz & m2) noexcept { m1.swap(m2); } template class mpz_manager { @@ -242,14 +254,13 @@ class mpz_manager { mpz m_two64; - static int64_t i64(mpz const & a) { return static_cast(a.m_val); } + static int64_t i64(mpz const & a) { return static_cast(a.value()); } void set_big_i64(mpz & c, int64_t v); void set_i64(mpz & c, int64_t v) { if (v >= INT_MIN && v <= INT_MAX) { - c.m_val = static_cast(v); - c.m_kind = mpz_small; + c.set(static_cast(v)); } else { set_big_i64(c, v); @@ -306,25 +317,25 @@ class mpz_manager { void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell, mpz_cell* reserve) { if (is_small(a)) { - if (a.m_val == INT_MIN) { + if (a.value() == INT_MIN) { sign = -1; cell = m_int_min.m_ptr; } else { cell = reserve; cell->m_size = 1; - if (a.m_val < 0) { + if (a.value() < 0) { sign = -1; - cell->m_digits[0] = -a.m_val; + cell->m_digits[0] = -a.value(); } else { sign = 1; - cell->m_digits[0] = a.m_val; + cell->m_digits[0] = a.value(); } } } else { - sign = a.m_val; + sign = a.sign(); cell = a.m_ptr; } } @@ -398,7 +409,7 @@ class mpz_manager { ~mpz_manager(); - static bool is_small(mpz const & a) { return a.m_kind == mpz_small; } + static bool is_small(mpz const & a) { return a.is_small(); } static mpz mk_z(int val) { return mpz(val); } @@ -461,7 +472,7 @@ class mpz_manager { bool eq(mpz const & a, mpz const & b) { if (is_small(a) && is_small(b)) { - return a.m_val == b.m_val; + return a.value() == b.value(); } else { return big_compare(a, b) == 0; @@ -470,7 +481,7 @@ class mpz_manager { bool lt(mpz const& a, int b) { if (is_small(a)) { - return a.m_val < b; + return a.value() < b; } else { return lt(a, mpz(b)); @@ -479,7 +490,7 @@ class mpz_manager { bool lt(mpz const & a, mpz const & b) { if (is_small(a) && is_small(b)) { - return a.m_val < b.m_val; + return a.value() < b.value(); } else { return big_compare(a, b) < 0; @@ -526,8 +537,7 @@ class mpz_manager { void set(mpz & target, mpz const & source) { if (is_small(source)) { - target.m_val = source.m_val; - target.m_kind = mpz_small; + target.set(source.value()); } else { big_set(target, source); @@ -535,8 +545,7 @@ class mpz_manager { } void set(mpz & a, int val) { - a.m_val = val; - a.m_kind = mpz_small; + a.set(val); } void set(mpz & a, unsigned val) { @@ -554,8 +563,7 @@ class mpz_manager { void set(mpz & a, uint64_t val) { if (val < INT_MAX) { - a.m_val = static_cast(val); - a.m_kind = mpz_small; + a.set(static_cast(val)); } else { set_big_ui64(a, val); @@ -573,11 +581,8 @@ class mpz_manager { // deallocates any memory. void reset(mpz & a); - void swap(mpz & a, mpz & b) { - std::swap(a.m_val, b.m_val); - std::swap(a.m_ptr, b.m_ptr); - auto o = a.m_owner; a.m_owner = b.m_owner; b.m_owner = o; - auto k = a.m_kind; a.m_kind = b.m_kind; b.m_kind = k; + void swap(mpz & a, mpz & b) noexcept { + a.swap(b); } bool is_uint64(mpz const & a) const; @@ -624,20 +629,20 @@ class mpz_manager { static bool is_one(mpz const & a) { #ifndef _MP_GMP - return is_small(a) && a.m_val == 1; + return is_small(a) && a.value() == 1; #else if (is_small(a)) - return a.m_val == 1; + return a.value() == 1; return mpz_cmp_si(*a.m_ptr, 1) == 0; #endif } static bool is_minus_one(mpz const & a) { #ifndef _MP_GMP - return is_small(a) && a.m_val == -1; + return is_small(a) && a.value() == -1; #else if (is_small(a)) - return a.m_val == -1; + return a.value() == -1; return mpz_cmp_si(*a.m_ptr, -1) == 0; #endif } @@ -692,6 +697,13 @@ class mpz_manager { \remark Return 0 if a is not positive. */ unsigned prev_power_of_two(mpz const & a) { return log2(a); } + + /** + \brief Return the smallest k s.t. a <= 2^k. + + \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two(mpz const & a); /** \brief Return true if a^{1/n} is an integer, and store the result in a. @@ -705,7 +717,7 @@ class mpz_manager { bool is_even(mpz const & a) { if (is_small(a)) - return !(a.m_val & 0x1); + return !(a.value() & 0x1); #ifndef _MP_GMP return !(0x1 & digits(a)[0]); #else diff --git a/src/util/mpzzp.h b/src/util/mpzzp.h index 731b116ceec..24fb64ea309 100644 --- a/src/util/mpzzp.h +++ b/src/util/mpzzp.h @@ -183,7 +183,7 @@ class mpzzp_manager { } } - void swap(mpz & a, mpz & b) { + void swap(mpz & a, mpz & b) noexcept { SASSERT(is_p_normalized(a) && is_p_normalized(b)); m().swap(a, b); } diff --git a/src/util/numeral_buffer.h b/src/util/numeral_buffer.h index 5bd678d782f..1951b54499e 100644 --- a/src/util/numeral_buffer.h +++ b/src/util/numeral_buffer.h @@ -81,7 +81,7 @@ class numeral_buffer { m_buffer.reserve(sz); } - void swap(svector & other) { + void swap(svector & other) noexcept { m_buffer.swap(other); } diff --git a/src/util/obj_hashtable.h b/src/util/obj_hashtable.h index 254f5f70d7d..49b37ca6189 100644 --- a/src/util/obj_hashtable.h +++ b/src/util/obj_hashtable.h @@ -209,7 +209,7 @@ class obj_map { } } - void swap(obj_map & other) { + void swap(obj_map & other) noexcept { m_table.swap(other.m_table); } }; diff --git a/src/util/obj_ref.h b/src/util/obj_ref.h index 04ae9e4124f..9ab68676da9 100644 --- a/src/util/obj_ref.h +++ b/src/util/obj_ref.h @@ -93,7 +93,7 @@ class obj_ref { return *this; } - obj_ref & operator=(obj_ref && n) { + obj_ref & operator=(obj_ref && n) noexcept { SASSERT(&m_manager == &n.m_manager); std::swap(m_obj, n.m_obj); n.reset(); @@ -105,7 +105,7 @@ class obj_ref { m_obj = nullptr; } - void swap(obj_ref & n) { + void swap(obj_ref & n) noexcept { std::swap(m_obj, n.m_obj); } diff --git a/src/util/obj_ref_hashtable.h b/src/util/obj_ref_hashtable.h index d3e17a4a3d2..dba6550d40b 100644 --- a/src/util/obj_ref_hashtable.h +++ b/src/util/obj_ref_hashtable.h @@ -104,7 +104,7 @@ class obj_ref_map { unsigned long long get_num_collision() const { return m_table.get_num_collision(); } - void swap(obj_ref_map & other) { + void swap(obj_ref_map & other) noexcept { m_table.swap(other.m_table); } diff --git a/src/util/optional.h b/src/util/optional.h index 50a2148d10a..8e515f4be7e 100644 --- a/src/util/optional.h +++ b/src/util/optional.h @@ -94,7 +94,7 @@ class optional { return * this; } - optional & operator=(optional && val) { + optional & operator=(optional && val) noexcept { std::swap(m_obj, val.m_obj); return *this; } diff --git a/src/util/params.cpp b/src/util/params.cpp index d8902615245..700a5310928 100644 --- a/src/util/params.cpp +++ b/src/util/params.cpp @@ -99,25 +99,24 @@ struct param_descrs::imp { return CPK_INVALID; } - bool split_name(symbol const& name, symbol & prefix, symbol & suffix) const { + bool split_name(symbol const& name, std::string_view & prefix, symbol & suffix) const { if (name.is_numerical()) return false; char const* str = name.bare_str(); char const* period = strchr(str,'.'); if (!period) return false; - svector prefix_((unsigned)(period-str), str); - prefix_.push_back(0); - prefix = symbol(prefix_.data()); + prefix = std::string_view(str, period - str); suffix = symbol(period + 1); return true; } param_kind get_kind_in_module(symbol & name) const { param_kind k = get_kind(name); - symbol prefix, suffix; + std::string_view prefix; + symbol suffix; if (k == CPK_INVALID && split_name(name, prefix, suffix)) { k = get_kind(suffix); if (k != CPK_INVALID) { - if (symbol(get_module(suffix)) == prefix) { + if (get_module(suffix) == prefix) { name = suffix; } else { @@ -170,8 +169,8 @@ struct param_descrs::imp { if (names.empty()) return; if (markdown) { - out << " Parameter | Type | Description | Default\n"; - out << " ----------|------|-------------|--------\n"; + out << " Parameter | Type | Description | Default\n" + " ----------|------|-------------|--------\n"; } for (symbol const& name : names) { for (unsigned i = 0; i < indent; i++) out << " "; @@ -197,16 +196,14 @@ struct param_descrs::imp { else out << " (" << d.m_kind << ")"; if (markdown) { - out << " | "; - std::string desc; - for (auto ch : std::string(d.m_descr)) { + out << " | "; + for (auto ch : std::string_view(d.m_descr)) { switch (ch) { - case '<': desc += "<"; break; - case '>': desc += ">"; break; - default: desc.push_back(ch); + case '<': out << "<"; break; + case '>': out << ">"; break; + default: out << ch; break; } } - out << " " << desc; } else if (include_descr) out << " " << d.m_descr; @@ -549,8 +546,7 @@ params_ref::~params_ref() { m_params->dec_ref(); } -params_ref::params_ref(params_ref const & p): - m_params(nullptr) { +params_ref::params_ref(params_ref const & p) { set(p); } diff --git a/src/util/params.h b/src/util/params.h index bc45bdbce03..200b4aa2c43 100644 --- a/src/util/params.h +++ b/src/util/params.h @@ -32,12 +32,12 @@ class param_descrs; class params_ref { static params_ref g_empty_params_ref; - params * m_params; + params * m_params = nullptr; void init(); void copy_core(params const * p); void set(params_ref const& p); public: - params_ref():m_params(nullptr) {} + params_ref() = default; params_ref(params_ref const & p); ~params_ref(); diff --git a/src/util/permutation.cpp b/src/util/permutation.cpp index 5b20580eb0d..8b7adedf316 100644 --- a/src/util/permutation.cpp +++ b/src/util/permutation.cpp @@ -31,7 +31,7 @@ void permutation::reset(unsigned size) { } } -void permutation::swap(unsigned i, unsigned j) { +void permutation::swap(unsigned i, unsigned j) noexcept { unsigned i_prime = m_p[i]; unsigned j_prime = m_p[j]; std::swap(m_p[i], m_p[j]); diff --git a/src/util/permutation.h b/src/util/permutation.h index 1fb7dc9e727..dfc52b370ac 100644 --- a/src/util/permutation.h +++ b/src/util/permutation.h @@ -31,7 +31,7 @@ class permutation { unsigned operator()(unsigned i) const { return m_p[i]; } unsigned inv(unsigned i_prime) const { return m_inv_p[i_prime]; } - void swap(unsigned i, unsigned j); + void swap(unsigned i, unsigned j) noexcept; void move_after(unsigned i, unsigned j); void display(std::ostream & out) const; diff --git a/src/util/rational.h b/src/util/rational.h index f47fddefefc..e9924bca751 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -30,6 +30,10 @@ class rational { static synch_mpq_manager & m() { return *g_mpq_manager; } + void display_hex(std::ostream & out, unsigned num_bits) const { SASSERT(is_int()); m().display_hex(out, m_val.numerator(), num_bits); } + + void display_bin(std::ostream& out, unsigned num_bits) const { SASSERT(is_int()); m().display_bin(out, m_val.numerator(), num_bits); } + public: static void initialize(); static void finalize(); @@ -55,7 +59,7 @@ class rational { explicit rational(double z) { UNREACHABLE(); } explicit rational(char const * v) { m().set(m_val, v); } - + explicit rational(unsigned const * v, unsigned sz) { m().set(m_val, sz, v); } struct i64 {}; @@ -86,7 +90,7 @@ class rational { struct eq_proc { bool operator()(rational const& r1, rational const& r2) const { return r1 == r2; } }; - void swap(rational & n) { m().swap(m_val, n.m_val); } + void swap(rational & n) noexcept { m().swap(m_val, n.m_val); } std::string to_string() const { return m().to_string(m_val); } @@ -96,9 +100,33 @@ class rational { void display_smt2(std::ostream & out) const { return m().display_smt2(out, m_val, false); } - void display_hex(std::ostream & out, unsigned num_bits) const { SASSERT(is_int()); return m().display_hex(out, m_val.numerator(), num_bits); } - void display_bin(std::ostream & out, unsigned num_bits) const { SASSERT(is_int()); return m().display_bin(out, m_val.numerator(), num_bits); } + struct as_hex_wrapper { + rational const& r; + unsigned bw; + }; + + as_hex_wrapper as_hex(unsigned bw) const { return as_hex_wrapper{*this, bw}; } + + friend inline std::ostream& operator<<(std::ostream& out, as_hex_wrapper const& ab) { + ab.r.display_hex(out, ab.bw); + return out; + } + + + + struct as_bin_wrapper { + rational const& r; + unsigned bw; + }; + + as_bin_wrapper as_bin(unsigned bw) const { return as_bin_wrapper{*this, bw}; } + + friend inline std::ostream& operator<<(std::ostream& out, as_bin_wrapper const& ab) { + ab.r.display_bin(out, ab.bw); + return out; + } + bool is_uint64() const { return m().is_uint64(m_val); } @@ -489,6 +517,18 @@ class rational { return get_num_digits(rational(10)); } + /** + * \brief Return the biggest k s.t. 2^k <= a. + * \remark Return 0 if a is not positive. + */ + unsigned prev_power_of_two() const { return m().prev_power_of_two(m_val); } + + /** + * \brief Return the smallest k s.t. a <= 2^k. + * \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two() const { return m().next_power_of_two(m_val); } + bool get_bit(unsigned index) const { return m().get_bit(m_val, index); } @@ -501,6 +541,15 @@ class rational { return k; } + /** Number of trailing zeros in an N-bit representation */ + unsigned parity(unsigned num_bits) const { + SASSERT(!is_neg()); + SASSERT(*this < rational::power_of_two(num_bits)); + if (is_zero()) + return num_bits; + return trailing_zeros(); + } + static bool limit_denominator(rational &num, rational const& limit); }; @@ -649,3 +698,7 @@ inline rational gcd(rational const & r1, rational const & r2, rational & a, rati rational::m().gcd(r1.m_val, r2.m_val, a.m_val, b.m_val, result.m_val); return result; } + +inline void swap(rational& r1, rational& r2) noexcept { + r1.swap(r2); +} diff --git a/src/util/ref.h b/src/util/ref.h index 1f30aae827f..849e23b38cc 100644 --- a/src/util/ref.h +++ b/src/util/ref.h @@ -97,7 +97,7 @@ class ref { return *this; } - ref & operator=(ref &&r) { + ref & operator=(ref &&r) noexcept { if (this != &r) { dec_ref (); m_ptr = r.detach (); @@ -123,7 +123,7 @@ class ref { friend bool operator!=(const ref & r1, const ref & r2) { return r1.m_ptr != r2.m_ptr; } - friend void swap (ref &r1, ref &r2) { + friend void swap (ref &r1, ref &r2) noexcept { T* tmp = r1.m_ptr; r1.m_ptr = r2.m_ptr; r2.m_ptr = tmp; diff --git a/src/util/ref_pair_vector.h b/src/util/ref_pair_vector.h index 489767912d6..40f845af666 100644 --- a/src/util/ref_pair_vector.h +++ b/src/util/ref_pair_vector.h @@ -143,7 +143,7 @@ class ref_pair_vector_core : public Ref { push_back(other[i]); } - void swap(unsigned idx1, unsigned idx2) { + void swap(unsigned idx1, unsigned idx2) noexcept { std::swap(m_nodes[idx1], m_nodes[idx2]); } @@ -179,7 +179,7 @@ class ref_pair_vector : public ref_pair_vector_coreappend(other); } - ref_pair_vector(ref_pair_vector && other) : super(std::move(other)) {} + ref_pair_vector(ref_pair_vector && other) noexcept : super(std::move(other)) {} ref_pair_vector(TManager & m, unsigned sz, elem_t const * data): super(ref_manager_wrapper(m)) { @@ -194,7 +194,7 @@ class ref_pair_vector : public ref_pair_vector_corem_manager) == &(other.m_manager)); this->m_nodes.swap(other.m_nodes); } diff --git a/src/util/ref_vector.h b/src/util/ref_vector.h index 406f7be61e0..3a8492c2927 100644 --- a/src/util/ref_vector.h +++ b/src/util/ref_vector.h @@ -197,14 +197,14 @@ class ref_vector_core : public Ref { push_back(data[i]); } - void operator=(ref_vector_core && other) { + void operator=(ref_vector_core && other) noexcept { if (this != &other) { reset(); m_nodes = std::move(other.m_nodes); } } - void swap(unsigned idx1, unsigned idx2) { + void swap(unsigned idx1, unsigned idx2) noexcept { std::swap(m_nodes[idx1], m_nodes[idx2]); } @@ -262,7 +262,7 @@ class ref_vector : public ref_vector_core > return get_manager(); } - void swap(ref_vector & other) { + void swap(ref_vector & other) noexcept { SASSERT(&(this->m_manager) == &(other.m_manager)); this->m_nodes.swap(other.m_nodes); } diff --git a/src/util/rlimit.h b/src/util/rlimit.h index 9be53cd5597..0abb06cb341 100644 --- a/src/util/rlimit.h +++ b/src/util/rlimit.h @@ -91,8 +91,9 @@ class scoped_suspend_rlimit { struct scoped_limits { reslimit& m_limit; - unsigned m_sz; - scoped_limits(reslimit& lim): m_limit(lim), m_sz(0) {} - ~scoped_limits() { for (unsigned i = 0; i < m_sz; ++i) m_limit.pop_child(); } + unsigned m_sz = 0; + scoped_limits(reslimit& lim): m_limit(lim) {} + ~scoped_limits() { reset(); } + void reset() { for (unsigned i = 0; i < m_sz; ++i) m_limit.pop_child(); m_sz = 0; } void push_child(reslimit* lim) { m_limit.push_child(lim); ++m_sz; } }; diff --git a/src/util/s_integer.h b/src/util/s_integer.h index 506b4f0a363..6ddd2bf6716 100644 --- a/src/util/s_integer.h +++ b/src/util/s_integer.h @@ -34,7 +34,7 @@ class s_integer { struct hash_proc { unsigned operator()(s_integer const& r) const { return r.hash(); } }; struct eq_proc { bool operator()(s_integer const& r1, s_integer const& r2) const { return r1 == r2; } }; - void swap(s_integer & n) { + void swap(s_integer & n) noexcept { std::swap(m_val, n.m_val); } diff --git a/src/util/sat_literal.h b/src/util/sat_literal.h index aeb23bddd20..58088e6281a 100644 --- a/src/util/sat_literal.h +++ b/src/util/sat_literal.h @@ -30,7 +30,7 @@ namespace sat { typedef svector bool_var_vector; - const bool_var null_bool_var = UINT_MAX >> 1; + inline constexpr bool_var null_bool_var = UINT_MAX >> 1; /** \brief The literal b is represented by the value 2*b, and @@ -39,9 +39,7 @@ namespace sat { class literal { unsigned m_val; public: - literal():m_val(null_bool_var << 1) { - SASSERT(var() == null_bool_var && !sign()); - } + constexpr literal(): m_val(null_bool_var << 1) { } explicit literal(bool_var v, bool _sign = false): m_val((v << 1) + static_cast(_sign)) { @@ -49,11 +47,11 @@ namespace sat { SASSERT(sign() == _sign); } - bool_var var() const { + constexpr bool_var var() const { return m_val >> 1; } - bool sign() const { + constexpr bool sign() const { return m_val & 1ul; } @@ -86,7 +84,10 @@ namespace sat { friend bool operator!=(literal const & l1, literal const & l2); }; - const literal null_literal; + inline constexpr literal null_literal; + static_assert(null_literal.var() == null_bool_var); + static_assert(!null_literal.sign()); + using literal_hash = obj_hash; inline literal to_literal(unsigned x) { literal l; l.m_val = x; return l; } diff --git a/src/util/scoped_numeral.h b/src/util/scoped_numeral.h index 65b75053d68..f70f5f18553 100644 --- a/src/util/scoped_numeral.h +++ b/src/util/scoped_numeral.h @@ -28,8 +28,8 @@ class _scoped_numeral { numeral m_num; public: _scoped_numeral(Manager & m):m_manager(m) {} - _scoped_numeral(_scoped_numeral const & n):m_manager(n.m_manager) { m().set(m_num, n.m_num); } - _scoped_numeral(_scoped_numeral &&) = default; + _scoped_numeral(_scoped_numeral const& n) :m_manager(n.m_manager) { m().set(m_num, n.m_num); } + _scoped_numeral(_scoped_numeral && n) noexcept: m_manager(n.m_manager) { m().swap(m_num, n.m_num); } ~_scoped_numeral() { m_manager.del(m_num); } Manager & m() const { return m_manager; } @@ -60,11 +60,11 @@ class _scoped_numeral { m().reset(m_num); } - void swap(_scoped_numeral & n) { + void swap(_scoped_numeral & n) noexcept { m().swap(m_num, n.m_num); } - void swap(numeral & n) { + void swap(numeral & n) noexcept { m().swap(m_num, n); } diff --git a/src/util/scoped_ptr_vector.h b/src/util/scoped_ptr_vector.h index 3c33cc70845..90afbb6ef5b 100644 --- a/src/util/scoped_ptr_vector.h +++ b/src/util/scoped_ptr_vector.h @@ -34,7 +34,7 @@ class scoped_ptr_vector { scoped_ptr_vector(scoped_ptr_vector&& other) noexcept { m_vector.swap(other.m_vector); } - scoped_ptr_vector& operator=(scoped_ptr_vector&& other) { + scoped_ptr_vector& operator=(scoped_ptr_vector&& other) noexcept { if (this == &other) return *this; reset(); @@ -55,7 +55,7 @@ class scoped_ptr_vector { dealloc(m_vector[idx]); m_vector[idx] = ptr; } - void swap(unsigned i, unsigned j) { std::swap(m_vector[i], m_vector[j]); } + void swap(unsigned i, unsigned j) noexcept { std::swap(m_vector[i], m_vector[j]); } unsigned size() const { return m_vector.size(); } bool empty() const { return m_vector.empty(); } void resize(unsigned sz) { diff --git a/src/util/sexpr.cpp b/src/util/sexpr.cpp index e20c56234bf..dcf427dfe0e 100644 --- a/src/util/sexpr.cpp +++ b/src/util/sexpr.cpp @@ -26,11 +26,11 @@ Module Name: #endif struct sexpr_composite : public sexpr { - unsigned m_num_chilren; + unsigned m_num_children; sexpr * m_children[0]; sexpr_composite(unsigned num_children, sexpr * const * children, unsigned line, unsigned pos): sexpr(kind_t::COMPOSITE, line, pos), - m_num_chilren(num_children) { + m_num_children(num_children) { for (unsigned i = 0; i < num_children; i++) { m_children[i] = children[i]; children[i]->inc_ref(); @@ -107,7 +107,7 @@ std::string const & sexpr::get_string() const { unsigned sexpr::get_num_children() const { SASSERT(is_composite()); - return static_cast(this)->m_num_chilren; + return static_cast(this)->m_num_children; } sexpr * sexpr::get_child(unsigned idx) const { diff --git a/src/util/sstream.h b/src/util/sstream.h deleted file mode 100644 index fba13c5d560..00000000000 --- a/src/util/sstream.h +++ /dev/null @@ -1,31 +0,0 @@ -/* -Copyright (c) 2018 Microsoft Corporation - -Module Name: - - nat_set.h - -Abstract: - - Wrapper for sstream. - -Author: - - Leonardo de Moura (leonardo) 2013 - -Revision History: - -*/ -#pragma once -#include -#include - -namespace lean { -/** \brief Wrapper for std::ostringstream */ -class sstream { - std::ostringstream m_strm; -public: - std::string str() const { return m_strm.str(); } - template sstream & operator<<(T const & t) { m_strm << t; return *this; } -}; -} diff --git a/src/util/tbv.cpp b/src/util/tbv.cpp index 017ca0eb702..5048267ade0 100644 --- a/src/util/tbv.cpp +++ b/src/util/tbv.cpp @@ -142,12 +142,8 @@ void tbv_manager::set(tbv& dst, rational const& r, unsigned hi, unsigned lo) { set(dst, r.get_uint64(), hi, lo); return; } - for (unsigned i = 0; i < hi - lo + 1; ++i) { - if (bitwise_and(r, rational::power_of_two(i)).is_zero()) - set(dst, lo + i, BIT_0); - else - set(dst, lo + i, BIT_1); - } + for (unsigned i = 0; i < hi - lo + 1; ++i) + set(dst, lo + i, r.get_bit(i) ? BIT_1 : BIT_0); } void tbv_manager::set(tbv& dst, tbv const& other, unsigned hi, unsigned lo) { diff --git a/src/util/tbv.h b/src/util/tbv.h index cffdc2460c4..3bda89b3f97 100644 --- a/src/util/tbv.h +++ b/src/util/tbv.h @@ -135,7 +135,7 @@ class tbv_ref { public: tbv_ref(tbv_manager& mgr) : mgr(mgr), d(nullptr) {} tbv_ref(tbv_manager& mgr, tbv* d) : mgr(mgr), d(d) {} - tbv_ref(tbv_ref&& d) : mgr(d.mgr), d(d.detach()) {} + tbv_ref(tbv_ref&& d) noexcept : mgr(d.mgr), d(d.detach()) {} ~tbv_ref() { if (d) mgr.deallocate(d); } diff --git a/src/util/tptr.h b/src/util/tptr.h index 6213b2efa40..99abe34a7a2 100644 --- a/src/util/tptr.h +++ b/src/util/tptr.h @@ -19,7 +19,9 @@ Revision History: #pragma once +#include #include "util/machine.h" +#include "util/debug.h" #define TAG_SHIFT PTR_ALIGNMENT #define ALIGNMENT_VALUE (1 << PTR_ALIGNMENT) @@ -41,4 +43,22 @@ Revision History: #define UNBOXINT(PTR) static_cast(reinterpret_cast(PTR) >> PTR_ALIGNMENT) - +template +U unbox(T* ptr) { + return static_cast(reinterpret_cast(ptr) >> PTR_ALIGNMENT); +} + +template +unsigned get_tag(T* ptr) { + return reinterpret_cast(ptr) & TAG_MASK; +} + +template +T* box(U val, std::uintptr_t tag = 0) { + static_assert( sizeof(T*) >= sizeof(U) + PTR_ALIGNMENT ); + SASSERT_EQ(tag & PTR_MASK, 0); + T* ptr = reinterpret_cast((static_cast(val) << PTR_ALIGNMENT) | tag); + SASSERT_EQ(val, unbox(ptr)); // roundtrip of conversion integer -> pointer -> integer is not actually guaranteed by the C++ standard (but seems fine in practice, as indicated by previous usage of BOXINT/UNBOXINT) + SASSERT_EQ(tag, get_tag(ptr)); + return ptr; +} diff --git a/src/util/trail.h b/src/util/trail.h index 1aa7e44418a..43e6982342c 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -219,12 +219,12 @@ class push_back_vector : public trail { } }; -template +template class set_vector_idx_trail : public trail { - ptr_vector & m_vector; + V & m_vector; unsigned m_idx; public: - set_vector_idx_trail(ptr_vector & v, unsigned idx): + set_vector_idx_trail(V & v, unsigned idx): m_vector(v), m_idx(idx) { } diff --git a/src/util/uint_set.h b/src/util/uint_set.h index 6e64cc7ae30..73c3bce1fed 100644 --- a/src/util/uint_set.h +++ b/src/util/uint_set.h @@ -28,7 +28,7 @@ class uint_set : unsigned_vector { typedef unsigned data; - void swap(uint_set & other) { + void swap(uint_set & other) noexcept { unsigned_vector::swap(other); } @@ -318,7 +318,7 @@ class indexed_uint_set { m_size(0) {} - void insert(unsigned x) { + void insert_fresh(unsigned x) { SASSERT(!contains(x)); m_index.reserve(x + 1, UINT_MAX); m_elems.reserve(m_size + 1); @@ -327,6 +327,11 @@ class indexed_uint_set { m_size++; SASSERT(contains(x)); } + + void insert(unsigned x) { + if (!contains(x)) + insert_fresh(x); + } void remove(unsigned x) { SASSERT(contains(x)); @@ -345,6 +350,10 @@ class indexed_uint_set { SASSERT(index < m_size); return m_elems[index]; } + unsigned operator[](unsigned index) const { + SASSERT(index < m_size); + return m_elems[index]; + } bool contains(unsigned x) const { return x < m_index.size() && m_index[x] < m_size && m_elems[m_index[x]] == x; } void reset() { m_size = 0; } diff --git a/src/util/union_find.h b/src/util/union_find.h index 7e42e1bbad9..0c08ac44623 100644 --- a/src/util/union_find.h +++ b/src/util/union_find.h @@ -35,7 +35,7 @@ class union_find_default_ctx { _trail_stack m_stack; }; -template +template class union_find { Ctx & m_ctx; trail_stack & m_trail_stack; diff --git a/src/util/util.h b/src/util/util.h index 6d4efb671c7..a4bf78073ab 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -265,7 +265,7 @@ class scoped_ptr { return *this; } - scoped_ptr& operator=(scoped_ptr&& other) { + scoped_ptr& operator=(scoped_ptr&& other) noexcept { *this = other.detach(); return *this; }; @@ -276,7 +276,7 @@ class scoped_ptr { return tmp; } - void swap(scoped_ptr & p) { + void swap(scoped_ptr & p) noexcept { std::swap(m_ptr, p.m_ptr); } }; @@ -363,7 +363,7 @@ void set_fatal_error_handler(void (*pfn)(int error_code)); template -bool any_of(S& set, T const& p) { +bool any_of(S const& set, T const& p) { for (auto const& s : set) if (p(s)) return true; @@ -371,13 +371,21 @@ bool any_of(S& set, T const& p) { } template -bool all_of(S& set, T const& p) { +bool all_of(S const& set, T const& p) { for (auto const& s : set) if (!p(s)) return false; return true; } +template +R find(S const& set, std::function p) { + for (auto const& s : set) + if (p(s)) + return s; + throw default_exception("element not found"); +} + /** \brief Iterator for the [0..sz[0]) X [0..sz[1]) X ... X [0..sz[n-1]). it contains the current value. diff --git a/src/util/var_queue.h b/src/util/var_queue.h index 7245153ca03..0af4de3b837 100644 --- a/src/util/var_queue.h +++ b/src/util/var_queue.h @@ -21,20 +21,20 @@ Revision History: #include "util/heap.h" - +template class var_queue { typedef unsigned var; struct lt { - svector & m_activity; - lt(svector & act):m_activity(act) {} + ActivityVector & m_activity; + lt(ActivityVector & act):m_activity(act) {} bool operator()(var v1, var v2) const { return m_activity[v1] > m_activity[v2]; } }; heap m_queue; -public: +public: - var_queue(svector & act):m_queue(128, lt(act)) {} + var_queue(ActivityVector & act):m_queue(128, lt(act)) {} void activity_increased_eh(var v) { if (m_queue.contains(v)) @@ -68,6 +68,8 @@ class var_queue { void reset() { m_queue.reset(); } + + bool contains(var v) const { return m_queue.contains(v); } bool empty() const { return m_queue.empty(); } @@ -90,11 +92,12 @@ class var_queue { return out; } - using const_iterator = decltype(m_queue)::const_iterator; + using const_iterator = const int *; const_iterator begin() const { return m_queue.begin(); } const_iterator end() const { return m_queue.end(); } }; -inline std::ostream& operator<<(std::ostream& out, var_queue const& queue) { +template +inline std::ostream& operator<<(std::ostream& out, var_queue const& queue) { return queue.display(out); } diff --git a/src/util/vector.h b/src/util/vector.h index 1cb25a8c456..d684f43ebe8 100644 --- a/src/util/vector.h +++ b/src/util/vector.h @@ -30,9 +30,11 @@ Revision History: #include #include #include +#include #include "util/memory_manager.h" #include "util/hash.h" #include "util/z3_exception.h" +#include // disable warning for constant 'if' expressions. // these are used heavily in templates. @@ -40,6 +42,8 @@ Revision History: #pragma warning(disable:4127) #endif +template +using std_vector = std::vector>; #if 0 @@ -327,7 +331,7 @@ class vector { return *this; } - vector & operator=(vector && source) { + vector & operator=(vector && source) noexcept { if (this == &source) { return *this; } @@ -555,7 +559,7 @@ class vector { for(; pos != e; ++pos, ++prev) { *prev = std::move(*pos); } - reinterpret_cast(m_data)[SIZE_IDX]--; + pop_back(); } void erase(T const & elem) { @@ -565,6 +569,20 @@ class vector { } } + /** Erase all elements that satisfy the given predicate. Returns the number of erased elements. */ + template + SZ erase_if(UnaryPredicate should_erase) { + iterator i = begin(); + iterator const e = end(); + for (iterator j = begin(); j != e; ++j) + if (!should_erase(std::as_const(*j))) + *(i++) = std::move(*j); + SZ const count = e - i; + SASSERT_EQ(i - begin(), size() - count); + shrink(size() - count); + return count; + } + void shrink(SZ s) { if (m_data) { SASSERT(s <= reinterpret_cast(m_data)[SIZE_IDX]); @@ -753,7 +771,8 @@ using bool_vector = svector; template inline std::ostream& operator<<(std::ostream& out, svector const& v) { - for (unsigned u : v) out << u << " "; + for (auto const& x : v) + out << x << " "; return out; } diff --git a/src/util/zstring.cpp b/src/util/zstring.cpp index eaa5bb5eef0..8e08820f657 100644 --- a/src/util/zstring.cpp +++ b/src/util/zstring.cpp @@ -152,7 +152,7 @@ std::string zstring::encode() const { unsigned ch = m_buffer[i]; if (ch < 32 || ch >= 128 || ('\\' == ch && i + 1 < m_buffer.size() && 'u' == m_buffer[i+1])) { _flush(); - strm << "\\u{" << std::hex << ch << std::dec << "}"; + strm << "\\u{" << std::hex << ch << std::dec << '}'; } else { if (offset == 99) @@ -161,7 +161,7 @@ std::string zstring::encode() const { } } _flush(); - return strm.str(); + return std::move(strm).str(); } bool zstring::suffixof(zstring const& other) const {