Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2 BUG]: Loading a model with AttentiveAggregation fails from a checkpoint #1120

Open
2 tasks
SwarnavaCB opened this issue Dec 5, 2024 · 1 comment
Open
2 tasks
Labels
bug Something isn't working

Comments

@SwarnavaCB
Copy link

SwarnavaCB commented Dec 5, 2024

Describe the bug
When a model is using AttentiveAggregation and is saved in a checkpoint(the best model weights from a particular epoch that does well on validation set), upon loading from that .ckpt file using load_model or MPNN.load_from_checkpoint(), it throws an error that AttentiveAggregation output_size not mentioned.

Example(s)

mp = BondMessagePassing(d_h=600, bias=True, depth=5, dropout=0.5, activation="relu", undirected=True)
#agg = AttentiveAggregation(output_size=600, bias=True, activation="relu")
agg = AttentiveAggregation(output_size=600, bias=True, activation="relu")
binary_class_ffn = BinaryClassificationFFN(n_layers=2, input_dim=1196, hidden_dim=600, dropout=0.2, activation="relu")
metrics_list = [nn.metrics.BinaryAccuracy(), nn.metrics.BinaryAUROC(),nn.metrics.BinaryAUPRC(), nn.metrics.BinaryF1Score(), nn.metrics.BinaryMCCMetric()]

model = MPNN(mp, agg, binary_class_ffn, batch_norm=True, metrics=metrics_list)

checkpointing = ModelCheckpoint(
    dirpath="chemprop_checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_progress_bar=True,
    enable_checkpointing=True,
    devices=1,
    max_epochs=20, # number of epochs to train for\
    accelerator="gpu",
    callbacks=[checkpointing],
)

# Now when I try to load the model from the saved checkpoint file using this

MPNN.load_from_checkpoint("/raid/home/debarka/KG_Works/Recreate_Molsorter/chemprop_checkpoints/best-epoch=17-val_loss=0.45.ckpt") #Errors out!

Expected behavior
It should just load the model

Error Stack Trace

{
	"name": "TypeError",
	"message": "AttentiveAggregation.__init__() missing 1 required keyword-only argument: 'output_size'",
	"stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[58], line 1
----> 1 MPNN.load_from_checkpoint(\"/raid/home/debarka/KG_Works/Recreate_Molsorter/chemprop_checkpoints/best-epoch=17-val_loss=0.45.ckpt\")

File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:282, in MPNN.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    275 @classmethod
    276 def load_from_checkpoint(
    277     cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs
    278 ) -> MPNN:
    279     submodules = {
    280         k: v for k, v in kwargs.items() if k in [\"message_passing\", \"agg\", \"predictor\"]
    281     }
--> 282     submodules, state_dict, hparams = cls._load(checkpoint_path, map_location, **submodules)
    283     kwargs.update(submodules)
    285     state_dict = cls._add_metric_task_weights_to_state_dict(state_dict, hparams)

File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:250, in MPNN._load(cls, path, map_location, **submodules)
    247 except KeyError:
    248     raise KeyError(f\"Could not find hyper parameters and/or state dict in {path}.\")
--> 250 submodules |= {
    251     key: hparams[key].pop(\"cls\")(**hparams[key])
    252     for key in (\"message_passing\", \"agg\", \"predictor\")
    253     if key not in submodules
    254 }
    256 if not hasattr(submodules[\"predictor\"].criterion, \"_defaults\"):
    257     submodules[\"predictor\"].criterion = submodules[\"predictor\"].criterion.__class__(
    258         task_weights=submodules[\"predictor\"].criterion.task_weights
    259     )

File /raid/home/debarka/miniconda3/envs/chemprop/lib/python3.11/site-packages/chemprop/models/model.py:251, in <dictcomp>(.0)
    247 except KeyError:
    248     raise KeyError(f\"Could not find hyper parameters and/or state dict in {path}.\")
    250 submodules |= {
--> 251     key: hparams[key].pop(\"cls\")(**hparams[key])
    252     for key in (\"message_passing\", \"agg\", \"predictor\")
    253     if key not in submodules
    254 }
    256 if not hasattr(submodules[\"predictor\"].criterion, \"_defaults\"):
    257     submodules[\"predictor\"].criterion = submodules[\"predictor\"].criterion.__class__(
    258         task_weights=submodules[\"predictor\"].criterion.task_weights
    259     )

TypeError: AttentiveAggregation.__init__() missing 1 required keyword-only argument: 'output_size'"
}

Environment

  • python version = 3.11.10
  • package versions:

aimsim_core 2.2.2
aiohappyeyeballs 2.4.4
aiohttp 3.11.9
aiosignal 1.3.1
astartes 1.3.0
asttokens 3.0.0
attrs 24.2.0
chemprop 2.1.0
comm 0.2.2
ConfigArgParse 1.7
debugpy 1.8.9
decorator 5.1.1
descriptastorus 2.8.0
dill 0.3.9
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.13.1
frozenlist 1.5.0
fsspec 2024.2.0
h5py 3.12.1
idna 3.10
importlib_metadata 8.5.0
ipykernel 6.29.5
ipython 8.30.0
jedi 0.19.2
Jinja2 3.1.3
joblib 1.4.2
jupyter_client 8.6.3
jupyter_core 5.7.2
lightning 2.4.0
lightning-utilities 0.11.9
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib-inline 0.1.7
mdurl 0.1.2
mhfp 1.9.6
mordredcommunity 2.0.6
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.17
nest_asyncio 1.6.0
networkx 3.2.1
numpy 1.26.3
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.1.105
nvidia-nvtx-cu12 12.1.105
packaging 24.2
padelpy 0.1.16
pandas 2.2.3
pandas-flavor 0.6.0
parso 0.8.4
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.2.0
pip 24.2
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.1
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
Pygments 2.18.0
python-dateutil 2.9.0.post0
pytorch-lightning 2.4.0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
rdkit 2024.3.6
rich 13.9.4
scikit-learn 1.5.2
scipy 1.14.1
setuptools 75.1.0
six 1.16.0
stack-data 0.6.2
sympy 1.13.1
tabulate 0.9.0
threadpoolctl 3.5.0
torch 2.5.0+cu121
torchaudio 2.5.0+cu121
torchmetrics 1.6.0
torchvision 0.20.0+cu121
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
triton 3.1.0
typing_extensions 4.9.0
tzdata 2024.2
wcwidth 0.2.13
wheel 0.44.0
xarray 2024.11.0
yarl 1.18.3
zipp 3.21.0

  • Linux Ubuntu 22.04 server

Checklist

  • all dependencies are satisifed: conda list or pip list shows the packages listed in the pyproject.toml
  • the unit tests are working: pytest -v reports no errors

Additional context
I think the model is not saving the metadata of that outputsize in the checkpoint properly.

@SwarnavaCB SwarnavaCB added the bug Something isn't working label Dec 5, 2024
@JacksonBurns
Copy link
Member

@KnathanM have you seen this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants