From 815eece70821926b0e273fe448d21bf37ea09082 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 14:59:59 +0100 Subject: [PATCH] [Feature] inline `hold_out_net` ghstack-source-id: c315202c8af55f0852195fe488ae855966386c4c Pull Request resolved: https://github.com/pytorch/rl/pull/2499 --- torchrl/objectives/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 31954005195..afd28e861c7 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -408,11 +408,18 @@ def __init__(self, network: nn.Module) -> None: def __enter__(self) -> None: if self.mode: - self.network.requires_grad_(False) + if is_dynamo_compiling(): + self._params = TensorDict.from_module(self.network) + self._params.data.to_module(self.network) + else: + self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.mode: - self.network.requires_grad_() + if is_dynamo_compiling(): + self._params.to_module(self.network) + else: + self.network.requires_grad_() class hold_out_params(_context_manager):