Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bayesian Optimisation to AB testing example #1250

Merged
merged 19 commits into from
Jul 25, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tidy up BO for OED
  • Loading branch information
Adam Foster committed Jul 25, 2018
commit 523362b9ad43e5f68b20c6ba802fe199bc74e73c
131 changes: 90 additions & 41 deletions examples/contrib/oed/gp_bayes_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,37 @@


class GPBayesOptimizer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to refactor this a bit into a pyro.optim.multi.MultiOptimizer so that it looks more like other Pyro optimizers. That way when we experiment with other optimizers e.g. gradient-based optimizers it'll be easy to swap them out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're best to talk over this together- we need to implement a step function which would replace my run function. Would be also need to refactor the example to use pyro.param? Right now, it's not really true that after a step, the params are updated to new near-optimal values. But obviously I could add self.opt_differentiable(lambda x: self.gpmodel(x)[0]) into each step (minimize the current GP mean function)

"""Performs Bayesian Optimization using a Gaussian Process as an
emulator for the unknown function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the code is a little more settled down, you should expand this docstring and add a couple usage examples

"""

def __init__(self, f, constraints, gpmodel):
"""
:param function f: the objective function which should accept `torch.Tensor`
inputs
:param torch.constraint constraints: constraints defining the domain of `f`
:param gp.models.GPRegression gpmodel: a (possibly initialized) GP
regression model. The kernel, etc is specified via `gpmodel`.
"""
self.f = f
self.constraints = constraints
self.gpmodel = gpmodel

def update_posterior(self, gpmodel, X, y):
X = torch.cat([gpmodel.X, X])
y = torch.cat([gpmodel.y, y])
gpmodel.set_data(X, y)
gpmodel.optimize()
def update_posterior(self, X, y):
X = torch.cat([self.gpmodel.X, X])
y = torch.cat([self.gpmodel.y, y])
self.gpmodel.set_data(X, y)
self.gpmodel.optimize()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very pyronic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do some GP magic to make these updates cheaper than recomputing the full posterior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes depends how much magic you want. Here https://github.com/uber/pyro/blob/dev/pyro/contrib/gp/models/gpr.py#L80 we could have cached Kff and just add the kernel computations we need. One step further is to use the magic of Schur complements https://en.wikipedia.org/wiki/Schur_complement to avoid having to invert the whole kernel matrix. I actually started implementing a Schur complement method for iter_sample but in the end the existing solution works and reuses more existing code. In either case, we would have to refactor internals of GPRegression

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's open a separate issue about this, it seems like generally useful functionality for contrib.gp

Copy link
Member

@fehiepsi fehiepsi Jul 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is not expensive for Bayesian Optimization because num_data in BO is small. When num_data is large, I guess the most expensive operator is Cholesky(Kff).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the other important factor is how you plan to add points to the GP- all at once or drip by drip. In the all-at-once setting, the current code is optimal. You want to avoid explicitly inverting Kff and so using Cholesky plus trtrs is the best. On the other hand, if you are going drip by drip, you can invert the small kernel matrices formed from the new points, and then get the overall precision matrix using Schur complements. But yh, I'll open an issue about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def find_a_candidate(self, sampler, x_init):
def find_a_candidate(self, differentiable, x_init):
"""Given a starting point, `x_init`, takes one LBFGS step
to optimize the differentiable function.

:param function differentiable: a function amenable to torch
autograd
:param torch.Tensor x_init: the initial point

"""
# transform x to an unconstrained domain
unconstrained_x_init = transform_to(self.constraints).inv(x_init)
unconstrained_x = torch.tensor(unconstrained_x_init, requires_grad=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use new_tensor here

Expand All @@ -35,23 +53,29 @@ def closure():
if (torch.log(torch.abs(unconstrained_x)) > 15.).any():
return torch.tensor(float('inf'))
x = transform_to(self.constraints)(unconstrained_x)
y = sampler(x)
autograd.backward(unconstrained_x, autograd.grad(y, unconstrained_x, retain_graph=True))
# print(x, y, unconstrained_x, unconstrained_x.grad)
y = differentiable(x)
autograd.backward(unconstrained_x,
autograd.grad(y, unconstrained_x, retain_graph=True))
return y

# print("Starting off minimizer from", unconstrained_x, x_init)
minimizer.step(closure)
# after finding a candidate in the unconstrained domain,
# convert it back to original domain.
x = transform_to(self.constraints)(unconstrained_x)
opt_y = sampler(x)
if torch.isnan(opt_y).any():
print('opt_y', opt_y)
print('x', x)
opt_y = differentiable(x)
return x.detach(), opt_y.detach()

def opt_differentiable(self, differentiable, num_candidates=5):
"""Optimizes a differentiable function by choosing `num_candidates`
initial points at random and calling :func:`find_a_candidate` on
each. The best candidate is returned with its function value.

:param function differentiable: a function amenable to torch autograd
:param int num_candidates: the number of random starting points to
use
:return: the minimiser and its function value
:rtype: tuple
"""

x_init = self.gpmodel.X[-1:].new_empty(1).uniform_(self.constraints.lower_bound,
self.constraints.upper_bound)
Expand All @@ -67,7 +91,18 @@ def opt_differentiable(self, differentiable, num_candidates=5):
mvalue, argmin = torch.min(torch.cat(values), dim=0)
return candidates[argmin.item()], mvalue

def acquire(self, method="Thompson", num_candidates=1, num_acquisitions=1):
def acquire(self, method="Thompson", num_acquisitions=1, **opt_params):
"""Selects `num_acquisitions` query points at which to query the
original function.

:param str method: the method to use for acquisition. Choose from:
Thompson
:param int num_acquisitions: the number of points to generate
:param dict opt_params: additional parameters for optimization
routines
:return: a tensor of points to evaluate `self.f` at
:rtype: torch.Tensor
"""

if method == "Thompson":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's split this out into an acquisition function parameter instead of implementing acquisition functions inside acquire, so it's easier to experiment with different Monte Carlo acquisition functions.


Expand All @@ -76,7 +111,7 @@ def acquire(self, method="Thompson", num_candidates=1, num_acquisitions=1):

for i in range(num_acquisitions):
sampler = self.gpmodel.iter_sample(noiseless=False)
x, _ = self.opt_differentiable(sampler, num_candidates=5)
x, _ = self.opt_differentiable(sampler, **opt_params)
X[i, ...] = x

return X
Expand All @@ -85,38 +120,52 @@ def acquire(self, method="Thompson", num_candidates=1, num_acquisitions=1):
raise NotImplementedError("Only method Thompson implemented for acquisition")

def run(self, num_steps, num_acquisitions):
plt.figure(figsize=(12, 30))
outer_gs = gridspec.GridSpec(num_steps, 1)
"""
Optimizes `self.f` in `num_steps` steps, acquiring `num_acquisitions`
new function evaluations at each step.

:param int num_steps:
:param int num_steps"
:return: the minimiser and the minimum value
:rtype: tuple
"""

for i in range(num_steps):
X = self.acquire(num_acquisitions=num_acquisitions)
y = self.f(X)
gs = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer_gs[i])
self.plot(gs, xlabel=i+1, with_title=(i % 2 == 0))
self.update_posterior(self.gpmodel, X, y)

plt.show()
self.update_posterior(X, y)

return self.opt_differentiable(lambda x: self.gpmodel(x)[0])

# def run(self, num_steps, num_acquisitions):
# plt.figure(figsize=(12, 30))
# outer_gs = gridspec.GridSpec(num_steps, 1)

def plot(self, gs, xlabel=None, with_title=True):
xlabel = "xmin" if xlabel is None else "x{}".format(xlabel)
Xnew = torch.linspace(-1., 101.)
ax1 = plt.subplot(gs[0])
ax1.plot(self.gpmodel.X.detach().numpy(), self.gpmodel.y.detach().numpy(), "kx") # plot all observed data
with torch.no_grad():
loc, var = self.gpmodel(Xnew, full_cov=False, noiseless=False)
sd = var.sqrt()
ax1.plot(Xnew.numpy(), loc.numpy(), "r", lw=2) # plot predictive mean
ax1.fill_between(Xnew.numpy(), loc.numpy() - 2*sd.numpy(), loc.numpy() + 2*sd.numpy(),
color="C0", alpha=0.3) # plot uncertainty intervals
ax1.set_xlim(-1, 101)
ax1.set_title("Find {}".format(xlabel))
if with_title:
ax1.set_ylabel("Gaussian Process Regression")



# for i in range(num_steps):
# X = self.acquire(num_acquisitions=num_acquisitions)
# y = self.f(X)
# gs = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer_gs[i])
# self.plot(gs, xlabel=i+1, with_title=(i % 2 == 0))
# self.update_posterior(self.gpmodel, X, y)

# plt.show()

# return self.opt_differentiable(lambda x: self.gpmodel(x)[0])


# def plot(self, gs, xlabel=None, with_title=True):
# xlabel = "xmin" if xlabel is None else "x{}".format(xlabel)
# Xnew = torch.linspace(-1., 101.)
# ax1 = plt.subplot(gs[0])
# ax1.plot(self.gpmodel.X.detach().numpy(), self.gpmodel.y.detach().numpy(), "kx") # plot all observed data
# with torch.no_grad():
# loc, var = self.gpmodel(Xnew, full_cov=False, noiseless=False)
# sd = var.sqrt()
# ax1.plot(Xnew.numpy(), loc.numpy(), "r", lw=2) # plot predictive mean
# ax1.fill_between(Xnew.numpy(), loc.numpy() - 2*sd.numpy(), loc.numpy() + 2*sd.numpy(),
# color="C0", alpha=0.3) # plot uncertainty intervals
# ax1.set_xlim(-1, 101)
# ax1.set_title("Find {}".format(xlabel))
# if with_title:
# ax1.set_ylabel("Gaussian Process Regression")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code is unsightly but useful for debugging. test_examples will not admit matplotlib in example code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to use local imports for plotting dependencies:

def plot(self, gs, ...):
    from matplotlib import pyplot as plt
    ...

That's not perfect, but it's better than commented-out code 😉

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably just add a separate test for this and remove the test code here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the code, keep it on a local branch


38 changes: 19 additions & 19 deletions pyro/contrib/gp/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,64 +144,64 @@ def forward(self, Xnew, full_cov=False, noiseless=True):

return loc + self.mean_function(Xnew), cov

def iter_sample(self, full_cov=False, noiseless=True):
def iter_sample(self, noiseless=True):
r"""
Iteratively constructs a sample from the Gaussian Process posterior.

Recall that at test input points :math:`X_{new}`, the posterior is
multivariate Gaussian distributed with mean and covariance matrix
given by :func:`forward`.

This method samples lazily from multivariate Gaussian. The advantage
This method samples lazily from this multivariate Gaussian. The advantage
of this approach is that later query points can depend upon earlier ones.
Particularly useful when the querying is to be done by an optimisation
routine.

.. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with
kernel's parameters have been learned from a training procedure (MCMC or
SVI).

:param bool noiseless: A flag to decide if we want to include noise in the
sample or not.
:param bool noiseless: A flag to decide if we want to add sampling noise
to the samples beyond the noise inherent in the GP posterior.
:returns: sampler
:rtype: function
"""
noise = self.guide().detach()
# Make these visible in the inner function
global X, y, Kff, N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using global here and below, how about making sample_next take these variables as arguments and returning a curried function with functools.partial?

def sample_next(X, y, Kff, N, xnew):
    ...

return functools.partial(sample_next, X, y, Kff, xnew)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I didn't do that was that we change X inside the inner function (by writing X = Xnew). I thought scoping rules would mean that that change wasn't saved, but honestly I didn't actually try it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will LBFGS add many samples to the globals "X", "Y" during its optimization for sampler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will, but I passed max_eval=20 to cap this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also put these variables into a dictionary and mutate that instead of using global


noise = self.guide().detach()
X = self.X.clone().detach()
y = self.y.clone().detach()

N = X.shape[0]
Kff = self.kernel(X).contiguous()
Kff.view(-1)[::N + 1] += noise # add noise to the diagonal

def sample_next(xnew):
"""Repeatedly samples from the Gaussian process posterior,
conditioning on previously sampled values.
"""
if torch.isnan(xnew).any():
raise ValueError("Cannot evaluate GP at value: {}".format(xnew))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you might want to use warn_if_nan https://github.com/uber/pyro/blob/dev/pyro/util.py#L49


# Variables from outer scope
global X, y, Kff, N
Lff = Kff.potrf(upper=False)

# Compute Cholesky decomposition of kernel matrix
Lff = Kff.potrf(upper=False)
y_residual = y - self.mean_function(X)

# Compute conditional mean and variance
loc, cov = conditional(xnew, X, self.kernel, y_residual, None, Lff,
False, jitter=self.jitter)

if not noiseless:
cov = cov + noise

# Todo use pyro.sample
d = normal.Normal(torch.tensor(0.), torch.tensor(1.))
# cov = torch.max(cov, torch.tensor(0.))
if torch.isnan(loc) or torch.isnan(cov) or torch.isnan(cov.sqrt()):
print('loc', loc)
print('cov', cov)
print('X', X)
print('xnew', xnew)
print('Kff', Kff)
print('LL^T', Lff.mm(Lff.t()))
print('logdet', Kff.logdet())
print('N', N)
raise
# Reparametrize explicitly - aids autograd
ynew = (loc + self.mean_function(xnew)) + d.sample()*cov.sqrt()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to make this a true pyronic sampler if possible

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you write torch.distributions.Normal(loc + self.mean_function(xnew), cov.sqrt()).rsample()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, had it in my mind that that didn't work- it does


# Update kernel matrix
Kffnew = Kff.new_empty(N+1,N+1)
Kffnew[:N, :N] = Kff
cross = self.kernel(X, xnew).squeeze()
Expand Down