forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_end_to_end.py
132 lines (123 loc) · 4.7 KB
/
test_end_to_end.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
import torch
from botorch.acquisition import ExpectedImprovement, qExpectedImprovement
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models import FixedNoiseGP, SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.utils.testing import BotorchTestCase
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
EPS = 1e-8
NOISE = [
[0.127],
[-0.113],
[-0.345],
[-0.034],
[-0.069],
[-0.272],
[0.013],
[0.056],
[0.087],
[-0.081],
]
class TestEndToEnd(BotorchTestCase):
def _setUp(self, double=False):
dtype = torch.double if double else torch.float
train_x = torch.linspace(0, 1, 10, device=self.device, dtype=dtype).view(-1, 1)
train_y = torch.sin(train_x * (2 * math.pi))
train_yvar = torch.tensor(0.1**2, device=self.device, dtype=dtype)
noise = torch.tensor(NOISE, device=self.device, dtype=dtype)
self.train_x = train_x
self.train_y = train_y + noise
self.train_yvar = train_yvar
self.bounds = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype)
model_st = SingleTaskGP(self.train_x, self.train_y)
self.model_st = model_st.to(device=self.device, dtype=dtype)
self.mll_st = ExactMarginalLogLikelihood(
self.model_st.likelihood, self.model_st
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=OptimizationWarning)
self.mll_st = fit_gpytorch_mll(
self.mll_st,
optimizer_kwargs={"options": {"maxiter": 5}},
max_attempts=1,
)
model_fn = FixedNoiseGP(
self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y)
)
self.model_fn = model_fn.to(device=self.device, dtype=dtype)
self.mll_fn = ExactMarginalLogLikelihood(
self.model_fn.likelihood, self.model_fn
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=OptimizationWarning)
self.mll_fn = fit_gpytorch_mll(
self.mll_fn,
optimizer_kwargs={"options": {"maxiter": 5}},
max_attempts=1,
)
def test_qEI(self):
for double in (True, False):
self._setUp(double=double)
qEI = qExpectedImprovement(self.model_st, best_f=0.0)
candidates, _ = optimize_acqf(
acq_function=qEI,
bounds=self.bounds,
q=3,
num_restarts=10,
raw_samples=20,
options={"maxiter": 5},
)
self.assertTrue(torch.all(-EPS <= candidates))
self.assertTrue(torch.all(candidates <= 1 + EPS))
qEI = qExpectedImprovement(self.model_fn, best_f=0.0)
candidates, _ = optimize_acqf(
acq_function=qEI,
bounds=self.bounds,
q=3,
num_restarts=10,
raw_samples=20,
options={"maxiter": 5},
)
self.assertTrue(torch.all(-EPS <= candidates))
self.assertTrue(torch.all(candidates <= 1 + EPS))
candidates_batch_limit, _ = optimize_acqf(
acq_function=qEI,
bounds=self.bounds,
q=3,
num_restarts=10,
raw_samples=20,
options={"maxiter": 5, "batch_limit": 5},
)
self.assertTrue(torch.all(-EPS <= candidates_batch_limit))
self.assertTrue(torch.all(candidates_batch_limit <= 1 + EPS))
def test_EI(self):
for double in (True, False):
self._setUp(double=double)
EI = ExpectedImprovement(self.model_st, best_f=0.0)
candidates, _ = optimize_acqf(
acq_function=EI,
bounds=self.bounds,
q=1,
num_restarts=10,
raw_samples=20,
options={"maxiter": 5},
)
self.assertTrue(-EPS <= candidates <= 1 + EPS)
EI = ExpectedImprovement(self.model_fn, best_f=0.0)
candidates, _ = optimize_acqf(
acq_function=EI,
bounds=self.bounds,
q=1,
num_restarts=10,
raw_samples=20,
options={"maxiter": 5},
)
self.assertTrue(-EPS <= candidates <= 1 + EPS)