Skip to content

Commit

Permalink
[Feature] LLMHashingEnv
Browse files Browse the repository at this point in the history
ghstack-source-id: d1a20ecd023008683cf18cf9e694340cfdbdac8a
Pull Request resolved: #2635
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent 57dc25a commit 30d21e5
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments.

PendulumEnv
TicTacToeEnv
LLMHashingEnv


Multi-agent environments
------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,7 +174,7 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogScaler
LogScalar
OptimizerHook
LogValidationReward
ReplayBufferTrainer
Expand Down
25 changes: 25 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import gc
import os.path
import random
import re
from collections import defaultdict
from functools import partial
Expand Down Expand Up @@ -114,6 +115,7 @@
DoubleToFloat,
EnvBase,
EnvCreator,
LLMHashingEnv,
ParallelEnv,
PendulumEnv,
SerialEnv,
Expand Down Expand Up @@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device):
r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device))
assert r.shape == torch.Size((5, 10))

def test_llm_hashing_env(self):
vocab_size = 5

class Tokenizer:
def __call__(self, obj):
return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist()

def decode(self, obj):
words = ["apple", "banana", "cherry", "date", "elderberry"]
return " ".join(random.choice(words) for _ in obj)

def batch_decode(self, obj):
return [self.decode(_obj) for _obj in obj]

def encode(self, obj):
return self(obj)

tokenizer = Tokenizer()
env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size)
td = env.make_tensordict("some sentence")
assert isinstance(td, TensorDict)
env.check_env_specs(tensordict=td)


@pytest.mark.parametrize("device", [None, *get_default_devices()])
@pytest.mark.parametrize("env_device", [None, *get_default_devices()])
Expand Down
15 changes: 10 additions & 5 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,35 +135,40 @@ def make_node(
def full_observation_spec(self):
"""The observation spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.
"""
return self.specs["output_spec", "full_observation_spec"]

@property
def full_reward_spec(self):
"""The reward spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.
"""
return self.specs["output_spec", "full_reward_spec"]

@property
def full_done_spec(self):
"""The done spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.
"""
return self.specs["output_spec", "full_done_spec"]

@property
def full_state_spec(self):
"""The state spec of the tree.
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`."""
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.
"""
return self.specs["input_spec", "full_state_spec"]

@property
def full_action_spec(self):
"""The action spec of the tree.
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`."""
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.
"""
return self.specs["input_spec", "full_action_spec"]

@property
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import PendulumEnv, TicTacToeEnv
from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .env_creator import env_creator, EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
Expand Down
37 changes: 31 additions & 6 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
from tensordict.utils import NestedKey
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.base import _is_leaf_nontensor
from tensordict.utils import is_non_tensor, NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
Expand All @@ -25,7 +31,13 @@
seed_generator,
)

from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
from torchrl.data.tensor_specs import (
Categorical,
Composite,
NonTensor,
TensorSpec,
Unbounded,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
Expand Down Expand Up @@ -430,7 +442,6 @@ def auto_specs_(
done_key: NestedKey | List[NestedKey] | None = None,
observation_key: NestedKey | List[NestedKey] = "observation",
reward_key: NestedKey | List[NestedKey] = "reward",
batch_size: torch.Size | None = None,
):
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
Expand Down Expand Up @@ -484,6 +495,7 @@ def auto_specs_(
tensordict2,
named=True,
nested_keys=True,
is_leaf=_is_leaf_nontensor,
)
input_spec = Composite(input_spec_stack, batch_size=batch_size)
if not self.batch_locked and batch_size != self.batch_size:
Expand All @@ -501,6 +513,7 @@ def auto_specs_(
nexts_1,
named=True,
nested_keys=True,
is_leaf=_is_leaf_nontensor,
)

output_spec = Composite(output_spec_stack, batch_size=batch_size)
Expand All @@ -523,7 +536,8 @@ def auto_specs_(
full_observation_spec = output_spec.separates(*observation_key, default=None)
if not output_spec.is_empty(recurse=True):
raise RuntimeError(
f"Keys {list(output_spec.keys(True, True))} are unaccounted for."
f"Keys {list(output_spec.keys(True, True))} are unaccounted for. "
f"Make sure you have passed all the leaf names to the auto_specs_ method."
)

if full_action_spec is not None:
Expand All @@ -541,6 +555,8 @@ def auto_specs_(

@wraps(check_env_specs_func)
def check_env_specs(self, *args, **kwargs):
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
kwargs["return_contiguous"] = return_contiguous
return check_env_specs_func(self, *args, **kwargs)

check_env_specs.__doc__ = check_env_specs_func.__doc__
Expand Down Expand Up @@ -3206,7 +3222,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
if self._simple_done:
done = tensordict._get_str("done", default=None)
any_done = done.any()
if done is not None:
any_done = done.any()
else:
any_done = False
if any_done:
tensordict._set_str(
"_reset",
Expand Down Expand Up @@ -3572,6 +3591,12 @@ def _has_dynamic_specs(spec: Composite):


def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)):
stack[name] = NonTensor(shape=())
return
elif is_non_tensor(leaf):
stack[name] = NonTensor(shape=leaf.shape)
return
shape = leaf.shape
if leaf_compare is not None:
shape_compare = leaf_compare.shape
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .llm import LLMHashingEnv
from .pendulum import PendulumEnv
from .tictactoeenv import TicTacToeEnv
Loading

0 comments on commit 30d21e5

Please sign in to comment.