{
"cells": [
{
"cell_type": "markdown",
"id": "e8966967-97bc-406e-a2f4-4a62d8f9e895",
"metadata": {},
"source": [
"[ ](https://colab.research.google.com/github/pytorch/rl/blob/main/tutorials/envs.ipynb)\n",
"\n",
"# TorchRL envs (`torchrl.envs`)\n",
"\n",
"Environments play a crucial role in RL settings, often somewhat similar to datasets in supervised and unsupervised settings.\n",
"The RL community has become quite familiar with OpenAI gym API which offers a flexible way of building environments, initializing them and interacting with them. \n",
"However, many other libraries exist, and the way one interacts with them can be quite different from what is expected with gym.\n",
"\n",
"Let us start by describing how TorchRL interacts with gym, which will serve as an introduction to other frameworks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b331338",
"metadata": {},
"outputs": [],
"source": [
"!pip install functorch torchvision\n",
"!pip install \"gym[classic_control]\"\n",
"!pip install dm_control matplotlib\n",
"!pip install torchrl"
]
},
{
"cell_type": "markdown",
"id": "f461815d-dfd2-4d48-8d9b-21cc25f55464",
"metadata": {},
"source": [
"## Gym environments\n",
"\n",
"To run this part of the tutorial, you will need to have a recent version of the gym library installed, as well as the atari suite.\n",
"You can get this installed by installing the following packages:\n",
"\n",
"```\n",
"pip install gym atari-py ale-py gym[accept-rom-license] pygame\n",
"```\n",
"\n",
"To unify all frameworks, torchrl environments are built inside the `__init__` method with a private method called `_build_env` that will pass the arguments and keyword arguments to the root library builder.\n",
"\n",
"With gym, it means that building an environment is as easy as:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "09a90ffb-eba0-458e-912d-568ea006e15c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"from torchrl.envs.libs.gym import GymEnv\n",
"from matplotlib import pyplot as plt\n",
"from torchrl.data import TensorDict\n",
"import torch\n",
"env = GymEnv(\"Pendulum-v1\")"
]
},
{
"cell_type": "markdown",
"id": "b508f501-20a0-4e44-b928-17410cf27eb6",
"metadata": {},
"source": [
"The list of available environment can be accessed through this command:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a2b3c152-be95-4140-ab2e-92c9df1c40bc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['ALE/Adventure-ram-v5',\n",
" 'ALE/Adventure-v5',\n",
" 'ALE/AirRaid-ram-v5',\n",
" 'ALE/AirRaid-v5',\n",
" 'ALE/Alien-ram-v5',\n",
" 'ALE/Alien-v5',\n",
" 'ALE/Amidar-ram-v5',\n",
" 'ALE/Amidar-v5',\n",
" 'ALE/Assault-ram-v5',\n",
" 'ALE/Assault-v5']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GymEnv.available_envs[:10]"
]
},
{
"cell_type": "markdown",
"id": "330e470a-ec2e-436c-b1f7-ff2e1f4704c8",
"metadata": {},
"source": [
"### Env specs\n",
"\n",
"Like other frameworks, TorchRL envs have attributes that indicate what space is for the observations, action and reward. \n",
"Because it often happens that more than one observation is retrieved, we expect the observation spec to be of type `CompositeSpec`. Reward and action do not have this restriction:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "36c1475f-c14a-4c76-ac0f-9fd9177ed5e1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Env observation_spec: \n",
" CompositeSpec(\n",
" next_observation: NdBoundedTensorSpec(\n",
" shape=torch.Size([3]), space=ContinuousBox(minimum=tensor([-1., -1., -8.]), maximum=tensor([1., 1., 8.])), device=cpu, dtype=torch.float32, domain=continuous))\n",
"Env action_spec: \n",
" NdBoundedTensorSpec(\n",
" shape=torch.Size([1]), space=ContinuousBox(minimum=tensor([-2.]), maximum=tensor([2.])), device=cpu, dtype=torch.float32, domain=continuous)\n",
"Env reward_spec: \n",
" UnboundedContinuousTensorSpec(\n",
" shape=torch.Size([1]), space=ContinuousBox(minimum=-inf, maximum=inf), device=cpu, dtype=torch.float32, domain=composite)\n"
]
}
],
"source": [
"print(\"Env observation_spec: \\n\", env.observation_spec)\n",
"print(\"Env action_spec: \\n\", env.action_spec)\n",
"print(\"Env reward_spec: \\n\", env.reward_spec)"
]
},
{
"cell_type": "markdown",
"id": "ab3e1a6b-06a8-47e6-b43d-d9cb4b82150a",
"metadata": {},
"source": [
"Those spec come with a series of useful tools: one can assert whether a sample is in the defined space. We can also use some heuristic to project a sample in the space if it is out of space, and generate random (possibly uniformly distributed) numbers in that space:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ad7d55fe-6dda-4757-ada8-5b8dddc41729",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"action is in bounds?\n",
" False\n",
"projected action: \n",
" tensor([2.])\n"
]
}
],
"source": [
"action = torch.ones(1) * 3\n",
"print(\"action is in bounds?\\n\", bool(env.action_spec.is_in(action)))\n",
"print(\"projected action: \\n\", env.action_spec.project(action))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fa103f09-c3a4-4c3e-b39c-6253599d0fec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"random action: \n",
" tensor([-0.8754])\n"
]
}
],
"source": [
"print(\"random action: \\n\", env.action_spec.rand())"
]
},
{
"cell_type": "markdown",
"id": "41045768-1c5b-46a9-9941-a5797bb3185f",
"metadata": {},
"source": [
"Envs are also packed with an `env.input_spec` attribute of type `CompositeSpec`. In brief, `input_spec` should contain all the specs of the inputs that are required for an env to exectute a step. For stateful envs (e.g. gym) this should include the action.\n",
"With stateless environments (e.g. Brax) this should also include a representation of the previous state. "
]
},
{
"cell_type": "markdown",
"id": "b4d99bce-99ef-44f5-8406-de7da52cb23f",
"metadata": {},
"source": [
"### Seeding, resetting and steps\n",
"\n",
"The basic operations on an environment are (1) `set_seed`, (2) `reset` and (3) `step`.\n",
"\n",
"Let's see how these methods work with TorchRL:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0e8bff90-5046-4888-8750-cfc7f050167a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" observation: Tensor(torch.Size([3]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"torch.manual_seed(0) # make sure that all torch code is also reproductible\n",
"env.set_seed(0)\n",
"tensordict = env.reset()\n",
"print(tensordict)"
]
},
{
"cell_type": "markdown",
"id": "2936a240-8a03-4726-94fe-6151cc4f7f3e",
"metadata": {},
"source": [
"We can now execute a step in the environment. \n",
"Since we don't have a policy, we can just generate a random action:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "676c3ddc-3396-4e93-a474-2e0a403ec14d",
"metadata": {},
"outputs": [],
"source": [
"def policy(tensordict):\n",
" tensordict.set(\"action\", env.action_spec.rand())\n",
" return tensordict\n",
"policy(tensordict)\n",
"tensordict_out = env.step(tensordict)"
]
},
{
"cell_type": "markdown",
"id": "c9a7364c-d8e9-44de-a9da-977e8d14c094",
"metadata": {},
"source": [
"By default, the tensordict returned by `step` is the same as the input..."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "903ad364-0843-4674-9625-11fdece3eb18",
"metadata": {},
"outputs": [],
"source": [
"assert tensordict_out is tensordict"
]
},
{
"cell_type": "markdown",
"id": "64aac817-a77d-4c8b-b6da-75a90f1ac1be",
"metadata": {},
"source": [
"... but with new keys"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9cabe1d8-d904-4795-abc3-9b404ee9e9b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" action: Tensor(torch.Size([1]), dtype=torch.float32),\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" next_observation: Tensor(torch.Size([3]), dtype=torch.float32),\n",
" observation: Tensor(torch.Size([3]), dtype=torch.float32),\n",
" reward: Tensor(torch.Size([1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensordict"
]
},
{
"cell_type": "markdown",
"id": "581b7ab6-f542-444f-970f-9755b18051cc",
"metadata": {},
"source": [
"What we just did (a random step using `action_spec.rand()`) can also be done via the simple shortcut"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2c08ee82-8d0e-4735-ba80-5944f393340e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" action: Tensor(torch.Size([1]), dtype=torch.float32),\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" next_observation: Tensor(torch.Size([3]), dtype=torch.float32),\n",
" reward: Tensor(torch.Size([1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.rand_step()"
]
},
{
"cell_type": "markdown",
"id": "46b3e9c0-ad70-475b-90ae-11787b900ed3",
"metadata": {},
"source": [
"The new key `\"next_observation\"` (as all keys starting with `\"next_\"`) have a special role in TorchRL: they indicate that they come after the key with the same name but without the prefix.\n",
"\n",
"We provide a function `step_mdp` that executes a step in the tensordict: it returns a new tensordict updated such that $t <- t'$:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ad697e31-ec9d-4607-942a-93f96b5f0e85",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TensorDict(\n",
" fields={\n",
" observation: Tensor(torch.Size([3]), dtype=torch.float32),\n",
" some other key: Tensor(torch.Size([1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"tensor(True)\n"
]
}
],
"source": [
"from torchrl.envs.utils import step_mdp\n",
"tensordict.set(\"some other key\", torch.randn(1))\n",
"tensordict_tprime = step_mdp(tensordict)\n",
"print(tensordict_tprime)\n",
"print((tensordict_tprime.get(\"observation\") == tensordict.get(\"next_observation\")).all())"
]
},
{
"cell_type": "markdown",
"id": "21925dd8-6492-401c-a09b-60ad6a7774d8",
"metadata": {},
"source": [
"We can observe that `step_mdp` has removed all the time-dependent key-value pairs, but not `\"some other key\"`. Also, the new observation matches the previous one"
]
},
{
"cell_type": "markdown",
"id": "14d2951e-c263-4d06-903b-686191ecf97b",
"metadata": {},
"source": [
"Finally, note that the `env.reset` method also accepts a tensordict to update:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bc928092-ade0-46b3-836b-4f21caafc3a7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" observation: Tensor(torch.Size([3]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensordict = TensorDict({}, [])\n",
"assert env.reset(tensordict) is tensordict\n",
"tensordict"
]
},
{
"cell_type": "markdown",
"id": "14ae176d-a7af-4c3d-82fa-bf69d375bae8",
"metadata": {},
"source": [
"### Rollouts\n",
"\n",
"The generic environment class provided by TorchRL allows you to run rollouts easily for a given number of steps:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "41d820f7-a063-4947-935d-6018f05c12ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TensorDict(\n",
" fields={\n",
" action: Tensor(torch.Size([20, 1]), dtype=torch.float32),\n",
" done: Tensor(torch.Size([20, 1]), dtype=torch.bool),\n",
" next_observation: Tensor(torch.Size([20, 3]), dtype=torch.float32),\n",
" observation: Tensor(torch.Size([20, 3]), dtype=torch.float32),\n",
" reward: Tensor(torch.Size([20, 1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([20]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"tensordict_rollout = env.rollout(max_steps=20, policy=policy)\n",
"print(tensordict_rollout)"
]
},
{
"cell_type": "markdown",
"id": "1e0bb1e5-8661-4187-b97b-69bf64389a71",
"metadata": {},
"source": [
"The resulting tensordict has a `batch_size` of `[20]`, which is the length of the trajectory. We can check that the observation match their next value:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "eb449dec-640b-43d8-8f46-70a2de3cf469",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(tensordict_rollout.get(\"observation\")[1:] == tensordict_rollout.get(\"next_observation\")[:-1]).all()"
]
},
{
"cell_type": "markdown",
"id": "630dfdb7-d448-4c27-865f-4bb455d016b4",
"metadata": {},
"source": [
"### frame_skip\n",
"\n",
"In some instances, it is useful to use a `frame_skip` argument to use the same action for several consecutive frames.\n",
"\n",
"The resulting tensordict will contain only the last frame observed in the sequence, but the rewards will be summed over the number of frames. \n",
"\n",
"If the environment reaches a done state during this process, it'll stop and return the result of the truncated chain."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6b2cb9da-d976-410e-92d3-19ad75bd228e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
},
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" observation: Tensor(torch.Size([3]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env = GymEnv(\"Pendulum-v1\", frame_skip=4)\n",
"env.reset()"
]
},
{
"cell_type": "markdown",
"id": "be11c29c-1a68-4cd6-a8dd-4aebdf72785e",
"metadata": {},
"source": [
"### Rendering\n",
"\n",
"Rendering plays an important role in many RL settings, and this is why the generic environment class from torchrl provides a `from_pixels` keyword argument that allows the user to quickly ask for image-based environments:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "19cf0c74-ab92-4a8a-9c0a-984282a295ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"env = GymEnv(\"Pendulum-v1\", from_pixels=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8b914a63-1734-4fa8-96b6-c2a859f273d6",
"metadata": {},
"outputs": [],
"source": [
"tensordict = env.reset()\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "39500114-9f7b-4150-bc11-ecdb7d38c3ff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAARbElEQVR4nO3dbYyV5Z3H8e9vnmdAeRwRGBBUbDW6WiWWpk3aaM1S21RjdKNptmRDwot1E7tt0mo27abJvmjf1LbpptFdm9JNW+1aE4kx6bJou9k0VaEoFREZ8AEQYRAYh8dhZv774lywIwzODXPuOWfm+n2Sk7mv6/7PnP8ww2/ux3MUEZhZvhpq3YCZ1ZZDwCxzDgGzzDkEzDLnEDDLnEPALHOlhICk5ZK2SuqW9GAZz2Fm1aFqXycgqRF4A7gN2AW8BNwXEa9V9YnMrCrK2BK4GeiOiB0R0Q88DtxRwvOYWRU0lfA15wM7h413AZ/8qE+YPXt2LFq0qIRWzOyUDRs27I+IzjPnywiBQiStAlYBLFy4kPXr19eqFbMsSHp7pPkydgd2AwuGjbvS3IdExKMRsTQilnZ2nhVOZjZOygiBl4AlkhZLagHuBdaU8DxmVgVV3x2IiAFJ/wD8DmgEfhYRm6v9PGZWHaUcE4iIZ4Fny/jaZlZdvmLQLHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDI3aghI+pmkfZJeHTY3U9JaSdvSxxlpXpJ+LKlb0iZJN5bZvJmNXZEtgZ8Dy8+YexBYFxFLgHVpDPAFYEl6rAJ+Wp02zawso4ZARPwPcOCM6TuA1Wl5NXDnsPlfRMWfgOmS5lapVzMrwYUeE5gTEXvS8nvAnLQ8H9g5rG5XmjuLpFWS1kta39PTc4FtmNlYjfnAYEQEEBfweY9GxNKIWNrZ2TnWNszsAl1oCOw9tZmfPu5L87uBBcPqutKcmdWpCw2BNcCKtLwCeHrY/FfTWYJlQO+w3QYzq0NNoxVI+jXwOWC2pF3APwPfA34jaSXwNvA3qfxZ4HagGzgK/F0JPZtZFY0aAhFx3zlW3TpCbQD3j7UpMxs/vmLQLHMOAbPMOQTMMucQMMucQ8Asc6OeHbDJLSIYOnaMDzZupG/zZoZOnKB94UKm3XwzrZdeiqRat2glcwhkLCI48e67vPPII/T95S8wOHh63b41a5i/YgUzPvMZ1OANxsnMP92MDRw6xNs/+Ql9L7/8oQAA6O/p4Z1HHqF3wwYql3/YZOUQyFREsH/tWg6/9to5awb7+tjzq18xdPToOHZm480hkKsIDv3xjzDKX/mj27ez58kniaGhcWrMxptDIGNFN/N7X3qJwSNHSu7GasUhYKPq37vXITCJOQQy1rF4caG6GBzk2M6doxfahOQQyJXElI99rFBpDAxwZOtWnyWYpBwCmZJES2cnam4uVN/f0wM+ODgpOQQy1nHFFTROmVKo9vCrrzJ04kTJHVktOAQy1tDWRtNFFxWqHervZ8AHByclh0DGGlpbmXrttYVqB/r6OLJlS8kdWS04BHIm0TJzZrHaoSEGent9cHAScghkTBJTr7228MHB3g0bRr3C0CYeh0DmWmbPRo2NhWpPHjhAnHGjkU18DoHMNV10Ee2XXVaotn//fk68+27JHdl4cwhkrqG9nZaCbwM3ePgwJw8eLLkjG28OgcxJomPJksL1R7ZtK7EbqwWHgNFx5ZWFa4+8/rrPEEwyDgGjqaOj8BmCk4cOMXTsWMkd2XhyCBhtXV20LVgweiFw7K23OHngQMkd2XhyCBhqaaGxo6NQbQwN+eDgJOMQMACmLV1arHBwkN7168ttxsaVQ8AAaLnkksK1A4cP+zUHJxGHgCGJ9ssuo7HgHYV9mzYx6FcgnjQcAgZAS2dn4eMCg0eO+LUFJhGHgAGgpibaFy4sVDt0/DhHXn+95I5svDgEDAA1NjLlqqsK1cbAAP09Pb5oaJIYNQQkLZD0vKTXJG2W9ECanylpraRt6eOMNC9JP5bULWmTpBvL/iZs7CRVDg4WvKPwxJ49vq14kiiyJTAAfCMirgGWAfdLugZ4EFgXEUuAdWkM8AVgSXqsAn5a9a6tFFOvvpqG1tZCtX2bN/u24kli1BCIiD0R8ee03AdsAeYDdwCrU9lq4M60fAfwi6j4EzBd0txqN27V19DeTmN7e6HawaNHGejrK7kjGw/ndUxA0iLgE8ALwJyI2JNWvQfMScvzgeHvVLErzVmda5oyhalXX12o9uTBgxx/+20fF5gECoeApKnAb4GvRcQHw9dF5TfhvH4bJK2StF7S+p6envP5VCtLYyNN06cXqx0cpH///lLbsfFRKAQkNVMJgF9GxFNpeu+pzfz0cV+a3w0MvxulK819SEQ8GhFLI2JpZ8EXtbBySeLiG24AqVD9B6+8Um5DNi6KnB0Q8BiwJSJ+MGzVGmBFWl4BPD1s/qvpLMEyoHfYboPVuZZLLoGGYhuI/fv2EQMDJXdkZSvy0/408LfALZJeTo/bge8Bt0naBnw+jQGeBXYA3cC/AX9f/batLC0zZ9I2b16h2hPvvcfJ998vuSMrW9NoBRHxv8C5tg9vHaE+gPvH2JfVSOPUqTTPnMnxAu9CPNDby8lDh2i99NJx6MzK4isG7UPU0EDHFVcUK47g6Pbt5TZkpXMI2FmKvmU5wJE33vBpwgnOIWBnaZk9m4aCFw319/T4jsIJziFgZ2mdN4/mGTMK1R7bsYPBw4dL7sjK5BCwszS2tRUOgaGBAfp9hmBCcwjY2RoauOi66wqVRn8/hzdv9nGBCcwhYCNqLnr5MJU3KvVtxROXQ8DOIon2yy+noa2tUH3fpk3EyZMld2VlcQjYiNrmzy/82gIDfX0M+gzBhOUQsBGpuZnWgpcPD/T1cWzHjpI7srI4BGxEDa2tTCn4RqXR30///v0+ODhBOQRsRJJo6ewsfFvxkS1bSu7IyuIQsHO66PrrUdOo95gBcPTNN8HvSjQhFfsJW5YaOzpoaG1lcNiR/w/6+3nqnXfoOX6cv543j+tmzEDS6TsKW2bNqmHHdiG8JWDn1DxzJlOWLDk97jt5ku9s3MhPtmzhiTff5B9ffJE/pZeG6+/poX/v3lq1amPgELBzamhu/tD7E+4+epQ/7tt3etx78iT/9e67tWjNqsghYB9p2k03nV5uaWig9Yw3J7m4uXm8W7IqcwjYR5p2003MXr4cgMVTp/Kt665jdmsrrY2N3DJ3LivT7kLLJZfQ4lcYmpB8YNA+UtPFF3Pp3Xdz5PXXOfbWW3yxq4sbZ83i2MAA86dMoa2xEbW0MOeuuwrfeWj1xVsCNqqWzk4W3n8/7YsXo4YG5nV0cMXFF9PW2EhDRweX3nMPsz//eVTwmgKrL94SsFFJYspVV3Hlt7/NgT/8gb5Nmxg8fpz2BQuY+dnPMuXqq2koeD2B1R//5KwQSbTMns2cu+5izl13VW4dTn/5vQUwsTkE7Lyc/g/v//iTho8JmGXOIWCWOYeAWeYcAmaZcwiYZc4hYJY5h4BZ5hwCZplzCJhlziFgljmHgFnmHAJmmRs1BCS1SXpR0iuSNkv6bppfLOkFSd2SnpDUkuZb07g7rV9U8vdgZmNQZEvgBHBLRFwP3AAsl7QM+D7wcERcCRwEVqb6lcDBNP9wqjOzOjVqCETF4TRsTo8AbgGeTPOrgTvT8h1pTFp/q3zDuVndKnRMQFKjpJeBfcBaYDtwKCIGUskuYH5ang/sBEjre4Gz3pFC0ipJ6yWt70mvXW9m469QCETEYETcAHQBNwMfH+sTR8SjEbE0IpZ2dnaO9cuZ2QU6r7MDEXEIeB74FDBd0qlXJuoCdqfl3cACgLR+GvB+NZo1s+orcnagU9L0tNwO3AZsoRIGd6eyFcDTaXlNGpPWPxd+z2qzulXkNQbnAqslNVIJjd9ExDOSXgMel/QvwEbgsVT/GPAfkrqBA8C9JfRtZlUyaghExCbgEyPM76ByfODM+ePAPVXpzsxK5ysGzTLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy1zhEJDUKGmjpGfSeLGkFyR1S3pCUkuab03j7rR+UUm9m1kVnM+WwAPAlmHj7wMPR8SVwEFgZZpfCRxM8w+nOjOrU4VCQFIX8EXg39NYwC3Ak6lkNXBnWr4jjUnrb031ZlaHim4J/BD4JjCUxrOAQxExkMa7gPlpeT6wEyCt7031ZlaHRg0BSV8C9kXEhmo+saRVktZLWt/T01PNL21m56HIlsCngS9Legt4nMpuwI+A6ZKaUk0XsDst7wYWAKT104D3z/yiEfFoRCyNiKWdnZ1j+ibM7MKNGgIR8VBEdEXEIuBe4LmI+ArwPHB3KlsBPJ2W16Qxaf1zERFV7drMqmYs1wl8C/i6pG4q+/yPpfnHgFlp/uvAg2Nr0czK1DR6yf+LiN8Dv0/LO4CbR6g5DtxThd7MbBz4ikGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDLnEDDLnEPALHMOAbPMOQTMMucQMMucQ8Ascw4Bs8w5BMwy5xAwy5xDwCxzDgGzzDkEzDKniKh1D0jqA7bWuo/zMBvYX+smCppIvcLE6nci9QpwWUR0njnZVItORrA1IpbWuomiJK2fKP1OpF5hYvU7kXr9KN4dMMucQ8Asc/USAo/WuoHzNJH6nUi9wsTqdyL1ek51cWDQzGqnXrYEzKxGah4CkpZL2iqpW9KDddDPzyTtk/TqsLmZktZK2pY+zkjzkvTj1PsmSTfWoN8Fkp6X9JqkzZIeqNeeJbVJelHSK6nX76b5xZJeSD09IaklzbemcXdav2i8eh3Wc6OkjZKeqfdeL1RNQ0BSI/CvwBeAa4D7JF1Ty56AnwPLz5h7EFgXEUuAdWkMlb6XpMcq4Kfj1ONwA8A3IuIaYBlwf/o3rMeeTwC3RMT1wA3AcknLgO8DD0fElcBBYGWqXwkcTPMPp7rx9gCwZdi4nnu9MBFRswfwKeB3w8YPAQ/VsqfUxyLg1WHjrcDctDyXynUNAI8A941UV8PenwZuq/eegQ7gz8AnqVxw03Tm7wTwO+BTabkp1Wkce+yiEqC3AM8Aqtdex/Ko9e7AfGDnsPGuNFdv5kTEnrT8HjAnLddV/2kT9BPAC9Rpz2nz+mVgH7AW2A4cioiBEfo53Wta3wvMGq9egR8C3wSG0ngW9dvrBat1CEw4UYn6ujulImkq8FvgaxHxwfB19dRzRAxGxA1U/sreDHy8th2NTNKXgH0RsaHWvZSt1iGwG1gwbNyV5urNXklzAdLHfWm+LvqX1EwlAH4ZEU+l6bruOSIOAc9T2aSeLunUJezD+znda1o/DXh/nFr8NPBlSW8Bj1PZJfhRnfY6JrUOgZeAJemIawtwL7Cmxj2NZA2wIi2voLLffWr+q+mI+zKgd9gm+LiQJOAxYEtE/GDYqrrrWVKnpOlpuZ3KsYstVMLg7nP0eup7uBt4Lm3VlC4iHoqIrohYROX38rmI+Eo99jpmtT4oAdwOvEFl3/Cf6qCfXwN7gJNU9vlWUtm3WwdsA/4bmJlqReXsxnbgL8DSGvT7GSqb+puAl9Pj9nrsGfgrYGPq9VXgO2n+cuBFoBv4T6A1zbelcXdaf3mNfic+BzwzEXq9kIevGDTLXK13B8ysxhwCZplzCJhlziFgljmHgFnmHAJmmXMImGXOIWCWuf8DsXqv/xEXH5gAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(tensordict.get(\"pixels\").numpy())"
]
},
{
"cell_type": "markdown",
"id": "6f6d85ad-dcde-426f-b143-5e188c8c4afc",
"metadata": {},
"source": [
"Let's have a look at what the tensordict contains:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "899fa1c2-e59a-40c6-bb50-450545984e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([500, 500, 3]), dtype=torch.uint8)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensordict"
]
},
{
"cell_type": "markdown",
"id": "95654812-ccce-46be-bf96-17ff48abd65d",
"metadata": {},
"source": [
"We still have a `\"state\"` that describes what `\"observation\"` used to describe in the previous case (the naming difference comes from the fact that gym now returns a dictionary and TorchRL gets the names from the dictionary if it exists, otherwise it names the step output `\"observation\"`: in a few words, this is due to inconsistencies in the object type returned by gym environment step method).\n",
"\n",
"One can also discard this supplementary output by asking for the pixels only:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "aee540c0-b51e-45df-be09-bf6a7549592a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"env = GymEnv(\"Pendulum-v1\", from_pixels=True, pixels_only=True)\n",
"env.reset()\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"id": "c9df9805-6e58-4c10-912d-a4bd228e9a11",
"metadata": {},
"source": [
"Some environments only come in image-based format"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "30fd300a-7b91-490c-a5d6-2d34ca4635f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"from pixels: True\n",
"tensordict: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([210, 160, 3]), dtype=torch.uint8)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]\n"
]
}
],
"source": [
"env = GymEnv(\"ALE/Pong-v5\")\n",
"print('from pixels: ', env.from_pixels)\n",
"print('tensordict: ', env.reset())\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"id": "f93140da-dc1c-4a09-94a9-626f5a1ff42d",
"metadata": {},
"source": [
"___\n",
"## DeepMind Control environments\n",
"\n",
"To run this part of the tutorial, make sure you have installed dm_control:\n",
"\n",
"```\n",
"pip install dm_control\n",
"```\n",
"\n",
"Make sure also to restart the notebook in between this demo and the previous, as gym and dm_control rendering can conflict.\n",
"\n",
"We also provide a wrapper for DM Control suite. Again, building an environment is easy: first let's look at what environments can be accessed. The `available_envs` now returns a dict of envs and possible tasks:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "1060ddb7-3880-473e-ab81-30e02add0e4d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'acrobot': ['swingup', 'swingup_sparse'],\n",
" 'ball_in_cup': ['catch'],\n",
" 'cartpole': ['balance',\n",
" 'balance_sparse',\n",
" 'swingup',\n",
" 'swingup_sparse',\n",
" 'three_poles',\n",
" 'two_poles'],\n",
" 'cheetah': ['run'],\n",
" 'finger': ['spin', 'turn_easy', 'turn_hard'],\n",
" 'fish': ['upright', 'swim'],\n",
" 'hopper': ['stand', 'hop'],\n",
" 'humanoid': ['stand', 'walk', 'run', 'run_pure_state'],\n",
" 'manipulator': ['bring_ball', 'bring_peg', 'insert_ball', 'insert_peg'],\n",
" 'pendulum': ['swingup'],\n",
" 'point_mass': ['easy', 'hard'],\n",
" 'reacher': ['easy', 'hard'],\n",
" 'swimmer': ['swimmer6', 'swimmer15'],\n",
" 'walker': ['stand', 'walk', 'run'],\n",
" 'dog': ['fetch', 'run', 'stand', 'trot', 'walk'],\n",
" 'humanoid_CMU': ['run', 'stand'],\n",
" 'lqr': ['lqr_2_1', 'lqr_6_2'],\n",
" 'quadruped': ['escape', 'fetch', 'run', 'walk'],\n",
" 'stacker': ['stack_2', 'stack_4']}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchrl.envs.libs.dm_control import DMControlEnv\n",
"from matplotlib import pyplot as plt\n",
"DMControlEnv.available_envs"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bb712ed0-aad8-4718-9dda-6eac875c78a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result of reset: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" orientations: Tensor(torch.Size([4]), dtype=torch.float64),\n",
" velocity: Tensor(torch.Size([2]), dtype=torch.float64)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"env = DMControlEnv('acrobot', 'swingup')\n",
"tensordict = env.reset()\n",
"print('result of reset: ', tensordict)\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"id": "f4f5f4d5-b3c0-401d-8934-0dc9ebd3d72a",
"metadata": {},
"source": [
"Of course we can also use pixel-based environments:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "db64ab96-a5bc-4d77-990a-ab4b7357e291",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result of reset: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([240, 320, 3]), dtype=torch.uint8)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from torchrl.envs.libs.dm_control import DMControlEnv\n",
"env = DMControlEnv('acrobot', 'swingup', from_pixels=True, pixels_only=True)\n",
"tensordict = env.reset()\n",
"print('result of reset: ', tensordict)\n",
"plt.imshow(tensordict.get(\"pixels\").numpy())\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"id": "e0e93b95-fa48-48a3-9acc-9e8d8594103b",
"metadata": {},
"source": [
"___\n",
"## Transforming envs\n",
"\n",
"It is common to pre-process the output of an environment before having it read by the policy or stored in a buffer.\n",
"\n",
"In many instances, the RL community has adopted a wrapping scheme of the type\n",
"\n",
"```\n",
"env_transformed = wrapper1(wrapper2(env))\n",
"```\n",
"\n",
"to transform environments. This has numerous advantages: it makes accessing the environment specs obvious (the outer wrapper is the source of truth for the external world), and it makes it easy to interact with vectorized environment.\n",
"However it also makes it hard to access inner environments: say one wants to remove a wrapper (e.g. `wrapper2`) from the chain, this operation requires us to collect\n",
"```\n",
"env0 = env.env.env\n",
"env_transformed_bis = wrapper1(env0)\n",
"```\n",
"\n",
"TorchRL takes the stance of using sequences of transforms instead, as it is done in other pytorch domain libraries (e.g. `torchvision`). This approach is also similar to the way distributions are transformed in `torch.distribution`, where a `TransformedDistribution` object is built around a `base_dist` distribution and (a sequence of) `transforms`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cf9ae717-2f7a-4722-9ce1-01484d53b984",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"reset before transform: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([240, 320, 3]), dtype=torch.uint8)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"reset after transform: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([3, 240, 320]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"from torchrl.envs.libs.dm_control import DMControlEnv\n",
"import torch\n",
"from torchrl.envs.transforms import TransformedEnv, ToTensorImage\n",
"# ToTensorImage transforms a numpy-like image into a tensor one, \n",
"env = DMControlEnv('acrobot', 'swingup', from_pixels=True, pixels_only=True)\n",
"print('reset before transform: ', env.reset())\n",
"\n",
"env = TransformedEnv(env, ToTensorImage())\n",
"print('reset after transform: ', env.reset())\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"id": "f0fdb760-bd1b-4688-ba54-da156a63c36b",
"metadata": {},
"source": [
"To compose transforms, simply use the `Compose` class:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f0a5081d-2afc-4f0a-ad8a-4df681cfc917",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([3, 32, 32]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchrl.envs.transforms import Compose, Resize\n",
"env = DMControlEnv('acrobot', 'swingup', from_pixels=True, pixels_only=True)\n",
"env = TransformedEnv(env, Compose(ToTensorImage(), Resize(32, 32)))\n",
"env.reset()"
]
},
{
"cell_type": "markdown",
"id": "566b0c94-6022-477a-9e2c-32f9009bcaaa",
"metadata": {},
"source": [
"Transforms can also be added one at a time:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "22da21c4-e9d6-44bc-996d-268bc37e4909",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([1, 32, 32]), dtype=torch.float32)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchrl.envs.transforms import GrayScale\n",
"env.append_transform(GrayScale())\n",
"env.reset()"
]
},
{
"cell_type": "markdown",
"id": "ef5a2176-20d3-4270-b3ff-a1d8b5c75fcf",
"metadata": {},
"source": [
"As expected, the metadata get updated too:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "734c07ec-ff03-4df8-844e-acca466d19e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original obs spec: CompositeSpec(\n",
" next_pixels: NdUnboundedDiscreteTensorSpec(\n",
" shape=(240, 320, 3), space=ContinuousBox(minimum=tensor([[[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" ...,\n",
"\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]],\n",
"\n",
" [[0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" ...,\n",
" [0, 0, 0],\n",
" [0, 0, 0],\n",
" [0, 0, 0]]]), maximum=tensor([[[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]],\n",
"\n",
" [[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]],\n",
"\n",
" [[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]],\n",
"\n",
" ...,\n",
"\n",
" [[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]],\n",
"\n",
" [[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]],\n",
"\n",
" [[255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" ...,\n",
" [255, 255, 255],\n",
" [255, 255, 255],\n",
" [255, 255, 255]]])), device=cpu, dtype=torch.uint8, domain=continuous))\n"
]
},
{
"ename": "TypeError",
"evalue": "Input image tensor permitted channel values are [3], but found240",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m/var/folders/zs/9lq15k8x61l1g0c_sf__63c80000gn/T/ipykernel_13887/2654911180.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'original obs spec: '\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0menv\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbase_env\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mobservation_spec\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'current obs spec: '\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0menv\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mobservation_spec\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/transforms.py\u001B[0m in \u001B[0;36mobservation_spec\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 338\u001B[0m \u001B[0;34m\"\"\"Observation spec of the transformed_in environment\"\"\"\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 339\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_observation_spec\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mNone\u001B[0m \u001B[0;32mor\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcache_specs\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 340\u001B[0;31m observation_spec = self.transform.transform_observation_spec(\n\u001B[0m\u001B[1;32m 341\u001B[0m \u001B[0mdeepcopy\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbase_env\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mobservation_spec\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 342\u001B[0m )\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/transforms.py\u001B[0m in \u001B[0;36mtransform_observation_spec\u001B[0;34m(self, observation_spec)\u001B[0m\n\u001B[1;32m 604\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mtransform_observation_spec\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mTensorSpec\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m->\u001B[0m \u001B[0mTensorSpec\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 605\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mt\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtransforms\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 606\u001B[0;31m \u001B[0mobservation_spec\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mt\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtransform_observation_spec\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mobservation_spec\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 607\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 608\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/transforms.py\u001B[0m in \u001B[0;36mnew_fun\u001B[0;34m(self, observation_spec)\u001B[0m\n\u001B[1;32m 76\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mkey_in\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mkey_out\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mzip\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mkeys_in\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mkeys_out\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 77\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mkey_in\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 78\u001B[0;31m \u001B[0md\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mkey_out\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mfunction\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mkey_in\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 79\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mCompositeSpec\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m**\u001B[0m\u001B[0md\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 80\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/transforms.py\u001B[0m in \u001B[0;36mtransform_observation_spec\u001B[0;34m(self, observation_spec)\u001B[0m\n\u001B[1;32m 1204\u001B[0m \u001B[0mspace\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mspace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1205\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0misinstance\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mspace\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mContinuousBox\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1206\u001B[0;31m \u001B[0mspace\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mminimum\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_apply_transform\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mspace\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mminimum\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1207\u001B[0m \u001B[0mspace\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmaximum\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_apply_transform\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mspace\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmaximum\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1208\u001B[0m \u001B[0mobservation_spec\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mspace\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mminimum\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/transforms.py\u001B[0m in \u001B[0;36m_apply_transform\u001B[0;34m(self, observation)\u001B[0m\n\u001B[1;32m 1197\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1198\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_apply_transform\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mobservation\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mTensor\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m->\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mTensor\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1199\u001B[0;31m \u001B[0mobservation\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mF\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrgb_to_grayscale\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mobservation\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1200\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mobservation\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1201\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/functional.py\u001B[0m in \u001B[0;36mrgb_to_grayscale\u001B[0;34m(img, num_output_channels)\u001B[0m\n\u001B[1;32m 34\u001B[0m \u001B[0;34m\"{}\"\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mimg\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mndim\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 35\u001B[0m )\n\u001B[0;32m---> 36\u001B[0;31m \u001B[0m_assert_channels\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mimg\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;36m3\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 37\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 38\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mnum_output_channels\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32min\u001B[0m \u001B[0;34m(\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m3\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/Repos/RL/torch_rl/torchrl/envs/transforms/functional.py\u001B[0m in \u001B[0;36m_assert_channels\u001B[0;34m(img, permitted)\u001B[0m\n\u001B[1;32m 22\u001B[0m \u001B[0mc\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0m_get_image_num_channels\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mimg\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 23\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mc\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mpermitted\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 24\u001B[0;31m raise TypeError(\n\u001B[0m\u001B[1;32m 25\u001B[0m \u001B[0;34m\"Input image tensor permitted channel values are {}, but found\"\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 26\u001B[0m \u001B[0;34m\"{}\"\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mpermitted\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mc\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mTypeError\u001B[0m: Input image tensor permitted channel values are [3], but found240"
]
}
],
"source": [
"print('original obs spec: ', env.base_env.observation_spec)\n",
"print('current obs spec: ', env.observation_spec)"
]
},
{
"cell_type": "markdown",
"id": "ff001409-5c34-46be-95e2-47b2f78114ac",
"metadata": {},
"source": [
"We can also concatenate tensors if needed:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cd294681-b15c-4735-9215-ea754b395fb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"keys before concat: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" orientations: Tensor(torch.Size([4]), dtype=torch.float64),\n",
" velocity: Tensor(torch.Size([2]), dtype=torch.float64)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"keys after concat: TensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([1]), dtype=torch.bool),\n",
" observation: Tensor(torch.Size([6]), dtype=torch.float64)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"from torchrl.envs.transforms import CatTensors\n",
"env = DMControlEnv('acrobot', 'swingup')\n",
"print(\"keys before concat: \", env.reset())\n",
"# make sure to work with \"next_key\" as this is what step will return\n",
"env = TransformedEnv(env, CatTensors(in_keys=[\"next_orientations\", \"next_velocity\"], out_key=\"next_observation\"))\n",
"print(\"keys after concat: \", env.reset())"
]
},
{
"cell_type": "markdown",
"id": "81b62090-d878-4cfb-8e83-dbefddaf3405",
"metadata": {},
"source": [
"This feature makes it easy to mofidy the sets of transforms applied to an environment input and output.\n",
"In fact, transforms are run both before and after a step is executed: for the pre-step pass, the `in_keys_inv` list of keys will be passed to the `_inv_apply_transform` method. An example of such a transform would be to transform floating-point actions (output from a neural network) to the double dtype (requires by the wrapped environment).\n",
"After the step is executed, the `_apply_transform` method will be executed on the keys indicated by the `in_keys` list of keys. "
]
},
{
"cell_type": "markdown",
"id": "34fb4aa3-6193-44a5-bd79-2ebf087155e8",
"metadata": {},
"source": [
"Another interesting feature of the environment transforms is that they allow the user to retrieve the equivalent of `env.env` in the wrapped case, or in other words the parent environment.\n",
"The parent environment can be retrieved by calling `transform.parent`: the returned environment will consist in a `TransformedEnvironment` with all the transforms up to (but not including) the current transform. \n",
"This is be used for instance in the `NoopResetEnv` case, which when reset executes the following steps: resets the parent environment before executing a certain number of steps at random in that environment."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ede057e5-11da-41b7-9635-bcf90ff10711",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: \n",
" TransformedEnv(env=DMControlEnv(env=acrobot, task=swingup, batch_size=torch.Size([])), transform=Compose(\n",
" CatTensors(in_keys=['next_orientations', 'next_velocity'], out_key=next_observation),\n",
" GrayScale(keys=['next_pixels'])))\n",
"GrayScale transform parent env: \n",
" TransformedEnv(env=DMControlEnv(env=acrobot, task=swingup, batch_size=torch.Size([])), transform=Compose(\n",
" CatTensors(in_keys=['next_orientations', 'next_velocity'], out_key=next_observation)))\n",
"CatTensors transform parent env: \n",
" TransformedEnv(env=DMControlEnv(env=acrobot, task=swingup, batch_size=torch.Size([])), transform=Compose(\n",
"))\n"
]
}
],
"source": [
"env = DMControlEnv('acrobot', 'swingup')\n",
"env = TransformedEnv(env)\n",
"env.append_transform(CatTensors(in_keys=[\"next_orientations\", \"next_velocity\"], out_key=\"next_observation\"))\n",
"env.append_transform(GrayScale())\n",
"print(\"env: \\n\", env)\n",
"print(\"GrayScale transform parent env: \\n\", env.transform[1].parent)\n",
"print(\"CatTensors transform parent env: \\n\", env.transform[0].parent)"
]
},
{
"cell_type": "markdown",
"id": "5bd8908e-a0b9-4844-8bc4-c95657acd07b",
"metadata": {},
"source": [
"___\n",
"## Environment device\n",
"Transforms can work on device, which can bring a significant speedup when operations are moderetely or highly computationally demanding. These include `ToTensorImage`, `Resize`, `GrayScale` etc. \n",
"\n",
"One could legitimately ask what that implies on the wrapped environment side. Very little for regular environments: the operations will still happen on the device where they're supposed to happen. The environment device attribute in torchrl indicates on which device is the incoming data supposed to be and on which device the output data will be. Casting from and to that device is the responsibility of the torchrl environment class. The big advantage of storing data on GPU is (1) speedup of transforms as mentioned above and (2) sharing data amongst workers in multiprocessing settings.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a7538009-c098-47ee-8129-c7535aa9eb97",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchrl.envs.libs.dm_control import DMControlEnv\n",
"from torchrl.envs.transforms import CatTensors, GrayScale, TransformedEnv\n",
"env = DMControlEnv('acrobot', 'swingup')\n",
"env = TransformedEnv(env)\n",
"env.append_transform(CatTensors(in_keys=[\"next_orientations\", \"next_velocity\"], out_key=\"next_observation\"))\n",
"\n",
"if torch.has_cuda and torch.cuda.device_count():\n",
" env.to('cuda:0')\n",
" env.reset()"
]
},
{
"cell_type": "markdown",
"id": "288f91d7-6736-46db-8e06-4eca34711d0d",
"metadata": {},
"source": [
"___\n",
"## Running environments in parallel\n",
"\n",
"TorchRL provides utilities to run environment in parallel. It is expected that the various environment read and return tensors of similar shapes and dtypes (but one could design masking functions to make this possible in case those tensors differ in shapes). Creating such environments is quite easy. Let us look at the simplest case:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ef7cbd08-e0c3-41af-b367-cf08cae9adc0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"from torchrl.envs import ParallelEnv, SerialEnv\n",
"from torchrl.envs.libs.gym import GymEnv\n",
"env_make = lambda: GymEnv(\"Pendulum-v1\")\n",
"parallel_env = ParallelEnv(3, env_make) # -> creates 3 envs in parallel\n",
"parallel_env = ParallelEnv(3, [env_make, env_make, env_make]) # similar to the previous command"
]
},
{
"cell_type": "markdown",
"id": "d6d4f2ae-35da-41c7-94e0-4e6fd7311918",
"metadata": {},
"source": [
"The `SerialEnv` class is similar to the `ParallelEnv` except for the fact that environments are run sequentially. This is mostly useful for debugging purposes.\n",
"\n",
"`ParallelEnv` instances are created in lazy mode: the environment will start running only when called. This allows us to move `ParallelEnv` objects from process to process without worring too much about running processes.\n",
"A `ParallelEnv` can be started by calling `start`, `reset` or simply by calling `step` (if `reset` does not need to be called first)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c3e4766f-9975-4cc0-96fc-fd4a7344337d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LazyStackedTensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([3, 1]), dtype=torch.bool),\n",
" observation: Tensor(torch.Size([3, 3]), dtype=torch.float32)},\n",
" batch_size=torch.Size([3]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parallel_env.reset()"
]
},
{
"cell_type": "markdown",
"id": "a5ecee3d-5e87-4351-bd03-a979e6e8bc79",
"metadata": {},
"source": [
"One can check that the parallel environment has the right batch size. Conventionally, the first part of the `batch_size` indicates the batch, the second the time frame. Let's check that with the `rollout` method:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a764b5c2-a17d-49ff-9cbb-b77903b89cad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" action: Tensor(torch.Size([3, 20, 1]), dtype=torch.float32),\n",
" done: Tensor(torch.Size([3, 20, 1]), dtype=torch.bool),\n",
" next_observation: Tensor(torch.Size([3, 20, 3]), dtype=torch.float32),\n",
" observation: Tensor(torch.Size([3, 20, 3]), dtype=torch.float32),\n",
" reward: Tensor(torch.Size([3, 20, 1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([3, 20]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parallel_env.rollout(max_steps=20)"
]
},
{
"cell_type": "markdown",
"id": "8a18c530-d8ac-4d00-bcd9-02e8350005f1",
"metadata": {},
"source": [
"### Closing parallel environments\n",
"\n",
"**Important**: before closing a program, it is important to close the parallel environment. In general, even with regular environments, it is good practice to close a function with a call to `close`. In some instances, TorchRL will throw an error if this is not done (and often it will be at the end of a program, when the environment gets out of scope!)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cb805b61-b29c-485c-b224-94ad8bdba05f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"parallel_env.close()"
]
},
{
"cell_type": "markdown",
"id": "0fcdb552-3c82-4be2-b5d2-20d87362669b",
"metadata": {},
"source": [
"### Seeding\n",
"When seeding a parallel environment, the difficulty we face is that we don't want to provide the same seed to all environments. The heuristic used by TorchRL is that we produce a deterministic chain of seeds given the input seed in a -- so to say -- Markovian way, such that it can be reconstructed from any of its elements. All `set_seed` methods will return the next seed to be used, such that one can easily keep the chain going given the last seed. This is useful when several collectors all contain a `ParallelEnv` instance and we want each of the sub-sub-environments to have a different seed."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2c10bc47-c386-4c00-b07e-97ee1db28316",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3288080526\n"
]
}
],
"source": [
"out_seed = parallel_env.set_seed(10)\n",
"print(out_seed)"
]
},
{
"cell_type": "markdown",
"id": "52c84cdb-f024-4c88-a462-7f50524d80ac",
"metadata": {},
"source": [
"### Accessing environment attributes\n",
"It sometimes occurs that a wrapped environment has an attribute that is of interest. \n",
"First, note that TorchRL environment wrapper constains the toolings to access this attribute. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3f317630-6ee7-42b4-89ee-01f5bee14f5e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"from uuid import uuid1\n",
"from time import sleep\n",
"def env_make():\n",
" env = GymEnv(\"Pendulum-v1\")\n",
" env._env.foo = f\"bar_{uuid1()}\"\n",
" env._env.get_something = lambda r: r+1 \n",
" return env\n",
"env = env_make()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "16d5621a-14e3-427d-b3d5-1ffa497a117d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'bar_542ef942-3257-11ed-b93c-aa665a2328e0'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# goes through env._env\n",
"env.foo"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1ddbb4f0-7418-4b91-97dc-808c0cb268af",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Aargh what did I do!\n"
]
}
],
"source": [
"parallel_env = ParallelEnv(3, env_make) # -> creates 3 envs in parallel\n",
"# env has not been started --> error:\n",
"try:\n",
" parallel_env.foo\n",
"except:\n",
" print(\"Aargh what did I do!\")\n",
" sleep(10) # make sure we don't get ahead of ourselves"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "702838c3-7aa1-4af3-974f-e312629940e6",
"metadata": {},
"outputs": [],
"source": [
"parallel_env.start()\n",
"foo_list = parallel_env.foo"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "965b8d5f-f549-4e38-80a9-fe82a571b209",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"foo_list # needs to be instantiated, for instance using list"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "15e0e5f5-d1f9-4f55-8437-8e393b4754f5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['bar_5cdf70ee-3257-11ed-acfd-aa665a2328e0',\n",
" 'bar_5cdf70da-3257-11ed-8393-aa665a2328e0',\n",
" 'bar_5cdf7102-3257-11ed-8191-aa665a2328e0']"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(foo_list)"
]
},
{
"cell_type": "markdown",
"id": "da844a71-f313-4e42-b352-0ec54a1e3b58",
"metadata": {},
"source": [
"Similarly, methods can also be accessed:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bce02ca2-b0fc-47fb-b57d-125410d8979e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 1, 1]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"something = parallel_env.get_something(0)\n",
"something"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "cefbe2dc-9906-4afc-950a-3039f8eebdca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"parallel_env.close()"
]
},
{
"cell_type": "markdown",
"id": "521d423e-6468-4ea6-b1b6-ac4befca8d05",
"metadata": {},
"source": [
"### kwargs for parallel environments\n",
"\n",
"One may want to provide kwargs to the various environments. This can achieved either at construction time or afterwards:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "3787fd8a-dfee-4006-8870-6be019d8dfc3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]\n",
"A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATkAAAChCAYAAAC8o8hrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlzUlEQVR4nO2dW6wkx3nf/1/PzLnvLndXyu5Gy5tD0gEVwZGj2HREBIIFAY7sSH4QBClCQCcC+OIEMmDAppy3IA92HmwJQRCAkGLoQYDkyIYkC0YMm5IixEBora7OakmLlHfDpbjcJbXLc/Zc5tZfHrpnpi9VXZeu7q6ZU7/F2ZnurstXX1V9XVVfdQ8xMwKBQGBViboWIBAIBJokGLlAILDSBCMXCARWmmDkAoHAShOMXCAQWGmCkQsEAitNLSNHRL9ERM8T0QtE9JQroQIBILSvgBvIdp8cEfUA/C2A9wC4DuCbAD7MzD9wJ17guBLaV8AVdUZyPwfgBWb+ETOPAHwOwPvdiBUIhPYVcEO/Rty3AHgpc3wdwM9XRSCi8HjF8eU1Zn6zQXjj9rW51ecTJwcgG+kM6fcibK33ERGhH8nHCsPJFHtHI2CFWj6DAJwB8xaANQC93FXMayAG0Y9BOGhABuTq+earR9L2VcfIaUFETwJ4EgAiEE7hRNNZBjzkNnavNZFutn3tnBzgQ0/8g3KgtEcUO4aQWaCKwAzg3MlN/Oz9b8b2+gCnt9cRMzCZxgAl0SIi9CLCizffwP967mXEXJ2ml0jkZe5jMv0w4vhnwHwfgJMAjgBMkEwOIyTGb4jB4D+iF32nnKaGLuZBZHJkTv+X/3xZ2r7qGLmXAdybOb6YnssLwvw0gKcB4E10mt+Hd9XIMrCs/CG+aBrFuH2dO78pHi9R7qMaKnxWBMlyOJrgxhv7ICKs9XvYWuvjzPa6OqIRplayplXVjhojoucBeg3ADsDriPmnAAzkaWqkTaUv4uuqQXIdI/dNAA8T0YNIGt+HAPyrqggRCJvIVzyB0Ecv/Zan/Rtf3RxbboQta0iUW4MSGLcvmVxFGTk9YyW3pMDj6RRvHIwQRYTNQR8RASi09fqYStxW22CAXkdE18F8D5i2Ab4IoZFTp4RkQky5c7K6hOC4iLWRY+YJEf07AH+OZFL+35n5clWcCaa4hdu5cwP0cRan0MvN6xOowS4kTrluXm03wnL8Js2eqD6aysumfcnkKk55cuatSmEGPWrvaIznXrmNfi/Cqc01XLhnGxdObeuKu+RM0Yu+iSj6FpjvBfNZMD8A5kL5NXSdXM7XD2UObdpbrTU5Zv4zAH9mEicqdZIqsZvrxD6PnxgAb20g3lgrXYv2DxENx9K4ruRq5iZgKINF+5IiE92s+SUI5kdxzBhOpoiZMZ7GmMSxoYCF9IV56y5oKcI5b8AM4BDAHoB9AJsABOW30bWOnIr5auOOh3xmfbwZZ3LnCEBksJPFh3XbNmQ4euuDGP70/fmTzNj6P/8X6y+Wlqac44OerTBxMlQnIccmYROBpOF0F7QU4eqsqDjwEjupm2wiisRaNXIEpOtvugVdEXeURXheGyDe2kjG6kQAMxDH4H6rVVaUCqrydV5jWSeDxiKizRpPHbmkmVZfaAiNkZ5IOTUNnaiEJUkkdUeahi1LZ8+u6snYfIWb1VeLm52mMWg8AZjBUWLkaDoFsXga1I5k6vro+paU04NIGKo8bAflSK0tHIz0HFHKUlF3XPisosthgQUL8159z9O/I+qsbOiF1rmuHz66e4j+a3fAgz641wNNpsB0CjoaOci5YgSZu9L5uMwYqbSyoij2wylHFylr/R7O7mygH0U4sTnA1lq/sfmy7oqcfQBTIjBfAPPDYD4Pxj0QeZYrszWsH5MBnedGrlhCEnwTYdrl8y5ru1TcsvH8Nay/eD0VIp0jMECjstNBJr8cedhqo748Rq+kE5m1qiiO1ugi5e+d2sR73novQISIgB5ZblFR5KNxWZ2vhWDVNT/AZPoRYDpCYlIiJM6H/A1Zf51TXj/ltq5uk54bObNOZtsFTZvjbIjcRHefpz2eJNNVDWy7k/kK23IYuIQqWettTmIkTzfcHY4xjdNdd5LE9oeT0pRKdxzdPOL81OuUDKI7ILqRXtmRpH8EwpGFRNV1V31cpjMjlxRGts1PVtn1x28mzaiLEV71NER/xKZTTvsVNnm9mY4pm8Kk7EKPXfFaIe7e0RiXX/4JelH1s6tH4wmKL/rRHUc3jzi/qlljUpQYUfTXiOgygA1AsMcVaTiil3LxRYkaOEqtNNS6kZsVKBFWNkpobppkNqnzobsuMJHHreRVI7r8d9VtqkmmzNgby/cQuqQ3iXAUx4goeU5VxngaJ4auFalaggnAGwAOwfPpqTAgCAcAtVMnMlo3cuYTBbvNJhbrzCuFbvn19DG7h5sY2fY1fTSN8fztvVbyovlaqdoJsJq/bbyH7HBFToxWdyUIaNXITTHFT7CbO9dDhB1soye9G+h1lCTUomPJYpl3cBOaMeA2qemUH0aTS1M527+VMBgTLYPiQO88/08/1WW7wyrcoc0Vx23KrRq5uzjEs/ib3LkT2MI/wVszD+5nZ/+mDgH1dMlk/m9Ot4bAdCRrt35mUj/+9Oq8JM2MM5Xp+aEKfRTyNuV4kz2cLwutolUjl4zatnLntrBZeJ5VzwSpPEB6I5km6XbCrFN+O+nE9SPeoO7PqmZVW2msRkw22nmJxuMhjnOpqqcyesK0auS2sYmfx9ty5xavWjLDRtdNbv0o51RlZppv7LZltZVKFK+dktanMflKCfuuiSJUedhQLs5p+dlVwpriHVO1d3NX5t80ed+xWpLmTEDdVF3VQ5vdukeEEwPPt34GWse7FqFrHvzE/8V5XZaxHtZ7PTx0SrYxNbDKfLXimndGLhCwhQAMKjbmBo4nwcgFVgofR5iBblkKI+fP4nX5YbTmcsl/6xI/pFBDAKji6YPaFB5DWha9tEVJL54oyNuxfXab5azDd79vXL3D3V0ubeRUDaf/F7db2KTiDbbiCF6M4UH/9YqSXmYWzwaHzaZ1I6cre7kBud1zpSsHFz5N4spT0IthE9ulSRGZdN06WMjRkSmQKaK8mU8P2aZCk7RqVo7p02G1nyYzLY8ovHmDqY5nUabWjZy6zO3c+bObOKryF921zbqt+X1fb2t0+ZEis1zqUl1PXY1y5lLJK1gcTnSVS2fymO1crTeyAaSvc3IVPofOVFOj/Cw8ECiBpFfyF1ThBHg4XZV35ybMX7ndUUU+rm/VduklZ3UffbGnWu9+TtaYGfvjCfbHE8TMUjGrp+DpAgkVzxQzMxVOlpBm3Dbj2Rhkue1KLpPorGhpSiJPJrCJGjsxcqa6a3oNRP+G7Hr3mF16pgMIW2z13uUq3DiOcePgCK8eHmEa680nZQaMMt9Lg5CZE0JR2Nxly8VNlgqpARnWh8TiKG/HWa9DJgJnvlPme0FEDWGUgaV04l31bQxgfoNtz7/qPnTzdClLRITNfm/+I2d5JJKJFFhYMygNQlIDp5oSKkcmKjkAUM0KlsZXlVt9enG9qI+sk0a11iMtn5uWtBRbSJrGXJXt+Vfdh15t+kQ4v7UBwEAvlgoUGTjR9hJtGyULVLeCa6YrKovQPtYYbTaJh2ty3ePZpocSvsvXJUSEKP2rvWfOYjFYZABc9OHstK9SpkIYF22lWBbjgWWTi+oaLKWRa1pXVa4HH2haPr9Lb0DdgogWJTXStDIAOmnqWE7ZNNuRHFVZS9NztahuWZ9eG7niDaCpzldenvZ7AliUr2m9dHwjtqM4x6oKWt4xUkpqjmqbhCm6BpQtMlLF0XJr6mVTQsNpMW9XupsMLLfgeG3kijeAprysbXkrXdG0vG3pvVF0Frtnl9POk+tD2esaWen2vUp/b1VGpCGIaZwKQ2RSHqM9ApkbyszbSipDX3Pu36mRa2JksFSjjQ45VrqXdJyigcl1WIPOZLJG1cQNyvW6m+i4Kp5R/hmva3FrjHSrTM0CdmrkTCpYt5xLNdroEF09mbQvb3WvuS1CNOKokbwYrTmxGbX07kAe2/JrG9aaDUtp5IjoXiL6GhH9gIguE9HH0vNniOgviOiH6efpeqJU67UJg9gEttOWNnFluJyMIFpsXzO01ho1GpxV+SscBqKpcmWGpseF89KRU9X5CrT0UbGm2dTar85IbgLgN5n5UQCPAfh1InoUwFMAnmHmhwE8kx7XYrFOXK+Y2SG07VpDnbxdhrNFVu6a+0rTNJLUHJWhtfY1w9Vao6Zfwzg9bcNjelw476oNZv0CNjS99qs0csz8CjN/O/2+B+AKgLcAeD+Az6TBPgPgV10J5eJ9IzKFeTulcoys3C7K7/J9MK20LwMrZOPEdKYNk6FMnWmm4yGTafmNddzmmhwRPQDg7QCeBXCOmV9JL90AcM4sa2+XqQO1sK9XJ+1L5Z2ripd6WEln20UTzVf3TlQ1HNfxBDTpLtfQD/HCs6qlx5rOCG0jR0Q7AP4YwG8w824uP2apuET0JBFdIqJLhxhmr+hLaUwwoNU0qR+7enXSvg6n9eZMMwORpiFdq1PM+ZVra7bozgtdz6ML+VcmW6EfFoWpO8fVQMvIEdEASQP8LDP/SXr6VSK6kF6/AOCmKC4zP83M72Dmd2xiXV+yWpSfCeD0/7qOgbbMZ938ZeV1uprmCGfta7Pi93u58jAjjPCrNEzlZddqttkn10D+ymR1vNmSMKV6UYyqdfqDjneVAHwawBVm/v3MpS8DeCL9/gSAL2nk1xriNSm91STVbCAbziXZ9OrMRhbxy+X1y7y12L6o8tCOJhuA6LzEUEtniKqpY1vyG1Cql6qKIr161HkLyTsB/GsAf0NE303P/Q6A3wXwR0T0UQDXAHxQIy0lWQPDYKeL3Lp05SH1xTMrI1sfLry0Ka22ryrybc9+xGKN6s4qMdTSJTbV2lsHDdhYxw5QGjlm/t8VsrzbrTj5jGwMXFuKW3Zs9JStD2fbD1puX1WQ5Ls2qVKt22Dbjdcyv3k0i/i1dWyB18+u2tCs4tp2aDSXX7gRNECN/WeVG3ObguxaWGNrjqZoCr9yRq4ps9DNgn1zL1UK/ucW0FRyl7MP3TVewCRgS2gqzTMjV0eLRTPkska6e/nSYgXMFS71VNdX3Samz74YWCjZcXakJAvHir7qSnVV+9YK8ojilUaapurUDmiSsF6mnhm5OqbEnR+xrLqux+X5/Ou1e5f+Vt9dJVlMy60psyJZ6dRO5RQwFEOJbj4SOZVeT205TQuscK9q0LmRc+XdlnvLzU2CD12yCjv5WHhkq2fd851jPeJwnL+3ClLQlfwO661zI+f6RlVOz3eT1RYkPGproNAZXW8UdK3oLG0Y8Cbl18lXdmxAB0ZOrya67jS+3ni7lku/XrqW9BjQtQFfEjowcstRE75K6atcZZZHUm/oekq4onQ+XQ0ElhWe/2cbuUDXU0KgVnm6spmqfIORCwRsmG39oPlh4YsCXwe6unJx7mPxHGkHlk4lckdGjiFeNTXZE+Non9Oxw/U+JdN6bAOWfFcEVYde/HxecauIrtvZd2+rYr9ckVK5Z0bfTO0a+jCo0wI6D+g3gMj2utw/o3M9YVWeddUvR9P7lHzQpsETkrJ9bbLgqofdVerrylupi6NySPUkiMa8jjh+HMynAXoDhH1E0fdBdFcSw0x5HRg5F2bFnWnyta2Z4rYcftVRPfJyqKWahWhAfl9Uoksj8or0u4HJ9F+C+SEQXQXRqyC6VjBy9sIsoXd12VrKMmL0RGNFGj4g3h+oCm/8IJ/OKoAvKtFFpxkYNhNWDQErhbGTYQkdD8vWUpaV461n49LLIkjW8IyWoFxiulZGsguF65o00qoUiS6hkVPh64qubwQ9tYJszcoyXm3Mlii143WGRjNeQSPna234RtBTDsc2f1l9/63L3cKqyAoaOV+bj28EPeVwbPPd+P7bp3W5XSzRK1hBI+dr8/GNoKdWkHTCpVmT04zXGas5kvNV26vG8dazcelVC/MsPi3FlzU5ieOkdF2TRlqVf95VF5Pw490Bm8fVPjkfEL9HTxXe+MXzOvM8X1Sii04zMPau2j7yURFeIUMHm4Fd3KLc3eaWcfuSCLfl8KuO6mG3T64R+X1RiS6N7vfIJn6Efu9Pc088EL3hTJiOHusSjYF1xsXF8FUF1+v2y9buZOiXw1RvVeFt6rENsjIrylu4rNQOSx5Zkj0oITv29e7qqBxSPQmSIxqiFz2j0IdBnRboaE2OIH4YrkpwUXhVHoEypnpT1YlpPbYBSb4rgqpDLzpu8cF+WUSZOrtWkQyZfBJ5S+VO9WLy7GpV+uIAZspbQsdDIOABlPZnnh8WvijwdX1OVy7KfQCcRu3AeKtEDkYuELCE5v/ZRi7QtuET5VejPF0NTlX5LqF3tR18ldJXucosj6Te0LaV8HXK7Bhv30LSdRfxtf67lku/XrqW9Big+WLL407n01VX9SLffRNqPkG8X8y1/r2ja0PgWtFZTH1xNjQpv06+smMDOjdyuk4p3XTK8cxr3tsOm2Inn3i/mK2edc93ThuGQCd/bxWkoCv5HdabtpEjoh4RfYeIvpIeP0hEzxLRC0T0eSJa00mHJf9mV+1xZ/rL+uza7OXzr9feXA5tdOOqw7lqX/oyOHrfhiJZVoVr67UfuvlI5CxFt25GpgWuCq+XqclI7mMArmSOfw/AHzDzQwBuA/ioKoF9HOISLuf+LuNFjDFJQ9TpvklcLhy7wfghH2fMHjJyh0s96cbVCle7fZnJ4GifZVWyXLG1JDNCqmxbrqpesTFXtT+O5gGrw9USxDhhvUy1jBwRXQTwywA+lR4TgF8E8IU0yGcA/KoqnSFGeAEv5f5ewo2MkatPU6Pqomng2R8BTLQ4dpqr8Uu4tfFp9uSqfXmHgZ3s8ibqzD61jabSdB/r+gSA3wJwIj0+C+AOM8+s03UAb1Elso1N/FO8NXduDWtYw0BTDDXNPi2zSHly4U0YPfj3ge0N0Mlt9K6+gv4PXwINx8Bw5EiG5kri2VNFn4CD9tU5qVJtdDsfKbVZKYpHryqidSNvEc28lUaOiH4FwE1m/hYRvctYDqInATwJACewjYdwX2X4/BNqDDLUYlM6L940pqd3MHzkXtCZU6BzpzGIIkQv30LEDBqNkmf3GpLFBTayZevDVft22r5O1rtZ2j8dORMm92FO14v7ptHqGEi0ZyN1RnLvBPA+InovgA0AJwF8EsA9RNRP77YXAbwsiszMTwN4GgDO0VnlADP/hFo3ZkKk/PjUDibnzmAyHGF8cIT+7gG2/+r7mJ4/i8n95zG98ToOCYgunEV08j4MXnkd/ZdvGZVAt9K7uoFm68Nh/u7a1/nNWrM+46cjXVeELD3Jw/DF06XoqpcBtCV/BmMdO0C5JsfMH2fmi8z8AIAPAfgqM38EwNcAfCAN9gSALzUmpQViHw2XzosQrZHEWxsYnz+D0ZtO4fDUFuLhEGsvXMfg6o8RvXQTvHsX44gwPrWD0f0XML1nx1jmwpq1EBNDKCpv137iIq21L5nXsw6ue6lqfw5Vny5FV23/aEt+A5Re3MI1nXqs86ql3wbwOSL6TwC+A+DTNdJyTHmaS+n/tv7A3p09rD///3D2vh5O/bMBjm6fxd2bF0Fbm+id2MLmmSm2Hh5j/5U93L10G7037grT1aVuu5SV1+1ks1Hcti+JgSiRUY1USwr1NbZmZZpeQ/krk5UEyJ2WhJEaahGkVzwjI8fMXwfw9fT7jwD8nEn8Qmpo2hfqkuhwiOhwiJP3bePiA6fw2qktHGyenw+FT5xcx5nTA/Ru3sXo2uvO83dLkwbOvl6dtC/b7GdDgsz0rtLAVeRTZ82qEt3FrNn1hvLXMnACGUsGjosXDDCo545emgkswUhCyO6Phrj6xdsYjg4wPVwYs73BCKP1EYa3Rh1K5wMd16vt2lNm9FY5QlBNAeugYUDneddZZ9PNxwYN/czeQaft2RXJaSB3h0ZOjo1XtZyGuB7r1uvw9QmGr08AHAK4szif/vmCrNwu2rWL+mkVA1EbHlToZa6TWJ31DMeG2rT8xtnWlLPzZ1ezLEavbgwcUOsGYJ23y3C2yMotcqqYp71Y2VtGWPJpm46rNsWlL7IAlseF867qLzuLtsFVfcjwysiplhlcpNM0to6NNjFdu3aRjk8ovZKAVoOzKn+F4VGu5anu2Lp39Ozamq5B1UBLH4J0teqjBp1OV02GubbT92htDYMTp+Zn49EQ4703tGXsrW+gv3Nyfjw9OsRkf087vq/o6r5LZ54zdLx9xWODghiVu84000X+uvm2UH7tpaSaDatTI9dEh5g/Sk8ARYxoLcJgZ3O+yjk5iDE5YDCzcuGTGYjWexic2FhIS2NMj2IwA4izCfjQvTO3yai6fLXX5WLKZEdO0mwMyQI9UfnYZvuHyVqn6zXiWf51sZXLWP7iNpRMAlK91yygl46HGcXFcp3FcxpMMTgxxNrpA+w88iqitQi9tY359XjEGO9NgIgRDaYgYiDKGAcmcEzgKYHHPUQbEQY7vUX8yRjxeISjV05i/+pZTA8HmB7UfAuQIVUNsr8zRLQxxomHb2Ht7H4j+cejHvaeO4/x3gYmu+vgaU8dqW2ySlF0HC4auOJ1QXRZVioq+29VQjYLgKo4RQMjk6sCHQMv0us8HiH/84WSG1JlXSrw2sjZzNUpYvQ2xhicPsDOQ68h6se56/EowuBgDVE/RrQ+AUWcG/VwDHAcgccRpsM+orUpeptj4ajo6NWT4EmEqW0BLaFCTeen51P0t8bYvHgHWxfvGKSq33qmh30c/fgU4nEPk7traF0BOmR7n6JYJOno2aTmVO3/sqHKS1Y3E1UcR0NK1RRTpVflVpIaBg7w3MjJqNOoqMforScjOYoYoMJKKCEZ3fVj9DAB9bKjvJqZO6L2Bg5hOSpS9KTcxtSek6M8VNFI08kanShNnRFNIQyTZhau1+BEBrRuO7KMu5RGrhbEoH6c3uFZrDhiUASgF+ensp50dFsxOGvPs9/TjsAFez+/7Em5W8dkCpEi6s/O1t10loALYVyv11nZKws9usSrLSStQAD1YlAUVzeUKDGGpZFcK5hlZBQ6JvA0ykUqqYGRrElOj4l1s6xX0U1B1J+1tehwO4fLdEVlEQ4i626Ua4hOjFxrtkIAEUBR+jc7psJ1yoTLXm+tz5tlZLQWHVPiFWYqVER2h2h6PY6sKqvL+q1GNlRVBOXMIS8+Z4vmFjmKL0gqkmu2O2l8VbnVpxfXZ86DjH7mH7IZRJUcWrnq0YmRM62zgt6aE6R2BqYJKJuO1lmTXONxD9NhHzyNEoM3Z+Z5SbzL8aiPeNSz6mFejP+EShFLJgua3SFDme/zz+I2FFWOLLugkUaNUZJRfRTX0gSnhdEka5fZaTZnvhdE1BBGGViKh9PVcimcDKQEyimdInE4fUwlNHF/Fc9y7lh7xJBOVxMDVxzNpSOTdLoaTyJw3WFEy2QNU/6EJBxEWk6uZn1SBjZTTp32ZVsN1gu4FnEVNmmhz7wSJHZVnBApwglo3ciphWuoUwmSFebkPHvzcahodC83/aIjyQVORnJxOpLLb2aehUmMYDKS65fDaBajqynrXFounlCEE12l0pk8pkPqmh4I0zUv6zUyQM8ga5SfhAdya6hsx6pwAlo3crrClfWl91Zf/ZQ5fyjPuCbm41DdhV4beBqBJ8lIjrPzh6xa5o6HqNxRNAVZBOvI3MnktCyP1lqS3cBcG1Mvd22vuGl5DNb4lGnJ4lmUycPpakJ5COvq5T6iRQNJxh3Bgm+uiMc9TEeZkVxxGsAEjtOR3LBfXpMzFskDhWapM+0r3Bf9dbB0Q0kvdUauDpuNt0YuS93y2i3vdwcJvtUhV76YwJPEwM1+1TXvZE0MH8dRutUkI4PvP0FWA602ULgv5vaRSyYGq4poeW2mj6pZqTQdN9IIWQojVxfV1GlF+y0WY+AZVFqTK9ktRupdTcNln4pbsV3Bls5OcSTJxGBVERoyi8JrOR2MUilzLIzcgrI3UarZlbgll6eaPI0W3lWRx5lp4XyYRos0OPO3LCg86tbOThbE89whY4q2nEXvvFHkfDylwbNUnndGrvGpZdUanOrWrszcVLr2m/xinxwttpHkRKLMSK632EsnCOolijoUVr9pNZAgnqgdKbyNIueaHnInHBumIzuq9CZXlJukkavF0FoalylPUdxWn11lMI4a/iWECBPE8Qg8HePgKAb1YnUkCw5GUxzGI0yYMIE/rxqa8gj9eISD0RR8mC87x4SD4RTj4RQ8jDEYxOhFeR1NjmLEI8LhMEY8JvSOYvQLb3KZDmMcTicYxiOMeAD26TUkkt0Jlf3O2HinKcriCVziDOQfsBdmqyuI3AlHhulU5s7IP+CvcvXn9uUYKFURVJhabppcbeVaNXJ3sIcvJ7841xg0YtBPYkR3JxjcOGps9DE92sPk7svJ1M+jAXG0F4MOYvT/aohobVK6Pj3cBU8iRFeTN6wUn+GdPfY1Hd4CmNC7Oi7dKDgmTHb3EE964GnPrymYoEe4bwLVKUq7OFVdVLOIqkokuV5rW15FZHW6bjUuHVVqejhaNXJjTPBj3Go2EwYwSv/q/b6zghGAZl5KWYtJ+icdMKc/maj7Bvdd2QWffpssg0n/mnWUmpt0dZIklE+aZqvvdTcbzwnlSQtR9DU1oDK7RA0y92cIEgi0zcz3ojFdMkhSuMYkDesCR0NpbZkNdxIpxdOsB1uCkauAiDAYDLC9vY319fXctcFggJ2dHayvr4OIQCu2veI4UOkhze7/qopb9C4KRmuVCbmghXSL5Sg9CVOxT7C4RUmcaHNuuOP30kxN1tbWsL29jUceeQSPPfYYLl++jG984xvz629729vw+OOP48qVK7h06RIODw9xeHjYocSBEhmDM/+aOUeZYEBhakaS2VPR51Bc/2tyPdB0juhwTlnyfkrcoVW+By7Gy+xikuk5l4xlecJITkIURej1etjc3MSZM2ewvb2dG61lz/d6PURRUKV3iDyZkk4i2kpC5VPanaxiF4k9ph3coYU1Lo9kLa94Uipi0WgK0tQljOQkDIdDTCYTfPvb38aLL76I/f19jEaj+fXvfe97uHbtGvb397G7u4s4bmarSsAOPf9jQpUBlE1VK/tbZsHeLL7T5XwNxPmJzuZ0pRCzKr7oZPaaSb3pEoycBGbGZDLB7u4udnfLLsa9vT3s7S3/j0yvLlXdod5PASljKjyo8vhtr+uK81MaZoWYdUqh3iIjejaiOscwxwqsJCUjJvMAVMy/Kh0T2nJYoMin9lNBHZWjMlvN+infnNSSaY3kiOgeAJ8C8I/SbP8tgOcBfB7AAwCuAvggM9/WSS8QyOKqfcUARlOPnr4IeIHudPWTAP4nM3+AiNYAbAH4HQDPMPPvEtFTAJ4C8NsNyRlYbZy0r6PpFM/dCUsIgTxKI0dEpwD8cwC/BgDMPAIwIqL3A3hXGuwzAL6OYOQChrhsXzEzDiblkVzby/nmdC1h1/lXU1c6nTW5BwHcAvCHRPQdIvoUEW0DOMfMr6RhbgA4V0OOwPGl8fZVt/tK15Kc7Q/J79hr+llgLn1zZOAkgtctT13pdIxcH8DPAvhvzPx2JA9sPpUNwDz7jSeBgERPEtElIrpUU9bAauKsfU2OYlkwAfpdT7WXyx2kmWw91wMJvjmhag+iNu7rT8fIXQdwnZmfTY+/gKRRvkpEFwAg/bwpFIX5aWZ+BzO/Q1uqwHHCWfvqb2Re8qmkma7XDirZ/Zl6muvNff0pjRwz3wDwEhH9dHrq3QB+AODLAJ5Izz0B4EvauQYCKe21r8JDk4a9zx+zsVwY643n/2VP1ELXu/rvAXw29Xz9CMC/QWIg/4iIPgrgGoAP1pYmcFxpqH0JnmuYPUDZotXye1lfTavyU/ofZ77XlETLyDHzdwGIppvvNs4xECjQXPtadAi9PtNMd15mAwc0Kb/glS2ie1JNScITD4FjQdWzk4pQgcagykPJKWOCkQsEAkuNatWOuPT2u+YgoltItgi81lqmZrwJ/soGLLd89zPzm5vMfAnaF+B3HfosG2DZvlo1cgBARJd83U7is2xAkG9ZZKjCZ/l8lg2wly9MVwOBwEoTjFwgEFhpujByT3eQpy4+ywYE+XTwQYYqfJbPZ9kAS/laX5MLBAKBNgnT1UAgsNK0ZuSI6JeI6HkieiF9CWKnENG9RPQ1IvoBEV0moo+l588Q0V8Q0Q/Tz9MdythLXz/0lfT4QSJ6NtXh59PHoLqS7R4i+gIRPUdEV4joF7rUXWhfVjIei/bVipEjoh6A/wrgXwB4FMCHiejRNvKuYALgN5n5UQCPAfj1VKankLyR9mEAz6Dw2p+W+RiAK5nj3wPwB8z8EIDbAD7aiVQJs7f5/kMAP4NEzk50F9qXNcejfTFz438AfgHAn2eOPw7g423kbSDjlwC8B8lvC1xIz10A8HxH8lxMK/IXAXwFyRMurwHoi3TasmynAPwd0jXdzPlOdBfaV2hfVX9tTVffAuClzPH19JwXENEDAN4O4Fn488bjTwD4LSS/zwIAZwHcYeZJetylDn17W3RoX+Z8AsekfR17xwMR7QD4YwC/wcy5H1jl5JbRuvuZiH4FwE1m/lbbeWtS622+x4nQvqxw2r7aMnIvA7g3c3wxPdcpRDRA0gA/y8x/kp7WeiNtw7wTwPuI6CqAzyGZUnwSwD1ENHs9Vpc6rPU23wYI7cuMY9W+2jJy3wTwcOq9WQPwISRvfu0MIiIAnwZwhZl/P3Op8zceM/PHmfkiMz+ARFdfZeaPAPgagA90KVsqn29viw7ty4Bj175aXEx8L4C/BfAigP/QxYJmQZ7HkQx3vw/gu+nfe5GsTTwD4IcA/hLAmY7lfBeAr6TffwrAXwN4AcD/ALDeoVz/GMClVH9fBHC6S92F9hXal+wvPPEQCARWmmPveAgEAqtNMHKBQGClCUYuEAisNMHIBQKBlSYYuUAgsNIEIxcIBFaaYOQCgcBKE4xcIBBYaf4/bV59SsKqNxoAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from torchrl.envs import ParallelEnv, TransformedEnv, ToTensorImage, Resize, Compose\n",
"from torchrl.envs.libs.gym import GymEnv\n",
"from matplotlib import pyplot as plt\n",
"\n",
"def env_make(env_name):\n",
" env = TransformedEnv(GymEnv(env_name, from_pixels=True, pixels_only=True), Compose(ToTensorImage(), Resize(64, 64)))\n",
" return env\n",
"\n",
"parallel_env = ParallelEnv(2, [env_make, env_make], [{\"env_name\": \"ALE/AirRaid-v5\"}, {\"env_name\": \"ALE/Pong-v5\"}])\n",
"tensordict = parallel_env.reset()\n",
"\n",
"plt.figure(figsize=(5, 10))\n",
"plt.subplot(121)\n",
"plt.imshow(tensordict[0].get(\"pixels\").permute(1, 2, 0).numpy())\n",
"plt.subplot(122)\n",
"plt.imshow(tensordict[1].get(\"pixels\").permute(1, 2, 0).numpy())\n",
"parallel_env.close()"
]
},
{
"cell_type": "markdown",
"id": "3b7d913e-4456-4ba5-84c1-eb78f0d58933",
"metadata": {},
"source": [
"## Transforming parallel environments\n",
"\n",
"There are two equivalent ways of transforming parallen environments: in each process separately, or on the main process. It is even possible to do both. One can therefore think carefully about the transform design to leverage the device capabilities (e.g. transforms on cuda devices) and vectorizing operations on the main process if possible."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ccd43ff0-b866-4d21-8f7d-4f5e53a051ba",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]\n",
"A.L.E: Arcade Learning Environment (version 0.8.0+919230b)\n",
"[Powered by Stella]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grayscale tensordict: LazyStackedTensorDict(\n",
" fields={\n",
" done: Tensor(torch.Size([2, 1]), dtype=torch.bool),\n",
" pixels: Tensor(torch.Size([2, 1, 64, 64]), dtype=torch.float32)},\n",
" batch_size=torch.Size([2]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATkAAAChCAYAAAC8o8hrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAASz0lEQVR4nO3da4wdd3nH8e8zM+e2F++u7cQxtsG5OSSQ0lQRhQalCIoKIQWEUgRBVaARfgNtQFSQUPVF1b6ANxBetJVcaBsqqkC5JYpQI+okSAgRSJoIkjiOTRJjx/Elsdfe27nMzNMX59jZtffq3TmXOb+PdLR75sw585zZZ3/7nzOXNXdHRCSvgk4XICKSJYWciOSaQk5Eck0hJyK5ppATkVxTyIlIrq0q5MzsvWa218z2m9mda1WUCKi/ZG3YhR4nZ2Yh8BzwHuAQ8CvgY+7+zNqVJ/1K/SVrZTUjubcC+939eXevA/cCH1ybskTUX7I2olU8dwtwcNb9Q8AfLvaEopW8zOAqFim9aoKTr7j7RSt4yor7KyoPemlo/YWUt2JJGdaPTDAQ1BkJEgybd74X60NMv1LBcnZiURoBBm7NrwsJGmBp9vVMv3powf5aTcgti5ntBHYClIMh/mj0lqwXKV3owRP/eiCL153dX8XBMa7+s89lsZjznNoBt978U64beJH3D0wSk3A8qRECoRkDFjIUlPnUwRt4YtfvEcRtKas9DGpjRlyGtAgeOpbOTToPmtMqR51oOvuSHv+Pzy/YX6sJuZeAbbPub21Nm8PddwG7AIbWb/NTf7JjFYuUnvXdFT9jxf01uHFbx8ZLT9Wdfz76XkphzKbiad5cOcSHBsc7VU7bBTUIEvCgObpLi7boCK+dVhNyvwKuNLNLaTbfR4FbF3uCGySl8995GnbJ2pBusuL+6qTjyTBPHn8dxShhy9AQA2EN+iXk3AiS5qZpGoIFkDq9H3LuHpvZZ4AHgRD4N3d/erHnBIlTGp+7ge4R1IZDXEfsySwX0l+d9LPJHcQPbaRahqObN3LsqmH+euzZTpfVNtE0RNNOUgQPjaRA1/xOr+ozOXf/MfDjFT3pnHR365K47zKNISOuzJ1mDoVJJ6x2pqZ2u6D+6pCZpEBh0rHEqE8HzDQKnS6pvdLmDgZLuyfczsh8x8NsaWRUR8PzpnfbSuk4g/FrY66+6tCcyakbz//i9Yzu7VBdIj2orSEHCrRlK6ZsqkwQWUpgKakH1NKQ357/N0JEFqHI6VaJMdEoEXtAwVJiD6gmBcjZ8VYiWWv7SE6WwSGYDNn36kVUig2iIKWehDTikKhqKOm6z5bSOKcvhbSYkl5cZ9u6kwR9NIZIC5CUIC0YaUjX7FkFhVzXGnvGSPaNUQNqs3JtsKaA60Y7R5/hfbc+BUDZUoYDo2B9cnaPOfUxo+6tZGudCdEtZ3ko5LpUUIeg3iVdIucJ6sYT49s4HZdJeY6A+c9d2jt+cZsraw9Lmgf/0gBPzh+2Ga/tbe00hZzIBagcc/buvpw9hcu4v7jwKbWF08ZQkr8/VtF081CmpXYkBo321LMYhZzkRhA7lRNJW5aVTBnRdICHrZPVFxBVndKppCtGNGspjey1E/QXESSOtedHsiCFnORGcHqGgZ/8uk0LC+DMgezBIsOZNIWkw7/lWVnsfZ+Rdj7d2xpylkBpYu6bTkOIK8GSfxFEluLupNU+OR1Elq2tIRfONBj+zfE503ywzPjV60iKSjkRWXvt3VwNAnywPGdSUil01TE1IpIvbQ25uBJy8k3r5k40XWpJRLLT3pGctfbKiGTAChHRRZd0ugzphMMLP6S9q5IbyVCJkzdu73QZ0gn3LvyQQk5ywwO0A0vOo5CTXNGlvORcagkRyTWFnIjkmkJORHJNISciuaYdD5IrlkJpIsHNqA8F2hEhGslJvkQ1Z2jPCYb3niSI83cdN1k5jeQkV9IQ4vWDrUvTdroa6QYKOcmVpGicumJgWRd0lP6gkJPc0edwMpvaQURyTSEnIrmmkBORXFPIiUiuKeREJNeWDDkz22ZmD5vZM2b2tJnd0Zq+3sx+Ymb7Wl/Hsi9X8kb9JVlbzkguBj7v7tcAbwM+bWbXAHcCu939SmB3677ISqm/JFNLhpy7v+zu/9f6fgLYA2wBPgjc05rtHuBDGdUoOab+kqyt6DM5M9sOXAc8Cmxy95dbDx0BNq1tadJv1F+ShWWHnJkNAd8HPuvup2c/5u4OzHs2tJntNLPHzOyxuDq1qmIlv9RfkpVlhZyZFWg24Lfd/QetyUfNbHPr8c3Asfme6+673P16d78+Kg+uRc2SM+ovydJy9q4a8E1gj7t/ddZD9wO3tb6/Dbhv7cuTvFN/SdaWc4L+DcBfAL8xsydb074EfBn4rpndDhwAPpJJhZJ36i/J1JIh5+4/Y+Erc717bcuRfqP+kqzpjAcRyTWFnIjkmkJORHJNISciuaaQE5FcU8iJSK4p5EQk1xRyIpJrCjkRyTWFnIjkmkJORHJtOSfoi4hkxgNoDBkegiVgKUQzjiVr8/oKORHpLIPGMCQlCOoQxBDWWbOQ0+aqiOSaQk5Eck0hJyK5ppATkVxTyIlIrinkRCTXFHIikmsKORHJNR0MLCKd5VCYgGj6tTMe1upAYFDIiUiHWQrF057Z62tzVURyTSEnIrmmkBORXFPIiUiuKeREJNcUciKSawo5Ecm1ZYecmYVm9oSZPdC6f6mZPWpm+83sO2ZWzK5MyTv1l2RlJSO5O4A9s+5/Bfiau18BnARuX+oFwmrK2J7JObeRF6sEcXYHAkrPWHV/icxnWSFnZluB9wPfaN034F3A91qz3AN8aMnXqTUIDxydcyscPkWwhqdwtJU1bx60vpcLslb9JTKf5Z7WdTfwBWC4dX8DMO7ucev+IWDLUi+SVopUr902d1oxIO3Bk8umthjTl9cpDde4eGSSgy9cxLpnI8KqE810urqeczdr0F8i81kyXszsZuCYuz9uZu9c6QLMbCewE6A4OMbklh7+aOXMaM2htj7luisPcP3o7/jT4d/wN8Gfc+zAFiw1qPrZ+WRxa91fIudazhjqBuADZnYTUAbWAV8HRs0sav213Qq8NN+T3X0XsAtgcOO2nv21r603pjc50YxRPA2lE8ZTP7+CJ1+3jZ+//jIOvLSRgRSmL3FOXZ1QeSli6GDPvt12Un9Jppb8TM7d73L3re6+Hfgo8JC7fxx4GLilNdttwH2ZVdkF4gokm2vUNiQ0hiCagZG9UPptmT0HL8FOFJrzjaRcsv1V6mNphyvuDeovydpqPg37InCvmf0j8ATwzbUpqTsVx4Fny0xfXueGm59i7/jFHHp5PaXBOpcMT1HeHDNYqPP0wc2c/OUmBsdB26ur0lf9JdlZUci5+yPAI63vnwfeuvYldafClFOYgpmrU/5u84M8NHIZ91fecvbxa0cOc+PQs3zm2K0MP69wuxD93F+SnR7cr9lZlT1l/njmc1g1pDD52nEjvy5dzn9W3kH55RCN4ES6h0JuhQYPO4OH51ttrYPmFHAiXUXnropIrvXHSO7csxFWOtha7fNFpGNyH3JxBaob7GxQRTNQftWXHVSNIaM2ytnnFyahdFIpJ9Ir8hlyBt46pzQpG40Rx4NmMHkYUBwH8+Y8C75EK8fiCnOPefOAwkTzcTszuZsyr/WePFj8/a1qEd363kXmkbuQS4tQX2dUNziFq05TLjZ4faV69vHT1TIndwwSBClRISEIHLPmDSBNA9whjkOSOKRcqfOGdRNnnz/dKDBVKzJxeJih5yOiGShMdslvukF92EgqMLmjzvDGqUwWU6sVCPYMUTgNxVNOEC/9HJFOyV/Ihc3RV2N9wk3b9zIQ1Oc8PpmUODK6jmKQsLE0ScESCrP+k21CQC2NmIpLnKxXGClU2VQ6fd5yHuBNNI6Mdd0veFKCeAC2bD3BjZv2Z7KM4/VhHjp8LUE9gIml5xfppNyF3FIKljBSqFIImuEW2PmjsIIlVMI6FGl+FZGe1XchFwUplbBOYD5vwIWkYEFzdBfW54zyRKT39F3IhaRUwsbZ7xeahwCiBR7vZbU0IiE4O4qdbx0kBMRp8xDKKEgXXE8ivaDvQg4WDreVztOLmgEWEgbpvCPZhIDUjYaHpB4QBbUOVCmydvoy5PrZRKNMLY0YLQA0IJgb6GcCbiIu00hDSkGDMMhn4Et/0GldfaaWRlSTiIYHJPP8+FM3UjdmkgL1NJx3HpFeog7uIwkB03GBU/UKM0mBOA3Pm6fhIbU0YrJR4lStMu88Ir1EIddn6mlELY5IPSClOWo7V+oBjSSkroCTHNBncn1mqlFkql6gmkRUwhDOybHYQ2IPmGwUqcfNzVqRXqYO7jNJ2hylxWk47ygOIHFrzpMEpAo56XEayfWR1I2ZRoFqtTmSO7u5OivrGmlIPY2YqRdoNJqjOpFepg7uM3EStC5CYCTnjOTOHCOXupEkAWkSnDePSK/RSK7P1OOQpBpRSyLieQ4RSTHiNKBRj0jqofauSs9rb8g5hPVsL0vkoRHWIJgO+N3UGOUwm8uETE+VKdcgqEPYWP5FODNlENaMtAAnpyq8OL1hzsP1NGRmooxNhrwyOQhANYnOrqMU49XqIJP1IvFkAasHHJ4aOW8xJ2oDhNXWeq5n/zMVWY22hlw0UWP9Iy9mu5AwhGIBr5SY+OFWJjLaIL9isk5wahxrxFBvZLOQC1Eq4lFIsnuI45Xtcx9LYcdkFWskpAMFPBphIhybs44sdoZTZ3R6EtxJhkY4Xhib+zqJc8Wrr2C1RvO9J7qIgXSvtoacN2Lil4+0bXlZfuDoQLf/as/3/p3XBp1n/r/YfNIl5un29y5yhnY8iEiuacfDQsywMCQYGMBG1uHT0yQnTp59OBgaIhgdwaenSU+dxpMEXJ9NiXQbjeQWEI6OEuy4jEOfejOX//Ao+75wFeHoKOHwMOHwMC9/4lqu/NERnrtrB8FVlxNturjTJYvIPBRyCwkDvBhRH3E+PPY48YYGhEFzx0YYUl8HHx57jGRDg7QYQaRBsUg30m/mAtLxU9hMlcu+dTH/8NNPcvWRCdLxU3ja3CS99L9e4u9/cTtvPDoFLxwk6aY9rCJylkJuAR7HeBzD/heI9r9w3t7E+IUDRC8cyOn1g0XyQ5urIpJryxrJmdko8A3gzTQPs/pLYC/wHWA78CLwEXc/Of8riCxsrfrLUoiq2sMtcy13c/XrwP+4+y1mVgQGgC8Bu939y2Z2J3An8MWM6pR8W5P+CieqjO7el3210lOWDDkzGwFuBD4B4O51oG5mHwTe2ZrtHuARFHKyQmvZXx4nJK+8mlWp0qOW85ncpcBx4N/N7Akz+4aZDQKb3P3l1jxHgE1ZFSm5pv6STC0n5CLgD4B/cffrgCmamw5nufvsUyLnMLOdZvaYmT3WQP/DU86j/pJMLSfkDgGH3P3R1v3v0WzKo2a2GaD19dh8T3b3Xe5+vbtfX6C0FjVLvqi/JFNLhpy7HwEOmtlVrUnvBp4B7gdua027Dbgvkwol19RfkrXl7l39K+DbrT1fzwOfpBmQ3zWz24EDwEeyKVH6gPpLMrOskHP3J4Hr53no3WtajfQl9ZdkSWc8iEiuKeREJNfM23ihRzM7TvMQgVfattCV2Uj31ga9Xd8b3P2iLBfeA/0F3f0z7Oba4AL7q60hB2Bmj7n7fJ+/dFw31waqr1dqWEw319fNtcGF16fNVRHJNYWciORaJ0JuVweWuVzdXBuovuXohhoW0831dXNtcIH1tf0zORGRdtLmqojkWttCzszea2Z7zWx/6yKIHWVm28zsYTN7xsyeNrM7WtPXm9lPzGxf6+tYB2sMW5cfeqB1/1Ize7S1Dr/TOg2qU7WNmtn3zOxZM9tjZm/v5LpTf11QjX3RX20JOTMLgX8C3gdcA3zMzK5px7IXEQOfd/drgLcBn27VdCfNK9JeCezmnMv+tNkdwJ5Z978CfM3drwBOArd3pKqmM1fzfSPwFpp1dmTdqb8uWH/0l7tnfgPeDjw46/5dwF3tWPYKarwPeA/N/y2wuTVtM7C3Q/Vsbf0g3wU8ABjNAyGj+dZpm2sbAV6g9ZnurOkdWXfqL/XXYrd2ba5uAQ7Oun+oNa0rmNl24DrgUbrnirR3A1+As//1cAMw7u5x634n12G3Xc1X/bVyd9Mn/dX3Ox7MbAj4PvBZdz89+zFv/slo++5nM7sZOObuj7d72cu0qqv59hP11wVZ0/5qV8i9BGybdX9ra1pHmVmBZgN+291/0Jq8rCvSZuwG4ANm9iJwL81Niq8Do2Z25vJYnVyHq7qabwbUXyvTV/3VrpD7FXBla+9NEfgozSu/doyZGfBNYI+7f3XWQx2/Iq273+XuW919O8119ZC7fxx4GLilk7W16uu2q/mqv1ag7/qrjR8m3gQ8B/wW+NtOfKB5Tj3voDnc/TXwZOt2E83PJnYD+4D/BdZ3uM53Ag+0vr8M+CWwH/hvoNTBun4feKy1/n4EjHVy3am/1F8L3XTGg4jkWt/veBCRfFPIiUiuKeREJNcUciKSawo5Eck1hZyI5JpCTkRyTSEnIrn2/2o103OBWxCiAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from torchrl.envs import ParallelEnv, TransformedEnv, ToTensorImage, Resize, Compose, GrayScale\n",
"from torchrl.envs.libs.gym import GymEnv\n",
"from matplotlib import pyplot as plt\n",
"\n",
"def env_make(env_name):\n",
" env = TransformedEnv(GymEnv(env_name, from_pixels=True, pixels_only=True), Compose(ToTensorImage(), Resize(64, 64))) # transforms on remote processes\n",
" return env\n",
"\n",
"parallel_env = ParallelEnv(2, [env_make, env_make], [{\"env_name\": \"ALE/AirRaid-v5\"}, {\"env_name\": \"ALE/Pong-v5\"}])\n",
"parallel_env = TransformedEnv(parallel_env, GrayScale()) # transforms on main process\n",
"tensordict = parallel_env.reset()\n",
"print(\"grayscale tensordict: \", tensordict)\n",
"plt.figure(figsize=(5, 10))\n",
"plt.subplot(121)\n",
"plt.imshow(tensordict[0].get(\"pixels\").permute(1, 2, 0).numpy())\n",
"plt.subplot(122)\n",
"plt.imshow(tensordict[1].get(\"pixels\").permute(1, 2, 0).numpy())\n",
"parallel_env.close()"
]
},
{
"cell_type": "markdown",
"id": "d66e276d-bd61-431e-852e-55198595fe34",
"metadata": {},
"source": [
"## VecNorm\n",
"\n",
"In RL, we commonly face the problem of normalizing data before inputting them into a model. \n",
"Sometimes, we can get a good approximation of the normalizing statistics from data gathered in the environment with, say, a random policy (or demonstrations). It might, however, be advisable to normalize the data \"on-the-fly\", updating the normalizing constants progressively to what has been observed so far. This is particularily useful when we expect the normalizing statistics to change following changes in performance in the task, or when the environment is evolving due to external factors.\n",
"\n",
"**Caution**: this feature should be used with caution with off-policy learning, as old data will be \"deprecated\" due to its normalization with previously valid normalizing statistics. In on-policy settings too, this feature makes learning non-steady and may have unexpected effects. One would therefore advice users to rely on this feature with caution and compare it with data normalizing given a fixed version of the normalizing constants.\n",
"\n",
"In regular setting, using VecNorm is quite easy:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "10d665f6-daf8-41bb-9767-5316387db737",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"mean: : tensor([-0.2824, -0.3740, -0.1690])\n",
"std: : tensor([0.9514, 1.0710, 1.1238])\n"
]
}
],
"source": [
"from torchrl.envs.libs.gym import GymEnv\n",
"from torchrl.envs.transforms import VecNorm, TransformedEnv\n",
"\n",
"env = TransformedEnv(GymEnv(\"Pendulum-v1\"), VecNorm())\n",
"tensordict = env.rollout(max_steps=100)\n",
"\n",
"print(\"mean: :\", tensordict.get(\"observation\").mean(0)) # Approx 0\n",
"print(\"std: :\", tensordict.get(\"observation\").std(0)) # Approx 1"
]
},
{
"cell_type": "markdown",
"id": "34c31e5c-82fc-4795-bb74-917ea5babc7e",
"metadata": {},
"source": [
"In **parallel envs** things are slightly more complicated, as we need to share the running statistics amongst the processes. We created a class `EnvCreator` that is responsible for looking at an environment creation method, retrieving tensordicts to share amongst processes in the environment class, and pointing each process to the right common, shared tensordict once created:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "55f3b06d-d4cf-4d5b-b0e5-eda4d46cfcd9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"tensordict: TensorDict(\n",
" fields={\n",
" action: Tensor(torch.Size([3, 5, 2]), dtype=torch.int64),\n",
" done: Tensor(torch.Size([3, 5, 1]), dtype=torch.bool),\n",
" next_observation: Tensor(torch.Size([3, 5, 4]), dtype=torch.float32),\n",
" observation: Tensor(torch.Size([3, 5, 4]), dtype=torch.float32),\n",
" reward: Tensor(torch.Size([3, 5, 1]), dtype=torch.float32)},\n",
" batch_size=torch.Size([3, 5]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"mean: : tensor([ 0.1187, -0.0427, -0.1390])\n",
"std: : tensor([1.1470, 1.1814, 1.1676])\n",
"update counts: tensor([18.])\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n",
"Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.\n"
]
}
],
"source": [
"from torchrl.envs import EnvCreator, ParallelEnv\n",
"from torchrl.envs.libs.gym import GymEnv\n",
"from torchrl.envs.transforms import VecNorm, TransformedEnv\n",
"\n",
"make_env = EnvCreator(lambda: TransformedEnv(GymEnv(\"CartPole-v1\"), VecNorm(decay=1.0)))\n",
"env = ParallelEnv(3, make_env)\n",
"make_env.state_dict()['_extra_state']['td'][\"next_observation_count\"].fill_(0.0)\n",
"make_env.state_dict()['_extra_state']['td'][\"next_observation_ssq\"].fill_(0.0)\n",
"make_env.state_dict()['_extra_state']['td'][\"next_observation_sum\"].fill_(0.0)\n",
"\n",
"tensordict = env.rollout(max_steps=5)\n",
"\n",
"print('tensordict: ', tensordict)\n",
"print(\"mean: :\", tensordict.get(\"observation\").view(-1, 3).mean(0)) # Approx 0\n",
"print(\"std: :\", tensordict.get(\"observation\").view(-1, 3).std(0)) # Approx 1\n",
"\n",
"# The count is slightly higher than the number of steps (since we did not use any decay)\n",
"# The difference between the two is due to the fact that ParallelEnv creates a dummy environment to initialize the shared TensorDict \n",
"# that is used to collect data from the dispached environments. This small difference will usually be absored throughout training.\n",
"print(\"update counts: \", make_env.state_dict()['_extra_state']['td'][\"next_observation_count\"])\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98279e92",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}