Skip to content

Commit

Permalink
[BugFix] Solve R3MTransform init problem (#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 9, 2023
1 parent 569161e commit 9974033
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 42 deletions.
39 changes: 14 additions & 25 deletions torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class R3MTransform(Compose):
"r3m_vec" is assumed.
size (int, optional): Size of the image to feed to resnet.
Defaults to 244.
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
argument will be treaded separetely and each will be given a single,
separated entry in the output tensordict. Defaults to :obj:`True`.
download (bool, optional): if True, the weights will be downloaded using
the torch.hub download API (i.e. weights will be cached for future use).
Defaults to False.
Expand All @@ -179,7 +182,6 @@ class R3MTransform(Compose):

@classmethod
def __new__(cls, *args, **kwargs):
cls._is_3d = None
cls.initialized = False
cls._device = None
cls._dtype = None
Expand All @@ -205,8 +207,11 @@ def __init__(
self.size = size
self.stack_images = stack_images
self.tensor_pixels_keys = tensor_pixels_keys
self._init()

def _init(self):
"""Initializer for R3M."""
self.initialized = True
in_keys = self.in_keys
model_name = self.model_name
out_keys = self.out_keys
Expand Down Expand Up @@ -263,13 +268,13 @@ def _init(self):
)

if stack_images and len(in_keys) > 1:
if self.is_3d:
unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
unsqueeze_dim=-4,
)
transforms.append(unsqueeze)

unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
unsqueeze_dim=-4,
)
transforms.append(unsqueeze)

cattensors = CatTensors(
in_keys,
Expand All @@ -284,6 +289,7 @@ def _init(self):
)
flatten = FlattenObservation(-2, -1, out_keys)
transforms = [*transforms, cattensors, network, flatten]

else:
network = _R3MNet(
in_keys=in_keys,
Expand All @@ -297,22 +303,12 @@ def _init(self):
self.append(transform)
if self.download:
self[-1].load_weights(dir_prefix=self.download_path)
self.initialized = True

if self._device is not None:
self.to(self._device)
if self._dtype is not None:
self.to(self._dtype)

@property
def is_3d(self):
if self._is_3d is None:
parent = self.parent
for key in parent.observation_spec.keys():
self._is_3d = len(parent.observation_spec[key].shape) == 3
break
return self._is_3d

def to(self, dest: Union[DEVICE_TYPING, torch.dtype]):
if isinstance(dest, torch.dtype):
self._dtype = dest
Expand All @@ -327,10 +323,3 @@ def device(self):
@property
def dtype(self):
return self._dtype

forward = _init_first(Compose.forward)
transform_observation_spec = _init_first(Compose.transform_observation_spec)
transform_input_spec = _init_first(Compose.transform_input_spec)
transform_reward_spec = _init_first(Compose.transform_reward_spec)
reset = _init_first(Compose.reset)
init = _init_first(Compose.init)
41 changes: 24 additions & 17 deletions torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms import (
from torchrl.envs.transforms.transforms import (
CatTensors,
Compose,
FlattenObservation,
Expand Down Expand Up @@ -142,6 +142,9 @@ class VIPTransform(Compose):
"vip_vec" is assumed.
size (int, optional): Size of the image to feed to resnet.
Defaults to 244.
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
argument will be treaded separetely and each will be given a single,
separated entry in the output tensordict. Defaults to :obj:`True`.
download (bool, optional): if True, the weights will be downloaded using
the torch.hub download API (i.e. weights will be cached for future use).
Defaults to False.
Expand All @@ -154,7 +157,6 @@ class VIPTransform(Compose):

@classmethod
def __new__(cls, *args, **kwargs):
cls._is_3d = None
cls.initialized = False
cls._device = None
cls._dtype = None
Expand All @@ -180,8 +182,11 @@ def __init__(
self.size = size
self.stack_images = stack_images
self.tensor_pixels_keys = tensor_pixels_keys
self._init()

def _init(self):
"""Initializer for VIP."""
self.initialized = True
in_keys = self.in_keys
model_name = self.model_name
out_keys = self.out_keys
Expand Down Expand Up @@ -238,13 +243,12 @@ def _init(self):
)

if stack_images and len(in_keys) > 1:
if self.is_3d:
unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
unsqueeze_dim=-4,
)
transforms.append(unsqueeze)
unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
unsqueeze_dim=-4,
)
transforms.append(unsqueeze)

cattensors = CatTensors(
in_keys,
Expand Down Expand Up @@ -272,7 +276,6 @@ def _init(self):
self.append(transform)
if self.download:
self[-1].load_weights(dir_prefix=self.download_path)
self.initialized = True

if self._device is not None:
self.to(self._device)
Expand All @@ -281,8 +284,19 @@ def _init(self):

@property
def is_3d(self):
"""Whether the input image has 3 dims (no-batched) or more.
If no parent environment exists, it defaults to True.
The main usage is this: if there are more than one image and they need to be
stacked, we must know if the input image has dim 3 or 4. If 3, we need to unsqueeze
before stacking. If 4, we can cat along the first dimension.
"""
if self._is_3d is None:
parent = self.parent
if parent is None:
return True
for key in parent.observation_spec.keys():
self._is_3d = len(parent.observation_spec[key].shape) == 3
break
Expand All @@ -303,13 +317,6 @@ def device(self):
def dtype(self):
return self._dtype

forward = _init_first(Compose.forward)
transform_observation_spec = _init_first(Compose.transform_observation_spec)
transform_input_spec = _init_first(Compose.transform_input_spec)
transform_reward_spec = _init_first(Compose.transform_reward_spec)
reset = _init_first(Compose.reset)
init = _init_first(Compose.init)


class VIPRewardTransform(VIPTransform):
"""A VIP transform to compute rewards based on embedded similarity.
Expand Down

0 comments on commit 9974033

Please sign in to comment.