# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from __future__ import annotations import torch.nn import torch.optim from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, EndOfLifeTransform, GrayScale, GymEnv, NoopResetEnv, Resize, RewardSum, SignTransform, StepCounter, ToTensorImage, TransformedEnv, VecNorm, ) from torchrl.modules import ConvNet, MLP, QValueActor from torchrl.record import VideoRecorder # ==================================================================== # Environment utils # -------------------------------------------------------------------- def make_env(env_name, frame_skip, device, is_test=False): env = GymEnv( env_name, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device, categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: env.append_transform(EndOfLifeTransform()) env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env # ==================================================================== # Model utils # -------------------------------------------------------------------- def make_dqn_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape env_specs = proof_environment.specs num_actions = env_specs["input_spec", "full_action_spec", "action"].space.n action_spec = env_specs["input_spec", "full_action_spec", "action"] # Define Q-Value Module cnn = ConvNet( activation_class=torch.nn.ReLU, num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], device=device, ) cnn_output = cnn(torch.ones(input_shape, device=device)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=torch.nn.ReLU, out_features=num_actions, num_cells=[512], device=device, ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), spec=Composite(action=action_spec).to(device), in_keys=["pixels"], ) return qvalue_module def make_dqn_model(env_name, frame_skip, device): proof_environment = make_env(env_name, frame_skip, device=device) qvalue_module = make_dqn_modules_pixels(proof_environment, device=device) del proof_environment return qvalue_module # ==================================================================== # Evaluation utils # -------------------------------------------------------------------- def eval_model(actor, test_env, num_episodes=3): test_rewards = torch.zeros(num_episodes, dtype=torch.float32) for i in range(num_episodes): td_test = test_env.rollout( policy=actor, auto_reset=True, auto_cast_to_device=True, break_when_any_done=True, max_steps=10_000_000, ) test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards[i] = reward.sum() del td_test return test_rewards.mean() def dump_video(module): if isinstance(module, VideoRecorder): module.dump()