-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_utils.py
73 lines (57 loc) · 1.98 KB
/
db_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import contextlib
import uuid
from sqlalchemy.dialects.postgresql import UUID as _UUID
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.types import CHAR, TypeDecorator
from utils import get_env_var
engine = create_async_engine(get_env_var("DB_URL"))
class Base(DeclarativeBase):
pass
# noinspection PyAbstractClass
class UUID(TypeDecorator):
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type, otherwise uses
CHAR(32), storing as stringified hex values.
http://docs.sqlalchemy.org/en/latest/core/custom_types.html#backend-agnostic-guid-type
"""
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(_UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR(32))
# noinspection PyUnresolvedReferences
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == "postgresql":
return str(value)
else:
if not isinstance(value, uuid.UUID):
return "%.32x" % uuid.UUID(value).int
else:
# hexstring
return "%.32x" % value.int
def process_result_value(self, value, dialect):
if value is None:
return value
else:
if not isinstance(value, uuid.UUID):
return uuid.UUID(value)
else:
return value
@contextlib.asynccontextmanager
async def session_scope(autocommit=True):
async with AsyncSession(engine) as session:
try:
yield session
if autocommit:
await session.commit()
except BaseException:
await session.rollback()
raise
async def prepare_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)