-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add find_MAP
with close JAX integration and fix bug with Laplace fit
#385
Conversation
Hey, nice one, yeah I agree, we should only have one
The current Current behaviour of
The behaviour when you only pass a subset of variables isn't really desirable in my opinion (see #345 (comment)), so we put a warning. So as you say:
Agree, that's the best plan for Judging by your docs and a quick glance at your code, I think you're basically doing the same thing. The current implementation is few lines of code and a few docs, so I reckon
Then it should be safe to delete the existing code and we can go back to one
I would love a generic optimiser in p u r e pytensor, but I can see looking at your code that there a lot of fancy extras that would take a large effort to write in pytensor. Still, if we want to go back to one of our efforts with a fixed point operator (pymc-devs/pytensor#978 and pymc-devs/pytensor#944), we could probably write Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs. |
That would be appreciated |
Agree with what @theorashid said. This |
No objections about your custom library wrapper |
tagging @theorashid -- I couldn't pick you as a reviewer? I did a major refactor of this. I broke the marriage to jax and generalized the find_MAP function. Files have been renamed to reflect this. I also merged the two laplace approaches. The biggest change is that I removed the ability to choose |
yea sorry I'm just a normal, but I'll give it a review. Will do it at some point in the next 2 weeks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestions, PR looks amazing!
@jessegrabowski can we close #376 with this PR? Do you have a test that covers something like it? |
- Rename function `laplace` -> `sample_laplace_posterior`
find_MAP
functionfind_MAP
with close JAX integration and fix bug with Laplace fit
sweet, all done? |
For now, though I'd still appreciate it if you could have a look and open issues on any bugs/shortcomings you find |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to follow the code through and it looks good to me. Happy you got rid of the option to fit on a subset of variables, which didn't make sense to me anyway. If it passes the original test then it should be good. You can do something about the other comments if you want, but maybe not because we are e x p e r i m e n t a l
return f_loss_and_grad, f_hess, f_hessp | ||
|
||
|
||
def _compile_functions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Maybe _compile_functions
and _compile_jax_gradients
are slightly too generic function names. I found it a little tricky to remember exactly what they were doing when reading through the code
use_hess = use_hess if use_hess is not None else method_info["uses_hess"] | ||
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] | ||
|
||
if use_hess and use_hessp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going through all the methods thinking when you would need hess
and hessp
and then came back to this. I would probably warn the user / not let them pass both use_hess
and use_hessp
return idata | ||
|
||
|
||
def fit_mvn_to_MAP( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fit_mvn_at_MAP
? I mean technically this function just fits a MVN at a point
, the user doesn't necessarily have to pass the MAP
H_inv = get_nearest_psd(H_inv) | ||
if on_bad_cov == "warn": | ||
_log.warning( | ||
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, what sort of scenarios/models would get a not PSD hessian. And is using closest PSD a good ideas?
|
||
Parameters | ||
---------- | ||
mu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docs here
and 1). | ||
|
||
.. warning:: | ||
This argumnet should be considered highly experimental. It has not been verified if this method produces |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*argument
gradient_backend: str, default "pytensor" | ||
The backend to use for gradient computations. Must be one of "pytensor" or "jax". | ||
chains: int, default: 2 | ||
The number of sampling chains running in parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably add something here reiterating that this isn't a sampling inference method. This is just sampling from the approximated posterior. There was already people in the forum asking about the differences in these methods
|
||
|
||
@pytest.mark.parametrize( | ||
"method, use_grad, use_hess", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any use_hessp
tests? or are we just testing if scipy works here
Closes #376
This PR adds code to run
find_MAP
using JAX. I'm using JAX for gradients, because I found the compile times were faster. Open to suggestions/rebuke.It also adds a
fit_laplace
function, which is bad because we already have afit_laplace
function. This one has slightly different objective though -- it isn't meant to be used as a step sampler on a subset of model variables. Instead, it is meant to be used on the MAP result to give an approximation to the full posterior. My function also lets you do the Laplace approximation in the transformed space, then do sample-wise reverse transformation. I think this is legit, and lets you obtain approximate posteriors that respect the domain of the prior. Tagging @theorashid so we can resolve the differences.Last point is that I added a dependency on
better_optimize
. This is a package I wrote that basically rips out the wrapper code used in PyMCfind_MAP
and applies it to arbitrary optimization problems. It is more feature complete than the PyMC wrapper -- it supports all optimizer modes forscipy.optimize.minimize
andscipy.optimize.root
, and also helps get keywords to the right place in those functions (who can ever remember if an argument goes inmethod_kwargs
or in the funciton itself?). I plan to add support forbasinhopping
as well, which will be nice for really hairy minimizations.I could see an objection to adding another dependency, but 1) it's a lightweight wrapper around functionality that doesn't really belong in PyMC anyway, and 2) it's a big value-add compared to working directly with the
scipy.optimize
functions, which have gnarly, inconsistent signatures.