Skip to content

Commit

Permalink
Refactor ResNet-50 model code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 414013885
  • Loading branch information
ychzhang authored and copybara-github committed Dec 3, 2021
1 parent 9671c78 commit 58eea29
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 229 deletions.
5 changes: 3 additions & 2 deletions aqt/jax/flax_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
"""

import contextlib
import dataclasses
import typing
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, Union

from absl import flags
import dataclasses
import flax
from flax import linen as nn
import jax
Expand Down Expand Up @@ -324,7 +324,8 @@ def __call__(self, inputs):
half_shift=hparams.weight_half_shift,
axis=kernel_reduction_axis,
expected_scale_shape=expected_scale_shape),
quantized_type=quantized_type)
quantized_type=quantized_type,
quantize_weights=self.quant_context.quantize_weights)

# Convolution
dimension_numbers = flax.nn.linear._conv_dimension_numbers(inputs.shape) # pylint: disable=protected-access
Expand Down
9 changes: 9 additions & 0 deletions aqt/jax/imagenet/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,20 @@ def get_base_config(imagenet_type, quant_target):
"cooldown_epochs": 50,
"scheduler": "cosine",
"num_epochs": 250,
"endlr": 0.0,
"knee_lr": 1e-5,
"knee_epochs": 125,
},
"optimizer": "sgd",
"adam": {
"beta1": 0.9,
"beta2": 0.999
},
"early_stop_steps": -1, # -1 means no early stop
"weight_quant_start_step": 0, # 0 means turned on by default
"teacher_model": "labels",
"is_teacher": True, # by default train the vanilla resnet
"seed": 0,
})

proj_layers = [sum(resnet_layers[:x]) for x in range(len(resnet_layers))]
Expand All @@ -134,6 +141,8 @@ def get_base_config(imagenet_type, quant_target):
idx].conv_1.quant_act.input_distribution = "positive"

config.model_hparams.filter_multiplier = 1.
config.model_hparams.se_ratio = 0.5
config.model_hparams.init_group = 32
config.half_shift = False

return config
Expand Down
11 changes: 10 additions & 1 deletion aqt/jax/imagenet/configs_script/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def get_base_config(use_auto_acts):
"cooldown_epochs": int_ph(),
"scheduler": str_ph(), # "cosine", "linear", or "step" lr decay
"num_epochs": int_ph(),
"endlr": float_ph(),
"knee_lr": float_ph(),
"knee_epochs": int_ph(),
},
"optimizer": str_ph(),
"adam": {
"beta1": float_ph(),
"beta2": float_ph()
},
"early_stop_steps": int_ph(),
"weight_quant_start_step": int_ph(),
"teacher_model": str_ph(),
"is_teacher": bool_ph(), # whether to train vanilla resnet or PokeBNN
"seed": int_ph(),
})
if use_auto_acts:
# config_schema_utils is shared by wmt. To not make other code libraries
Expand Down Expand Up @@ -90,7 +97,7 @@ def get_residual_config(
config = ml_collections.ConfigDict()
config_schema_utils.set_default_reference(
config,
parent_config, ["conv_proj", "conv_1", "conv_2", "conv_3"],
parent_config, ["conv_se", "conv_proj", "conv_1", "conv_2", "conv_3"],
parent_field="conv")
# TODO(b/179063860): The input distribution is an intrinsic model
# property and shouldn't be part of the model configuration. Update
Expand Down Expand Up @@ -154,6 +161,8 @@ def get_config(num_blocks,
# of conv filters in each layer by this number.
"filter_multiplier": float_ph(),
"act_function": str_ph(),
"se_ratio": float_ph(),
"init_group": int_ph(), # feature group in the second group conv layer
})
config_schema_utils.set_default_reference(
model_hparams, base_config, "act_function", parent_field="act_function")
Expand Down
10 changes: 10 additions & 0 deletions aqt/jax/imagenet/configs_script/config_schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_schema_matches_expected(self, num_blocks):
}

residual_block_schema = {
'conv_se': conv_schema,
'conv_proj': conv_schema,
'conv_1': conv_schema,
'conv_2': conv_schema,
Expand All @@ -171,6 +172,7 @@ def test_schema_matches_expected(self, num_blocks):
'weight_decay': None,
'activation_bound_update_freq': None,
'activation_bound_start_step': None,
'weight_quant_start_step': None,
'prec': None,
'half_shift': None,
'weight_prec': None,
Expand All @@ -183,11 +185,17 @@ def test_schema_matches_expected(self, num_blocks):
'shortcut_ch_shrink_method': None,
'shortcut_ch_expand_method': None,
'shortcut_spatial_method': None,
'teacher_model': None,
'is_teacher': None,
'seed': None,
'lr_scheduler': {
'warmup_epochs': None,
'cooldown_epochs': None,
'scheduler': None,
'num_epochs': None,
'endlr': None,
'knee_lr': None,
'knee_epochs': None,
},
'optimizer': None,
'adam': {
Expand All @@ -203,6 +211,8 @@ def test_schema_matches_expected(self, num_blocks):
'residual_blocks': [residual_block_schema] * num_blocks,
'filter_multiplier': None,
'act_function': None,
'se_ratio': None,
'init_group': None,
},
}

Expand Down
7 changes: 7 additions & 0 deletions aqt/jax/imagenet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class LrScheduler:
cooldown_epochs: int
scheduler: str
num_epochs: int
endlr: float
knee_lr: float
knee_epochs: int


@dataclass
Expand All @@ -54,6 +57,9 @@ class TrainingHParams:
optimizer: str
adam: Adam # only used when optimizer=='adam'
early_stop_steps: int
teacher_model: str
is_teacher: bool
seed: int

# Auto-clip activation quantization hparams. See
# train_utils.should_update_bounds for more details. We use -1 instead of None
Expand All @@ -64,6 +70,7 @@ class TrainingHParams:
# incomplete configuration.
activation_bound_update_freq: int
activation_bound_start_step: int
weight_quant_start_step: int

# Model hparams
model_hparams: Any
Loading

0 comments on commit 58eea29

Please sign in to comment.