Skip to content

Commit

Permalink
Fix internal blocks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 432311753
shivaniag authored and copybara-github committed Mar 4, 2022
1 parent 51960c0 commit a300d0a
Showing 8 changed files with 6 additions and 96 deletions.
2 changes: 1 addition & 1 deletion aqt/jax/flax_layers.py
Original file line number Diff line number Diff line change
@@ -84,7 +84,7 @@ class DenseAqt(nn.Module):

@dataclass
class HParams:
"""Hyperparameter class to quantize/sparsify Dense Layer."""
"""Hyperparameter class to quantize Dense Layer."""
# Target integer precision of weights in bits.
# If None, no weight quantization will be applied.
weight_prec: Union[None, int, QuantOps.FloatQuant]
7 changes: 0 additions & 7 deletions aqt/jax/imagenet/configs/base_config.py
Original file line number Diff line number Diff line change
@@ -108,13 +108,6 @@ def get_base_config(imagenet_type, quant_target):
"activation_bound_update_freq": -1,
"activation_bound_start_step": -1,
"prec": None,
"sparsity": {
"type": "N_M_STRUCTURED",
"prune_rate": None,
"order": "C",
"absolute": True,
"smallest": True,
},
"quant_type": "fake_quant",
"weight_quant_granularity": "per_channel",
"act_function": "relu",
12 changes: 0 additions & 12 deletions aqt/jax/imagenet/configs_script/config_schema_test.py
Original file line number Diff line number Diff line change
@@ -110,13 +110,6 @@ def test_schema_matches_expected(self, num_blocks):
# configuration hierarchy. A value of 'None' in the expected schemas defined
# below indicates a real configuration would have a concrete scalar value
# there.
sparsity_schema = {
'type': None,
'prune_rate': [None, None], # set to default structured
'smallest': None,
'order': None,
'absolute': None
}
quant_act_schema = {
'bounds': {
'initial_bound': None,
@@ -146,8 +139,6 @@ def test_schema_matches_expected(self, num_blocks):
'quant_type': None,
'quant_act': quant_act_schema,
'weight_half_shift': None,
'weight_sparsity': sparsity_schema,
'act_sparsity': sparsity_schema,
}

conv_schema = {
@@ -196,9 +187,6 @@ def test_schema_matches_expected(self, num_blocks):
'teacher_model': None,
'is_teacher': None,
'seed': None,
'sparsity': sparsity_schema,
'weight_sparsity': sparsity_schema,
'act_sparsity': sparsity_schema,
'lr_scheduler': {
'warmup_epochs': None,
'cooldown_epochs': None,
12 changes: 0 additions & 12 deletions aqt/jax/wmt_mlperf/hparams_config_scripts/config_schema_test.py
Original file line number Diff line number Diff line change
@@ -172,13 +172,6 @@ def test_schema_matches_expected(self, n_layers):
# below indicates a real configuration would have a concrete scalar value
# there.

sparsity_schema = {
'type': None,
'prune_rate': [None, None], # set to default structured
'smallest': None,
'order': None,
'absolute': None
}

quant_act_schema = {
'bounds': {
@@ -204,8 +197,6 @@ def test_schema_matches_expected(self, n_layers):
'quant_type': None,
'quant_act': quant_act_schema,
'weight_half_shift': None,
'weight_sparsity': sparsity_schema,
'act_sparsity': sparsity_schema,
}

embedding_schema = {
@@ -285,9 +276,6 @@ def test_schema_matches_expected(self, n_layers):
'embedding': embedding_schema,
'mlp_block': mlp_block_schema,
'attention': attention_schema,
'sparsity': sparsity_schema,
'weight_sparsity': sparsity_schema,
'act_sparsity': sparsity_schema,
'model_hparams': {
'emb_dim': None,
'num_heads': None,
8 changes: 1 addition & 7 deletions aqt/jax/wmt_mlperf/hparams_configs/base_config.py
Original file line number Diff line number Diff line change
@@ -102,13 +102,7 @@ def get_base_config(n_layers, use_auto_acts, fp_quant):
},
"weight_outlier_regularization_regex": "^.*kernel$",
"weight_quant_granularity": "per_channel",
"sparsity": {
"type": "N_M_STRUCTURED",
"prune_rate": None,
"order": "C",
"absolute": True,
"smallest": True,
},

})
if not fp_quant:
config.prec = None
22 changes: 1 addition & 21 deletions aqt/utils/config_schema_utils.py
Original file line number Diff line number Diff line change
@@ -125,8 +125,6 @@ def get_dense_config(
"quant_type",
"quant_act",
"weight_half_shift",
"act_sparsity",
"weight_sparsity",
])
config.lock()
return config
@@ -163,26 +161,14 @@ def get_fp_config():
return config


def get_sparse_config(use_unstructured):
"""Returns a sparse ConfigDict based on sparsity type argument."""
prune_rate = float_ph() if use_unstructured else (int_ph(), int_ph())
config = ml_collections.ConfigDict({
"type": str_ph(),
"prune_rate": prune_rate,
"smallest": bool_ph(),
"order": str_ph(),
"absolute": bool_ph()
})
config.lock()
return config


# TODO(shivaniagrawal): base config should be more generic and only model
# specific configs should be updated.
def get_base_config(
use_auto_acts,
fp_quant,
use_unstructured = False):
):
"""Return a base ConfigDict for AQT; does not have model specific fields."""
if use_auto_acts:
bounds = ml_collections.ConfigDict({
@@ -203,7 +189,6 @@ def get_base_config(
prec = get_fp_quant_config()
else:
prec = int_ph()
sparsity = get_sparse_config(use_unstructured)
base_config = ml_collections.ConfigDict({
"metadata": {
"description": "Base configuration",
@@ -213,7 +198,6 @@ def get_base_config(
"activation_bound_update_freq": int_ph(),
"activation_bound_start_step": int_ph(),
"prec": prec,
"sparsity": sparsity,
"half_shift": bool_ph(),
"quant_type": str_ph(),
"quant_act": {
@@ -230,10 +214,6 @@ def get_base_config(
set_default_reference(
base_config, base_config, "weight_prec", parent_field="prec")
set_default_reference(base_config.quant_act, base_config, "prec")
set_default_reference(
base_config, base_config, "act_sparsity", parent_field="sparsity")
set_default_reference(
base_config, base_config, "weight_sparsity", parent_field="sparsity")

set_default_reference(
base_config, base_config, "weight_half_shift", parent_field="half_shift")
38 changes: 3 additions & 35 deletions aqt/utils/config_schema_utils_test.py
Original file line number Diff line number Diff line change
@@ -168,27 +168,6 @@ def test_fp_precision_propagates(self, use_auto_acts):
self.assertEqual(config.weight_prec.to_dict(), expected_prec_dict)
self.assertEqual(config.quant_act.prec.to_dict(), expected_prec_dict)

@parameterized.parameters(
dict(use_unstructured=True), dict(use_unstructured=False))
def test_sparsity_propagates(self, use_unstructured):
config = config_schema_utils.get_base_config(
use_auto_acts=True, fp_quant=False, use_unstructured=use_unstructured)

prune_rate = 0.2 if use_unstructured else (1, 2)

expected_sparse_dict = {
'type': 'unstructured',
'prune_rate': prune_rate,
'smallest': True,
'absolute': False,
'order': 'C'
}

config.sparsity.update(expected_sparse_dict)

# Test that this sets the weight and activation to 4 as well.
self.assertEqual(config.weight_sparsity.to_dict(), expected_sparse_dict)
self.assertEqual(config.act_sparsity.to_dict(), expected_sparse_dict)

def test_auto_acts_parameter(self):
# If use_auto_acts is False, then the bounds should be a single scalar that
@@ -211,11 +190,11 @@ def test_auto_acts_parameter(self):
dict(use_auto_acts=True, fp_quant=False),
dict(use_auto_acts=False, fp_quant=False),
dict(use_auto_acts=False, fp_quant=True),
dict(use_auto_acts=False, fp_quant=True, use_unstructured=True))
)
def test_schema_matches_expected(self,
use_auto_acts,
fp_quant,
use_unstructured=False):
):
# This tests that the schema of the configdict returned by 'base_config',
# once all references are resolved, matches an expected schema. 'Schema'
# here means the names and structure of fields at each level of the
@@ -260,14 +239,6 @@ def test_schema_matches_expected(self,
'prec': prec,
'half_shift': None,
}
prune_rate = None if use_unstructured else [None, None]
sparsity = {
'type': None,
'prune_rate': prune_rate,
'smallest': None,
'order': None,
'absolute': None
}
expected_top_level_schema = {
'metadata': {
'description': None,
@@ -283,15 +254,12 @@ def test_schema_matches_expected(self,
'quant_type': None,
'quant_act': quant_act_schema,
'weight_quant_granularity': None,
'sparsity': sparsity,
'weight_sparsity': sparsity,
'act_sparsity': sparsity,
}

config = config_schema_utils.get_base_config(
use_auto_acts=use_auto_acts,
fp_quant=fp_quant,
use_unstructured=use_unstructured)
)
# This round-trip conversion from JSON forces all references to resolve to
# concrete values.
config_reified = json.loads(config.to_json())
1 change: 0 additions & 1 deletion aqt/utils/hparams_utils.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,6 @@ def load_dataclass_from_dict(dataclass_name,
enum_classes = [
quantization.QuantOps.ActHParams.InputDistribution,
quantization.QuantType, quant_config.QuantGranularity,

]
data_dict = _convert_lists_to_tuples(data_dict)
return dacite.from_dict(

0 comments on commit a300d0a

Please sign in to comment.