Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch size calculation for multicomponent #1098

Merged
merged 4 commits into from
Dec 24, 2024

Conversation

KnathanM
Copy link
Member

@KnathanM KnathanM commented Nov 5, 2024

A LightningModule has a self.log method for logging metrics. This takes the batch_size argument to help it know what your batch size is when you have a complex data structure. Our batch is a tuple of 7ish things, so by default lightning would see our batch size as 7. We get around this by explicitly passing batch_size=len(batch[0]). This calls the __len__ of the batch mol graphs which is set to the number of mol graphs.

MulticomponentMPNN inherits from MPNN and reuses a lot of its code. But batch[0] for the multicomponent case is a list of BatchMolGraphs. len(batch[0]) will be the number of components and not the true batch size. I fixed this by adding a get_batch_size method to both MulticomponentMPNN and MPNN. I also updated the type hints to reflect the fact that MulticomponentMPNN reuses code from MPNN.

@KnathanM KnathanM requested a review from akshatzalte November 5, 2024 18:37
@KnathanM
Copy link
Member Author

@akshatzalte If this looks good to you, can you merge it?

@akshatzalte akshatzalte merged commit 6ddf25d into chemprop:main Dec 24, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants