Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Fixes pytorch#689
- handles (y_pred, y) or {'y_pred': y_pred, 'y': y, ...} as output argument for update function

* Update documentation
  • Loading branch information
vfdev-5 authored Jan 15, 2020
1 parent 8c8c3c2 commit 4f5b7fc
Show file tree
Hide file tree
Showing 33 changed files with 106 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ value is then computed using the output of the engine's `process_function`:
metric = Accuracy()
metric.attach(engine, "accuracy")
If the engine's output is not in the format `y_pred, y`, the user can
If the engine's output is not in the format `(y_pred, y)` or `{'y_pred': y_pred, 'y': y, ...}`, the user can
use the `output_transform` argument to transform it:

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class AveragePrecision(EpochMetric):
def activated_output_transform(output):
y_pred, y = output
y_pred = torch.softmax(y_pred)
y_pred = torch.softmax(y_pred, dim=1)
return y_pred, y
avg_precision = AveragePrecision(activated_output_transform)
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/canberra_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CanberraMetric(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FractionalAbsoluteError(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/fractional_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FractionalBias(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class GeometricMeanAbsoluteError(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/manhattan_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ManhattanDistance(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MaximumAbsoluteError(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MeanAbsoluteRelativeError(_BaseRegression):
More details can be found in the reference `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/ftp/arxiv/papers/1809/1809.03006.pdf
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MeanError(_BaseRegression):
More details can be found in the reference `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/mean_normalized_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MeanNormalizedBias(_BaseRegression):
More details can be found in the reference `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/median_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MedianAbsoluteError(_BaseRegressionEpoch):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
.. warning::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
.. warning::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
.. warning::
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/r2_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class R2Score(_BaseRegression):
where :math:`A_j` is the ground truth, :math:`P_j` is the predicted value and
:math:`\bar{A}` is the mean of the ground truth.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
"""
def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/regression/wave_hedges_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class WaveHedgesDistance(_BaseRegression):
More details can be found in `Botchkarev 2018`__.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
__ https://arxiv.org/abs/1809.03006
Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
from collections import defaultdict, OrderedDict
from collections import Mapping
from collections.abc import Mapping
from enum import Enum
import weakref
import numbers
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class VariableAccumulation(Metric):
initialized and available, device is set to `cuda`.
"""
_required_output_keys = None

def __init__(self, op, output_transform=lambda x: x, device=None):
if not callable(op):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Accuracy(_BaseClassification):
"""
Calculates the accuracy for binary, multiclass and multilabel data.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
- `y` must be in the following shape (batch_size, ...).
- `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) for multilabel cases.
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class ConfusionMatrix(Metric):
"""Calculates confusion matrix for multi-class data.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...)
- `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
with or without the background class. During the computation, argmax of `y_pred` is taken to determine
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class EpochMetric(Metric):
Current implementation does not work with distributed computations. Results are not gather across all devices
and computed results are valid for a single device only.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
If target shape is `(batch_size, n_classes)` and `n_classes > 1` than it should be binary: e.g. `[[0, 1, 0, 1], ]`.
Expand Down
5 changes: 3 additions & 2 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class Loss(Metric):
form expected by the metric.
This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
The output is is expected to be a tuple (prediction, target) or
The output is expected to be a tuple `(prediction, target)` or
(prediction, target, kwargs) where kwargs is a dictionary of extra
keywords arguments.
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
batch_size (callable): a callable taking a target tensor that returns the
first dimension size (usually the batch size).
device (str of torch.device, optional): device specification in case of distributed computation usage.
Expand All @@ -29,6 +29,7 @@ class Loss(Metric):
initialized and available, device is set to `cuda`.
"""
_required_output_keys = None

def __init__(self, loss_fn, output_transform=lambda x: x,
batch_size=lambda x: len(x), device=None):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MeanAbsoluteError(Metric):
"""
Calculates the mean absolute error.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
"""
@reinit__is_reduced
def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MeanPairwiseDistance(Metric):
"""
Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
"""
def __init__(self, p=2, eps=1e-6, output_transform=lambda x: x, device=None):
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MeanSquaredError(Metric):
"""
Calculates the mean squared error.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
"""
@reinit__is_reduced
def reset(self):
Expand Down
12 changes: 12 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numbers
from abc import ABCMeta, abstractmethod
from functools import wraps
from collections.abc import Mapping
import warnings

import torch
Expand All @@ -18,12 +19,14 @@ class Metric(metaclass=ABCMeta):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
device (str of torch.device, optional): device specification in case of distributed computation usage.
In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
initialized and available, device is set to `cuda`.
"""
_required_output_keys = ("y_pred", "y")

def __init__(self, output_transform=lambda x: x, device=None):
self._output_transform = output_transform
Expand Down Expand Up @@ -110,6 +113,15 @@ def started(self, engine):
@torch.no_grad()
def iteration_completed(self, engine):
output = self._output_transform(engine.state.output)
if isinstance(output, Mapping):
if self._required_output_keys is None:
raise TypeError("Transformed engine output for {} metric should be a tuple/list, but given {}"
.format(self.__class__.__name__, type(output)))
if not all([k in output for k in self._required_output_keys]):
raise ValueError("When transformed engine's output is a mapping, "
"it should contain {} keys, but given {}".format(self._required_output_keys,
list(output.keys())))
output = tuple(output[k] for k in self._required_output_keys)
self.update(output)

def completed(self, engine, name):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Precision(_BasePrecisionRecall):
"""
Calculates precision for binary and multiclass data.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
- `y` must be in the following shape (batch_size, ...).
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Recall(_BasePrecisionRecall):
"""
Calculates recall for binary and multiclass data.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
- `y` must be in the following shape (batch_size, ...).
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class RootMeanSquaredError(MeanSquaredError):
"""
Calculates the root mean squared error.
- `update` must receive output of the form (y_pred, y).
- `update` must receive output of the form (y_pred, y) or `{'y_pred': y_pred, 'y': y}`.
"""
def compute(self):
mse = super(RootMeanSquaredError, self).compute()
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def log_running_avg_metrics(engine):
print("running avg loss:", engine.state.metrics['running_avg_loss'])
"""
_required_output_keys = None

def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None):
if not (isinstance(src, Metric) or src is None):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class TopKCategoricalAccuracy(Metric):
"""
Calculates the top-k categorical accuracy.
- `update` must receive output of the form `(y_pred, y)`.
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
"""
def __init__(self, k=5, output_transform=lambda x: x, device=None):
super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device)
Expand Down
Loading

0 comments on commit 4f5b7fc

Please sign in to comment.