diff --git a/chemprop/nn/agg.py b/chemprop/nn/agg.py index 150851928..f027a795e 100644 --- a/chemprop/nn/agg.py +++ b/chemprop/nn/agg.py @@ -117,6 +117,7 @@ class AttentiveAggregation(Aggregation): def __init__(self, dim: int = 0, *args, output_size: int, **kwargs): super().__init__(dim, *args, **kwargs) + self.hparams["output_size"] = output_size self.W = nn.Linear(output_size, 1) def forward(self, H: Tensor, batch: Tensor) -> Tensor: