You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
in the DQN forward we have this code to check the device
if self.device is not None:
warnings.warn(
"The use of a device for the objective function will soon be deprecated",
category=DeprecationWarning,
)
device = self.device
else:
device = tensordict.device
The problem is that loss.device is
@property
def device(self) -> torch.device:
for p in self.parameters():
return p.device
return torch.device("cpu")
which will never return None, so the first branch is always taken and the warning always triggered.
So this logic has never made any sense.
How should we sort this out?
The text was updated successfully, but these errors were encountered:
Not sure about the "never" but I agree that we should fix that
We should simply not check the device.
If you give a tensor on the wrong device to torch.nn.Linear, Linear does not care. It's only when the op is execute that things will break. We should just remove all device checks from losses.
in the DQN forward we have this code to check the device
The problem is that
loss.device
iswhich will never return None, so the first branch is always taken and the warning always triggered.
So this logic has never made any sense.
How should we sort this out?
The text was updated successfully, but these errors were encountered: