Skip to content

Commit

Permalink
2️⃣↕️ Add DoubleMarginLoss (pykeen#539)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>
Co-authored-by: PyKEEN_bot <pykeen2019@gmail.com>
  • Loading branch information
3 people authored Jul 20, 2021
1 parent 6ba0a38 commit 44b3337
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 14 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,18 @@ have a suggestion for another dataset to include in PyKEEN, please let us know
| TuckER | [`pykeen.models.TuckER`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.TuckER.html) | [Balažević *et al.*, 2019](https://arxiv.org/abs/1901.09590) |
| Unstructured Model | [`pykeen.models.UnstructuredModel`](https://pykeen.readthedocs.io/en/latest/api/pykeen.models.UnstructuredModel.html) | [Bordes *et al.*, 2014](https://link.springer.com/content/pdf/10.1007%2Fs10994-013-5363-6.pdf) |

### Losses (7)

| Name | Reference | Description |
|--------------------------------------|---------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------|
| Binary cross entropy (after sigmoid) | [`pykeen.losses.BCEAfterSigmoidLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.BCEAfterSigmoidLoss.html) | A module for the numerically unstable version of explicit Sigmoid + BCE loss. |
| Binary cross entropy (with logits) | [`pykeen.losses.BCEWithLogitsLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.BCEWithLogitsLoss.html) | A module for the binary cross entropy loss. |
| Cross entropy | [`pykeen.losses.CrossEntropyLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.CrossEntropyLoss.html) | A module for the cross entropy loss that evaluates the cross entropy after softmax output. |
| Margin ranking | [`pykeen.losses.MarginRankingLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.MarginRankingLoss.html) | A module for the margin ranking loss. |
| Mean square error | [`pykeen.losses.MSELoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.MSELoss.html) | A module for the mean square error loss. |
| Self-adversarial negative sampling | [`pykeen.losses.NSSALoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.NSSALoss.html) | An implementation of the self-adversarial negative sampling loss function proposed by [sun2019]_. |
| Softplus | [`pykeen.losses.SoftplusLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.SoftplusLoss.html) | A module for the softplus loss. |
### Losses (8)

| Name | Reference | Description |
|--------------------------------------|---------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------|
| Binary cross entropy (after sigmoid) | [`pykeen.losses.BCEAfterSigmoidLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.BCEAfterSigmoidLoss.html) | A module for the numerically unstable version of explicit Sigmoid + BCE loss. |
| Binary cross entropy (with logits) | [`pykeen.losses.BCEWithLogitsLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.BCEWithLogitsLoss.html) | A module for the binary cross entropy loss. |
| Cross entropy | [`pykeen.losses.CrossEntropyLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.CrossEntropyLoss.html) | A module for the cross entropy loss that evaluates the cross entropy after softmax output. |
| Double Margin | [`pykeen.losses.DoubleMarginLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.DoubleMarginLoss.html) | A limit-based scoring loss, with separate margins for positive and negative elements from [sun2018]_. |
| Margin ranking | [`pykeen.losses.MarginRankingLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.MarginRankingLoss.html) | A module for the margin ranking loss. |
| Mean square error | [`pykeen.losses.MSELoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.MSELoss.html) | A module for the mean square error loss. |
| Self-adversarial negative sampling | [`pykeen.losses.NSSALoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.NSSALoss.html) | An implementation of the self-adversarial negative sampling loss function proposed by [sun2019]_. |
| Softplus | [`pykeen.losses.SoftplusLoss`](https://pykeen.readthedocs.io/en/latest/api/pykeen.losses.SoftplusLoss.html) | A module for the softplus loss. |

### Regularizers (5)

Expand Down
4 changes: 4 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ References
.. [galkin2020] Galkin, M., *et al.* (2020). `Message Passing for Hyper-Relational Knowledge Graphs
<https://doi.org/10.18653/v1/2020.emnlp-main.596>`_. Proceedings of the 2020 Conference on Empirical
Methods in Natural Language Processing (EMNLP), 7346–7359.
.. [sun2018] Sun, Z., *et al.* (2018). `Bootstrapping Entity Alignment with Knowledge Graph Embedding.
<https://dl.acm.org/doi/10.5555/3304222.3304381>`_
*Proceedings of the 27th International Joint Conference on Artificial Intelligence*, 4396–4402.
242 changes: 241 additions & 1 deletion src/pykeen/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@
Pairwise Logistic $h(\Delta) = \log(1 + \exp(\Delta))$
=============================== ==============================================
Atypical Pairwise Loss Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The following pairwise loss function use the full generalized form of $L(k, \bar{k}) = \dots$
for their definitions:
.. table::
:align: center
:widths: auto
============== =============================================
Pairwise Loss Formulation
============== =============================================
Double Loss $h(\bar{\lambda} + \bar{k}) + h(\lambda - k)$
============== =============================================
Batching
~~~~~~~~
The pairwise loss for a set of pairs of positive/negative triples $\mathcal{L}_L: 2^{\mathcal{K} \times
Expand Down Expand Up @@ -133,7 +148,9 @@
\mathcal{L}_L(\mathcal{B}) = \frac{1}{|\mathcal{B}|} \sum \limits_{\mathcal{b} \in \mathcal{B}} L(\mathcal{b})
"""

from typing import Any, ClassVar, Mapping, Optional, Set
import logging
from textwrap import dedent
from typing import Any, ClassVar, Mapping, Optional, Set, Tuple

import torch
from class_resolver import Hint, Resolver
Expand All @@ -156,10 +173,13 @@
'MSELoss',
'NSSALoss',
'SoftplusLoss',
'DoubleMarginLoss',
# Utils
'loss_resolver',
]

logger = logging.getLogger(__name__)


def apply_label_smoothing(
labels: torch.FloatTensor,
Expand Down Expand Up @@ -510,6 +530,226 @@ def forward(
))


@parse_docdata
class DoubleMarginLoss(PointwiseLoss):
r"""A limit-based scoring loss, with separate margins for positive and negative elements from [sun2018]_.
Despite its similarity to the margin-based loss, this loss is quite different to it, since it uses absolute margins
for positive/negative scores, rather than comparing the difference. Hence, it has a natural decision boundary
(somewhere between the positive and negative margin), while still resulting in sparse losses with no gradients for
sufficiently correct examples.
.. math ::
L(k, \bar{k}) = h(\bar{\lambda} + \bar{k}) + h(\lambda - k)
Where $k$ is positive scores, $\bar{k}$ is negative scores, $\lambda$ is the positive margin, $\bar{\lambda}$ is
the negative margin, and $h$ is an activation function, like the ReLU or softmax.
---
name: Double Margin
"""

hpo_default: ClassVar[Mapping[str, Any]] = dict(
margin_positive=dict(type=float, low=-1, high=1),
offset=dict(type=float, low=0, high=1),
positive_negative_balance=dict(type=float, low=1.0e-03, high=1.0 - 1.0e-03),
margin_activation=dict(
type='categorical',
choices=margin_activation_resolver.options,
),
)

@staticmethod
def resolve_margin(
positive_margin: Optional[float],
negative_margin: Optional[float],
offset: Optional[float],
) -> Tuple[float, float]:
"""Resolve margins from multiple methods how to specify them.
The method supports three combinations:
- positive_margin & negative_margin.
This returns the values as-is.
- negative_margin & offset
This sets positive_margin = negative_margin + offset
- positive_margin & offset
This sets negative_margin = positive_margin - offset
.. note ::
Notice that this method does not apply a precedence between the three methods, but requires the remaining
parameter to be None. This is done to fail fast on ambiguous input rather than delay a failure to a later
point in time where it might be harder to find its cause.
:param positive_margin:
The (absolute) margin for the positive scores. Should be larger than the negative one.
:param negative_margin:
The (absolute) margin for the negative scores. Should be smaller than the positive one.
:param offset:
The offset between positive and negative margin. Must be non-negative.
:returns:
A pair of the positive and negative margin. Guaranteed to fulfil positive_margin >= negative_margin.
:raises ValueError:
In case of an invalid combination.
"""
# 1. positive & negative margin
if positive_margin is not None and negative_margin is not None and offset is None:
if negative_margin > positive_margin:
raise ValueError(
f"Positive margin ({positive_margin}) must not be smaller than the negative one "
f"({negative_margin}).",
)
return positive_margin, negative_margin

# 2. negative margin & offset
if negative_margin is not None and offset is not None and positive_margin is None:
if offset < 0:
raise ValueError(f"The offset must not be negative, but it is: {offset}")
return negative_margin + offset, negative_margin

# 3. positive margin & offset
if positive_margin is not None and offset is not None and negative_margin is None:
if offset < 0:
raise ValueError(f"The offset must not be negative, but it is: {offset}")
return positive_margin, positive_margin - offset

raise ValueError(dedent(f"""\
Invalid combination of margins and offset:
positive_margin={positive_margin}
negative_margin={negative_margin}
offset={offset}
Supported are:
1. positive & negative margin
2. negative margin & offset
3. positive margin & offset
"""))

def __init__(
self,
*,
positive_margin: Optional[float] = 1.0,
negative_margin: Optional[float] = 0.0,
offset: Optional[float] = None,
positive_negative_balance: float = 0.5,
margin_activation: Hint[nn.Module] = 'relu',
reduction: str = 'mean',
):
r"""Initialize the double margin loss.
.. note ::
There are multiple variants to set the pair of margins. A full documentation is provided in
:func:`DoubleMarginLoss.resolve_margins`.
:param positive_margin:
The (absolute) margin for the positive scores. Should be larger than the negative one.
:param negative_margin:
The (absolute) margin for the negative scores. Should be smaller than the positive one.
:param offset:
The offset between positive and negative margin. Must be non-negative.
:param positive_negative_balance:
The balance between positive and negative term. Must be in (0, 1).
:param margin_activation:
A margin activation. Defaults to ``'relu'``, i.e. $h(\Delta) = max(0, \Delta + \lambda)$, which is the
default "margin loss". Using ``'softplus'`` leads to a "soft-margin" formulation as discussed in
https://arxiv.org/abs/1703.07737.
:param reduction:
The name of the reduction operation to aggregate the individual loss values from a batch to a scalar loss
value. From {'mean', 'sum'}.
:raises ValueError: If the positive/negative balance is not within the right range
"""
super().__init__(reduction=reduction)
if not (0 <= positive_negative_balance <= 1):
raise ValueError(
f"The positive-negative balance weight must be in (0, 1), but is {positive_negative_balance}",
)
self.positive_margin, self.negative_margin = self.resolve_margin(
positive_margin=positive_margin,
negative_margin=negative_margin,
offset=offset,
)
self.negative_weight = 1.0 - positive_negative_balance
self.positive_weight = positive_negative_balance
self.margin_activation = margin_activation_resolver.make(margin_activation)

def process_slcwa_scores(
self,
positive_scores: torch.FloatTensor,
negative_scores: torch.FloatTensor,
label_smoothing: Optional[float] = None,
batch_filter: Optional[torch.BoolTensor] = None,
num_entities: Optional[int] = None,
) -> torch.FloatTensor: # noqa: D102
# Sanity check
if label_smoothing:
raise UnsupportedLabelSmoothingError(self)

# positive term
if batch_filter is None:
# implicitly repeat positive scores
positive_loss = self.margin_activation(self.positive_margin - positive_scores)
positive_loss = self._reduction_method(positive_loss)
if self.reduction == "sum":
positive_loss = positive_loss * negative_scores.shape[1]
elif self.reduction != "mean":
raise NotImplementedError(
f"There is not implementation for reduction={self.reduction} and filtered negatives",
)
else:
num_neg_per_pos = batch_filter.shape[1]
positive_scores = positive_scores.unsqueeze(dim=1).repeat(1, num_neg_per_pos, 1)[batch_filter]
# shape: (nnz,)
positive_loss = self._reduction_method(self.margin_activation(self.positive_margin - positive_scores))

# negative term
# negative_scores have already been filtered in the sampler!
negative_loss = self._reduction_method(self.margin_activation(self.negative_margin + negative_scores))
return self.positive_weight * positive_loss + self.negative_weight * negative_loss

def process_lcwa_scores(
self,
predictions: torch.FloatTensor,
labels: torch.FloatTensor,
label_smoothing: Optional[float] = None,
num_entities: Optional[int] = None,
) -> torch.FloatTensor: # noqa: D102
# Sanity check
if label_smoothing:
labels = apply_label_smoothing(
labels=labels,
epsilon=label_smoothing,
num_classes=num_entities,
)

return self(predictions=predictions, labels=labels)

def forward(
self,
predictions: torch.FloatTensor,
labels: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Compute the double margin loss.
The scores have to be in broadcastable shape.
:param predictions:
The predicted scores.
:param labels:
The labels.
:return:
A scalar loss term.
"""
return self.positive_weight * self._reduction_method(
labels * self.margin_activation(self.positive_margin - predictions),
) + self.negative_weight * self._reduction_method(
(1.0 - labels) * self.margin_activation(self.negative_margin + predictions),
)


@parse_docdata
class SoftplusLoss(PointwiseLoss):
r"""
Expand Down
11 changes: 9 additions & 2 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import unittest_templates

from pykeen.losses import (
BCEAfterSigmoidLoss, BCEWithLogitsLoss, CrossEntropyLoss, Loss, MSELoss, MarginRankingLoss, NSSALoss, PairwiseLoss,
PointwiseLoss, SetwiseLoss, SoftplusLoss, UnsupportedLabelSmoothingError, apply_label_smoothing,
BCEAfterSigmoidLoss, BCEWithLogitsLoss, CrossEntropyLoss, DoubleMarginLoss, Loss, MSELoss, MarginRankingLoss,
NSSALoss, PairwiseLoss, PointwiseLoss, SetwiseLoss, SoftplusLoss, UnsupportedLabelSmoothingError,
apply_label_smoothing,
)
from pykeen.pipeline import PipelineResult, pipeline
from tests import cases
Expand All @@ -28,6 +29,12 @@ class BCEAfterSigmoidLossTests(cases.PointwiseLossTestCase):
cls = BCEAfterSigmoidLoss


class DoubleMarginLossTests(cases.PointwiseLossTestCase):
"""Unit test for DoubleMarginLoss."""

cls = DoubleMarginLoss


class SoftplusLossTests(cases.PointwiseLossTestCase):
"""Unit test for SoftplusLoss."""

Expand Down

0 comments on commit 44b3337

Please sign in to comment.