Releases: evanatyourservice/psgd_jax
Releases · evanatyourservice/psgd_jax
psgd-jax 0.2.9
What's Changed
- swapped normalize_grads out for clipping outputs by RMS. This is more stable, more accurate, and will work in a wider variety of situations. normalizing input grads is worse due to getting rid of valuable info for preconditioners.
psgd-jax 0.2.8
What's Changed
normalize_grads
option added to normalize incoming grads to unit norm, can help with grads with poor distribution- Get rid of trust region in favor of
normalize_grads
- Deterministic preconditioner update
- Damping based on machine precision added to handle singular or near-singular g g ^ T properly
psgd-jax 0.2.7
What's Changed
- small improvements
psgd-jax 0.2.6
What's Changed
- New trust region clipping that should need very little (or no) tuning
- scalar bug fix from ClashLuke
psgd-jax 0.2.5
What's Changed
- expose trust region clipping value. usually don't need to change, try lower LR first, but if things are still looking a little unstable you can try lowering this for hard problems/rough gradients.
psgd-jax 0.2.4
What's Changed
- Swap
max_skew_triangular
formemory_save_mode
to give easy ways to use different preconditioner setups and save memory/compute. - readme updates
psgd-jax 0.2.3
What's Changed
- remove use of opt_einsum
psgd-jax 0.2.2
What's Changed
- Simplify code a little bit
psgd-jax 0.2.1
What's Changed
- add min ndim triangular arg to better catch bias and scale params for diag preconditioners without catching linear params like (512, 1)
psgd-jax 0.2.0
What's Changed
- Use jax instead of optax for init momentum