diff --git a/Dreamer.py b/Dreamer.py index 32134c9..e82a5e3 100644 --- a/Dreamer.py +++ b/Dreamer.py @@ -146,6 +146,7 @@ def model_update(self): # import pdb; pdb.set_trace() + # print(f"STATES : {states}") latent_spaces, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs, decoded_observations, rewards = self.RSSM( prev_state.to(device), actions.squeeze().float().to(device), @@ -328,7 +329,9 @@ def train( while (self.num_timesteps < timesteps): # wandb.init(project="dreamer_training", reinit=True) + print(f"In Rollout") self.rollout() + print(f"Out Rollout") total_actor_loss = 0 total_critic_loss = 0 total_reward_loss = 0 diff --git a/RSSM.py b/RSSM.py index f45f2f8..cae8262 100644 --- a/RSSM.py +++ b/RSSM.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from conv_env_dec import ConvDecoder, ConvEncoder, compute_encoder_output_size +from conv_env_dec import ConvDecoder, ConvEncoder +import matplotlib.pyplot as plt class RSSM(nn.Module): @@ -21,18 +22,25 @@ def __init__(self, state_dim, action_dim, o_feature_dim, o_dim, latent_dim, rewa self.state_dim = state_dim self.latent_dim = latent_dim self.o_feature_dim = o_feature_dim + self.reward_dim = reward_dim + self.o_dim = o_dim + self.o_feature_dim = o_feature_dim + self.state_dim = state_dim + self.latent_dim = latent_dim + self.reward_dim = reward_dim + + self.encoder = ConvEncoder(self.o_feature_dim, self.latent_dim) + self.decoder = ConvDecoder(self.o_feature_dim, latent_size=self.latent_dim * 2, shape=(o_dim[0], o_dim[1], 3)) + self.rnn = nn.GRUCell(input_size=self.latent_dim, hidden_size=self.latent_dim) - self.encoder = ConvEncoder(self.o_feature_dim) - self.decoder = ConvDecoder(latent_dim * 2, shape = (o_dim[0], o_dim[1], 3)) - self.rnn = nn.GRUCell(input_size=latent_dim, hidden_size=latent_dim) - - self.reward_model = RewardModel(latent_dim, state_dim, reward_dim) - action_dim = action_dim.shape[0] - self.transition_pre = nn.Linear(state_dim + action_dim, latent_dim) - self.transition_post = nn.Linear(latent_dim, 2 * state_dim) - self.representation_pre = nn.Linear((latent_dim + compute_encoder_output_size((1, 1, o_dim[0], o_dim[1], 3), self.encoder)), latent_dim) - self.representation_post = nn.Linear(latent_dim, 2 * state_dim) + self.reward_model = RewardModel(self.latent_dim, self.state_dim, self.reward_dim) + self.action_dim = action_dim.shape[0] + self.transition_pre = nn.Linear(self.state_dim + self.action_dim, self.latent_dim) + self.transition_post = nn.Linear(self.latent_dim, 2 * self.state_dim) + self.representation_pre = nn.Linear(self.latent_dim + self.o_feature_dim, self.latent_dim) + self.representation_post = nn.Linear(self.latent_dim, 2 * self.state_dim) self.relu = nn.ReLU() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, prev_state, actions, prev_belief, observations = None, nonterminals = None): @@ -40,7 +48,7 @@ def forward(self, prev_state, actions, prev_belief, observations = None, nonterm actions = actions.to(self.device) if observations is not None: observations = observations.to(self.device) - + # print(f"observations : {observations}") encoded_observation = self.encoder(observations.float()) T = actions.size(1) + 1 batch_size = actions.size(0) @@ -98,6 +106,19 @@ def forward(self, prev_state, actions, prev_belief, observations = None, nonterm hidden = [beliefs[:, 1:], prior_states[:, 1:], prior_means[:, 1:], prior_std_devs[:, 1:]] if observations is not None: hidden += [posterior_states[:, 1:], posterior_means[:, 1:], posterior_std_devs[:, 1:], decoded_observations, rewards] + + # Plot the observation input and output + # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + + # axes[0].imshow(observations[0, -1].cpu().detach()) + # axes[0].axis('off') + # axes[0].set_title("Last Input Observation") + + # axes[1].imshow(decoded_observations[0, -1].cpu().detach()) + # axes[1].axis('off') + # axes[1].set_title("Last Decoded Observation") + + # plt.show() return hidden ## Reward Model as defined by Reward Model qθ(rt | st): diff --git a/ReplayBuffer.py b/ReplayBuffer.py index d706bd3..2b71e0b 100644 --- a/ReplayBuffer.py +++ b/ReplayBuffer.py @@ -43,14 +43,19 @@ def sample(self, batch_size : int, data_length : int, random_flag : bool = False # else: sampled_indices = [self.sample_idx(data_length) for _ in range(batch_size)] - print(f"Batch Size : {batch_size}") - print(f"Data Length : {data_length}") + # print(f"Batch Size : {batch_size}") + # print(f"Data Length : {data_length}") # print(f"Buffer : {self.buffer}") batch_states = torch.zeros((batch_size, data_length) + self.buffer[0][0].shape).to(self.device) batch_actions = torch.zeros((batch_size, data_length) + self.buffer[0][1].shape).to(self.device) batch_rewards = torch.zeros((batch_size, data_length) + self.buffer[0][2].shape).to(self.device) batch_next_states = torch.zeros((batch_size, data_length) + self.buffer[0][3].shape).to(self.device) batch_dones = torch.zeros((batch_size, data_length) + self.buffer[0][4].shape).to(self.device) + # print(f"Batch States Shape: {batch_states.shape}") + # print(f"Batch Actions Shape: {batch_actions.shape}") + # print(f"Batch Rewards Shape: {batch_rewards.shape}") + # print(f"Batch Next States Shape: {batch_next_states.shape}") + # print(f"Batch Dones Shape: {batch_dones.shape}") # batch_states = torch.zeros((batch_size, data_length)) # batch_actions = torch.zeros((batch_size, data_length)) @@ -65,20 +70,31 @@ def sample(self, batch_size : int, data_length : int, random_flag : bool = False # print(batch_dones.shape) # print(self.curr_idx) # print(sampled_indices) + + # print(f"Sampled Indices: {sampled_indices}") for i, idxs in enumerate(sampled_indices): - idx_sequence_states = torch.stack([self.buffer[idx][0] for idx in idxs]) + # print(f"Sampling index: {i}") + idx_sequence_states = torch.stack([self.buffer[idx][0] / 255.0 for idx in idxs]) idx_sequence_actions = torch.stack([self.buffer[idx][1] for idx in idxs]) idx_sequence_rewards = torch.stack([self.buffer[idx][2] for idx in idxs]) idx_sequence_next_states = torch.stack([self.buffer[idx][3] for idx in idxs]) idx_sequence_dones = torch.stack([self.buffer[idx][4] for idx in idxs]) - + + # print(f"States shape: {idx_sequence_states.shape}") + # print(f"Actions shape: {idx_sequence_actions.shape}") + # print(f"Rewards shape: {idx_sequence_rewards.shape}") + # print(f"Next States shape: {idx_sequence_next_states.shape}") + # print(f"Dones shape: {idx_sequence_dones.shape}") + # print(f"idx_sequence_states: {idx_sequence_states}") + # print(f"Rewards : {idx_sequence_rewards}") batch_states[i] = idx_sequence_states + batch_actions[i] = idx_sequence_actions batch_rewards[i] = idx_sequence_rewards batch_next_states[i] = idx_sequence_next_states batch_dones[i] = idx_sequence_dones - return batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones + return batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones def get_size(self): return len(self.buffer) diff --git a/__pycache__/Dreamer.cpython-311.pyc b/__pycache__/Dreamer.cpython-311.pyc index 46dc622..4a55e5e 100644 Binary files a/__pycache__/Dreamer.cpython-311.pyc and b/__pycache__/Dreamer.cpython-311.pyc differ diff --git a/__pycache__/RSSM.cpython-311.pyc b/__pycache__/RSSM.cpython-311.pyc index 2433020..46406c6 100644 Binary files a/__pycache__/RSSM.cpython-311.pyc and b/__pycache__/RSSM.cpython-311.pyc differ diff --git a/__pycache__/ReplayBuffer.cpython-311.pyc b/__pycache__/ReplayBuffer.cpython-311.pyc index 7d151a3..ffd6bf3 100644 Binary files a/__pycache__/ReplayBuffer.cpython-311.pyc and b/__pycache__/ReplayBuffer.cpython-311.pyc differ diff --git a/__pycache__/conv_env_dec.cpython-311.pyc b/__pycache__/conv_env_dec.cpython-311.pyc index ab2e3ea..0aabe82 100644 Binary files a/__pycache__/conv_env_dec.cpython-311.pyc and b/__pycache__/conv_env_dec.cpython-311.pyc differ diff --git a/config.yaml b/config.yaml index b043ee8..88a79c7 100644 --- a/config.yaml +++ b/config.yaml @@ -3,9 +3,9 @@ env: task_name: "run" dreamer: - state_dims: 512 - latent_dims: 512 - o_feature_dim: 32 + state_dims: 64 + latent_dims: 64 + o_feature_dim: 64 img_h : 64 img_w : 64 reward_dim: 1 @@ -14,7 +14,7 @@ dreamer: batch_size: 8 batch_train_freq: 10 buffer_size: 100000000 - sample_steps: 50 + sample_steps: 10 steps_of_sampling: 500 horizon: 10 diff --git a/conv_env_dec.py b/conv_env_dec.py index f605034..1e402db 100644 --- a/conv_env_dec.py +++ b/conv_env_dec.py @@ -3,24 +3,23 @@ import torch.distributions as dist import torch.nn.functional as F -def compute_encoder_output_size(input_shape, encoder): - dummy_input = torch.zeros(input_shape).to(encoder.device) - with torch.no_grad(): - output = encoder(dummy_input) - return output.shape[-1] + +## VAE Architecture comes form Ha's World Model Paper https://arxiv.org/pdf/1803.10122 class ConvEncoder(nn.Module): - def __init__(self, depth=32, act=nn.ReLU()): + def __init__(self, depth=32, latent_size=32, act=nn.ReLU()): super().__init__() self.depth = depth self.act = act + self.latent_size = latent_size self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.conv1 = nn.Conv2d(in_channels=3, out_channels=1*depth, kernel_size=4, stride=2).to(self.device) self.conv2 = nn.Conv2d(in_channels=1*depth, out_channels=2*depth, kernel_size=4, stride=2).to(self.device) self.conv3 = nn.Conv2d(in_channels=2*depth, out_channels=4*depth, kernel_size=4, stride=2).to(self.device) self.conv4 = nn.Conv2d(in_channels=4*depth, out_channels=8*depth, kernel_size=4, stride=2).to(self.device) - + self.fc_mu = nn.Linear(8*depth*2*2, latent_size).to(self.device) + self.fc_logvar = nn.Linear(8*depth*2*2, latent_size).to(self.device) def forward(self, obs): obs = obs.to(self.device) @@ -31,49 +30,34 @@ def forward(self, obs): x = self.act(self.conv2(x)) x = self.act(self.conv3(x)) x = self.act(self.conv4(x)) - x = x.reshape(x.size(0), -1) - x = x.view(B, T, -1) - return x + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_logvar(x) + + # Latent Vector Z is sampled from Gaussian prior + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = mu + eps * std + + z = z.view(B, T, -1) + print("Shape of the encoded observation:", z.shape) + return z class ConvDecoder(nn.Module): - def calculate_flattened_size(self, channels, height, width): - return channels * height * width - - - def __init__(self, depth=32, act=nn.ReLU(), shape=(128, 192, 3)): + def __init__(self, depth=32, latent_size=32, act=nn.ReLU(), shape=(64, 64, 3)): super().__init__() self.depth = depth + self.latent_size = latent_size self.act = act self.out_height, self.out_width, self.out_channels = shape self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.fc = nn.Linear(depth, 32 * depth).to(self.device) - - ### Hard Coded, should not be ### - self.projection_layer = nn.Linear(12288, self.out_height * self.out_width * self.out_channels).to(self.device) - self.deconv1 = nn.ConvTranspose2d( - in_channels=32*depth, - out_channels=4*depth, - kernel_size=5, - stride=2 - ) - self.deconv2 = nn.ConvTranspose2d( - in_channels=4*depth, - out_channels=2*depth, - kernel_size=5, - stride=2 - ) - self.deconv3 = nn.ConvTranspose2d( - in_channels=2*depth, - out_channels=1*depth, - kernel_size=6, - stride=2 - ) - self.deconv4 = nn.ConvTranspose2d( - in_channels=1*depth, - out_channels=self.out_channels, - kernel_size=6, - stride=2 - ) + + self.fc = nn.Linear(latent_size, 8 * depth * 1 * 1).to(self.device) + self.deconv1 = nn.ConvTranspose2d(in_channels=8 * depth, out_channels=8 * depth, kernel_size=5, stride=2).to(self.device) + self.deconv2 = nn.ConvTranspose2d(in_channels=8 * depth, out_channels=4 * depth, kernel_size=5, stride=2).to(self.device) + self.deconv3 = nn.ConvTranspose2d(in_channels=4 * depth, out_channels=2 * depth, kernel_size=6, stride=2).to(self.device) + self.deconv4 = nn.ConvTranspose2d(in_channels=2 * depth, out_channels=self.out_channels, kernel_size=6, stride=2).to(self.device) + self.sigmoid = nn.Sigmoid() def forward(self, features): features = features.to(self.device) @@ -83,21 +67,11 @@ def forward(self, features): x = self.fc(features) x = self.act(x) - x = x.view(B * T, 32 * self.depth, 1, 1) - x = self.deconv1(x) - x = self.act(x) - x = self.deconv2(x) - x = self.act(x) - x = self.deconv3(x) - x = self.act(x) - x = self.deconv4(x) - - x = x.reshape(x.size(0), -1) - x = self.projection_layer(x) - x = x.view(B, T, self.out_height, self.out_width, self.out_channels) + x = x.view(B * T, 8 * self.depth, 1, 1) + x = self.act(self.deconv1(x)) + x = self.act(self.deconv2(x)) + x = self.act(self.deconv3(x)) + x = self.sigmoid(self.deconv4(x)) - mean = x - normal_dist = dist.Normal(loc=mean, scale=1.0) - out_dist = dist.Independent(normal_dist, reinterpreted_batch_ndims=3) - sample = out_dist.sample() - return sample \ No newline at end of file + x = x.view(B, T, self.out_height, self.out_width, self.out_channels) + return x diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f63ef95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,215 @@ +aesara==2.9.4 +aiocontextvars==0.2.2 +amdsmi==6.2.4 +apache_beam==2.61.0 +AppKit==0.2.8 +argh==0.31.3 +array_api_strict==2.2 +asttokens==2.0.5 +astunparse==1.6.3 +attr==0.3.2 +azure_storage==0.37.0 +BeautifulSoup==3.2.2 +beautifulsoup4==4.12.3 +billiard==4.2.1 +blinker==1.6.2 +botocore==1.31.64 +Brotli==1.0.9 +brotlicffi==1.1.0.0 +cairocffi==1.7.1 +Carbon==1.1.10 +channels==4.2.0 +chardet==4.0.0 +chex==0.1.87 +cloudpickle==2.2.1 +colabtools==0.0.1 +colorama==0.4.6 +conda==4.3.16 +coremltools==8.1 +cryptography==42.0.2 +ctypes_snappy==1.03 +cyordereddict==1.0.0 +Cython==3.0.11 +dataclass_array==1.5.2 +deeplearning==0.0.1 +defcon==0.10.3 +defusedxml==0.7.1 +Dialog==0.2.0 +dill==0.3.7 +djangorestframework==3.15.2 +dl==0.1.0 +docutils==0.18.1 +docutils==0.18.1 +email_validator==2.2.0 +etcd==2.0.8 +eval_type_backport==0.2.2 +eventlet==0.38.2 +expecttest==0.3.0 +fastcore==1.7.27 +fastparquet==2024.11.0 +fbscribelogger==0.1.7 +fiddle==0.3.0 +flask_login==0.6.3 +flax==0.10.0 +flint==0.0.1 +foo==.1 +freetype_py==2.5.1 +fs==2.4.16 +gin==0.1.006 +gitdb_speedups==0.1.0 +glfw_preview==0.0.3 +glyphsLib==6.9.5 +gmpy2==2.1.2 +grain==0.2.2 +graphviz==0.20.3 +greenlet==3.0.1 +grpc_tools==1.0.0 +gymnasium==1.0.0 +h2==4.1.0 +h5py==3.9.0 +html5lib==1.1 +httpcore==1.0.7 +hypothesis==6.122.6 +imageio==2.33.1 +ini2toml==0.15 +ipykernel==6.29.5 +ipython==8.12.3 +ipywidgets==7.6.5 +isal==1.7.1 +jax==0.4.35 +jaxtyping==0.2.36 +jnius==1.1.0 +joblib==1.2.0 +kerchunk==0.2.7 +kombu==5.4.2 +kubernetes_asyncio==32.0.0 +langchain_core==0.3.28 +ldclient==0.0.1 +lxml_html_clean==0.4.1 +lz4==4.3.2 +lzmaffi==0.3.0 +matchpy==0.5.5 +mediapy==1.2.2 +ml_collections==1.0.0 +mock==5.1.0 +monkeytype==23.3.0 +moviepy==2.1.1 +mtrand==0.1 +multipart==1.2.1 +multiset==3.2.0 +munkres==1.1.4 +nbformat==5.10.4 +numarray==1.5.1 +Numeric==24.2 +olefile==0.47 +onnxscript==0.1.0.dev20241222 +opencv_python==4.10.0.84 +optax==0.2.4 +optree==0.13.1 +orbax==0.1.9 +orjson==3.10.12 +pandas==2.2.3 +panel==1.3.8 +paramiko==3.5.0 +pickle5==0.0.12 +pip==23.3.1 +pkg1==0.0.3 +polars==1.17.1 +pooch==1.8.2 +prometheus_client==0.14.1 +py4j==0.10.9.8 +pyarrow==14.0.2 +pycolmap==3.11.1 +pycosat==0.6.6 +pydot==3.0.3 +pydy==0.7.1 +pygit2==1.16.0 +pyglet==2.0.20 +pygraphviz==1.14 +PyInstaller==6.11.1 +pyinstrument==5.0.0 +pymc==5.19.1 +pymc3==3.11.6 +pyobjc_framework_Cocoa==10.3.1 +pyobjc_framework_FSEvents==9.0 +pyodide==0.0.2 +pyOpenSSL==24.0.0 +pyOpenSSL==24.3.0 +PyQt5==5.15.11 +PyQt5_sip==12.13.0 +PyQt6==6.8.0 +pysat==3.2.1 +PySide2==5.15.2.1 +PySide6==6.8.1.1 +pyspark==3.5.4 +pytest==8.2.2 +pytest_fail_slow==0.6.0 +pytest_timeout==2.3.1 +python-dotenv==1.0.1 +python_multipart==0.0.20 +pytorch_lightning==2.5.0.post0 +pytz==2023.3.post1 +quart_auth==0.10.1 +Quartz==0.0.1.dev0 +railroad==0.5.0 +rdkit==2024.3.6 +rediscluster==0.5.3 +reportlab==4.2.5 +requests_kerberos==0.15.0 +Res==0.1.7 +safetensors==0.4.5 +sage==0.0.0 +scikits.talkbox==0.2.5 +scipy_doctest==1.6 +seaborn==0.13.2 +select_backport==0.2 +selenium==4.27.1 +setuptools_scm==8.1.0 +shiboken2==5.15.2.1 +shiboken6==6.8.1.1 +simple_parsing==0.1.6 +simplejson==3.19.3 +sip==6.7.12 +slack_sdk==3.34.0 +smbprotocol==1.15.0 +Sphinx==5.0.2 +sphinx_gallery==0.18.0 +stable_baselines3==2.4.0 +sunds==0.4.1 +symengine==0.13.0 +tabulate==0.9.0 +tenacity==8.2.2 +tensorflow_datasets==4.9.7 +theano==1.0.5 +threadpoolctl==3.5.0 +tiktoken==0.7.0 +toml==0.10.2 +tomli_w==1.1.0 +torch_xla==2.5.1 +torcharrow==0.1.0 +torchaudio==2.5.1 +torchrec==1.0.0 +torchtext==0.18.0 +torchvision==0.18.1 +trove_classifiers==2024.10.21.16 +ufoLib2==0.17.0 +uharfbuzz==0.43.0 +ujson==5.4.0 +unicodedata2==15.1.0 +unittest2==1.1.0 +uwsgi==2.0.28 +visu3d==1.5.3 +wandb_workspaces==0.1.8 +watchdog==2.1.6 +WebOb==1.8.9 +Werkzeug==3.1.3 +wmi==1.5.1 +wrapt==1.14.1 +xattr==1.1.0 +xmanager==0.6.0 +xmlrunner==1.7.7 +xx==3.3.2 +yarl==1.9.3 +z3==0.2.0 +zopfli==0.2.3.post1 +zstandard==0.19.0