Skip to content

Commit

Permalink
[ANTBO]: support pre-computed embeddings
Browse files Browse the repository at this point in the history
AntGro committed Jan 15, 2025
1 parent 45d8011 commit 2acf5d7
Showing 19 changed files with 947 additions and 808 deletions.
10 changes: 6 additions & 4 deletions AntBO/bo/botask.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import numpy as np
import torch

from bo.base import TestFunction
from task.tools import Absolut, Manual, TableFilling, RandomBlackBox
import torch


class BOTask(TestFunction):
"""
BO Task Class
"""
# this should be changed if we are tackling a mixed, or continuous problem, for e.g.
problem_type = 'categorical'

def __init__(self,
device,
n_categories,
@@ -33,8 +36,7 @@ def __init__(self,
elif self.bbox['tool'] == 'random':
self.fbox = RandomBlackBox(self.bbox)
else:
assert 0,f"{self.bbox['tool']} Not Implemented"

assert 0, f"{self.bbox['tool']} Not Implemented"

def compute(self, x):
'''
@@ -48,4 +50,4 @@ def idx_to_seq(self, x):
seqs = []
for seq in x:
seqs.append(''.join(self.fbox.idx_to_AA[int(aa)] for aa in seq))
return seqs
return seqs
2 changes: 1 addition & 1 deletion AntBO/bo/config.yaml
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ seq_len: 11
normalise: True
batch_size: 1
save_path: './results' # Put path where you want results to be stored '/ABS/PATH/TO/results/'
kernel_type: 'transformed_overlap'
kernel_type: 'transformed_overlap' # mat52
noise_variance: 1e-6
resume: ???
search_strategy: 'local' # local (CASMO), 'batch_local' (CASMO with NSGA con), 'global' (NSGA), or glocal (HEBO variant)
12 changes: 6 additions & 6 deletions AntBO/bo/custom_init.py
Original file line number Diff line number Diff line change
@@ -19,22 +19,22 @@ def get_top_cut_ratio_per_cat(top_cut_ratio_loosers: int, top_cut_ratio_mascotte

class InitialBODataset:

def __init__(self, data: pd.DataFrame):
def __init__(self, data: pd.DataFrame) -> None:
self.data = data

def get_categories(self):
def get_categories(self) -> np.ndarray:
return self.data['Type'].values

def get_index_encoded_x(self):
def get_index_encoded_x(self) -> np.ndarray:
return np.vstack(self.data['AA to ind'].values)

def get_protein_names(self):
def get_protein_names(self) -> pd.Series:
return self.data['Protein']

def get_protein_binding_energy(self):
def get_protein_binding_energy(self) -> pd.Series:
return self.data['Binding Energy']

def __len__(self):
def __len__(self) -> int:
return len(self.data)


54 changes: 24 additions & 30 deletions AntBO/bo/gp.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from __future__ import annotations

import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Any

import numpy as np
import torch.nn.functional
from botorch.fit import fit_gpytorch_model
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import Interval
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import ScaleKernel, RBFKernel, CosineKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.kernels import ScaleKernel, RBFKernel, CosineKernel, MaternKernel
from gpytorch.likelihoods import GaussianLikelihood, Likelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.models import ExactGP

import numpy as np

from bo import CategoricalOverlap, TransformedCategorical, OrdinalKernel, FastStringKernel
from bo.kernels import BERTWarpRBF, BERTWarpCosine
from bo.localbo_utils import SEARCH_STRATS


def identity(x):
def identity(x: Any) -> Any:
return x


class GP(ExactGP):
def __init__(self, train_x, train_y, likelihood,
outputscale_constraint=None, ard_dims=None, kern=None, MeanMod=ConstantMean, cat_dims=None,
def __init__(self, train_x: torch.tensor, train_y: torch.tensor, likelihood: Likelihood,
outputscale_constraint=None, ard_dims=None, kern=None, mean_mode=ConstantMean, cat_dims=None,
batch_shape=torch.Size(), transform_inputs=None):
if transform_inputs is None:
transform_inputs = identity
@@ -37,15 +37,15 @@ def __init__(self, train_x, train_y, likelihood,
self.dim = train_x.shape[1]
self.ard_dims = ard_dims
self.cat_dims = cat_dims
self.mean_module = MeanMod(batch_shape=batch_shape)
self.mean_module = mean_mode(batch_shape=batch_shape)
self.covar_module = ScaleKernel(kern, outputscale_constraint=outputscale_constraint, batch_shape=batch_shape)

def forward(self, x):
def forward(self, x: torch.tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> MultivariateNormal:
return super().__call__(*[self.transform_inputs(input_point) for input_point in args], **kwargs)

def dmu_dphi(self, num_cats: int, xs: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -92,7 +92,7 @@ def ag_ev_phi(self, num_cats: int, dmu_dphi: torch.Tensor = None, xs: torch.Tens
Parameters
----------
num_cats: number of categories
dmu_dphi: matrix of partial derivatives d mu / d phi of shape (n_points, n_dim, n_categories) --> compute it if None
dmu_dphi: matrix of partial derivatives d mu / d phi of shape (n_pts, n_dim, n_cats) --> compute it if None
xs: points for which derivatives have been computed --> assume it is the training points of the GP if None
n_samples_threshold: if number of samples having feature phi_ij is less than this threshold, AG_ij will be nan
@@ -126,19 +126,13 @@ def ag_ev_phi(self, num_cats: int, dmu_dphi: torch.Tensor = None, xs: torch.Tens
return ag_phi, ev_phi


def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', hypers: dict | None =None,
noise_variance=None,
cat_configs=None,
antigen=None,
search_strategy='local',
acq='EI',
num_samples=51,
warmup_steps=102,
thinning=1,
max_tree_depth=6,
**params):
"""Fit a GP model where train_x is in [0, 1]^d and train_y is standardized.
def train_gp(train_x: torch.tensor, train_y: torch.tensor, use_ard: bool, num_steps: int,
kern: str = 'transformed_overlap', hypers: Optional[dict] = None, noise_variance: float = None,
cat_configs=None, antigen: str = None, search_strategy: SEARCH_STRATS = 'local', **params):
"""
Fit a GP model where train_x is in [0, 1]^d and train_y is standardized.
(train_x, train_y): pairs of x and y (trained)
noise_variance: if provided, this value will be used as the noise variance for the GP model. Otherwise, the noise
variance will be inferred from the model.
"""
@@ -167,8 +161,7 @@ def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', h

outputscale_constraint = Interval(0.5, 5.)

likelihood = GaussianLikelihood(noise_constraint=noise_constraint).to(device=train_x.device,
dtype=train_y.dtype)
likelihood = GaussianLikelihood(noise_constraint=noise_constraint).to(device=train_x.device, dtype=train_y.dtype)

ard_dims = train_x.shape[1] if use_ard else None
transform_inputs = None
@@ -208,11 +201,13 @@ def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', h
if kern in ['rbfBERT', "cosine-BERT"]:
min_x = train_x.min(0)[0]
max_x = train_x.max(0)[0]
transform_inputs = lambda input_x: (input_x - min_x.to(input_x)) / (
max_x.to(input_x) - min_x.to(input_x) + 1e-8)

def transform_inputs(input_x: torch.tensor) -> torch.tensor:
return (input_x - min_x.to(input_x)) / (max_x.to(input_x) - min_x.to(input_x) + 1e-8)
elif kern == 'rbf':
kernel = RBFKernel(lengthscale_constraint=lengthscale_constraint, ard_num_dims=ard_dims)
elif kern == "mat52":
kernel = MaternKernel(nu=2.5, lengthscale_constraint=lengthscale_constraint, ard_num_dims=ard_dims)
else:
raise ValueError('Unknown kernel choice %s' % kern)

@@ -223,7 +218,7 @@ def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', h
kern=kernel,
outputscale_constraint=outputscale_constraint,
ard_dims=ard_dims,
transform_inputs=transform_inputs
transform_inputs=transform_inputs,
).to(device=train_x.device, dtype=train_x.dtype)

# Find optimal model hyperparameters
@@ -257,7 +252,6 @@ def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', h
output = model(train_x, )
loss = -mll(output, train_y).float()
loss.backward()
# print(f"Loss Step {i} = {loss.item()}")
optimizer.step()

# Switch to eval mode
57 changes: 30 additions & 27 deletions AntBO/bo/kernels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Implementation of various kernels

import numpy as np
import torch
from gpytorch.constraints import Interval
from gpytorch.kernels import Kernel
@@ -176,13 +175,17 @@ def __init__(self, seq_length: int, alphabet_size: int, gap_decay=.5, match_deca
for i in range(self.maxlen - 1):
self.exp[i, i + 1:] = torch.arange(self.maxlen - i - 1)

def K_diag(self, X: Tensor):
self.symmetric = None
self.D = None

@staticmethod
def K_diag(self, x: torch.tensor) -> torch.tensor:
r"""
The diagonal elements of the string kernel are always unity (due to normalisation)
"""
return torch.ones(X.shape[:-1], dtype=torch.double)
return torch.ones(x.shape[:-1], dtype=torch.double)

def forward(self, X1, X2, diag=False, last_dim_is_batch=False, **params):
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
r"""
Vectorized kernel calc.
Following notation from Beck (2017), i.e have tensors S,D,Kpp,Kp
@@ -196,43 +199,43 @@ def forward(self, X1, X2, diag=False, last_dim_is_batch=False, **params):
# pad until all have length of self.maxlen
if diag:
raise ValueError()
if X2 is None:
X2 = X1
if x2 is None:
x2 = x1
self.symmetric = True
else:
self.symmetric = False
# keep track of original input sizes
X1_shape = X1.shape[0]
X2_shape = X2.shape[0]
x1_shape = x1.shape[0]
x2_shape = x2.shape[0]

# prep the decay tensor D
self.D = self._precalc().to(X1)
self.D = self._precalc().to(x1)

# turn into one-hot i.e. shape (# strings, #characters+1, alphabet size)
X1 = torch.nn.functional.one_hot(X1.to(int), self.alphabet_size).to(X1)
X2 = torch.nn.functional.one_hot(X2.to(int), self.alphabet_size).to(X2)
x1 = torch.nn.functional.one_hot(x1.to(int), self.alphabet_size).to(x1)
x2 = torch.nn.functional.one_hot(x2.to(int), self.alphabet_size).to(x2)

# get indicies of all possible pairings from X and X2
# get indicies of all possible pairings from X and x2
# this way allows maximum number of kernel calcs to be squished onto the GPU (rather than just doing individual rows of gram)
indicies_2, indicies_1 = torch.meshgrid(torch.arange(0, X2.shape[0]), torch.arange(0, X1.shape[0]))
indicies_2, indicies_1 = torch.meshgrid(torch.arange(0, x2.shape[0]), torch.arange(0, x1.shape[0]))
indicies = torch.cat([torch.reshape(indicies_1.T, (-1, 1)), torch.reshape(indicies_2.T, (-1, 1))], axis=1)

# if symmetric then only calc upper matrix (fill in rest later)
if self.symmetric:
indicies = indicies[indicies[:, 1] >= indicies[:, 0]]

X1_full = torch.repeat_interleave(X1.unsqueeze(0), len(indicies), dim=0)[
x1_full = torch.repeat_interleave(x1.unsqueeze(0), len(indicies), dim=0)[
np.arange(len(indicies)), indicies[:, 0]]
X2_full = torch.repeat_interleave(X2.unsqueeze(0), len(indicies), dim=0)[
x2_full = torch.repeat_interleave(x2.unsqueeze(0), len(indicies), dim=0)[
np.arange(len(indicies)), indicies[:, 1]]

if not self.symmetric:
# also need to calculate some extra kernel evals for the normalization terms
X1_full = torch.cat([X1_full, X1, X2], 0)
X2_full = torch.cat([X2_full, X1, X2], 0)
x1_full = torch.cat([x1_full, x1, x2], 0)
x2_full = torch.cat([x2_full, x1, x2], 0)

# Make S: the similarity tensor of shape (# strings, #characters, # characters)
S = torch.matmul(X1_full, torch.transpose(X2_full, 1, 2))
S = torch.matmul(x1_full, torch.transpose(x2_full, 1, 2))

# store squared match coef
match_sq = self.match_decay ** 2
@@ -253,9 +256,9 @@ def forward(self, X1, X2, diag=False, last_dim_is_batch=False, **params):
# put results into the right places in the gram matrix and normalize
if self.symmetric:
# if symmetric then only put in top triangle (inc diag)
mask = torch.triu(torch.ones((X1_shape, X2_shape)), 0).to(S)
mask = torch.triu(torch.ones((x1_shape, x2_shape)), 0).to(S)
non_zero = mask > 0
k_results = torch.zeros((X1_shape, X2_shape)).to(S)
k_results = torch.zeros((x1_shape, x2_shape)).to(S)
k_results[non_zero] = k.squeeze()
# add in mising elements (lower diagonal)
k_results = k_results + k_results.T - torch.diag(k_results.diag())
@@ -270,15 +273,15 @@ def forward(self, X1, X2, diag=False, last_dim_is_batch=False, **params):

# COULD SPEED THIS UP FOR PREDICTIONS, AS MANY NORM TERMS ALREADY IN GRAM

X_diag_Ks = k[X1_shape * X2_shape:X1_shape * X2_shape + X1_shape].flatten()
X_diag_Ks = k[x1_shape * x2_shape:x1_shape * x2_shape + x1_shape].flatten()

X2_diag_Ks = k[-X2_shape:].flatten()
x2_diag_Ks = k[-x2_shape:].flatten()

k = k[0:X1_shape * X2_shape]
k_results = k.reshape(X1_shape, X2_shape)
k = k[0:x1_shape * x2_shape]
k_results = k.reshape(x1_shape, x2_shape)

# normalise
norm = torch.matmul(X_diag_Ks[:, None], X2_diag_Ks[None, :])
norm = torch.matmul(X_diag_Ks[:, None], x2_diag_Ks[None, :])
k_results = torch.divide(k_results, torch.sqrt(norm))

return k_results
@@ -316,7 +319,7 @@ def forward(self, x1, x2, diag=False, **params):
import numpy as np
import matplotlib.pyplot as plt

x1 = torch.tensor([[13., 4.],
x1_ = torch.tensor([[13., 4.],
[43., 15.],
[32., 19.],
[41., 9.],
@@ -339,7 +342,7 @@ def forward(self, x1, x2, diag=False, **params):

o = OrdinalKernel(config=[51, 51])
o.lengthscale = 1.
K = o.forward(x1, x1).detach().numpy()
K = o.forward(x1_, x1_).detach().numpy()
plt.imshow(K)
plt.colorbar()
plt.show()
Loading
Oops, something went wrong.

0 comments on commit 2acf5d7

Please sign in to comment.