Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aLadNamedPat committed Sep 5, 2024
2 parents c83815c + a1df45b commit 8f6aedb
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 60 deletions.
13 changes: 5 additions & 8 deletions Dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import gzip
import torch.nn.functional as F

device = torch.device("cuda")
device = torch.device("cpu")

wandb.init(
project="Dreamer",
Expand All @@ -26,7 +26,6 @@ def __init__(
env,
state_dims : int,
latent_dims : int,
observation_dim : int,
o_feature_dim : int,
reward_dim : int,
gamma : float = 0.99,
Expand All @@ -44,7 +43,6 @@ def __init__(
self.action_space = env.action_spec()
self.state_dims = state_dims
self.latent_dims = latent_dims
self.observation_dim = observation_dim
self.o_feature_dim = o_feature_dim
self.reward_dim = reward_dim
self.gamma = gamma
Expand Down Expand Up @@ -74,7 +72,6 @@ def __init__(
self.RSSM = RSSM(
state_dim=self.state_dims,
action_dim=self.action_space,
observation_dim=self.observation_dim,
o_feature_dim=self.o_feature_dim,
latent_dim=self.latent_dims,
reward_dim=self.reward_dim
Expand Down Expand Up @@ -104,7 +101,7 @@ def latent_imagine(self, latents, posterior, horizon : int):
action_list = [action]

for _ in range(horizon):
state = self.RSSM(imagined_state, action, imagined_latent)
state = self.RSSM(imagined_state, action_list, imagined_latent)
imagined_state, imagined_latent = state[0], state[1]
action = self.actor(torch.cat([imagined_state, imagined_latent], -1))
# action.reshape(x, y, -1)
Expand Down Expand Up @@ -142,7 +139,7 @@ def model_update(self):
actions.squeeze().float().to(device),
prev_latent_space.to(device),
nonterminals=torch.logical_not(dones).to(device),
observation=states.to(device)
observations=states.to(device)
)

# Calculate the MSE loss for observation and decoded observation
Expand Down Expand Up @@ -249,13 +246,13 @@ def rollout(
timestep = self.env.step(action.cpu())
obs = torch.tensor(self.env.physics.render(camera_id=0, height=128, width=192).copy())
obs = obs.reshape(1, obs.shape[0], obs.shape[1], obs.shape[2]).detach()

action = action.reshape(1, action.shape[0], action.shape[1])
states = self.RSSM(
self.prev_state.to(device),
action.to(device),
self.prev_latent_space.to(device),
nonterminals=1-timestep.last(),
observation=obs.to(device)
observations=obs.to(device)
)

# print(f"States {states}")
Expand Down
79 changes: 29 additions & 50 deletions RSSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class RSSM(nn.Module):
state_dim --> the dimensions for the state
'''
def __init__(self, state_dim, action_dim, observation_dim, o_feature_dim, latent_dim, reward_dim):
def __init__(self, state_dim, action_dim, o_feature_dim, latent_dim, reward_dim):
super(RSSM, self).__init__()
self.state_dim = state_dim
self.latent_dim = latent_dim
Expand All @@ -31,58 +31,37 @@ def __init__(self, state_dim, action_dim, observation_dim, o_feature_dim, latent
self.representation_post = nn.Linear(latent_dim, 2 * state_dim)
self.relu = nn.ReLU()

## Heavily based off of: https://github.com/zhaoyi11/dreamer-pytorch/blob/master/models.py
def forward(self, prev_state, action, prev_latent_space, nonterminals = None, observation = None):
latent_space = prev_latent_space
prior_state = prev_state
posterior_state = prev_state

if observation is None:
state = prior_state
else:
state = posterior_state
def forward(self, prev_state, actions, prev_belief, observations = None, nonterminals = None):
print(f"Actions {actions}")
T = actions.size(0) + 1
beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, decoded_observations = [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * (T - 1)
beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state

if nonterminals is not None:
state = state * nonterminals

print(state.shape)
print(latent_space.shape)
# latent_space = self.rnn(state.view(-1, state.size(-1)), latent_space.view(-1, latent_space.size(-1)))
latent_space = self.rnn(state, latent_space)
for t in range(T - 1):
_state = prior_states[t] if observations is None else posterior_states[t]
_state = _state if (nonterminals is None or t == 0) else _state * nonterminals[t-1]
print(f"State {_state}")
hidden = self.relu(self.transition_pre(torch.cat([_state, actions[t]], dim=1)))
beliefs[t + 1] = self.rnn(hidden, beliefs[t])

## TODO : Do we need to reduce the size of the state

# state = state.view(-1)
hidden = self.relu(self.transition_pre(torch.cat([state, action], dim = -1)))
prior_mean, _prior_std_dev = torch.chunk(self.transition_post(hidden), 2, dim = -1)
prior_std_dev = F.softplus(_prior_std_dev) + 1e-5
cov_matrix = torch.diag_embed(prior_std_dev**2)
sampled_state = torch.distributions.MultivariateNormal(prior_mean, cov_matrix)
prior_state = sampled_state.rsample()
hidden = self.relu(self.transition_pre(beliefs[t + 1]))
prior_means[t + 1], _prior_std_dev = torch.chunk(self.transition_post(hidden), 2, dim=1)
prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + 1e-5
prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])

if observation is not None:
observation = observation.float()
# print(f"Observation Shape: {observation.shape}")
encoded_observation = self.encoder(observation)
# print(f"Latent Space Shape: {latent_space.shape}")
# print(f"Encoded Observation Shape: {encoded_observation.shape}")
hidden = self.relu(self.representation_pre(torch.cat([latent_space, encoded_observation], dim=-1)))
posterior_mean, _posterior_std_dev = torch.chunk(self.representation_post(hidden), 2, dim=-1)
posterior_std_dev = F.softplus(_posterior_std_dev) + 1e-5
cov_matrix = torch.diag_embed(posterior_std_dev**2)
sampled_state = torch.distributions.MultivariateNormal(posterior_mean, cov_matrix)
posterior_state = sampled_state.rsample()

reward = self.reward_model(latent_space, posterior_state)

states = [latent_space, prior_state, prior_mean, prior_std_dev]
if observation is not None:
states = states + [posterior_state, posterior_mean, posterior_std_dev]
decoded_observation = self.decoder(posterior_state)
states.append(decoded_observation)

states.append(reward)
return states
if observations is not None:
t_ = t - 1
encoded_observation = self.encoder(observations[t_ + 1].float())
hidden = self.relu(self.representation_pre(torch.cat([beliefs[t + 1], encoded_observation], dim=1)))
posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.representation_post(hidden), 2, dim=1)
posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + 1e-5
posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])
decoded_observations[t] = self.decoder(posterior_states[t + 1])

hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
if observations is not None:
hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0), torch.stack(decoded_observations, dim=0)]
return hidden

## Reward Model as defined by Reward Model qθ(rt | st):
class RewardModel(nn.Module):
Expand Down
Binary file modified __pycache__/Dreamer.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/RSSM.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/ReplayBuffer.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/conv_env_dec.cpython-311.pyc
Binary file not shown.
1 change: 0 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ env:
dreamer:
state_dims: 30
latent_dims: 30
observation_dim: [120, 160]
o_feature_dim: 1024
reward_dim: 1
gamma: 0.99
Expand Down
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run(configurations, dreamer : Dreamer):
env=env,
state_dims=config['dreamer']['state_dims'],
latent_dims=config['dreamer']['latent_dims'],
observation_dim=tuple(config['dreamer']['observation_dim']),
o_feature_dim=config['dreamer']['o_feature_dim'],
reward_dim=config['dreamer']['reward_dim'],
gamma=config['dreamer']['gamma'],
Expand Down

0 comments on commit 8f6aedb

Please sign in to comment.