Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aLadNamedPat committed Aug 31, 2024
2 parents 09dfeb3 + db3d2f6 commit 73bd33a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 69 deletions.
23 changes: 13 additions & 10 deletions Dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def latent_imagine(self, latents, posterior, horizon : int):
imagined_latent = latents.reshape(x * y, -1)

action = self.actor(torch.cat([imagined_state, imagined_belief]).to(device=device))

print(f"Action Reshape {action.reshape(x, y, -1)}")
latent_list = [imagined_latent]
state_list = [imagined_state]
action_list = [action]
Expand All @@ -91,10 +91,10 @@ def latent_imagine(self, latents, posterior, horizon : int):
state = self.RSSM(imagined_state, action, imagined_belief)
imagined_state, imagined_belief = state[0], state[1]
action = self.actor(torch.cat([imagined_state, imagined_belief]).to(device=device))

action.reshape(x, y, -1)
latent_list.append(imagined_latent)
state_list.append(imagined_state)
action_list.append(action_list)
action_list.append(action)


latent_list = torch.stack(latent_list, dim = 0).to(device = device)
Expand All @@ -115,9 +115,7 @@ def model_update(self):
prev_latent_space = torch.zeros((self.batch_size, self.RSSM.latent_dim))

# Forward pass through the RSSM
latent_spaces, prior_states, prior_means, prior_std_devs, \
posterior_states, posterior_means, posterior_std_devs, \
decoded_observations, rewards = self.RSSM(
latent_spaces, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, decoded_observations, rewards = self.RSSM(
prev_state,
actions,
prev_latent_space,
Expand Down Expand Up @@ -210,20 +208,25 @@ def rollout(
self.num_timesteps += 1
done = False
action = self.sample_action(torch.cat([self.prev_state, self.prev_latent_space]))
print(f"First Action: {action}")
timestep = self.env.step(action)
obs = torch.tensor(self.env.physics.render(camera_id=0, height=120, width=160).copy())
if (t == self.batch_train_freq - 1):
done = True
latent_spaces, prior_states, prior_means, prior_std_devs, \
posterior_states, posterior_means, posterior_std_devs, \
decoded_observations, rewards = self.RSSM(
states = self.RSSM(
self.prev_state,
action,
self.prev_latent_space,
nonterminals=1-done,
observations=obs
observation=obs
)

print(f"States {states}")
if obs is not None:
latent_spaces, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, decoded_observations, rewards = states
else:
latent_spaces, prior_states, prior_means, prior_std_devs, rewards = states

self.prev_state = posterior_states
self.prev_latent_space = latent_spaces

Expand Down
101 changes: 42 additions & 59 deletions RSSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,69 +34,52 @@ def __init__(self, state_dim, action_dim, observation_dim, o_feature_dim, latent
self.relu = nn.ReLU()

## Heavily based off of: https://github.com/zhaoyi11/dreamer-pytorch/blob/master/models.py
def forward(self, prev_state, actions, prev_latent_space, nonterminals = None, observations = None):
time_steps = actions.shape[0] + 1
latent_spaces = [torch.empty(0)] * time_steps
prior_states = [torch.empty(0)] * time_steps
prior_means = [torch.empty(0)] * time_steps
prior_std_devs = [torch.empty(0)] * time_steps
posterior_states = [torch.empty(0)] * time_steps
posterior_means = [torch.empty(0)] * time_steps
posterior_std_devs = [torch.empty(0)] * time_steps
rewards = [torch.empty(0)] * time_steps
latent_spaces[0] = prev_latent_space
prior_states[0] = prev_state
posterior_states[0] = prev_state
def forward(self, prev_state, action, prev_latent_space, nonterminals = None, observation = None):
latent_space = prev_latent_space
prior_state = prev_state
posterior_state = prev_state


for t in range(time_steps - 1):
if observations is None:
state = prior_states[t]
else:
state = posterior_states[t]
if observation is None:
state = prior_state
else:
state = posterior_state

if nonterminals and t != 0:
state = state * nonterminals[t-1]

latent_spaces[t+1] = self.rnn(hidden, latent_spaces[t])
if nonterminals:
state = state * nonterminals

hidden = self.relu(self.transition_pre(torch.cat([state, actions[t]], dim = 1)))
prior_means[t+1], _prior_std_dev = torch.chunk(self.transition_post(hidden), 2, dim = 1)
prior_std_devs[t+1] = F.softplus(_prior_std_dev)
cov_matrix = torch.diag_embed(prior_std_devs[t+1]**2)
sampled_state = torch.distributions.MultivariateNormal(prior_means[t+1], cov_matrix).rsample()
prior_states[t+1] = sampled_state

if observations is not None:
encoded_observation = self.encoder(observations[t])
hidden = self.relu(self.representation_pre(torch.cat([latent_spaces[t+1], encoded_observation], dim=1)))
posterior_means[t+1], _posterior_std_dev = torch.chunk(self.representation_post(hidden), 2, dim=1)
posterior_std_devs[t+1] = F.softplus(_posterior_std_dev)
cov_matrix = torch.diag_embed(posterior_std_devs[t+1]**2)
sampled_state = torch.distributions.MultivariateNormal(posterior_means[t+1], cov_matrix).rsample()
posterior_states[t+1] = sampled_state

rewards[t+1] = self.reward_model(latent_spaces[t+1], sampled_state)

## Returns the latent spaces, states, means, and standard deviations
states = [torch.stack(latent_spaces[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
if observations:
states += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
decoded_observations = [self.decoder(state) for state in posterior_states[1:]]
states.append(torch.stack(decoded_observations, dim=0))

states.append(torch.stack(rewards[1:], dim=0))
return states

def save_model(self,
num_steps):
model_path = f"ModelCheckpoint/RSSM_{num_steps}.pth"
torch.save(self.state_dict(), model_path)
latent_space = self.rnn(state, latent_space)
action = torch.tensor(action, dtype=torch.float32)

## TODO : Do we need to reduce the size of the state
print("State: ", state)
print("Action: ", action)
# state = state.view(-1)
hidden = self.relu(self.transition_pre(torch.cat([state, action], dim = -1)))
prior_mean, _prior_std_dev = torch.chunk(self.transition_post(hidden), 2, dim = 1)
prior_std_dev = F.softplus(_prior_std_dev)
cov_matrix = torch.diag_embed(prior_std_dev**2)
sampled_state = torch.distributions.MultivariateNormal(prior_mean, cov_matrix).rsample()
prior_state = sampled_state

def load_model(self,
num_steps):
model_path = f"ModelCheckpoint/RSSM_{num_steps}.pth"
self.load_state_dict(torch.load(model_path))
if observation is not None:
encoded_observation = self.encoder(observation)
hidden = self.relu(self.representation_pre(torch.cat([latent_space, encoded_observation], dim=1)))
posterior_mean, _posterior_std_dev = torch.chunk(self.representation_post(hidden), 2, dim=1)
posterior_std_dev = F.softplus(_posterior_std_dev)
cov_matrix = torch.diag_embed(posterior_std_dev**2)
sampled_state = torch.distributions.MultivariateNormal(posterior_mean, cov_matrix).rsample()
posterior_state = sampled_state

reward = self.reward_model(latent_space, sampled_state)

states = [latent_space, prior_state, prior_mean, prior_std_dev]
if observation:
states += [posterior_state, posterior_mean, posterior_std_dev]
decoded_observation = self.decoder(posterior_state)
states.append(decoded_observation)

states.append(reward)
return states

## Reward Model as defined by Reward Model qθ(rt | st):
class RewardModel(nn.Module):
Expand Down
Binary file modified __pycache__/Dreamer.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/RSSM.cpython-311.pyc
Binary file not shown.

0 comments on commit 73bd33a

Please sign in to comment.