Surjection layers for density estimation with normalizing flows
Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors builds on Distrax and Haiku and is fully compatible with both of them.
Surjectors makes use of
- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
Documentation can be found here.
You can find several self-contained examples on how to use the algorithms in examples
.
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install the package from PyPI, call:
pip install surjectors
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
"good first issue" <https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22>
_.
In order to contribute:
- Clone Surjectors and install
hatch
viapip install hatch
, - create a new branch locally
git checkout -b feature/my-new-feature
orgit checkout -b issue/fixes-bug
, - implement your contribution and ideally a test case,
- test it by calling
hatch run test
on the (Unix) command line, - submit a PR 🙂
Simon Dirmeier sfyrbnd @ pm me