Skip to content

Commit

Permalink
feat: add AddNamedDomainMatchingFunc and AddNamedMatchingFunc to enfo…
Browse files Browse the repository at this point in the history
…rcer (#122)

* feat: add AddNamedDomainMatchingFunc and AddNamedMatchingFunc to enforcer

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>

* refactor: remove log output

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>

* refactor: rename rm to rmMap

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>

* style: format

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>

* refactor: rename to python-style variable

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>

* refactor: remove unnecessary statements

Signed-off-by: Zxilly <zhouxinyu1001@gmail.com>
  • Loading branch information
Zxilly authored Feb 23, 2021
1 parent 4e1c864 commit e01f393
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 27 deletions.
1 change: 1 addition & 0 deletions casbin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .enforcer import *
from .synced_enforcer import SyncedEnforcer
from . import util
48 changes: 38 additions & 10 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from casbin.persist.adapters import FileAdapter
import logging

from casbin.effect import DefaultEffector, Effector
from casbin.model import Model, FunctionMap
from casbin.persist.adapters import FileAdapter
from casbin.rbac import default_role_manager
from casbin.util import generate_g_function, SimpleEval, util
from casbin.effect import DefaultEffector, Effector
import logging


class CoreEnforcer:
Expand All @@ -16,7 +17,7 @@ class CoreEnforcer:

adapter = None
watcher = None
rm = None
rm_map = None

enabled = False
auto_save = False
Expand Down Expand Up @@ -63,14 +64,16 @@ def init_with_model_and_adapter(self, m, adapter=None):
self.load_policy()

def _initialize(self):
self.rm = default_role_manager.RoleManager(10)
self.rm_map = dict()
self.eft = DefaultEffector()
self.watcher = None

self.enabled = True
self.auto_save = True
self.auto_build_role_links = True

self.init_rm_map()

@staticmethod
def new_model(path="", text=""):
"""creates a model."""
Expand Down Expand Up @@ -122,12 +125,11 @@ def set_watcher(self, watcher):

def get_role_manager(self):
"""gets the current role manager."""
return self.rm
return self.rm_map['g']

def set_role_manager(self, rm):
"""sets the current role manager."""

self.rm = rm
self.rm_map['g'] = rm

def set_effector(self, eft):
"""sets the current effector."""
Expand All @@ -139,12 +141,18 @@ def clear_policy(self):

self.model.clear_policy()

def init_rm_map(self):
if 'g' in self.model.model.keys():
for ptype in self.model.model['g']:
self.rm_map[ptype] = default_role_manager.RoleManager(10)

def load_policy(self):
"""reloads the policy from file/database."""

self.model.clear_policy()
self.adapter.load_policy(self.model)

self.init_rm_map()
self.model.print_policy()
if self.auto_build_role_links:
self.build_role_links()
Expand All @@ -157,6 +165,7 @@ def load_filtered_policy(self, filter):
raise ValueError("filtered policies are not supported by this adapter")

self.adapter.load_filtered_policy(self.model, filter)
self.init_rm_map()
self.model.print_policy()
if self.auto_build_role_links:
self.build_role_links()
Expand Down Expand Up @@ -203,8 +212,25 @@ def enable_auto_build_role_links(self, auto_build_role_links):
def build_role_links(self):
"""manually rebuild the role inheritance relations."""

self.rm.clear()
self.model.build_role_links(self.rm)
for rm in self.rm_map.values():
rm.clear()
self.model.build_role_links(self.rm_map)

def add_named_matching_func(self, ptype, fn):
"""add_named_matching_func add MatchingFunc by ptype RoleManager"""
try:
self.rm_map[ptype].add_matching_func(fn)
return True
except:
return False

def add_named_domain_matching_func(self, ptype, fn):
"""add_named_domain_matching_func add MatchingFunc by ptype to RoleManager"""
try:
self.rm_map[ptype].add_domain_matching_func(fn)
return True
except:
return False

def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
Expand Down Expand Up @@ -320,6 +346,8 @@ def enforce(self, *rvals):

return result



@staticmethod
def _get_expression(expr, functions=None):
expr = expr.replace("&&", "and")
Expand Down
13 changes: 7 additions & 6 deletions casbin/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ def get_implicit_roles_for_user(self, name, *domain):
name = q[0]
q = q[1:]

roles = self.rm.get_roles(name, *domain)
for r in roles:
if r not in roleSet:
res.append(r)
q.append(r)
roleSet[r] = True
for rm in self.rm_map.values():
roles = rm.get_roles(name, *domain)
for r in roles:
if r not in roleSet:
res.append(r)
q.append(r)
roleSet[r] = True

return res

Expand Down
5 changes: 3 additions & 2 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ def __init__(self):
self.logger = logging.getLogger()
self.model = {}

def build_role_links(self, rm):
def build_role_links(self, rm_map):
"""initializes the roles in RBAC."""

if "g" not in self.model.keys():
return

for ast in self.model["g"].values():
for ptype, ast in self.model["g"].items():
rm = rm_map[ptype]
ast.build_role_links(rm)

def print_policy(self):
Expand Down
14 changes: 12 additions & 2 deletions casbin/rbac/default_role_manager/role_manager.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from casbin.rbac import RoleManager
import logging

from casbin.rbac import RoleManager


class RoleManager(RoleManager):
"""provides a default implementation for the RoleManager interface"""

all_roles = dict()
max_hierarchy_level = 0


def __init__(self, max_hierarchy_level):
self.logger = logging.getLogger()
self.all_roles = dict()
self.max_hierarchy_level = max_hierarchy_level
self.matching_func = None
self.domain_matching_func = None
self.has_pattern = None
self.has_domain_pattern = None

def add_matching_func(self, fn):
def add_matching_func(self, fn=None):
self.has_pattern = True
self.matching_func = fn

def add_domain_matching_func(self, fn=None):
self.has_domain_pattern = True
self.domain_matching_func = fn

def has_role(self, name):
if self.matching_func is None:
return name in self.all_roles.keys()
Expand Down
16 changes: 14 additions & 2 deletions casbin/synced_enforcer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from casbin.enforcer import Enforcer
from casbin.util.rwlock import RWLockWrite
import threading
import time

from casbin.enforcer import Enforcer
from casbin.util.rwlock import RWLockWrite


class AtomicBool():

def __init__(self, value):
Expand Down Expand Up @@ -497,6 +499,16 @@ def enable_enforce(self, enabled=True):
with self._wl:
return self._e.enable_enforce(enabled)

def add_named_matching_func(self, ptype, fn):
"""add_named_matching_func add MatchingFunc by ptype RoleManager"""
with self._wl:
self._e.add_named_matching_func(ptype, fn)

def add_named_domain_matching_func(self, ptype, fn):
"""add_named_domain_matching_func add MatchingFunc by ptype to RoleManager"""
with self._wl:
self._e.add_named_domain_matching_func(ptype, fn)

def is_filtered(self):
"""returns true if the loaded policy has been filtered."""
with self._rl:
Expand Down
2 changes: 1 addition & 1 deletion casbin/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .builtin_operators import *
from .util import *
from .expression import *
from .util import *
10 changes: 6 additions & 4 deletions tests/test_enforcer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import casbin
import os
from unittest import TestCase
import time
from unittest import TestCase

import casbin


def get_examples(path):
examples_path = os.path.split(os.path.realpath(__file__))[0] + "/../examples/"
Expand Down Expand Up @@ -176,7 +178,7 @@ def test_enforce_rbac_with_pattern(self):
get_examples("rbac_with_pattern_policy.csv"))

#set matching function to key_match2
e.get_role_manager().add_matching_func(casbin.util.key_match2)
e.add_named_matching_func('g2', casbin.util.key_match2)

self.assertTrue(e.enforce("alice", "/book/1", "GET"))
self.assertTrue(e.enforce("alice", "/book/2", "GET"))
Expand All @@ -188,7 +190,7 @@ def test_enforce_rbac_with_pattern(self):
self.assertTrue(e.enforce("bob", "/pen/2", "GET"))

#replace key_match2 with key_match3
e.get_role_manager().add_matching_func(casbin.util.key_match3)
e.add_named_matching_func('g2', casbin.util.key_match3)
self.assertTrue(e.enforce("alice", "/book2/1", "GET"))
self.assertTrue(e.enforce("alice", "/book2/2", "GET"))
self.assertTrue(e.enforce("alice", "/pen2/1", "GET"))
Expand Down

0 comments on commit e01f393

Please sign in to comment.