Skip to content

Commit

Permalink
feat: port fastbin to casbin (#318)
Browse files Browse the repository at this point in the history
* feat: port fastbin

* feat: implement FastEnforcer

* fix: remove redundant init code
  • Loading branch information
BustDot authored Sep 16, 2023
1 parent 94b2172 commit 67537d6
Show file tree
Hide file tree
Showing 11 changed files with 373 additions and 15 deletions.
1 change: 1 addition & 0 deletions casbin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .enforcer import *
from .synced_enforcer import SyncedEnforcer
from .distributed_enforcer import DistributedEnforcer
from .fast_enforcer import FastEnforcer
from .async_enforcer import AsyncEnforcer
from . import util
from .persist import *
Expand Down
6 changes: 1 addition & 5 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy
import logging

from casbin.effect import Effector, get_effector, effect_to_bool
from casbin.model import Model, FunctionMap
Expand Down Expand Up @@ -202,7 +202,6 @@ def load_policy(self):
new_model.clear_policy()

try:

self.adapter.load_policy(new_model)

new_model.sort_policies_by_subject_hierarchy()
Expand All @@ -212,7 +211,6 @@ def load_policy(self):
new_model.print_policy()

if self.auto_build_role_links:

need_to_rebuild = True
for rm in self.rm_map.values():
rm.clear()
Expand All @@ -222,7 +220,6 @@ def load_policy(self):
self.model = new_model

except Exception as e:

if self.auto_build_role_links and need_to_rebuild:
self.build_role_links()

Expand Down Expand Up @@ -315,7 +312,6 @@ def add_named_domain_matching_func(self, ptype, fn):
return False

def new_enforce_context(self, suffix: str) -> EnforceContext:

return EnforceContext(
rtype="r" + suffix,
ptype="p" + suffix,
Expand Down
6 changes: 2 additions & 4 deletions casbin/distributed_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from casbin import SyncedEnforcer
import logging

from casbin.persist import batch_adapter
from casbin.model.policy_op import PolicyOp
from casbin.persist import batch_adapter
from casbin.persist.adapters import update_adapter
from casbin.synced_enforcer import SyncedEnforcer


class DistributedEnforcer(SyncedEnforcer):
Expand Down
41 changes: 41 additions & 0 deletions casbin/fast_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Sequence

from casbin.enforcer import Enforcer
from casbin.model import Model, FastModel, fast_policy_filter, FunctionMap
from casbin.persist.adapters import FileAdapter
from casbin.util.log import configure_logging


class FastEnforcer(Enforcer):
_cache_key_order: Sequence[int] = None

def __init__(self, model=None, adapter=None, enable_log=False, cache_key_order: Sequence[int] = None):
self._cache_key_order = cache_key_order
super().__init__(model, adapter, enable_log)

def new_model(self, path="", text=""):
"""creates a model."""
if self._cache_key_order is None:
m = Model()
else:
m = FastModel(self._cache_key_order)
if len(path) > 0:
m.load_model(path)
else:
m.load_model_from_text(text)

return m

def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
if FastEnforcer._cache_key_order is None:
result, _ = self.enforce_ex(*rvals)
else:
keys = [rvals[x] for x in self._cache_key_order]
with fast_policy_filter(self.model.model["p"]["p"].policy, *keys):
result, _ = self.enforce_ex(*rvals)

return result
4 changes: 3 additions & 1 deletion casbin/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

from .assertion import Assertion
from .function import FunctionMap
from .model import Model
from .model_fast import FastModel
from .policy import Policy
from .function import FunctionMap
from .policy_fast import FastPolicy, fast_policy_filter
36 changes: 36 additions & 0 deletions casbin/model/model_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Sequence

from .model import Model
from .policy_fast import FastPolicy


class FastModel(Model):
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
super().__init__()
self._cache_key_order = cache_key_order

def add_def(self, sec: str, key: str, value: Any) -> None:
super().add_def(sec, key, value)
if sec == "p" and key == "p":
self.model[sec][key].policy = FastPolicy(self._cache_key_order)

def clear_policy(self) -> None:
"""clears all current policy."""
super().clear_policy()
self.model["p"]["p"].policy = FastPolicy(self._cache_key_order)
101 changes: 101 additions & 0 deletions casbin/model/policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast


def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]:
if keys[0] in cache:
if len(keys) > 1:
return in_cache(cache[keys[-0]], keys[1:])
return cast(Set[Sequence[str]], cache[keys[0]])
else:
return None


class FastPolicy(Container[Sequence[str]]):
_cache: Dict[str, Any]
_current_filter: Optional[Set[Sequence[str]]]
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
self._cache = {}
self._current_filter = None
self._cache_key_order = cache_key_order

def __iter__(self) -> Iterator[Sequence[str]]:
yield from self.__get_policy()

def __len__(self) -> int:
return len(list(self.__get_policy()))

def __contains__(self, item: object) -> bool:
if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item):
return False
keys = [item[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return False
return tuple(item) in exists

def __getitem__(self, item: int) -> Sequence[str]:
for i, entry in enumerate(self):
if i == item:
return entry
raise KeyError("No such value exists")

def append(self, item: Sequence[str]) -> None:
cache = self._cache
keys = [item[x] for x in self._cache_key_order]

for key in keys[:-1]:
if key not in cache:
cache[key] = dict()
cache = cache[key]
if keys[-1] not in cache:
cache[keys[-1]] = set()

cache[keys[-1]].add(tuple(item))

def remove(self, policy: Sequence[str]) -> bool:
keys = [policy[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return True

exists.remove(tuple(policy))
return True

def __get_policy(self) -> Iterable[Sequence[str]]:
if self._current_filter is not None:
return (list(x) for x in self._current_filter)
else:
return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1)

def apply_filter(self, *keys: str) -> None:
value = in_cache(self._cache, keys)
self._current_filter = value or set()

def clear_filter(self) -> None:
self._current_filter = None


@contextmanager
def fast_policy_filter(policy: FastPolicy, *keys: str) -> Iterator[None]:
try:
policy.apply_filter(*keys)
yield
finally:
policy.clear_filter()
11 changes: 6 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import benchmarks
from . import config
from . import model
from . import rbac
from . import util
from .test_distributed_api import TestDistributedApi
from .test_enforcer import *
from .test_fast_enforcer import TestFastEnforcer
from .test_filter import TestFilteredAdapter
from .test_frontend import TestFrontend
from .test_management_api import TestManagementApi, TestManagementApiSynced
from .test_rbac_api import TestRbacApi, TestRbacApiSynced
from . import benchmarks
from . import config
from . import model
from . import rbac
from . import util
1 change: 1 addition & 0 deletions tests/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .test_policy import TestPolicy
from .test_policy_fast import TestContextManager, TestFastPolicy
101 changes: 101 additions & 0 deletions tests/model/test_policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import TestCase

from casbin.model import FastPolicy, fast_policy_filter


class TestFastPolicy(TestCase):
def test_able_to_add_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_does_not_add_duplicates(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_can_remove_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.remove(["sub", "obj", "read"])

assert list(policy) == []

def test_returns_lengtt(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert len(policy) == 1

def test_supports_in_keyword(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert ["sub", "obj", "read"] in policy

def test_supports_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")

assert list(policy) == [["sub", "obj2", "read2"]]

def test_clears_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")
policy.clear_filter()

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]


class TestContextManager:
def test_fast_policy_filter(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

with fast_policy_filter(policy, "read2", "obj2"):
assert list(policy) == [["sub", "obj2", "read2"]]

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]
Loading

0 comments on commit 67537d6

Please sign in to comment.