Open
Description
First of all, thank you for providing such a wonderful environment.
I am currently trying to implement the RecordVideoWrapper in the Navix environment.
I want to render a video while training.
Based on the tutorial in this repository, it seems we need to specify the observation_fn as nx.observations.rgb. However, if I remove this option, the state returns a grid of size (height, width, 3).
I would like to ask: how can I record a video of my training in the Navix environment?
I have attached the differences between using nx.observations.rgb and not using this option.
class NavixGymnaxWrapper:
def __init__(self, env_name):
self._env = nx.make(env_name)
def reset(self, key, params=None):
timestep = self._env.reset(key)
return timestep.observation, timestep
def step(self, key, state, action, params=None):
timestep = self._env.step(state, action)
return timestep.observation, timestep, timestep.reward, timestep.is_done(), {}
def observation_space(self, params):
return spaces.Box(
low=self._env.observation_space.minimum,
high=self._env.observation_space.maximum,
shape=self._env.observation_space.shape,
dtype=self._env.observation_space.dtype,
)
def action_space(self, params):
return spaces.Discrete(
num_categories=self._env.action_space.maximum.item() + 1,
)
@struct.dataclass
class VideoEnvState:
env_state: environment.EnvState
frames: chex.Array
frame_counter: int
class RecordVideoWrapper(GymnaxWrapper):
def __init__(self, env: environment.Environment, max_episode_length: int = 1000,
video_dir: str = "./videos", video_name: str = "episode"):
super().__init__(env)
self.max_episode_length = max_episode_length
self.video_dir = video_dir
self.video_name = video_name
if not os.path.exists(self.video_dir):
os.makedirs(self.video_dir)
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
) -> Tuple[chex.Array, VideoEnvState]:
obs, env_state = self._env.reset(key, params)
H, W, C = obs.shape if obs.ndim == 3 else (*obs.shape, 1)
frames = jnp.zeros((self.max_episode_length, H, W, C), dtype=obs.dtype)
frames = frames.at[0].set(obs)
state = VideoEnvState(env_state=env_state, frames=frames, frame_counter=1)
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(self, key: chex.PRNGKey, state: VideoEnvState, action: Union[int, float],
params: Optional[environment.EnvParams] = None
) -> Tuple[chex.Array, VideoEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
frames = state.frames.at[state.frame_counter].set(obs)
frame_counter = state.frame_counter + 1
state = VideoEnvState(env_state=env_state, frames=frames, frame_counter=frame_counter)
return obs, state, reward, done, info
def save_video(self, frames: np.ndarray, video_path: str):
if len(frames) == 0:
return
frames_list = [np.array(frame) for frame in frames]
for i, frame in enumerate(frames_list):
if frame.ndim == 2 or (frame.ndim == 3 and frame.shape[2] == 1):
frames_list[i] = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imageio.mimsave(video_path, frames_list, format='GIF', fps=30)
print(f"GIF is saved: {video_path}")
env, env_params = NavixGymnaxWrapper2(env_name), None
env = LogWrapper(env)
env = SparseRewardWrapper(env)
env = RecordVideoWrapper(env, max_episode_length=1000, video_dir='./videos', video_name='episode1')
env = FlattenObservationWrapper(env)
Metadata
Assignees
Labels
No labels
Projects
Status
TODOs