-
-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚚🪞Make sure inductive representation is on same device #1229
Conversation
Maybe it might be more consistent to simply override def to(self, *args, **kwargs):
for mode in [TRAINING, VALIDATION, TESTING]:
self._mode_to_representations[mode] = self._mode_to_representations[mode].to(*args, **kwargs)
return super().to(*args, **kwargs) |
Why are the representations not a I tend towards making |
Looks like the decision to not make it a |
@@ -97,8 +100,9 @@ def _get_entity_representations_from_inductive_mode( | |||
raise ValueError( | |||
f"{self.__class__.__name__} does not support the transductive setting (i.e., when mode is None)" | |||
) | |||
if mode in self._mode_to_representations: | |||
return self._mode_to_representations[mode] | |||
key = f"{mode}_factory" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this mode_factory and not just mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a (key, value)
pair torch.nn.ModuleDict
registers value
as an attribute with name key
(cf. here). Since ModuleDict
is a Module
, too, it already has an attribute training
, which is why we cannot have the key training
in the dict. Thus, I suffixed them with _factory
. I did not want to change the InductiveMode
constants though, thus we need to derive a key different from the mode
.
EDIT: the relevant commit is a086760 and there is a small note above the declaration of the dictionary.
EDIT2: we could have refactored key = f"{mode}_factory"
into a helper method, but I thought this would make the code less readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay just making sure you thought about this. feel free to merge when ready
Contributing to bug fixes
Link to the relevant Bug(s)
#1228
Description of the Change
Before returning
entity_representation
for the right mode, sent them to the same device as model.Possible Drawbacks
None?
Verification Process
I tested this on the code in the original issue and also on https://github.com/pykeen/ilpc2022 for the small dataset. Both previously failed, but now run through as expected.
Release Notes