From 4081c4fcc75a55d640cfb4c4e74db6d54faa3985 Mon Sep 17 00:00:00 2001 From: hans Date: Mon, 26 Apr 2021 18:32:28 -0400 Subject: [PATCH] Adds a kdtree for efficiently positioning the agents randomly in the field avoiding collision. This can also be used for checking collisions efficiently (log(num_of_objects)). --- envs/rc_gym/Utils/kdtree.py | 88 +++++++++++++++++++ envs/rc_gym/Utils/kdtree_test.py | 24 +++++ rc_gym/Utils/__init__.py | 1 + rc_gym/__init__.py | 4 +- .../ssl/ssl_hw_challenge/static_defenders.py | 15 ++-- 5 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 envs/rc_gym/Utils/kdtree.py create mode 100644 envs/rc_gym/Utils/kdtree_test.py diff --git a/envs/rc_gym/Utils/kdtree.py b/envs/rc_gym/Utils/kdtree.py new file mode 100644 index 00000000..c6e1a0ca --- /dev/null +++ b/envs/rc_gym/Utils/kdtree.py @@ -0,0 +1,88 @@ +import math + + +def closest_node(values, node1, node2): + + if node1 is None: + return node2, node2.distance2_to(values) if node2 is not None else math.inf + + if node2 is None: + return node1, node1.distance2_to(values) if node1 is not None else math.inf + + node1_dist2 = node1.distance2_to(values) + node2_dist2 = node2.distance2_to(values) + + if node1_dist2 < node2_dist2: + return node1, node1_dist2 + else: + return node2, node2_dist2 + + +class KDTree: + class KDTreeNode: + + def __init__(self, values, left=None, right=None): + self.values = values + self.left = left + self.right = right + + def insert(self, values, depth=0): + if self.values is None: + self.values = values + else: + if values[depth % len(values)] < self.values[depth % len(self.values)]: + if self.left is None: + self.left = KDTree.KDTreeNode(values) + else: + self.left.insert(values, depth+1) + else: + if self.right is None: + self.right = KDTree.KDTreeNode(values) + else: + self.right.insert(values, depth+1) + + def distance2_to(self, values): + d2 = 0 + for i in range(len(values)): + d2 += (values[i] - self.values[i])**2 + + return d2 + + def get_nearest(self, values, depth=0): + if self.values is None: + return None, math.inf + + if self.left is None and self.right is None: + return self, self.distance2_to(values) + + if values[depth % len(values)] < self.values[depth % len(self.values)]: + next_branch = self.left + other_branch = self.right + else: + next_branch = self.left + other_branch = self.right + + if next_branch is not None: + other, _ = next_branch.get_nearest(values, depth+1) + closest, closest_dist2 = closest_node(values, other, self) + else: + closest, closest_dist2 = self, self.distance2_to(values) + + line_dist = values[depth % len(values)] - self.values[depth % len(self.values)] + + if other_branch is not None: + if closest_dist2 >= line_dist**2: + other, _ = other_branch.get_nearest(values, depth+1) + closest, closest_dist2 = closest_node(values, other, closest) + + return closest, closest_dist2 + + def __init__(self): + self.root = KDTree.KDTreeNode(None) + + def insert(self, values): + self.root.insert(values) + + def get_nearest(self, values): + node, dist2 = self.root.get_nearest(values) + return node.values, math.sqrt(dist2) diff --git a/envs/rc_gym/Utils/kdtree_test.py b/envs/rc_gym/Utils/kdtree_test.py new file mode 100644 index 00000000..44e27496 --- /dev/null +++ b/envs/rc_gym/Utils/kdtree_test.py @@ -0,0 +1,24 @@ +from kdtree import KDTree +import unittest + + +class TestKDTree(unittest.TestCase): + + def test_kdtree(self): + + tree = KDTree() + tree.insert((2, 6)) + tree.insert((3, 1)) + tree.insert((8, 7)) + tree.insert((10, 2)) + tree.insert((13, 3)) + + assert tree.get_nearest((9, 4)) == ((10, 2), 2.23606797749979) + assert tree.get_nearest((4, 1.5))[0] == (3, 1) + assert tree.get_nearest((7, 8))[0] == (8, 7) + assert tree.get_nearest((11, 1))[0] == (10, 2) + assert tree.get_nearest((13, 3))[0] == (13, 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/rc_gym/Utils/__init__.py b/rc_gym/Utils/__init__.py index ca3eea3b..00b2ea11 100644 --- a/rc_gym/Utils/__init__.py +++ b/rc_gym/Utils/__init__.py @@ -1 +1,2 @@ from rc_gym.Utils.Utils import * +from rc_gym.Utils.kdtree import KDTree diff --git a/rc_gym/__init__.py b/rc_gym/__init__.py index 2fc4bec9..885fc896 100644 --- a/rc_gym/__init__.py +++ b/rc_gym/__init__.py @@ -58,8 +58,8 @@ register(id='SSLStaticDefenders-v0', entry_point='rc_gym.ssl.ssl_hw_challenge.static_defenders:SSLHWStaticDefendersEnv', - kwargs={'field_type': 2}, - max_episode_steps=1000 + kwargs={'random_init': True, 'field_type': 2, 'n_robots_yellow': 4}, + max_episode_steps=1200 ) register(id='SSLDribbling-v0', diff --git a/rc_gym/ssl/ssl_hw_challenge/static_defenders.py b/rc_gym/ssl/ssl_hw_challenge/static_defenders.py index 73a5ec6c..f15da20d 100644 --- a/rc_gym/ssl/ssl_hw_challenge/static_defenders.py +++ b/rc_gym/ssl/ssl_hw_challenge/static_defenders.py @@ -234,14 +234,17 @@ def in_gk_area(obj): while in_gk_area(pos_frame.ball): pos_frame.ball = Ball(x=x(), y=y()) - d_ball_rbt = (self.field.ball_radius + self.field.rbt_radius) * 1.1 - d_rbt_rbt = (self.field.rbt_radius * 2) * 1.1 + radius_robot = 0.1 + places = KDTree() + places.insert((pos_frame.ball.x, pos_frame.ball.y)) for i in range(len(agents)): - for j in range(i): - while same_position_ref(agents[i], agents[j], d_rbt_rbt)\ - or same_position_ref(agents[i], pos_frame.ball, d_ball_rbt): - agents[i] = Robot(x=x(), y=y(), theta=theta()) + pos = (x(), y()) + while in_gk_area(pos) or places.get_nearest(pos)[1] < radius_robot: + pos = (x(), y()) + + places.insert(pos) + agents[i] = Robot(x=pos[0], y=pos[1], theta=theta()) for i in range(self.n_robots_yellow): pos_frame.robots_yellow[i] = agents[i]