diff --git a/doc/source/rllib/images/rollout_worker_class_overview.svg b/doc/source/rllib/images/rollout_worker_class_overview.svg deleted file mode 100644 index 6ee606b800ec..000000000000 --- a/doc/source/rllib/images/rollout_worker_class_overview.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/doc/source/rllib/new-api-stack-migration-guide.rst b/doc/source/rllib/new-api-stack-migration-guide.rst index 1dde24cf4d70..99e518033d4e 100644 --- a/doc/source/rllib/new-api-stack-migration-guide.rst +++ b/doc/source/rllib/new-api-stack-migration-guide.rst @@ -376,7 +376,7 @@ The following callback methods are no longer available on the new API stack: **`on_sub_environment_created()`**: The new API stack uses `Farama's gymnasium `__ vector Envs leaving no control for RLlib to call a callback on each individual env-index's creation. -**`on_create_policy()`**: This method is no longer available on the new API stack because only :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` calls it. +**`on_create_policy()`**: This method is no longer available on the new API stack because only ``RolloutWorker`` calls it. **`on_postprocess_trajectory()`**: The new API stack no longer triggers and calls this method, because :py:class:`~ray.rllib.connectors.connector_v2.ConnectorV2` pipelines handle trajectory processing entirely. @@ -388,14 +388,14 @@ The documention for :py:class:`~ray.rllib.connectors.connector_v2.ConnectorV2` d ModelV2 to RLModule ------------------- -If you're using a custom :py:class:`~ray.rllib.models.modelv2.ModelV2` class and want to translate +If you're using a custom ``ModelV2`` class and want to translate the entire NN architecture and possibly action distribution logic to the new API stack, see :ref:`RL Modules ` in addition to this section. Also, see these example scripts on `how to write a custom CNN-containing RL Module `__ and `how to write a custom LSTM-containing RL Module `__. -There are various options for translating an existing, custom :py:class:`~ray.rllib.models.modelv2.ModelV2` from the old API stack, +There are various options for translating an existing, custom ``ModelV2`` from the old API stack, to the new API stack's :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`: #. Move your ModelV2 code to a new, custom `RLModule` class. See :ref:`RL Modules ` for details). @@ -411,8 +411,7 @@ and distributions. Translating Policy.compute_actions_from_input_dict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This old API stack method, as well as :py:meth:`~ray.rllib.policy.policy.Policy.compute_actions` and -:py:meth:`~ray.rllib.policy.policy.Policy.compute_single_action`, directly translate to +This old API stack method, as well as ``compute_actions`` and ``compute_single_action``, directly translate to :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_inference` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_exploration`. :ref:`The RLModule guide explains how to implement this method `. @@ -421,7 +420,7 @@ and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_exploration` Translating Policy.action_distribution_fn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To translate :py:meth:`~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.action_distribution_fn`, write the following custom RLModule code: +To translate ``action_distribution_fn``, write the following custom RLModule code: .. tab-set:: @@ -464,7 +463,7 @@ To translate :py:meth:`~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.action_di Translating Policy.action_sampler_fn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To translate :py:meth:`~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.action_sampler_fn`, write the following custom RLModule code: +To translate ``action_sampler_fn``, write the following custom RLModule code: .. testcode:: :skipif: True diff --git a/doc/source/rllib/package_ref/algorithm.rst b/doc/source/rllib/package_ref/algorithm.rst index b904492b11c1..bc72e6be3e29 100644 --- a/doc/source/rllib/package_ref/algorithm.rst +++ b/doc/source/rllib/package_ref/algorithm.rst @@ -9,7 +9,7 @@ Algorithms The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` class is the highest-level API in RLlib responsible for **WHEN** and **WHAT** of RL algorithms. Things like **WHEN** should we sample the algorithm, **WHEN** should we perform a neural network update, and so on. -The **HOW** will be delegated to components such as :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`, etc.. +The **HOW** will be delegated to components such as ``RolloutWorker``, etc.. It is the main entry point for RLlib users to interact with RLlib's algorithms. @@ -24,9 +24,9 @@ and thus fully supports distributed hyperparameter tuning for RL. :align: left **A typical RLlib Algorithm object:** Algorithms are normally comprised of - N :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` that + N ``RolloutWorkers`` that orchestrated via a :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` object. - Each worker own its own a set of :py:class:`~ray.rllib.policy.policy.Policy` objects and their NN models per worker, plus a :py:class:`~ray.rllib.env.base_env.BaseEnv` instance per worker. + Each worker own its own a set of ``Policy`` objects and their NN models per worker, plus a :py:class:`~ray.rllib.env.base_env.BaseEnv` instance per worker. Building Custom Algorithm Classes --------------------------------- diff --git a/doc/source/rllib/package_ref/distributions.rst b/doc/source/rllib/package_ref/distributions.rst new file mode 100644 index 000000000000..954e4ded09f8 --- /dev/null +++ b/doc/source/rllib/package_ref/distributions.rst @@ -0,0 +1,20 @@ +.. include:: /_includes/rllib/we_are_hiring.rst + +.. include:: /_includes/rllib/new_api_stack.rst + + +.. _rllib-distributions-reference-docs: + +Distribution API +================ + +.. currentmodule:: ray.rllib.models.distributions + +Base Distribution class +----------------------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + ~Distribution diff --git a/doc/source/rllib/package_ref/env.rst b/doc/source/rllib/package_ref/env.rst index 8e273901c684..b5eaf762abbc 100644 --- a/doc/source/rllib/package_ref/env.rst +++ b/doc/source/rllib/package_ref/env.rst @@ -50,6 +50,7 @@ Environment API Reference .. toctree:: :maxdepth: 1 + env/env_runner.rst env/single_agent_env_runner.rst env/single_agent_episode.rst env/multi_agent_env.rst diff --git a/doc/source/rllib/package_ref/env/env_runner.rst b/doc/source/rllib/package_ref/env/env_runner.rst new file mode 100644 index 000000000000..6dc1f9fee626 --- /dev/null +++ b/doc/source/rllib/package_ref/env/env_runner.rst @@ -0,0 +1,57 @@ + +.. include:: /_includes/rllib/we_are_hiring.rst + +.. include:: /_includes/rllib/new_api_stack.rst + +.. _env-runner-reference-docs: + +EnvRunner API +============= + +rllib.env.env_runner.EnvRunner +------------------------------ + +.. currentmodule:: ray.rllib.env.env_runner + +Construction and setup +~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + EnvRunner + EnvRunner.make_env + EnvRunner.make_module + EnvRunner.get_spaces + EnvRunner.assert_healthy + +Sampling +~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + EnvRunner.sample + EnvRunner.get_metrics + +Cleanup +~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + EnvRunner.stop + + + +Single-agent and multi-agent EnvRunners +--------------------------------------- + +By default, RLlib uses two built-in subclasses of EnvRunner, one for :ref:`single-agent `, one +for :ref:`multi-agent ` setups. It determines based on your config, which one to use. + +Check your ``config.is_multi_agent`` property to find out, which of these setups you have configured +and see :ref:`the docs on setting up RLlib multi-agent ` for more details. diff --git a/doc/source/rllib/package_ref/evaluation.rst b/doc/source/rllib/package_ref/evaluation.rst deleted file mode 100644 index d2e0560d0527..000000000000 --- a/doc/source/rllib/package_ref/evaluation.rst +++ /dev/null @@ -1,370 +0,0 @@ - -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - - -.. _evaluation-reference-docs: - -Sampling the Environment or offline data -======================================== - -Data ingest via either environment rollouts or other data-generating methods -(e.g. reading from offline files) is done in RLlib by :py:class:`~ray.rllib.env.env_runner.EnvRunner` instances, -which sit inside a :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` -(together with other parallel ``EnvRunners``) in the RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` -(under the ``self.env_runner_group`` property): - - -.. https://docs.google.com/drawings/d/1OewMLAu6KZNon7zpDfZnTh9qiT6m-3M9wnkqWkQQMRc/edit -.. figure:: ../images/rollout_worker_class_overview.svg - :width: 600 - :align: left - - **A typical RLlib EnvRunnerGroup setup inside an RLlib Algorithm:** Each :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` contains - exactly one local :py:class:`~ray.rllib.env.env_runner.EnvRunner` object and N ray remote - :py:class:`~ray.rllib.env.env_runner.EnvRunner` (Ray actors). - The workers contain a policy map (with one or more policies), and - in case a simulator - (env) is available - a vectorized :py:class:`~ray.rllib.env.base_env.BaseEnv` - (containing M sub-environments) and a :py:class:`~ray.rllib.evaluation.sampler.SamplerInput` (either synchronous or asynchronous) which controls - the environment data collection loop. - In the online case (i.e. environment is available) as well as the offline case (i.e. no environment), - :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` uses the :py:meth:`~ray.rllib.env.env_runner.EnvRunner.sample` method to - get :py:class:`~ray.rllib.policy.sample_batch.SampleBatch` objects for training. - - -.. _rolloutworker-reference-docs: - -RolloutWorker API ------------------ - -.. currentmodule:: ray.rllib.evaluation.rollout_worker - -Constructor -~~~~~~~~~~~ - - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - RolloutWorker - -Multi agent -~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.add_policy - ~RolloutWorker.remove_policy - ~RolloutWorker.get_policy - ~RolloutWorker.set_is_policy_to_train - ~RolloutWorker.set_policy_mapping_fn - ~RolloutWorker.for_policy - ~RolloutWorker.foreach_policy - ~RolloutWorker.foreach_policy_to_train - -Setter and getter methods -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.get_filters - ~RolloutWorker.get_global_vars - ~RolloutWorker.set_global_vars - ~RolloutWorker.get_host - ~RolloutWorker.get_metrics - ~RolloutWorker.get_node_ip - ~RolloutWorker.get_weights - ~RolloutWorker.set_weights - ~RolloutWorker.get_state - ~RolloutWorker.set_state - -Threading -~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.lock - ~RolloutWorker.unlock - -Sampling API -~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.sample - ~RolloutWorker.sample_with_count - ~RolloutWorker.sample_and_learn - -Training API -~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.learn_on_batch - ~RolloutWorker.setup_torch_data_parallel - ~RolloutWorker.compute_gradients - ~RolloutWorker.apply_gradients - -Environment API -~~~~~~~~~~~~~~~ - - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.foreach_env - ~RolloutWorker.foreach_env_with_context - - -Miscellaneous -~~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~RolloutWorker.stop - ~RolloutWorker.apply - ~RolloutWorker.sync_filters - ~RolloutWorker.find_free_port - ~RolloutWorker.creation_args - ~RolloutWorker.assert_healthy - - -.. _workerset-reference-docs: - -EnvRunner API -------------- - -.. currentmodule:: ray.rllib.env.env_runner - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - EnvRunner - -EnvRunnerGroup API ------------------- - -.. currentmodule:: ray.rllib.env.env_runner_group - -Constructor -~~~~~~~~~~~ - - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - EnvRunnerGroup - EnvRunnerGroup.stop - EnvRunnerGroup.reset - - -Worker Orchestration -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~EnvRunnerGroup.add_workers - ~EnvRunnerGroup.foreach_worker - ~EnvRunnerGroup.foreach_worker_with_id - ~EnvRunnerGroup.foreach_worker_async - ~EnvRunnerGroup.fetch_ready_async_reqs - ~EnvRunnerGroup.num_in_flight_async_reqs - ~EnvRunnerGroup.local_worker - ~EnvRunnerGroup.remote_workers - ~EnvRunnerGroup.num_healthy_remote_workers - ~EnvRunnerGroup.num_healthy_workers - ~EnvRunnerGroup.num_remote_worker_restarts - ~EnvRunnerGroup.probe_unhealthy_workers - -Pass-through methods -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~EnvRunnerGroup.add_policy - ~EnvRunnerGroup.foreach_env - ~EnvRunnerGroup.foreach_env_with_context - ~EnvRunnerGroup.foreach_policy - ~EnvRunnerGroup.foreach_policy_to_train - ~EnvRunnerGroup.sync_weights - - - -Sampler API ------------ -:py:class:`~ray.rllib.offline.input_reader.InputReader` instances are used to collect and return experiences from the envs. -For more details on `InputReader` used for offline RL (e.g. reading files of -pre-recorded data), see the :ref:`offline RL API reference here `. - - - - -Input Reader API -~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.offline.input_reader - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - InputReader - InputReader.next - - -Input Sampler API -~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.evaluation.sampler - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - SamplerInput - SamplerInput.get_data - SamplerInput.get_extra_batches - SamplerInput.get_metrics - -Synchronous Sampler API -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.evaluation.sampler - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - SyncSampler - - - -.. _offline-reference-docs: - -Offline Sampler API -~~~~~~~~~~~~~~~~~~~~~~~ - -The InputReader API is used by an individual :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` -to produce batches of experiences either from an simulator or from an -offline source (e.g. a file). - -Here are some example extentions of the InputReader API: - -JSON reader API -++++++++++++++++ - -.. currentmodule:: ray.rllib.offline.json_reader - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - JsonReader - JsonReader.read_all_files - -.. currentmodule:: ray.rllib.offline.mixed_input - -Mixed input reader -++++++++++++++++++ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - MixedInput - -.. currentmodule:: ray.rllib.offline.d4rl_reader - -D4RL reader -+++++++++++ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - D4RLReader - -.. currentmodule:: ray.rllib.offline.io_context - -IOContext -~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - IOContext - IOContext.default_sampler_input - - - -Policy Map API --------------- - -.. currentmodule:: ray.rllib.policy.policy_map - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - PolicyMap - PolicyMap.items - PolicyMap.keys - PolicyMap.values - -Sample batch API ----------------- - -.. currentmodule:: ray.rllib.policy.sample_batch - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - SampleBatch - SampleBatch.set_get_interceptor - SampleBatch.is_training - SampleBatch.set_training - SampleBatch.as_multi_agent - SampleBatch.get - SampleBatch.to_device - SampleBatch.right_zero_pad - SampleBatch.slice - SampleBatch.split_by_episode - SampleBatch.shuffle - SampleBatch.columns - SampleBatch.rows - SampleBatch.copy - SampleBatch.is_single_trajectory - SampleBatch.is_terminated_or_truncated - SampleBatch.env_steps - SampleBatch.agent_steps - - -MultiAgent batch API --------------------- - -.. currentmodule:: ray.rllib.policy.sample_batch - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - MultiAgentBatch - MultiAgentBatch.env_steps - MultiAgentBatch.agent_steps diff --git a/doc/source/rllib/package_ref/external-app.rst b/doc/source/rllib/package_ref/external-app.rst deleted file mode 100644 index fe8504c443d6..000000000000 --- a/doc/source/rllib/package_ref/external-app.rst +++ /dev/null @@ -1,19 +0,0 @@ - -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - -External Application API ------------------------- - -In some cases, for instance when interacting with an externally hosted simulator or -production environment, it makes more sense to interact with RLlib as if it were an -independently running service, rather than RLlib hosting the simulations itself. -This is possible via RLlib's external applications interface -`(full documentation) `__. - -.. autoclass:: ray.rllib.env.policy_client.PolicyClient - :members: - -.. autoclass:: ray.rllib.env.policy_server_input.PolicyServerInput - :members: diff --git a/doc/source/rllib/package_ref/index.rst b/doc/source/rllib/package_ref/index.rst index c77f86332f71..9f03c643a9ea 100644 --- a/doc/source/rllib/package_ref/index.rst +++ b/doc/source/rllib/package_ref/index.rst @@ -23,12 +23,9 @@ If you think there is anything missing, please open an issue on `Github`_. algorithm-config.rst algorithm.rst env.rst - policy.rst - models.rst rl_modules.rst + distributions.rst learner.rst offline.rst - evaluation.rst replay-buffers.rst utils.rst - external-app.rst diff --git a/doc/source/rllib/package_ref/models.rst b/doc/source/rllib/package_ref/models.rst deleted file mode 100644 index c632615fdd1b..000000000000 --- a/doc/source/rllib/package_ref/models.rst +++ /dev/null @@ -1,60 +0,0 @@ -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - - -.. _model-reference-docs: - -Model APIs -========== - -.. currentmodule:: ray.rllib.models - -Base Model classes -------------------- - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~modelv2.ModelV2 - ~torch.torch_modelv2.TorchModelV2 - ~tf.tf_modelv2.TFModelV2 - -Feed Forward methods ---------------------- -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~modelv2.ModelV2.forward - ~modelv2.ModelV2.value_function - ~modelv2.ModelV2.last_output - -Recurrent Models API ---------------------- -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~modelv2.ModelV2.get_initial_state - ~modelv2.ModelV2.is_time_major - -Acessing variables ---------------------- -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~modelv2.ModelV2.variables - ~modelv2.ModelV2.trainable_variables - ~distributions.Distribution - -Customization --------------- -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~modelv2.ModelV2.custom_loss - ~modelv2.ModelV2.metrics diff --git a/doc/source/rllib/package_ref/policy.rst b/doc/source/rllib/package_ref/policy.rst deleted file mode 100644 index db3fd42765a6..000000000000 --- a/doc/source/rllib/package_ref/policy.rst +++ /dev/null @@ -1,272 +0,0 @@ -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - -.. _policy-reference-docs: - - -Policy API -========== - -The :py:class:`~ray.rllib.policy.policy.Policy` class contains functionality to compute -actions for decision making in an environment, as well as computing loss(es) and gradients, -updating a neural network model as well as postprocessing a collected environment trajectory. -One or more :py:class:`~ray.rllib.policy.policy.Policy` objects sit inside a -:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`'s :py:class:`~ray.rllib.policy.policy_map.PolicyMap` and -are - if more than one - are selected based on a multi-agent ``policy_mapping_fn``, -which maps agent IDs to a policy ID. - -.. https://docs.google.com/drawings/d/1eFAVV1aU47xliR5XtGqzQcdvuYs2zlVj1Gb8Gg0gvnc/edit -.. figure:: ../images/policy_classes_overview.svg - :align: left - - **RLlib's Policy class hierarchy:** Policies are deep-learning framework - specific as they hold functionality to handle a computation graph (e.g. a - TensorFlow 1.x graph in a session). You can define custom policy behavior - by sub-classing either of the available, built-in classes, depending on your - needs. - -.. include:: policy/custom_policies.rst - -.. currentmodule:: ray.rllib - -Base Policy classes -------------------- - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.policy.Policy - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2 - - ~policy.torch_policy_v2.TorchPolicyV2 - - -.. -------------------------------------------- - -Making models --------------------- - - -Torch Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.torch_policy_v2.TorchPolicyV2.make_model - ~policy.torch_policy_v2.TorchPolicyV2.make_model_and_action_dist - - -Tensorflow Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.make_model - -.. -------------------------------------------- - -Inference --------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.policy.Policy.compute_actions - ~policy.policy.Policy.compute_actions_from_input_dict - ~policy.policy.Policy.compute_single_action - -Torch Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.torch_policy_v2.TorchPolicyV2.action_sampler_fn - ~policy.torch_policy_v2.TorchPolicyV2.action_distribution_fn - ~policy.torch_policy_v2.TorchPolicyV2.extra_action_out - -Tensorflow Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.action_sampler_fn - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.action_distribution_fn - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.extra_action_out_fn - -.. -------------------------------------------- - -Computing, processing, and applying gradients ---------------------------------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.Policy.compute_gradients - ~policy.Policy.apply_gradients - -Torch Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.torch_policy_v2.TorchPolicyV2.extra_compute_grad_fetches - ~policy.torch_policy_v2.TorchPolicyV2.extra_grad_process - - -Tensorflow Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.grad_stats_fn - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.compute_gradients_fn - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.apply_gradients_fn - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.extra_learn_fetches_fn - - - -.. -------------------------------------------- - -Updating the Policy's model ----------------------------- - - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.Policy.learn_on_batch - ~policy.Policy.load_batch_into_buffer - ~policy.Policy.learn_on_loaded_batch - ~policy.Policy.learn_on_batch_from_replay_buffer - ~policy.Policy.get_num_samples_loaded_into_buffer - - -.. -------------------------------------------- - -Loss, Logging, optimizers, and trajectory processing ----------------------------------------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.Policy.loss - ~policy.Policy.compute_log_likelihoods - ~policy.Policy.on_global_var_update - ~policy.Policy.postprocess_trajectory - - - -Torch Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.torch_policy_v2.TorchPolicyV2.optimizer - ~policy.torch_policy_v2.TorchPolicyV2.get_tower_stats - - -Tensorflow Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.optimizer - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.stats_fn - - -.. -------------------------------------------- - -Saving and restoring --------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.Policy.from_checkpoint - ~policy.Policy.export_checkpoint - ~policy.Policy.export_model - ~policy.Policy.from_state - ~policy.Policy.get_weights - ~policy.Policy.set_weights - ~policy.Policy.get_state - ~policy.Policy.set_state - ~policy.Policy.import_model_from_h5 - -.. -------------------------------------------- - - -Recurrent Policies --------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - Policy.get_initial_state - Policy.num_state_tensors - Policy.is_recurrent - - -.. -------------------------------------------- - -Miscellaneous --------------------- - -Base Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.Policy.apply - ~policy.Policy.get_session - ~policy.Policy.init_view_requirements - ~policy.Policy.get_host - ~policy.Policy.get_exploration_state - - -Torch Policy -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.torch_policy_v2.TorchPolicyV2.get_batch_divisibility_req - - -Tensorflow Policy -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.variables - ~policy.eager_tf_policy_v2.EagerTFPolicyV2.get_batch_divisibility_req diff --git a/doc/source/rllib/package_ref/policy/custom_policies.rst b/doc/source/rllib/package_ref/policy/custom_policies.rst deleted file mode 100644 index 2234143900ba..000000000000 --- a/doc/source/rllib/package_ref/policy/custom_policies.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - -Building Custom Policy Classes ------------------------------- - -.. currentmodule:: ray.rllib - -.. warning:: - As of Ray >= 1.9, it is no longer recommended to use the ``build_policy_class()`` or - ``build_tf_policy()`` utility functions for creating custom Policy sub-classes. - Instead, follow the simple guidelines here for directly sub-classing from - either one of the built-in types: - :py:class:`~policy.eager_tf_policy_v2.EagerTFPolicyV2` - or - :py:class:`~policy.torch_policy_v2.TorchPolicyV2` - -In order to create a custom Policy, sub-class :py:class:`~policy.policy.Policy` (for a generic, -framework-agnostic policy), -:py:class:`~policy.torch_policy_v2.TorchPolicyV2` -(for a PyTorch specific policy), or -:py:class:`~policy.eager_tf_policy_v2.EagerTFPolicyV2` -(for a TensorFlow specific policy) and override one or more of their methods. Those are in particular: - -* :py:meth:`~policy.policy.Policy.compute_actions_from_input_dict` -* :py:meth:`~policy.policy.Policy.postprocess_trajectory` -* :py:meth:`~policy.policy.Policy.loss` - -`See here for an example on how to override TorchPolicy `_. diff --git a/doc/source/rllib/package_ref/utils.rst b/doc/source/rllib/package_ref/utils.rst index e492f0b8e7c7..d0bfca886353 100644 --- a/doc/source/rllib/package_ref/utils.rst +++ b/doc/source/rllib/package_ref/utils.rst @@ -10,119 +10,44 @@ RLlib Utilities Here is a list of all the utilities available in RLlib. -Exploration API ---------------- - -Exploration is crucial in RL for enabling a learning agent to find new, potentially high-reward states by reaching unexplored areas of the environment. - -RLlib has several built-in exploration components that -the different algorithms use. You can also customize an algorithm's exploration -behavior by sub-classing the Exploration base class and implementing -your own logic: - -Built-in Exploration components -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.utils.exploration - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~exploration.Exploration - ~random.Random - ~stochastic_sampling.StochasticSampling - ~epsilon_greedy.EpsilonGreedy - ~gaussian_noise.GaussianNoise - ~ornstein_uhlenbeck_noise.OrnsteinUhlenbeckNoise - ~random_encoder.RE3 - ~curiosity.Curiosity - ~parameter_noise.ParameterNoise - - -Inference -~~~~~~~~~ -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~exploration.Exploration.get_exploration_action - -Callback hooks -~~~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~exploration.Exploration.before_compute_actions - ~exploration.Exploration.on_episode_start - ~exploration.Exploration.on_episode_end - ~exploration.Exploration.postprocess_trajectory - - -Setting and getting states -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~exploration.Exploration.get_state - ~exploration.Exploration.set_state - - - Scheduler API ------------- -Use a scheduler to set scheduled values for variables (in Python, PyTorch, or -TensorFlow) based on an (int64) timestep input. The computed values are usually float32 -types. - - +RLlib uses the Scheduler API to set scheduled values for variables, in Python or PyTorch, +dependent on an int timestep input. The type of the schedule is always a ``PiecewiseSchedule``, which defines a list +of increasing time steps, starting at 0, associated with values to be reached at these particular timesteps. +``PiecewiseSchedule`` interpolates values for all intermittent timesteps. +The computed values are usually float32 types. +For example: -Built-in Scheduler components -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. testcode:: -.. currentmodule:: ray.rllib.utils.schedules - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~schedule.Schedule - ~constant_schedule.ConstantSchedule - ~linear_schedule.LinearSchedule - ~piecewise_schedule.PiecewiseSchedule - ~exponential_schedule.ExponentialSchedule - ~polynomial_schedule.PolynomialSchedule - -Methods -~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc/ + from ray.rllib.utils.schedules.scheduler import Scheduler - ~schedule.Schedule.value - ~schedule.Schedule.__call__ + scheduler = Scheduler([[0, 0.1], [50, 0.05], [60, 0.001]]) + print(scheduler.get_current_value()) # <- expect 0.1 + # Up the timestep. + scheduler.update(timestep=45) + print(scheduler.get_current_value()) # <- expect 0.055 -.. _train-ops-docs: + # Up the timestep. + scheduler.update(timestep=100) + print(scheduler.get_current_value()) # <- expect 0.001 (keep final value) -Training Operations Utilities ------------------------------ -.. currentmodule:: ray.rllib.execution.train_ops +.. currentmodule:: ray.rllib.utils.schedules.scheduler .. autosummary:: - :nosignatures: - :toctree: doc/ + :nosignatures: + :toctree: doc/ - ~multi_gpu_train_one_step - ~train_one_step + Scheduler + Scheduler.validate + Scheduler.get_current_value + Scheduler.update + Scheduler._create_tensor_variable Framework Utilities @@ -138,60 +63,28 @@ Import utilities :toctree: doc/ ~try_import_torch - ~try_import_tf - ~try_import_tfp - - -Tensorflow utilities -~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.utils.tf_utils - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - ~explained_variance - ~flatten_inputs_to_1d_tensor - ~get_gpu_devices - ~get_placeholder - ~huber_loss - ~l2_loss - ~make_tf_callable - ~minimize_and_clip - ~one_hot - ~reduce_mean_ignore_inf - ~scope_vars - ~warn_if_infinite_kl_divergence - ~zero_logps_from_actions - Torch utilities ~~~~~~~~~~~~~~~ .. currentmodule:: ray.rllib.utils.torch_utils - .. autosummary:: - :nosignatures: - :toctree: doc/ - - ~apply_grad_clipping - ~concat_multi_gpu_td_errors - ~convert_to_torch_tensor - ~explained_variance - ~flatten_inputs_to_1d_tensor - ~global_norm - ~huber_loss - ~l2_loss - ~minimize_and_clip - ~one_hot - ~reduce_mean_ignore_inf - ~sequence_mask - ~warn_if_infinite_kl_divergence - ~set_torch_seed - ~softmax_cross_entropy_with_logits - + :nosignatures: + :toctree: doc/ + + ~clip_gradients + ~compute_global_norm + ~convert_to_torch_tensor + ~explained_variance + ~flatten_inputs_to_1d_tensor + ~global_norm + ~one_hot + ~reduce_mean_ignore_inf + ~sequence_mask + ~set_torch_seed + ~softmax_cross_entropy_with_logits + ~update_target_network Numpy utilities ~~~~~~~~~~~~~~~ @@ -216,8 +109,9 @@ Numpy utilities ~sigmoid ~softmax + Checkpoint utilities -~~~~~~~~~~~~~~~~~~~~ +-------------------- .. currentmodule:: ray.rllib.utils.checkpoints @@ -225,34 +119,5 @@ Checkpoint utilities :nosignatures: :toctree: doc/ - Checkpointable - convert_to_msgpack_checkpoint - convert_to_msgpack_policy_checkpoint - get_checkpoint_info try_import_msgpack - -Policy utilities -~~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib.utils.policy - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - compute_log_likelihoods_from_input_dict - create_policy_for_framework - local_policy_inference - parse_policy_specs_from_checkpoint - -Other utilities -~~~~~~~~~~~~~~~ - -.. currentmodule:: ray.rllib - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - utils.tensor_dtype.get_np_dtype - core.rl_module.validate_module_id + Checkpointable diff --git a/doc/source/rllib/rllib-advanced-api.rst b/doc/source/rllib/rllib-advanced-api.rst index aa9730e28379..40145f99cbd7 100644 --- a/doc/source/rllib/rllib-advanced-api.rst +++ b/doc/source/rllib/rllib-advanced-api.rst @@ -115,7 +115,7 @@ and .. tip:: You can create custom logic that can run on each evaluation episode by checking - if the :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` is in + if the ``RolloutWorker`` is in evaluation mode, through accessing ``worker.policy_config["in_evaluation"]``. You can then implement this check in ``on_episode_start()`` or ``on_episode_end()`` in your subclass of :py:class:`~ray.rllib.callbacks.callbacks.RLlibCallback`. diff --git a/doc/source/rllib/rllib-fault-tolerance.rst b/doc/source/rllib/rllib-fault-tolerance.rst index 0ad69df9db15..4559a4cd4cb7 100644 --- a/doc/source/rllib/rllib-fault-tolerance.rst +++ b/doc/source/rllib/rllib-fault-tolerance.rst @@ -19,8 +19,7 @@ Worker Recovery --------------- RLlib supports self-recovering and elastic :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` for both -:ref:`training and evaluation EnvRunner workers `. -This provides fault tolerance at worker level. +training and evaluation EnvRunner workers. This provides fault tolerance at worker level. This means that if you have n :py:class:`~ray.rllib.env.env_runner.EnvRunner` workers sitting on different machines and a machine is pre-empted, RLlib can continue training and evaluation with minimal interruption. diff --git a/doc/source/rllib/rllib-new-api-stack.rst b/doc/source/rllib/rllib-new-api-stack.rst index 37094aba2881..e35092f586ca 100644 --- a/doc/source/rllib/rllib-new-api-stack.rst +++ b/doc/source/rllib/rllib-new-api-stack.rst @@ -87,10 +87,10 @@ Applying the above principles, the Ray Team reduced the important **must-know** for the average RLlib user from seven on the old stack, to only four on the new stack. The **core** new API stack classes are: -* :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` (replaces :py:class:`~ray.rllib.models.modelv2.ModelV2` and :py:class:`~ray.rllib.policy.policy_map.PolicyMap` APIs) -* :py:class:`~ray.rllib.core.learner.learner.Learner` (replaces :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` and some of :py:class:`~ray.rllib.policy.policy.Policy`) -* :py:class:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode` and :py:class:`~ray.rllib.env.multi_agent_episode.MultiAgentEpisode` (replaces :py:class:`~ray.rllib.policy.view_requirement.ViewRequirement`, :py:class:`~ray.rllib.evaluation.collectors.SampleCollector`, :py:class:`~ray.rllib.evaluation.episode.Episode`, and :py:class:`~ray.rllib.evaluation.episode_v2.EpisodeV2`) -* :py:class:`~ray.rllib.connector.connector_v2.ConnectorV2` (replaces :py:class:`~ray.rllib.connector.connector.Connector` and some of :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` and :py:class:`~ray.rllib.policy.policy.Policy`) +* :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` (replaces ``ModelV2`` and ``PolicyMap`` APIs) +* :py:class:`~ray.rllib.core.learner.learner.Learner` (replaces ``RolloutWorker`` and some of ``Policy``) +* :py:class:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode` and :py:class:`~ray.rllib.env.multi_agent_episode.MultiAgentEpisode` (replaces ``ViewRequirement``, ``SampleCollector``, ``Episode``, and ``EpisodeV2``) +* :py:class:`~ray.rllib.connector.connector_v2.ConnectorV2` (replaces ``Connector`` and some of ``RolloutWorker`` and ``Policy``) The :py:class:`~ray.rllib.algorithm.algorithm_config.AlgorithmConfig` and :py:class:`~ray.rllib.algorithm.algorithm.Algorithm` APIs remain as-is. These are already established APIs on the old stack. diff --git a/rllib/core/rl_module/__init__.py b/rllib/core/rl_module/__init__.py index fa88c9e75ecd..490cd7942947 100644 --- a/rllib/core/rl_module/__init__.py +++ b/rllib/core/rl_module/__init__.py @@ -7,12 +7,12 @@ MultiRLModuleSpec, ) from ray.util import log_once -from ray.util.annotations import PublicAPI +from ray.util.annotations import DeveloperAPI logger = logging.getLogger("ray.rllib") -@PublicAPI(stability="alpha") +@DeveloperAPI def validate_module_id(policy_id: str, error: bool = False) -> None: """Makes sure the given `policy_id` is valid. diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index dde74deaf889..1c8e9531fc34 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -685,7 +685,7 @@ def _is_dir(file_info: pyarrow.fs.FileInfo) -> bool: return file_info.type == pyarrow.fs.FileType.Directory -@PublicAPI(stability="alpha") +@OldAPIStack def get_checkpoint_info( checkpoint: Union[str, Checkpoint], filesystem: Optional["pyarrow.fs.FileSystem"] = None, diff --git a/rllib/utils/policy.py b/rllib/utils/policy.py index 9cadcb08b054..a5b6b2ccfda6 100644 --- a/rllib/utils/policy.py +++ b/rllib/utils/policy.py @@ -19,6 +19,7 @@ from ray.rllib.models.preprocessors import ATARI_OBS_SHAPE from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ( @@ -31,7 +32,6 @@ TensorType, ) from ray.util import log_once -from ray.util.annotations import PublicAPI if TYPE_CHECKING: from ray.rllib.policy.policy import Policy @@ -41,7 +41,7 @@ tf1, tf, tfv = try_import_tf() -@PublicAPI +@OldAPIStack def create_policy_for_framework( policy_id: str, policy_class: Type["Policy"], @@ -108,7 +108,7 @@ def create_policy_for_framework( return policy_class(observation_space, action_space, merged_config) -@PublicAPI(stability="alpha") +@OldAPIStack def parse_policy_specs_from_checkpoint( path: str, ) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]: @@ -137,7 +137,7 @@ def parse_policy_specs_from_checkpoint( return policy_config, policy_specs, policy_states -@PublicAPI(stability="alpha") +@OldAPIStack def local_policy_inference( policy: "Policy", env_id: str, @@ -242,7 +242,7 @@ def local_policy_inference( return outputs -@PublicAPI +@OldAPIStack def compute_log_likelihoods_from_input_dict( policy: "Policy", batch: Union[SampleBatch, Dict[str, TensorStructType]] ): diff --git a/rllib/utils/schedules/scheduler.py b/rllib/utils/schedules/scheduler.py index ea7eca808457..901b5c785acd 100644 --- a/rllib/utils/schedules/scheduler.py +++ b/rllib/utils/schedules/scheduler.py @@ -19,8 +19,8 @@ class Scheduler: def __init__( self, - *, fixed_value_or_schedule: LearningRateOrSchedule, + *, framework: str = "torch", device: Optional[str] = None, ): diff --git a/rllib/utils/tensor_dtype.py b/rllib/utils/tensor_dtype.py index c1865756d893..83677d80a46a 100644 --- a/rllib/utils/tensor_dtype.py +++ b/rllib/utils/tensor_dtype.py @@ -2,7 +2,7 @@ from ray.rllib.utils.typing import TensorType from ray.rllib.utils.framework import try_import_torch, try_import_tf -from ray.util.annotations import PublicAPI +from ray.util.annotations import DeveloperAPI torch, _ = try_import_torch() _, tf, _ = try_import_tf() @@ -52,7 +52,7 @@ tf_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_tf_dtype_dict.items()} -@PublicAPI(stability="alpha") +@DeveloperAPI def get_np_dtype(x: TensorType) -> np.dtype: """Returns the NumPy dtype of the given tensor or array.""" if torch and isinstance(x, torch.Tensor): diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index 0d360d4d1488..1016edb31f0f 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -10,7 +10,7 @@ import tree # pip install dm_tree from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI, OldAPIStack from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import SMALL_NUMBER from ray.rllib.utils.typing import ( @@ -42,9 +42,7 @@ ) -# TODO (sven): Deprecate this function once we have moved completely to the Learner API. -# Replaced with `clip_gradients()`. -@PublicAPI +@OldAPIStack def apply_grad_clipping( policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType ) -> Dict[str, TensorType]: @@ -202,7 +200,7 @@ def compute_global_norm(gradients_list: "ParamList") -> TensorType: return total_norm -@PublicAPI +@OldAPIStack def concat_multi_gpu_td_errors( policy: Union["TorchPolicy", "TorchPolicyV2"] ) -> Dict[str, TensorType]: @@ -482,7 +480,7 @@ def global_norm(tensors: List[TensorType]) -> TensorType: return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) -@PublicAPI +@OldAPIStack def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: """Computes the huber loss for a given term and delta parameter. @@ -507,7 +505,7 @@ def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: ) -@PublicAPI +@OldAPIStack def l2_loss(x: TensorType) -> TensorType: """Computes half the L2 norm over a tensor's values without the sqrt. @@ -522,27 +520,6 @@ def l2_loss(x: TensorType) -> TensorType: return 0.5 * torch.sum(torch.pow(x, 2.0)) -@PublicAPI -def minimize_and_clip( - optimizer: "torch.optim.Optimizer", clip_val: float = 10.0 -) -> None: - """Clips grads found in `optimizer.param_groups` to given value in place. - - Ensures the norm of the gradients for each variable is clipped to - `clip_val`. - - Args: - optimizer: The torch.optim.Optimizer to get the variables from. - clip_val: The global norm clip value. Will clip around -clip_val and - +clip_val. - """ - # Loop through optimizer's variables and norm per variable. - for param_group in optimizer.param_groups: - for p in param_group["params"]: - if p.grad is not None: - torch.nn.utils.clip_grad_norm_(p.grad, clip_val) - - @PublicAPI def one_hot(x: TensorType, space: gym.Space) -> TensorType: """Returns a one-hot tensor, given and int tensor and a space. @@ -661,9 +638,11 @@ def update_target_network( ) -> None: """Updates a torch.nn.Module target network using Polyak averaging. - new_target_net_weight = ( - tau * main_net_weight + (1.0 - tau) * current_target_net_weight - ) + .. code-block:: text + + new_target_net_weight = ( + tau * main_net_weight + (1.0 - tau) * current_target_net_weight + ) Args: main_net: The nn.Module to update from.