Skip to content

Commit

Permalink
[ANTBO]: fix batch results report.
Browse files Browse the repository at this point in the history
  • Loading branch information
AntGro committed Jan 10, 2025
1 parent 0d846c8 commit 8c2b424
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions AntBO/bo/main.py
Original file line number Diff line number Diff line change
@@ -193,27 +193,20 @@ def save(self, optim) -> None:
save_w_pickle(np.random.get_state(), self.np_rd_state_path)
save_w_pickle(random.getstate(), self.random_rd_state_path)

def results(self, optim, x, itern, rtime):
Y = np.array(optim.casmopolitan.fX)
if Y[:(itern + 1)].shape[0]:

# sequential
if self.config['batch_size'] == 1:
argmin = np.argmin(Y[:(itern + 1) * self.config['batch_size']])
x_best = ''.join([self.f_obj.fbox.idx_to_AA[j] for j in
optim.casmopolitan.X[:(itern + 1) * self.config['batch_size']][argmin].flatten()])
self.res.iloc[itern, :] = [itern, float(Y[-1]), float(np.min(Y[:(itern + 1)])), rtime,
self.f_obj.idx_to_seq(x)[0], x_best]
# batch
else:
for idx, j in enumerate(
range(itern * self.config['batch_size'], (itern + 1) * self.config['batch_size'])):
argmin = np.argmin(Y[:(j + 1) * self.config['batch_size']])
x_best = ''.join([self.f_obj.fbox.idx_to_AA[ind] for ind in
optim.casmopolitan.X[:(j + 1) * self.config['batch_size']][argmin].flatten()])
self.res.iloc[j, :] = [j, float(Y[-idx]), float(np.min(Y[:(j + 1) * self.config['batch_size']])),
rtime,
self.f_obj.idx_to_seq(x)[idx], x_best]
def results(self, optim: Optimizer, x: np.ndarray, itern: int, rtime: float) -> None:
y = np.array(optim.casmopolitan.fX)
if y[:itern + 1].shape[0] == 0:
return

antibodies = self.f_obj.idx_to_seq(x)

def add_best(step: int, y_val: float, last_protein: str) -> None:
argmin = np.argmin(y[:step + 1])
x_best = ''.join([self.f_obj.fbox.idx_to_AA[ind] for ind in optim.casmopolitan.X[argmin].flatten()])
self.res.iloc[step, :] = [step, y_val, float(np.min(y[:(step + 1)])), rtime, last_protein, x_best]

for idx, j in enumerate(range(itern * self.config['batch_size'], (itern + 1) * self.config['batch_size'])):
add_best(step=j, y_val=float(y[j]), last_protein=antibodies[idx])

def run(self):
random.seed(self.seed)

0 comments on commit 8c2b424

Please sign in to comment.