Skip to content

Commit

Permalink
Make random_search() return a CompilerEnv.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Sep 10, 2021
1 parent a8408e8 commit 977485c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
25 changes: 15 additions & 10 deletions compiler_gym/bin/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,28 @@ def make_env():
with make_env() as env:
env.reset()

best_reward, _ = random_search(
env = random_search(
make_env=make_env,
outdir=Path(FLAGS.output_dir) if FLAGS.output_dir else None,
patience=FLAGS.patience,
total_runtime=FLAGS.runtime,
nproc=FLAGS.nproc,
skip_done=FLAGS.skip_done,
)

# Exit with error if --fail_threshold was set and the best reward does not
# meet this value.
if FLAGS.fail_threshold is not None and best_reward < FLAGS.fail_threshold:
print(
f"Best reward {best_reward:.3f} below threshold of {FLAGS.fail_threshold}",
file=sys.stderr,
)
sys.exit(1)
try:
# Exit with error if --fail_threshold was set and the best reward does not
# meet this value.
if (
FLAGS.fail_threshold is not None
and env.episode_reward < FLAGS.fail_threshold
):
print(
f"Best reward {env.episode_reward:.3f} below threshold of {FLAGS.fail_threshold}",
file=sys.stderr,
)
sys.exit(1)
finally:
env.close()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions compiler_gym/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from threading import Thread
from time import sleep, time
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Union

import humanize

Expand Down Expand Up @@ -106,7 +106,7 @@ def random_search(
patience: int = 0,
nproc: int = cpu_count(),
skip_done: bool = False,
) -> Tuple[float, List[int]]:
) -> CompilerEnv:
with make_env() as env:
env.reset()
if not isinstance(env.unwrapped, CompilerEnv):
Expand Down Expand Up @@ -251,4 +251,4 @@ def random_search(
env.reset()
replay_actions(env, best_action_names, outdir)

return best_returns, best_actions
return env
3 changes: 2 additions & 1 deletion tests/random_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def test_random_search_smoke_test():
with tempfile.TemporaryDirectory() as tmp:
outdir = Path(tmp)
set_command_line_flags(["argv0"])
random_search(
env = random_search(
make_env=make_env,
outdir=outdir,
patience=50,
total_runtime=3,
nproc=1,
skip_done=False,
)
env.close()

assert (outdir / "random_search.json").is_file()
assert (outdir / "random_search_progress.csv").is_file()
Expand Down

0 comments on commit 977485c

Please sign in to comment.