Skip to content

Commit

Permalink
[ANTBO]: update kernels and some extra files.
Browse files Browse the repository at this point in the history
AntGro committed Oct 10, 2024
1 parent 15ef274 commit 62bc78c
Showing 9 changed files with 187 additions and 109 deletions.
4 changes: 3 additions & 1 deletion AntBO/bo/botask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from bo.base import TestFunction
from task.tools import Absolut, Manual, TableFilling
from task.tools import Absolut, Manual, TableFilling, RandomBlackBox
import torch

class BOTask(TestFunction):
@@ -30,6 +30,8 @@ def __init__(self,
self.fbox = Manual(self.bbox)
elif self.bbox['tool'] == 'table_filling':
self.fbox = TableFilling(self.bbox)
elif self.bbox['tool'] == 'random':
self.fbox = RandomBlackBox(self.bbox)
else:
assert 0,f"{self.bbox['tool']} Not Implemented"

2 changes: 1 addition & 1 deletion AntBO/bo/config.yaml
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ device: 'cuda'
seq_len: 11
normalise: True
batch_size: 1
save_path: './results/BO_transformed_overlap/' # Put path where you want results to be stored '/ABS/PATH/TO/results/'
save_path: './results' # Put path where you want results to be stored '/ABS/PATH/TO/results/'
kernel_type: 'transformed_overlap'
noise_variance: 1e-6
resume: ???
15 changes: 12 additions & 3 deletions AntBO/bo/gp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import warnings
from typing import Optional, Tuple

import torch.nn.functional
from botorch.fit import fit_gpytorch_model
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import Interval
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import ScaleKernel
from gpytorch.kernels import ScaleKernel, RBFKernel, CosineKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.models import ExactGP

from bo.kernels import *
import numpy as np

from bo import CategoricalOverlap, TransformedCategorical, OrdinalKernel, FastStringKernel
from bo.kernels import BERTWarpRBF, BERTWarpCosine


def identity(x):
@@ -120,7 +126,7 @@ def ag_ev_phi(self, num_cats: int, dmu_dphi: torch.Tensor = None, xs: torch.Tens
return ag_phi, ev_phi


def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', hypers={},
def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', hypers: dict | None =None,
noise_variance=None,
cat_configs=None,
antigen=None,
@@ -140,6 +146,9 @@ def train_gp(train_x, train_y, use_ard, num_steps, kern='transformed_overlap', h
assert train_y.ndim == 1
assert train_x.shape[0] == train_y.shape[0]

if hypers is None:
hypers = {}

device = train_x.device
# Create hyper parameter bounds
if noise_variance is None:
114 changes: 22 additions & 92 deletions AntBO/bo/kernels.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Implementation of various kernels

import numpy as np
import torch
from gpytorch.constraints import Interval
from gpytorch.kernels import Kernel
from gpytorch.kernels.cosine_kernel import CosineKernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.constraints import Interval
import torch
import numpy as np
from torch import Tensor


@@ -111,34 +112,6 @@ def mat52(d, ard):
return torch.diag(k_cat).float()
return k_cat.float()

def forward_one_hot(self, x1: torch.tensor, x2: torch.tensor, diag=False, last_dim_is_batch=False, exp='rbf'):
assert not last_dim_is_batch
assert x1.shape[-2:] == x2.shape[-2:], (x1.shape, x2.shape)

diff = x1[:, None] - x2[None, :]
assert diff.shape == (*x1.shape[:-2], *x2.shape[:-2], *x1.shape[-2:]), diff.shape
diff_per_var = diff.abs().sum(dim=-1).div(2)

def rbf(d, ard):
if ard:
return torch.exp(torch.sum(d * self.lengthscale, dim=-1) / torch.sum(self.lengthscale))
else:
return torch.exp(self.lengthscale * torch.sum(d, dim=-1) / x1.shape[1])

def mat52(d, ard):
raise NotImplementedError

if exp == 'rbf':
k_cat = rbf(diff_per_var, self.ard_num_dims is not None and self.ard_num_dims > 1)
elif exp == 'mat52':
k_cat = mat52(diff_per_var, self.ard_num_dims is not None and self.ard_num_dims > 1)
else:
raise ValueError('Exponentiation scheme %s is not recognised!' % exp)

if diag:
return torch.diag(k_cat).float()
return k_cat.float()


class OrdinalKernel(Kernel):
"""
@@ -318,67 +291,24 @@ def _precalc(self):
return torch.pow(self.gap_decay * self.tril, self.exp)


# from transformers import pipeline,\
# AutoTokenizer, \
# Trainer, \
# AutoModel
# import os
# from einops import rearrange
#
# def batch_iterator(data1, data2, step=8):
# assert len(data1)==len(data2), "The data sets should be of same size"
# size = len(data1)
# for i in range(0, size, step):
# yield data1[i:min(i+step, size)], data2[i:min(i+step, size)]
#
#
# BERT_config = {'path':'/nfs/aiml/asif/ProtBERT',
# 'modelname':'prot_bert_bfd',
# 'batch_size': 16}
# BERT_device = 'cuda:1'
# BERT_tokenizer = AutoTokenizer.from_pretrained(f"{BERT_config['path']}/{BERT_config['modelname']}")
# model = AutoModel.from_pretrained(f"{BERT_config['path']}/{BERT_config['modelname']}").to(BERT_device)
#
# class BERTWarpRBF(RBFKernel):
# """Similar to above, but applied to RBF."""
#
# def __init__(self, **kwargs):
# super(BERTWarpRBF, self).__init__(**kwargs)
# AAs = 'ACDEFGHIKLMNPQRSTVWY'
# self.AA_to_idx = {aa: i for i, aa in enumerate(AAs)}
# self.idx_to_AA = {value: key for key, value in self.AA_to_idx.items()}
#
# def compute_features(self, x1, x2):
# with torch.no_grad():
# x1 = [" ".join(self.idx_to_AA[i.item()] for i in x_i) for x_i in x1]
# ids1 = BERT_tokenizer.batch_encode_plus(x1, add_special_tokens=False, padding=True)
# input_ids1 = torch.tensor(ids1['input_ids']).to(BERT_device)
# attention_mask1 = torch.tensor(ids1['attention_mask']).to(BERT_device)
# reprsn1 = model(input_ids=input_ids1, attention_mask=attention_mask1)[0]
#
# x2 = [" ".join(self.idx_to_AA[i.item()] for i in x_i) for x_i in x2]
# ids2 = BERT_tokenizer.batch_encode_plus(x2, add_special_tokens=False, padding=True)
# input_ids2 = torch.tensor(ids2['input_ids']).to(BERT_device)
# attention_mask2 = torch.tensor(ids2['attention_mask']).to(BERT_device)
# reprsn2 = model(input_ids=input_ids2, attention_mask=attention_mask2)[0]
# return reprsn1, reprsn2
#
# def forward(self, samples1, samples2, diag=False, **params):
# inp_device = samples1.device
# nm_samples = samples1.shape[0]
# if nm_samples>BERT_config['batch_size']:
# reprsn1, reprsn2 = [], []
# for x1, x2 in batch_iterator(samples1, samples2, BERT_config['batch_size']):
# features1, features2 = self.compute_features(x1, x2)
# reprsn1.append(features1)
# reprsn2.append(features2)
# reprsn1 = torch.cat(reprsn1, 0)
# reprsn2 = torch.cat(reprsn2, 0)
# else:
# reprsn1, reprsn2 = self.compute_features(samples1, samples2)
# reprsn1 = rearrange(reprsn1, 'b l d -> b (l d)').to(inp_device)
# reprsn2 = rearrange(reprsn2, 'b l d -> b (l d)').to(inp_device)
# return super().forward(reprsn1, reprsn2, diag=diag, **params)
class BERTWarpCosine(CosineKernel):
"""Applied to Cosine."""

def __init__(self, **kwargs):
super(BERTWarpCosine, self).__init__(**kwargs)

def forward(self, x1, x2, diag=False, **params):
return super().forward(x1, x2, diag=diag, **params)


class BERTWarpRBF(RBFKernel):
"""Similar to above, but applied to RBF."""

def __init__(self, **kwargs):
super(BERTWarpRBF, self).__init__(**kwargs)

def forward(self, x1, x2, diag=False, **params):
return super().forward(x1, x2, diag=diag, **params)


if __name__ == '__main__':
1 change: 0 additions & 1 deletion AntBO/bo/localbo_cat.py
Original file line number Diff line number Diff line change
@@ -302,7 +302,6 @@ def _ei(X, augmented=True):
# flip for minimization problems
if self.kernel_type in ['rbfBERT', 'rbf-pca-BERT', 'cosine-BERT', 'cosine-pca-BERT']:
from bo.utils import BERTFeatures
from einops import rearrange
bert = BERTFeatures(self.BERT_model, self.BERT_tokeniser)
x_reprsn = bert.compute_features(X.to(device))
x_center_reprsn = bert.compute_features(torch.tensor(x_center[0].reshape(1, -1)))
6 changes: 1 addition & 5 deletions AntBO/bo/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import os
import pickle
from typing import Any, Optional

@@ -73,9 +72,6 @@ def get_config(config):
return yaml.safe_load(f)


from einops import rearrange


def batch_iterator(data1, step=8):
size = len(data1)
for i in range(0, size, step):
@@ -202,7 +198,7 @@ def update_table_of_candidates(original_table: np.ndarray, observed_candidates:


def update_table_of_candidates_torch(original_table: torch.Tensor, observed_candidates: torch.Tensor,
check_candidates_in_table: bool) -> np.ndarray:
check_candidates_in_table: bool) -> np.ndarray:
""" Update the table of candidates, removing the newly observed candidates from the table
Args:
21 changes: 17 additions & 4 deletions AntBO/task/tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# import pymol
import __main__
import numpy as np
import os
import pandas as pd
import subprocess
import time

import numpy as np
import pandas as pd

# from pymol import cmd
from task.base import BaseTool

@@ -227,6 +229,11 @@ def Energy(self, x):
sequences.append(seq2char)
print(seq2char)

energies = self.get_energies(sequences=sequences)

return np.array(energies), sequences

def get_energies(self, sequences) -> list[float]:
energies = []
for i in range(len(sequences)):
default = np.random.randn()
@@ -240,8 +247,14 @@ def Energy(self, x):
energy2 = float(custom_input(message=f"[{self.antigen}] Confirm energy for {sequences[i]}:",
default=default))
energies.append(energy1)
return energies

return np.array(energies), sequences

class RandomBlackBox(Manual):
""" Suitable for quick debugging """

def get_energies(self, sequences) -> list[float]:
return [np.random.random() for _ in range(len(sequences))]


class TableFilling(BaseTool):
@@ -281,7 +294,7 @@ def Energy(self, x):
validations = np.ones(len(sequences))
else:
print(f"Saved candidates to evaluate in {self.path_to_eval_csv}")
values = [None for i in range(len(sequences))]
values = [None for _ in range(len(sequences))]
validations = np.zeros(len(sequences))

to_eval = pd.DataFrame(
Loading
Oops, something went wrong.

0 comments on commit 62bc78c

Please sign in to comment.