-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] reset_parameters
for multiagent nets
#1970
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1970
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 21 Unrelated FailuresAs of commit 48757d0 with merge base 3d65083 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
return torch.vmap(reset_module, *args, **kwargs) | ||
|
||
if not self.share_params: | ||
vmap_reset_module(self._empty_net, randomness="different")(self.params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to discuss the vmap randomness of this class
@property
def _vmap_randomness(self):
if self.initialized:
return "error"
return "same"
Why would this be a class property and why are we having those values?
For me this should be
@property
def _vmap_randomness(self):
return "different"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case it needs to be different to have a different reset values for each agent.
But also in the forward pass I feel like it should be "different"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it should be
@property
def _vmap_randomness(self):
if self.initialized:
return self.vmap_randomness
return "different"
and users are in charge of telling the module what randomness they want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this a bit? Why do we have a switch on initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For init you need "different" because you must have different weights for each net.
But in other settings you can't tell, and the best is to let the user choose.
They may as well want the same random number for each element of the batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what I do not understand is why before we had
@property
def _vmap_randomness(self):
if self.initialized:
return "error"
return "same"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I try to change this to
@property
def _vmap_randomness(self):
if self.initialized:
return self.vmap_randomness
return "different"
the lazy layers will crash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the lazy layers will crash
this is a statement that is hard to reproduce, can you share more?
For instance this works fine on my end:
from torchrl.modules import MLP
from tensordict import TensorDict
import torch
from functorch import dim
d0 = dim.dims(1)
modules = [torch.nn.Linear(2, 3) for _ in range(3)]
td = TensorDict.from_modules(*modules, as_module=True)
def reset(td):
with td.to_module(modules[0]):
modules[0].reset_parameters()
return td
td = torch.vmap(reset, randomness="same")(td)
print(td["weight"])
td = torch.vmap(reset, randomness="different")(td)
print(td["weight"])
the first produces a stack of identical tensors, the second different
reset_params
for multiagent netsreset_parameters
for multiagent nets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Fixes #1967
Todo: