Skip to content

Commit

Permalink
Move Trajectory to lagom.metric and Runner to root folder
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Jun 28, 2019
1 parent dda54ee commit ecf8311
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 199 deletions.
2 changes: 1 addition & 1 deletion baselines/ppo/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gym.spaces import Box
from gym.wrappers import ClipAction

from lagom import EpisodeRunner
from lagom.utils import pickle_dump
from lagom.utils import set_global_seeds
from lagom.experiment import Config
Expand All @@ -17,7 +18,6 @@
from lagom.envs.wrappers import VecStandardizeObservation
from lagom.envs.wrappers import VecStandardizeReward
from lagom.envs.wrappers import VecStepInfo
from lagom.runner import EpisodeRunner

from baselines.ppo.agent import Agent
from baselines.ppo.engine import Engine
Expand Down
2 changes: 1 addition & 1 deletion baselines/vpg/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gym.spaces import Box
from gym.wrappers import ClipAction

from lagom import EpisodeRunner
from lagom.utils import pickle_dump
from lagom.utils import set_global_seeds
from lagom.experiment import Config
Expand All @@ -17,7 +18,6 @@
from lagom.envs.wrappers import VecStandardizeObservation
from lagom.envs.wrappers import VecStandardizeReward
from lagom.envs.wrappers import VecStepInfo
from lagom.runner import EpisodeRunner

from baselines.vpg.agent import Agent
from baselines.vpg.engine import Engine
Expand Down
4 changes: 1 addition & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ A graphical illustration is coming soon.
lagom.envs <envs>
lagom.experiment <experiment>
lagom.metric <metric>
lagom.multiprocessing <multiprocessing>
lagom.networks <networks>
lagom.runner <runner>
lagom.transform <transform>
lagom.vis <vis>
lagom.utils <utils>
lagom.vis <vis>

Indices and tables
==================
Expand Down
8 changes: 8 additions & 0 deletions docs/source/lagom.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ Engine
.. autoclass:: BaseEngine
:members:

Runner
----------------------------
.. autoclass:: BaseRunner
:members:

.. autoclass:: EpisodeRunner
:members:

Evolution Strategies
----------------------------
.. autoclass:: BaseES
Expand Down
3 changes: 3 additions & 0 deletions docs/source/metric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ lagom.metric: Metrics
.. automodule:: lagom.metric
.. currentmodule:: lagom.metric

.. autoclass:: Trajectory
:members:

.. autofunction:: returns

.. autofunction:: bootstrapped_returns
Expand Down
13 changes: 0 additions & 13 deletions docs/source/runner.rst

This file was deleted.

7 changes: 5 additions & 2 deletions lagom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .version import __version__

from .logger import Logger

from .agent import BaseAgent
from .agent import RandomAgent

Expand All @@ -10,3 +8,8 @@
from .es import BaseES
from .es import CMAES
from .es import CEM

from .logger import Logger

from .runner import BaseRunner
from .runner import EpisodeRunner
2 changes: 2 additions & 0 deletions lagom/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .trajectory import Trajectory

from .returns import returns
from .returns import bootstrapped_returns

Expand Down
53 changes: 26 additions & 27 deletions lagom/runner/trajectory.py → lagom/metric/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,12 @@ def __init__(self):
self.rewards = []
self.step_infos = []

def __len__(self):
return len(self.step_infos)

@property
def completed(self):
return len(self.step_infos) > 0 and self.step_infos[-1].last

def add_observation(self, observation):
assert not self.completed
self.observations.append(observation)

@property
def numpy_observations(self):
out = np.concatenate(self.observations, axis=0)
return out

@property
def last_observation(self):
return self.observations[-1]

@property
def reach_time_limit(self):
Expand All @@ -32,29 +22,41 @@ def reach_time_limit(self):
@property
def reach_terminal(self):
return self.step_infos[-1].terminal

def add_observation(self, observation):
assert not self.completed
self.observations.append(observation)

def add_action(self, action):
assert not self.completed
self.actions.append(action)

@property
def numpy_actions(self):
return np.concatenate(self.actions, axis=0)


def add_reward(self, reward):
assert not self.completed
self.rewards.append(reward)

@property
def numpy_rewards(self):
return np.asarray(self.rewards)

def add_step_info(self, step_info):
assert not self.completed
self.step_infos.append(step_info)
if step_info.last:
assert self.completed


@property
def last_observation(self):
return self.observations[-1]

@property
def numpy_observations(self):
return np.concatenate(self.observations, axis=0)

@property
def numpy_actions(self):
return np.concatenate(self.actions, axis=0)

@property
def numpy_rewards(self):
return np.asarray(self.rewards)

@property
def numpy_dones(self):
return np.asarray([step_info.done for step_info in self.step_infos])
Expand All @@ -70,8 +72,5 @@ def infos(self):
def get_all_info(self, key):
return [step_info[key] for step_info in self.step_infos]

def __len__(self):
return len(self.step_infos)

def __repr__(self):
return f'Trajectory({len(self)})'
return f'Trajectory(T: {len(self)}, Completed: {self.completed}, Reach time limit: {self.reach_time_limit}, Reach terminal: {self.reach_terminal})'
39 changes: 31 additions & 8 deletions lagom/runner/episode_runner.py → lagom/runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,35 @@
from abc import ABC
from abc import abstractmethod

from lagom.metric import Trajectory
from lagom.envs import VecEnv
from lagom.envs.wrappers import VecStepInfo

from .base_runner import BaseRunner
from .trajectory import Trajectory

class BaseRunner(ABC):
r"""Base class for all runners.
A runner is a data collection interface between the agent and the environment.
For each calling of the runner, the agent will take actions and receive observation
in and from an environment for a certain number of trajectories/segments and a certain
number of time steps.
.. note::
By default, the agent handles batched data returned from :class:`VecEnv` type of environment.
"""
@abstractmethod
def __call__(self, agent, env, T, **kwargs):
r"""Run the agent in the environment for a number of time steps and collect all necessary interaction data.
Args:
agent (BaseAgent): agent
env (VecEnv): VecEnv type of environment
T (int): number of time steps
**kwargs: keyword arguments for more specifications.
"""
pass


class EpisodeRunner(BaseRunner):
Expand All @@ -11,9 +38,7 @@ def __init__(self, reset_on_call=True):
self.observation = None

def __call__(self, agent, env, T, **kwargs):
assert isinstance(env, VecEnv)
assert isinstance(env, VecStepInfo)
assert len(env) == 1, 'for cleaner API, one should use single VecEnv'
assert isinstance(env, VecEnv) and isinstance(env, VecStepInfo) and len(env) == 1

D = [Trajectory()]
if self.reset_on_call:
Expand All @@ -26,9 +51,7 @@ def __call__(self, agent, env, T, **kwargs):
for t in range(T):
out_agent = agent.choose_action(observation, **kwargs)
action = out_agent.pop('raw_action')
next_observation, reward, step_info = env.step(action)
# unbatch for [reward, step_info]
reward, step_info = map(lambda x: x[0], [reward, step_info])
next_observation, [reward], [step_info] = env.step(action)
step_info.info = {**step_info.info, **out_agent}
if step_info.last:
D[-1].add_observation([step_info['last_observation']]) # add a batch dim
Expand Down
3 changes: 0 additions & 3 deletions lagom/runner/__init__.py

This file was deleted.

28 changes: 0 additions & 28 deletions lagom/runner/base_runner.py

This file was deleted.

49 changes: 45 additions & 4 deletions test/test_lagom.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from pathlib import Path
import pytest

import numpy as np

import gym
from gym.wrappers import TimeLimit

from pathlib import Path

from lagom.envs import make_vec_env
from lagom import RandomAgent
from lagom import Logger
from lagom import EpisodeRunner
from lagom.metric import Trajectory
from lagom.envs import make_vec_env
from lagom.envs.wrappers import VecStepInfo
from lagom.utils import pickle_load

from .sanity_env import SanityEnv


def test_logger():
logger = Logger()
Expand Down Expand Up @@ -83,3 +87,40 @@ def test_random_agent(env_id, num_env):
assert isinstance(out, dict)
assert len(out['raw_action']) == num_env
assert all(action in env.action_space for action in out['raw_action'])


@pytest.mark.parametrize('env_id', ['Sanity', 'CartPole-v1', 'Pendulum-v0', 'Pong-v0'])
@pytest.mark.parametrize('num_env', [1, 3])
@pytest.mark.parametrize('init_seed', [0, 10])
@pytest.mark.parametrize('T', [1, 5, 100])
def test_episode_runner(env_id, num_env, init_seed, T):
if env_id == 'Sanity':
make_env = lambda: TimeLimit(SanityEnv())
else:
make_env = lambda: gym.make(env_id)
env = make_vec_env(make_env, num_env, init_seed)
env = VecStepInfo(env)
agent = RandomAgent(None, env, None)
runner = EpisodeRunner()

if num_env > 1:
with pytest.raises(AssertionError):
D = runner(agent, env, T)
else:
with pytest.raises(AssertionError):
runner(agent, env.env, T) # must be VecStepInfo
D = runner(agent, env, T)
for traj in D:
assert isinstance(traj, Trajectory)
assert len(traj) <= env.spec.max_episode_steps
assert traj.numpy_observations.shape == (len(traj) + 1, *env.observation_space.shape)
if isinstance(env.action_space, gym.spaces.Discrete):
assert traj.numpy_actions.shape == (len(traj),)
else:
assert traj.numpy_actions.shape == (len(traj), *env.action_space.shape)
assert traj.numpy_rewards.shape == (len(traj),)
assert traj.numpy_dones.shape == (len(traj), )
assert traj.numpy_masks.shape == (len(traj), )
assert len(traj.step_infos) == len(traj)
if traj.completed:
assert np.allclose(traj.observations[-1], traj.step_infos[-1]['last_observation'])
Loading

0 comments on commit ecf8311

Please sign in to comment.