Skip to content

Commit

Permalink
fix: role manager with matching_func
Browse files Browse the repository at this point in the history
Signed-off-by: Andreas Bichinger <andreas.bichinger@gmail.com>
  • Loading branch information
abichinger committed Jan 26, 2021
1 parent 5e16bff commit 8079cda
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 16 deletions.
47 changes: 31 additions & 16 deletions casbin/rbac/default_role_manager/role_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def create_role(self, name):
if name not in self.all_roles.keys():
self.all_roles[name] = Role(name)

if self.matching_func is not None:
for key, role in self.all_roles.items():
if self.matching_func(name, key) and name != key:
self.all_roles[name].add_role(role)

return self.all_roles[name]

def clear(self):
Expand All @@ -50,6 +45,17 @@ def add_link(self, name1, name2, *domain):
role2 = self.create_role(name2)
role1.add_role(role2)

if self.matching_func is not None:
for key, role in self.all_roles.items():
if self.matching_func(key, name1) and name1 != key:
self.all_roles[key].add_role(role1)
if self.matching_func(key, name2) and name2 != key:
self.all_roles[name2].add_role(role)
if self.matching_func(name1, key) and name1 != key:
self.all_roles[key].add_role(role1)
if self.matching_func(name2, key) and name2 != key:
self.all_roles[name2].add_role(role)

def delete_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
Expand Down Expand Up @@ -77,9 +83,14 @@ def has_link(self, name1, name2, *domain):
if not self.has_role(name1) or not self.has_role(name2):
return False

role1 = self.create_role(name1)

return role1.has_role(name2, self.max_hierarchy_level)
if self.matching_func is None:
role1 = self.create_role(name1)
return role1.has_role(name2, self.max_hierarchy_level)
else:
for key, role in self.all_roles.items():
if self.matching_func(name1, key) and role.has_role(name2, self.max_hierarchy_level, self.matching_func):
return True
return False

def get_roles(self, name, *domain):
"""
Expand Down Expand Up @@ -158,23 +169,27 @@ def delete_role(self, role):
self.roles.remove(rr)
return

def has_role(self, name, hierarchy_level):
if name == self.name:
def has_role(self, name, hierarchy_level, matching_func=None):
if self.has_direct_role(name, matching_func):
return True
if hierarchy_level <= 0:
return False

for role in self.roles:
if role.has_role(name, hierarchy_level - 1):
if role.has_role(name, hierarchy_level - 1, matching_func):
return True

return False

def has_direct_role(self, name):
for role in self.roles:
if role.name == name:
return True

def has_direct_role(self, name, matching_func=None):
if matching_func is None:
for role in self.roles:
if role.name == name:
return True
else:
for role in self.roles:
if matching_func(name, role.name):
return True
return False

def to_string(self):
Expand Down
Empty file added tests/rbac/__init__.py
Empty file.
234 changes: 234 additions & 0 deletions tests/rbac/test_role_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from unittest import TestCase
from casbin.rbac import default_role_manager
from casbin.util import regex_match_func
import time
from concurrent.futures import ThreadPoolExecutor

def get_role_manager():
return default_role_manager.RoleManager(max_hierarchy_level=10)

class TestDefaultRoleManager(TestCase):

def test_role(self):
rm = get_role_manager()
rm.add_link("u1", "g1")
rm.add_link("u2", "g1")
rm.add_link("u3", "g2")
rm.add_link("u4", "g2")
rm.add_link("u4", "g3")
rm.add_link("g1", "g3")

# Current role inheritance tree:
# g3 g2
# / \ / \
# g1 u4 u3
# / \
# u1 u2

self.assertTrue(rm.has_link("u1", "g1"))
self.assertFalse(rm.has_link("u1", "g2"))
self.assertTrue(rm.has_link("u1", "g3"))
self.assertTrue(rm.has_link("u2", "g1"))
self.assertFalse(rm.has_link("u2", "g2"))
self.assertTrue(rm.has_link("u2", "g3"))
self.assertFalse(rm.has_link("u3", "g1"))
self.assertTrue(rm.has_link("u3", "g2"))
self.assertFalse(rm.has_link("u3", "g3"))
self.assertFalse(rm.has_link("u4", "g1"))
self.assertTrue(rm.has_link("u4", "g2"))
self.assertTrue(rm.has_link("u4", "g3"))

self.assertCountEqual(rm.get_roles("u1"), ["g1"])
self.assertCountEqual(rm.get_roles("u2"), ["g1"])
self.assertCountEqual(rm.get_roles("u3"), ["g2"])
self.assertCountEqual(rm.get_roles("u4"), ["g2", "g3"])
self.assertCountEqual(rm.get_roles("g1"), ["g3"])
self.assertCountEqual(rm.get_roles("g2"), [])
self.assertCountEqual(rm.get_roles("g3"), [])

rm.delete_link("g1", "g3")
rm.delete_link("u4", "g2")

# Current role inheritance tree after deleting the links:
# g3 g2
# \ \
# g1 u4 u3
# / \
# u1 u2

self.assertTrue(rm.has_link("u1", "g1"))
self.assertFalse(rm.has_link("u1", "g2"))
self.assertFalse(rm.has_link("u1", "g3"))
self.assertTrue(rm.has_link("u2", "g1"))
self.assertFalse(rm.has_link("u2", "g2"))
self.assertFalse(rm.has_link("u2", "g3"))
self.assertFalse(rm.has_link("u3", "g1"))
self.assertTrue(rm.has_link("u3", "g2"))
self.assertFalse(rm.has_link("u3", "g3"))
self.assertFalse(rm.has_link("u4", "g1"))
self.assertFalse(rm.has_link("u4", "g2"))
self.assertTrue(rm.has_link("u4", "g3"))

self.assertCountEqual(rm.get_roles("u1"), ["g1"])
self.assertCountEqual(rm.get_roles("u2"), ["g1"])
self.assertCountEqual(rm.get_roles("u3"), ["g2"])
self.assertCountEqual(rm.get_roles("u4"), ["g3"])
self.assertCountEqual(rm.get_roles("g1"), [])
self.assertCountEqual(rm.get_roles("g2"), [])
self.assertCountEqual(rm.get_roles("g3"), [])

def test_domain_role(self):
rm = get_role_manager()
rm.add_link("u1", "g1", "domain1")
rm.add_link("u2", "g1", "domain1")
rm.add_link("u3", "admin", "domain2")
rm.add_link("u4", "admin", "domain2")
rm.add_link("u4", "admin", "domain1")
rm.add_link("g1", "admin", "domain1")

# Current role inheritance tree:
# domain1:admin domain2:admin
# / \ / \
# domain1:g1 u4 u3
# / \
# u1 u2

self.assertTrue(rm.has_link("u1", "g1", "domain1"))
self.assertFalse(rm.has_link("u1", "g1", "domain2"))
self.assertTrue(rm.has_link("u1", "admin", "domain1"))
self.assertFalse(rm.has_link("u1", "admin", "domain2"))

self.assertTrue(rm.has_link("u2", "g1", "domain1"))
self.assertFalse(rm.has_link("u2", "g1", "domain2"))
self.assertTrue(rm.has_link("u2", "admin", "domain1"))
self.assertFalse(rm.has_link("u2", "admin", "domain2"))

self.assertFalse(rm.has_link("u3", "g1", "domain1"))
self.assertFalse(rm.has_link("u3", "g1", "domain2"))
self.assertFalse(rm.has_link("u3", "admin", "domain1"))
self.assertTrue(rm.has_link("u3", "admin", "domain2"))

self.assertFalse(rm.has_link("u4", "g1", "domain1"))
self.assertFalse(rm.has_link("u4", "g1", "domain2"))
self.assertTrue(rm.has_link("u4", "admin", "domain1"))
self.assertTrue(rm.has_link("u4", "admin", "domain2"))

def test_clear(self):
rm = get_role_manager()
rm.add_link("u1", "g1")
rm.add_link("u2", "g1")
rm.add_link("u3", "g2")
rm.add_link("u4", "g2")
rm.add_link("u4", "g3")
rm.add_link("g1", "g3")

# Current role inheritance tree:
# g3 g2
# / \ / \
# g1 u4 u3
# / \
# u1 u2

rm.clear()

# All data is cleared.
# No role inheritance now.

self.assertFalse(rm.has_link("u1", "g1"))
self.assertFalse(rm.has_link("u1", "g2"))
self.assertFalse(rm.has_link("u1", "g3"))
self.assertFalse(rm.has_link("u2", "g1"))
self.assertFalse(rm.has_link("u2", "g2"))
self.assertFalse(rm.has_link("u2", "g3"))
self.assertFalse(rm.has_link("u3", "g1"))
self.assertFalse(rm.has_link("u3", "g2"))
self.assertFalse(rm.has_link("u3", "g3"))
self.assertFalse(rm.has_link("u4", "g1"))
self.assertFalse(rm.has_link("u4", "g2"))
self.assertFalse(rm.has_link("u4", "g3"))

def test_matching_func(self):
rm = get_role_manager()
rm.add_matching_func(regex_match_func)

rm.add_link("u1", "g1")
rm.add_link("u3", "g2")
rm.add_link("u3", "g3")
rm.add_link(r"u\d+", "g2")

self.assertTrue(rm.has_link("u1", "g1"))
self.assertTrue(rm.has_link("u1", "g2"))
self.assertFalse(rm.has_link("u1", "g3"))

self.assertFalse(rm.has_link("u2", "g1"))
self.assertTrue(rm.has_link("u2", "g2"))
self.assertFalse(rm.has_link("u2", "g3"))

self.assertFalse(rm.has_link("u3", "g1"))
self.assertTrue(rm.has_link("u3", "g2"))
self.assertTrue(rm.has_link("u3", "g3"))

def test_one_to_many(self):
rm = get_role_manager()
rm.add_matching_func(regex_match_func)

rm.add_link("u1", r"g\d+")
self.assertTrue(rm.has_link("u1", "g1"))
self.assertTrue(rm.has_link("u1", "g2"))
self.assertFalse(rm.has_link("u2", "g1"))
self.assertFalse(rm.has_link("u2", "g2"))

def test_many_to_one(self):
rm = get_role_manager()
rm.add_matching_func(regex_match_func)

rm.add_link(r"u\d+", "g1")
self.assertTrue(rm.has_link("u1", "g1"))
self.assertFalse(rm.has_link("u1", "g2"))
self.assertTrue(rm.has_link("u2", "g1"))
self.assertFalse(rm.has_link("u2", "g2"))

def test_matching_func_order(self):
rm = get_role_manager()
rm.add_matching_func(regex_match_func)

rm.add_link(r"g\d+", "root")
rm.add_link("u1", "g1")
self.assertTrue(rm.has_link("u1", "root"))

rm.clear()

rm.add_link("u1", "g1")
rm.add_link(r"g\d+", "root")
self.assertTrue(rm.has_link("u1", "root"))

rm.clear()

rm.add_link("u1", r"g\d+")
rm.add_link("g1", "root")
self.assertTrue(rm.has_link("u1", "root"))

rm.clear()

rm.add_link("g1", "root")
rm.add_link("u1", r"g\d+")
self.assertTrue(rm.has_link("u1", "root"))

def test_concurrent_has_link_with_matching_func(self):

def matching_func(*args):
time.sleep(0.01)
return regex_match_func(*args)

rm = get_role_manager()
rm.add_matching_func(matching_func)
rm.add_link(r"u\d+", "users")

def test_has_link(role):
return rm.has_link(role, "users")

executor = ThreadPoolExecutor(10)
futures = [executor.submit(test_has_link, "u"+str(i)) for i in range(10)]
for future in futures:
self.assertTrue(future.result())

0 comments on commit 8079cda

Please sign in to comment.