From be679c72ade4c47b39dc907ff79df304ba87d80e Mon Sep 17 00:00:00 2001 From: sebastianament Date: Mon, 16 Oct 2023 13:18:26 -0700 Subject: [PATCH] qLogEHVI (#2036) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2036 This commit adds `qLogEHVI`, a member of the LogEI family of acquisition functions, for multi-objective optimization problems. Differential Revision: https://internalfb.com/D49967862 fbshipit-source-id: 3b45182b94a2289d10840c6414d95c1adeb749b8 --- botorch/acquisition/analytic.py | 8 + botorch/acquisition/logei.py | 23 +- botorch/acquisition/monte_carlo.py | 2 + botorch/acquisition/multi_objective/logei.py | 305 +++++ .../multi_objective/monte_carlo.py | 47 +- botorch/utils/multi_objective/hypervolume.py | 62 + botorch/utils/safe_math.py | 40 +- sphinx/source/acquisition.rst | 3 + .../acquisition/multi_objective/test_logei.py | 144 +++ .../multi_objective/test_monte_carlo.py | 1043 +++++++++-------- test/acquisition/test_logei.py | 2 - 11 files changed, 1157 insertions(+), 522 deletions(-) create mode 100644 botorch/acquisition/multi_objective/logei.py create mode 100644 test/acquisition/multi_objective/test_logei.py diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index ad848c94c3..2ad538acc2 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -362,6 +362,8 @@ class LogExpectedImprovement(AnalyticAcquisitionFunction): to avoid numerical issues in the computation of the acquisition value and its gradient in regions where improvement is predicted to be virtually impossible. + See [Ament2023logei]_ for details. Formally, + `LogEI(x) = log(E(max(f(x) - best_f, 0))),` where the expectation is taken over the value of stochastic function `f` at `x`. @@ -423,7 +425,10 @@ class LogConstrainedExpectedImprovement(AnalyticAcquisitionFunction): multi-outcome, with the index of the objective and constraints passed to the constructor. + See [Ament2023logei]_ for details. Formally, + `LogConstrainedEI(x) = log(EI(x)) + Sum_i log(P(y_i \in [lower_i, upper_i]))`, + where `y_i ~ constraint_i(x)` and `lower_i`, `upper_i` are the lower and upper bounds for the i-th constraint, respectively. @@ -569,7 +574,10 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction): `q=1`. Assumes that the posterior distribution of the model is Gaussian. The model must be single-outcome. + See [Ament2023logei]_ for details. Formally, + `LogNEI(x) = log(E(max(y - max Y_base), 0))), (y, Y_base) ~ f((x, X_base))`, + where `X_base` are previously observed points. Note: This acquisition function currently relies on using a FixedNoiseGP (required diff --git a/botorch/acquisition/logei.py b/botorch/acquisition/logei.py index e12228c1a5..677e356c33 100644 --- a/botorch/acquisition/logei.py +++ b/botorch/acquisition/logei.py @@ -4,7 +4,15 @@ # LICENSE file in the root directory of this source tree. r""" -Batch implementations of the LogEI family of improvements-based acquisition functions. +Monte-Carlo variants of the LogEI family of improvements-based acquisition functions, +see [Ament2023logei]_ for details. + +References + +.. [Ament2023logei] + S. Ament, S. Daulton, D. Eriksson, M. Balandat, and E. Bakshy. + Unexpected Improvements to Expected Improvement for Bayesian Optimization. Advances + in Neural Information Processing Systems 36, 2023. """ from __future__ import annotations @@ -138,9 +146,11 @@ class qLogExpectedImprovement(LogImprovementMCAcquisitionFunction): (3) smoothly maximizing over q, and (4) averaging over the samples in log space. - `qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0)))`, + See [Ament2023logei]_ for details. Formally, + + `qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0)))`. - where `Y ~ f(X)`, and `X = (x_1,...,x_q)`. + where `Y ~ f(X)`, and `X = (x_1,...,x_q)`, . Example: >>> model = SingleTaskGP(train_X, train_Y) @@ -237,8 +247,11 @@ class qLogNoisyExpectedImprovement( to the canonical improvement over previously observed points is computed for each sample and the logarithm of the average is returned. - `qLogNEI(X) ~ log(qNEI(X)) = Log E(max(max Y - max Y_baseline, 0))`, where - `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)` + See [Ament2023logei]_ for details. Formally, + + `qLogNEI(X) ~ log(qNEI(X)) = Log E(max(max Y - max Y_baseline, 0))`, + + where `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`. Example: >>> model = SingleTaskGP(train_X, train_Y) diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index 34c14f0402..8da9d67a1b 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -9,6 +9,8 @@ with (quasi) Monte-Carlo sampling. See [Rezende2014reparam]_, [Wilson2017reparam]_ and [Balandat2020botorch]_. +References + .. [Rezende2014reparam] D. J. Rezende, S. Mohamed, and D. Wierstra. Stochastic backpropagation and approximate inference in deep generative models. ICML 2014. diff --git a/botorch/acquisition/multi_objective/logei.py b/botorch/acquisition/multi_objective/logei.py new file mode 100644 index 0000000000..6bd3cfa733 --- /dev/null +++ b/botorch/acquisition/multi_objective/logei.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Multi-objective variants of the LogEI family of acquisition functions, see +[Ament2023logei]_ for details. +""" + +from __future__ import annotations + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from botorch.acquisition.logei import TAU_MAX, TAU_RELU +from botorch.acquisition.multi_objective import MultiObjectiveMCAcquisitionFunction +from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective +from botorch.models.model import Model +from botorch.sampling.base import MCSampler +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) +from botorch.utils.multi_objective.hypervolume import SubsetIndexCachingMixin +from botorch.utils.objective import compute_smoothed_feasibility_indicator +from botorch.utils.safe_math import ( + fatmin, + log_fatplus, + log_softplus, + logdiffexp, + logmeanexp, + logplusexp, + smooth_amin, +) +from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform +from torch import Tensor + + +class qLogExpectedHypervolumeImprovement( + MultiObjectiveMCAcquisitionFunction, SubsetIndexCachingMixin +): + def __init__( + self, + model: Model, + ref_point: Union[List[float], Tensor], + partitioning: NondominatedPartitioning, + sampler: Optional[MCSampler] = None, + objective: Optional[MCMultiOutputObjective] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + X_pending: Optional[Tensor] = None, + eta: Optional[Union[Tensor, float]] = 1e-2, + fat: bool = True, + tau_relu: float = TAU_RELU, + tau_max: float = TAU_MAX, + ) -> None: + r"""Parallel Log Expected Hypervolume Improvement supporting m>=2 outcomes. + + See [Ament2023logei]_ for details and the methodology behind the LogEI family of + acquisition function. Line-by-line differences to the original differentiable + expected hypervolume formulation of [Daulton2020qehvi]_ are described via inline + comments in `forward`. + + Example: + >>> model = SingleTaskGP(train_X, train_Y) + >>> ref_point = [0.0, 0.0] + >>> acq = qLogExpectedHypervolumeImprovement(model, ref_point, partitioning) + >>> value = acq(test_X) + + Args: + model: A fitted model. + ref_point: A list or tensor with `m` elements representing the reference + point (in the outcome space) w.r.t. to which compute the hypervolume. + This is a reference point for the objective values (i.e. after + applying`objective` to the samples). + partitioning: A `NondominatedPartitioning` module that provides the non- + dominated front and a partitioning of the non-dominated space in hyper- + rectangles. If constraints are present, this partitioning must only + include feasible points. + sampler: The sampler used to draw base samples. If not given, + a sampler is generated using `get_sampler`. + objective: The MCMultiOutputObjective under which the samples are evaluated. + Defaults to `IdentityMultiOutputObjective()`. + constraints: A list of callables, each mapping a Tensor of dimension + `sample_shape x batch-shape x q x m` to a Tensor of dimension + `sample_shape x batch-shape x q`, where negative values imply + feasibility. The acqusition function will compute expected feasible + hypervolume. + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have + points that have been submitted for function evaluation but have not yet + been evaluated. Concatenated into `X` upon forward call. Copied and set + to have no gradient. + eta: The temperature parameter for the sigmoid function used for the + differentiable approximation of the constraints. In case of a float the + same eta is used for every constraint in constraints. In case of a + tensor the length of the tensor must match the number of provided + constraints. The i-th constraint is then estimated with the i-th + eta value. + fat: Toggles the logarithmic / linear asymptotic behavior of the smooth + approximation to the ReLU and the maximum. + tau_relu: Temperature parameter controlling the sharpness of the + approximation to the ReLU over the `q` candidate points. For further + details, see the comments above the definition of `TAU_RELU`. + tau_max: Temperature parameter controlling the sharpness of the + approximation to the `max` operator over the `q` candidate points. + For further details, see the comments above the definition of `TAU_MAX`. + """ + if len(ref_point) != partitioning.num_outcomes: + raise ValueError( + "The dimensionality of the reference point must match the number of " + f"outcomes. Got ref_point with {len(ref_point)} elements, but expected " + f"{partitioning.num_outcomes}." + ) + ref_point = torch.as_tensor( + ref_point, + dtype=partitioning.pareto_Y.dtype, + device=partitioning.pareto_Y.device, + ) + super().__init__( + model=model, + sampler=sampler, + objective=objective, + constraints=constraints, + eta=eta, + X_pending=X_pending, + ) + self.register_buffer("ref_point", ref_point) + cell_bounds = partitioning.get_hypercell_bounds() + self.register_buffer("cell_lower_bounds", cell_bounds[0]) + self.register_buffer("cell_upper_bounds", cell_bounds[1]) + SubsetIndexCachingMixin.__init__(self) + self.tau_relu = tau_relu + self.tau_max = tau_max + self.fat = fat + + def _compute_log_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: + r"""Compute the expected (feasible) hypervolume improvement given MC samples. + + Args: + samples: A `sample_shape x batch_shape x q' x m`-dim tensor of samples. + X: A `batch_shape x q x d`-dim tensor of inputs. + + Returns: + A `batch_shape x (model_batch_shape)`-dim tensor of expected hypervolume + improvement for each batch. + """ + # Note that the objective may subset the outcomes (e.g. this will usually happen + # if there are constraints present). + obj = self.objective(samples, X=X) # mc_samples x batch_shape x q x m + q = obj.shape[-2] + if self.constraints is not None: + log_feas_weights = compute_smoothed_feasibility_indicator( + constraints=self.constraints, + samples=samples, + eta=self.eta, + log=True, + fat=self.fat, + ) + device = self.ref_point.device + q_subset_indices = self.compute_q_subset_indices(q_out=q, device=device) + batch_shape = obj.shape[:-2] # mc_samples x batch_shape + # areas tensor is `mc_samples x batch_shape x num_cells x 2`-dim + log_areas_per_segment = torch.full( + size=( + *batch_shape, + self.cell_lower_bounds.shape[-2], # num_cells + 2, # for even and odd terms + ), + fill_value=-torch.inf, + dtype=obj.dtype, + device=device, + ) + + cell_batch_ndim = self.cell_lower_bounds.ndim - 2 + # conditionally adding mc_samples dim if cell_batch_ndim > 0 + # adding ones to shape equal in number to to batch_shape_ndim - cell_batch_ndim + # adding cell_bounds batch shape w/o 1st dimension + sample_batch_view_shape = torch.Size( + [ + batch_shape[0] if cell_batch_ndim > 0 else 1, + *[1 for _ in range(len(batch_shape) - max(cell_batch_ndim, 1))], + *self.cell_lower_bounds.shape[1:-2], + ] + ) + view_shape = ( + *sample_batch_view_shape, + self.cell_upper_bounds.shape[-2], # num_cells + 1, # adding for q_choose_i dimension + self.cell_upper_bounds.shape[-1], # num_objectives + ) + + for i in range(1, self.q_out + 1): + # TODO: we could use batches to compute (q choose i) and (q choose q-i) + # simultaneously since subsets of size i and q-i have the same number of + # elements. This would decrease the number of iterations, but increase + # memory usage. + q_choose_i = q_subset_indices[f"q_choose_{i}"] # q_choose_i x i + # this tensor is mc_samples x batch_shape x i x q_choose_i x m + obj_subsets = obj.index_select(dim=-2, index=q_choose_i.view(-1)) + obj_subsets = obj_subsets.view( + obj.shape[:-2] + q_choose_i.shape + obj.shape[-1:] + ) # mc_samples x batch_shape x q_choose_i x i x m + + # NOTE: the order of operations in non-log _compute_qehvi is 3), 1), 2). + # since 3) moved above 1), _log_improvement adds another Tensor dimension + # that keeps track of num_cells. + + # 1) computes log smoothed improvement over the cell lower bounds. + # mc_samples x batch_shape x num_cells x q_choose_i x i x m + log_improvement_i = self._log_improvement(obj_subsets, view_shape) + + # 2) take the minimum log improvement over all i subsets. + # since all hyperrectangles share one vertex, the opposite vertex of the + # overlap is given by the component-wise minimum. + # negative of maximum of negative log_improvement is approximation to min. + log_improvement_i = self._smooth_min( + log_improvement_i, + dim=-2, + ) # mc_samples x batch_shape x num_cells x q_choose_i x m + + # 3) compute the log lengths of the cells' sides. + # mc_samples x batch_shape x num_cells x q_choose_i x m + log_lengths_i = self._log_cell_lengths(log_improvement_i, view_shape) + + # 4) take product over hyperrectangle side lengths to compute area (m-dim). + # after, log_areas_i is mc_samples x batch_shape x num_cells x q_choose_i + log_areas_i = log_lengths_i.sum(dim=-1) # areas_i = lengths_i.prod(dim=-1) + + # 5) if constraints are present, apply a differentiable approximation of + # the indicator function. + if self.constraints is not None: + log_feas_subsets = log_feas_weights.index_select( + dim=-1, index=q_choose_i.view(-1) + ).view(log_feas_weights.shape[:-1] + q_choose_i.shape) + log_areas_i = log_areas_i + log_feas_subsets.unsqueeze(-3).sum(dim=-1) + + # 6) sum over all subsets of size i, i.e. reduce over q_choose_i-dim + # after, log_areas_i is mc_samples x batch_shape x num_cells + log_areas_i = torch.logsumexp(log_areas_i, dim=-1) # areas_i.sum(dim=-1) + + # 7) Using the inclusion-exclusion principle, set the sign to be positive + # for subsets of odd sizes and negative for subsets of even size + # in non-log space: areas_per_segment += (-1) ** (i + 1) * areas_i, + # but here in log space, we need to keep track of sign: + log_areas_per_segment[..., i % 2] = logplusexp( + log_areas_per_segment[..., i % 2], + log_areas_i, + ) + + # 8) subtract even from odd log area terms + log_areas_per_segment = logdiffexp( + log_a=log_areas_per_segment[..., 0], log_b=log_areas_per_segment[..., 1] + ) + + # 9) sum over segments (n_cells-dim) and average over MC samples + return logmeanexp(torch.logsumexp(log_areas_per_segment, dim=-1), dim=0) + + def _log_improvement( + self, obj_subsets: Tensor, view_shape: Union[Tuple, torch.Size] + ) -> Tensor: + # smooth out the clamp and take the log (previous step 3) + # substract cell lower bounds, clamp min at zero, but first + # make obj_subsets broadcastable with cell bounds: + # mc_samples x batch_shape x (num_cells = 1) x q_choose_i x i x m + obj_subsets = obj_subsets.unsqueeze(-4) + # making cell bounds broadcastable with obj_subsets: + # (mc_samples = 1) x (batch_shape = 1) x num_cells x 1 x (i = 1) x m + cell_lower_bounds = self.cell_lower_bounds.view(view_shape).unsqueeze(-3) + Z = obj_subsets - cell_lower_bounds + log_Zi = self._log_smooth_relu(Z) + return log_Zi # mc_samples x batch_shape x num_cells x q_choose_i x i x m + + def _log_cell_lengths( + self, log_improvement_i: Tensor, view_shape: Union[Tuple, torch.Size] + ) -> Tensor: + cell_upper_bounds = self.cell_upper_bounds.clamp_max( + 1e10 if log_improvement_i.dtype == torch.double else 1e8 + ) # num_cells x num_objectives + # add batch-dim to compute area for each segment (pseudo-pareto-vertex) + log_cell_lengths = ( + (cell_upper_bounds - self.cell_lower_bounds).log().view(view_shape) + ) # (mc_samples = 1) x (batch_shape = 1) x n_cells x (q_choose_i = 1) x m + # mc_samples x batch_shape x num_cells x q_choose_i x m + return self._smooth_minimum( + log_improvement_i, + log_cell_lengths, + ) + + def _log_smooth_relu(self, X: Tensor) -> Tensor: + f = log_fatplus if self.fat else log_softplus + return f(X, tau=self.tau_relu) + + def _smooth_min(self, X: Tensor, dim: int, keepdim: bool = False) -> Tensor: + f = fatmin if self.fat else smooth_amin + return f(X, tau=self.tau_max, dim=dim) + + def _smooth_minimum(self, X: Tensor, Y: Tensor) -> Tensor: + XY = torch.stack(torch.broadcast_tensors(X, Y), dim=-1) + return self._smooth_min(XY, dim=-1, keepdim=False) + + @concatenate_pending_points + @t_batch_mode_transform() + def forward(self, X: Tensor) -> Tensor: + posterior = self.model.posterior(X) + samples = self.get_posterior_samples(posterior) + return self._compute_log_qehvi(samples=samples, X=X) diff --git a/botorch/acquisition/multi_objective/monte_carlo.py b/botorch/acquisition/multi_objective/monte_carlo.py index 10e2f70e61..7ce2ff2915 100644 --- a/botorch/acquisition/multi_objective/monte_carlo.py +++ b/botorch/acquisition/multi_objective/monte_carlo.py @@ -26,7 +26,6 @@ import warnings from abc import ABC, abstractmethod from copy import deepcopy -from itertools import combinations from typing import Callable, List, Optional, Union import torch @@ -57,6 +56,7 @@ from botorch.utils.multi_objective.box_decompositions.utils import ( _pad_batch_pareto_frontier, ) +from botorch.utils.multi_objective.hypervolume import SubsetIndexCachingMixin from botorch.utils.objective import compute_smoothed_feasibility_indicator from botorch.utils.torch import BufferDict from botorch.utils.transforms import ( @@ -154,7 +154,9 @@ def forward(self, X: Tensor) -> Tensor: pass # pragma: no cover -class qExpectedHypervolumeImprovement(MultiObjectiveMCAcquisitionFunction): +class qExpectedHypervolumeImprovement( + MultiObjectiveMCAcquisitionFunction, SubsetIndexCachingMixin +): def __init__( self, model: Model, @@ -229,39 +231,7 @@ def __init__( cell_bounds = partitioning.get_hypercell_bounds() self.register_buffer("cell_lower_bounds", cell_bounds[0]) self.register_buffer("cell_upper_bounds", cell_bounds[1]) - self.q_out = -1 - self.q_subset_indices = BufferDict() - - def _cache_q_subset_indices(self, q_out: int) -> None: - r"""Cache indices corresponding to all subsets of `q_out`. - - This means that consecutive calls to `forward` with the same - `q_out` will not recompute the indices for all (2^q_out - 1) subsets. - - Note: this will use more memory than regenerating the indices - for each i and then deleting them, but it will be faster for - repeated evaluations (e.g. during optimization). - - Args: - q_out: The batch size of the objectives. This is typically equal - to the q-batch size of `X`. However, if using a set valued - objective (e.g., MVaR) that produces `s` objective values for - each point on the q-batch of `X`, we need to properly account - for each objective while calculating the hypervolume contributions - by using `q_out = q * s`. - """ - if q_out != self.q_out: - indices = list(range(q_out)) - tkwargs = {"dtype": torch.long, "device": self.ref_point.device} - self.q_subset_indices = BufferDict( - { - f"q_choose_{i}": torch.tensor( - list(combinations(indices, i)), **tkwargs - ) - for i in range(1, q_out + 1) - } - ) - self.q_out = q_out + SubsetIndexCachingMixin.__init__(self) def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: r"""Compute the expected (feasible) hypervolume improvement given MC samples. @@ -282,14 +252,15 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: feas_weights = compute_smoothed_feasibility_indicator( constraints=self.constraints, samples=samples, eta=self.eta ) # `sample_shape x batch-shape x q` - self._cache_q_subset_indices(q_out=q) + device = self.ref_point.device + q_subset_indices = self.compute_q_subset_indices(q_out=q, device=device) batch_shape = obj.shape[:-2] # this is n_samples x input_batch_shape x areas_per_segment = torch.zeros( *batch_shape, self.cell_lower_bounds.shape[-2], dtype=obj.dtype, - device=obj.device, + device=device, ) cell_batch_ndim = self.cell_lower_bounds.ndim - 2 sample_batch_view_shape = torch.Size( @@ -310,7 +281,7 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: # simultaneously since subsets of size i and q-i have the same number of # elements. This would decrease the number of iterations, but increase # memory usage. - q_choose_i = self.q_subset_indices[f"q_choose_{i}"] + q_choose_i = q_subset_indices[f"q_choose_{i}"] # this tensor is mc_samples x batch_shape x i x q_choose_i x m obj_subsets = obj.index_select(dim=-2, index=q_choose_i.view(-1)) obj_subsets = obj_subsets.view( diff --git a/botorch/utils/multi_objective/hypervolume.py b/botorch/utils/multi_objective/hypervolume.py index 3a160ecd72..7242926a71 100644 --- a/botorch/utils/multi_objective/hypervolume.py +++ b/botorch/utils/multi_objective/hypervolume.py @@ -22,10 +22,13 @@ from __future__ import annotations +from itertools import combinations + from typing import List, Optional import torch from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError +from botorch.utils.torch import BufferDict from torch import Tensor MIN_Y_RANGE = 1e-7 @@ -412,3 +415,62 @@ def reinsert(self, node: Node, index: int, bounds: Tensor) -> None: node.prev[i].next[i] = node node.next[i].prev[i] = node bounds.data = torch.min(bounds, node.data) + + +class SubsetIndexCachingMixin: + """A Mixin class that adds q-subset index computations and caching.""" + + def __init__(self): + """Initializes the class with q_out = -1 and an empty q_subset_indices dict.""" + self.q_out: int = -1 + self.q_subset_indices: BufferDict[str, Tensor] = BufferDict() + + def compute_q_subset_indices( + self, q_out: int, device: torch.device + ) -> BufferDict[str, Tensor]: + r"""Returns and caches a dict of indices equal to subsets of `{1, ..., q_out}`. + + This means that consecutive calls to `self.compute_q_subset_indices` with + the same `q_out` do not recompute the indices for all (2^q_out - 1) subsets. + + NOTE: This will use more memory than regenerating the indices + for each i and then deleting them, but it will be faster for + repeated evaluations (e.g. during optimization). + + Args: + q_out: The batch size of the objectives. This is typically equal + to the q-batch size of `X`. However, if using a set valued + objective (e.g., MVaR) that produces `s` objective values for + each point on the q-batch of `X`, we need to properly account + for each objective while calculating the hypervolume contributions + by using `q_out = q * s`. + + Returns: + A dict that maps "q choose i" to all size-i subsets of `{1, ..., q_out}`. + """ + if q_out != self.q_out: + self.q_subset_indices = compute_subset_indices(q_out, device=device) + self.q_out = q_out + return self.q_subset_indices + + +def compute_subset_indices( + q: int, device: Optional[torch.device] = None +) -> BufferDict[str, Tensor]: + r"""Compute all (2^q - 1) distinct subsets of {1, ..., `q`}. + + Args: + q: An integer defininig the set {1, ..., `q`} whose subsets to compute. + + Returns: + A dict that maps "q choose i" to all size-i subsets of {1, ..., `q_out`}. + """ + indices = torch.arange(q, dtype=torch.long, device=device) + return BufferDict( + { + f"q_choose_{i}": torch.tensor( + list(combinations(indices, i)), dtype=torch.long, device=device + ) + for i in range(1, q + 1) + } + ) diff --git a/botorch/utils/safe_math.py b/botorch/utils/safe_math.py index ac4d6bb11d..7c4c30c984 100644 --- a/botorch/utils/safe_math.py +++ b/botorch/utils/safe_math.py @@ -269,6 +269,16 @@ def smooth_amax( return logsumexp(X / tau, dim=dim, keepdim=keepdim) * tau # ~ X.amax(dim=dim) +def smooth_amin( + X: Tensor, + dim: Union[int, Tuple[int, ...]] = -1, + keepdim: bool = False, + tau: Union[float, Tensor] = 1.0, +) -> Tensor: + """A smooth approximation to `min(X, dim=dim)`, similar to `smooth_amax`.""" + return -smooth_amax(X=-X, dim=dim, keepdim=keepdim, tau=tau) + + def check_dtype_float32_or_float64(X: Tensor) -> None: if X.dtype != torch.float32 and X.dtype != torch.float64: raise UnsupportedError( @@ -316,7 +326,7 @@ def fatmax( """Computes a smooth approximation to amax(X, dim=dim) with a fat tail. Args: - X: A Tensor from which to compute the smoothed amax. + X: A Tensor from which to compute the smoothed maximum. dim: The dimensions to reduce over. keepdim: If True, keeps the reduced dimensions. tau: Temperature parameter controlling the smooth approximation @@ -327,7 +337,7 @@ def fatmax( recommended to keep this value low or moderate, e.g. < 10. Returns: - A Tensor of smooth approximations to `max(X, dim=dim)` with a fat tail. + A Tensor of smooth approximations to `amax(X, dim=dim)` with a fat tail. """ def max_fun( @@ -338,6 +348,32 @@ def max_fun( return _inf_max_helper(max_fun=max_fun, x=x, dim=dim, keepdim=keepdim) +def fatmin( + x: Tensor, + dim: Union[int, Tuple[int, ...]], + keepdim: bool = False, + tau: Union[float, Tensor] = TAU, + alpha: float = ALPHA, +) -> Tensor: + """Computes a smooth approximation to amin(X, dim=dim) with a fat tail. + + Args: + X: A Tensor from which to compute the smoothed minimum. + dim: The dimensions to reduce over. + keepdim: If True, keeps the reduced dimensions. + tau: Temperature parameter controlling the smooth approximation + to min operator, becomes tighter as tau goes to 0. Needs to be positive. + alpha: The exponent of the asymptotic power decay of the approximation. The + default value is 2. Higher alpha parameters make the function behave more + similarly to the standard logsumexp approximation to the max, so it is + recommended to keep this value low or moderate, e.g. < 10. + + Returns: + A Tensor of smooth approximations to `amin(X, dim=dim)` with a fat tail. + """ + return -fatmax(-x, dim=dim, keepdim=keepdim, tau=tau, alpha=alpha) + + def fatmaximum( a: Tensor, b: Tensor, tau: Union[float, Tensor] = TAU, alpha: float = ALPHA ) -> Tensor: diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 0e4d72fa7d..4fe407eb79 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -90,6 +90,9 @@ Multi-Objective Monte-Carlo Acquisition Functions :members: :exclude-members: MultiObjectiveMCAcquisitionFunction +.. automodule:: botorch.acquisition.multi_objective.logei + :members: + Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity diff --git a/test/acquisition/multi_objective/test_logei.py b/test/acquisition/multi_objective/test_logei.py new file mode 100644 index 0000000000..ec3f532544 --- /dev/null +++ b/test/acquisition/multi_objective/test_logei.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from botorch.acquisition.multi_objective.logei import qLogExpectedHypervolumeImprovement +from botorch.acquisition.multi_objective.monte_carlo import ( + MultiObjectiveMCAcquisitionFunction, +) +from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective +from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyMultiObjectiveMCAcquisitionFunction(MultiObjectiveMCAcquisitionFunction): + def forward(self, X): + pass + + +class DummyMCMultiOutputObjective(MCMultiOutputObjective): + def forward(self, samples, X=None): + if X is not None: + return samples[..., : X.shape[-2], :] + else: + return samples + + +class TestLogQExpectedHypervolumeImprovement(BotorchTestCase): + def test_q_log_expected_hypervolume_improvement(self): + for dtype, fat in itertools.product((torch.float, torch.double), (True, False)): + with self.subTest(dtype=dtype, fat=fat): + self._qLogEHVI_test(dtype, fat) + + def _qLogEHVI_test(self, dtype: torch.dtype, fat: bool): + """NOTE: The purpose of this test is to test the numerical particularities + of the qLogEHVI. For further tests including the non-numerical features of the + acquisition function, please see the corresponding tests - unified with qEHVI - + in `multi_objective/test_monte_carlo.py`. + """ + tkwargs = {"device": self.device, "dtype": dtype} + ref_point = [0.0, 0.0] + t_ref_point = torch.tensor(ref_point, **tkwargs) + pareto_Y = torch.tensor( + [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs + ) + partitioning = NondominatedPartitioning(ref_point=t_ref_point) + # the event shape is `b x q x m` = 1 x 1 x 2 + samples = torch.zeros(1, 1, 2, **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + partitioning.update(Y=pareto_Y) + + X = torch.zeros(1, 1, **tkwargs) + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([1])) + acqf = qLogExpectedHypervolumeImprovement( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + fat=fat, + ) + res = acqf(X) + exp_log_res = res.exp().item() + + # The log value is never -inf due to the smooth approximations. + self.assertFalse(res.isinf().item()) + + # Due to the smooth approximation, the value at zero should be close to, but + # not exactly zero, and upper-bounded by the tau hyperparameter. + if fat: + self.assertTrue(0 < exp_log_res) + self.assertTrue(exp_log_res <= acqf.tau_relu) + else: # This is an interesting difference between the exp and the fat tail. + # Even though the log value is never -inf, softmax's exponential tail gives + # rise to a zero value upon the exponentiation of the log acquisition value. + self.assertEqual(0, exp_log_res) + + # similar test for q=2 + X2 = torch.zeros(2, 1, **tkwargs) + samples2 = torch.zeros(1, 2, 2, **tkwargs) + mm2 = MockModel(MockPosterior(samples=samples2)) + acqf.model = mm2 + self.assertEqual(acqf.model, mm2) + self.assertIn("model", acqf._modules) + self.assertEqual(acqf._modules["model"], mm2) + + # see detailed comments for the tests around the first set of test above. + res = acqf(X2) + exp_log_res = res.exp().item() + self.assertFalse(res.isinf().item()) + if fat: + self.assertTrue(0 < exp_log_res) + self.assertTrue(exp_log_res <= acqf.tau_relu) + else: # This is an interesting difference between the exp and the fat tail. + self.assertEqual(0, exp_log_res) + + X = torch.zeros(1, 1, **tkwargs) + samples = torch.zeros(1, 1, 2, **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + acqf = qLogExpectedHypervolumeImprovement( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + fat=fat, + ) + res = acqf(X) + # non-log EHVI is zero, but qLogEHVI is not -Inf. + self.assertFalse(res.isinf().item()) + exp_log_res = res.exp().item() + if fat: + self.assertTrue(0 < exp_log_res) + self.assertTrue(exp_log_res <= 1e-10) # should be *very* small + else: # This is an interesting difference between the exp and the fat tail. + self.assertEqual(0, exp_log_res) + + # basic test, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + acqf = qLogExpectedHypervolumeImprovement( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + fat=fat, + ) + res = acqf(X) + exp_log_res = res.exp().item() + # non-log EHVI is zero, but qLogEHVI is not -Inf. + self.assertFalse(res.isinf().item()) + + if fat: + self.assertTrue(0 < exp_log_res) + self.assertTrue(exp_log_res <= 1e-10) # should be *very* small + else: # This is an interesting difference between the exp and the fat tail. + self.assertEqual(0, exp_log_res) diff --git a/test/acquisition/multi_objective/test_monte_carlo.py b/test/acquisition/multi_objective/test_monte_carlo.py index fd84c2049a..c0219c1364 100644 --- a/test/acquisition/multi_objective/test_monte_carlo.py +++ b/test/acquisition/multi_objective/test_monte_carlo.py @@ -8,11 +8,14 @@ from copy import deepcopy from itertools import product from math import pi +from typing import Any, Dict, Optional, Type from unittest import mock import torch from botorch import settings +from botorch.acquisition import AcquisitionFunction from botorch.acquisition.cached_cholesky import _get_cache_root_not_supported_message +from botorch.acquisition.multi_objective.logei import qLogExpectedHypervolumeImprovement from botorch.acquisition.multi_objective.monte_carlo import ( MultiObjectiveMCAcquisitionFunction, qExpectedHypervolumeImprovement, @@ -50,8 +53,18 @@ FastNondominatedPartitioning, NondominatedPartitioning, ) +from botorch.utils.safe_math import sigmoid from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from botorch.utils.transforms import match_batch_shape, standardize +from torch import Tensor + + +def evaluate(acqf, X: Tensor) -> Tensor: + return ( + acqf(X).exp() + if isinstance(acqf, qLogExpectedHypervolumeImprovement) + else acqf(X) + ) class DummyMultiObjectiveMCAcquisitionFunction(MultiObjectiveMCAcquisitionFunction): @@ -104,500 +117,580 @@ def test_init(self): model=mm, constraints=[lambda Z: -100.0 * torch.ones_like(Z[..., -1])] ) - -class TestQExpectedHypervolumeImprovement(BotorchTestCase): def test_q_expected_hypervolume_improvement(self): - tkwargs = {"device": self.device} for dtype in (torch.float, torch.double): - tkwargs["dtype"] = dtype - ref_point = [0.0, 0.0] - t_ref_point = torch.tensor(ref_point, **tkwargs) - pareto_Y = torch.tensor( - [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs - ) - partitioning = NondominatedPartitioning(ref_point=t_ref_point) - # the event shape is `b x q x m` = 1 x 1 x 2 - samples = torch.zeros(1, 1, 2, **tkwargs) - mm = MockModel(MockPosterior(samples=samples)) - # test error if there is not pareto_Y initialized in partitioning - with self.assertRaises(BotorchError): - qExpectedHypervolumeImprovement( - model=mm, ref_point=ref_point, partitioning=partitioning - ) - partitioning.update(Y=pareto_Y) - # test error if ref point has wrong shape - with self.assertRaises(ValueError): - qExpectedHypervolumeImprovement( - model=mm, ref_point=ref_point[:1], partitioning=partitioning + with self.subTest(dtype=dtype): + self._test_q_expected_hypervolume_improvement( + acqf_class=qExpectedHypervolumeImprovement, dtype=dtype ) - X = torch.zeros(1, 1, **tkwargs) - # basic test - sampler = IIDNormalSampler(sample_shape=torch.Size([1])) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 0.0) - # check ref point - self.assertTrue( - torch.equal(acqf.ref_point, torch.tensor(ref_point, **tkwargs)) - ) - # check cached indices - self.assertTrue(hasattr(acqf, "q_subset_indices")) - self.assertIn("q_choose_1", acqf.q_subset_indices) - self.assertTrue( - torch.equal( - acqf.q_subset_indices["q_choose_1"], - torch.tensor([[0]], device=self.device), + # separating out LogEI test to avoid timeouts + def test_q_log_expected_hypervolume_improvement(self): + for dtype in (torch.float, torch.double): + with self.subTest(dtype=dtype): + self._test_q_expected_hypervolume_improvement( + acqf_class=qLogExpectedHypervolumeImprovement, + dtype=dtype, + # setting tau_max tighter than default to more easily unify tests + # with qEHVI. + acqf_kwargs={"tau_max": 1e-3, "fat": False}, ) - ) - # test q=2 - X2 = torch.zeros(2, 1, **tkwargs) - samples2 = torch.zeros(1, 2, 2, **tkwargs) - mm2 = MockModel(MockPosterior(samples=samples2)) - acqf.model = mm2 - self.assertEqual(acqf.model, mm2) - self.assertIn("model", acqf._modules) - self.assertEqual(acqf._modules["model"], mm2) - res = acqf(X2) - self.assertEqual(res.item(), 0.0) - # check cached indices - self.assertTrue(hasattr(acqf, "q_subset_indices")) - self.assertIn("q_choose_1", acqf.q_subset_indices) - self.assertTrue( - torch.equal( - acqf.q_subset_indices["q_choose_1"], - torch.tensor([[0], [1]], device=self.device), - ) - ) - self.assertIn("q_choose_2", acqf.q_subset_indices) - self.assertTrue( - torch.equal( - acqf.q_subset_indices["q_choose_2"], - torch.tensor([[0, 1]], device=self.device), - ) - ) - self.assertNotIn("q_choose_3", acqf.q_subset_indices) - # now back to 1 and sure all caches were cleared - acqf.model = mm - res = acqf(X) - self.assertNotIn("q_choose_2", acqf.q_subset_indices) - self.assertIn("q_choose_1", acqf.q_subset_indices) - self.assertTrue( - torch.equal( - acqf.q_subset_indices["q_choose_1"], - torch.tensor([[0]], device=self.device), + def test_fat_q_log_expected_hypervolume_improvement(self): + for dtype in (torch.float, torch.double): + with self.subTest(dtype=dtype): + self._test_q_expected_hypervolume_improvement( + acqf_class=qLogExpectedHypervolumeImprovement, + dtype=dtype, + # setting tau_max tighter than default to more easily unify tests + # with qEHVI. + acqf_kwargs={"tau_max": 1e-3, "fat": True}, ) - ) - X = torch.zeros(1, 1, **tkwargs) - samples = torch.zeros(1, 1, 2, **tkwargs) - mm = MockModel(MockPosterior(samples=samples)) - # basic test - sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 0.0) - self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 2])) - bs = acqf.sampler.base_samples.clone() - res = acqf(X) - self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) - - # basic test, qmc - sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, + def _test_q_expected_hypervolume_improvement( + self, + acqf_class: Type[AcquisitionFunction], + dtype: torch.dtype, + acqf_kwargs: Optional[Dict[str, Any]] = None, + ): + if acqf_kwargs is None: + acqf_kwargs = {} + tkwargs = {"device": self.device, "dtype": dtype} + ref_point = [0.0, 0.0] + t_ref_point = torch.tensor(ref_point, **tkwargs) + pareto_Y = torch.tensor( + [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs + ) + partitioning = NondominatedPartitioning(ref_point=t_ref_point) + # the event shape is `b x q x m` = 1 x 1 x 2 + samples = torch.zeros(1, 1, 2, **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + # test error if there is not pareto_Y initialized in partitioning + with self.assertRaises(BotorchError): + acqf_class(model=mm, ref_point=ref_point, partitioning=partitioning) + partitioning.update(Y=pareto_Y) + # test error if ref point has wrong shape + with self.assertRaises(ValueError): + acqf_class(model=mm, ref_point=ref_point[:1], partitioning=partitioning) + + X = torch.zeros(1, 1, **tkwargs) + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([1])) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + # check ref point + self.assertTrue(torch.equal(acqf.ref_point, torch.tensor(ref_point, **tkwargs))) + # check cached indices + self.assertTrue(hasattr(acqf, "q_subset_indices")) + self.assertIn("q_choose_1", acqf.q_subset_indices) + self.assertTrue( + torch.equal( + acqf.q_subset_indices["q_choose_1"], + torch.tensor([[0]], device=self.device), ) - res = acqf(X) - self.assertEqual(res.item(), 0.0) - self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 2])) - bs = acqf.sampler.base_samples.clone() - acqf(X) - self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) - - # basic test for X_pending and warning - acqf.set_X_pending() - self.assertIsNone(acqf.X_pending) - acqf.set_X_pending(None) - self.assertIsNone(acqf.X_pending) - acqf.set_X_pending(X) - self.assertEqual(acqf.X_pending, X) - # get mm sample shape to match shape of X + X_pending - acqf.model._posterior._samples = torch.zeros(1, 2, 2, **tkwargs) - res = acqf(X) - X2 = torch.zeros(1, 1, 1, requires_grad=True, **tkwargs) - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - acqf.set_X_pending(X2) - self.assertEqual(acqf.X_pending, X2) - self.assertEqual( - sum(issubclass(w.category, BotorchWarning) for w in ws), 1 - ) + ) - # test objective - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - objective=IdentityMCMultiOutputObjective(), + # test q=2 + X2 = torch.zeros(2, 1, **tkwargs) + samples2 = torch.zeros(1, 2, 2, **tkwargs) + mm2 = MockModel(MockPosterior(samples=samples2)) + acqf.model = mm2 + self.assertEqual(acqf.model, mm2) + self.assertIn("model", acqf._modules) + self.assertEqual(acqf._modules["model"], mm2) + res = evaluate(acqf, X2) + self.assertAlmostEqual(res.item(), 0.0) + # check cached indices + self.assertTrue(hasattr(acqf, "q_subset_indices")) + self.assertIn("q_choose_1", acqf.q_subset_indices) + self.assertTrue( + torch.equal( + acqf.q_subset_indices["q_choose_1"], + torch.tensor([[0], [1]], device=self.device), ) - # get mm sample shape to match shape of X - acqf.model._posterior._samples = torch.zeros(1, 1, 2, **tkwargs) - res = acqf(X) - self.assertEqual(res.item(), 0.0) - - # Test that the hypervolume improvement is correct for given sample - # test q = 1 - X = torch.zeros(1, 1, **tkwargs) - # basic test - samples = torch.tensor([[[6.5, 4.5]]], **tkwargs) - mm = MockModel(MockPosterior(samples=samples)) - sampler = IIDNormalSampler(sample_shape=torch.Size([1])) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, + ) + self.assertIn("q_choose_2", acqf.q_subset_indices) + self.assertTrue( + torch.equal( + acqf.q_subset_indices["q_choose_2"], + torch.tensor([[0, 1]], device=self.device), ) - res = acqf(X) - self.assertEqual(res.item(), 1.5) - # test q = 1, does not contribute - samples = torch.tensor([0.0, 1.0], **tkwargs).view(1, 1, 2) - sampler = IIDNormalSampler(sample_shape=torch.Size([1])) - mm = MockModel(MockPosterior(samples=samples)) - acqf.model = mm - res = acqf(X) - self.assertEqual(res.item(), 0.0) - - # test q = 2, both points contribute - X = torch.zeros(2, 1, **tkwargs) - samples = torch.tensor([[6.5, 4.5], [7.0, 4.0]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf.model = mm - res = acqf(X) - self.assertEqual(res.item(), 1.75) - - # test q = 2, only 1 point contributes - samples = torch.tensor([[6.5, 4.5], [6.0, 4.0]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf.model = mm - res = acqf(X) - self.assertEqual(res.item(), 1.5) - - # test q = 2, neither contributes - samples = torch.tensor([[2.0, 2.0], [0.0, 0.1]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf.model = mm - res = acqf(X) - self.assertEqual(res.item(), 0.0) + ) + self.assertNotIn("q_choose_3", acqf.q_subset_indices) + # now back to 1 and sure all caches were cleared + acqf.model = mm + res = evaluate(acqf, X) + self.assertNotIn("q_choose_2", acqf.q_subset_indices) + self.assertIn("q_choose_1", acqf.q_subset_indices) + self.assertTrue( + torch.equal( + acqf.q_subset_indices["q_choose_1"], + torch.tensor([[0]], device=self.device), + ) + ) - # test q = 2, test point better than current best second objective - samples = torch.tensor([[6.5, 4.5], [6.0, 6.0]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf.model = mm - res = acqf(X) - self.assertEqual(res.item(), 8.0) + X = torch.zeros(1, 1, **tkwargs) + samples = torch.zeros(1, 1, 2, **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 2])) + bs = acqf.sampler.base_samples.clone() + res = evaluate(acqf, X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # basic test, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 2])) + bs = acqf.sampler.base_samples.clone() + evaluate(acqf, X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # basic test for X_pending and warning + acqf.set_X_pending() + self.assertIsNone(acqf.X_pending) + acqf.set_X_pending(None) + self.assertIsNone(acqf.X_pending) + acqf.set_X_pending(X) + self.assertEqual(acqf.X_pending, X) + # get mm sample shape to match shape of X + X_pending + acqf.model._posterior._samples = torch.zeros(1, 2, 2, **tkwargs) + res = evaluate(acqf, X) + X2 = torch.zeros(1, 1, 1, requires_grad=True, **tkwargs) + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + acqf.set_X_pending(X2) + self.assertEqual(acqf.X_pending, X2) + self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) + + # test objective + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + objective=IdentityMCMultiOutputObjective(), + **acqf_kwargs, + ) + # get mm sample shape to match shape of X + acqf.model._posterior._samples = torch.zeros(1, 1, 2, **tkwargs) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + + # Test that the hypervolume improvement is correct for given sample + # test q = 1 + X = torch.zeros(1, 1, **tkwargs) + # basic test + samples = torch.tensor([[[6.5, 4.5]]], **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + sampler = IIDNormalSampler(sample_shape=torch.Size([1])) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + tol = ( + 1e-3 # due to smooth approximations + if isinstance(acqf, qLogExpectedHypervolumeImprovement) + else 1e-12 + ) + self.assertAlmostEqual(res.item(), 1.5, delta=tol) + # test q = 1, does not contribute + samples = torch.tensor([0.0, 1.0], **tkwargs).view(1, 1, 2) + sampler = IIDNormalSampler(sample_shape=torch.Size([1])) + mm = MockModel(MockPosterior(samples=samples)) + acqf.model = mm + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + + # test q = 2, both points contribute + X = torch.zeros(2, 1, **tkwargs) + samples = torch.tensor([[6.5, 4.5], [7.0, 4.0]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf.model = mm + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 1.75, delta=tol) + + # test q = 2, only 1 point contributes + samples = torch.tensor([[6.5, 4.5], [6.0, 4.0]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf.model = mm + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 1.5, delta=tol) + + # test q = 2, neither contributes + samples = torch.tensor([[2.0, 2.0], [0.0, 0.1]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf.model = mm + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) + + # test q = 2, test point better than current best second objective + samples = torch.tensor([[6.5, 4.5], [6.0, 6.0]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf.model = mm + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 8.0, delta=tol) + + # test q = 2, test point better than current-best first objective + samples = torch.tensor([[6.5, 4.5], [9.0, 2.0]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + # Giving this a higher tolerance due to higher q + self.assertAlmostEqual(res.item(), 2.0, delta=tol) + # test q = 3, all contribute + X = torch.zeros(3, 1, **tkwargs) + samples = torch.tensor( + [[6.5, 4.5], [9.0, 2.0], [7.0, 4.0]], **tkwargs + ).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 2.25, delta=tol) + # test q = 3, not all contribute + samples = torch.tensor( + [[6.5, 4.5], [9.0, 2.0], [7.0, 5.0]], **tkwargs + ).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 3.5, delta=tol) + # test q = 3, none contribute + samples = torch.tensor( + [[0.0, 4.5], [1.0, 2.0], [3.0, 0.0]], **tkwargs + ).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0) - # test q = 2, test point better than current-best first objective - samples = torch.tensor([[6.5, 4.5], [9.0, 2.0]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 2.0) - # test q = 3, all contribute - X = torch.zeros(3, 1, **tkwargs) - samples = torch.tensor( - [[6.5, 4.5], [9.0, 2.0], [7.0, 4.0]], **tkwargs - ).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 2.25) - # test q = 3, not all contribute - samples = torch.tensor( - [[6.5, 4.5], [9.0, 2.0], [7.0, 5.0]], **tkwargs - ).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 3.5) - # test q = 3, none contribute - samples = torch.tensor( - [[0.0, 4.5], [1.0, 2.0], [3.0, 0.0]], **tkwargs - ).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 0.0) + # test m = 3, q=1 + pareto_Y = torch.tensor( + [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0], [1.0, 3.0, 4.0]], + **tkwargs, + ) + ref_point = [-1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + samples = torch.tensor([[1.0, 2.0, 6.0]], **tkwargs).unsqueeze(0) + mm = MockModel(MockPosterior(samples=samples)) - # test m = 3, q=1 - pareto_Y = torch.tensor( - [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0], [1.0, 3.0, 4.0]], - **tkwargs, - ) - ref_point = [-1.0] * 3 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) - samples = torch.tensor([[1.0, 2.0, 6.0]], **tkwargs).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + X = torch.zeros(1, 2, **tkwargs) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 12.0, delta=tol) + + # change reference point + ref_point = [0.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 4.0, delta=tol) + + # test m = 3, no contribution + ref_point = [1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0, delta=tol) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - X = torch.zeros(1, 2, **tkwargs) - res = acqf(X) - self.assertEqual(res.item(), 12.0) - - # change reference point - ref_point = [0.0] * 3 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 4.0) - - # test m = 3, no contribution - ref_point = [1.0] * 3 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - res = acqf(X) - self.assertEqual(res.item(), 0.0) + # test m = 3, q = 2 + pareto_Y = torch.tensor( + [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0]], **tkwargs + ) + samples = torch.tensor([[1.0, 2.0, 6.0], [1.0, 3.0, 4.0]], **tkwargs).unsqueeze( + 0 + ) + mm = MockModel(MockPosterior(samples=samples)) + ref_point = [-1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + X = torch.zeros(2, 2, **tkwargs) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 22.0, delta=22.0 * tol) - # test m = 3, q = 2 - pareto_Y = torch.tensor( - [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0]], **tkwargs - ) - samples = torch.tensor( - [[1.0, 2.0, 6.0], [1.0, 3.0, 4.0]], **tkwargs - ).unsqueeze(0) - mm = MockModel(MockPosterior(samples=samples)) - ref_point = [-1.0] * 3 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) - acqf = qExpectedHypervolumeImprovement( + # test batched model + pareto_Y = torch.tensor( + [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0]], **tkwargs + ) + samples = torch.tensor([[1.0, 2.0, 6.0], [1.0, 3.0, 4.0]], **tkwargs).unsqueeze( + 0 + ) + samples = torch.stack([samples, samples + 1], dim=1) + mm = MockModel(MockPosterior(samples=samples)) + ref_point = [-1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + **acqf_kwargs, + ) + X = torch.zeros(2, 2, **tkwargs) + res = evaluate(acqf, X) + self.assertAllClose( + res, + # batch_shape x model_batch_shape + torch.tensor([[22.0, 60.0]], **tkwargs), + rtol=tol, + ) + # test batched model with batched partitioning with multiple batch dims + pareto_Y = torch.tensor( + [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs + ) + pareto_Y = torch.stack( + [ + pareto_Y, + pareto_Y + 0.5, + ], + dim=0, + ) + samples = torch.tensor([[6.5, 4.5], [7.0, 4.0]], **tkwargs).unsqueeze(0) + samples = torch.stack([samples, samples + 1], dim=1) + mm = MockModel(MockPosterior(samples=samples)) + ref_point = [-1.0] * 2 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = FastNondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) + cell_bounds = partitioning.get_hypercell_bounds().unsqueeze(1) + with mock.patch.object( + partitioning, "get_hypercell_bounds", return_value=cell_bounds + ): + acqf = acqf_class( model=mm, ref_point=ref_point, partitioning=partitioning, sampler=sampler, - ) - X = torch.zeros(2, 2, **tkwargs) - res = acqf(X) - self.assertEqual(res.item(), 22.0) + **acqf_kwargs, + ) + # test multiple batch dims + self.assertEqual(acqf.cell_lower_bounds.shape, torch.Size([1, 2, 4, 2])) + self.assertEqual(acqf.cell_upper_bounds.shape, torch.Size([1, 2, 4, 2])) + X = torch.zeros(2, 2, **tkwargs) + res = evaluate(acqf, X) + self.assertAllClose( + res, + # batch_shape x model_batch_shape + torch.tensor([[1.75, 3.5]], dtype=samples.dtype, device=samples.device), + rtol=tol, + ) - # test batched model - pareto_Y = torch.tensor( - [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0]], **tkwargs - ) - samples = torch.tensor( - [[1.0, 2.0, 6.0], [1.0, 3.0, 4.0]], **tkwargs - ).unsqueeze(0) - samples = torch.stack([samples, samples + 1], dim=1) - mm = MockModel(MockPosterior(samples=samples)) - ref_point = [-1.0] * 3 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - ) - X = torch.zeros(2, 2, **tkwargs) - res = acqf(X) - self.assertTrue( - torch.equal( - res, - # batch_shape x model_batch_shape - torch.tensor([[22.0, 60.0]], **tkwargs), - ) - ) - # test batched model with batched partitioning with multiple batch dims - pareto_Y = torch.tensor( - [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs - ) - pareto_Y = torch.stack( - [ - pareto_Y, - pareto_Y + 0.5, - ], - dim=0, - ) - samples = torch.tensor([[6.5, 4.5], [7.0, 4.0]], **tkwargs).unsqueeze(0) - samples = torch.stack([samples, samples + 1], dim=1) - mm = MockModel(MockPosterior(samples=samples)) - ref_point = [-1.0] * 2 - t_ref_point = torch.tensor(ref_point, **tkwargs) - partitioning = FastNondominatedPartitioning( - ref_point=t_ref_point, Y=pareto_Y - ) - cell_bounds = partitioning.get_hypercell_bounds().unsqueeze(1) - with mock.patch.object( - partitioning, "get_hypercell_bounds", return_value=cell_bounds - ): - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, + def test_constrained_q_expected_hypervolume_improvement(self): + for dtype in (torch.float, torch.double): + with self.subTest(dtype=dtype): + self._test_constrained_q_expected_hypervolume_improvement( + acqf_class=qExpectedHypervolumeImprovement, + dtype=dtype, ) - # test multiple batch dims - self.assertEqual(acqf.cell_lower_bounds.shape, torch.Size([1, 2, 4, 2])) - self.assertEqual(acqf.cell_upper_bounds.shape, torch.Size([1, 2, 4, 2])) - X = torch.zeros(2, 2, **tkwargs) - res = acqf(X) - self.assertTrue( - torch.equal( - res, - # batch_shape x model_batch_shape - torch.tensor( - [[1.75, 3.5]], dtype=samples.dtype, device=samples.device - ), + + def test_constrained_q_log_expected_hypervolume_improvement(self): + for dtype in (torch.float, torch.double): + with self.subTest(dtype=dtype): + self._test_constrained_q_expected_hypervolume_improvement( + acqf_class=qLogExpectedHypervolumeImprovement, + dtype=dtype, + # setting tau_max tighter than default to more easily unify tests + # with qEHVI. + acqf_kwargs={"tau_max": 1e-3, "fat": False}, ) - ) - def test_constrained_q_expected_hypervolume_improvement(self): + def test_fat_constrained_q_log_expected_hypervolume_improvement(self): for dtype in (torch.float, torch.double): - tkwargs = {"device": self.device, "dtype": dtype} - ref_point = [0.0, 0.0] - t_ref_point = torch.tensor(ref_point, **tkwargs) - pareto_Y = torch.tensor( - [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs - ) - partitioning = NondominatedPartitioning(ref_point=t_ref_point) - partitioning.update(Y=pareto_Y) + with self.subTest(dtype=dtype): + self._test_constrained_q_expected_hypervolume_improvement( + acqf_class=qLogExpectedHypervolumeImprovement, + dtype=dtype, + # setting tau_max tighter than default to more easily unify tests + # with qEHVI. + acqf_kwargs={"tau_max": 1e-3, "fat": True}, + ) - # test q=1 - # the event shape is `b x q x m` = 1 x 1 x 2 - samples = torch.tensor([[[6.5, 4.5]]], **tkwargs) - mm = MockModel(MockPosterior(samples=samples)) - sampler = IIDNormalSampler(sample_shape=torch.Size([1])) - X = torch.zeros(1, 1, **tkwargs) - # test zero slack - for eta in (1e-1, 1e-2): - expected_values = [0.5 * 1.5, 0.5 * 0.5 * 1.5] - for i, constraints in enumerate( + def _test_constrained_q_expected_hypervolume_improvement( + self, + acqf_class: Type[AcquisitionFunction], + dtype: torch.dtype, + acqf_kwargs: Optional[Dict[str, Any]] = None, + ): + if acqf_kwargs is None: + acqf_kwargs = {} + tkwargs = {"device": self.device, "dtype": dtype} + ref_point = [0.0, 0.0] + t_ref_point = torch.tensor(ref_point, **tkwargs) + pareto_Y = torch.tensor( + [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs + ) + partitioning = NondominatedPartitioning(ref_point=t_ref_point) + partitioning.update(Y=pareto_Y) + + # test q=1 + # the event shape is `b x q x m` = 1 x 1 x 2 + samples = torch.tensor([[[6.5, 4.5]]], **tkwargs) + mm = MockModel(MockPosterior(samples=samples)) + sampler = IIDNormalSampler(sample_shape=torch.Size([1])) + X = torch.zeros(1, 1, **tkwargs) + + def get_zeros_like_last(Z: Tensor) -> Tensor: + return torch.zeros_like(Z[..., -1]) + + # test zero slack + for eta in (1e-1, 1e-2): + expected_values = [0.5 * 1.5, 0.5 * 0.5 * 1.5] + for i, constraints in enumerate( + [ + [get_zeros_like_last], [ - [lambda Z: torch.zeros_like(Z[..., -1])], - [ - lambda Z: torch.zeros_like(Z[..., -1]), - lambda Z: torch.zeros_like(Z[..., -1]), - ], - ] - ): - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - constraints=constraints, - eta=eta, - ) - res = acqf(X) - self.assertAlmostEqual(res.item(), expected_values[i], places=4) - # test multiple constraints one and multiple etas - constraints = [ - lambda Z: torch.ones_like(Z[..., -1]), - lambda Z: torch.ones_like(Z[..., -1]), - ] - etas = [1, torch.tensor([1, 10])] - expected_values = [ - ( - torch.sigmoid(torch.as_tensor(-1.0)) - * torch.sigmoid(torch.as_tensor(-1.0)) - * 1.5 - ).item(), - ( - torch.sigmoid(torch.as_tensor(-1.0)) - * torch.sigmoid(torch.as_tensor(-1.0 / 10.0)) - * 1.5 - ).item(), - ] - for eta, expected_value in zip(etas, expected_values): - acqf = qExpectedHypervolumeImprovement( + get_zeros_like_last, + get_zeros_like_last, + ], + ] + ): + acqf = acqf_class( model=mm, ref_point=ref_point, partitioning=partitioning, sampler=sampler, constraints=constraints, eta=eta, + **acqf_kwargs, ) - res = acqf(X) - self.assertAlmostEqual( - res.item(), - expected_value, - places=4, - ) - # test feasible - acqf = qExpectedHypervolumeImprovement( - model=mm, - ref_point=ref_point, - partitioning=partitioning, - sampler=sampler, - constraints=[lambda Z: -100.0 * torch.ones_like(Z[..., -1])], - eta=1e-3, - ) - res = acqf(X) - self.assertAlmostEqual(res.item(), 1.5, places=4) - # test infeasible - acqf = qExpectedHypervolumeImprovement( + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), expected_values[i], places=4) + # test multiple constraints one and multiple etas + constraints = [ + lambda Z: torch.ones_like(Z[..., -1]), + lambda Z: torch.ones_like(Z[..., -1]), + ] + etas = [1, torch.tensor([1, 10])] + fat = getattr(acqf, "fat", False) + sigmoid_neg1 = sigmoid(torch.as_tensor(-1.0), fat=fat) + expected_values = [ + (sigmoid_neg1**2 * 1.5).item(), + ( + sigmoid_neg1 * sigmoid(torch.as_tensor(-1.0 / 10.0), fat=fat) * 1.5 + ).item(), + ] + for eta, expected_value in zip(etas, expected_values): + acqf = acqf_class( model=mm, ref_point=ref_point, partitioning=partitioning, sampler=sampler, - constraints=[lambda Z: 100.0 * torch.ones_like(Z[..., -1])], - eta=1e-3, - ) - res = acqf(X) - self.assertAlmostEqual(res.item(), 0.0, places=4) + constraints=constraints, + eta=eta, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual( + res.item(), + expected_value, + places=4, + ) + # test feasible + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + constraints=[lambda Z: -100.0 * torch.ones_like(Z[..., -1])], + eta=1e-3, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 1.5, places=4) + # test infeasible + acqf = acqf_class( + model=mm, + ref_point=ref_point, + partitioning=partitioning, + sampler=sampler, + constraints=[lambda Z: 100.0 * torch.ones_like(Z[..., -1])], + eta=1e-3, + **acqf_kwargs, + ) + res = evaluate(acqf, X) + self.assertAlmostEqual(res.item(), 0.0, places=4) - # TODO: Test non-trivial constraint values, multiple constraints, and q > 1 + # TODO: Test non-trivial constraint values, multiple constraints, and q > 1 class TestQNoisyExpectedHypervolumeImprovement(BotorchTestCase): @@ -667,7 +760,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates acqf.model._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertEqual(res.item(), 0.0) # check ref point self.assertTrue( @@ -701,7 +794,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates acqf.model._posterior._samples = samples2 - res = acqf(X2) + res = evaluate(acqf, X2) self.assertEqual(res.item(), 0.0) # check cached indices self.assertTrue(hasattr(acqf, "q_subset_indices")) @@ -722,7 +815,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): self.assertNotIn("q_choose_3", acqf.q_subset_indices) # now back to 1 and sure all caches were cleared acqf.model = mm - res = acqf(X) + res = evaluate(acqf, X) self.assertNotIn("q_choose_2", acqf.q_subset_indices) self.assertIn("q_choose_1", acqf.q_subset_indices) self.assertTrue( @@ -760,7 +853,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertEqual(res.item(), 0.0) # test that original base samples were retained self.assertTrue( @@ -787,7 +880,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): # candidates mm._posterior._samples = samples with torch.no_grad(): - acqf(X) + evaluate(acqf, X) self.assertTrue( torch.equal( orig_base_sampler.base_samples, acqf.base_sampler.base_samples @@ -882,7 +975,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): # test forward X_test = torch.rand(1, 1, dtype=dtype, device=self.device) with torch.no_grad(): - val = acqf(X_test) + val = evaluate(acqf, X_test) bd.update(mm._posterior._samples[0, -1:]) expected_val = bd.compute_hypervolume() - initial_hv self.assertTrue(torch.equal(val, expected_val.view(-1))) @@ -963,7 +1056,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): ).unsqueeze(0) X_test = torch.rand(1, 1, dtype=dtype, device=self.device) with torch.no_grad(): - val = acqf(X_test) + val = evaluate(acqf, X_test) if incremental_nehvi: # set initial hv to include X_pending initial_hv = bd.compute_hypervolume() @@ -1067,7 +1160,7 @@ def test_q_noisy_expected_hypervolume_improvement(self): ] ).unsqueeze(0) with torch.no_grad(): - val = acqf(X_test) + val = evaluate(acqf, X_test) bd = DominatedPartitioning( ref_point=torch.tensor(ref_point).to(**tkwargs), Y=pareto_Y ) @@ -1158,7 +1251,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 0.5 * 0.5 * 1.5, places=4) # test zero slack single constraint for eta in (1e-1, 1e-2): @@ -1177,7 +1270,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 0.5 * 1.5, places=4) # set X_pending X_pending = torch.rand(1, 1, **tkwargs) @@ -1190,7 +1283,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): dim=1, ) mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 0.5 * 0.5, places=4) # test incremental nehvi=False @@ -1213,7 +1306,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): dim=1, ) mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 0.5 * 1.5, places=4) acqf.set_X_pending(X_pending) samples = torch.cat( @@ -1224,7 +1317,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): dim=1, ) mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) # test that HVI is not incremental # Note that the cached pending point uses strict constraint evaluation # so the HVI from the cached pending point is 1.5. @@ -1253,7 +1346,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): dim=1, ) mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 1.5, places=4) # test multiple constraints one eta with # this crashes for large etas, and I do not why @@ -1279,7 +1372,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): ): acqf.constraints = constraint acqf.eta = eta - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual( res.item(), @@ -1302,7 +1395,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): # set the MockPosterior to use samples over baseline points and new # candidates mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 0.0, places=4) # test >2 objectives @@ -1338,7 +1431,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self): dim=1, ) mm._posterior._samples = samples - res = acqf(X) + res = evaluate(acqf, X) self.assertAlmostEqual(res.item(), 1.5, places=4) def test_prune_baseline(self): @@ -1459,7 +1552,7 @@ def test_cache_root(self): sample_cached_path, wraps=sample_cached_cholesky ) as mock_sample_cached: torch.manual_seed(0) - val = acqf(test_X) + val = evaluate(acqf, test_X) mock_sample_cached.assert_called_once() val.sum().backward() base_samples = acqf.sampler.base_samples.detach().clone() @@ -1487,7 +1580,7 @@ def test_cache_root(self): acqf._baseline_L = torch.zeros_like(acqf._baseline_L) with warnings.catch_warnings(record=True) as ws, settings.debug(True): with torch.no_grad(): - acqf(test_X) + evaluate(acqf, test_X) self.assertEqual( sum(issubclass(w.category, BotorchWarning) for w in ws), 1 ) @@ -1544,7 +1637,7 @@ def forward(self, samples, X=None): dim=1, ) acqf.model._posterior._samples = samples - res = acqf(test_x) + res = evaluate(acqf, test_x) self.assertTrue(torch.equal(res, torch.zeros(3, **tkwargs))) self.assertEqual(acqf.q_in, 6) self.assertEqual(acqf.q_out, 4) @@ -1567,7 +1660,7 @@ def test_deterministic(self): ) self.assertFalse(acqf._cache_root) self.assertEqual( - acqf(torch.rand(3, 2, 2, **tkwargs)).shape, torch.Size([3]) + evaluate(acqf, torch.rand(3, 2, 2, **tkwargs)).shape, torch.Size([3]) ) def test_with_multitask(self): diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index 83b66b8bf0..1174d4fca5 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -223,12 +223,10 @@ def test_q_log_expected_improvement_batch(self): acqf = qLogExpectedImprovement(model=mm, best_f=0, sampler=sampler) exp_log_res = acqf(X).exp() # with no approximations (qEI): self.assertEqual(res[0].item(), 1.0) - # in the batch case, the values get adjusted toward self.assertEqual(exp_log_res.dtype, dtype) self.assertEqual(exp_log_res.device.type, self.device.type) self.assertTrue(1.0 <= exp_log_res[0].item()) self.assertTrue(exp_log_res[0].item() <= 1.0 + acqf.tau_relu) - # self.assertAllClose(exp_log_res[0], torch.ones_like(exp_log_res[0]), ) # with no approximations (qEI): self.assertEqual(res[1].item(), 0.0) self.assertTrue(0 < exp_log_res[1].item())