Skip to content

Commit

Permalink
Merge 69091e6 into 2aef8e2
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsMrLin authored Sep 28, 2023
2 parents 2aef8e2 + 69091e6 commit 4c84ce0
Show file tree
Hide file tree
Showing 3 changed files with 703 additions and 684 deletions.
25 changes: 13 additions & 12 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
from botorch.sampling import IIDNormalSampler, MCSampler
from botorch.sampling import IIDNormalSampler
from botorch.utils import apply_constraints
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from linear_operator.operators.dense_linear_operator import to_linear_operator
Expand Down Expand Up @@ -547,7 +547,8 @@ class LearnedObjective(MCAcquisitionObjective):
def __init__(
self,
pref_model: Model,
sampler: Optional[MCSampler] = None,
sample_shape: Optional[torch.Size] = None,
seed: Optional[int] = None,
):
r"""
Args:
Expand All @@ -564,13 +565,13 @@ def __init__(
super().__init__()
self.pref_model = pref_model
if isinstance(pref_model, DeterministicModel):
assert sampler is None
assert sample_shape is None
self.sampler = None
else:
if sampler is None:
self.sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
else:
self.sampler = sampler
if sample_shape is None:
sample_shape = torch.Size([1])
self.sampler = IIDNormalSampler(sample_shape=sample_shape, seed=seed)
self.sampler.batch_range_override = (1, -1)

def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Sample each element of samples.
Expand All @@ -583,7 +584,7 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
A `(sample_size * num_samples) x batch_shape x N`-dim Tensor of
objective values sampled from utility posterior using `pref_model`.
"""
if samples.dtype == torch.float32 and any(
if samples.dtype != torch.float64 and any(
d == torch.float64 for d in self.pref_model.dtypes_of_buffers
):
warnings.warn(
Expand All @@ -593,11 +594,11 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
)
samples = samples.to(torch.float64)

post = self.pref_model.posterior(samples)
posterior = self.pref_model.posterior(samples)
if isinstance(self.pref_model, DeterministicModel):
# return preference posterior mean
return post.mean.squeeze(-1)
return posterior.mean.squeeze(-1)
else:
# return preference posterior sample mean
samples = self.sampler(post).squeeze(-1)
# return preference posterior augmented samples
samples = self.sampler(posterior).squeeze(-1)
return samples.reshape(-1, *samples.shape[2:]) # batch_shape x N
40 changes: 33 additions & 7 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils import apply_constraints
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
Expand Down Expand Up @@ -457,12 +456,17 @@ def test_learned_preference_objective(self) -> None:
pref_model = self._get_pref_model(dtype=torch.float64)

og_sample_shape = 3
large_sample_shape = 256
batch_size = 2
n = 8
test_X = torch.rand(
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
dtype=torch.float64,
)
large_X = torch.rand(
torch.Size((large_sample_shape, batch_size, n, self.x_dim)),
dtype=torch.float64,
)

# test default setting where sampler =
# IIDNormalSampler(sample_shape=torch.Size([1]))
Expand All @@ -472,19 +476,35 @@ def test_learned_preference_objective(self) -> None:
self.assertEqual(
first_call_output.shape, torch.Size([og_sample_shape, batch_size, n])
)

# test when sampler has num_samples = 16
with self.subTest("SobolQMCNormalSampler"):
num_samples = 16
# Making sure the sampler has correct base_samples shape
self.assertEqual(
pref_obj.sampler.base_samples.shape,
torch.Size([1, og_sample_shape, 1, n]),
)
# Passing through a same-shaped X again shouldn't change the base sample
previous_base_samples = pref_obj.sampler.base_samples
another_test_X = torch.rand_like(test_X)
pref_obj(another_test_X)
self.assertIs(pref_obj.sampler.base_samples, previous_base_samples)

# test when sampler has multiple preference samples
with self.subTest("Multiple samples"):
num_samples = 256
pref_obj = LearnedObjective(
pref_model=pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
sample_shape=torch.Size([num_samples]),
)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([num_samples * og_sample_shape, batch_size, n]),
)

avg_obj_val = pref_obj(large_X).mean(dim=0)
flipped_avg_obj_val = pref_obj(large_X.flip(dims=[0])).mean(dim=0)
# Check if they are approximately close.
# The variance is large hence the loose atol.
self.assertAllClose(avg_obj_val, flipped_avg_obj_val, atol=1e-2)

# test posterior mean
with self.subTest("PosteriorMeanModel"):
mean_pref_model = PosteriorMeanModel(model=pref_model)
Expand All @@ -493,11 +513,17 @@ def test_learned_preference_objective(self) -> None:
pref_obj(test_X).shape, torch.Size([og_sample_shape, batch_size, n])
)

# the order of samples shouldn't matter
avg_obj_val = pref_obj(large_X).mean(dim=0)
flipped_avg_obj_val = pref_obj(large_X.flip(dims=[0])).mean(dim=0)
# When we use the posterior mean objective, they should be very close
self.assertAllClose(avg_obj_val, flipped_avg_obj_val)

# cannot use a deterministic model together with a sampler
with self.subTest("deterministic model"), self.assertRaises(AssertionError):
LearnedObjective(
pref_model=mean_pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
sample_shape=torch.Size([num_samples]),
)

def test_dtype_compatibility_with_PairwiseGP(self) -> None:
Expand Down
Loading

0 comments on commit 4c84ce0

Please sign in to comment.