Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix LearnedObjective base sample shape #2021

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix LearnedObjective base sample shape (#2021)
Summary:

Explicitly specifcy batch_shape in LearnedObjective's sampler to prevent shared base sample in the same batch.

Reviewed By: Balandat, esantorella

Differential Revision: D49574252
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Sep 28, 2023
commit 69091e69f48d4f265fea2e5ecc36ca9dc30fcf3f
25 changes: 13 additions & 12 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
from botorch.sampling import IIDNormalSampler, MCSampler
from botorch.sampling import IIDNormalSampler
from botorch.utils import apply_constraints
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from linear_operator.operators.dense_linear_operator import to_linear_operator
Expand Down Expand Up @@ -547,7 +547,8 @@ class LearnedObjective(MCAcquisitionObjective):
def __init__(
self,
pref_model: Model,
sampler: Optional[MCSampler] = None,
sample_shape: Optional[torch.Size] = None,
seed: Optional[int] = None,
):
r"""
Args:
Expand All @@ -564,13 +565,13 @@ def __init__(
super().__init__()
self.pref_model = pref_model
if isinstance(pref_model, DeterministicModel):
assert sampler is None
assert sample_shape is None
self.sampler = None
else:
if sampler is None:
self.sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
else:
self.sampler = sampler
if sample_shape is None:
sample_shape = torch.Size([1])
self.sampler = IIDNormalSampler(sample_shape=sample_shape, seed=seed)
self.sampler.batch_range_override = (1, -1)

def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Sample each element of samples.
Expand All @@ -583,7 +584,7 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
A `(sample_size * num_samples) x batch_shape x N`-dim Tensor of
objective values sampled from utility posterior using `pref_model`.
"""
if samples.dtype == torch.float32 and any(
if samples.dtype != torch.float64 and any(
d == torch.float64 for d in self.pref_model.dtypes_of_buffers
):
warnings.warn(
Expand All @@ -593,11 +594,11 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
)
samples = samples.to(torch.float64)

post = self.pref_model.posterior(samples)
posterior = self.pref_model.posterior(samples)
if isinstance(self.pref_model, DeterministicModel):
# return preference posterior mean
return post.mean.squeeze(-1)
return posterior.mean.squeeze(-1)
else:
# return preference posterior sample mean
samples = self.sampler(post).squeeze(-1)
# return preference posterior augmented samples
samples = self.sampler(posterior).squeeze(-1)
return samples.reshape(-1, *samples.shape[2:]) # batch_shape x N
40 changes: 33 additions & 7 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils import apply_constraints
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
Expand Down Expand Up @@ -457,12 +456,17 @@ def test_learned_preference_objective(self) -> None:
pref_model = self._get_pref_model(dtype=torch.float64)

og_sample_shape = 3
large_sample_shape = 256
batch_size = 2
n = 8
test_X = torch.rand(
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
dtype=torch.float64,
)
large_X = torch.rand(
torch.Size((large_sample_shape, batch_size, n, self.x_dim)),
dtype=torch.float64,
)

# test default setting where sampler =
# IIDNormalSampler(sample_shape=torch.Size([1]))
Expand All @@ -472,19 +476,35 @@ def test_learned_preference_objective(self) -> None:
self.assertEqual(
first_call_output.shape, torch.Size([og_sample_shape, batch_size, n])
)

# test when sampler has num_samples = 16
with self.subTest("SobolQMCNormalSampler"):
num_samples = 16
# Making sure the sampler has correct base_samples shape
self.assertEqual(
pref_obj.sampler.base_samples.shape,
torch.Size([1, og_sample_shape, 1, n]),
)
# Passing through a same-shaped X again shouldn't change the base sample
previous_base_samples = pref_obj.sampler.base_samples
another_test_X = torch.rand_like(test_X)
pref_obj(another_test_X)
self.assertIs(pref_obj.sampler.base_samples, previous_base_samples)

# test when sampler has multiple preference samples
with self.subTest("Multiple samples"):
num_samples = 256
pref_obj = LearnedObjective(
pref_model=pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
sample_shape=torch.Size([num_samples]),
)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([num_samples * og_sample_shape, batch_size, n]),
)

avg_obj_val = pref_obj(large_X).mean(dim=0)
flipped_avg_obj_val = pref_obj(large_X.flip(dims=[0])).mean(dim=0)
# Check if they are approximately close.
# The variance is large hence the loose atol.
self.assertAllClose(avg_obj_val, flipped_avg_obj_val, atol=1e-2)

# test posterior mean
with self.subTest("PosteriorMeanModel"):
mean_pref_model = PosteriorMeanModel(model=pref_model)
Expand All @@ -493,11 +513,17 @@ def test_learned_preference_objective(self) -> None:
pref_obj(test_X).shape, torch.Size([og_sample_shape, batch_size, n])
)

# the order of samples shouldn't matter
avg_obj_val = pref_obj(large_X).mean(dim=0)
flipped_avg_obj_val = pref_obj(large_X.flip(dims=[0])).mean(dim=0)
# When we use the posterior mean objective, they should be very close
self.assertAllClose(avg_obj_val, flipped_avg_obj_val)

# cannot use a deterministic model together with a sampler
with self.subTest("deterministic model"), self.assertRaises(AssertionError):
LearnedObjective(
pref_model=mean_pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
sample_shape=torch.Size([num_samples]),
)

def test_dtype_compatibility_with_PairwiseGP(self) -> None:
Expand Down
Loading