Skip to content

Commit

Permalink
Avoid updating partitions table when unnecessary
Browse files Browse the repository at this point in the history
This commit refactors the loader code to avoid unnecessary
partition updates. Partition updates should be avoided
when unnecessary as they can have a large performance impact
for partitions containing many items.
  • Loading branch information
lossyrob committed May 14, 2022
1 parent 3a639e8 commit 787fc09
Showing 1 changed file with 100 additions and 74 deletions.
174 changes: 100 additions & 74 deletions pypgstac/pypgstac/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@
logger = logging.getLogger(__name__)


@dataclass
class Partition:
name: str
collection: str
datetime_range_min: str
datetime_range_max: str
end_datetime_range_min: str
end_datetime_range_max: str
requires_update: bool


def chunked_iterable(iterable: Iterable, size: Optional[int] = 10000) -> Iterable:
"""Chunk an iterable."""
it = iter(iterable)
Expand Down Expand Up @@ -137,15 +148,18 @@ def read_json(file: Union[Path, str, Iterator[Any]] = "stdin") -> Iterable:
yield orjson.loads(line)


@dataclass
class Loader:
"""Utilities for loading data."""

db: PgstacDB
_partition_cache: Optional[dict] = None
_partition_cache: Dict[str, Partition]

def __init__(self, db: PgstacDB):
self.db = db
self._partition_cache: Dict[str, Partition] = {}

@lru_cache
def collection_json(self, collection_id: str) -> Tuple[dict, int, str]:
def collection_json(self, collection_id: str) -> Tuple[Dict[str, Any], int, str]:
"""Get collection."""
res = self.db.query_one(
"SELECT base_item, key, partition_trunc FROM collections WHERE id=%s",
Expand Down Expand Up @@ -230,8 +244,8 @@ def load_collections(
)
def load_partition(
self,
partition: dict,
items: Iterable,
partition: Partition,
items: Iterable[Dict[str, Any]],
insert_mode: Optional[Methods] = Methods.insert,
) -> None:
"""Load items data for a single partition."""
Expand All @@ -240,34 +254,40 @@ def load_partition(

logger.debug(f"Loading data for partition: {partition}.")
with conn.cursor() as cur:
if partition.requires_update:
with conn.transaction():
cur.execute(
"SELECT * FROM partitions WHERE name = %s FOR UPDATE;",
(partition.name,),
)
cur.execute(
"""
INSERT INTO partitions
(collection, datetime_range, end_datetime_range)
VALUES
(%s, tstzrange(%s, %s, '[]'), tstzrange(%s,%s, '[]'))
ON CONFLICT (name) DO UPDATE SET
datetime_range = EXCLUDED.datetime_range,
end_datetime_range = EXCLUDED.end_datetime_range
;
""",
(
partition.collection,
partition.datetime_range_min,
partition.datetime_range_max,
partition.end_datetime_range_min,
partition.end_datetime_range_max,
),
)
logger.debug(
f"Adding or updating partition {partition.name} "
f"took {time.perf_counter() - t}s"
)
partition.requires_update = False
else:
logger.debug(f"Partition {partition.name} does not require an update.")

with conn.transaction():
cur.execute(
"SELECT * FROM partitions WHERE name = %s FOR UPDATE;",
(partition["partition"],),
)
cur.execute(
"""
INSERT INTO partitions
(collection, datetime_range, end_datetime_range)
VALUES
(%s, tstzrange(%s, %s, '[]'), tstzrange(%s,%s, '[]'))
ON CONFLICT (name) DO UPDATE SET
datetime_range = EXCLUDED.datetime_range,
end_datetime_range = EXCLUDED.end_datetime_range
;
""",
(
partition["collection"],
partition["mindt"],
partition["maxdt"],
partition["minedt"],
partition["maxedt"],
),
)
with conn.transaction():
logger.debug(
f"Adding partition {partition} took {time.perf_counter() - t}s"
)
t = time.perf_counter()
if insert_mode is None or insert_mode == "insert":
with cur.copy(
Expand All @@ -277,7 +297,7 @@ def load_partition(
(id, collection, datetime, end_datetime, geometry, content)
FROM stdin;
"""
).format(sql.Identifier(partition["partition"]))
).format(sql.Identifier(partition.name))
) as copy:
for item in items:
item.pop("partition")
Expand Down Expand Up @@ -333,7 +353,7 @@ def load_partition(
"""
LOCK TABLE ONLY {} IN EXCLUSIVE MODE;
"""
).format(sql.Identifier(partition["partition"]))
).format(sql.Identifier(partition.name))
)
if insert_mode in ("ignore", "insert_ignore"):
cur.execute(
Expand All @@ -343,7 +363,7 @@ def load_partition(
SELECT *
FROM items_ingest_temp ON CONFLICT DO NOTHING;
"""
).format(sql.Identifier(partition["partition"]))
).format(sql.Identifier(partition.name))
)
logger.debug(cur.statusmessage)
logger.debug(f"Rows affected: {cur.rowcount}")
Expand All @@ -362,7 +382,7 @@ def load_partition(
WHERE t IS DISTINCT FROM EXCLUDED
;
"""
).format(sql.Identifier(partition["partition"]))
).format(sql.Identifier(partition.name))
)
logger.debug(cur.statusmessage)
logger.debug(f"Rows affected: {cur.rowcount}")
Expand All @@ -385,7 +405,7 @@ def load_partition(
USING (id, collection);
;
"""
).format(sql.Identifier(partition["partition"]))
).format(sql.Identifier(partition.name))
)
logger.debug(cur.statusmessage)
logger.debug(f"Rows affected: {cur.rowcount}")
Expand All @@ -398,8 +418,12 @@ def load_partition(
f"Copying data for {partition} took {time.perf_counter() - t} seconds"
)

def _partition_update(self, item: dict) -> str:
def _partition_update(self, item: Dict[str, Any]) -> str:
"""Update the cached partition with the item information and return the name.
This method will mark the partition as dirty if the bounds of the partition
need to be updated based on this item.
"""
p = item.get("partition", None)
if p is None:
_, key, partition_trunc = self.collection_json(item["collection"])
Expand All @@ -413,37 +437,36 @@ def _partition_update(self, item: dict) -> str:
p = f"_items_{key}"
item["partition"] = p

if self._partition_cache is None:
self._partition_cache = {}

partition = self._partition_cache.get(
item["partition"],
{
"partition": None,
"collection": None,
"mindt": None,
"maxdt": None,
"minedt": None,
"maxedt": None,
},
)

partition["partition"] = item["partition"]
partition["collection"] = item["collection"]
if partition["mindt"] is None or item["datetime"] < partition["mindt"]:
partition["mindt"] = item["datetime"]

if partition["maxdt"] is None or item["datetime"] > partition["maxdt"]:
partition["maxdt"] = item["datetime"]

if partition["minedt"] is None or item["end_datetime"] < partition["minedt"]:
partition["minedt"] = item["end_datetime"]
partition_name: str = p

if partition_name not in self._partition_cache:
partition = Partition(
name=partition_name,
collection=item["collection"],
datetime_range_min=item["datetime"],
datetime_range_max=item["datetime"],
end_datetime_range_min=item["end_datetime"],
end_datetime_range_max=item["end_datetime"],
requires_update=True,
)
else:
partition = self._partition_cache[partition_name]

if item["datetime"] < partition.datetime_range_min:
partition.datetime_range_min = item["datetime"]
partition.requires_update = True
if item["datetime"] > partition.datetime_range_max:
partition.datetime_range_max = item["datetime"]
partition.requires_update = True
if item["end_datetime"] < partition.end_datetime_range_min:
partition.end_datetime_range_min = item["end_datetime"]
partition.requires_update = True
if item["end_datetime"] > partition.end_datetime_range_max:
partition.end_datetime_range_max = item["end_datetime"]

if partition["maxedt"] is None or item["end_datetime"] > partition["maxedt"]:
partition["maxedt"] = item["end_datetime"]
self._partition_cache[item["partition"]] = partition

return p
return partition_name

def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
if file is None:
Expand Down Expand Up @@ -497,10 +520,10 @@ def load_items(

logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.")

def format_item(self, _item: Union[Path, str, dict]) -> dict:
def format_item(self, _item: Union[Path, str, Dict[str, Any]]) -> Dict[str, Any]:
"""Format an item to insert into a record."""
out = {}
item: dict
out: Dict[str, Any] = {}
item: Dict[str, Any]
if not isinstance(_item, dict):
try:
item = orjson.loads(str(_item).replace("\\\\", "\\"))
Expand All @@ -513,11 +536,11 @@ def format_item(self, _item: Union[Path, str, dict]) -> dict:

out["id"] = item.get("id")
out["collection"] = item.get("collection")
properties: dict = item.get("properties", {})
properties: Dict[str, Any] = item.get("properties", {})

dt = properties.get("datetime")
edt = properties.get("end_datetime")
sdt = properties.get("start_datetime")
dt: Optional[str] = properties.get("datetime")
edt: Optional[str] = properties.get("end_datetime")
sdt: Optional[str] = properties.get("start_datetime")

if edt is not None and sdt is not None:
out["datetime"] = sdt
Expand Down Expand Up @@ -548,7 +571,10 @@ def format_item(self, _item: Union[Path, str, dict]) -> dict:
if geojson is None:
geometry = None
else:
geometry = str(Geometry.from_geojson(geojson).wkb)
geom = Geometry.from_geojson(geojson)
if geom is None:
raise Exception(f"Invalid geometry encountered: {geojson}")
geometry = str(geom.wkb)
out["geometry"] = geometry

content = dehydrate(base_item, item)
Expand Down

0 comments on commit 787fc09

Please sign in to comment.