diff --git a/examples/contrib/oed/ab_test.py b/examples/contrib/oed/ab_test.py index 5aca37fcd6..6d7b9b3ea5 100644 --- a/examples/contrib/oed/ab_test.py +++ b/examples/contrib/oed/ab_test.py @@ -1,12 +1,16 @@ import argparse import torch +from torch.distributions import constraints import numpy as np import pyro import pyro.distributions as dist from pyro import optim -from pyro.infer import Trace_ELBO +from pyro.infer import TraceEnum_ELBO from pyro.contrib.oed.eig import vi_ape +import pyro.contrib.gp as gp + +from gp_bayes_opt import GPBayesOptimizer """ Example builds on the Bayesian regression tutorial [1]. It demonstrates how @@ -35,7 +39,7 @@ N = 100 # number of participants p_treatments = 2 # number of treatment groups p = p_treatments # number of features -prior_stdevs = torch.tensor([1, .5]) +prior_stdevs = torch.tensor([10., 2.5]) softplus = torch.nn.functional.softplus @@ -93,61 +97,70 @@ def design_to_matrix(design): if i > 0: X[t:t+i, col] = 1. t += i + if t < n: + X[t:, -1] = 1. return X def analytic_posterior_entropy(prior_cov, x): - posterior_cov = prior_cov - prior_cov.mm(x.t().mm(torch.inverse( - x.mm(prior_cov.mm(x.t())) + torch.eye(N)).mm(x.mm(prior_cov)))) - return 0.5*torch.logdet(2*np.pi*np.e*posterior_cov) + # Use some kernel trick magic + SigmaXX = prior_cov.mm(x.t().mm(x)) + posterior_cov = prior_cov - torch.inverse( + SigmaXX + torch.eye(p)).mm(SigmaXX.mm(prior_cov)) + y = 0.5*torch.logdet(2*np.pi*np.e*posterior_cov) + return y -def main(num_steps): +def main(num_vi_steps, num_acquisitions, num_bo_steps): pyro.set_rng_seed(42) pyro.clear_param_store() - ns = range(0, N, 5) - designs = [design_to_matrix(torch.tensor([n1, N-n1])) for n1 in ns] - X = torch.stack(designs) - - # Estimated loss (linear transform of EIG) - est_ape = vi_ape( - model, - X, - observation_labels="y", - vi_parameters={ - "guide": guide, - "optim": optim.Adam({"lr": 0.0025}), - "loss": Trace_ELBO(), - "num_steps": num_steps}, - is_parameters={"num_samples": 2} - ) - - # Analytic loss - true_ape = [] - prior_cov = torch.diag(prior_stdevs**2) - for i in range(len(ns)): - x = X[i, :, :] - true_ape.append(analytic_posterior_entropy(prior_cov, x)) - - print("Estimated APE values") - print(est_ape) - print("True APE values") - print(true_ape) - - # # Plot to compare - # import matplotlib.pyplot as plt - # ns = np.array(ns) - # est_ape = np.array(est_ape.detach()) - # true_ape = np.array(true_ape) - # plt.scatter(ns, est_ape) - # plt.scatter(ns, true_ape, color='r') - # plt.show() + def estimated_ape(ns): + designs = [design_to_matrix(torch.tensor([n1, N-n1])) for n1 in ns] + X = torch.stack(designs) + est_ape = vi_ape( + model, + X, + observation_labels="y", + vi_parameters={ + "guide": guide, + "optim": optim.Adam({"lr": 0.0025}), + "loss": TraceEnum_ELBO(strict_enumeration_warning=False).differentiable_loss, + "num_steps": num_vi_steps}, + is_parameters={"num_samples": 1} + ) + return est_ape + + def true_ape(ns): + true_ape = [] + prior_cov = torch.diag(prior_stdevs**2) + designs = [design_to_matrix(torch.tensor([n1, N-n1])) for n1 in ns] + for i in range(len(ns)): + x = designs[i] + true_ape.append(analytic_posterior_entropy(prior_cov, x)) + return torch.tensor(true_ape) + + for f in [true_ape, estimated_ape]: + X = torch.tensor([25., 75.]) + y = f(X) + pyro.clear_param_store() + gpmodel = gp.models.GPRegression( + X, y, gp.kernels.Matern52(input_dim=1, lengthscale=torch.tensor(5.)), + noise=torch.tensor(0.1), jitter=1e-6) + gpmodel.optimize(loss=TraceEnum_ELBO(strict_enumeration_warning=False).differentiable_loss) + gpbo = GPBayesOptimizer(constraints.interval(0, 100), gpmodel, + num_acquisitions=num_acquisitions) + for i in range(num_bo_steps): + result = gpbo.get_step(f, None) + + print(result) if __name__ == "__main__": parser = argparse.ArgumentParser(description="A/B test experiment design using VI") - parser.add_argument("-n", "--num-steps", nargs="?", default=3000, type=int) + parser.add_argument("-n", "--num-vi-steps", nargs="?", default=5000, type=int) + parser.add_argument('--num-acquisitions', nargs="?", default=10, type=int) + parser.add_argument('--num-bo-steps', nargs="?", default=6, type=int) args = parser.parse_args() - main(args.num_steps) + main(args.num_vi_steps, args.num_acquisitions, args.num_bo_steps) diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py new file mode 100644 index 0000000000..eca4efd09f --- /dev/null +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -0,0 +1,120 @@ +import torch +import torch.autograd as autograd +import torch.optim as optim +from torch.distributions import transform_to + +import pyro +from pyro.infer import TraceEnum_ELBO + + +class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): + """Performs Bayesian Optimization using a Gaussian Process as an + emulator for the unknown function. + """ + + def __init__(self, constraints, gpmodel, num_acquisitions, acquisition_func=None): + """ + :param torch.constraint constraints: constraints defining the domain of `f` + :param gp.models.GPRegression gpmodel: a (possibly initialized) GP + regression model. The kernel, etc is specified via `gpmodel`. + :param int num_acquisitions: number of points to acquire at each step + :param function acquisition_func: a function to generate acquisitions. + It should return a torch.Tensor of new points to query. + """ + if acquisition_func is None: + acquisition_func = self.acquire_thompson + + self.constraints = constraints + self.gpmodel = gpmodel + self.num_acquisitions = num_acquisitions + self.acquisition_func = acquisition_func + + def update_posterior(self, X, y): + X = torch.cat([self.gpmodel.X, X]) + y = torch.cat([self.gpmodel.y, y]) + self.gpmodel.set_data(X, y) + self.gpmodel.optimize(loss=TraceEnum_ELBO(strict_enumeration_warning=False).differentiable_loss) + + def find_a_candidate(self, differentiable, x_init): + """Given a starting point, `x_init`, takes one LBFGS step + to optimize the differentiable function. + + :param function differentiable: a function amenable to torch + autograd + :param torch.Tensor x_init: the initial point + + """ + # transform x to an unconstrained domain + unconstrained_x_init = transform_to(self.constraints).inv(x_init) + unconstrained_x = unconstrained_x_init.new_tensor( + unconstrained_x_init, requires_grad=True) + # TODO: Use LBFGS with line search by pytorch #8824 merged + minimizer = optim.LBFGS([unconstrained_x], max_eval=20) + + def closure(): + minimizer.zero_grad() + if (torch.log(torch.abs(unconstrained_x)) > 25.).any(): + return torch.tensor(float('inf')) + x = transform_to(self.constraints)(unconstrained_x) + y = differentiable(x) + autograd.backward(unconstrained_x, + autograd.grad(y, unconstrained_x, retain_graph=True)) + return y + + minimizer.step(closure) + # after finding a candidate in the unconstrained domain, + # convert it back to original domain. + x = transform_to(self.constraints)(unconstrained_x) + opt_y = differentiable(x) + return x.detach(), opt_y.detach() + + def opt_differentiable(self, differentiable, num_candidates=5): + """Optimizes a differentiable function by choosing `num_candidates` + initial points at random and calling :func:`find_a_candidate` on + each. The best candidate is returned with its function value. + + :param function differentiable: a function amenable to torch autograd + :param int num_candidates: the number of random starting points to + use + :return: the minimiser and its function value + :rtype: tuple + """ + + candidates = [] + values = [] + for j in range(num_candidates): + x_init = self.gpmodel.X.new_empty(1).uniform_( + self.constraints.lower_bound, self.constraints.upper_bound) + x, y = self.find_a_candidate(differentiable, x_init) + candidates.append(x) + values.append(y) + + mvalue, argmin = torch.min(torch.cat(values), dim=0) + return candidates[argmin.item()], mvalue + + def acquire_thompson(self, num_acquisitions=1, **opt_params): + """Selects `num_acquisitions` query points at which to query the + original function by Thompson sampling. + + :param int num_acquisitions: the number of points to generate + :param dict opt_params: additional parameters for optimization + routines + :return: a tensor of points to evaluate `self.f` at + :rtype: torch.Tensor + """ + + # Initialize the return tensor + X = self.gpmodel.X.new_empty(num_acquisitions, *self.gpmodel.X.shape[1:]) + + for i in range(num_acquisitions): + sampler = self.gpmodel.iter_sample(noiseless=False) + x, _ = self.opt_differentiable(sampler, **opt_params) + X[i, ...] = x + + return X + + def get_step(self, loss, params): + X = self.acquisition_func(num_acquisitions=self.num_acquisitions) + y = loss(X) + self.update_posterior(X, y) + return self.opt_differentiable(lambda x: self.gpmodel(x)[0]) diff --git a/pyro/contrib/gp/models/gpr.py b/pyro/contrib/gp/models/gpr.py index 77f80ae438..8b7ea0c758 100644 --- a/pyro/contrib/gp/models/gpr.py +++ b/pyro/contrib/gp/models/gpr.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function -from torch.distributions import constraints +import torch +import torch.distributions as torchdist from torch.nn import Parameter import pyro @@ -8,6 +9,7 @@ from pyro.contrib.gp.models.model import GPModel from pyro.contrib.gp.util import conditional from pyro.params import param_with_module_name +from pyro.util import warn_if_nan class GPRegression(GPModel): @@ -69,7 +71,7 @@ def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6, noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) - self.set_constraint("noise", constraints.greater_than(self.jitter)) + self.set_constraint("noise", torchdist.constraints.greater_than(self.jitter)) def model(self): self.set_mode("model") @@ -141,3 +143,76 @@ def forward(self, Xnew, full_cov=False, noiseless=True): cov = cov + noise return loc + self.mean_function(Xnew), cov + + def iter_sample(self, noiseless=True): + r""" + Iteratively constructs a sample from the Gaussian Process posterior. + + Recall that at test input points :math:`X_{new}`, the posterior is + multivariate Gaussian distributed with mean and covariance matrix + given by :func:`forward`. + + This method samples lazily from this multivariate Gaussian. The advantage + of this approach is that later query points can depend upon earlier ones. + Particularly useful when the querying is to be done by an optimisation + routine. + + .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with + kernel's parameters have been learned from a training procedure (MCMC or + SVI). + + :param bool noiseless: A flag to decide if we want to add sampling noise + to the samples beyond the noise inherent in the GP posterior. + :returns: sampler + :rtype: function + """ + noise = self.guide().detach() + X = self.X.clone().detach() + y = self.y.clone().detach() + N = X.shape[0] + Kff = self.kernel(X).contiguous() + Kff.view(-1)[::N + 1] += noise # add noise to the diagonal + + outside_vars = {"X": X, "y": y, "N": N, "Kff": Kff} + + def sample_next(xnew, outside_vars): + """Repeatedly samples from the Gaussian process posterior, + conditioning on previously sampled values. + """ + warn_if_nan(xnew) + + # Variables from outer scope + X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"] + + # Compute Cholesky decomposition of kernel matrix + Lff = Kff.potrf(upper=False) + y_residual = y - self.mean_function(X) + + # Compute conditional mean and variance + loc, cov = conditional(xnew, X, self.kernel, y_residual, None, Lff, + False, jitter=self.jitter) + if not noiseless: + cov = cov + noise + + ynew = torchdist.Normal(loc + self.mean_function(xnew), cov.sqrt()).rsample() + + # Update kernel matrix + N = outside_vars["N"] + Kffnew = Kff.new_empty(N+1, N+1) + Kffnew[:N, :N] = Kff + cross = self.kernel(X, xnew).squeeze() + end = self.kernel(xnew, xnew).squeeze() + Kffnew[N, :N] = cross + Kffnew[:N, N] = cross + # No noise, just jitter for numerical stability + Kffnew[N, N] = end + self.jitter + # Heuristic to avoid adding degenerate points + if Kffnew.logdet() > -15.: + outside_vars["Kff"] = Kffnew + outside_vars["N"] += 1 + outside_vars["X"] = torch.cat((X, xnew)) + outside_vars["y"] = torch.cat((y, ynew)) + + return ynew + + return lambda xnew: sample_next(xnew, outside_vars) diff --git a/pyro/infer/svi.py b/pyro/infer/svi.py index 121c1a7528..529938f462 100644 --- a/pyro/infer/svi.py +++ b/pyro/infer/svi.py @@ -53,7 +53,7 @@ def __init__(self, if loss_and_grads is None: def _loss_and_grads(*args, **kwargs): loss_val = loss(*args, **kwargs) - loss_val.backward() + loss_val.backward(retain_graph=True) return loss_val loss_and_grads = _loss_and_grads self.loss = loss diff --git a/pyro/optim/__init__.py b/pyro/optim/__init__.py index 4cd0be4745..a7c4486c26 100644 --- a/pyro/optim/__init__.py +++ b/pyro/optim/__init__.py @@ -4,6 +4,7 @@ from pyro.optim.optim import AdagradRMSProp, ClippedAdam, PyroOptim from pyro.optim.pytorch_optimizers import __all__ as pytorch_optims from pyro.optim.pytorch_optimizers import * # noqa F403 +import pyro.optim.multi # noqa F403 __all__ = [ "AdagradRMSProp", diff --git a/tests/test_examples.py b/tests/test_examples.py index 048d66a947..0964fd831e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -19,7 +19,7 @@ ['contrib/autoname/mixture.py', '--num-epochs=1'], ['contrib/autoname/tree_data.py', '--num-epochs=1'], ['contrib/gp/sv-dkl.py', '--epochs=1', '--num-inducing=4'], - ['contrib/oed/ab_test.py', '--num-steps=500'], + ['contrib/oed/ab_test.py', '--num-vi-steps=1000', '--num-acquisitions=2'], ['dmm/dmm.py', '--num-epochs=1'], ['dmm/dmm.py', '--num-epochs=1', '--num-iafs=1'], ['eight_schools/mcmc.py', '--num-samples=500', '--warmup-steps=100'],