{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Q-Network Training\n", "\n", "In this worksheet we are going to go over training a Reinforcement Learning (RL) agent trained on ATARI 2600 games. We will use the famous game of pong as a our training test bed.\n", "\n", "One can download the code required to train the DQN agent at https://sites.google.com/a/deepmind.com/dqn/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "We first start by setting up a function to update our lua path so that DQN related classes can be loaded." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "require 'cutorch'\n", "require 'cunn'\n", "require 'alewrap'\n", "\n", "torch.setdefaulttensortype('torch.FloatTensor')\n", "torch.setnumthreads(4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Game Environment\n", "We create an RL environment that wraps an ATARI game so that we can train an agent.\n", "The agent will have no knowledge of what game it is playing." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "\n", "Playing:\tpong\t\n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-- Create the Game Object\n", "game_options = {\n", " -- name of the game to play (you need the ROM file for this game)\n", " env='pong',\n", " -- directory where the ROMS are stored\n", " game_path='/home/ubuntu/torch/install/share/lua/5.1/dqn/roms/',\n", " -- we want get RGB frames\n", " env_params = {useRGB = true},\n", " -- we will repeat each action 4 times\n", " actrep = 4,\n", " -- for every new episode, play null actions a random number of time [0,30]\n", " random_starts = 30,\n", " -- use gpu\n", " gpu = 1,\n", " -- have some info logs\n", " verbose = 2\n", "}\n", "game_env = alewrap.GameEnvironment(game_options)\n", "game_actions = game_env:getActions()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{\n", " 1 : 0\n", " 2 : 3\n", " 3 : 4\n", "}\n", "\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-- These are the valid actions for this game. The total possible action set is in (0-17)\n", "print(game_actions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Agent\n", "We create an agent that can **learn to play** any game" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Creating Agent Network from dqn.convnet_atari3\t\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "nn.Sequential {\n", " [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> output]\n", " (1): nn.Reshape(4x84x84)\n", " (2): nn.SpatialConvolution(in: 4, out: 32, kW: 8, kH: 8, dW: 4, dH: 4, padding: 1)\n", " (3): nn.Rectifier\n", " (4): nn.SpatialConvolution(in: 32, out: 64, kW: 4, kH: 4, dW: 2, dH: 2)\n", " (5): nn.Rectifier\n", " (6): nn.SpatialConvolution(in: 64, out: 64, kW: 3, kH: 3)\n", " (7): nn.Rectifier\n", " (8): nn.Reshape(3136)\n", " (9): nn.Linear(3136 -> 512)\n", " (10): nn.Rectifier\n", " (11): nn.Linear(512 -> 3)\n", "}\n", "{\n", " gradInput : CudaTensor - empty\n", " modules : \n", " {\n", " 1 : \n", " nn.Reshape(4x84x84)\n", " {\n", " nelement : 28224\n", " _input : CudaTensor - empty\n", " output : CudaTensor - size: 1x4x84x84\n", " gradInput : CudaTensor - empty\n", " size : LongStorage - size: 3\n", " _gradOutput : CudaTensor - empty\n", " batchsize : LongStorage - size: 4\n", " }\n", " 2 : \n", " nn.SpatialConvolution(in: 4, out: 32, kW: 8, kH: 8, dW: 4, dH: 4, padding: 1)\n", " {\n", " dH : 4\n", " dW : 4\n", " nOutputPlane : 32\n", " output : CudaTensor - size: 1x32x20x20\n", " gradInput : CudaTensor - empty\n", " " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "finput : CudaTensor - empty\n", " fgradInput : CudaTensor - empty\n", " gradBias : CudaTensor - size: 32\n", " weight : CudaTensor - size: 32x4x8x8\n", " bias : CudaTensor - size: 32\n", " gradWeight : CudaTensor - size: 32x4x8x8\n", " padding : 1\n", " nInputPlane : 4\n", " kW : 8\n", " kH : 8\n", " }\n", " 3 : \n", " nn.Rectifier\n", " {\n", " gradInput : CudaTensor - empty\n", " output : CudaTensor - size: 1x32x20x20\n", " }\n", " 4 : \n", " nn.SpatialConvolution(in: 32, out: 64, kW: 4, kH: 4, dW: 2, dH: 2)\n", " {\n", " dH : 2\n", " dW : 2\n", " nOutputPlane : 64\n", " output : CudaTensor - size: 1x64x9x9\n", " gradInput : CudaTensor - empty\n", " finput : CudaTensor - empty\n", " fgradInput : CudaTensor - empty\n", " gradBias : CudaTensor - size: 64\n", " weight : CudaTensor - size: 64x32x4x4\n", " bias : CudaTensor - size: 64\n", " gradWeight : CudaTensor - size: 64x32x4x4\n", " padding : 0\n", " nInputPlane : 32\n", " kW : 4\n", " kH : 4\n", " }\n", " 5 : \n", " nn.Rectifier\n", " {\n", " gradInput : CudaTensor - empty\n", " output : CudaTensor - size: 1x64x9x9\n", " }\n", " 6 : \n", " nn.SpatialConvolution(in: 64, out: 64, kW: 3, kH: 3)\n", " {\n", " dH : 1\n", " dW : 1\n", " nOutputPlane : 64\n", " output : CudaTensor - size: 1x64x7x7\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ " gradInput : CudaTensor - empty\n", " finput : CudaTensor - empty\n", " fgradInput : CudaTensor - empty\n", " gradBias : CudaTensor - size: 64\n", " weight : CudaTensor - size: 64x64x3x3\n", " bias : CudaTensor - size: 64\n", " gradWeight : CudaTensor - size: 64x64x3x3\n", " padding : 0\n", " nInputPlane : 64\n", " kW : 3\n", " kH : 3\n", " }\n", " 7 : \n", " nn.Rectifier\n", " {\n", " gradInput : CudaTensor - empty\n", " output : CudaTensor - size: 1x64x7x7\n", " }\n", " 8 : \n", " nn.Reshape(3136)\n", " {\n", " nelement : 3136\n", " _input : CudaTensor - empty\n", " output : CudaTensor - empty\n", " gradInput : CudaTensor - empty\n", " size : LongStorage - size: 1\n", " _gradOutput : CudaTensor - empty\n", " batchsize : LongStorage - size: 2\n", " }\n", " 9 : \n", " nn.Linear(3136 -> 512)\n", " {\n", " gradBias : CudaTensor - size: 512\n", " weight : CudaTensor - size: 512x3136\n", " bias : CudaTensor - size: 512\n", " gradInput : CudaTensor - empty\n", " gradWeight : CudaTensor - size: 512x3136\n", " output : CudaTensor - empty\n", " }\n", " 10 : \n", " nn.Rectifier\n", " {\n", " gradInput : CudaTensor - empty\n", " output : CudaTensor - empty\n", " }\n", " 11 : \n", " nn.Linear(512 -> 3)\n", " {\n", " gradBias : CudaTensor - size: 3\n", " weight : CudaTensor - size: 3x512\n", " bias : CudaTensor - size: 3\n", " gradInput : CudaTensor - empty\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ " gradWeight : CudaTensor - size: 3x512\n", " output : CudaTensor - empty\n", " }\n", " }\n", " output : CudaTensor - empty\n", "}\n", "Convolutional layers flattened output size:\t3136\t\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'dqn'\n", "\n", "agent_params = {\n", " -- The agent only knows about actions it can take in the environment\n", " actions = game_actions,\n", " -- we will use gpu\n", " gpu = 1,\n", " -- we will print info\n", " verbose = 2,\n", " -- learning rate for SGD\n", " lr=0.00025,\n", " -- Random exploration ratio, start from 100% exploration\n", " ep=1,\n", " -- Drop down to 10% exploration\n", " ep_end=0.1,\n", " -- Linear decay over 1M steps\n", " ep_endt=1000000,\n", " -- Discount factor \\gamma for Q-Learning\n", " discount=0.99,\n", " -- Number of frames to input into convolutional net\n", " hist_len=4,\n", " -- Learning starts after a delay of 50K actions, we do not want to overfit onto early experience\n", " learn_start=50000,\n", " -- We will store last 1M transitions\n", " replay_memory=1000000,\n", " -- We will update every 4 actions\n", " update_freq=4,\n", " -- Will update only once\n", " n_replay=1,\n", " -- Network spec\n", " network= \"dqn.convnet_atari3\",\n", " -- pre-processing spec (just scale down to grayscale 84x84)\n", " preproc=\"dqn.net_downsample_2x_full_y\",\n", " -- size of inputs after rescale (84*84)\n", " state_dim=7056,\n", " -- size of minibatch for SGD\n", " minibatch_size=32,\n", " -- we will scale reward values to limit to 1,-1\n", " rescale_r=1,\n", " -- we use Y channel\n", " ncols=1,\n", " -- buffer on GPU\n", " bufferSize=512,\n", " -- set of validation transitions to track training progress\n", " valid_size=500,\n", " -- update target Q network every 10K updates\n", " target_q=10000,\n", " -- we will clip errors that go into DQN\n", " clip_delta=1,\n", " -- clip reward between -1,1\n", " min_reward=-1,\n", " -- clip reward between -1,1\n", " max_reward=1\n", "}\n", "agent = dqn.NeuralQLearner(agent_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logs to Track\n", "We will log the following quantities to track training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "-- The following quantities are collected during evaluation of the agent\n", "\n", "-- How many times did the agent end up in reward states\n", "reward_counts = {}\n", "-- How many episodes did the agent finish\n", "episode_counts = {}\n", "-- How long does it take to learn/test\n", "time_history = { 0 }\n", "-- The reward agent gets during testing\n", "reward_history = {}\n", "\n", "-- The following are training measures collected by agent during training\n", "\n", "-- Maximum q-value during training\n", "qmax_history = {}\n", "-- Value of the validation states\n", "v_history = {}\n", "-- TD error over the validation states\n", "td_history = {}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What does the input to agent look like?\n", "The current state of the game can be grabbed by getState() fucnction call on the environment." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{\n", " reward : 0\n", " screen : CudaTensor - size: 1x3x210x160\n", " terminal : false\n", "}\n" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "screen, reward, terminal = game_env:getState()\n", "print({screen = screen, reward = reward, terminal = terminal})\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What does the game look like?\n", "We can play some random actions and look at the screens. We should be able to play animations with ipython!" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA8wAAAT4CAIAAABCBtwMAAAgAElEQVR4nO3dMW4cV9qo4eqLTm9i5eNJZhagiDcQtADBAhj9UM41aALGDDxbsHKFBBrwAgwnjLSAmWQ8OdfQN/CAv+ChRIp8u6urzvPAQQNsog/wQfCrj0fFzX6/nwAAgM52mqbNZnPv13764bvjHoaDuNjdTqa8dqY8AlMegSmPwJRHcLG73U7HHeeb168efM/Pv/x6hJMcx83l+YPvObu6PsJJTPlwTPlLTPkQTPlwTPlLTPkQTPlwTmTK/+fQHwAAAKMR2QAAEBPZAAAQ28742Z/f/nnMbaGl+/z2z2NuC62DKY/AlEdgyiMw5RGY8tHYZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBAbDvjZ795/WrGTz++m8vzuY8wA1MegSmPwJRHYMojMOWjsckGAICYyAYAgJjIBgCA2Ga/3394+2LuY3BAF7vbaZpMed1MeQSmPAJTHoEpj+Bid2uTDQAAMZENAACxzX6/n/sMAACwKjbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAENtO0/Th7Yu5j8EBXexuJ1NeO1MegSmPwJRHYMojuNjd2mQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBsO/cBAAAW5uby/O712dX1jCfhZNlkAwBAzCYbAAD+6Jk/r7DJBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmOdkA0DJ7wIEJptsAADIiWwAAIiJbAAAiIlsAACIiWwAAIh5usgy+LfqAHA6/L+YB9lkAwBAzCYbAAD+6Jk/r7DJBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmOdkAxyP3946ApMFJptsAADIiWwAAIiJbAAAiLmTDafCbV0AWA2bbAAAiNlkL4O9JgDAgthkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQMxvfAQ4Hr+9FWAQNtkAABAT2QAAEBPZAAAQcycbToXbugCwGjbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAENvs9/u5zwAAAKtikw0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAse00TZvN5t6v/fTDd8c9DAdxsbudTHntTHkEpjwCUx6BKY/gYne7nY47zjevXz34np9/+fUIJzmOm8vzB99zdnV9hJOY8uGY8peY8iGY8uGY8peY8iGY8uGcyJRdFwEAgJjIBgCAmMgGAIDYdsbP/vz2z2NuCy3d57d/HnNbaB1MeQSmPAJTHoEpj8CUj8YmGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABi2xk/+83rVzN++vHdXJ7PfYQZmPIITHkEpjwCUx6BKR+NTTYAAMRENgAAxEQ2AADENvv9/sPbF3MfgwO62N1O02TK62bKIzDlEZjyCEx5BBe7W5tsAACIiWwAAIht9vv93GcAAIBVsckGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCA2Haapg9vX8x9DA7oYnc7mfLamfIITHkEpjwCUx7Bxe7WJhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYtu5D8Cj3Fye370+u7qe8SQAADzIJhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABi27kPwKOcXV3PfQQO7uby/O61iQPAotlkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBAbDv3AQAGcnN5fvf67Op6xpMAcFA22QAAELPJhlNhrwkAq2GTDQAAMZtsACi5eQ9MNtkAAJCzyQYA+DZ+XjGCZ07ZJhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABifuMjwPH4zXAAg7DJBgCAmMgGAICYyAYAgJg72QBQcvMemGyyAQAgZ5MNAPBt/LxiBM+csk02AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMQ2+/1+7jMAAMCq2GQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQGw7TdNms7n3az/98N1xD8NBXOxuJ1NeO1MegSmPwJRHYMojuNjdbqfjjvPN61cPvufnX349wkmO4+by/MH3nF1dH+Ekpnw4pvwlpnwIpnw4pvwlpnwIpnw4JzJl10UAACAmsgEAICayAQAgtp3xsz+//fOY20JL9/ntn8fcFloHUx6BKY/AlEdgyiMw5aOxyQYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCA2HbGz37z+tWMn358N5fncx9hBqY8AlMegSmPwJRHYMpHY5MNAAAxkQ0AADGRDQAAsc1+v//w9sXcx+CALna30zSZ8rqZ8ghMeQSmPAJTHsHF7tYmGwAAYiIbAABim/1+P/cZAABgVWyyAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAILadpunD2xdzH4MDutjdTqa8dqY8AlMegSmPwJRHcLG7tckGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAIDYdu4DAADAybm5PL97fXZ1/a3fbpMNAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAse3cBwAAWJiby/O712dX1zOehMN55mRtsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACC2nfsAAAALc3Z1PfcROHU22QAAEBPZAAAQE9kAABAT2QAAEPMPH2Fcf/nz93ev//mv32Y8CQCsjMgGFublu/fTNH36+Pe5DwL3u7k8v3vtGRQwLNdFYFx//f5Pd//NfZbHevnu/ct3f5v7FADwAJvsZbAXAQBYEJENsGZu3gPMwnURYGE+ffzRhexVevnu/e8X7gFWwCYbWBJ5/a0+v3B/ypvs32/bf/r449wHAWjYZAMAQMwmGwCWzc17OEEiG06FZ8gwOLft18qz7RmTyAZgfgrsOU755r3b9gzLnWwAAIjZZAMAnDo37xfHJhsAOCC37dfKs+2/ziYbADgUeV05tZv3bts/yCYbAABiNtkwrn/89u+5jwAA6ySyAY7H09BHYLIMwm37rxPZMK5TuNUHwBLJ6weJ7GWwFwEAWBCRDbBmbt4DzMLTRQAAIGaTDbBmbt4DzMImGwAAYjbZALBsbt7DCRLZcCo8QwYAVkNkA8CyuXkPJ8idbAAAiNlkAwCcOjfvF8cmGwAAYjbZAACnzs37xbHJBgCAmMgGAICY6yIAx+Np6ACDsMkGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCAmMgGAICYyAYAgJjIBgCA2Ga/3899BgAAWBWbbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIbadp2mw2937tpx++O+5hOIiL3e1kymtnyiMw5RGY8ghMeQQXu9vtdNxxvnn96sH3/PzLr0c4yXHcXJ4/+J6zq+sjnMSUD8eUv8SUD8GUD8eUv8SUD8GUD+dEpuy6CAAAxEQ2AADERDYAAMS2M37257d/HnNbaOk+v/3zmNtC62DKIzDlEZjyCEx5BKZ8NDbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABDbzvjZb16/mvHTj+/m8nzuI8zAlEdgyiMw5RGY8ghM+WhssgEAICayAQAgJrIBACC22e/3H96+mPsYHNDF7naaJlNeN1MegSmPwJRHYMojuNjd2mQDAEBMZAMAQGyz3+/nPgMAAKyKTTYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADEttM0fXj7Yu5jcEAXu9vJlNfOlEdgyiMw5aW4uTy/e312df1N32vKS/HMKdtkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBAbDv3AQBgVW4uz+9en11dz3gSDsdkeZBNNgAAxGyyAQDgj5758wqbbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiG3nPgAArMrZ1fXcRwDmZ5MNAAAxkQ0AADGRDQAAMZENAAAxkQ0AADFPFwE4npvL87vXnkEBsGI22QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABDzGx+XwW+JAwBYEJtsAACIiWwAAIiJbAAAiIlsAACIiWwAAIh5ugicCs+QGYHJAgzCJhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABifuPjMvgtcQAAC2KTDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxTxeBU+EZMgCwGjbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQ2z7he/7y5+/vXv/zX791hwEAgDVY8Cb75bv3L9+9n/sUAADwR0+J7L9+/6e7//IDPdLLd+9fvvvbXJ8OAABfseBNNgAAnKan3MlmBG7eAwA82YI32Z8+/vjp49/nPgU9t+0BgKVb6iZbXh/a5xfuj7nJ/v22/aePPx7tEwEAcgveZAMAwGla6iYbeD437wHgQGyyOTlu26+V2/YAjMMmm9Mir4/pmDfv3bYHYCg22QAAELPJBlgzN+8BZmGTDRyJ2/Zr5bY9wH+zyQaOQV7P5dA37922B7iXTTYAAMSessn+x2//zs8BADyNm/dwgmyyAXgWt+3Xym17eI6nbLL9LRmA38nrU3CIm/du28Mz2WQDAEDM00W4n5v3AHA63LxfHJtsAOAebtuvldv2x2GTzf38LRlgZPL61FQ37922PxqbbAAAiNlkw7jcvAeAA7HJBgAYiNv2x2GTDeNy8x5gNPL6aGyyAQAgZpMNsGZu3gPMwiYbAABiNtkAa+bmPcAsbLIBACBmkw0Ay+bmPZwgm2wAAIjZZAPAsrl5DyfIJhsAAGI22QAAp87N+8WxyQYAgJhNNgDAqXPzfnFssgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACAmsgEAICayAQAgJrIBACC22e/3c58BAABWxSYbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGIiGwAAYiIbAABiIhsAAGLbaZo2m829X/vph++OexgO4mJ3O5ny2pnyCEx5BKY8AlMewcXudjsdd5xvXr968D0///LrEU5yHDeX5w++5+zq+ggnMeXDMeUvMeVDMOXDMeUvMeVDMOXDOZEpuy4CAAAxkQ0AADGRDQAAse2Mn/357Z/H3BZaus9v/zzmttA6mPIITHkEpjwCUx6BKR+NTTYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxLYzfvab169m/PTju7k8n/sIMzDlEZjyCEx5BKY8AlM+GptsAACIiWwAAIiJbAAAiG32+/2Hty/mPgYHdLG7nabJlNfNlEdgyiMw5RGY8ggudrc22QAAEBPZAAAQ2+z3+7nPAAAAq2KTDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AADGRDQAAMZENAAAxkQ0AALHtNE0f3r6Y+xgc0MXudjLltTPlEZjyCEx5BKY8govdrU02AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMS2cx+AR7m5PL97fXZ1PeNJAAB4kE02AADEbLLhVPh5BQCshk02AADERDYAAMRENgAAxEQ2AADERDYAAMQ8XQTgeDxDBmAQNtkAABAT2QAAEBPZAAAQcycbAODb+PcVPMgmGwAAYiIbAABiIhsAAGLuZC+D+14AS+G2LjDZZAMAQM4mG06FjRcArIZNNgAAxEQ2AADERDYAAMRENgAAxEQ2AADEPF0E4Hg8QwZgEDbZAAAQE9kAABAT2QAAEHMnGwDg2/j3FSO4uTy/e/2EidtkAwBATGQDAEBMZAMAQMydbAAoua0LTDbZAACQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABDbzn0AAAA4OWdX18/5dptsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIbZ/wPX/58/d3r//5r9+6wwAAwBrMsMl++e79y3fvj/+5AABwHE+J7L9+/6e7/771e1++e//y3d+e8KEAALAU7mQDAEDsKXeyGYGb9wAATzbDJvvTxx8/ffz78T+XQ3PbHgDgd8feZMvrpfj8wv1jNtm/37b/9PHHQx4KAGAZ3MkGAICYO9kwLjfvAeBAbLLJuG2/Vm7bA8C3ssmmIa+X6DE37922B4AnsMkGAICYTTbAmrl5D+vgz/Li2GQDD3Dbfq3ctodF80f4xNlkA18jr5fuSzfv3baHZfnW31/B7GyyAQAW5ve/J899Cr7mKZvsf/z27/wcAMDTuK0LJ8gmG2BQbtuvlau6g/BH+MQ9ZZPtb8kAS+f/zWvy+W3d//v//sdtezgF/uEjAMDC+Hvy6RPZ3M/NewCAJ3MnGwBWxVVdOAU22dzPzXuAJZLXcCJssgEAIGaTDeNy8x4ADsQmGwAAYjbZMC437wHgQGyyAQAgZpMNsGZu3sM6+LO8ODbZAAAQs8kGWDM372Ed/FleHJtsAACI2WQDwLK5rQsnyCYbAABiNtkAsGxu68IJsskGAICYyAYAgJjIBgCAmMgGAIDYZr/fz30GAABYFZtsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIiJbAAAiIlsAACIiWwAAIhtp2nabDb3fu2nH7477mE4iIvd7WTKa2fKIzDlEZjyCEx5BBe72+103HG+ef3qwff8/MuvRzjJcdxcnj/4nrOr6yOcxJQPx5S/xJQPwZQPx5S/xJQPwZQP50Sm7LoIAADERDYAAMRENgAAxLYzfvbnt38ec1to6T6//fOY20LrYMojMOURmPIITHkEpnw0NtkAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAENvO+NlvXr+a8dOP7+byfO4jzMCUR2DKIzDlEZjyCEz5aGyyAQAgJrIBACAmsgEAILbZ7/cf3r6Y+xgc0MXudpomU143Ux6BKY/AlEdgyiO42N3aZAMAQExkAwBAbLPf7+c+AwAArIpNNgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMS20zR9ePti7mNwQBe728mU186UR2DKIzDlEZjyCC52tzbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABDbzn0AAAA4OTeX53evz66uv/XbbbIBACBmkw0A8G2eueNkBDbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQ2859AAAAODlnV9fP+XabbAAAiNlkA0Dp5vL87vUzN2GcLJPlQTbZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABAT2QAAEBPZAAAQE9kAABDbzn0A4D9uLs/vXp9dXc94EgDgmWyyAQAgZpO9DHacAAALYpMNAAAxm0APsNwAAAOPSURBVGyA4/FTqRGYLDDZZAMAQE5kAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBAbDv3AYD/OLu6nvsIAEDDJhsAAGI22ctgxwkAsCA22QAAELPJBjgeP5UCGIRNNgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAwLhevnv/8t37uU+xQiIbAABiIhsAYFAv371/+e5vc59inUQ2AADERDYAwLg+ffzx08e/z32KFdrOfQAAAB7wlz9/f/f6n//6bcaT8EgiGwBgUHbYhyOyAQBO3V+//9Pda5vsRXAnGwD4X56aDAmRDQAAMZENAPyHpyZDRWQDAEBMZAMA/8tTkyHh6SIAsGyeoAwnSGQDAP9hhw0VkQ0Ay+YJynCC3MkG4Gs8NRngCUQ2AADERDYAX+SpyQBP4042wJp57gTALGyyAfgaT00GeAKbbIA189wJWId//PbvuY/AtxHZAHyRHTbA04hsAIBT5ydRi+NONvBtPDUZAB4ksgEAICaygW/gqckA8BjuZMO4PEEZAA7EJhv4Np6aDAAPssmGcXmCMqyDJyjDCRLZwDewwwaAxxDZALBsfhIFJ8idbJ7LU5MBAP7AJpv7ee4EAMCT2WTzLJ6aDADw32yyuZ/nTsA6eO4EwCxssnkuT00GAPgDm2yANfOTKIBZiGyexQ4bAOC/uS4CAACxPrI9NRkAgME95bqIJygDAMBXxJtsT00GAICnbLI9QRnWwROUAeBA+jvZnpoMAMDgPMIPxuUnUQBwIHFk22EDAIDnZAMAQExkAwBAzJ1s7ue5EwAAT2aTDQAAMZts7ue5EwAAT2aTDQAAMZENAAAxkQ0AADGRDQAAMZENAACxpzxdxBOUAQDgK2yyAQAg9pRNticoAwDAV9hkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBMZAMAQExkAwBATGQDAEBss9/v5z4DAACsik02AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADERDYAAMRENgAAxEQ2AADE/j+umjTeRM8EkwAAAABJRU5ErkJggg==", "text/plain": [ "Console does not support images" ] }, "metadata": { "image/png": { "height": 1272, "width": 972 } }, "output_type": "display_data" } ], "source": [ "local screens = {}\n", "for i=1,36 do\n", " local screen, reward, terminal = game_env:step(game_actions[torch.random(3)])\n", " table.insert(screens, screen[1]:clone())\n", "end\n", "itorch.image(screens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "Now we will train the agent to play the game" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "-- Training options\n", "opt = {\n", " -- number of evaluation steps\n", " eval_steps = 125000,\n", " -- frequency of evaluation\n", " eval_freq = 250000,\n", " -- total number of training steps\n", " steps = 50000000,\n", " -- frequency of progress reporting\n", " prog_freq = 10000,\n", " -- frequency to save agent on disk\n", " save_freq = 125000,\n", " -- filename for saved agent\n", " name = 'Itorch_DQN3_0_1_pong_FULL_Y',\n", " -- we want to use random starts of up to 30 nil steps\n", " random_starts = 30,\n", " -- we repeat every action 4 times\n", " actrep = 4,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Iteration ..\t0\t\n" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "local total_reward\n", "local nrewards\n", "local nepisodes\n", "local episode_reward\n", "\n", "local learn_start = agent.learn_start\n", "local start_time = sys.clock()\n", "local step = 0\n", "\n", "print(\"Iteration ..\", step)\n", "while step < opt.steps do\n", " step = step + 1\n", " local action_index = agent:perceive(reward, screen, terminal)\n", "\n", " -- game over? get next game! \n", " if not terminal then\n", " screen, reward, terminal = game_env:step(game_actions[action_index], true)\n", " else\n", " if opt.random_starts > 0 then\n", " screen, reward, terminal = game_env:nextRandomGame()\n", " else\n", " screen, reward, terminal = game_env:newGame()\n", " end\n", " end\n", "\n", " if step % opt.prog_freq == 0 then\n", " assert(step==agent.numSteps, 'trainer step: ' .. step ..\n", " ' & agent.numSteps: ' .. agent.numSteps)\n", " print(\"Training Steps: \", step)\n", " agent:report()\n", " collectgarbage()\n", " end\n", "\n", " if step%1000 == 0 then collectgarbage() end\n", "\n", " if step % opt.eval_freq == 0 and step > learn_start then\n", "\n", " print('Evaluating')\n", " screen, reward, terminal = game_env:newGame()\n", "\n", " total_reward = 0\n", " nrewards = 0\n", " nepisodes = 0\n", " episode_reward = 0\n", "\n", " local eval_time = sys.clock()\n", " for estep=1,opt.eval_steps do\n", " local action_index = agent:perceive(reward, screen, terminal, true, 0.05)\n", " -- Play game in test mode (episodes don't end when losing a life) \n", " screen, reward, terminal = game_env:step(game_actions[action_index])\n", "\n", " if estep%1000 == 0 then collectgarbage() end\n", "\n", " -- record every reward \n", " episode_reward = episode_reward + reward\n", " if reward ~= 0 then\n", " nrewards = nrewards + 1\n", " end\n", "\n", " if terminal then\n", " total_reward = total_reward + episode_reward\n", " episode_reward = 0\n", " nepisodes = nepisodes + 1\n", " screen, reward, terminal = game_env:nextRandomGame()\n", " end\n", " end\n", "\n", " eval_time = sys.clock() - eval_time\n", " start_time = start_time + eval_time\n", " agent:compute_validation_statistics()\n", " local ind = #reward_history+1\n", " total_reward = total_reward/math.max(1, nepisodes)\n", "\n", " if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then\n", " agent.best_network = agent.network:clone()\n", " end\n", "\n", " if agent.v_avg then\n", " v_history[ind] = agent.v_avg\n", " td_history[ind] = agent.tderr_avg\n", " qmax_history[ind] = agent.q_max\n", " end\n", " print(\"V\", v_history[ind], \"TD error\", td_history[ind], \"Qmax\", qmax_history[ind])\n", "\n", " reward_history[ind] = total_reward\n", " reward_counts[ind] = nrewards\n", " episode_counts[ind] = nepisodes\n", " time_history[ind+1] = sys.clock() - start_time\n", "\n", " local time_dif = time_history[ind+1] - time_history[ind]\n", "\n", " local training_rate = opt.actrep*opt.eval_freq/time_dif\n", "\n", " print(string.format(\n", " '\\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' ..\n", " 'training time: %ds, training rate: %dfps, testing time: %ds, ' ..\n", " 'testing rate: %dfps, num. ep.: %d, num. rewards: %d',\n", " step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif,\n", " training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time,\n", " nepisodes, nrewards))\n", " end\n", "\n", " if step % opt.save_freq == 0 or step == opt.steps then\n", " local filename = opt.name\n", " torch.save(filename .. \".t7\", {\n", " reward_history = reward_history,\n", " reward_counts = reward_counts,\n", " episode_counts = episode_counts,\n", " time_history = time_history,\n", " v_history = v_history,\n", " td_history = td_history,\n", " qmax_history = qmax_history,\n", " opt = opt})\n", "-- we are not going to bother saving agents for now.\n", "-- torch.save(filename .. \"_agent.t7\", {\n", "-- agent = agent,\n", "-- best_model = agent.best_network,\n", "-- })\n", " print('Saved:', filename .. '.t7')\n", " io.flush()\n", " collectgarbage()\n", " end\n", "end\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "iTorch", "language": "lua", "name": "itorch" }, "language_info": { "name": "lua", "version": "5.1" } }, "nbformat": 4, "nbformat_minor": 0 }