Skip to content

Commit

Permalink
Add batch_shape argument to optax.tree_utils.tree_random_like.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 26, 2024
1 parent 1e08bcc commit 8fa173b
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def tree_random_like(
[chex.PRNGKey, chex.Shape, chex.ArrayDType], chex.Array
] = jax.random.normal,
dtype: Optional[chex.ArrayDType] = None,
batch_shape: tuple[int, ...] = (),
) -> chex.ArrayTree:
"""Create tree with random entries of the same shape as target tree.
Expand All @@ -54,6 +55,8 @@ def tree_random_like(
sampler: the noise sampling function, by default ``jax.random.normal``.
dtype: the desired dtype for the random numbers, passed to ``sampler``. If
None, the dtype of the target tree is used if possible.
batch_shape: additional leading batch dimensions to insert before the shape
of each leaf, by default ``()``.
Returns:
a random tree with the same structure as ``target_tree``, whose leaves have
Expand All @@ -69,7 +72,11 @@ def tree_random_like(
"""
keys_tree = tree_split_key_like(rng_key, target_tree)
return jax.tree.map(
lambda leaf, key: sampler(key, leaf.shape, dtype or leaf.dtype),
lambda leaf, key: sampler(
key,
batch_shape + leaf.shape,
dtype or leaf.dtype,
),
target_tree,
keys_tree,
)

0 comments on commit 8fa173b

Please sign in to comment.