Skip to content

Commit

Permalink
Merge pull request #341 from bwasti/reward_reset
Browse files Browse the repository at this point in the history
Add ObservationView to reward reset() arguments
ChrisCummins authored Jul 27, 2021
2 parents 19047db + 5e168f4 commit f3dd7a8
Showing 5 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
@@ -793,7 +793,7 @@ def reset( # pylint: disable=arguments-differ
self.action_space.name, reply.new_action_space.action
)

self.reward.reset(benchmark=self.benchmark)
self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
if self.reward_space:
self.episode_reward = 0.0

7 changes: 4 additions & 3 deletions compiler_gym/envs/llvm/llvm_rewards.py
Original file line number Diff line number Diff line change
@@ -36,9 +36,10 @@ def __init__(self, cost_function: str, init_cost_function: str, **kwargs):
self.init_cost_function: str = init_cost_function
self.previous_cost: Optional[ObservationType] = None

def reset(self, benchmark: Benchmark) -> None:
def reset(self, benchmark: Benchmark, observation_view: ObservationView) -> None:
"""Called on env.reset(). Reset incremental progress."""
del benchmark # unused
del observation_view # unused
self.previous_cost = None

def update(
@@ -68,9 +69,9 @@ def __init__(self, **kwargs):
self.cost_norm: Optional[ObservationType] = None
self.benchmark: Benchmark = None

def reset(self, benchmark: str) -> None:
def reset(self, benchmark: str, observation_view: ObservationView) -> None:
"""Called on env.reset(). Reset incremental progress."""
super().reset(benchmark)
super().reset(benchmark, observation_view)
# The benchmark has changed so we must compute a new cost normalization
# value. If the benchmark has not changed then the previously computed
# value is still valid.
8 changes: 6 additions & 2 deletions compiler_gym/spaces/reward.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

import numpy as np

import compiler_gym
from compiler_gym.spaces.scalar import Scalar
from compiler_gym.util.gym_type_hints import ObservationType, RewardType

@@ -88,12 +89,15 @@ def __init__(
self.deterministic = deterministic
self.platform_dependent = platform_dependent

def reset(self, benchmark: str) -> None:
def reset(
self, benchmark: str, observation_view: "compiler_gym.views.ObservationView"
) -> None:
"""Reset the rewards space. This is called on
:meth:`env.reset() <compiler_gym.envs.CompilerEnv.reset>`.
:param benchmark: The URI of the benchmark that is used for this
episode.
:param observation: An observation view for reward initialization
"""
pass

@@ -153,7 +157,7 @@ def __init__(self, observation_name: str, **kwargs):
)
self.previous_value: Optional[ObservationType] = None

def reset(self, benchmark: str) -> None:
def reset(self, benchmark: str, observation_view) -> None:
"""Called on env.reset(). Reset incremental progress."""
del benchmark # unused
self.previous_value = None
4 changes: 2 additions & 2 deletions compiler_gym/views/reward.py
Original file line number Diff line number Diff line change
@@ -60,15 +60,15 @@ def __getitem__(self, reward_space: str) -> float:
observations = [self._observation_view[obs] for obs in space.observation_spaces]
return space.update(self.previous_action, observations, self._observation_view)

def reset(self, benchmark: Benchmark) -> None:
def reset(self, benchmark: Benchmark, observation_view: ObservationView) -> None:
"""Reset the rewards space view. This is called on
:meth:`env.reset() <compiler_gym.envs.CompilerEnv.reset>`.
:param benchmark: The benchmark that is used for this episode.
"""
self.previous_action = None
for space in self.spaces.values():
space.reset(benchmark=benchmark)
space.reset(benchmark=benchmark, observation_view=observation_view)

def add_space(self, space: Reward) -> None:
"""Register a new :class:`Reward <compiler_gym.spaces.Reward>` space.
2 changes: 1 addition & 1 deletion examples/example_compiler_gym_service/__init__.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ def __init__(self):
)
self.previous_runtime = None

def reset(self, benchmark: str):
def reset(self, benchmark: str, observation_view):
del benchmark # unused
self.previous_runtime = None

0 comments on commit f3dd7a8

Please sign in to comment.