Skip to content

Commit

Permalink
Import DynamicScale directly.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 433342694
  • Loading branch information
jihwanlee-alphago authored and copybara-github committed Mar 9, 2022
1 parent e367204 commit 33fdafd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion aqt/jax/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flax import jax_utils
from flax import optim
from flax.metrics import tensorboard
from flax.optim import dynamic_scale as dynamic_scale_lib
from flax.training import checkpoints
from flax.training import common_utils
import jax
Expand Down Expand Up @@ -256,7 +257,7 @@ def main(argv):
else:
model_dtype = jnp.float16
input_dtype = tf.float16
dynamic_scale = optim.DynamicScale()
dynamic_scale = dynamic_scale_lib.DynamicScale()
else:
model_dtype = jnp.float32
input_dtype = tf.float32
Expand Down
3 changes: 2 additions & 1 deletion aqt/jax/imagenet/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import flax
from flax import jax_utils
from flax import optim
from flax.optim import dynamic_scale as dynamic_scale_lib
from flax.training import common_utils
import jax
from jax import lax
Expand Down Expand Up @@ -192,7 +193,7 @@ class TrainState:
step: int
optimizer: optim.Optimizer
model_state: Mapping[str, Any]
dynamic_scale: optim.DynamicScale
dynamic_scale: dynamic_scale_lib.DynamicScale


def sync_batch_stats(state):
Expand Down

0 comments on commit 33fdafd

Please sign in to comment.