Skip to content

Commit

Permalink
Internal fix for NumPy 1.24 test failures.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 503544399
Change-Id: Icf9acb48a940870a8479c89b3850f3dc01d49382
  • Loading branch information
hawkinsp authored and lanctot committed Jan 23, 2023
1 parent 87b4cce commit bbe0007
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Dataset:
def __init__(self, train_dataset: List[Tuple[List[List[float]],
InfostateNode]],
batch_size: int):
self._train_dataset = np.array(train_dataset)
self._train_dataset = np.array(train_dataset, dtype=object)
self._size = self._train_dataset.shape[0]
self._batch_size = batch_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def training_optimizer(self):
for _ in range(FLAGS.num_batches):
batch = next(data_loader)
cfvalues, infoset = zip(*batch)
cfvalues = np.array(list(cfvalues))
cfvalues = np.array(list(cfvalues), dtype=object)
cfvalues = utils.mask(cfvalues, infoset, len(self._all_actions),
FLAGS.batch_size)
self.optimize_infoset(cfvalues, infoset, self._infostate_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get_batched_input(input_list: List[jax.numpy.DeviceArray],
1) - len(input_list)
idx_sample = np.random.choice(len(input_list), items_to_sample)
input_zip = np.array(
list(zip(input_list, infostate_list, illegal_action_list)))
list(zip(input_list, infostate_list, illegal_action_list)),
dtype=object)
input_lst_sample = input_zip[idx_sample]
input_sample, infostate_sample, illegal_action_sample = zip(*input_lst_sample)

Expand Down

0 comments on commit bbe0007

Please sign in to comment.