Skip to content

Commit

Permalink
fixup! Implement index definition and creation
Browse files Browse the repository at this point in the history
  • Loading branch information
art049 committed Sep 7, 2022
1 parent 3e3d41e commit eeb0596
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
3 changes: 2 additions & 1 deletion odmantic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Type, Union

import pymongo
from pydantic.main import BaseConfig
from pydantic.typing import AnyCallable

Expand Down Expand Up @@ -30,7 +31,7 @@ class BaseODMConfig:
parse_doc_with_default_factories: bool = False

@staticmethod
def indexes() -> Iterable[ODMIndex.Index]:
def indexes() -> Iterable[Union[ODMIndex.Index, pymongo.IndexModel]]:
return []

# Inherited from pydantic
Expand Down
17 changes: 13 additions & 4 deletions odmantic/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from odmantic.exceptions import DocumentNotFoundError, DuplicateKeyError
from odmantic.field import FieldProxy, ODMReference
from odmantic.index import ODMBaseIndex
from odmantic.model import Model
from odmantic.query import QueryExpression, SortExpression, and_
from odmantic.session import (
Expand Down Expand Up @@ -366,7 +367,11 @@ async def configure_database(
for model in models:
collection = self.get_collection(model)
for index in model.__indexes__():
pymongo_index = index.get_pymongo_index()
pymongo_index = (
index.get_pymongo_index()
if isinstance(index, ODMBaseIndex)
else index
)
try:
await collection.create_indexes(
[pymongo_index], session=driver_session
Expand All @@ -377,7 +382,7 @@ async def configure_database(
86, # aka IndexKeySpecsConflict for MongoDB > 5
):
await collection.drop_index(
index.get_index_specifier(), session=driver_session
pymongo_index.document["name"], session=driver_session
)
await collection.create_indexes(
[pymongo_index], session=driver_session
Expand Down Expand Up @@ -776,7 +781,11 @@ def configure_database(
for model in models:
collection = self.get_collection(model)
for index in model.__indexes__():
pymongo_index = index.get_pymongo_index()
pymongo_index = (
index.get_pymongo_index()
if isinstance(index, ODMBaseIndex)
else index
)
try:
collection.create_indexes([pymongo_index], session=driver_session)
except pymongo.errors.OperationFailure as exc:
Expand All @@ -785,7 +794,7 @@ def configure_database(
86, # aka IndexKeySpecsConflict for MongoDB > 5
):
collection.drop_index(
index.get_index_specifier(), session=driver_session
pymongo_index.document["name"], session=driver_session
)
collection.create_indexes(
[pymongo_index], session=driver_session
Expand Down
9 changes: 5 additions & 4 deletions odmantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import bson
import pydantic
import pymongo
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field as PDField
from pydantic.fields import FieldInfo as PDFieldInfo
Expand Down Expand Up @@ -61,7 +62,7 @@
ODMFieldInfo,
ODMReference,
)
from odmantic.index import ODMBaseIndex, ODMSingleFieldIndex
from odmantic.index import Index, ODMBaseIndex, ODMSingleFieldIndex
from odmantic.reference import ODMReferenceInfo
from odmantic.utils import (
is_dunder,
Expand Down Expand Up @@ -780,8 +781,8 @@ def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)

@classmethod
def __indexes__(cls) -> Tuple[ODMBaseIndex, ...]:
indexes: List[ODMBaseIndex] = []
def __indexes__(cls) -> Tuple[Union[ODMBaseIndex, pymongo.IndexModel], ...]:
indexes: List[Union[ODMBaseIndex, pymongo.IndexModel]] = []
for field in cls.__odm_fields__.values():
if isinstance(field, ODMBaseIndexableField) and (
field.index or field.unique
Expand All @@ -794,7 +795,7 @@ def __indexes__(cls) -> Tuple[ODMBaseIndex, ...]:
)

for index in cast(BaseODMConfig, cls.Config).indexes():
indexes.append(index.to_odm_index())
indexes.append(index.to_odm_index() if isinstance(index, Index) else index)
return tuple(indexes)

def update(
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pymongo
import pytest

from odmantic.engine import AIOEngine, SyncEngine
Expand Down Expand Up @@ -310,3 +311,37 @@ class Config:
)
is not None
)


async def test_custom_text_index(aio_engine: AIOEngine):
class Post(Model):
title: str
content: str

class Config:
@staticmethod
def indexes():
yield pymongo.IndexModel(
[("title", pymongo.TEXT), ("content", pymongo.TEXT)]
)

await aio_engine.configure_database([Post])
await aio_engine.save(Post(title="My post on python", content="It's awesome!"))
assert await aio_engine.find_one(Post, {"$text": {"$search": "python"}}) is not None


async def test_sync_custom_text_index(sync_engine: SyncEngine):
class Post(Model):
title: str
content: str

class Config:
@staticmethod
def indexes():
yield pymongo.IndexModel(
[("title", pymongo.TEXT), ("content", pymongo.TEXT)]
)

sync_engine.configure_database([Post])
sync_engine.save(Post(title="My post on python", content="It's awesome!"))
assert sync_engine.find_one(Post, {"$text": {"$search": "python"}}) is not None

0 comments on commit eeb0596

Please sign in to comment.