From e88d5cf33d7927ffcbc250a7c36aa220b4e3d5a8 Mon Sep 17 00:00:00 2001
From: lizun
Date: Mon, 22 Aug 2022 11:31:31 -0600
Subject: [PATCH] add python ismcts
---
open_spiel/python/CMakeLists.txt | 1 +
open_spiel/python/algorithms/ismcts.py | 326 ++++++++++++++++++
.../python/algorithms/ismcts_agent_test.py | 52 +++
3 files changed, 379 insertions(+)
create mode 100644 open_spiel/python/algorithms/ismcts.py
create mode 100644 open_spiel/python/algorithms/ismcts_agent_test.py
diff --git a/open_spiel/python/CMakeLists.txt b/open_spiel/python/CMakeLists.txt
index 6576117d02..541df034ba 100644
--- a/open_spiel/python/CMakeLists.txt
+++ b/open_spiel/python/CMakeLists.txt
@@ -186,6 +186,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
algorithms/gambit_test.py
algorithms/generate_playthrough_test.py
algorithms/get_all_states_test.py
+ algorithms/ismcts_agent_test.py
algorithms/mcts_agent_test.py
algorithms/mcts_test.py
algorithms/minimax_test.py
diff --git a/open_spiel/python/algorithms/ismcts.py b/open_spiel/python/algorithms/ismcts.py
new file mode 100644
index 0000000000..6ace557e3d
--- /dev/null
+++ b/open_spiel/python/algorithms/ismcts.py
@@ -0,0 +1,326 @@
+# Copyright 2019 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pyspiel
+import copy
+import time
+from enum import Enum
+import numpy as np
+
+UNLIMITED_NUM_WORLD_SAMPLES = -1
+UNEXPANDED_VISIT_COUNT = -1
+TIE_TOLERANCE = 1e-5
+
+
+class ISMCTSFinalPolicyType(Enum):
+ """A enumeration class for final ISMCTS policy type."""
+ NORMALIZED_VISITED_COUNT = 1
+ MAX_VISIT_COUNT = 2
+ MAX_VALUE = 3
+
+
+class ChildSelectionPolicy(Enum):
+ """A enumeration class for children selection in ISMCTS."""
+ UCT = 1
+ PUCT = 2
+
+
+class ChildInfo(object):
+ """Child node information for the search tree."""
+
+ def __init__(self, visits, return_sum, prior):
+ self.visits = visits
+ self.return_sum = return_sum
+ self.prior = prior
+
+ def value(self):
+ return self.return_sum / self.visits
+
+
+class ISMCTSNode(object):
+ """Node data structure for the search tree."""
+
+ def __init__(self):
+ self.child_info = {}
+ self.total_visits = 0
+ self.prior_map = {}
+
+
+class ISMCTSBot(pyspiel.Bot):
+ """Adapted from the C++ implementation."""
+
+ def __init__(self,
+ game,
+ evaluator,
+ uct_c,
+ max_simulations,
+ max_world_samples=UNLIMITED_NUM_WORLD_SAMPLES,
+ random_state=None,
+ final_policy_type=ISMCTSFinalPolicyType.MAX_VISIT_COUNT,
+ use_observation_string=False,
+ allow_inconsistent_action_sets=False,
+ child_selection_policy=ChildSelectionPolicy.PUCT):
+
+ pyspiel.Bot.__init__(self)
+ self._game = game
+ self._evaluator = evaluator
+ self._uct_c = uct_c
+ self._max_simulations = max_simulations
+ self._max_world_samples = max_world_samples
+ self._final_policy_type = final_policy_type
+ self._use_observation_string = use_observation_string
+ self._allow_inconsistent_action_sets = allow_inconsistent_action_sets
+ self._nodes = {}
+ self._node_pool = []
+ self._root_samples = []
+ self._random_state = random_state or np.random.RandomState()
+ self._child_selection_policy = child_selection_policy
+ self._resampler_cb = None
+
+ def random_number(self):
+ return self._random_state.uniform()
+
+ def reset(self):
+ self._nodes = {}
+ self._node_pool = []
+ self._root_samples = []
+
+ def get_state_key(self, state):
+ if self._use_observation_string:
+ return state.current_player(), state.observation_string()
+ else:
+ return state.current_player(), state.information_state_string()
+
+ def run_search(self, state):
+ self.reset()
+ assert state.get_game().get_type().dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL
+ assert state.get_game().get_type(
+ ).information == pyspiel.GameType.Information.IMPERFECT_INFORMATION
+
+ legal_actions = state.legal_actions()
+ if len(legal_actions) == 1:
+ return [(legal_actions[0], 1.0)]
+
+ self._root_node = self.create_new_node(state)
+
+ assert self._root_node
+
+ root_infostate_key = self.get_state_key(state)
+
+ for sim in range(self._max_simulations):
+ # how to sample a pyspiel.state from another pyspiel.state?
+ sampled_root_state = self.sample_root_state(state)
+ assert root_infostate_key == self.get_state_key(sampled_root_state)
+ assert sampled_root_state
+ self.run_simulation(sampled_root_state)
+
+ if self._allow_inconsistent_action_sets: # when this happens?
+ legal_actions = state.legal_actions()
+ temp_node = self.filter_illegals(self._root_node, legal_actions)
+ assert temp_node.total_visits > 0
+ return self.get_final_policy(state, temp_node)
+ else:
+ return self.get_final_policy(state, self._root_node)
+
+ def step(self, state):
+ action_list, prob_list = zip(*self.run_search(state))
+ return self._random_state.choice(action_list, p=prob_list)
+
+ def get_policy(self, state):
+ return self.run_search(state)
+
+ def step_with_policy(self, state):
+ policy = self.get_policy(state)
+ action_list, prob_list = zip(*policy)
+ sampled_action = self._random_state.choice(action_list, p=prob_list)
+ return policy, sampled_action
+
+ def get_final_policy(self, state, node):
+ assert node
+ if self._final_policy_type == ISMCTSFinalPolicyType.NORMALIZED_VISITED_COUNT:
+ assert node.total_visits > 0
+ total_visits = node.total_visits
+ policy = [(action, child.visits/total_visits)
+ for action, child in node.child_info.items()]
+ elif self._final_policy_type == ISMCTSFinalPolicyType.MAX_VISIT_COUNT:
+ assert node.total_visits > 0
+ max_visits = -float('inf')
+ count = 0
+ for action, child in node.child_info.items():
+ if child.visits == max_visits:
+ count += 1
+ elif child.visits > max_visits:
+ max_visits = child.visits
+ count = 1
+ policy = [(action, 1./count if child.visits == max_visits else 0.0)
+ for action, child in node.child_info.items()]
+ elif self._final_policy_type == ISMCTSFinalPolicyType.MAX_VALUE:
+ assert node.total_visits > 0
+ max_value = -float('inf')
+ count = 0
+ for action, child in node.child_info.items():
+ if child.value() == max_value:
+ count += 1
+ elif child.value() > max_value:
+ max_value = child.value()
+ count = 1
+ policy = [(action, 1./count if child.value() == max_value else 0.0)
+ for action, child in node.child_info.items()]
+
+ policy_size = len(policy)
+ legal_actions = state.legal_actions()
+ if policy_size < len(legal_actions): # do we really need this step?
+ for action in legal_actions:
+ if action not in node.child_info:
+ policy.append((action, 0.0))
+ return policy
+
+ def sample_root_state(self, state):
+ if self._max_world_samples == UNLIMITED_NUM_WORLD_SAMPLES:
+ return self.resample_from_infostate(state)
+ elif len(self._root_samples) < self._max_world_samples:
+ self._root_samples.append(self.resample_from_infostate(state))
+ return self._root_samples[-1].clone()
+ elif len(self._root_samples) == self._max_world_samples:
+ idx = self._random_state.randint(len(self._root_samples))
+ return self._root_samples[idx].clone()
+ else:
+ raise pyspiel.SpielError(
+ "Case not handled (badly set max_world_samples..?)")
+
+ def resample_from_infostate(self, state):
+ if self._resampler_cb:
+ return self._resampler_cb(state, state.current_player())
+ else:
+ return state.resample_from_infostate(state.current_player(), pyspiel.UniformProbabilitySampler(0., 1.))
+
+ def create_new_node(self, state):
+ infostate_key = self.get_state_key(state)
+ self._node_pool.append(ISMCTSNode())
+ node = self._node_pool[-1]
+ self._nodes[infostate_key] = node
+ node.total_visits = UNEXPANDED_VISIT_COUNT
+ return node
+
+ def set_resampler(self, cb):
+ self._resampler_cb = cb
+
+ def lookup_node(self, state):
+ if self.get_state_key(state) in self._nodes:
+ return self._nodes[self.get_state_key(state)]
+ return None
+
+ def lookup_or_create_node(self, state):
+ node = self.lookup_node(state)
+ if node:
+ return node
+ return self.create_new_node(state)
+
+ def filter_illeals(self, node, legal_actions):
+ new_node = copy.deepcopy(node)
+ for action, child in node.child_info.items():
+ if action not in legal_actions:
+ new_node.total_visits -= child.visits
+ del new_node.child_info[action]
+ return new_node
+
+ def expand_if_necessary(self, node, action):
+ if action not in node.child_info:
+ node.child_info[action] = ChildInfo(0.0, 0.0, node.prior_map[action])
+
+ def select_action_tree_policy(self, node, legal_actions):
+ if self._allow_inconsistent_action_sets:
+ temp_node = self.filter_illegals(node, legal_actions)
+ if temp_node.total_visits == 0:
+ action = legal_actions[self._random_state.randint(
+ len(legal_actions))] # prior?
+ self.expand_if_necessary(node, action)
+ return action
+ else:
+ return self.select_action(temp_node)
+ else:
+ return self.select_action(node)
+
+ def select_action(self, node):
+ candidates = []
+ max_value = -float('inf')
+ for action, child in node.child_info.items():
+ assert child.visits > 0
+
+ action_value = child.value()
+ if self._child_selection_policy == ChildSelectionPolicy.UCT:
+ action_value += self._uct_c * \
+ np.sqrt(np.log(node.total_visits)/child.visits)
+ elif self._child_selection_policy == ChildSelectionPolicy.PUCT:
+ action_value += self._uct_c * child.prior * \
+ np.sqrt(node.total_visits)/(1 + child.visits)
+ else:
+ raise pyspiel.SpielError("Child selection policy unrecognized.")
+ if action_value > max_value + TIE_TOLERANCE:
+ candidates = [action]
+ max_value = action_value
+ elif action_value > max_value - TIE_TOLERANCE and action_value < max_value + TIE_TOLERANCE:
+ candidates.append(action)
+ max_value = action_value
+
+ assert len(candidates) >= 1
+ return candidates[self._random_state.randint(len(candidates))]
+
+ def check_expand(self, node, legal_actions):
+ if not self._allow_inconsistent_action_sets and len(node.child_info) == len(legal_actions):
+ return pyspiel.INVALID_ACTION
+ legal_actions_copy = copy.deepcopy(legal_actions)
+ self._random_state.shuffle(legal_actions_copy)
+ for action in legal_actions_copy:
+ if action not in node.child_info:
+ return action
+ return pyspiel.INVALID_ACTION
+
+ def run_simulation(self, state):
+ if state.is_terminal():
+ return state.returns()
+ elif state.is_chance_node():
+ action_list, prob_list = zip(*state.chance_outcomes())
+ chance_action = self._random_state.choice(action_list, p=prob_list)
+ state.apply_action(chance_action)
+ return self.run_simulation(state)
+ legal_actions = state.legal_actions()
+ cur_player = state.current_player()
+ node = self.lookup_or_create_node(state)
+
+ assert node
+
+ if node.total_visits == UNEXPANDED_VISIT_COUNT:
+ node.total_visits = 0
+ for action, prob in self._evaluator.prior(state):
+ node.prior_map[action] = prob
+ return self._evaluator.evaluate(state)
+ else:
+ chosen_action = self.check_expand(
+ node, legal_actions) # add one children at a time?
+ if chosen_action != pyspiel.INVALID_ACTION:
+ # check if all actions have been expanded, if not, select one?, if yes, ucb?
+ self.expand_if_necessary(node, chosen_action)
+ else:
+ chosen_action = self.select_action_tree_policy(node, legal_actions)
+
+ assert chosen_action != pyspiel.INVALID_ACTION
+
+ node.total_visits += 1
+ node.child_info[chosen_action].visits += 1
+ state.apply_action(chosen_action)
+ returns = self.run_simulation(state)
+ node.child_info[chosen_action].return_sum += returns[cur_player]
+ return returns
diff --git a/open_spiel/python/algorithms/ismcts_agent_test.py b/open_spiel/python/algorithms/ismcts_agent_test.py
new file mode 100644
index 0000000000..67d3f43d78
--- /dev/null
+++ b/open_spiel/python/algorithms/ismcts_agent_test.py
@@ -0,0 +1,52 @@
+# Copyright 2022 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test the IS-MCTS Agent."""
+
+from absl.testing import absltest
+from open_spiel.python import rl_environment
+from open_spiel.python.algorithms import ismcts
+from open_spiel.python.algorithms import mcts
+from open_spiel.python.algorithms import mcts_agent
+
+
+class MCTSAgentTest(absltest.TestCase):
+
+ def test_tic_tac_toe_episode(self):
+ env = rl_environment.Environment("kuhn_poker", include_full_state=True)
+ num_players = env.num_players
+ num_actions = env.action_spec()["num_actions"]
+
+ # Create the MCTS bot. Both agents can share the same bot in this case since
+ # there is no state kept between searches. See mcts.py for more info about
+ # the arguments.
+ ismcts_bot = ismcts.ISMCTSBot(
+ game=env.game, uct_c=1.5, max_simulations=100, evaluator=mcts.RandomRolloutEvaluator())
+
+ agents = [
+ mcts_agent.MCTSAgent(player_id=idx, num_actions=num_actions,
+ mcts_bot=ismcts_bot)
+ for idx in range(num_players)
+ ]
+
+ time_step = env.reset()
+ while not time_step.last():
+ player_id = time_step.observations["current_player"]
+ agent_output = agents[player_id].step(time_step)
+ time_step = env.step([agent_output.action])
+ for agent in agents:
+ agent.step(time_step)
+
+
+if __name__ == "__main__":
+ absltest.main()