-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement variance reduction in SLQ logdet backward pass. #1836
Conversation
a144a73
to
d542553
Compare
d542553
to
475ca01
Compare
@jacobrgardner @JonathanWenger ready for review |
475ca01
to
70b79c2
Compare
(actually ready for review now. I just fixed broken tests.) |
@gpleiss something seems off about how this is computing the preconditioner log determinant now. We're still computing it efficiently using the QR decomposition in the Entirely possible I just missed the relevant code here... |
Ideally, this would be the logdet value we'd return in the forward pass:
|
3e69d22
to
f2d487c
Compare
This reverts commit d2bff48.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks solid to me now 👍
I just profiled this PR on the KeOps example notebook - just to double check. It is just as fast as what is on is on master. |
Based on "Reducing the Variance of Gaussian Process Hyperparameter Optimization with Preconditioning" by Wenger et al., 2021.
When using iterative methods (i.e. CG/SLQ) to compute the log determinant, the forward pass currently computes:
logdet K \approx logdet P + SLQ( P^{-1/2} K P^{-1/2} )
,where
P
is a preconditioner, andSLQ
is a stochastic estimate of the log determinant. If the preconditioner is a good approximation of K, then this forward pass can be seen as a form of variance reduction.In this PR, we apply this same variance reduction strategy to the backward pass. We compute the backward pass as:
d logdet(K)/dtheta \approx d logdet(P)/dtheta + d SLQ/dtheta
TODOs:
inv_quad_logdet
function to apply variance reduction in the forward and backward passes.