Skip to content

Commit

Permalink
Merge pull request #1 from jorzel/sorting
Browse files Browse the repository at this point in the history
Sorting
jorzel authored Apr 20, 2022
2 parents 99c079f + 736cbe9 commit 16b28dc
Showing 10 changed files with 364 additions and 25 deletions.
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
# graphene-sqlalchemy-sort
# graphene-sqlalchemy-sort

```
class ExampleSort(SortSet):
class Meta:
model = Example
fields = ["first_name", "second_name", "name"]
@classmethod
def name_sort(cls):
return case(
[
(Example.second_name.is_(None), Example.first_name),
(Example.first_name.is_(None), Example.second_name),
],
else_=Example.second_name,
)
Query:
"""
{
examples (sort: {name: "ASC"}) {
edges {
node {
firstName
}
}
}
}
"""
```
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -3,3 +3,4 @@ colorama==0.4.4
flake8==4.0.1
isort==5.10.1
pre-commit==2.18.1
pytest==7.1.1
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -4,4 +4,4 @@ multi_line_output = 3
src_paths = "src/"

[tool.pytest.ini_options]
python_files = "app/tests/*.py"
python_files = "src/tests/*.py"
File renamed without changes.
25 changes: 25 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sqlalchemy import Column, Integer, MetaData, String, create_engine
from sqlalchemy.orm import configure_mappers, declarative_base

# db config
DB_URI = "sqlite:///:memory:"
engine = create_engine(DB_URI)
metadata = MetaData()
Base = declarative_base(metadata=metadata)


class Example(Base):
__tablename__ = "example"

id = Column(Integer, primary_key=True, autoincrement=True)
counter = Column(Integer, default=0, nullable=False)
first_name = Column(String)
second_name = Column(String)

def __str__(self):
return f"Example(id={self.id}, first_name={self.first_name}, second_name={self.second_name}"

__repr__ = __str__


configure_mappers()
44 changes: 44 additions & 0 deletions src/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import graphene
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
from sqlalchemy import case

from models import Example
from sort import SortSet


class ExampleNode(SQLAlchemyObjectType):
class Meta:
model = Example
interfaces = (graphene.relay.Node,)


class ExampleSort(SortSet):
class Meta:
model = Example
fields = ["first_name", "second_name", "name"]

@classmethod
def name_sort(cls):
return case(
[
(Example.second_name.is_(None), Example.first_name),
(Example.first_name.is_(None), Example.second_name),
],
else_=Example.second_name,
)


class SchemaQuery(graphene.ObjectType):
examples = SQLAlchemyConnectionField(
ExampleNode,
sort=ExampleSort(),
)

def resolve_examples(self, info, **kwargs):
query = info.context["session"].query(Example)
if kwargs.get("sort"):
query = ExampleSort().sort(query, kwargs["sort"])
return query


schema = graphene.Schema(query=SchemaQuery)
107 changes: 84 additions & 23 deletions src/sort.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,94 @@
import logging
from typing import TYPE_CHECKING

from sqlalchemy import Column, Integer, MetaData, String, create_engine
from sqlalchemy.orm import configure_mappers, declarative_base
import graphene
from graphene.types.inputobjecttype import InputObjectTypeOptions
from sqlalchemy import desc, inspection, nullslast
from sqlalchemy.orm.attributes import InstrumentedAttribute

# logging config
LOGGER_FORMAT = "%(asctime)s [%(levelname)s] %(message)s"
logging.basicConfig(format=LOGGER_FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if TYPE_CHECKING:
from typing import Iterable

# db config
DB_URI = "sqlite:///:memory:"
engine = create_engine(DB_URI)
metadata = MetaData()
Base = declarative_base(metadata=metadata)

def _custom_field_func_name(custom_field: str) -> str:
return custom_field + "_sort"

class Example(Base):
__tablename__ = "example"

id = Column(Integer, primary_key=True, autoincrement=True)
counter = Column(Integer, default=0, nullable=False)
first_name = Column(String)
second_name = Column(String)
class SortSetOptions(InputObjectTypeOptions):
model = None
fields = None

def __str__(self):
return f"Example(id={self.id}, important_counter={self.important_counter}"

__repr__ = __str__
class SortSet(graphene.InputObjectType):
model = None

class Meta:
abstract = True

configure_mappers()
@classmethod
def __init_subclass_with_meta__(
cls, model=None, fields=None, _meta=None, **options
):
if model is None and fields:
raise AttributeError("Model not specified")

if not _meta:
_meta = SortSetOptions(cls)

cls.model = model
_meta.model = model

_meta.fields = cls._generate_default_sort_fields(model, fields)
_meta.fields.update(cls._generate_custom_sort_fields(model, fields))
if not _meta.fields:
_meta.fields = {}
super().__init_subclass_with_meta__(_meta=_meta, **options)

@classmethod
def _generate_default_sort_fields(cls, model, sort_fields: "Iterable[str]"):
graphql_sort_fields = {}
model_fields = cls._get_model_fields(model, sort_fields)
for field_name, field_object in model_fields.items():
graphql_sort_fields[field_name] = graphene.InputField(graphene.String)
return graphql_sort_fields

@classmethod
def _generate_custom_sort_fields(cls, model, sort_fields: "Iterable[str]"):
graphql_sort_fields = {}
for field in sort_fields:
if not hasattr(cls, _custom_field_func_name(field)):
continue
graphql_sort_fields[field] = graphene.InputField(graphene.String)
return graphql_sort_fields

@classmethod
def _get_model_fields(cls, model, only_fields: "Iterable[str]"):
model_fields = {}
inspected = inspection.inspect(model)
for descr in inspected.all_orm_descriptors:
if isinstance(descr, InstrumentedAttribute):
attr = descr.property
name = attr.key
if name not in only_fields:
continue

column = attr.columns[0]
model_fields[name] = {
"column": column,
"type": column.type,
"nullable": column.nullable,
}
return model_fields

@classmethod
def sort(cls, query, args):
sort_fields = []
for field, ordering in args.items():
_field = field
if hasattr(cls, _custom_field_func_name(_field)):
_field = getattr(cls, _custom_field_func_name(_field))()
if ordering.strip().lower() == "desc":
_field = nullslast(desc(_field))
else:
_field = nullslast(_field)
sort_fields.append(_field)
return query.order_by(*sort_fields)
42 changes: 42 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from models import Base, Example


@pytest.fixture(scope="session")
def db_connection():
SQLALCHEMY_DATABASE_URL = "sqlite:///"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
connection = engine.connect()

yield connection

Base.metadata.drop_all(engine)
engine.dispose()


@pytest.fixture
def db_session(db_connection):
transaction = db_connection.begin()
session = sessionmaker(bind=db_connection)
db_session = session()

yield db_session

transaction.rollback()
db_session.close()


@pytest.fixture
def example_factory(db_session):
def _example_factory(**kwargs):
example = Example(**kwargs)
db_session.add(example)
db_session.flush()
return example

yield _example_factory
21 changes: 21 additions & 0 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from models import Example


def test_example(example_factory):
example = example_factory(first_name="test")

assert example.first_name == "test"


def test_example_ordering(example_factory, db_session):
example1 = example_factory(first_name="test", second_name="cest")
example2 = example_factory(first_name="atest", second_name="atest")
example3 = example_factory(first_name="atest", second_name="ttest")

results = (
db_session.query(Example)
.order_by(Example.first_name)
.order_by(Example.second_name.desc())
).all()

assert results == [example3, example2, example1]
Loading
Oops, something went wrong.

0 comments on commit 16b28dc

Please sign in to comment.