Skip to content

Commit

Permalink
Moves fsm into rebar
Browse files Browse the repository at this point in the history
  • Loading branch information
andyljones committed Jul 7, 2020
1 parent 72dc820 commit 1a1d078
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions onedee/fsm.py → rebar/fsm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import numpy as np
import torch
from rebar import arrdict, dotdict
from . import spaces
import pandas as pd

__all__ = ['dataframe']

class MultiVector:

def __init__(self, n_agents, dim):
super().__init__()
self.shape = (n_agents, dim)

class MultiDiscrete:

def __init__(self, n_agents, n_actions):
super().__init__()
self.shape = (n_agents, n_actions)

def _dataframe(traj):
if isinstance(traj, dict):
return [([k] + kk, vv) for k, v in traj.items() for kk, vv in _dataframe(v)]
Expand All @@ -32,8 +43,8 @@ def __init__(self, n_envs, fsm, device='cuda'):

self._token = torch.full((self.n_envs,), -1, dtype=torch.long, device=device)

self.observation_space = spaces.MultiVector(1, fsm.d_obs) if fsm.d_obs else spaces.MultiEmpty()
self.action_space = spaces.MultiDiscrete(1, fsm.n_actions)
self.observation_space = MultiVector(1, fsm.d_obs) if fsm.d_obs else spaces.MultiEmpty()
self.action_space = MultiDiscrete(1, fsm.n_actions)

def _reset(self, reset):
if reset.any():
Expand Down

0 comments on commit 1a1d078

Please sign in to comment.