diff --git a/tests/nn/test_batch_norm.py b/tests/nn/test_batch_norm.py deleted file mode 100644 index 77487a7c..00000000 --- a/tests/nn/test_batch_norm.py +++ /dev/null @@ -1,193 +0,0 @@ -import unittest - -import hypothesis as hp -import jax -import numpy as np -from flax import linen -from hypothesis import strategies as st - -import treex as tx - -INITS = ( - tx.initializers.zeros, - tx.initializers.ones, - tx.initializers.normal(), - tx.initializers.uniform(), -) - - -class BatchNormTest(unittest.TestCase): - @hp.given( - batch_size=st.integers(min_value=1, max_value=32), - length=st.integers(min_value=1, max_value=32), - channels=st.integers(min_value=1, max_value=32), - axis=st.sampled_from([-1]), # flax has an error with other axis - momentum=st.floats(min_value=0.01, max_value=1.0), - epsilon=st.floats(min_value=0.000001, max_value=0.01), - use_bias=st.booleans(), - use_scale=st.booleans(), - bias_init=st.sampled_from(INITS), - scale_init=st.sampled_from(INITS), - training=st.booleans(), - frozen=st.booleans(), - ) - @hp.settings(deadline=None, max_examples=20) - def test_equivalence( - self, - batch_size, - length, - channels, - axis, - momentum, - epsilon, - use_bias, - use_scale, - bias_init, - scale_init, - training, - frozen, - ): - use_running_average = not training or frozen - shape = (batch_size, length, channels) - - x = np.random.uniform(size=shape) - - key = tx.Key(42) - - flax_module = linen.BatchNorm( - use_running_average=use_running_average, - axis=axis, - momentum=momentum, - epsilon=epsilon, - use_bias=use_bias, - use_scale=use_scale, - bias_init=bias_init, - scale_init=scale_init, - ) - treex_module = ( - tx.BatchNorm( - axis=axis, - momentum=momentum, - epsilon=epsilon, - use_bias=use_bias, - use_scale=use_scale, - bias_init=bias_init, - scale_init=scale_init, - ) - .train(training) - .freeze(frozen) - ) - - flax_key, _ = tx.iter_split(key) # emulate init split - variables = flax_module.init(flax_key, x) - treex_module = treex_module.init(key, x) - - if use_bias: - assert np.allclose(variables["params"]["bias"], treex_module.bias) - - if use_scale: - assert np.allclose(variables["params"]["scale"], treex_module.scale) - - assert np.allclose(variables["batch_stats"]["mean"], treex_module.mean) - assert np.allclose(variables["batch_stats"]["var"], treex_module.var) - - y_flax, updates = flax_module.apply(variables, x, mutable=["batch_stats"]) - variables = variables.copy(updates) - - y_treex = treex_module(x) - - assert np.allclose(y_flax, y_treex) - - if use_bias: - assert np.allclose(variables["params"]["bias"], treex_module.bias) - - if use_scale: - assert np.allclose(variables["params"]["scale"], treex_module.scale) - - assert np.allclose(variables["batch_stats"]["mean"], treex_module.mean) - assert np.allclose(variables["batch_stats"]["var"], treex_module.var) - - def test_call(self): - x = np.random.uniform(size=(10, 2)) - module = tx.BatchNorm().init(42, x) - - y = module(x) - - assert y.shape == (10, 2) - - def test_tree(self): - x = np.random.uniform(size=(10, 2)) - module = tx.BatchNorm().init(42, x) - - flat = jax.tree_leaves(module) - - assert len(flat) == 5 - - def test_slice(self): - x = np.random.uniform(size=(10, 2)) - module = tx.BatchNorm().init(42, x) - - flat = jax.tree_leaves(module.filter(tx.Parameter)) - assert len(flat) == 2 - - flat = jax.tree_leaves(module.filter(tx.BatchStat)) - assert len(flat) == 2 - - flat = jax.tree_leaves( - module.filter(lambda field: not issubclass(field.kind, tx.TreePart)) - ) - assert len(flat) == 1 - - def test_jit(self): - x = np.random.uniform(size=(10, 2)) - module = tx.BatchNorm().init(42, x) - - @jax.jit - def f(module, x): - return module, module(x) - - module2, y = f(module, x) - - assert y.shape == (10, 2) - assert all( - np.allclose(a, b) - for a, b in zip( - jax.tree_leaves(module.filter(tx.Parameter)), - jax.tree_leaves(module2.filter(tx.Parameter)), - ) - ) - assert not all( - np.allclose(a, b) - for a, b in zip( - jax.tree_leaves(module.filter(tx.BatchStat)), - jax.tree_leaves(module2.filter(tx.BatchStat)), - ) - ) - - def test_eval(self): - x = np.random.uniform(size=(10, 2)) - module = tx.BatchNorm().init(42, x) - - @jax.jit - def f(module, x): - module = module.eval() - - return module, module(x) - - module2, y = f(module, x) - - assert y.shape == (10, 2) - assert all( - np.allclose(a, b) - for a, b in zip( - jax.tree_leaves(module.filter(tx.Parameter)), - jax.tree_leaves(module2.filter(tx.Parameter)), - ) - ) - assert all( - np.allclose(a, b) - for a, b in zip( - jax.tree_leaves(module.filter(tx.BatchStat)), - jax.tree_leaves(module2.filter(tx.BatchStat)), - ) - ) diff --git a/tests/nn/test_norm.py b/tests/nn/test_norm.py new file mode 100644 index 00000000..c482369e --- /dev/null +++ b/tests/nn/test_norm.py @@ -0,0 +1,512 @@ +import unittest + +import hypothesis as hp +import jax +import numpy as np +from flax import linen +from hypothesis import strategies as st + +import treex as tx + +INITS = ( + tx.initializers.zeros, + tx.initializers.ones, + tx.initializers.normal(), + tx.initializers.uniform(), +) + + +class BatchNormTest(unittest.TestCase): + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + length=st.integers(min_value=1, max_value=32), + channels=st.integers(min_value=1, max_value=32), + axis=st.sampled_from([-1]), # flax has an error with other axis + momentum=st.floats(min_value=0.01, max_value=1.0), + epsilon=st.floats(min_value=0.000001, max_value=0.01), + use_bias=st.booleans(), + use_scale=st.booleans(), + bias_init=st.sampled_from(INITS), + scale_init=st.sampled_from(INITS), + training=st.booleans(), + frozen=st.booleans(), + ) + @hp.settings(deadline=None, max_examples=20) + def test_equivalence( + self, + batch_size, + length, + channels, + axis, + momentum, + epsilon, + use_bias, + use_scale, + bias_init, + scale_init, + training, + frozen, + ): + use_running_average = not training or frozen + shape = (batch_size, length, channels) + + x = np.random.uniform(size=shape) + + key = tx.Key(42) + + flax_module = linen.BatchNorm( + use_running_average=use_running_average, + axis=axis, + momentum=momentum, + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + treex_module = ( + tx.BatchNorm( + axis=axis, + momentum=momentum, + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + .train(training) + .freeze(frozen) + ) + + flax_key, _ = tx.iter_split(key) # emulate init split + variables = flax_module.init(flax_key, x) + treex_module = treex_module.init(key, x) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + assert np.allclose(variables["batch_stats"]["mean"], treex_module.mean) + assert np.allclose(variables["batch_stats"]["var"], treex_module.var) + + y_flax, updates = flax_module.apply(variables, x, mutable=["batch_stats"]) + variables = variables.copy(updates) + + y_treex = treex_module(x) + + assert np.allclose(y_flax, y_treex) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + assert np.allclose(variables["batch_stats"]["mean"], treex_module.mean) + assert np.allclose(variables["batch_stats"]["var"], treex_module.var) + + def test_call(self): + x = np.random.uniform(size=(10, 2)) + module = tx.BatchNorm().init(42, x) + + y = module(x) + + assert y.shape == (10, 2) + + def test_tree(self): + x = np.random.uniform(size=(10, 2)) + module = tx.BatchNorm().init(42, x) + + flat = jax.tree_leaves(module) + + assert len(flat) == 5 + + def test_slice(self): + x = np.random.uniform(size=(10, 2)) + module = tx.BatchNorm().init(42, x) + + flat = jax.tree_leaves(module.filter(tx.Parameter)) + assert len(flat) == 2 + + flat = jax.tree_leaves(module.filter(tx.BatchStat)) + assert len(flat) == 2 + + flat = jax.tree_leaves( + module.filter(lambda field: not issubclass(field.kind, tx.TreePart)) + ) + assert len(flat) == 1 + + def test_jit(self): + x = np.random.uniform(size=(10, 2)) + module = tx.BatchNorm().init(42, x) + + @jax.jit + def f(module, x): + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 2) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) + assert not all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.BatchStat)), + jax.tree_leaves(module2.filter(tx.BatchStat)), + ) + ) + + def test_eval(self): + x = np.random.uniform(size=(10, 2)) + module = tx.BatchNorm().init(42, x) + + @jax.jit + def f(module, x): + module = module.eval() + + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 2) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.BatchStat)), + jax.tree_leaves(module2.filter(tx.BatchStat)), + ) + ) + + +class LayerNormTest(unittest.TestCase): + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + length=st.integers(min_value=1, max_value=32), + channels=st.integers(min_value=1, max_value=32), + epsilon=st.floats(min_value=0.000001, max_value=0.01), + use_bias=st.booleans(), + use_scale=st.booleans(), + bias_init=st.sampled_from(INITS), + scale_init=st.sampled_from(INITS), + ) + @hp.settings(deadline=None, max_examples=20) + def test_equivalence( + self, + batch_size, + length, + channels, + epsilon, + use_bias, + use_scale, + bias_init, + scale_init, + ): + shape = (batch_size, length, channels) + + x = np.random.uniform(size=shape) + + key = tx.Key(42) + + flax_module = linen.LayerNorm( + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + treex_module = tx.LayerNorm( + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + + flax_key, _ = tx.iter_split(key) # emulate init split + variables = flax_module.init(flax_key, x) + treex_module = treex_module.init(key, x) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + y_flax = flax_module.apply(variables, x) + + y_treex = treex_module(x) + + assert np.allclose(y_flax, y_treex) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + def test_call(self): + x = np.random.uniform(size=(10, 2)) + module = tx.LayerNorm().init(42, x) + + y = module(x) + + assert y.shape == (10, 2) + + def test_tree(self): + x = np.random.uniform(size=(10, 2)) + module = tx.LayerNorm().init(42, x) + + flat = jax.tree_leaves(module) + + assert len(flat) == 2 + + def test_slice(self): + x = np.random.uniform(size=(10, 2)) + module = tx.LayerNorm().init(42, x) + + flat = jax.tree_leaves(module.filter(tx.Parameter)) + assert len(flat) == 2 + + flat = jax.tree_leaves(module.filter(tx.BatchStat)) + assert len(flat) == 0 + + flat = jax.tree_leaves( + module.filter(lambda field: not issubclass(field.kind, tx.TreePart)) + ) + assert len(flat) == 0 + + def test_jit(self): + x = np.random.uniform(size=(10, 2)) + module = tx.LayerNorm().init(42, x) + + @jax.jit + def f(module, x): + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 2) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) + + def test_eval(self): + x = np.random.uniform(size=(10, 2)) + module = tx.LayerNorm().init(42, x) + + @jax.jit + def f(module, x): + module = module.eval() + + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 2) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) + + +class GroupNormTest(unittest.TestCase): + def _test_equivalence( + self, + batch_size, + length, + channels, + num_groups, + group_size, + epsilon, + use_bias, + use_scale, + bias_init, + scale_init, + ): + shape = (batch_size, length, channels) + + x = np.random.uniform(size=shape) + + key = tx.Key(42) + + flax_module = linen.GroupNorm( + num_groups=num_groups, + group_size=group_size, + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + treex_module = tx.GroupNorm( + num_groups=num_groups, + group_size=group_size, + epsilon=epsilon, + use_bias=use_bias, + use_scale=use_scale, + bias_init=bias_init, + scale_init=scale_init, + ) + + flax_key, _ = tx.iter_split(key) # emulate init split + variables = flax_module.init(flax_key, x) + treex_module = treex_module.init(key, x) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + y_flax = flax_module.apply(variables, x) + + y_treex = treex_module(x) + + assert np.allclose(y_flax, y_treex) + + if use_bias: + assert np.allclose(variables["params"]["bias"], treex_module.bias) + + if use_scale: + assert np.allclose(variables["params"]["scale"], treex_module.scale) + + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + length=st.integers(min_value=1, max_value=32), + channels=st.integers(min_value=1, max_value=32), + num_groups=st.none(), + group_size=st.just(1), + epsilon=st.floats(min_value=0.000001, max_value=0.01), + use_bias=st.booleans(), + use_scale=st.booleans(), + bias_init=st.sampled_from(INITS), + scale_init=st.sampled_from(INITS), + ) + @hp.settings(deadline=None, max_examples=20) + def test_equivalence_channels(self, **kwargs): + self._test_equivalence(**kwargs) + + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + length=st.integers(min_value=1, max_value=32), + channels=st.just(32), + num_groups=st.sampled_from([2 ** i for i in range(5)]), + group_size=st.none(), + epsilon=st.floats(min_value=0.000001, max_value=0.01), + use_bias=st.booleans(), + use_scale=st.booleans(), + bias_init=st.sampled_from(INITS), + scale_init=st.sampled_from(INITS), + ) + @hp.settings(deadline=None, max_examples=20) + def test_equivalence_num_groups(self, **kwargs): + self._test_equivalence(**kwargs) + + @hp.given( + batch_size=st.integers(min_value=1, max_value=32), + length=st.integers(min_value=1, max_value=32), + channels=st.just(32), + num_groups=st.none(), + group_size=st.sampled_from([2 ** i for i in range(5)]), + epsilon=st.floats(min_value=0.000001, max_value=0.01), + use_bias=st.booleans(), + use_scale=st.booleans(), + bias_init=st.sampled_from(INITS), + scale_init=st.sampled_from(INITS), + ) + @hp.settings(deadline=None, max_examples=20) + def test_equivalence_group_size(self, **kwargs): + self._test_equivalence(**kwargs) + + def test_call(self): + x = np.random.uniform(size=(10, 32)) + module = tx.GroupNorm().init(42, x) + + y = module(x) + + assert y.shape == (10, 32) + + def test_tree(self): + x = np.random.uniform(size=(10, 32)) + module = tx.GroupNorm().init(42, x) + + flat = jax.tree_leaves(module) + + assert len(flat) == 2 + + def test_slice(self): + x = np.random.uniform(size=(10, 32)) + module = tx.GroupNorm().init(42, x) + + flat = jax.tree_leaves(module.filter(tx.Parameter)) + assert len(flat) == 2 + + flat = jax.tree_leaves(module.filter(tx.BatchStat)) + assert len(flat) == 0 + + flat = jax.tree_leaves( + module.filter(lambda field: not issubclass(field.kind, tx.TreePart)) + ) + assert len(flat) == 0 + + def test_jit(self): + x = np.random.uniform(size=(10, 32)) + module = tx.GroupNorm().init(42, x) + + @jax.jit + def f(module, x): + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 32) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) + + def test_eval(self): + x = np.random.uniform(size=(10, 32)) + module = tx.GroupNorm().init(42, x) + + @jax.jit + def f(module, x): + module = module.eval() + + return module, module(x) + + module2, y = f(module, x) + + assert y.shape == (10, 32) + assert all( + np.allclose(a, b) + for a, b in zip( + jax.tree_leaves(module.filter(tx.Parameter)), + jax.tree_leaves(module2.filter(tx.Parameter)), + ) + ) diff --git a/treex/nn/__init__.py b/treex/nn/__init__.py index dea2ab4b..42bfe4af 100644 --- a/treex/nn/__init__.py +++ b/treex/nn/__init__.py @@ -1,6 +1,5 @@ from treex import types -from .batch_norm import BatchNorm from .conv import Conv from .dropout import Dropout from .embed import Embed @@ -8,6 +7,7 @@ from .flax_module import FlaxModule from .linear import Linear from .mlp import MLP +from .norm import BatchNorm, GroupNorm, LayerNorm from .sequential import Lambda, Sequential, sequence try: @@ -24,6 +24,8 @@ "Embed", "Flatten", "FlaxModule", + "GroupNorm", + "LayerNorm", "Linear", "MLP", "Lambda", diff --git a/treex/nn/batch_norm.py b/treex/nn/batch_norm.py deleted file mode 100644 index e1707188..00000000 --- a/treex/nn/batch_norm.py +++ /dev/null @@ -1,195 +0,0 @@ -import typing as tp - -import jax -import jax.numpy as jnp -import numpy as np -import treeo as to -from flax.linen import normalization as flax_module - -from treex import types, utils -from treex.module import Module, next_key - - -class BatchNorm(Module): - """BatchNorm Module. - - `BatchNorm` is implemented as a wrapper over `flax.linen.BatchNorm`, its constructor - arguments accept almost the same arguments including any Flax artifacts such as initializers. - Main differences: - - * `use_running_average` is not a constructor argument, but remains a `__call__` argument. - * `self.training` state is used to indicate how BatchNorm should behave, interally - `use_running_average = not self.training or self.frozen` is used unless `use_running_average` is explicitly - passed via `__call__`. - """ - - # pytree - mean: tp.Optional[jnp.ndarray] = types.BatchStat.node() - var: tp.Optional[jnp.ndarray] = types.BatchStat.node() - scale: tp.Optional[jnp.ndarray] = types.Parameter.node() - bias: tp.Optional[jnp.ndarray] = types.Parameter.node() - momentum: jnp.ndarray = to.node() - - # props - axis: int - epsilon: float - dtype: flax_module.Dtype - use_bias: bool - use_scale: bool - bias_init: tp.Callable[ - [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], - flax_module.Array, - ] - scale_init: tp.Callable[ - [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], - flax_module.Array, - ] - axis_name: tp.Optional[str] - axis_index_groups: tp.Any - - def __init__( - self, - *, - axis: int = -1, - momentum: tp.Union[float, jnp.ndarray] = 0.99, - epsilon: float = 1e-5, - dtype: flax_module.Dtype = jnp.float32, - use_bias: bool = True, - use_scale: bool = True, - bias_init: tp.Callable[ - [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], - flax_module.Array, - ] = flax_module.initializers.zeros, - scale_init: tp.Callable[ - [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], - flax_module.Array, - ] = flax_module.initializers.ones, - axis_name: tp.Optional[str] = None, - axis_index_groups: tp.Any = None, - ): - """ - Arguments: - features_in: the number of input features. - axis: the feature or non-batch axis of the input. - momentum: decay rate for the exponential moving average of - the batch statistics. - epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - use_bias: if True, bias (beta) is added. - use_scale: if True, multiply by scale (gamma). - When the next layer is linear (also e.g. nn.relu), this can be disabled - since the scaling will be done by the next layer. - bias_init: initializer for bias, by default, zero. - scale_init: initializer for scale, by default, one. - axis_name: the axis name used to combine batch statistics from multiple - devices. See `jax.pmap` for a description of axis names (default: None). - axis_index_groups: groups of axis indices within that named axis - representing subsets of devices to reduce over (default: None). For - example, `[[0, 1], [2, 3]]` would independently batch-normalize over - the examples on the first two and last two devices. See `jax.lax.psum` - for more details. - """ - - self.axis = axis - self.momentum = jnp.asarray(momentum) - self.epsilon = epsilon - self.dtype = dtype - self.use_bias = use_bias - self.use_scale = use_scale - self.bias_init = bias_init - self.scale_init = scale_init - self.axis_name = axis_name - self.axis_index_groups = axis_index_groups - - self.mean = None - self.var = None - self.scale = None - self.bias = None - - @property - def module(self) -> flax_module.BatchNorm: - return flax_module.BatchNorm( - use_running_average=None, - axis=self.axis, - momentum=self.momentum, - epsilon=self.epsilon, - dtype=self.dtype, - use_bias=self.use_bias, - use_scale=self.use_scale, - bias_init=self.bias_init, - scale_init=self.scale_init, - axis_name=self.axis_name, - axis_index_groups=self.axis_index_groups, - ) - - def __call__( - self, x: jnp.ndarray, use_running_average: tp.Optional[bool] = None - ) -> jnp.ndarray: - """Normalizes the input using batch statistics. - - Arguments: - x: the input to be normalized. - use_running_average: if true, the statistics stored in batch_stats - will be used instead of computing the batch statistics on the input. - - Returns: - Normalized inputs (the same shape as inputs). - """ - if self.initializing(): - variables = self.module.init( - next_key(), - x, - use_running_average=True, - ).unfreeze() - - # Extract collections - if "params" in variables: - params = variables["params"] - - if self.use_bias: - self.bias = params["bias"] - - if self.use_scale: - self.scale = params["scale"] - - self.mean = variables["batch_stats"]["mean"] - self.var = variables["batch_stats"]["var"] - - params = {} - - if self.use_bias: - params["bias"] = self.bias - - if self.use_scale: - params["scale"] = self.scale - - variables = dict( - batch_stats=dict( - mean=self.mean, - var=self.var, - ), - params=params, - ) - # use_running_average = True means batch_stats will not be mutated - # self.training = True means batch_stats will be mutated - training = ( - not use_running_average - if use_running_average is not None - else self.training and not self.frozen and self.initialized - ) - - # call apply - output, variables = self.module.apply( - variables, - x, - mutable=["batch_stats"] if training else [], - use_running_average=not training, - ) - variables = variables.unfreeze() - - # update batch_stats - if "batch_stats" in variables: - self.mean = variables["batch_stats"]["mean"] - self.var = variables["batch_stats"]["var"] - - return tp.cast(jnp.ndarray, output) diff --git a/treex/nn/norm.py b/treex/nn/norm.py new file mode 100644 index 00000000..59e3158a --- /dev/null +++ b/treex/nn/norm.py @@ -0,0 +1,455 @@ +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +import treeo as to +from flax.linen import normalization as flax_module + +from treex import types, utils +from treex.module import Module, next_key + + +class BatchNorm(Module): + """BatchNorm Module. + + `BatchNorm` is implemented as a wrapper over `flax.linen.BatchNorm`, its constructor + arguments accept almost the same arguments including any Flax artifacts such as initializers. + Main differences: + + * `use_running_average` is not a constructor argument, but remains a `__call__` argument. + * `self.training` state is used to indicate how BatchNorm should behave, interally + `use_running_average = not self.training or self.frozen` is used unless `use_running_average` is explicitly + passed via `__call__`. + """ + + # pytree + mean: tp.Optional[jnp.ndarray] = types.BatchStat.node() + var: tp.Optional[jnp.ndarray] = types.BatchStat.node() + scale: tp.Optional[jnp.ndarray] = types.Parameter.node() + bias: tp.Optional[jnp.ndarray] = types.Parameter.node() + momentum: jnp.ndarray = to.node() + + # props + axis: int + epsilon: float + dtype: flax_module.Dtype + use_bias: bool + use_scale: bool + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + axis_name: tp.Optional[str] + axis_index_groups: tp.Any + + def __init__( + self, + *, + axis: int = -1, + momentum: tp.Union[float, jnp.ndarray] = 0.99, + epsilon: float = 1e-5, + dtype: flax_module.Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.zeros, + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.ones, + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + ): + """ + Arguments: + features_in: the number of input features. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of + the batch statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + self.axis = axis + self.momentum = jnp.asarray(momentum) + self.epsilon = epsilon + self.dtype = dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + self.mean = None + self.var = None + self.scale = None + self.bias = None + + @property + def module(self) -> flax_module.BatchNorm: + return flax_module.BatchNorm( + use_running_average=None, + axis=self.axis, + momentum=self.momentum, + epsilon=self.epsilon, + dtype=self.dtype, + use_bias=self.use_bias, + use_scale=self.use_scale, + bias_init=self.bias_init, + scale_init=self.scale_init, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ) + + def __call__( + self, x: jnp.ndarray, use_running_average: tp.Optional[bool] = None + ) -> jnp.ndarray: + """Normalizes the input using batch statistics. + + Arguments: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + if self.initializing(): + variables = self.module.init( + next_key(), + x, + use_running_average=True, + ).unfreeze() + + # Extract collections + if "params" in variables: + params = variables["params"] + + if self.use_bias: + self.bias = params["bias"] + + if self.use_scale: + self.scale = params["scale"] + + self.mean = variables["batch_stats"]["mean"] + self.var = variables["batch_stats"]["var"] + + params = {} + + if self.use_bias: + params["bias"] = self.bias + + if self.use_scale: + params["scale"] = self.scale + + variables = dict( + batch_stats=dict( + mean=self.mean, + var=self.var, + ), + params=params, + ) + # use_running_average = True means batch_stats will not be mutated + # self.training = True means batch_stats will be mutated + training = ( + not use_running_average + if use_running_average is not None + else self.training and not self.frozen and self.initialized + ) + + # call apply + output, variables = self.module.apply( + variables, + x, + mutable=["batch_stats"] if training else [], + use_running_average=not training, + ) + variables = variables.unfreeze() + + # update batch_stats + if "batch_stats" in variables: + self.mean = variables["batch_stats"]["mean"] + self.var = variables["batch_stats"]["var"] + + +class LayerNorm(Module): + """LayerNorm Module. + + `LayerNorm` is implemented as a wrapper over `flax.linen.LayerNorm`, its constructor + arguments accept the same arguments including any Flax artifacts such as initializers. + + It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. + + """ + + # pytree + scale: tp.Optional[jnp.ndarray] = types.Parameter.node() + bias: tp.Optional[jnp.ndarray] = types.Parameter.node() + + # props + epsilon: float + dtype: flax_module.Dtype + use_bias: bool + use_scale: bool + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + + def __init__( + self, + *, + epsilon: float = 1e-5, + dtype: flax_module.Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.zeros, + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.ones, + ): + """ + Arguments: + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + """ + + self.epsilon = epsilon + self.dtype = dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + + self.scale = None + self.bias = None + + @property + def module(self) -> flax_module.LayerNorm: + return flax_module.LayerNorm( + epsilon=self.epsilon, + dtype=self.dtype, + use_bias=self.use_bias, + use_scale=self.use_scale, + bias_init=self.bias_init, + scale_init=self.scale_init, + ) + + def __call__( + self, + x: jnp.ndarray, + ) -> jnp.ndarray: + """Normalizes individual input on the last axis (channels) of the input data. + + Arguments: + x: the input to be normalized. + + Returns: + Normalized inputs (the same shape as inputs). + """ + if self.initializing(): + variables = self.module.init( + next_key(), + x, + use_running_average=True, + ).unfreeze() + + # Extract collections + if "params" in variables: + params = variables["params"] + + if self.use_bias: + self.bias = params["bias"] + + if self.use_scale: + self.scale = params["scale"] + + params = {} + + if self.use_bias: + params["bias"] = self.bias + + if self.use_scale: + params["scale"] = self.scale + + variables = dict( + params=params, + ) + + # call apply + output, variables = self.module.apply( + variables, + x, + ) + variables = variables.unfreeze() + + return tp.cast(jnp.ndarray, output) + + +class GroupNorm(Module): + """Group normalization Module (arxiv.org/abs/1803.08494). + + + `GroupNorm` is implemented as a wrapper over `flax.linen.GroupNorm`, its constructor arguments accept the same arguments including any Flax artifacts such as initializers. + + This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.. + + """ + + # pytree + scale: tp.Optional[jnp.ndarray] = types.Parameter.node() + bias: tp.Optional[jnp.ndarray] = types.Parameter.node() + + # props + num_groups: tp.Optional[int] + group_size: tp.Optional[int] + epsilon: float + dtype: flax_module.Dtype + use_bias: bool + use_scale: bool + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] + + def __init__( + self, + *, + num_groups: tp.Optional[int] = 32, + group_size: tp.Optional[int] = None, + epsilon: float = 1e-5, + dtype: flax_module.Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.zeros, + scale_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.initializers.ones, + ): + """ + Arguments: + num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. + group_size: the number of channels in a group. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + """ + + self.num_groups = num_groups + self.group_size = group_size + self.epsilon = epsilon + self.dtype = dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + + self.scale = None + self.bias = None + + @property + def module(self) -> flax_module.LayerNorm: + return flax_module.LayerNorm( + epsilon=self.epsilon, + dtype=self.dtype, + use_bias=self.use_bias, + use_scale=self.use_scale, + bias_init=self.bias_init, + scale_init=self.scale_init, + ) + + def __call__( + self, + x: jnp.ndarray, + ) -> jnp.ndarray: + """Normalizes the individual input over equally-sized group of channels. + + Arguments: + x: the input to be normalized. + + Returns: + Normalized inputs (the same shape as inputs). + """ + if self.initializing(): + variables = self.module.init( + next_key(), + x, + ).unfreeze() + + # Extract collections + if "params" in variables: + params = variables["params"] + + if self.use_bias: + self.bias = params["bias"] + + if self.use_scale: + self.scale = params["scale"] + + params = {} + + if self.use_bias: + params["bias"] = self.bias + + if self.use_scale: + params["scale"] = self.scale + + variables = dict( + params=params, + ) + + # call apply + output = self.module.apply( + variables, + x, + ) + + return tp.cast(jnp.ndarray, output)