Skip to content

Commit

Permalink
Search - better type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
RoDuth committed Oct 3, 2024
1 parent bbe18dd commit 8e02560
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 51 deletions.
9 changes: 6 additions & 3 deletions bauble/search/clauses.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@
AGGREGATE_FUNC_NAMES = ["sum", "avg", "min", "max", "count", "total"]


Q = typing.TypeVar("Q", Query, Select)


@dataclass
class QueryHandler:
class QueryHandler(typing.Generic[Q]):
session: Session
domain: Base
query: Query | Select
domain: type[Base]
query: Q


class ClauseAction(ABC):
Expand Down
72 changes: 29 additions & 43 deletions bauble/search/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,73 +34,59 @@
from abc import abstractmethod

from pyparsing import ParseResults
from sqlalchemy.orm import Query
from sqlalchemy.orm import QueryableAttribute
from sqlalchemy.orm import aliased
from sqlalchemy.sql import Select
from sqlalchemy.sql import distinct
from sqlalchemy.sql.elements import ColumnElement

from bauble.db import Base
from bauble.db import get_related_class

from .clauses import Q
from .clauses import QueryHandler
from .operations import OPERATIONS


@typing.overload
def create_joins(
query: Query,
cls: Base,
query: Q,
cls: type[Base],
steps: list[str],
alias: bool = False,
alias: bool = True,
alias_return: str = "",
_current: Base | None = None,
) -> tuple[Query, Base, Base]:
"""When provided a Query and alias_return return a Query and include the
alias."""
_current: type[Base] | None = None,
) -> tuple[Q, type[Base], type[Base]]:
"""When alias is True and alias_return is provided include the alias."""


@typing.overload
def create_joins(
query: Query,
cls: Base,
query: Q,
cls: type[Base],
steps: list[str],
alias: bool = False,
alias_return: None = None,
_current: Base | None = None,
) -> tuple[Query, Base, Base | None]:
"""When provided a Query return a Query"""


@typing.overload
def create_joins(
query: Select,
cls: Base,
steps: list[str],
alias: bool = False,
alias_return: str | None = None,
_current: Base | None = None,
) -> tuple[Select, Base, None]:
"""When provided a Select return a Select"""
_current: type[Base] | None = None,
) -> tuple[Q, type[Base], type[Base] | None]:
"""When alias is False don't include the alias."""


def create_joins(
query: Query | Select,
cls: Base,
query: Q,
cls: type[Base],
steps: list[str],
alias: bool = False,
alias_return: str | None = None,
_current: Base | None = None,
) -> tuple[Query | Select, Base, Base | None]:
_current: type[Base] | None = None,
) -> tuple[Q, type[Base], type[Base] | None]:
"""Given a starting query, class and steps add the appropriate `join()`
clauses to the query. Returns the query and the last class in the joins.
"""
# pylint: disable=protected-access
if not hasattr(query, "_to_join"):
# monkeypatch _to_join so it is available at all steps of creating the
# query or will not alias correctly
query._to_join = [cls] # type: ignore[union-attr]
query._to_join = [cls] # type: ignore[attr-defined]
if not steps:
return (query, cls, _current)
step = steps[0]
Expand All @@ -116,15 +102,15 @@ def create_joins(

attribute = getattr(cls, step)

if joinee in query._to_join or alias:
if joinee in query._to_join or alias: # type: ignore[attr-defined]
logger.debug("Aliasing %s", joinee)
joinee = aliased(joinee)
query = query.join(attribute.of_type(joinee))
if step == alias_return:
_current = joinee
else:
query = query.join(attribute)
query._to_join.append(joinee)
query._to_join.append(joinee) # type: ignore[attr-defined]

return create_joins(query, joinee, steps, alias, alias_return, _current)

Expand All @@ -145,8 +131,8 @@ def __repr__(self) -> str:

@abstractmethod
def evaluate(
self, handler: QueryHandler
) -> tuple[Query | Select, QueryableAttribute]:
self, handler: QueryHandler[Q]
) -> tuple[Q, QueryableAttribute]:
"""return pair (query, attribute) where query is an altered query where
the joinpoint is the one relative to the attribute, and attribute is
the attribute itself.
Expand All @@ -168,8 +154,8 @@ def __repr__(self) -> str:
return ".".join(self.steps + [self.leaf])

def evaluate(
self, handler: QueryHandler
) -> tuple[Query | Select, QueryableAttribute]:
self, handler: QueryHandler[Q]
) -> tuple[Q, QueryableAttribute]:
logger.debug("%s::evaluate %s", self.__class__.__name__, self)
if len(self.steps) == 0:
# identifier is an attribute of the table being queried
Expand Down Expand Up @@ -212,7 +198,7 @@ def __repr__(self) -> str:

@staticmethod
def add_filter_clauses(
filter_: ParseResults, handler: QueryHandler, this_cls: Base
filter_: ParseResults, handler: QueryHandler, this_cls: type[Base]
) -> None:
filter_attr = filter_[0]
filter_op = filter_[1]
Expand All @@ -231,8 +217,8 @@ def clause(attr, operation, val) -> ColumnElement:
)

def evaluate(
self, handler: QueryHandler
) -> tuple[Query | Select, QueryableAttribute]:
self, handler: QueryHandler[Q]
) -> tuple[Q, QueryableAttribute]:
logger.debug("%s::evaluate %s", self.__class__.__name__, self)

this_cls = handler.domain
Expand All @@ -250,7 +236,7 @@ def evaluate(
self.add_filter_clauses(filter_, handler, this_cls)

handler.query, cls, _this = create_joins(
typing.cast(Query, handler.query),
handler.query,
this_cls,
self.leaf_indentifier.steps,
)
Expand Down Expand Up @@ -279,8 +265,8 @@ def __repr__(self) -> str:
return f"{self.function}({distinct_str}{self.identifier})"

def evaluate(
self, handler: QueryHandler
) -> tuple[Query | Select, QueryableAttribute]:
self, handler: QueryHandler[Q]
) -> tuple[Q, QueryableAttribute]:
"""Let the identifier compute the query and its attribute, no need to
alter anything right now since the condition on the identifier is
applied in the HAVING and not in the WHERE for aggreate functions and
Expand Down
10 changes: 5 additions & 5 deletions bauble/search/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@
class SearchStrategy(ABC):
"""interface for adding search strategies to a view."""

domains: dict[str, tuple[Base, list[str]]] = {}
domains: dict[str, tuple[type[Base], list[str]]] = {}
shorthand: dict[str, str] = {}
properties: dict[Base, list[str]] = {}
properties: dict[type[Base], list[str]] = {}
# placed here for simple search convenience.
completion_funcs: dict[str, Callable] = {}

Expand All @@ -90,7 +90,7 @@ def __init__(self):
self.session = None

def add_meta(
self, domain: tuple[str, ...], cls: Base, properties: list[str]
self, domain: tuple[str, ...], cls: type[Base], properties: list[str]
) -> None:
"""Add a domain to the search space
Expand Down Expand Up @@ -126,14 +126,14 @@ def add_meta(
self.properties[cls] = properties

@classmethod
def get_domain_classes(cls) -> dict[str, Base]:
def get_domain_classes(cls) -> dict[str, type[Base]]:
"""Returns a dictionary of domains names, as strings, to the classes
they point to.
Only the first domain name per class, as added via add_meta, is
returned.
"""
domains: dict[str, Base] = {}
domains: dict[str, type[Base]] = {}
for domain, item in cls.domains.items():
if item[0] not in domains.values():
domains.setdefault(domain, item[0])
Expand Down

0 comments on commit 8e02560

Please sign in to comment.