Skip to content

Rendering Issue while training #94

Open
@brunoleej

Description

First of all, thank you for providing such a wonderful environment.

I am currently trying to implement the RecordVideoWrapper in the Navix environment.

I want to render a video while training.

Based on the tutorial in this repository, it seems we need to specify the observation_fn as nx.observations.rgb. However, if I remove this option, the state returns a grid of size (height, width, 3).

I would like to ask: how can I record a video of my training in the Navix environment?

I have attached the differences between using nx.observations.rgb and not using this option.

  1. env = nx.make(env_name, observation_fn=nx.observations.rgb)
    d82c8073-a8b2-451c-b733-c1eaf1876189

  2. without observation function

class NavixGymnaxWrapper:
    def __init__(self, env_name):
        self._env = nx.make(env_name)

    def reset(self, key, params=None):
        timestep = self._env.reset(key)
        return timestep.observation, timestep

    def step(self, key, state, action, params=None):
        timestep = self._env.step(state, action)
        return timestep.observation, timestep, timestep.reward, timestep.is_done(), {}

    def observation_space(self, params):
        return spaces.Box(
            low=self._env.observation_space.minimum,
            high=self._env.observation_space.maximum,
            shape=self._env.observation_space.shape,
            dtype=self._env.observation_space.dtype,
        )

    def action_space(self, params):
        return spaces.Discrete(
            num_categories=self._env.action_space.maximum.item() + 1,
        )

@struct.dataclass
class VideoEnvState:
    env_state: environment.EnvState
    frames: chex.Array
    frame_counter: int

class RecordVideoWrapper(GymnaxWrapper):
    def __init__(self, env: environment.Environment, max_episode_length: int = 1000,
                 video_dir: str = "./videos", video_name: str = "episode"):
        super().__init__(env)
        self.max_episode_length = max_episode_length
        self.video_dir = video_dir
        self.video_name = video_name
        if not os.path.exists(self.video_dir):
            os.makedirs(self.video_dir)
    
    @partial(jax.jit, static_argnums=(0,))
    def reset(self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
              ) -> Tuple[chex.Array, VideoEnvState]:
        obs, env_state = self._env.reset(key, params)
        H, W, C = obs.shape if obs.ndim == 3 else (*obs.shape, 1)
        frames = jnp.zeros((self.max_episode_length, H, W, C), dtype=obs.dtype)
        frames = frames.at[0].set(obs)
        state = VideoEnvState(env_state=env_state, frames=frames, frame_counter=1)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(self, key: chex.PRNGKey, state: VideoEnvState, action: Union[int, float],
             params: Optional[environment.EnvParams] = None
             ) -> Tuple[chex.Array, VideoEnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
        frames = state.frames.at[state.frame_counter].set(obs)
        frame_counter = state.frame_counter + 1
        state = VideoEnvState(env_state=env_state, frames=frames, frame_counter=frame_counter)
        return obs, state, reward, done, info

    def save_video(self, frames: np.ndarray, video_path: str):
        if len(frames) == 0:
            return
        frames_list = [np.array(frame) for frame in frames]
        for i, frame in enumerate(frames_list):
            if frame.ndim == 2 or (frame.ndim == 3 and frame.shape[2] == 1):
                frames_list[i] = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
        imageio.mimsave(video_path, frames_list, format='GIF', fps=30)
        print(f"GIF is saved: {video_path}")
    
env, env_params = NavixGymnaxWrapper2(env_name), None
env = LogWrapper(env)
env = SparseRewardWrapper(env)
env = RecordVideoWrapper(env, max_episode_length=1000, video_dir='./videos', video_name='episode1')
env = FlattenObservationWrapper(env)

episode1

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    • Status

      TODOs

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions