Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
docusaurus-bot committed Jul 28, 2022
1 parent 2402a28 commit deba329
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions discrete_mixed_bo/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,8 @@ def get_problem(name: str, dim: Optional[int] = None, **kwargs) -> DiscreteTestP


def generate_discrete_options(
base_function: DiscreteTestProblem,
) -> List[Dict[int, float]]:
base_function: DiscreteTestProblem, return_tensor: bool = False,
) -> Union[List[Dict[int, float]], Tensor]:
categorical_features = base_function.categorical_features
discrete_indices = torch.cat(
[base_function.integer_indices, base_function.categorical_indices], dim=0
Expand Down Expand Up @@ -764,4 +764,6 @@ def generate_discrete_options(
start_idx = end_idx
# create a list of dictionaries of mapping indices to values
# the list has a dictionary for each discrete configuration
if return_tensor:
return discrete_options.to(base_function.bounds)
return [dict(zip(indices, xi)) for xi in discrete_options.tolist()]
4 changes: 2 additions & 2 deletions discrete_mixed_bo/run_one_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"exact_round__fin_diff__nehvi-1",
"exact_round__ste__nehvi-1",
"enumerate__nehvi-1",
"nevergrad_porfolio",
"nevergrad_portfolio",
]


Expand Down Expand Up @@ -405,7 +405,7 @@ def run_one_replication(
# construct a list of dictionaries mapping indices in one-hot space
# to parameter values.
discrete_options = generate_discrete_options(
base_function=base_function,
base_function=base_function, return_tensor=base_function.cont_indices.shape[0]==0,
)
if base_function.cont_indices.shape[0] > 0:
# optimize mixed
Expand Down

0 comments on commit deba329

Please sign in to comment.