Skip to content

Commit

Permalink
🏄 🤙 Fix gradient error in CompGCN buffering (pykeen#573)
Browse files Browse the repository at this point in the history
* Invalidate buffered representations when returning from evaluation mode
* Add model test case for pipeline with early stopping
  • Loading branch information
mberr authored Aug 19, 2021
1 parent 8405d82 commit 7a87578
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/pykeen/nn/emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,14 @@ def post_parameter_update(self) -> None: # noqa: D102
# invalidate enriched embeddings
self.enriched_representations = None

def train(self, mode: bool = True): # noqa: D102
# when changing from evaluation to training mode, the buffered representations have been computed without
# gradient tracking. hence, we need to invalidate them.
# note: this occurs in practice when continuing training after evaluation.
if mode and not self.training:
self.enriched_representations = None
return super().train(mode=mode)

def forward(
self,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
Expand Down
21 changes: 21 additions & 0 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pykeen.nn.emb import RepresentationModule
from pykeen.nn.modules import FunctionalInteraction, Interaction, LiteralInteraction
from pykeen.optimizers import optimizer_resolver
from pykeen.pipeline import pipeline
from pykeen.regularizers import LpRegularizer, Regularizer
from pykeen.trackers import ResultTracker
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop, TrainingLoop
Expand Down Expand Up @@ -1057,6 +1058,26 @@ def test_cli_training_nations(self):
"""Test running the pipeline on almost all models with only training data."""
self._help_test_cli(['-t', NATIONS_TRAIN_PATH] + self._cli_extras)

@pytest.mark.slow
def test_pipeline_nations_early_stopper(self):
"""Test running the pipeline with early stopping."""
model_kwargs = dict(self.instance_kwargs)
# triples factory is added by the pipeline
model_kwargs.pop("triples_factory")
pipeline(
model=self.cls,
model_kwargs=model_kwargs,
dataset="nations",
dataset_kwargs=dict(create_inverse_triples=self.create_inverse_triples),
stopper="early",
training_loop_kwargs=self.training_loop_kwargs,
stopper_kwargs=dict(frequency=1),
training_kwargs=dict(
batch_size=self.train_batch_size,
num_epochs=self.train_num_epochs,
),
)

@pytest.mark.slow
def test_cli_training_kinships(self):
"""Test running the pipeline on almost all models with only training data."""
Expand Down

0 comments on commit 7a87578

Please sign in to comment.