Skip to content

Commit

Permalink
Replaces references to jax.numpy.DeviceArray with jax.Array.\n
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 515673111
Change-Id: I87d1cbd2bb49aad7b8a4029b20d2b1bf058a573d
  • Loading branch information
hawkinsp authored and lanctot committed Mar 13, 2023
1 parent 59bfea3 commit ca2b942
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions open_spiel/python/examples/bridge_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def loss(
params: Params,
inputs: np.ndarray,
targets: np.ndarray,
) -> jnp.DeviceArray:
) -> jax.Array:
"""Cross-entropy loss."""
assert targets.dtype == np.int32
log_probs = net.apply(params, inputs)
Expand All @@ -140,7 +140,7 @@ def accuracy(
params: Params,
inputs: np.ndarray,
targets: np.ndarray,
) -> jnp.DeviceArray:
) -> jax.Array:
"""Classification accuracy."""
predictions = net.apply(params, inputs)
return jnp.mean(jnp.argmax(predictions, axis=-1) == targets)
Expand Down
4 changes: 2 additions & 2 deletions open_spiel/python/examples/hearts_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def loss(
params: Params,
inputs: np.ndarray,
targets: np.ndarray,
) -> jnp.DeviceArray:
) -> jax.Array:
"""Cross-entropy loss."""
assert targets.dtype == np.int32
log_probs = net.apply(params, inputs)
Expand All @@ -140,7 +140,7 @@ def accuracy(
params: Params,
inputs: np.ndarray,
targets: np.ndarray,
) -> jnp.DeviceArray:
) -> jax.Array:
"""Classification accuracy."""
predictions = net.apply(params, inputs)
return jnp.mean(jnp.argmax(predictions, axis=-1) == targets)
Expand Down
6 changes: 3 additions & 3 deletions open_spiel/python/examples/meta_cfr/sequential_games/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from open_spiel.python.examples.meta_cfr.sequential_games.typing import Params


def get_batched_input(input_list: List[jax.numpy.DeviceArray],
def get_batched_input(input_list: List[jax.Array],
infostate_list: List[InfostateNode],
illegal_action_list: List[List[int]], batch_size: int):
"""Returns list of function arguments extended to be consistent with batch size.
Expand Down Expand Up @@ -95,7 +95,7 @@ def filter_terminal_infostates(infostates_map: InfostateMapping):

def get_network_output(net_apply: ApplyFn, net_params: Params,
net_input: np.ndarray, illegal_actions: List[int],
key: hk.PRNGSequence) -> jax.numpy.DeviceArray:
key: hk.PRNGSequence) -> jax.Array:
"""Returns policy generated as output of model.
Args:
Expand All @@ -119,7 +119,7 @@ def get_network_output(net_apply: ApplyFn, net_params: Params,
def get_network_output_batched(
net_apply: ApplyFn, net_params: Params, net_input: np.ndarray,
all_illegal_actions: List[List[int]],
key: hk.PRNGSequence) -> List[jax.numpy.DeviceArray]:
key: hk.PRNGSequence) -> List[jax.Array]:
"""Returns policy of batched input generated as output of model.
Args:
Expand Down

0 comments on commit ca2b942

Please sign in to comment.