Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of prior terms in ExactMarginalLogLikelihood #2039

Merged
merged 9 commits into from
Jul 29, 2022

Conversation

saitcakmak
Copy link
Collaborator

@saitcakmak saitcakmak commented Jun 14, 2022

This fixes an issue originally discovered in pytorch/botorch#1259. The bug relates to handling of prior terms in https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/mlls/exact_marginal_log_likelihood.py#L42-L43

When training batched model of n_batch batches, the res term (which is the log probability of observations under the prior) has shape n_batch. The prior term prior.log_prob(closure(module)) also has a shape that is n_batch x (potential other dimensions). When we add the sum of all prior terms via res.add_ we end up adding each prior term to each one of the terms in res, when they should only be added to the term corresponding to the same batch. This is corrected by replacing the sum() with a series of sums that reduce the prior term to have the same ndim as res.

Test plan:
Units

cc @Balandat

@saitcakmak
Copy link
Collaborator Author

This seems to be the same issue as #1318, so that should also get fixed.

Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Balandat Balandat merged commit 271f53d into cornellius-gp:master Jul 29, 2022
@saitcakmak saitcakmak deleted the mll_bug branch August 31, 2022 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants