diff --git a/discrete_mixed_bo/experiment_utils.py b/discrete_mixed_bo/experiment_utils.py index 813e5a7..c623d9d 100644 --- a/discrete_mixed_bo/experiment_utils.py +++ b/discrete_mixed_bo/experiment_utils.py @@ -716,8 +716,8 @@ def get_problem(name: str, dim: Optional[int] = None, **kwargs) -> DiscreteTestP def generate_discrete_options( - base_function: DiscreteTestProblem, -) -> List[Dict[int, float]]: + base_function: DiscreteTestProblem, return_tensor: bool = False, +) -> Union[List[Dict[int, float]], Tensor]: categorical_features = base_function.categorical_features discrete_indices = torch.cat( [base_function.integer_indices, base_function.categorical_indices], dim=0 @@ -764,4 +764,6 @@ def generate_discrete_options( start_idx = end_idx # create a list of dictionaries of mapping indices to values # the list has a dictionary for each discrete configuration + if return_tensor: + return discrete_options.to(base_function.bounds) return [dict(zip(indices, xi)) for xi in discrete_options.tolist()] diff --git a/discrete_mixed_bo/run_one_replication.py b/discrete_mixed_bo/run_one_replication.py index 689aae9..2549bea 100644 --- a/discrete_mixed_bo/run_one_replication.py +++ b/discrete_mixed_bo/run_one_replication.py @@ -67,7 +67,7 @@ "exact_round__fin_diff__nehvi-1", "exact_round__ste__nehvi-1", "enumerate__nehvi-1", - "nevergrad_porfolio", + "nevergrad_portfolio", ] @@ -405,7 +405,7 @@ def run_one_replication( # construct a list of dictionaries mapping indices in one-hot space # to parameter values. discrete_options = generate_discrete_options( - base_function=base_function, + base_function=base_function, return_tensor=base_function.cont_indices.shape[0]==0, ) if base_function.cont_indices.shape[0] > 0: # optimize mixed