Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Remote on_update hook call from train_step (#399)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #399

We want people to write their own train_step functions, and having a hook in
there makes that complicated. on_update was already happening at the end of
train_step, so move it out into the step function.

Another subtle change here is that we now call on_step for both train and test
phases.

Reviewed By: aadcock

Differential Revision: D19906465

fbshipit-source-id: 9e4a094e1aa2f9b2c2ddd4a72554febe05b715e4
  • Loading branch information
vreis authored and facebook-github-bot committed Feb 24, 2020
1 parent e2f2acd commit 4ff72b7
Show file tree
Hide file tree
Showing 20 changed files with 40 additions and 36 deletions.
2 changes: 1 addition & 1 deletion classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CheckpointHook(ClassyHook):
on_phase_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_end = ClassyHook._noop

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ClassyHook(ABC):
are listed below in the chronological order.
on_start -> on_phase_start -> on_forward -> on_loss_and_meter ->
on_update -> on_phase_end -> on_end
on_step -> on_phase_end -> on_end
Deriving classes should call ``super().__init__()`` and store any state in
``self.state``. Any state added to this property should be serializable.
Expand Down Expand Up @@ -93,7 +93,7 @@ def on_loss_and_meter(
pass

@abstractmethod
def on_update(
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
"""Called each time after parameters have been updated by the optimizer."""
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class ClassyHookFunctions(Enum):
on_phase_start = auto()
on_forward = auto()
on_loss_and_meter = auto()
on_update = auto()
on_step = auto()
on_phase_end = auto()
on_end = auto()
5 changes: 4 additions & 1 deletion classy_vision/hooks/exponential_moving_average_model_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_update(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
if not task.train:
return

with torch.no_grad():
for name, param in self.get_model_state_iterator(task.base_model):
self.state.ema_model_state[
Expand Down
4 changes: 2 additions & 2 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def on_phase_end(
if task.train:
self._log_lr(task, local_variables)

def on_update(
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/model_complexity_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ModelComplexityHook(ClassyHook):
on_phase_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/model_tensorboard_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelTensorboardHook(ClassyHook):
on_phase_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ProfilerHook(ClassyHook):
on_phase_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

Expand Down
4 changes: 2 additions & 2 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_update(
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
"""Update the progress bar with the batch size."""
if is_master() and self.progress_bar is not None:
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
self.progress_bar.update(min(self.batches, self.bar_size))

Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_phase_start(
self.wall_times = []
self.num_steps_global = []

def on_update(
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
"""Store the observed learning rates."""
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TimeMetricsHook(ClassyHook):

on_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_end = ClassyHook._noop

def __init__(self, log_freq: Optional[int] = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/hooks/visdom_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class VisdomHook(ClassyHook):
on_phase_start = ClassyHook._noop
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_end = ClassyHook._noop

def __init__(
Expand Down
2 changes: 0 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,6 @@ def train_step(self, use_gpu, local_variables=None):
self.optimizer.update_schedule_on_step(self.where)
self.optimizer.step()

self.run_hooks(local_variables, ClassyHookFunctions.on_update.name)

self.num_updates += self.get_global_batchsize()

def compute_loss(self, model_output, sample):
Expand Down
4 changes: 4 additions & 0 deletions classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,15 @@ def eval_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
pass

def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
from classy_vision.hooks import ClassyHookFunctions

if self.train:
self.train_step(use_gpu, local_variables)
else:
self.eval_step(use_gpu, local_variables)

self.run_hooks(local_variables, ClassyHookFunctions.on_step.name)

def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None:
"""
Helper function that runs a hook function for all the
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_classy_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestHook(ClassyHook):
on_forward = ClassyHook._noop
on_loss_and_meter = ClassyHook._noop
on_backward = ClassyHook._noop
on_update = ClassyHook._noop
on_step = ClassyHook._noop
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

Expand Down
2 changes: 1 addition & 1 deletion test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_update(task, local_variables)
exponential_moving_average_hook.on_step(task, local_variables)
exponential_moving_average_hook.on_phase_end(task, local_variables)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
Expand Down
10 changes: 3 additions & 7 deletions test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,13 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
local_variables = {}
task.phase_idx = 0

loss_vals = {"train": 1.425, "test": 0.57}

for log_freq, phase_type in product([5, None], loss_vals):
task.train = phase_type == "train"

for log_freq in [5, None]:
# create a loss lr meter hook
loss_lr_meter_hook = LossLrMeterLoggingHook(log_freq=log_freq)

# check that _log_loss_meters() is called after on_loss_and_meter() every
# log_freq batches and after on_phase_end()
# and _log_lr() is called after on_update() every log_freq batches
# and _log_lr() is called after on_step() every log_freq batches
# and after on_phase_end()
with mock.patch.object(loss_lr_meter_hook, "_log_loss_meters") as mock_fn:
with mock.patch.object(loss_lr_meter_hook, "_log_lr") as mock_lr_fn:
Expand All @@ -56,7 +52,7 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
for i in range(num_batches):
task.losses = list(range(i))
loss_lr_meter_hook.on_loss_and_meter(task, local_variables)
loss_lr_meter_hook.on_update(task, local_variables)
loss_lr_meter_hook.on_step(task, local_variables)
if log_freq is not None and i and i % log_freq == 0:
mock_fn.assert_called_with(task, local_variables)
mock_fn.reset_mock()
Expand Down
12 changes: 6 additions & 6 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def test_progress_bar(
mock_progress_bar.start.reset_mock()
mock_progressbar_pkg.ProgressBar.reset_mock()

# on_update should update the progress bar correctly
# on_step should update the progress bar correctly
for i in range(num_batches):
progress_bar_hook.on_update(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
mock_progress_bar.update.assert_called_once_with(i + 1)
mock_progress_bar.update.reset_mock()

# check that even if on_update is called again, the progress bar is
# check that even if on_step is called again, the progress bar is
# only updated with num_batches
for _ in range(num_batches):
progress_bar_hook.on_update(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
mock_progress_bar.update.assert_called_once_with(num_batches)
mock_progress_bar.update.reset_mock()

Expand All @@ -68,7 +68,7 @@ def test_progress_bar(
# crash
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_update(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail(
Expand All @@ -81,7 +81,7 @@ def test_progress_bar(
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_phase_start(task, local_variables)
progress_bar_hook.on_update(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail("Received Exception when is_master() is False: {}".format(e))
Expand Down
6 changes: 3 additions & 3 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:

# test that the hook logs a warning and doesn't write anything to
# the writer if on_phase_start() is not called for initialization
# before on_update() is called.
# before on_step() is called.
with self.assertLogs() as log_watcher:
tensorboard_plot_hook.on_update(task, local_variables)
tensorboard_plot_hook.on_step(task, local_variables)

self.assertTrue(
len(log_watcher.records) == 1
Expand All @@ -88,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:

for loss in losses:
task.losses.append(loss)
tensorboard_plot_hook.on_update(task, local_variables)
tensorboard_plot_hook.on_step(task, local_variables)

tensorboard_plot_hook.on_phase_end(task, local_variables)

Expand Down
5 changes: 4 additions & 1 deletion test/optim_param_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class TestHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_update(self, task: ClassyTask, local_variables) -> None:
def on_step(self, task: ClassyTask, local_variables) -> None:
if not task.train:
return

# make sure we have non-zero param groups
test_instance.assertGreater(
len(task.optimizer.optimizer.param_groups), 0
Expand Down

0 comments on commit 4ff72b7

Please sign in to comment.