From 788de21a65fbf05011c6d681b2664e6fa290b77a Mon Sep 17 00:00:00 2001 From: Sergei Laktionov Date: Sun, 2 Jul 2023 18:28:52 +0100 Subject: [PATCH] Revert "add comment to play_and_record fun" This reverts commit fe796548c76595dd7b3837360321b8b2a53f0fa8. --- week04_approx_rl/homework_pytorch_debug.ipynb | 275 ++++++-------- week04_approx_rl/homework_pytorch_main.ipynb | 351 +++++++----------- week04_approx_rl/seminar_pytorch.ipynb | 156 +++----- 3 files changed, 293 insertions(+), 489 deletions(-) diff --git a/week04_approx_rl/homework_pytorch_debug.ipynb b/week04_approx_rl/homework_pytorch_debug.ipynb index dc562122..cc724d1c 100644 --- a/week04_approx_rl/homework_pytorch_debug.ipynb +++ b/week04_approx_rl/homework_pytorch_debug.ipynb @@ -144,7 +144,7 @@ "env = make_env()\n", "env.reset()\n", "plt.imshow(env.render())\n", - "state_shape, n_actions = env.observation_space.shape, env.action_space.n\n" + "state_shape, n_actions = env.observation_space.shape, env.action_space.n" ] }, { @@ -176,11 +176,10 @@ "source": [ "import torch\n", "import torch.nn as nn\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "# those who have a GPU but feel unfair to use it can uncomment:\n", "# device = torch.device('cpu')\n", - "device\n" + "device" ] }, { @@ -203,6 +202,7 @@ " state_dim = state_shape[0]\n", " \n", "\n", + "\n", " def forward(self, state_t):\n", " \"\"\"\n", " takes agent's observation (tensor), returns qvalues (tensor)\n", @@ -230,15 +230,16 @@ " return qvalues.data.cpu().numpy()\n", "\n", " def sample_actions(self, qvalues):\n", - " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy.\"\"\"\n", + " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", " epsilon = self.epsilon\n", " batch_size, n_actions = qvalues.shape\n", "\n", " random_actions = np.random.choice(n_actions, size=batch_size)\n", " best_actions = qvalues.argmax(axis=-1)\n", "\n", - " should_explore = np.random.choice([0, 1], batch_size, p=[1 - epsilon, epsilon])\n", - " return np.where(should_explore, random_actions, best_actions)\n" + " should_explore = np.random.choice(\n", + " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", + " return np.where(should_explore, random_actions, best_actions)" ] }, { @@ -249,7 +250,7 @@ }, "outputs": [], "source": [ - "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)\n" + "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" ] }, { @@ -270,25 +271,21 @@ "outputs": [], "source": [ "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000, seed=None):\n", - " \"\"\"Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward.\"\"\"\n", + " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", " rewards = []\n", " for _ in range(n_games):\n", " s, _ = env.reset(seed)\n", " reward = 0\n", " for _ in range(t_max):\n", " qvalues = agent.get_qvalues([s])\n", - " action = (\n", - " qvalues.argmax(axis=-1)[0]\n", - " if greedy\n", - " else agent.sample_actions(qvalues)[0]\n", - " )\n", + " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", " s, r, terminated, truncated, _ = env.step(action)\n", " reward += r\n", " if terminated or truncated:\n", " break\n", "\n", " rewards.append(reward)\n", - " return np.mean(rewards)\n" + " return np.mean(rewards)" ] }, { @@ -324,19 +321,14 @@ "outputs": [], "source": [ "from replay_buffer import ReplayBuffer\n", - "\n", "exp_replay = ReplayBuffer(10)\n", "\n", "for _ in range(30):\n", - " exp_replay.add(\n", - " env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False\n", - " )\n", + " exp_replay.add(env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False)\n", "\n", "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", "\n", - "assert (\n", - " len(exp_replay) == 10\n", - "), \"experience replay size should be 10 because that's what maximum capacity is\"\n" + "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" ] }, { @@ -363,7 +355,7 @@ " # Play the game for n_steps as per instructions above\n", " \n", "\n", - " return sum_rewards, s\n" + " return sum_rewards, s" ] }, { @@ -386,49 +378,31 @@ "\n", "# if you're using your own experience replay buffer, some of those tests may need correction.\n", "# just make sure you know what your code does\n", - "assert len(exp_replay) == 1000, (\n", - " \"play_and_record should have added exactly 1000 steps, \"\n", + "assert len(exp_replay) == 1000, \\\n", + " \"play_and_record should have added exactly 1000 steps, \" \\\n", " \"but instead added %i\" % len(exp_replay)\n", - ")\n", "is_dones = list(zip(*exp_replay._storage))[-1]\n", "\n", - "assert 0 < np.mean(is_dones) < 0.1, (\n", - " \"Please make sure you restart the game whenever it is 'done' and \"\n", - " \"record the is_done correctly into the buffer. Got %f is_done rate over \"\n", - " \"%i steps. [If you think it's your tough luck, just re-run the test]\"\n", - " % (np.mean(is_dones), len(exp_replay))\n", - ")\n", + "assert 0 < np.mean(is_dones) < 0.1, \\\n", + " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", + " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", + " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", + " np.mean(is_dones), len(exp_replay))\n", "\n", "for _ in range(100):\n", - " (\n", - " obs_batch,\n", - " act_batch,\n", - " reward_batch,\n", - " next_obs_batch,\n", - " is_done_batch,\n", - " ) = exp_replay.sample(10)\n", + " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", - " assert act_batch.shape == (\n", - " 10,\n", - " ), \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", - " assert reward_batch.shape == (\n", - " 10,\n", - " ), \"rewards batch should have shape (10,) but is instead %s\" % str(\n", - " reward_batch.shape\n", - " )\n", - " assert is_done_batch.shape == (\n", - " 10,\n", - " ), \"is_done batch should have shape (10,) but is instead %s\" % str(\n", - " is_done_batch.shape\n", - " )\n", - " assert [\n", - " int(i) in (0, 1) for i in is_dones\n", - " ], \"is_done should be strictly True or False\"\n", - " assert [\n", - " 0 <= a < n_actions for a in act_batch\n", - " ], \"actions should be within [0, n_actions)\"\n", - "\n", - "print(\"Well done!\")\n" + " assert act_batch.shape == (10,), \\\n", + " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", + " assert reward_batch.shape == (10,), \\\n", + " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", + " assert is_done_batch.shape == (10,), \\\n", + " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", + " assert [int(i) in (0, 1) for i in is_dones], \\\n", + " \"is_done should be strictly True or False\"\n", + " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", + "\n", + "print(\"Well done!\")" ] }, { @@ -462,7 +436,7 @@ "source": [ "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", "# This is how you can load weights from agent into target network\n", - "target_network.load_state_dict(agent.state_dict())\n" + "target_network.load_state_dict(agent.state_dict())" ] }, { @@ -508,32 +482,19 @@ }, "outputs": [], "source": [ - "def compute_td_loss(\n", - " states,\n", - " actions,\n", - " rewards,\n", - " next_states,\n", - " is_done,\n", - " agent,\n", - " target_network,\n", - " gamma=0.99,\n", - " check_shapes=False,\n", - " device=device,\n", - "):\n", - " \"\"\"Compute td loss using torch operations only. Use the formulae above.\"\"\"\n", - " states = torch.tensor(\n", - " states, device=device, dtype=torch.float32\n", - " ) # shape: [batch_size, *state_shape]\n", - " actions = torch.tensor(\n", - " actions, device=device, dtype=torch.int64\n", - " ) # shape: [batch_size]\n", - " rewards = torch.tensor(\n", - " rewards, device=device, dtype=torch.float32\n", - " ) # shape: [batch_size]\n", + "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", + " agent, target_network,\n", + " gamma=0.99,\n", + " check_shapes=False,\n", + " device=device):\n", + " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", + " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", + " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", " # shape: [batch_size, *state_shape]\n", " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", " is_done = torch.tensor(\n", - " is_done.astype(\"float32\"),\n", + " is_done.astype('float32'),\n", " device=device,\n", " dtype=torch.float32,\n", " ) # shape: [batch_size]\n", @@ -543,44 +504,34 @@ " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", "\n", " # compute q-values for all actions in next states\n", - " predicted_next_qvalues = target_network(\n", - " next_states\n", - " ) # shape: [batch_size, n_actions]\n", + " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", "\n", " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[\n", - " range(len(actions)), actions\n", - " ] # shape: [batch_size]\n", + " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", "\n", " # compute V*(next_states) using predicted next q-values\n", " next_state_values = \n", "\n", - " assert (\n", - " next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0]\n", - " ), \"must predict one value per state\"\n", + " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", + " \"must predict one value per state\"\n", "\n", - " # # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", - " # # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", - " # # you can multiply next state values by is_not_done to achieve this.\n", + " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", + " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", + " # you can multiply next state values by is_not_done to achieve this.\n", " target_qvalues_for_actions = \n", "\n", " # mean squared error loss to minimize\n", - " loss = torch.mean(\n", - " (predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2\n", - " )\n", + " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", "\n", " if check_shapes:\n", - " assert (\n", - " predicted_next_qvalues.data.dim() == 2\n", - " ), \"make sure you predicted q-values for all actions in next state\"\n", - " assert (\n", - " next_state_values.data.dim() == 1\n", - " ), \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert (\n", - " target_qvalues_for_actions.data.dim() == 1\n", - " ), \"there's something wrong with target q-values, they must be a vector\"\n", + " assert predicted_next_qvalues.data.dim() == 2, \\\n", + " \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim() == 1, \\\n", + " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim() == 1, \\\n", + " \"there's something wrong with target q-values, they must be a vector\"\n", "\n", - " return loss\n" + " return loss" ] }, { @@ -600,32 +551,19 @@ }, "outputs": [], "source": [ - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n", - " 10\n", - ")\n", + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", "\n", - "loss = compute_td_loss(\n", - " obs_batch,\n", - " act_batch,\n", - " reward_batch,\n", - " next_obs_batch,\n", - " is_done_batch,\n", - " agent,\n", - " target_network,\n", - " gamma=0.99,\n", - " check_shapes=True,\n", - ")\n", + "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", + " agent, target_network,\n", + " gamma=0.99, check_shapes=True)\n", "loss.backward()\n", "\n", - "assert (\n", - " loss.requires_grad and tuple(loss.data.size()) == ()\n", - "), \"you must return scalar loss - mean over batch\"\n", - "assert np.any(\n", - " next(agent.parameters()).grad.data.cpu().numpy() != 0\n", - "), \"loss must be differentiable w.r.t. network weights\"\n", - "assert np.all(\n", - " next(target_network.parameters()).grad is None\n", - "), \"target network should not have grads\"\n" + "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", + " \"you must return scalar loss - mean over batch\"\n", + "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", + " \"loss must be differentiable w.r.t. network weights\"\n", + "assert np.all(next(target_network.parameters()).grad is None), \\\n", + " \"target network should not have grads\"" ] }, { @@ -649,7 +587,7 @@ "source": [ "from tqdm import trange\n", "from IPython.display import clear_output\n", - "import matplotlib.pyplot as plt\n" + "import matplotlib.pyplot as plt" ] }, { @@ -688,7 +626,7 @@ "\n", "agent = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", "target_network = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n", - "target_network.load_state_dict(agent.state_dict())\n" + "target_network.load_state_dict(agent.state_dict())" ] }, { @@ -708,18 +646,17 @@ "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", "for i in range(100):\n", " if not utils.is_enough_ram(min_available_gb=0.1):\n", - " print(\n", - " \"\"\"\n", + " print(\"\"\"\n", " Less than 100 Mb RAM available.\n", " Make sure the buffer size in not too huge.\n", " Also check, maybe other processes consume RAM heavily.\n", " \"\"\"\n", - " )\n", + " )\n", " break\n", " play_and_record(state, agent, env, exp_replay, n_steps=10**2)\n", " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", " break\n", - "print(len(exp_replay))\n" + "print(len(exp_replay))" ] }, { @@ -746,7 +683,7 @@ "# refresh_target_network_freq = 1000\n", "# eval_freq = 5000\n", "\n", - "# max_grad_norm = 5000\n" + "# max_grad_norm = 5000" ] }, { @@ -771,7 +708,7 @@ "refresh_target_network_freq = 100\n", "eval_freq = 1000\n", "\n", - "max_grad_norm = 5000\n" + "max_grad_norm = 5000" ] }, { @@ -786,7 +723,7 @@ "td_loss_history = []\n", "grad_norm_history = []\n", "initial_state_v_history = []\n", - "step = 0\n" + "step = 0" ] }, { @@ -799,13 +736,12 @@ "source": [ "import time\n", "\n", - "\n", "def wait_for_keyboard_interrupt():\n", " try:\n", " while True:\n", " time.sleep(1)\n", " except KeyboardInterrupt:\n", - " pass\n" + " pass" ] }, { @@ -820,13 +756,11 @@ "with trange(step, total_steps + 1) as progress_bar:\n", " for step in progress_bar:\n", " if not utils.is_enough_ram():\n", - " print(\"less that 100 Mb RAM available, freezing\")\n", - " print(\"make sure everything is ok and use KeyboardInterrupt to continue\")\n", + " print('less that 100 Mb RAM available, freezing')\n", + " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", " wait_for_keyboard_interrupt()\n", "\n", - " agent.epsilon = utils.linear_decay(\n", - " init_epsilon, final_epsilon, step, decay_steps\n", - " )\n", + " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", "\n", " # play\n", " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", @@ -850,16 +784,17 @@ " \n", "\n", " if step % eval_freq == 0:\n", - " mean_rw_history.append(\n", - " evaluate(\n", - " make_env(), agent, n_games=3, greedy=True, t_max=1000, seed=step\n", - " )\n", + " mean_rw_history.append(evaluate(\n", + " make_env(), agent, n_games=3, greedy=True, t_max=1000, seed=step)\n", + " )\n", + " initial_state_q_values = agent.get_qvalues(\n", + " [make_env().reset(seed=step)[0]]\n", " )\n", - " initial_state_q_values = agent.get_qvalues([make_env().reset(seed=step)[0]])\n", " initial_state_v_history.append(np.max(initial_state_q_values))\n", "\n", " clear_output(True)\n", - " print(\"buffer size = %i, epsilon = %.5f\" % (len(exp_replay), agent.epsilon))\n", + " print(\"buffer size = %i, epsilon = %.5f\" %\n", + " (len(exp_replay), agent.epsilon))\n", "\n", " plt.figure(figsize=[16, 9])\n", "\n", @@ -884,7 +819,7 @@ " plt.plot(utils.smoothen(grad_norm_history))\n", " plt.grid()\n", "\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -895,10 +830,13 @@ }, "outputs": [], "source": [ - "final_score = evaluate(make_env(), agent, n_games=30, greedy=True, t_max=1000)\n", - "print(\"final score:\", final_score)\n", - "assert final_score > 300, \"not good enough for DQN\"\n", - "print(\"Well done\")\n" + "final_score = evaluate(\n", + " make_env(),\n", + " agent, n_games=30, greedy=True, t_max=1000\n", + ")\n", + "print('final score:', final_score)\n", + "assert final_score > 300, 'not good enough for DQN'\n", + "print('Well done')" ] }, { @@ -920,9 +858,9 @@ "source": [ "eval_env = make_env()\n", "record = utils.play_and_log_episode(eval_env, agent)\n", - "print(\"total reward for life:\", np.sum(record[\"rewards\"]))\n", + "print('total reward for life:', np.sum(record['rewards']))\n", "for key in record:\n", - " print(key)\n" + " print(key)" ] }, { @@ -936,18 +874,17 @@ "fig = plt.figure(figsize=(5, 5))\n", "ax = fig.add_subplot(1, 1, 1)\n", "\n", - "ax.scatter(record[\"v_mc\"], record[\"v_agent\"])\n", - "ax.plot(\n", - " sorted(record[\"v_mc\"]), sorted(record[\"v_mc\"]), \"black\", linestyle=\"--\", label=\"x=y\"\n", - ")\n", + "ax.scatter(record['v_mc'], record['v_agent'])\n", + "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", + " 'black', linestyle='--', label='x=y')\n", "\n", "ax.grid()\n", "ax.legend()\n", - "ax.set_title(\"State Value Estimates\")\n", - "ax.set_xlabel(\"Monte-Carlo\")\n", - "ax.set_ylabel(\"Agent\")\n", + "ax.set_title('State Value Estimates')\n", + "ax.set_xlabel('Monte-Carlo')\n", + "ax.set_ylabel('Agent')\n", "\n", - "plt.show()\n" + "plt.show()" ] } ], diff --git a/week04_approx_rl/homework_pytorch_main.ipynb b/week04_approx_rl/homework_pytorch_main.ipynb index 64e699b3..3b843b81 100644 --- a/week04_approx_rl/homework_pytorch_main.ipynb +++ b/week04_approx_rl/homework_pytorch_main.ipynb @@ -84,7 +84,7 @@ "import random\n", "import numpy as np\n", "import torch\n", - "import utils\n" + "import utils" ] }, { @@ -121,7 +121,7 @@ }, "outputs": [], "source": [ - "ENV_NAME = \"BreakoutNoFrameskip-v4\"\n" + "ENV_NAME = \"BreakoutNoFrameskip-v4\"" ] }, { @@ -162,7 +162,7 @@ " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", " ax.imshow(env.render())\n", " env.step(env.action_space.sample())\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -189,7 +189,7 @@ "\n", "# from gymnasium.utils.play import play\n", "\n", - "# play(env=gym.make(ENV_NAME, render_mode=\"rgb_array\"), zoom=4, fps=40)\n" + "# play(env=gym.make(ENV_NAME, render_mode=\"rgb_array\"), zoom=4, fps=40)" ] }, { @@ -237,9 +237,11 @@ " self.img_size = (1, 64, 64)\n", " self.observation_space = Box(0.0, 1.0, self.img_size)\n", "\n", + "\n", " def _to_gray_scale(self, rgb, channel_weights=[0.8, 0.1, 0.1]):\n", " \n", "\n", + "\n", " def observation(self, img):\n", " \"\"\"what happens to each observation\"\"\"\n", "\n", @@ -252,7 +254,7 @@ " # * cast image to grayscale\n", " # * convert image pixels to (0,1) range, float32 type\n", " \n", - " return \n" + " return " ] }, { @@ -264,7 +266,6 @@ "outputs": [], "source": [ "import gymnasium as gym\n", - "\n", "# spawn game instance for tests\n", "env = gym.make(ENV_NAME, render_mode=\"rgb_array\") # create raw env\n", "env = PreprocessAtariObs(env)\n", @@ -274,13 +275,12 @@ "obs, _, _, _, _ = env.step(env.action_space.sample())\n", "\n", "# test observation\n", - "assert (\n", - " obs.ndim == 3\n", - "), \"observation must be [channel, h, w] even if there's just one channel\"\n", + "assert obs.ndim == 3, \"observation must be [channel, h, w] even if there's just one channel\"\n", "assert obs.shape == observation_shape, obs.shape\n", - "assert obs.dtype == \"float32\"\n", + "assert obs.dtype == 'float32'\n", "assert len(np.unique(obs)) > 2, \"your image must not be binary\"\n", - "assert 0 <= np.min(obs) and np.max(obs) <= 1, \"convert image pixels to [0,1] range\"\n", + "assert 0 <= np.min(obs) and np.max(\n", + " obs) <= 1, \"convert image pixels to [0,1] range\"\n", "\n", "assert np.max(obs) >= 0.5, \"It would be easier to see a brighter observation\"\n", "assert np.mean(obs) >= 0.1, \"It would be easier to see a brighter observation\"\n", @@ -294,7 +294,7 @@ "for row in range(n_rows):\n", " for col in range(n_cols):\n", " ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)\n", - " ax.imshow(obs[0, :, :], interpolation=\"none\", cmap=\"gray\")\n", + " ax.imshow(obs[0, :, :], interpolation='none', cmap='gray')\n", " obs, _, _, _, _ = env.step(env.action_space.sample())\n", "plt.show()\n" ] @@ -327,9 +327,8 @@ "source": [ "import atari_wrappers\n", "\n", - "\n", "def PrimaryAtariWrap(env, clip_rewards=True):\n", - " assert \"NoFrameskip\" in env.spec.id\n", + " assert 'NoFrameskip' in env.spec.id\n", "\n", " # This wrapper holds the same action for frames and outputs\n", " # the maximal pixel value of 2 last frames (to handle blinking\n", @@ -352,7 +351,7 @@ "\n", " # This wrapper is yours :)\n", " env = PreprocessAtariObs(env)\n", - " return env\n" + " return env" ] }, { @@ -385,7 +384,7 @@ "# env = atari_wrappers.AntiTorchWrapper(env)\n", "# return env\n", "\n", - "# play(make_play_env(), zoom=4, fps=3)\n" + "# play(make_play_env(), zoom=4, fps=3)" ] }, { @@ -415,19 +414,17 @@ "source": [ "from framebuffer import FrameBuffer\n", "\n", - "\n", "def make_env(clip_rewards=True):\n", " env = gym.make(ENV_NAME, render_mode=\"rgb_array\") # create raw env\n", " env = PrimaryAtariWrap(env, clip_rewards)\n", - " env = FrameBuffer(env, n_frames=4, dim_order=\"pytorch\")\n", + " env = FrameBuffer(env, n_frames=4, dim_order='pytorch')\n", " return env\n", "\n", - "\n", "env = make_env()\n", "env.reset()\n", "n_actions = env.action_space.n\n", "state_shape = env.observation_space.shape\n", - "n_actions, state_shape\n" + "n_actions, state_shape" ] }, { @@ -441,15 +438,15 @@ "for _ in range(12):\n", " obs, _, _, _, _ = env.step(env.action_space.sample())\n", "\n", - "plt.figure(figsize=[12, 10])\n", + "plt.figure(figsize=[12,10])\n", "plt.title(\"Game image\")\n", "plt.imshow(env.render())\n", "plt.show()\n", "\n", - "plt.figure(figsize=[15, 15])\n", + "plt.figure(figsize=[15,15])\n", "plt.title(\"Agent observation (4 frames top to bottom)\")\n", - "plt.imshow(utils.img_by_obs(obs, state_shape), cmap=\"gray\")\n", - "plt.show()\n" + "plt.imshow(utils.img_by_obs(obs, state_shape), cmap='gray')\n", + "plt.show()" ] }, { @@ -512,11 +509,10 @@ "source": [ "import torch\n", "import torch.nn as nn\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "# those who have a GPU but feel unfair to use it can uncomment:\n", "# device = torch.device('cpu')\n", - "device\n" + "device" ] }, { @@ -534,7 +530,7 @@ " cur_layer_img_h = conv2d_size_out(cur_layer_img_h, kernel_size, stride)\n", " to understand the shape for dense layer's input\n", " \"\"\"\n", - " return (size - (kernel_size - 1) - 1) // stride + 1\n" + " return (size - (kernel_size - 1) - 1) // stride + 1" ] }, { @@ -556,6 +552,7 @@ " # Define your network body here. Please make sure agent is fully contained here\n", " # nn.Flatten() can be useful\n", " \n", + " \n", "\n", " def forward(self, state_t):\n", " \"\"\"\n", @@ -567,9 +564,9 @@ "\n", " assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n", " assert (\n", - " len(qvalues.shape) == 2\n", - " and qvalues.shape[0] == state_t.shape[0]\n", - " and qvalues.shape[1] == n_actions\n", + " len(qvalues.shape) == 2 and \n", + " qvalues.shape[0] == state_t.shape[0] and \n", + " qvalues.shape[1] == n_actions\n", " )\n", "\n", " return qvalues\n", @@ -584,15 +581,16 @@ " return qvalues.data.cpu().numpy()\n", "\n", " def sample_actions(self, qvalues):\n", - " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy.\"\"\"\n", + " \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n", " epsilon = self.epsilon\n", " batch_size, n_actions = qvalues.shape\n", "\n", " random_actions = np.random.choice(n_actions, size=batch_size)\n", " best_actions = qvalues.argmax(axis=-1)\n", "\n", - " should_explore = np.random.choice([0, 1], batch_size, p=[1 - epsilon, epsilon])\n", - " return np.where(should_explore, random_actions, best_actions)\n" + " should_explore = np.random.choice(\n", + " [0, 1], batch_size, p=[1-epsilon, epsilon])\n", + " return np.where(should_explore, random_actions, best_actions)" ] }, { @@ -603,7 +601,7 @@ }, "outputs": [], "source": [ - "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)\n" + "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)" ] }, { @@ -624,25 +622,21 @@ "outputs": [], "source": [ "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000, seed=None):\n", - " \"\"\"Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward.\"\"\"\n", + " \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n", " rewards = []\n", " for _ in range(n_games):\n", " s, _ = env.reset(seed=seed)\n", " reward = 0\n", " for _ in range(t_max):\n", " qvalues = agent.get_qvalues([s])\n", - " action = (\n", - " qvalues.argmax(axis=-1)[0]\n", - " if greedy\n", - " else agent.sample_actions(qvalues)[0]\n", - " )\n", + " action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n", " s, r, terminated, truncated, _ = env.step(action)\n", " reward += r\n", " if terminated or truncated:\n", " break\n", "\n", " rewards.append(reward)\n", - " return np.mean(rewards)\n" + " return np.mean(rewards)" ] }, { @@ -657,7 +651,7 @@ }, "outputs": [], "source": [ - "evaluate(env, agent, n_games=1)\n" + "evaluate(env, agent, n_games=1)" ] }, { @@ -693,19 +687,14 @@ "outputs": [], "source": [ "from replay_buffer import ReplayBuffer\n", - "\n", "exp_replay = ReplayBuffer(10)\n", "\n", "for _ in range(30):\n", - " exp_replay.add(\n", - " env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False\n", - " )\n", + " exp_replay.add(env.reset()[0], env.action_space.sample(), 1.0, env.reset()[0], done=False)\n", "\n", "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(5)\n", "\n", - "assert (\n", - " len(exp_replay) == 10\n", - "), \"experience replay size should be 10 because that's what maximum capacity is\"\n" + "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\"" ] }, { @@ -719,7 +708,7 @@ "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n", " \"\"\"\n", " Play the game for exactly n_steps, record every (s,a,r,s', done) to replay buffer.\n", - " Whenever game ends due to termination or truncation, add record with done=terminated and reset the game.\n", + " Whenever game ends, add record with done=True and reset the game.\n", " It is guaranteed that env has terminated=False when passed to this function.\n", "\n", " PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n", @@ -755,49 +744,31 @@ "\n", "# if you're using your own experience replay buffer, some of those tests may need correction.\n", "# just make sure you know what your code does\n", - "assert len(exp_replay) == 1000, (\n", - " \"play_and_record should have added exactly 1000 steps, \"\n", + "assert len(exp_replay) == 1000, \\\n", + " \"play_and_record should have added exactly 1000 steps, \" \\\n", " \"but instead added %i\" % len(exp_replay)\n", - ")\n", "is_dones = list(zip(*exp_replay._storage))[-1]\n", "\n", - "assert 0 < np.mean(is_dones) < 0.1, (\n", - " \"Please make sure you restart the game whenever it is 'done' and \"\n", - " \"record the is_done correctly into the buffer. Got %f is_done rate over \"\n", - " \"%i steps. [If you think it's your tough luck, just re-run the test]\"\n", - " % (np.mean(is_dones), len(exp_replay))\n", - ")\n", + "assert 0 < np.mean(is_dones) < 0.1, \\\n", + " \"Please make sure you restart the game whenever it is 'done' and \" \\\n", + " \"record the is_done correctly into the buffer. Got %f is_done rate over \" \\\n", + " \"%i steps. [If you think it's your tough luck, just re-run the test]\" % (\n", + " np.mean(is_dones), len(exp_replay))\n", "\n", "for _ in range(100):\n", - " (\n", - " obs_batch,\n", - " act_batch,\n", - " reward_batch,\n", - " next_obs_batch,\n", - " is_done_batch,\n", - " ) = exp_replay.sample(10)\n", + " obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", " assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n", - " assert act_batch.shape == (\n", - " 10,\n", - " ), \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", - " assert reward_batch.shape == (\n", - " 10,\n", - " ), \"rewards batch should have shape (10,) but is instead %s\" % str(\n", - " reward_batch.shape\n", - " )\n", - " assert is_done_batch.shape == (\n", - " 10,\n", - " ), \"is_done batch should have shape (10,) but is instead %s\" % str(\n", - " is_done_batch.shape\n", - " )\n", - " assert [\n", - " int(i) in (0, 1) for i in is_dones\n", - " ], \"is_done should be strictly True or False\"\n", - " assert [\n", - " 0 <= a < n_actions for a in act_batch\n", - " ], \"actions should be within [0, n_actions)\"\n", - "\n", - "print(\"Well done!\")\n" + " assert act_batch.shape == (10,), \\\n", + " \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n", + " assert reward_batch.shape == (10,), \\\n", + " \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n", + " assert is_done_batch.shape == (10,), \\\n", + " \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n", + " assert [int(i) in (0, 1) for i in is_dones], \\\n", + " \"is_done should be strictly True or False\"\n", + " assert [0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions)\"\n", + "\n", + "print(\"Well done!\")" ] }, { @@ -831,7 +802,7 @@ "source": [ "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n", "# This is how you can load weights from agent into target network\n", - "target_network.load_state_dict(agent.state_dict())\n" + "target_network.load_state_dict(agent.state_dict())" ] }, { @@ -886,34 +857,19 @@ }, "outputs": [], "source": [ - "def compute_td_loss(\n", - " states,\n", - " actions,\n", - " rewards,\n", - " next_states,\n", - " is_done,\n", - " agent,\n", - " target_network,\n", - " gamma=0.99,\n", - " check_shapes=False,\n", - " device=device,\n", - "):\n", - " \"\"\"Compute td loss using torch operations only. Use the formulae above.\"\"\"\n", - " states = torch.tensor(\n", - " states, device=device, dtype=torch.float32\n", - " ) # shape: [batch_size, *state_shape]\n", - " actions = torch.tensor(\n", - " actions, device=device, dtype=torch.int64\n", - " ) # shape: [batch_size]\n", - " rewards = torch.tensor(\n", - " rewards, device=device, dtype=torch.float32\n", - " ) # shape: [batch_size]\n", + "def compute_td_loss(states, actions, rewards, next_states, is_done,\n", + " agent, target_network,\n", + " gamma=0.99,\n", + " check_shapes=False,\n", + " device=device):\n", + " \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n", + " states = torch.tensor(states, device=device, dtype=torch.float32) # shape: [batch_size, *state_shape]\n", + " actions = torch.tensor(actions, device=device, dtype=torch.int64) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, device=device, dtype=torch.float32) # shape: [batch_size]\n", " # shape: [batch_size, *state_shape]\n", - " next_states = torch.tensor(\n", - " next_states, device=device, dtype=torch.float\n", - " ) # shape: [batch_size, *state_shape]\n", + " next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n", " is_done = torch.tensor(\n", - " is_done.astype(\"float32\"),\n", + " is_done.astype('float32'),\n", " device=device,\n", " dtype=torch.float32,\n", " ) # shape: [batch_size]\n", @@ -923,21 +879,16 @@ " predicted_qvalues = agent(states) # shape: [batch_size, n_actions]\n", "\n", " # compute q-values for all actions in next states\n", - " predicted_next_qvalues = target_network(\n", - " next_states\n", - " ) # shape: [batch_size, n_actions]\n", - "\n", + " predicted_next_qvalues = target_network(next_states) # shape: [batch_size, n_actions]\n", + " \n", " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[\n", - " range(len(actions)), actions\n", - " ] # shape: [batch_size]\n", + " predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] # shape: [batch_size]\n", "\n", " # compute V*(next_states) using predicted next q-values\n", " next_state_values = \n", "\n", - " assert (\n", - " next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0]\n", - " ), \"must predict one value per state\"\n", + " assert next_state_values.dim() == 1 and next_state_values.shape[0] == states.shape[0], \\\n", + " \"must predict one value per state\"\n", "\n", " # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n", " # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", @@ -945,22 +896,17 @@ " target_qvalues_for_actions = \n", "\n", " # mean squared error loss to minimize\n", - " loss = torch.mean(\n", - " (predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2\n", - " )\n", + " loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2)\n", "\n", " if check_shapes:\n", - " assert (\n", - " predicted_next_qvalues.data.dim() == 2\n", - " ), \"make sure you predicted q-values for all actions in next state\"\n", - " assert (\n", - " next_state_values.data.dim() == 1\n", - " ), \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert (\n", - " target_qvalues_for_actions.data.dim() == 1\n", - " ), \"there's something wrong with target q-values, they must be a vector\"\n", + " assert predicted_next_qvalues.data.dim() == 2, \\\n", + " \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim() == 1, \\\n", + " \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim() == 1, \\\n", + " \"there's something wrong with target q-values, they must be a vector\"\n", "\n", - " return loss\n" + " return loss" ] }, { @@ -980,32 +926,19 @@ }, "outputs": [], "source": [ - "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n", - " 10\n", - ")\n", + "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(10)\n", "\n", - "loss = compute_td_loss(\n", - " obs_batch,\n", - " act_batch,\n", - " reward_batch,\n", - " next_obs_batch,\n", - " is_done_batch,\n", - " agent,\n", - " target_network,\n", - " gamma=0.99,\n", - " check_shapes=True,\n", - ")\n", + "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n", + " agent, target_network,\n", + " gamma=0.99, check_shapes=True)\n", "loss.backward()\n", "\n", - "assert (\n", - " loss.requires_grad and tuple(loss.data.size()) == ()\n", - "), \"you must return scalar loss - mean over batch\"\n", - "assert np.any(\n", - " next(agent.parameters()).grad.data.cpu().numpy() != 0\n", - "), \"loss must be differentiable w.r.t. network weights\"\n", - "assert np.all(\n", - " next(target_network.parameters()).grad is None\n", - "), \"target network should not have grads\"\n" + "assert loss.requires_grad and tuple(loss.data.size()) == (), \\\n", + " \"you must return scalar loss - mean over batch\"\n", + "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() != 0), \\\n", + " \"loss must be differentiable w.r.t. network weights\"\n", + "assert np.all(next(target_network.parameters()).grad is None), \\\n", + " \"target network should not have grads\"" ] }, { @@ -1032,7 +965,7 @@ "source": [ "from tqdm import trange\n", "from IPython.display import clear_output\n", - "import matplotlib.pyplot as plt\n" + "import matplotlib.pyplot as plt" ] }, { @@ -1050,7 +983,7 @@ "seed = \n", "random.seed(seed)\n", "np.random.seed(seed)\n", - "torch.manual_seed(seed)\n" + "torch.manual_seed(seed)" ] }, { @@ -1072,7 +1005,7 @@ "\n", "agent = DQNAgent(state_shape, n_actions, epsilon=1).to(device)\n", "target_network = DQNAgent(state_shape, n_actions).to(device)\n", - "target_network.load_state_dict(agent.state_dict())\n" + "target_network.load_state_dict(agent.state_dict())" ] }, { @@ -1104,18 +1037,17 @@ "exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)\n", "for i in trange(REPLAY_BUFFER_SIZE // N_STEPS):\n", " if not utils.is_enough_ram(min_available_gb=0.1):\n", - " print(\n", - " \"\"\"\n", + " print(\"\"\"\n", " Less than 100 Mb RAM available.\n", " Make sure the buffer size in not too huge.\n", " Also check, maybe other processes consume RAM heavily.\n", " \"\"\"\n", - " )\n", + " )\n", " break\n", " play_and_record(state, agent, env, exp_replay, n_steps=N_STEPS)\n", " if len(exp_replay) == REPLAY_BUFFER_SIZE:\n", " break\n", - "print(len(exp_replay))\n" + "print(len(exp_replay))" ] }, { @@ -1142,7 +1074,7 @@ "\n", "max_grad_norm = 50\n", "\n", - "n_lives = 5\n" + "n_lives = 5" ] }, { @@ -1157,7 +1089,7 @@ "td_loss_history = []\n", "grad_norm_history = []\n", "initial_state_v_history = []\n", - "step = 0\n" + "step = 0" ] }, { @@ -1170,13 +1102,12 @@ "source": [ "import time\n", "\n", - "\n", "def wait_for_keyboard_interrupt():\n", " try:\n", " while True:\n", " time.sleep(1)\n", " except KeyboardInterrupt:\n", - " pass\n" + " pass" ] }, { @@ -1191,13 +1122,11 @@ "with trange(step, total_steps + 1) as progress_bar:\n", " for step in progress_bar:\n", " if not utils.is_enough_ram():\n", - " print(\"less that 100 Mb RAM available, freezing\")\n", - " print(\"make sure everything is ok and use KeyboardInterrupt to continue\")\n", + " print('less that 100 Mb RAM available, freezing')\n", + " print('make sure everything is ok and use KeyboardInterrupt to continue')\n", " wait_for_keyboard_interrupt()\n", "\n", - " agent.epsilon = utils.linear_decay(\n", - " init_epsilon, final_epsilon, step, decay_steps\n", - " )\n", + " agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n", "\n", " # play\n", " _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n", @@ -1221,20 +1150,17 @@ " \n", "\n", " if step % eval_freq == 0:\n", - " mean_rw_history.append(\n", - " evaluate(\n", - " make_env(clip_rewards=True),\n", - " agent,\n", - " n_games=3 * n_lives,\n", - " greedy=True,\n", - " seed=step,\n", - " )\n", + " mean_rw_history.append(evaluate(\n", + " make_env(clip_rewards=True), agent, n_games=3 * n_lives, greedy=True, seed=step)\n", + " )\n", + " initial_state_q_values = agent.get_qvalues(\n", + " [make_env().reset(seed=step)[0]]\n", " )\n", - " initial_state_q_values = agent.get_qvalues([make_env().reset(seed=step)[0]])\n", " initial_state_v_history.append(np.max(initial_state_q_values))\n", "\n", " clear_output(True)\n", - " print(\"buffer size = %i, epsilon = %.5f\" % (len(exp_replay), agent.epsilon))\n", + " print(\"buffer size = %i, epsilon = %.5f\" %\n", + " (len(exp_replay), agent.epsilon))\n", "\n", " plt.figure(figsize=[16, 9])\n", "\n", @@ -1259,7 +1185,7 @@ " plt.plot(utils.smoothen(grad_norm_history))\n", " plt.grid()\n", "\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -1291,16 +1217,12 @@ "outputs": [], "source": [ "final_score = evaluate(\n", - " make_env(clip_rewards=False),\n", - " agent,\n", - " n_games=30,\n", - " greedy=True,\n", - " t_max=10 * 1000,\n", - " seed=9,\n", + " make_env(clip_rewards=False),\n", + " agent, n_games=30, greedy=True, t_max=10 * 1000, seed=9\n", ")\n", - "print(\"final score:\", final_score)\n", - "assert final_score >= 3, \"not as cool as DQN can\"\n", - "print(\"Cool!\")\n" + "print('final score:', final_score)\n", + "assert final_score >= 3, 'not as cool as DQN can'\n", + "print('Cool!')" ] }, { @@ -1395,26 +1317,22 @@ "from base64 import b64encode\n", "from IPython.display import HTML\n", "\n", - "video_paths = sorted([s for s in Path(\"videos\").iterdir() if s.suffix == \".mp4\"])\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", "video_path = video_paths[-1] # You can also try other indices\n", "\n", - "if \"google.colab\" in sys.modules:\n", + "if 'google.colab' in sys.modules:\n", " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open(\"rb\") as fp:\n", + " with video_path.open('rb') as fp:\n", " mp4 = fp.read()\n", - " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", "else:\n", " data_url = str(video_path)\n", "\n", - "HTML(\n", - " \"\"\"\n", + "HTML(\"\"\"\n", "\n", - "\"\"\".format(\n", - " data_url\n", - " )\n", - ")\n" + "\"\"\".format(data_url))" ] }, { @@ -1438,9 +1356,9 @@ "source": [ "eval_env = make_env(clip_rewards=False)\n", "record = utils.play_and_log_episode(eval_env, agent)\n", - "print(\"total reward for life:\", np.sum(record[\"rewards\"]))\n", + "print('total reward for life:', np.sum(record['rewards']))\n", "for key in record:\n", - " print(key)\n" + " print(key)" ] }, { @@ -1454,18 +1372,17 @@ "fig = plt.figure(figsize=(5, 5))\n", "ax = fig.add_subplot(1, 1, 1)\n", "\n", - "ax.scatter(record[\"v_mc\"], record[\"v_agent\"])\n", - "ax.plot(\n", - " sorted(record[\"v_mc\"]), sorted(record[\"v_mc\"]), \"black\", linestyle=\"--\", label=\"x=y\"\n", - ")\n", + "ax.scatter(record['v_mc'], record['v_agent'])\n", + "ax.plot(sorted(record['v_mc']), sorted(record['v_mc']),\n", + " 'black', linestyle='--', label='x=y')\n", "\n", "ax.grid()\n", "ax.legend()\n", - "ax.set_title(\"State Value Estimates\")\n", - "ax.set_xlabel(\"Monte-Carlo\")\n", - "ax.set_ylabel(\"Agent\")\n", + "ax.set_title('State Value Estimates')\n", + "ax.set_xlabel('Monte-Carlo')\n", + "ax.set_ylabel('Agent')\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { diff --git a/week04_approx_rl/seminar_pytorch.ipynb b/week04_approx_rl/seminar_pytorch.ipynb index b32f4041..2315a06e 100644 --- a/week04_approx_rl/seminar_pytorch.ipynb +++ b/week04_approx_rl/seminar_pytorch.ipynb @@ -71,7 +71,7 @@ "state_dim = env.observation_space.shape\n", "\n", "plt.imshow(env.render())\n", - "env.close()\n" + "env.close()" ] }, { @@ -106,7 +106,7 @@ "source": [ "import torch\n", "import torch.nn as nn\n", - "import torch.nn.functional as F\n" + "import torch.nn.functional as F" ] }, { @@ -144,7 +144,7 @@ "\n", " \n", "\n", - " return int( )\n" + " return int( )" ] }, { @@ -156,44 +156,24 @@ "outputs": [], "source": [ "s, _ = env.reset()\n", - "assert tuple(network(torch.tensor([s] * 3, dtype=torch.float32)).size()) == (\n", - " 3,\n", - " n_actions,\n", - "), \"please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]\"\n", - "assert isinstance(\n", - " list(network.modules())[-1], nn.Linear\n", - "), \"please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)\"\n", - "assert isinstance(\n", - " get_action(network, s), int\n", - "), \"get_action(s) must return int, not %s. try int(action)\" % (\n", - " type(get_action(network, s))\n", - ")\n", - " list(network.modules())[-1], nn.Linear\n", - "), \"please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)\"\n", - "for eps in [0.0, 0.1, 0.5, 1.0]:\n", - " get_action(network, s), int\n", - " [get_action(network, s, epsilon=eps) for i in range(10000)], minlength=n_actions\n", - " )\n", - " type(get_action(network, s))\n", - " assert (\n", - " abs(state_frequencies[best_action] - 10000 * (1 - eps + eps / n_actions)) < 200\n", - " )\n", + "assert tuple(network(torch.tensor([s]*3, dtype=torch.float32)).size()) == (\n", + " 3, n_actions), \"please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]\"\n", + "assert isinstance(list(network.modules(\n", + "))[-1], nn.Linear), \"please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)\"\n", + "assert isinstance(get_action(network, s), int), \"get_action(s) must return int, not %s. try int(action)\" % (type(get_action(network, s)))\n", + "\n", "# test epsilon-greedy exploration\n", - "for eps in [0.0, 0.1, 0.5, 1.0]:\n", - " assert (\n", - " abs(state_frequencies[other_action] - 10000 * (eps / n_actions)) < 200\n", - " )\n", - " print(\"e=%.1f tests passed\" % eps)\n", + "for eps in [0., 0.1, 0.5, 1.0]:\n", + " state_frequencies = np.bincount(\n", + " [get_action(network, s, epsilon=eps) for i in range(10000)], minlength=n_actions)\n", " best_action = state_frequencies.argmax()\n", - " assert (\n", - " abs(state_frequencies[best_action] - 10000 * (1 - eps + eps / n_actions)) < 200\n", - " )\n", + " assert abs(state_frequencies[best_action] -\n", + " 10000 * (1 - eps + eps / n_actions)) < 200\n", " for other_action in range(n_actions):\n", " if other_action != best_action:\n", - " assert (\n", - " abs(state_frequencies[other_action] - 10000 * (eps / n_actions)) < 200\n", - " )\n", - " print(\"e=%.1f tests passed\" % eps)\n" + " assert abs(state_frequencies[other_action] -\n", + " 10000 * (eps / n_actions)) < 200\n", + " print('e=%.1f tests passed' % eps)" ] }, { @@ -225,25 +205,22 @@ }, "outputs": [], "source": [ - "def compute_td_loss(\n", - " states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False\n", - "):\n", - " \"\"\"Compute td loss using torch operations only. Use the formula above.\"\"\"\n", + "def compute_td_loss(states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False):\n", + " \"\"\" Compute td loss using torch operations only. Use the formula above. \"\"\"\n", " states = torch.tensor(\n", - " states, dtype=torch.float32\n", - " ) # shape: [batch_size, state_size]\n", - " actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n", - " rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n", + " states, dtype=torch.float32) # shape: [batch_size, state_size]\n", + " actions = torch.tensor(actions, dtype=torch.long) # shape: [batch_size]\n", + " rewards = torch.tensor(rewards, dtype=torch.float32) # shape: [batch_size]\n", " # shape: [batch_size, state_size]\n", " next_states = torch.tensor(next_states, dtype=torch.float32)\n", - " is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n", + " is_done = torch.tensor(is_done, dtype=torch.uint8) # shape: [batch_size]\n", "\n", " # get q-values for all actions in current states\n", - " predicted_qvalues = network(states) # shape: [batch_size, n_actions]\n", + " predicted_qvalues = network(states) # shape: [batch_size, n_actions]\n", "\n", " # select q-values for chosen actions\n", - " predicted_qvalues_for_actions = predicted_qvalues[ # shape: [batch_size]\n", - " range(states.shape[0]), actions\n", + " predicted_qvalues_for_actions = predicted_qvalues[ # shape: [batch_size]\n", + " range(states.shape[0]), actions\n", " ]\n", "\n", " # compute q-values for all actions in next states\n", @@ -258,26 +235,21 @@ "\n", " # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n", " target_qvalues_for_actions = torch.where(\n", - " is_done, rewards, target_qvalues_for_actions\n", - " )\n", + " is_done, rewards, target_qvalues_for_actions)\n", "\n", " # mean squared error loss to minimize\n", - " loss = torch.mean(\n", - " (predicted_qvalues_for_actions - target_qvalues_for_actions.detach()) ** 2\n", - " )\n", + " loss = torch.mean((predicted_qvalues_for_actions -\n", + " target_qvalues_for_actions.detach()) ** 2)\n", "\n", " if check_shapes:\n", - " assert (\n", - " predicted_next_qvalues.data.dim() == 2\n", - " ), \"make sure you predicted q-values for all actions in next state\"\n", - " assert (\n", - " next_state_values.data.dim() == 1\n", - " ), \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", - " assert (\n", - " target_qvalues_for_actions.data.dim() == 1\n", - " ), \"there's something wrong with target q-values, they must be a vector\"\n", - "\n", - " return loss\n" + " assert predicted_next_qvalues.data.dim(\n", + " ) == 2, \"make sure you predicted q-values for all actions in next state\"\n", + " assert next_state_values.data.dim(\n", + " ) == 1, \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n", + " assert target_qvalues_for_actions.data.dim(\n", + " ) == 1, \"there's something wrong with target q-values, they must be a vector\"\n", + "\n", + " return loss" ] }, { @@ -296,10 +268,8 @@ "loss.backward()\n", "\n", "assert len(loss.size()) == 0, \"you must return scalar loss - mean over batch\"\n", - "assert np.any(\n", - " next(network.parameters()).grad.detach().numpy() != 0\n", - "), \"loss must be differentiable w.r.t. network weights\"\n", - "), \"loss must be differentiable w.r.t. network weights\"\n" + "assert np.any(next(network.parameters()).grad.detach().numpy() !=\n", + " 0), \"loss must be differentiable w.r.t. network weights\"" ] }, { @@ -319,7 +289,7 @@ }, "outputs": [], "source": [ - "opt = torch.optim.Adam(network.parameters(), lr=1e-4)\n" + "opt = torch.optim.Adam(network.parameters(), lr=1e-4)" ] }, { @@ -349,7 +319,7 @@ " if terminated or truncated:\n", " break\n", "\n", - " return total_reward\n" + " return total_reward" ] }, { @@ -360,7 +330,7 @@ }, "outputs": [], "source": [ - "epsilon = 0.5\n" + "epsilon = 0.5" ] }, { @@ -372,27 +342,15 @@ "outputs": [], "source": [ "for i in range(1000):\n", - " session_rewards = [\n", - " generate_session(env, epsilon=epsilon, train=True) for _ in range(100)\n", - " ]\n", - " print(\n", - " \"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(\n", - " i, np.mean(session_rewards), epsilon\n", - " )\n", - " )\n", - " ]\n", - " print(\n", - " \"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(\n", - " i, np.mean(session_rewards), epsilon\n", - " )\n", - " )\n", - " break\n", + " session_rewards = [generate_session(env, epsilon=epsilon, train=True) for _ in range(100)]\n", + " print(\"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(i, np.mean(session_rewards), epsilon))\n", + "\n", " epsilon *= 0.99\n", " assert epsilon >= 1e-4, \"Make sure epsilon is always nonzero during training\"\n", "\n", " if np.mean(session_rewards) > 300:\n", " print(\"You Win!\")\n", - " break\n" + " break" ] }, { @@ -457,30 +415,22 @@ "from base64 import b64encode\n", "from IPython.display import HTML\n", "\n", - "video_paths = sorted([s for s in Path(\"videos\").iterdir() if s.suffix == \".mp4\"])\n", + "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n", "video_path = video_paths[-1] # You can also try other indices\n", "\n", - "if \"google.colab\" in sys.modules:\n", + "if 'google.colab' in sys.modules:\n", " # https://stackoverflow.com/a/57378660/1214547\n", - " with video_path.open(\"rb\") as fp:\n", + " with video_path.open('rb') as fp:\n", " mp4 = fp.read()\n", - " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + " data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n", "else:\n", " data_url = str(video_path)\n", "\n", - "HTML(\n", - " \"\"\"\n", - " \"\"\"\n", + "HTML(\"\"\"\n", "\n", + "\"\"\".format(data_url))" ] } ],