-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from jorzel/sorting
Sorting
Showing
10 changed files
with
364 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,34 @@ | ||
# graphene-sqlalchemy-sort | ||
# graphene-sqlalchemy-sort | ||
|
||
``` | ||
class ExampleSort(SortSet): | ||
class Meta: | ||
model = Example | ||
fields = ["first_name", "second_name", "name"] | ||
@classmethod | ||
def name_sort(cls): | ||
return case( | ||
[ | ||
(Example.second_name.is_(None), Example.first_name), | ||
(Example.first_name.is_(None), Example.second_name), | ||
], | ||
else_=Example.second_name, | ||
) | ||
Query: | ||
""" | ||
{ | ||
examples (sort: {name: "ASC"}) { | ||
edges { | ||
node { | ||
firstName | ||
} | ||
} | ||
} | ||
} | ||
""" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ colorama==0.4.4 | |
flake8==4.0.1 | ||
isort==5.10.1 | ||
pre-commit==2.18.1 | ||
pytest==7.1.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from sqlalchemy import Column, Integer, MetaData, String, create_engine | ||
from sqlalchemy.orm import configure_mappers, declarative_base | ||
|
||
# db config | ||
DB_URI = "sqlite:///:memory:" | ||
engine = create_engine(DB_URI) | ||
metadata = MetaData() | ||
Base = declarative_base(metadata=metadata) | ||
|
||
|
||
class Example(Base): | ||
__tablename__ = "example" | ||
|
||
id = Column(Integer, primary_key=True, autoincrement=True) | ||
counter = Column(Integer, default=0, nullable=False) | ||
first_name = Column(String) | ||
second_name = Column(String) | ||
|
||
def __str__(self): | ||
return f"Example(id={self.id}, first_name={self.first_name}, second_name={self.second_name}" | ||
|
||
__repr__ = __str__ | ||
|
||
|
||
configure_mappers() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import graphene | ||
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType | ||
from sqlalchemy import case | ||
|
||
from models import Example | ||
from sort import SortSet | ||
|
||
|
||
class ExampleNode(SQLAlchemyObjectType): | ||
class Meta: | ||
model = Example | ||
interfaces = (graphene.relay.Node,) | ||
|
||
|
||
class ExampleSort(SortSet): | ||
class Meta: | ||
model = Example | ||
fields = ["first_name", "second_name", "name"] | ||
|
||
@classmethod | ||
def name_sort(cls): | ||
return case( | ||
[ | ||
(Example.second_name.is_(None), Example.first_name), | ||
(Example.first_name.is_(None), Example.second_name), | ||
], | ||
else_=Example.second_name, | ||
) | ||
|
||
|
||
class SchemaQuery(graphene.ObjectType): | ||
examples = SQLAlchemyConnectionField( | ||
ExampleNode, | ||
sort=ExampleSort(), | ||
) | ||
|
||
def resolve_examples(self, info, **kwargs): | ||
query = info.context["session"].query(Example) | ||
if kwargs.get("sort"): | ||
query = ExampleSort().sort(query, kwargs["sort"]) | ||
return query | ||
|
||
|
||
schema = graphene.Schema(query=SchemaQuery) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,94 @@ | ||
import logging | ||
from typing import TYPE_CHECKING | ||
|
||
from sqlalchemy import Column, Integer, MetaData, String, create_engine | ||
from sqlalchemy.orm import configure_mappers, declarative_base | ||
import graphene | ||
from graphene.types.inputobjecttype import InputObjectTypeOptions | ||
from sqlalchemy import desc, inspection, nullslast | ||
from sqlalchemy.orm.attributes import InstrumentedAttribute | ||
|
||
# logging config | ||
LOGGER_FORMAT = "%(asctime)s [%(levelname)s] %(message)s" | ||
logging.basicConfig(format=LOGGER_FORMAT) | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
if TYPE_CHECKING: | ||
from typing import Iterable | ||
|
||
# db config | ||
DB_URI = "sqlite:///:memory:" | ||
engine = create_engine(DB_URI) | ||
metadata = MetaData() | ||
Base = declarative_base(metadata=metadata) | ||
|
||
def _custom_field_func_name(custom_field: str) -> str: | ||
return custom_field + "_sort" | ||
|
||
class Example(Base): | ||
__tablename__ = "example" | ||
|
||
id = Column(Integer, primary_key=True, autoincrement=True) | ||
counter = Column(Integer, default=0, nullable=False) | ||
first_name = Column(String) | ||
second_name = Column(String) | ||
class SortSetOptions(InputObjectTypeOptions): | ||
model = None | ||
fields = None | ||
|
||
def __str__(self): | ||
return f"Example(id={self.id}, important_counter={self.important_counter}" | ||
|
||
__repr__ = __str__ | ||
class SortSet(graphene.InputObjectType): | ||
model = None | ||
|
||
class Meta: | ||
abstract = True | ||
|
||
configure_mappers() | ||
@classmethod | ||
def __init_subclass_with_meta__( | ||
cls, model=None, fields=None, _meta=None, **options | ||
): | ||
if model is None and fields: | ||
raise AttributeError("Model not specified") | ||
|
||
if not _meta: | ||
_meta = SortSetOptions(cls) | ||
|
||
cls.model = model | ||
_meta.model = model | ||
|
||
_meta.fields = cls._generate_default_sort_fields(model, fields) | ||
_meta.fields.update(cls._generate_custom_sort_fields(model, fields)) | ||
if not _meta.fields: | ||
_meta.fields = {} | ||
super().__init_subclass_with_meta__(_meta=_meta, **options) | ||
|
||
@classmethod | ||
def _generate_default_sort_fields(cls, model, sort_fields: "Iterable[str]"): | ||
graphql_sort_fields = {} | ||
model_fields = cls._get_model_fields(model, sort_fields) | ||
for field_name, field_object in model_fields.items(): | ||
graphql_sort_fields[field_name] = graphene.InputField(graphene.String) | ||
return graphql_sort_fields | ||
|
||
@classmethod | ||
def _generate_custom_sort_fields(cls, model, sort_fields: "Iterable[str]"): | ||
graphql_sort_fields = {} | ||
for field in sort_fields: | ||
if not hasattr(cls, _custom_field_func_name(field)): | ||
continue | ||
graphql_sort_fields[field] = graphene.InputField(graphene.String) | ||
return graphql_sort_fields | ||
|
||
@classmethod | ||
def _get_model_fields(cls, model, only_fields: "Iterable[str]"): | ||
model_fields = {} | ||
inspected = inspection.inspect(model) | ||
for descr in inspected.all_orm_descriptors: | ||
if isinstance(descr, InstrumentedAttribute): | ||
attr = descr.property | ||
name = attr.key | ||
if name not in only_fields: | ||
continue | ||
|
||
column = attr.columns[0] | ||
model_fields[name] = { | ||
"column": column, | ||
"type": column.type, | ||
"nullable": column.nullable, | ||
} | ||
return model_fields | ||
|
||
@classmethod | ||
def sort(cls, query, args): | ||
sort_fields = [] | ||
for field, ordering in args.items(): | ||
_field = field | ||
if hasattr(cls, _custom_field_func_name(_field)): | ||
_field = getattr(cls, _custom_field_func_name(_field))() | ||
if ordering.strip().lower() == "desc": | ||
_field = nullslast(desc(_field)) | ||
else: | ||
_field = nullslast(_field) | ||
sort_fields.append(_field) | ||
return query.order_by(*sort_fields) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pytest | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
from models import Base, Example | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def db_connection(): | ||
SQLALCHEMY_DATABASE_URL = "sqlite:///" | ||
engine = create_engine(SQLALCHEMY_DATABASE_URL) | ||
Base.metadata.drop_all(engine) | ||
Base.metadata.create_all(engine) | ||
connection = engine.connect() | ||
|
||
yield connection | ||
|
||
Base.metadata.drop_all(engine) | ||
engine.dispose() | ||
|
||
|
||
@pytest.fixture | ||
def db_session(db_connection): | ||
transaction = db_connection.begin() | ||
session = sessionmaker(bind=db_connection) | ||
db_session = session() | ||
|
||
yield db_session | ||
|
||
transaction.rollback() | ||
db_session.close() | ||
|
||
|
||
@pytest.fixture | ||
def example_factory(db_session): | ||
def _example_factory(**kwargs): | ||
example = Example(**kwargs) | ||
db_session.add(example) | ||
db_session.flush() | ||
return example | ||
|
||
yield _example_factory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from models import Example | ||
|
||
|
||
def test_example(example_factory): | ||
example = example_factory(first_name="test") | ||
|
||
assert example.first_name == "test" | ||
|
||
|
||
def test_example_ordering(example_factory, db_session): | ||
example1 = example_factory(first_name="test", second_name="cest") | ||
example2 = example_factory(first_name="atest", second_name="atest") | ||
example3 = example_factory(first_name="atest", second_name="ttest") | ||
|
||
results = ( | ||
db_session.query(Example) | ||
.order_by(Example.first_name) | ||
.order_by(Example.second_name.desc()) | ||
).all() | ||
|
||
assert results == [example3, example2, example1] |
Oops, something went wrong.