Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
add LayerNorm and GroupNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
lkhphuc committed Jan 26, 2022
1 parent e021728 commit f94f4c5
Show file tree
Hide file tree
Showing 3 changed files with 584 additions and 1 deletion.
319 changes: 319 additions & 0 deletions tests/nn/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,322 @@ def f(module, x):
jax.tree_leaves(module2.filter(tx.BatchStat)),
)
)


class LayerNormTest(unittest.TestCase):
@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
length=st.integers(min_value=1, max_value=32),
channels=st.integers(min_value=1, max_value=32),
epsilon=st.floats(min_value=0.000001, max_value=0.01),
use_bias=st.booleans(),
use_scale=st.booleans(),
bias_init=st.sampled_from(INITS),
scale_init=st.sampled_from(INITS),
)
@hp.settings(deadline=None, max_examples=20)
def test_equivalence(
self,
batch_size,
length,
channels,
epsilon,
use_bias,
use_scale,
bias_init,
scale_init,
):
shape = (batch_size, length, channels)

x = np.random.uniform(size=shape)

key = tx.Key(42)

flax_module = linen.LayerNorm(
epsilon=epsilon,
use_bias=use_bias,
use_scale=use_scale,
bias_init=bias_init,
scale_init=scale_init,
)
treex_module = tx.LayerNorm(
epsilon=epsilon,
use_bias=use_bias,
use_scale=use_scale,
bias_init=bias_init,
scale_init=scale_init,
)

flax_key, _ = tx.iter_split(key) # emulate init split
variables = flax_module.init(flax_key, x)
treex_module = treex_module.init(key, x)

if use_bias:
assert np.allclose(variables["params"]["bias"], treex_module.bias)

if use_scale:
assert np.allclose(variables["params"]["scale"], treex_module.scale)

y_flax = flax_module.apply(variables, x)

y_treex = treex_module(x)

assert np.allclose(y_flax, y_treex)

if use_bias:
assert np.allclose(variables["params"]["bias"], treex_module.bias)

if use_scale:
assert np.allclose(variables["params"]["scale"], treex_module.scale)

def test_call(self):
x = np.random.uniform(size=(10, 2))
module = tx.LayerNorm().init(42, x)

y = module(x)

assert y.shape == (10, 2)

def test_tree(self):
x = np.random.uniform(size=(10, 2))
module = tx.LayerNorm().init(42, x)

flat = jax.tree_leaves(module)

assert len(flat) == 2

def test_slice(self):
x = np.random.uniform(size=(10, 2))
module = tx.LayerNorm().init(42, x)

flat = jax.tree_leaves(module.filter(tx.Parameter))
assert len(flat) == 2

flat = jax.tree_leaves(module.filter(tx.BatchStat))
assert len(flat) == 0

flat = jax.tree_leaves(
module.filter(lambda field: not issubclass(field.kind, tx.TreePart))
)
assert len(flat) == 0

def test_jit(self):
x = np.random.uniform(size=(10, 2))
module = tx.LayerNorm().init(42, x)

@jax.jit
def f(module, x):
return module, module(x)

module2, y = f(module, x)

assert y.shape == (10, 2)
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.filter(tx.Parameter)),
jax.tree_leaves(module2.filter(tx.Parameter)),
)
)

def test_eval(self):
x = np.random.uniform(size=(10, 2))
module = tx.LayerNorm().init(42, x)

@jax.jit
def f(module, x):
module = module.eval()

return module, module(x)

module2, y = f(module, x)

assert y.shape == (10, 2)
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.filter(tx.Parameter)),
jax.tree_leaves(module2.filter(tx.Parameter)),
)
)


class GroupNormTest(unittest.TestCase):
def _test_equivalence(
self,
batch_size,
length,
channels,
num_groups,
group_size,
epsilon,
use_bias,
use_scale,
bias_init,
scale_init,
):
shape = (batch_size, length, channels)

x = np.random.uniform(size=shape)

key = tx.Key(42)

flax_module = linen.GroupNorm(
num_groups=num_groups,
group_size=group_size,
epsilon=epsilon,
use_bias=use_bias,
use_scale=use_scale,
bias_init=bias_init,
scale_init=scale_init,
)
treex_module = tx.GroupNorm(
num_groups=num_groups,
group_size=group_size,
epsilon=epsilon,
use_bias=use_bias,
use_scale=use_scale,
bias_init=bias_init,
scale_init=scale_init,
)

flax_key, _ = tx.iter_split(key) # emulate init split
variables = flax_module.init(flax_key, x)
treex_module = treex_module.init(key, x)

if use_bias:
assert np.allclose(variables["params"]["bias"], treex_module.bias)

if use_scale:
assert np.allclose(variables["params"]["scale"], treex_module.scale)

y_flax = flax_module.apply(variables, x)

y_treex = treex_module(x)

assert np.allclose(y_flax, y_treex)

if use_bias:
assert np.allclose(variables["params"]["bias"], treex_module.bias)

if use_scale:
assert np.allclose(variables["params"]["scale"], treex_module.scale)

@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
length=st.integers(min_value=1, max_value=32),
channels=st.integers(min_value=1, max_value=32),
num_groups=st.none(),
group_size=st.just(1),
epsilon=st.floats(min_value=0.000001, max_value=0.01),
use_bias=st.booleans(),
use_scale=st.booleans(),
bias_init=st.sampled_from(INITS),
scale_init=st.sampled_from(INITS),
)
@hp.settings(deadline=None, max_examples=20)
def test_equivalence_channels(self, **kwargs):
self._test_equivalence(**kwargs)

@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
length=st.integers(min_value=1, max_value=32),
channels=st.just(32),
num_groups=st.sampled_from([2 ** i for i in range(5)]),
group_size=st.none(),
epsilon=st.floats(min_value=0.000001, max_value=0.01),
use_bias=st.booleans(),
use_scale=st.booleans(),
bias_init=st.sampled_from(INITS),
scale_init=st.sampled_from(INITS),
)
@hp.settings(deadline=None, max_examples=20)
def test_equivalence_num_groups(self, **kwargs):
self._test_equivalence(**kwargs)

@hp.given(
batch_size=st.integers(min_value=1, max_value=32),
length=st.integers(min_value=1, max_value=32),
channels=st.just(32),
num_groups=st.none(),
group_size=st.sampled_from([2 ** i for i in range(5)]),
epsilon=st.floats(min_value=0.000001, max_value=0.01),
use_bias=st.booleans(),
use_scale=st.booleans(),
bias_init=st.sampled_from(INITS),
scale_init=st.sampled_from(INITS),
)
@hp.settings(deadline=None, max_examples=20)
def test_equivalence_group_size(self, **kwargs):
self._test_equivalence(**kwargs)

def test_call(self):
x = np.random.uniform(size=(10, 32))
module = tx.GroupNorm().init(42, x)

y = module(x)

assert y.shape == (10, 32)

def test_tree(self):
x = np.random.uniform(size=(10, 32))
module = tx.GroupNorm().init(42, x)

flat = jax.tree_leaves(module)

assert len(flat) == 2

def test_slice(self):
x = np.random.uniform(size=(10, 32))
module = tx.GroupNorm().init(42, x)

flat = jax.tree_leaves(module.filter(tx.Parameter))
assert len(flat) == 2

flat = jax.tree_leaves(module.filter(tx.BatchStat))
assert len(flat) == 0

flat = jax.tree_leaves(
module.filter(lambda field: not issubclass(field.kind, tx.TreePart))
)
assert len(flat) == 0

def test_jit(self):
x = np.random.uniform(size=(10, 32))
module = tx.GroupNorm().init(42, x)

@jax.jit
def f(module, x):
return module, module(x)

module2, y = f(module, x)

assert y.shape == (10, 32)
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.filter(tx.Parameter)),
jax.tree_leaves(module2.filter(tx.Parameter)),
)
)

def test_eval(self):
x = np.random.uniform(size=(10, 32))
module = tx.GroupNorm().init(42, x)

@jax.jit
def f(module, x):
module = module.eval()

return module, module(x)

module2, y = f(module, x)

assert y.shape == (10, 32)
assert all(
np.allclose(a, b)
for a, b in zip(
jax.tree_leaves(module.filter(tx.Parameter)),
jax.tree_leaves(module2.filter(tx.Parameter)),
)
)
4 changes: 3 additions & 1 deletion treex/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .flax_module import FlaxModule
from .linear import Linear
from .mlp import MLP
from .norm import BatchNorm
from .norm import BatchNorm, GroupNorm, LayerNorm
from .sequential import Lambda, Sequential, sequence

try:
Expand All @@ -24,6 +24,8 @@
"Embed",
"Flatten",
"FlaxModule",
"GroupNorm",
"LayerNorm",
"Linear",
"MLP",
"Lambda",
Expand Down
Loading

0 comments on commit f94f4c5

Please sign in to comment.