Skip to content

Commit

Permalink
RSSM Converges
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoPasini committed Dec 23, 2024
1 parent 618ed04 commit d54bbbd
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 81 deletions.
3 changes: 3 additions & 0 deletions Dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def model_update(self):

# import pdb; pdb.set_trace()

# print(f"STATES : {states}")
latent_spaces, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, decoded_observations, rewards = self.RSSM(
prev_state.to(device),
actions.squeeze().float().to(device),
Expand Down Expand Up @@ -328,7 +329,9 @@ def train(

while (self.num_timesteps < timesteps):
# wandb.init(project="dreamer_training", reinit=True)
print(f"In Rollout")
self.rollout()
print(f"Out Rollout")
total_actor_loss = 0
total_critic_loss = 0
total_reward_loss = 0
Expand Down
45 changes: 33 additions & 12 deletions RSSM.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from conv_env_dec import ConvDecoder, ConvEncoder, compute_encoder_output_size
from conv_env_dec import ConvDecoder, ConvEncoder
import matplotlib.pyplot as plt

class RSSM(nn.Module):

Expand All @@ -21,26 +22,33 @@ def __init__(self, state_dim, action_dim, o_feature_dim, o_dim, latent_dim, rewa
self.state_dim = state_dim
self.latent_dim = latent_dim
self.o_feature_dim = o_feature_dim
self.reward_dim = reward_dim
self.o_dim = o_dim
self.o_feature_dim = o_feature_dim
self.state_dim = state_dim
self.latent_dim = latent_dim
self.reward_dim = reward_dim

self.encoder = ConvEncoder(self.o_feature_dim, self.latent_dim)
self.decoder = ConvDecoder(self.o_feature_dim, latent_size=self.latent_dim * 2, shape=(o_dim[0], o_dim[1], 3))
self.rnn = nn.GRUCell(input_size=self.latent_dim, hidden_size=self.latent_dim)

self.encoder = ConvEncoder(self.o_feature_dim)
self.decoder = ConvDecoder(latent_dim * 2, shape = (o_dim[0], o_dim[1], 3))
self.rnn = nn.GRUCell(input_size=latent_dim, hidden_size=latent_dim)

self.reward_model = RewardModel(latent_dim, state_dim, reward_dim)
action_dim = action_dim.shape[0]
self.transition_pre = nn.Linear(state_dim + action_dim, latent_dim)
self.transition_post = nn.Linear(latent_dim, 2 * state_dim)
self.representation_pre = nn.Linear((latent_dim + compute_encoder_output_size((1, 1, o_dim[0], o_dim[1], 3), self.encoder)), latent_dim)
self.representation_post = nn.Linear(latent_dim, 2 * state_dim)
self.reward_model = RewardModel(self.latent_dim, self.state_dim, self.reward_dim)
self.action_dim = action_dim.shape[0]
self.transition_pre = nn.Linear(self.state_dim + self.action_dim, self.latent_dim)
self.transition_post = nn.Linear(self.latent_dim, 2 * self.state_dim)
self.representation_pre = nn.Linear(self.latent_dim + self.o_feature_dim, self.latent_dim)
self.representation_post = nn.Linear(self.latent_dim, 2 * self.state_dim)
self.relu = nn.ReLU()

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def forward(self, prev_state, actions, prev_belief, observations = None, nonterminals = None):
prev_state = prev_state.to(self.device)
actions = actions.to(self.device)
if observations is not None:
observations = observations.to(self.device)

# print(f"observations : {observations}")
encoded_observation = self.encoder(observations.float())
T = actions.size(1) + 1
batch_size = actions.size(0)
Expand Down Expand Up @@ -98,6 +106,19 @@ def forward(self, prev_state, actions, prev_belief, observations = None, nonterm
hidden = [beliefs[:, 1:], prior_states[:, 1:], prior_means[:, 1:], prior_std_devs[:, 1:]]
if observations is not None:
hidden += [posterior_states[:, 1:], posterior_means[:, 1:], posterior_std_devs[:, 1:], decoded_observations, rewards]

# Plot the observation input and output
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# axes[0].imshow(observations[0, -1].cpu().detach())
# axes[0].axis('off')
# axes[0].set_title("Last Input Observation")

# axes[1].imshow(decoded_observations[0, -1].cpu().detach())
# axes[1].axis('off')
# axes[1].set_title("Last Decoded Observation")

# plt.show()
return hidden

## Reward Model as defined by Reward Model qθ(rt | st):
Expand Down
26 changes: 21 additions & 5 deletions ReplayBuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@ def sample(self, batch_size : int, data_length : int, random_flag : bool = False
# else:

sampled_indices = [self.sample_idx(data_length) for _ in range(batch_size)]
print(f"Batch Size : {batch_size}")
print(f"Data Length : {data_length}")
# print(f"Batch Size : {batch_size}")
# print(f"Data Length : {data_length}")
# print(f"Buffer : {self.buffer}")
batch_states = torch.zeros((batch_size, data_length) + self.buffer[0][0].shape).to(self.device)
batch_actions = torch.zeros((batch_size, data_length) + self.buffer[0][1].shape).to(self.device)
batch_rewards = torch.zeros((batch_size, data_length) + self.buffer[0][2].shape).to(self.device)
batch_next_states = torch.zeros((batch_size, data_length) + self.buffer[0][3].shape).to(self.device)
batch_dones = torch.zeros((batch_size, data_length) + self.buffer[0][4].shape).to(self.device)
# print(f"Batch States Shape: {batch_states.shape}")
# print(f"Batch Actions Shape: {batch_actions.shape}")
# print(f"Batch Rewards Shape: {batch_rewards.shape}")
# print(f"Batch Next States Shape: {batch_next_states.shape}")
# print(f"Batch Dones Shape: {batch_dones.shape}")

# batch_states = torch.zeros((batch_size, data_length))
# batch_actions = torch.zeros((batch_size, data_length))
Expand All @@ -65,20 +70,31 @@ def sample(self, batch_size : int, data_length : int, random_flag : bool = False
# print(batch_dones.shape)
# print(self.curr_idx)
# print(sampled_indices)

# print(f"Sampled Indices: {sampled_indices}")
for i, idxs in enumerate(sampled_indices):
idx_sequence_states = torch.stack([self.buffer[idx][0] for idx in idxs])
# print(f"Sampling index: {i}")
idx_sequence_states = torch.stack([self.buffer[idx][0] / 255.0 for idx in idxs])
idx_sequence_actions = torch.stack([self.buffer[idx][1] for idx in idxs])
idx_sequence_rewards = torch.stack([self.buffer[idx][2] for idx in idxs])
idx_sequence_next_states = torch.stack([self.buffer[idx][3] for idx in idxs])
idx_sequence_dones = torch.stack([self.buffer[idx][4] for idx in idxs])


# print(f"States shape: {idx_sequence_states.shape}")
# print(f"Actions shape: {idx_sequence_actions.shape}")
# print(f"Rewards shape: {idx_sequence_rewards.shape}")
# print(f"Next States shape: {idx_sequence_next_states.shape}")
# print(f"Dones shape: {idx_sequence_dones.shape}")
# print(f"idx_sequence_states: {idx_sequence_states}")
# print(f"Rewards : {idx_sequence_rewards}")
batch_states[i] = idx_sequence_states

batch_actions[i] = idx_sequence_actions
batch_rewards[i] = idx_sequence_rewards
batch_next_states[i] = idx_sequence_next_states
batch_dones[i] = idx_sequence_dones

return batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones
return batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones

def get_size(self):
return len(self.buffer)
Binary file modified __pycache__/Dreamer.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/RSSM.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/ReplayBuffer.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/conv_env_dec.cpython-311.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ env:
task_name: "run"

dreamer:
state_dims: 512
latent_dims: 512
o_feature_dim: 32
state_dims: 64
latent_dims: 64
o_feature_dim: 64
img_h : 64
img_w : 64
reward_dim: 1
Expand All @@ -14,7 +14,7 @@ dreamer:
batch_size: 8
batch_train_freq: 10
buffer_size: 100000000
sample_steps: 50
sample_steps: 10
steps_of_sampling: 500
horizon: 10

Expand Down
94 changes: 34 additions & 60 deletions conv_env_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@
import torch.distributions as dist
import torch.nn.functional as F

def compute_encoder_output_size(input_shape, encoder):
dummy_input = torch.zeros(input_shape).to(encoder.device)
with torch.no_grad():
output = encoder(dummy_input)
return output.shape[-1]

## VAE Architecture comes form Ha's World Model Paper https://arxiv.org/pdf/1803.10122

class ConvEncoder(nn.Module):
def __init__(self, depth=32, act=nn.ReLU()):
def __init__(self, depth=32, latent_size=32, act=nn.ReLU()):
super().__init__()
self.depth = depth
self.act = act
self.latent_size = latent_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.conv1 = nn.Conv2d(in_channels=3, out_channels=1*depth, kernel_size=4, stride=2).to(self.device)
self.conv2 = nn.Conv2d(in_channels=1*depth, out_channels=2*depth, kernel_size=4, stride=2).to(self.device)
self.conv3 = nn.Conv2d(in_channels=2*depth, out_channels=4*depth, kernel_size=4, stride=2).to(self.device)
self.conv4 = nn.Conv2d(in_channels=4*depth, out_channels=8*depth, kernel_size=4, stride=2).to(self.device)

self.fc_mu = nn.Linear(8*depth*2*2, latent_size).to(self.device)
self.fc_logvar = nn.Linear(8*depth*2*2, latent_size).to(self.device)

def forward(self, obs):
obs = obs.to(self.device)
Expand All @@ -31,49 +30,34 @@ def forward(self, obs):
x = self.act(self.conv2(x))
x = self.act(self.conv3(x))
x = self.act(self.conv4(x))
x = x.reshape(x.size(0), -1)
x = x.view(B, T, -1)
return x
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)

# Latent Vector Z is sampled from Gaussian prior
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std

z = z.view(B, T, -1)
print("Shape of the encoded observation:", z.shape)
return z

class ConvDecoder(nn.Module):
def calculate_flattened_size(self, channels, height, width):
return channels * height * width


def __init__(self, depth=32, act=nn.ReLU(), shape=(128, 192, 3)):
def __init__(self, depth=32, latent_size=32, act=nn.ReLU(), shape=(64, 64, 3)):
super().__init__()
self.depth = depth
self.latent_size = latent_size
self.act = act
self.out_height, self.out_width, self.out_channels = shape
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.fc = nn.Linear(depth, 32 * depth).to(self.device)

### Hard Coded, should not be ###
self.projection_layer = nn.Linear(12288, self.out_height * self.out_width * self.out_channels).to(self.device)
self.deconv1 = nn.ConvTranspose2d(
in_channels=32*depth,
out_channels=4*depth,
kernel_size=5,
stride=2
)
self.deconv2 = nn.ConvTranspose2d(
in_channels=4*depth,
out_channels=2*depth,
kernel_size=5,
stride=2
)
self.deconv3 = nn.ConvTranspose2d(
in_channels=2*depth,
out_channels=1*depth,
kernel_size=6,
stride=2
)
self.deconv4 = nn.ConvTranspose2d(
in_channels=1*depth,
out_channels=self.out_channels,
kernel_size=6,
stride=2
)

self.fc = nn.Linear(latent_size, 8 * depth * 1 * 1).to(self.device)
self.deconv1 = nn.ConvTranspose2d(in_channels=8 * depth, out_channels=8 * depth, kernel_size=5, stride=2).to(self.device)
self.deconv2 = nn.ConvTranspose2d(in_channels=8 * depth, out_channels=4 * depth, kernel_size=5, stride=2).to(self.device)
self.deconv3 = nn.ConvTranspose2d(in_channels=4 * depth, out_channels=2 * depth, kernel_size=6, stride=2).to(self.device)
self.deconv4 = nn.ConvTranspose2d(in_channels=2 * depth, out_channels=self.out_channels, kernel_size=6, stride=2).to(self.device)
self.sigmoid = nn.Sigmoid()

def forward(self, features):
features = features.to(self.device)
Expand All @@ -83,21 +67,11 @@ def forward(self, features):
x = self.fc(features)
x = self.act(x)

x = x.view(B * T, 32 * self.depth, 1, 1)
x = self.deconv1(x)
x = self.act(x)
x = self.deconv2(x)
x = self.act(x)
x = self.deconv3(x)
x = self.act(x)
x = self.deconv4(x)

x = x.reshape(x.size(0), -1)
x = self.projection_layer(x)
x = x.view(B, T, self.out_height, self.out_width, self.out_channels)
x = x.view(B * T, 8 * self.depth, 1, 1)
x = self.act(self.deconv1(x))
x = self.act(self.deconv2(x))
x = self.act(self.deconv3(x))
x = self.sigmoid(self.deconv4(x))

mean = x
normal_dist = dist.Normal(loc=mean, scale=1.0)
out_dist = dist.Independent(normal_dist, reinterpreted_batch_ndims=3)
sample = out_dist.sample()
return sample
x = x.view(B, T, self.out_height, self.out_width, self.out_channels)
return x
Loading

0 comments on commit d54bbbd

Please sign in to comment.