You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When a model is using AttentiveAggregation and is saved in a checkpoint(the best model weights from a particular epoch that does well on validation set), upon loading from that .ckpt file using load_model or MPNN.load_from_checkpoint(), it throws an error that AttentiveAggregation output_size not mentioned.
Example(s)
mp = BondMessagePassing(d_h=600, bias=True, depth=5, dropout=0.5, activation="relu", undirected=True)
#agg = AttentiveAggregation(output_size=600, bias=True, activation="relu")
agg = AttentiveAggregation(output_size=600, bias=True, activation="relu")
binary_class_ffn = BinaryClassificationFFN(n_layers=2, input_dim=1196, hidden_dim=600, dropout=0.2, activation="relu")
metrics_list = [nn.metrics.BinaryAccuracy(), nn.metrics.BinaryAUROC(),nn.metrics.BinaryAUPRC(), nn.metrics.BinaryF1Score(), nn.metrics.BinaryMCCMetric()]
model = MPNN(mp, agg, binary_class_ffn, batch_norm=True, metrics=metrics_list)
checkpointing = ModelCheckpoint(
dirpath="chemprop_checkpoints",
filename="best-{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_last=True,
)
trainer = pl.Trainer(
logger=False,
enable_progress_bar=True,
enable_checkpointing=True,
devices=1,
max_epochs=20, # number of epochs to train for\
accelerator="gpu",
callbacks=[checkpointing],
)
# Now when I try to load the model from the saved checkpoint file using this
MPNN.load_from_checkpoint("/raid/home/debarka/KG_Works/Recreate_Molsorter/chemprop_checkpoints/best-epoch=17-val_loss=0.45.ckpt") #Errors out!
Expected behavior
It should just load the model
Error Stack Trace
{
"name": "TypeError",
"message": "AttentiveAggregation.__init__() missing 1 required keyword-only argument: 'output_size'",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[58], line 1
----> 1 MPNN.load_from_checkpoint(\"/raid/home/debarka/KG_Works/Recreate_Molsorter/chemprop_checkpoints/best-epoch=17-val_loss=0.45.ckpt\")
File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:282, in MPNN.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
275 @classmethod
276 def load_from_checkpoint(
277 cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs
278 ) -> MPNN:
279 submodules = {
280 k: v for k, v in kwargs.items() if k in [\"message_passing\", \"agg\", \"predictor\"]
281 }
--> 282 submodules, state_dict, hparams = cls._load(checkpoint_path, map_location, **submodules)
283 kwargs.update(submodules)
285 state_dict = cls._add_metric_task_weights_to_state_dict(state_dict, hparams)
File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:250, in MPNN._load(cls, path, map_location, **submodules)
247 except KeyError:
248 raise KeyError(f\"Could not find hyper parameters and/or state dict in {path}.\")
--> 250 submodules |= {
251 key: hparams[key].pop(\"cls\")(**hparams[key])
252 for key in (\"message_passing\", \"agg\", \"predictor\")
253 if key not in submodules
254 }
256 if not hasattr(submodules[\"predictor\"].criterion, \"_defaults\"):
257 submodules[\"predictor\"].criterion = submodules[\"predictor\"].criterion.__class__(
258 task_weights=submodules[\"predictor\"].criterion.task_weights
259 )
File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:251, in <dictcomp>(.0)
247 except KeyError:
248 raise KeyError(f\"Could not find hyper parameters and/or state dict in {path}.\")
250 submodules |= {
--> 251 key: hparams[key].pop(\"cls\")(**hparams[key])
252 for key in (\"message_passing\", \"agg\", \"predictor\")
253 if key not in submodules
254 }
256 if not hasattr(submodules[\"predictor\"].criterion, \"_defaults\"):
257 submodules[\"predictor\"].criterion = submodules[\"predictor\"].criterion.__class__(
258 task_weights=submodules[\"predictor\"].criterion.task_weights
259 )
TypeError: AttentiveAggregation.__init__() missing 1 required keyword-only argument: 'output_size'"
}
Describe the bug
When a model is using AttentiveAggregation and is saved in a checkpoint(the best model weights from a particular epoch that does well on validation set), upon loading from that .ckpt file using load_model or MPNN.load_from_checkpoint(), it throws an error that AttentiveAggregation output_size not mentioned.
Example(s)
Expected behavior
It should just load the model
Error Stack Trace
Environment
aimsim_core 2.2.2
aiohappyeyeballs 2.4.4
aiohttp 3.11.9
aiosignal 1.3.1
astartes 1.3.0
asttokens 3.0.0
attrs 24.2.0
chemprop 2.1.0
comm 0.2.2
ConfigArgParse 1.7
debugpy 1.8.9
decorator 5.1.1
descriptastorus 2.8.0
dill 0.3.9
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.13.1
frozenlist 1.5.0
fsspec 2024.2.0
h5py 3.12.1
idna 3.10
importlib_metadata 8.5.0
ipykernel 6.29.5
ipython 8.30.0
jedi 0.19.2
Jinja2 3.1.3
joblib 1.4.2
jupyter_client 8.6.3
jupyter_core 5.7.2
lightning 2.4.0
lightning-utilities 0.11.9
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib-inline 0.1.7
mdurl 0.1.2
mhfp 1.9.6
mordredcommunity 2.0.6
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.17
nest_asyncio 1.6.0
networkx 3.2.1
numpy 1.26.3
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.1.105
nvidia-nvtx-cu12 12.1.105
packaging 24.2
padelpy 0.1.16
pandas 2.2.3
pandas-flavor 0.6.0
parso 0.8.4
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.2.0
pip 24.2
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.1
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
Pygments 2.18.0
python-dateutil 2.9.0.post0
pytorch-lightning 2.4.0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
rdkit 2024.3.6
rich 13.9.4
scikit-learn 1.5.2
scipy 1.14.1
setuptools 75.1.0
six 1.16.0
stack-data 0.6.2
sympy 1.13.1
tabulate 0.9.0
threadpoolctl 3.5.0
torch 2.5.0+cu121
torchaudio 2.5.0+cu121
torchmetrics 1.6.0
torchvision 0.20.0+cu121
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
triton 3.1.0
typing_extensions 4.9.0
tzdata 2024.2
wcwidth 0.2.13
wheel 0.44.0
xarray 2024.11.0
yarl 1.18.3
zipp 3.21.0
Checklist
conda list
orpip list
shows the packages listed in thepyproject.toml
pytest -v
reports no errorsAdditional context
I think the model is not saving the metadata of that outputsize in the checkpoint properly.
The text was updated successfully, but these errors were encountered: