Skip to content

Commit

Permalink
Refactors persister connection params into a function
Browse files Browse the repository at this point in the history
So that it's easier to override and adjust from
an inheriting class.

This could be overengineering, but I think this is
probably fine for now.
  • Loading branch information
skrawcz committed Dec 26, 2024
1 parent c2cc4bb commit c5deb2d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 30 deletions.
23 changes: 16 additions & 7 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
from typing import Any, Literal, Optional

from pymongo import MongoClient

Expand Down Expand Up @@ -35,6 +35,11 @@ class MongoDBBasePersister(persistence.BaseStatePersister):
this change backwards compatible.
"""

@classmethod
def default_client(cls) -> Any:
"""Returns the default client for the persister."""
return MongoClient

@classmethod
def from_values(
cls,
Expand All @@ -47,7 +52,7 @@ def from_values(
"""Initializes the MongoDBBasePersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
client = cls.default_client()(uri, **mongo_client_kwargs)
return cls(
client=client,
db_name=db_name,
Expand Down Expand Up @@ -130,14 +135,18 @@ def save(
def __del__(self):
self.client.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = {
def get_connection_params(self) -> dict:
"""Get the connection parameters for the MongoDB persister."""
return {
"uri": self.client.address[0],
"port": self.client.address[1],
"db_name": self.db.name,
"collection_name": self.collection.name,
}

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = self.get_connection_params()
del state["client"]
del state["db"]
del state["collection"]
Expand All @@ -146,7 +155,7 @@ def __getstate__(self) -> dict:
def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume MongoClient.
self.client = MongoClient(connection_params["uri"], connection_params["port"])
self.client = self.default_client()(connection_params["uri"], connection_params["port"])
self.db = self.client[connection_params["db_name"]]
self.collection = self.db[connection_params["collection_name"]]
self.__dict__.update(state)
Expand All @@ -169,7 +178,7 @@ def __init__(
"""Initializes the MongoDBPersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
client = MongoClient(uri, **mongo_client_kwargs)
client = self.default_client()(uri, **mongo_client_kwargs)
super(MongoDBPersister, self).__init__(
client=client,
db_name=db_name,
Expand Down
29 changes: 18 additions & 11 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
from typing import Any, Literal, Optional

from burr.core import persistence, state

Expand All @@ -28,6 +28,11 @@ class RedisBasePersister(persistence.BaseStatePersister):
so this is an attempt to fix that in a backwards compatible way.
"""

@classmethod
def default_client(cls) -> Any:
"""Returns the default client for the persister."""
return redis.Redis

@classmethod
def from_values(
cls,
Expand All @@ -42,7 +47,7 @@ def from_values(
"""Creates a new instance of the RedisBasePersister from passed in values."""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
connection = cls.default_client()(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
return cls(connection, serde_kwargs, namespace)
Expand Down Expand Up @@ -160,24 +165,26 @@ def save(
def __del__(self):
self.connection.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if not hasattr(self.connection, "connection_pool"):
logger.warning("Redis connection is not serializable.")
return state
state["connection_params"] = {
def get_connection_params(self) -> dict:
"""Get the connection parameters for the Redis connection."""
return {
"host": self.connection.connection_pool.connection_kwargs["host"],
"port": self.connection.connection_pool.connection_kwargs["port"],
"db": self.connection.connection_pool.connection_kwargs["db"],
"password": self.connection.connection_pool.connection_kwargs["password"],
}

def __getstate__(self) -> dict:
state = self.__dict__.copy()
# override self.get_connection_params if needed
state["connection_params"] = self.get_connection_params()
del state["connection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume normal redis client.
self.connection = redis.Redis(**connection_params)
# override self.default_client if needed
self.connection = self.default_client()(**connection_params)
self.__dict__.update(state)


Expand Down Expand Up @@ -211,7 +218,7 @@ def __init__(
"""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
connection = self.default_client()(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
super(RedisPersister, self).__init__(connection, serde_kwargs, namespace)
Expand Down
27 changes: 15 additions & 12 deletions burr/integrations/persisters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import json
import logging
from typing import Literal, Optional
from typing import Any, Literal, Optional

from burr.core import persistence, state

Expand Down Expand Up @@ -51,6 +51,11 @@ def from_config(cls, config: dict) -> "PostgreSQLPersister":
table_name=config.get("table_name", "burr_state"),
)

@classmethod
def default_client(cls) -> Any:
"""Returns the default client for the persister."""
return psycopg2.connect

@classmethod
def from_values(
cls,
Expand All @@ -70,7 +75,7 @@ def from_values(
:param port: the port of the PostgreSQL database.
:param table_name: the table name to store things under.
"""
connection = psycopg2.connect(
connection = cls.default_client()(
dbname=db_name, user=user, password=password, host=host, port=port
)
return cls(connection, table_name)
Expand Down Expand Up @@ -246,27 +251,25 @@ def __del__(self):
# closes connection at end when things are being shutdown.
self.connection.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if not hasattr(self.connection, "info"):
logger.warning(
"Postgresql information for connection object not available. Cannot serialize persister."
)
return state
state["connection_params"] = {
def get_connection_params(self) -> dict:
"""Returns the connection parameters for the persister."""
return {
"dbname": self.connection.info.dbname,
"user": self.connection.info.user,
"password": self.connection.info.password,
"host": self.connection.info.host,
"port": self.connection.info.port,
}

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = self.get_connection_params()
del state["connection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume normal psycopg2 client.
self.connection = psycopg2.connect(**connection_params)
self.connection = self.default_client()(**connection_params)
self.__dict__.update(state)


Expand Down

0 comments on commit c5deb2d

Please sign in to comment.