Skip to content

Commit

Permalink
Add epoch_bound parameter to RunningAverage. (pytorch#488)
Browse files Browse the repository at this point in the history
* Add `epoch_bound` parameter to `RunningAverage`.

This optional parameter controls whether the running average should be reset after each epoch.

* +integration test

* fix lint
  • Loading branch information
Evpok authored and vfdev-5 committed Apr 12, 2019
1 parent f53f2ab commit ca3ba0c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 3 deletions.
10 changes: 7 additions & 3 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class RunningAverage(Metric):
alpha (float, optional): running average decay factor, default 0.98
output_transform (callable, optional): a function to use to transform the output if `src` is None and
corresponds the output of process function. Otherwise it should be None.
epoch_bound (boolean, optional): whether the running average should be reset after each epoch (defaults
to True).
Examples:
Expand All @@ -30,7 +32,7 @@ def log_running_avg_metrics(engine):
"""

def __init__(self, src=None, alpha=0.98, output_transform=None):
def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True):
if not (isinstance(src, Metric) or src is None):
raise TypeError("Argument src should be a Metric or None.")
if not (0.0 < alpha <= 1.0):
Expand All @@ -50,6 +52,7 @@ def __init__(self, src=None, alpha=0.98, output_transform=None):
self.update = self._output_update

self.alpha = alpha
self.epoch_bound = epoch_bound
super(RunningAverage, self).__init__(output_transform=output_transform)

def reset(self):
Expand All @@ -67,8 +70,9 @@ def compute(self):
return self._value

def attach(self, engine, name):
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
if self.epoch_bound:
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
Expand Down
67 changes: 67 additions & 0 deletions tests/ignite/metrics/test_running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,73 @@ def assert_equal_running_avg_output_values(engine):
trainer.run(data, max_epochs=1)


def test_epoch_unbound():

n_iters = 10
n_epochs = 3
batch_size = 10
n_classes = 10
data = list(range(n_iters))
loss_values = iter(range(n_epochs * n_iters))
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_epochs * n_iters, batch_size)))
y_pred_batch_values = iter(np.random.rand(n_epochs * n_iters, batch_size, n_classes))

def update_fn(engine, batch):
loss_value = next(loss_values)
y_true_batch = next(y_true_batch_values)
y_pred_batch = next(y_pred_batch_values)
return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

trainer = Engine(update_fn)
alpha = 0.98

acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha, epoch_bound=False)
acc_metric.attach(trainer, 'running_avg_accuracy')

avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha, epoch_bound=False)
avg_output.attach(trainer, 'running_avg_output')

running_avg_acc = [None]

@trainer.on(Events.STARTED)
def running_avg_output_init(engine):
engine.state.running_avg_output = None

@trainer.on(Events.ITERATION_COMPLETED, running_avg_acc)
def manual_running_avg_acc(engine, running_avg_acc):
_, y_pred, y = engine.state.output
indices = torch.max(y_pred, 1)[1]
correct = torch.eq(indices, y).view(-1)
num_correct = torch.sum(correct).item()
num_examples = correct.shape[0]
batch_acc = num_correct * 1.0 / num_examples
if running_avg_acc[0] is None:
running_avg_acc[0] = batch_acc
else:
running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc
engine.state.running_avg_acc = running_avg_acc[0]

@trainer.on(Events.ITERATION_COMPLETED)
def running_avg_output_update(engine):
if engine.state.running_avg_output is None:
engine.state.running_avg_output = engine.state.output[0]
else:
engine.state.running_avg_output = engine.state.running_avg_output * alpha + \
(1.0 - alpha) * engine.state.output[0]

@trainer.on(Events.ITERATION_COMPLETED)
def assert_equal_running_avg_acc_values(engine):
assert engine.state.running_avg_acc == engine.state.metrics['running_avg_accuracy'], \
"{} vs {}".format(engine.state.running_avg_acc, engine.state.metrics['running_avg_accuracy'])

@trainer.on(Events.ITERATION_COMPLETED)
def assert_equal_running_avg_output_values(engine):
assert engine.state.running_avg_output == engine.state.metrics['running_avg_output'], \
"{} vs {}".format(engine.state.running_avg_output, engine.state.metrics['running_avg_output'])

trainer.run(data, max_epochs=3)


def test_multiple_attach():
n_iters = 100
errD_values = iter(np.random.rand(n_iters, ))
Expand Down

0 comments on commit ca3ba0c

Please sign in to comment.