Fix batch size calculation for multicomponent #1098
Merged
+30
−12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A
LightningModule
has aself.log
method for logging metrics. This takes thebatch_size
argument to help it know what your batch size is when you have a complex data structure. Our batch is a tuple of 7ish things, so by defaultlightning
would see our batch size as 7. We get around this by explicitly passingbatch_size=len(batch[0])
. This calls the__len__
of the batch mol graphs which is set to the number of mol graphs.MulticomponentMPNN
inherits fromMPNN
and reuses a lot of its code. Butbatch[0]
for the multicomponent case is a list of BatchMolGraphs.len(batch[0])
will be the number of components and not the true batch size. I fixed this by adding aget_batch_size
method to bothMulticomponentMPNN
andMPNN
. I also updated the type hints to reflect the fact thatMulticomponentMPNN
reuses code fromMPNN
.