diff --git a/tensorflow_addons/optimizers/average_wrapper.py b/tensorflow_addons/optimizers/average_wrapper.py index ded496478f..d7c46d2d71 100644 --- a/tensorflow_addons/optimizers/average_wrapper.py +++ b/tensorflow_addons/optimizers/average_wrapper.py @@ -46,6 +46,7 @@ def __init__( raise TypeError("sequential_update must be of bool type") self._optimizer = optimizer + self._track_trackable(self._optimizer, "awg_optimizer") if sequential_update is not None: warnings.warn( diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 31fe13b043..e18ccaa8e0 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -80,6 +80,7 @@ def __init__( self._set_hyper("sync_period", sync_period) self._set_hyper("slow_step_size", slow_step_size) self._initialized = False + self._track_trackable(self._optimizer, "lh_base_optimizer") def _create_slots(self, var_list): self._optimizer._create_slots(