keras implementation of retention networks.
This has no affiliation with either keras or the authors of the original work.
Note there are currently no publicly available weights to load. If you are aware of any, please let us know in the issues. While considerable effort has gone into ensuring consistency with the original, it's difficult to say if this is the case or not without integration tests with pretrained weights.
While this repository supports all three keras backends (tensorflow, pytorch and jax), the former two are supported via jax2torch
and jax2tf
- so all need jax
(most of these functions could be re-written in their native backends, but that's not a high priority right now). Data processing in the examples is done via tensorflow. Once you've install jax
, keras-core
and keras-nlp
, this package can be installed via
git clone https://github.com/jackd/keras-retnet.git
pip install -e keras-retnet
Note getting all backends to work in the same environment is non-trivial. I had success using conda
to install jax
and pip for tensorflow
/torch
(following conda installation instructions for tensorflow
/torch
tends to break jax
installations).
See examples/train.py and examples/generate.py. Things might work without specifying a keras backend - in which case behaviour reverts to tf.keras
- this hasn't been tested. jax
is probably the most stable backend, since it doesn't use jax2torch
or jax2tf
.
KERAS_BACKEND=jax python examples/train.py --smoke
This package uses pre-commit to ensure commits meet minimum criteria. To Install, use
pip install pre-commit
pre-commit install
This will ensure git hooks are run before each commit. While it is not advised to do so, you can skip these hooks with
git commit --no-verify -m "commit message"