Skip to content

Commit

Permalink
🔮 💃 Reimplement SE with new-style model (pykeen#521)
Browse files Browse the repository at this point in the history
* Reimplement SE with new-style model

Trigger CI

* Update README.md

Trigger CI

* Update test_models.py

Trigger CI

* Reuse pykeen function

Trigger CI

* Update structured_embedding.py

Trigger CI
  • Loading branch information
cthoyt authored Jul 3, 2021
1 parent f14c6b1 commit 85519d8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 130 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ have a suggestion for another dataset to include in PyKEEN, please let us know
| R-GCN | [`pykeen.models.RGCN`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.RGCN.html) | [Schlichtkrull *et al.*, 2018](https://arxiv.org/pdf/1703.06103) |
| RotatE | [`pykeen.models.RotatE`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.RotatE.html) | [Sun *et al.*, 2019](https://arxiv.org/abs/1902.10197v1) |
| SimplE | [`pykeen.models.SimplE`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.SimplE.html) | [Kazemi *et al.*, 2018](https://papers.nips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs) |
| Structured Embedding | [`pykeen.models.StructuredEmbedding`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.StructuredEmbedding.html) | [Bordes *et al.*, 2011](http://www.aaai.org/ocs/index.php/AAAI/AAAI11/paper/download/3659/3898) |
| Structured Embedding | [`pykeen.models.StructuredEmbedding`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.StructuredEmbedding.html) | [Bordes *et al.*, 2011](https://www.aaai.org/ocs/index.php/AAAI/AAAI11/paper/download/3659/3898) |
| TorusE | [`pykeen.models.TorusE`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.TorusE.html) | [Ebisu *et al.*, 2018](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/16227) |
| TransD | [`pykeen.models.TransD`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.TransD.html) | [Ji *et al.*, 2015](http://www.aclweb.org/anthology/P15-1067) |
| TransE | [`pykeen.models.TransE`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.TransE.html) | [Bordes *et al.*, 2013](http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf) |
Expand Down
136 changes: 25 additions & 111 deletions src/pykeen/models/unimodal/structured_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,24 @@

"""Implementation of structured model (SE)."""

import functools
from typing import Any, ClassVar, Mapping, Optional

import numpy as np
import torch
from torch import nn
from class_resolver import Hint
from torch.nn import functional

from ..base import EntityEmbeddingModel
from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn.emb import Embedding, EmbeddingSpecification
from ...nn.init import xavier_uniform_
from ...typing import Constrainer, Hint, Initializer
from ...utils import compose
from ...nn import EmbeddingSpecification
from ...nn.init import xavier_uniform_, xavier_uniform_norm_
from ...nn.modules import StructuredEmbeddingInteraction
from ...typing import Constrainer, Initializer

__all__ = [
'StructuredEmbedding',
]


class StructuredEmbedding(EntityEmbeddingModel):
class StructuredEmbedding(ERModel):
r"""An implementation of the Structured Embedding (SE) published by [bordes2011]_.
SE applies role- and relation-specific projection matrices
Expand All @@ -41,7 +38,7 @@ class StructuredEmbedding(EntityEmbeddingModel):
citation:
author: Bordes
year: 2011
link: http://www.aaai.org/ocs/index.php/AAAI/AAAI11/paper/download/3659/3898
link: https://www.aaai.org/ocs/index.php/AAAI/AAAI11/paper/download/3659/3898
"""

#: The default strategy for optimizing the model's hyper-parameters
Expand All @@ -58,6 +55,7 @@ def __init__(
entity_initializer: Hint[Initializer] = xavier_uniform_,
entity_constrainer: Hint[Constrainer] = functional.normalize,
entity_constrainer_kwargs: Optional[Mapping[str, Any]] = None,
relation_initializer: Hint[Initializer] = xavier_uniform_norm_,
**kwargs,
) -> None:
r"""Initialize SE.
Expand All @@ -67,115 +65,31 @@ def __init__(
:param entity_initializer: Entity initializer function. Defaults to :func:`pykeen.nn.init.xavier_uniform_`
:param entity_constrainer: Entity constrainer function. Defaults to :func:`torch.nn.functional.normalize`
:param entity_constrainer_kwargs: Keyword arguments to be used when calling the entity constrainer
:param relation_initializer: Relation initializer function. Defaults to
:func:`pykeen.nn.init.xavier_uniform_norm_`
:param kwargs:
Remaining keyword arguments to forward to :class:`pykeen.models.EntityEmbeddingModel`
"""
super().__init__(
interaction=StructuredEmbeddingInteraction(
p=scoring_fct_norm,
power_norm=False,
),
entity_representations=EmbeddingSpecification(
embedding_dim=embedding_dim,
initializer=entity_initializer,
constrainer=entity_constrainer,
constrainer_kwargs=entity_constrainer_kwargs,
),
relation_representations=[
EmbeddingSpecification(
shape=(embedding_dim, embedding_dim),
initializer=relation_initializer,
),
EmbeddingSpecification(
shape=(embedding_dim, embedding_dim),
initializer=relation_initializer,
),
],
**kwargs,
)

self.scoring_fct_norm = scoring_fct_norm

# Embeddings
init_bound = 6 / np.sqrt(self.embedding_dim)
# Initialise relation embeddings to unit length
initializer = compose(
functools.partial(nn.init.uniform_, a=-init_bound, b=+init_bound),
functional.normalize,
)

self.left_relation_embeddings = Embedding.init_with_device(
num_embeddings=self.num_relations,
embedding_dim=embedding_dim ** 2,
device=self.device,
initializer=initializer,
)
self.right_relation_embeddings = Embedding.init_with_device(
num_embeddings=self.num_relations,
embedding_dim=embedding_dim ** 2,
device=self.device,
initializer=initializer,
)

def _reset_parameters_(self): # noqa: D102
super()._reset_parameters_()
self.left_relation_embeddings.reset_parameters()
self.right_relation_embeddings.reset_parameters()

def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
# Get embeddings
h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.embedding_dim, 1)
rel_h = self.left_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
rel_t = self.right_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.embedding_dim, 1)

# Project entities
proj_h = rel_h @ h
proj_t = rel_t @ t

scores = -torch.norm(proj_h - proj_t, dim=1, p=self.scoring_fct_norm)
return scores

def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102
# Get embeddings
h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, self.embedding_dim, 1)
rel_h = self.left_relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim)
rel_t = self.right_relation_embeddings(indices=hr_batch[:, 1])
rel_t = rel_t.view(-1, 1, self.embedding_dim, self.embedding_dim)
t_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)

if slice_size is not None:
proj_t_arr = []
# Project entities
proj_h = rel_h @ h

for t in torch.split(t_all, slice_size, dim=1):
# Project entities
proj_t = rel_t @ t
proj_t_arr.append(proj_t)

proj_t = torch.cat(proj_t_arr, dim=1)

else:
# Project entities
proj_h = rel_h @ h
proj_t = rel_t @ t_all

scores = -torch.norm(proj_h[:, None, :, 0] - proj_t[:, :, :, 0], dim=-1, p=self.scoring_fct_norm)

return scores

def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102
# Get embeddings
h_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1)
rel_h = self.left_relation_embeddings(indices=rt_batch[:, 0])
rel_h = rel_h.view(-1, 1, self.embedding_dim, self.embedding_dim)
rel_t = self.right_relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.embedding_dim)
t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim, 1)

if slice_size is not None:
proj_h_arr = []

# Project entities
proj_t = rel_t @ t

for h in torch.split(h_all, slice_size, dim=1):
# Project entities
proj_h = rel_h @ h
proj_h_arr.append(proj_h)

proj_h = torch.cat(proj_h_arr, dim=1)
else:
# Project entities
proj_h = rel_h @ h_all
proj_t = rel_t @ t

scores = -torch.norm(proj_h[:, :, :, 0] - proj_t[:, None, :, 0], dim=-1, p=self.scoring_fct_norm)

return scores
20 changes: 2 additions & 18 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class TestSimplE(cases.ModelTestCase):
cls = pykeen.models.SimplE


class _BaseTestSE(cases.ModelTestCase):
class TestSE(cases.ModelTestCase):
"""Test the Structured Embedding model."""

cls = pykeen.models.StructuredEmbedding
Expand All @@ -337,26 +337,10 @@ def _check_constraints(self):
Entity embeddings have to have unit L2 norm.
"""
norms = self.instance.entity_embeddings(indices=None).norm(p=2, dim=-1)
norms = self.instance.entity_representations[0](indices=None).norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms))


class TestSELowMemory(_BaseTestSE):
"""Tests SE with low memory."""

training_loop_kwargs = {
'automatic_memory_optimization': True,
}


class TestSEHighMemory(_BaseTestSE):
"""Tests SE with low memory."""

training_loop_kwargs = {
'automatic_memory_optimization': False,
}


class TestTorusE(cases.DistanceModelTestCase):
"""Test the TorusE model."""

Expand Down

0 comments on commit 85519d8

Please sign in to comment.