Skip to content

Commit

Permalink
[Performance] Faster losses (#1272)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 2, 2023
1 parent 12cbe72 commit c42ff78
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 109 deletions.
18 changes: 9 additions & 9 deletions .github/workflows/nightly_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.7", "cp37-cp37m"], ["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-cuda116
steps:
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -158,7 +158,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.7", "cp37-cp37m"], ["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
Expand Down Expand Up @@ -189,7 +189,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Checkout torchrl
uses: actions/checkout@v2
Expand Down Expand Up @@ -217,7 +217,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.7", "cp37-cp37"], ["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
cuda_support: [["", "cpu", "cpu"]]
steps:
- name: Setup Python
Expand Down Expand Up @@ -279,7 +279,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -312,7 +312,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -369,7 +369,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.7", "3.7"], ["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
steps:
- name: Checkout torchrl
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _get_policy_and_device(

try:
policy_device = next(policy.parameters()).device
except: # noqa
except Exception:
policy_device = (
torch.device(device) if device is not None else torch.device("cpu")
)
Expand Down
8 changes: 7 additions & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -333,14 +334,19 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
)
return self.critic_coef * loss_value

@property
@_cache_values
def _cached_detach_critic_params(self):
return self.critic_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
tensordict,
params=self.critic_params.detach(),
params=self._cached_detach_critic_params,
target_params=self.target_critic_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
Expand Down
33 changes: 18 additions & 15 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.utils import Buffer
from torchrl.objectives.utils import ValueEstimators
from torchrl.objectives.utils import _cache_values, ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase

_has_functorch = False
Expand Down Expand Up @@ -104,10 +104,12 @@ def tensor_keys(self) -> _AcceptedKeys:
def __new__(cls, *args, **kwargs):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
cls._tensor_keys = cls._AcceptedKeys()
return super().__new__(cls)
self = super().__new__(cls)
return self

def __init__(self):
super().__init__()
self._cache = {}
self._param_maps = {}
self._value_estimator = None
self._has_update_associated = False
Expand Down Expand Up @@ -389,6 +391,7 @@ def _compare_and_expand(param):
property(lambda _self=self: _self._target_param_getter(module_name)),
)

@_cache_values
def _param_getter(self, network_name):
name = "_" + network_name + "_params"
param_name = network_name + "_params"
Expand Down Expand Up @@ -417,6 +420,7 @@ def _param_getter(self, network_name):
f"{self.__class__.__name__} does not have the target param {name}"
)

@_cache_values
def _target_param_getter(self, network_name):
target_name = "_target_" + network_name + "_params"
param_name = network_name + "_params"
Expand Down Expand Up @@ -452,6 +456,17 @@ def _target_param_getter(self, network_name):
f"{self.__class__.__name__} does not have the target param {target_name}"
)

def _apply(self, fn):
# any call to apply erases the cache: the reason is that detached
# params will fail to be cast so we need to get the cache back
self._erase_cache()
return super()._apply(fn)

def _erase_cache(self):
for key in list(self.__dict__):
if key.startswith("_cache"):
del self.__dict__[key]

def _networks(self) -> Iterator[nn.Module]:
for item in self.__dir__():
if isinstance(item, nn.Module):
Expand Down Expand Up @@ -491,19 +506,7 @@ def to(self, *args, **kwargs):
origin_value = getattr(self, origin)
target_value = getattr(self, target)
setattr(self, target, origin_value.expand_as(target_value))

# lists_of_params = {
# name: value
# for name, value in self.__dict__.items()
# if name.endswith("_params") and isinstance(value, TensorDictBase)
# }
# for list_of_params in lists_of_params.values():
# for key, param in list(list_of_params.items(True)):
# if isinstance(param, TensorDictBase):
# continue
# # we replace the param by the expanded form if needs be
# if param in self._param_maps:
# list_of_params[key] = self._param_maps[param].data.expand_as(param)
out._cache = {}
return out

def cuda(self, device: Optional[Union[int, device]] = None) -> LossModule:
Expand Down
24 changes: 18 additions & 6 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -339,6 +340,9 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = vmap(self.qvalue_network)

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
Expand Down Expand Up @@ -491,6 +495,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

return TensorDict(out, [])

@property
@_cache_values
def _cached_detach_qvalue_params(self):
return self.qvalue_network_params.detach()

def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor_network.get_dist(
Expand All @@ -502,8 +511,9 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor:

td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q.set(self.tensor_keys.action, a_reparm)
td_q = vmap(self.qvalue_network, (None, 0))(
td_q, self.qvalue_network_params.detach().clone()
td_q = self._vmap_qvalue_networkN0(
td_q,
self._cached_detach_qvalue_params,
)
min_q_logprob = (
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
Expand Down Expand Up @@ -554,7 +564,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):

# get q-values
if not self.max_q_backup:
next_tensordict_expand = vmap(self.qvalue_network, (None, 0))(
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)
next_state_value = next_tensordict_expand.get(
Expand All @@ -575,7 +585,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
actor_params,
num_actions=self.num_random,
)
next_tensordict_expand = vmap(self.qvalue_network, (None, 0))(
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)

Expand Down Expand Up @@ -605,7 +615,7 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
)

tensordict_pred_q = tensordict.select(*self.qvalue_network.in_keys)
q_pred = vmap(self.qvalue_network, (None, 0))(
q_pred = self._vmap_qvalue_networkN0(
tensordict_pred_q, self.qvalue_network_params
).get(self.tensor_keys.state_action_value)

Expand Down Expand Up @@ -670,7 +680,9 @@ def _loss_qvalue_v(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
)
cql_tensordict = cql_tensordict.contiguous()

cql_tensordict_expand = vmap(self.qvalue_network)(cql_tensordict, qvalue_params)
cql_tensordict_expand = self._vmap_qvalue_network00(
cql_tensordict, qvalue_params
)
# get q values
state_action_value = cql_tensordict_expand.get(
self.tensor_keys.state_action_value
Expand Down
43 changes: 26 additions & 17 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
default_value_kwargs,
distance_loss,
hold_out_params,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
Expand Down Expand Up @@ -299,11 +299,10 @@ def _loss_actor(
td_copy,
params=self.actor_network_params,
)
with hold_out_params(self.value_network_params) as params:
td_copy = self.value_network(
td_copy,
params=params,
)
td_copy = self.value_network(
td_copy,
params=self._cached_detached_value_params,
)
return -td_copy.get(self.tensor_keys.state_action_value)

def _loss_value(
Expand All @@ -318,18 +317,8 @@ def _loss_value(
)
pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)

target_params = TensorDict(
{
"module": {
"0": self.target_actor_network_params,
"1": self.target_value_network_params,
}
},
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
target_value = self.value_estimator.value_estimate(
tensordict, target_params=target_params
tensordict, target_params=self._cached_target_params
).squeeze(-1)

# td_error = pred_val - target_value
Expand Down Expand Up @@ -368,3 +357,23 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

@property
@_cache_values
def _cached_target_params(self):
target_params = TensorDict(
{
"module": {
"0": self.target_actor_network_params,
"1": self.target_value_network_params,
}
},
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
return target_params

@property
@_cache_values
def _cached_detached_value_params(self):
return self.value_network_params.detach()
Loading

0 comments on commit c42ff78

Please sign in to comment.