Skip to content

Commit

Permalink
Adds a kdtree for efficiently positioning the agents randomly in the …
Browse files Browse the repository at this point in the history
…field avoiding collision. This can also be used for checking collisions efficiently (log(num_of_objects)).
  • Loading branch information
hfbassani authored and goncamateus committed Apr 30, 2021
1 parent 6b84104 commit 4081c4f
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 8 deletions.
88 changes: 88 additions & 0 deletions envs/rc_gym/Utils/kdtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math


def closest_node(values, node1, node2):

if node1 is None:
return node2, node2.distance2_to(values) if node2 is not None else math.inf

if node2 is None:
return node1, node1.distance2_to(values) if node1 is not None else math.inf

node1_dist2 = node1.distance2_to(values)
node2_dist2 = node2.distance2_to(values)

if node1_dist2 < node2_dist2:
return node1, node1_dist2
else:
return node2, node2_dist2


class KDTree:
class KDTreeNode:

def __init__(self, values, left=None, right=None):
self.values = values
self.left = left
self.right = right

def insert(self, values, depth=0):
if self.values is None:
self.values = values
else:
if values[depth % len(values)] < self.values[depth % len(self.values)]:
if self.left is None:
self.left = KDTree.KDTreeNode(values)
else:
self.left.insert(values, depth+1)
else:
if self.right is None:
self.right = KDTree.KDTreeNode(values)
else:
self.right.insert(values, depth+1)

def distance2_to(self, values):
d2 = 0
for i in range(len(values)):
d2 += (values[i] - self.values[i])**2

return d2

def get_nearest(self, values, depth=0):
if self.values is None:
return None, math.inf

if self.left is None and self.right is None:
return self, self.distance2_to(values)

if values[depth % len(values)] < self.values[depth % len(self.values)]:
next_branch = self.left
other_branch = self.right
else:
next_branch = self.left
other_branch = self.right

if next_branch is not None:
other, _ = next_branch.get_nearest(values, depth+1)
closest, closest_dist2 = closest_node(values, other, self)
else:
closest, closest_dist2 = self, self.distance2_to(values)

line_dist = values[depth % len(values)] - self.values[depth % len(self.values)]

if other_branch is not None:
if closest_dist2 >= line_dist**2:
other, _ = other_branch.get_nearest(values, depth+1)
closest, closest_dist2 = closest_node(values, other, closest)

return closest, closest_dist2

def __init__(self):
self.root = KDTree.KDTreeNode(None)

def insert(self, values):
self.root.insert(values)

def get_nearest(self, values):
node, dist2 = self.root.get_nearest(values)
return node.values, math.sqrt(dist2)
24 changes: 24 additions & 0 deletions envs/rc_gym/Utils/kdtree_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from kdtree import KDTree
import unittest


class TestKDTree(unittest.TestCase):

def test_kdtree(self):

tree = KDTree()
tree.insert((2, 6))
tree.insert((3, 1))
tree.insert((8, 7))
tree.insert((10, 2))
tree.insert((13, 3))

assert tree.get_nearest((9, 4)) == ((10, 2), 2.23606797749979)
assert tree.get_nearest((4, 1.5))[0] == (3, 1)
assert tree.get_nearest((7, 8))[0] == (8, 7)
assert tree.get_nearest((11, 1))[0] == (10, 2)
assert tree.get_nearest((13, 3))[0] == (13, 3)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions rc_gym/Utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from rc_gym.Utils.Utils import *
from rc_gym.Utils.kdtree import KDTree
4 changes: 2 additions & 2 deletions rc_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@

register(id='SSLStaticDefenders-v0',
entry_point='rc_gym.ssl.ssl_hw_challenge.static_defenders:SSLHWStaticDefendersEnv',
kwargs={'field_type': 2},
max_episode_steps=1000
kwargs={'random_init': True, 'field_type': 2, 'n_robots_yellow': 4},
max_episode_steps=1200
)

register(id='SSLDribbling-v0',
Expand Down
15 changes: 9 additions & 6 deletions rc_gym/ssl/ssl_hw_challenge/static_defenders.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,17 @@ def in_gk_area(obj):
while in_gk_area(pos_frame.ball):
pos_frame.ball = Ball(x=x(), y=y())

d_ball_rbt = (self.field.ball_radius + self.field.rbt_radius) * 1.1
d_rbt_rbt = (self.field.rbt_radius * 2) * 1.1
radius_robot = 0.1

places = KDTree()
places.insert((pos_frame.ball.x, pos_frame.ball.y))
for i in range(len(agents)):
for j in range(i):
while same_position_ref(agents[i], agents[j], d_rbt_rbt)\
or same_position_ref(agents[i], pos_frame.ball, d_ball_rbt):
agents[i] = Robot(x=x(), y=y(), theta=theta())
pos = (x(), y())
while in_gk_area(pos) or places.get_nearest(pos)[1] < radius_robot:
pos = (x(), y())

places.insert(pos)
agents[i] = Robot(x=pos[0], y=pos[1], theta=theta())

for i in range(self.n_robots_yellow):
pos_frame.robots_yellow[i] = agents[i]
Expand Down

0 comments on commit 4081c4f

Please sign in to comment.