Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2271 #2272

Merged
merged 3 commits into from
Jan 18, 2025
Merged

#2271 #2272

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 994
__build__ = 996

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions opteryx/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .aggregate_node import AGGREGATORS
from .aggregate_node import AggregateNode # aggregate data
from .async_read_node import AsyncReaderNode
from .simple_aggregate_node import SimpleAggregateNode # aggregate data

# from .build_statistics_node import BuildStatisticsNode # Analyze Tables
from .cross_join_node import CrossJoinNode # CROSS JOIN
Expand Down
154 changes: 154 additions & 0 deletions opteryx/operators/simple_aggregate_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# See the License at http://www.apache.org/licenses/LICENSE-2.0
# Distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND.

"""
Simple Aggregation Node

This is a SQL Query Execution Plan Node.

This node performs aggregates without performing groupings, this is a specialized version
which focuses on building aggregates which don't require seeing the entire dataset at a time.

We avoid doing some work by not creating entire columns of data where possible.
"""

import pyarrow

from opteryx import EOS
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import evaluate_and_append
from opteryx.models import QueryProperties
from opteryx.operators.aggregate_node import extract_evaluations

from . import BasePlanNode


class SimpleAggregateCollector:
def __init__(self, aggregate_type, column_id, count_nulls=False):
self.aggregate_type = aggregate_type
self.current_value = None
self.count_nulls = count_nulls if aggregate_type == "COUNT" else False
self.counter = 0
self.column_id = column_id

def collect(self, values):
if self.count_nulls:
self.counter += pyarrow.compute.count(values).as_py()
else:
self.counter += pyarrow.compute.count(values, mode="only_valid").as_py()

if self.current_value is None:
if self.aggregate_type in ("SUM", "AVG"):
self.current_value = pyarrow.compute.sum(values).as_py()
elif self.aggregate_type == "MIN":
self.current_value = pyarrow.compute.min(values).as_py()
elif self.aggregate_type == "MAX":
self.current_value = pyarrow.compute.max(values).as_py()
elif self.aggregate_type != "COUNT":
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}")
else:
if self.aggregate_type in ("SUM", "AVG"):
self.current_value += pyarrow.compute.sum(values).as_py()
elif self.aggregate_type == "MIN":
self.current_value = min(self.current_value, pyarrow.compute.min(values).as_py())
elif self.aggregate_type == "MAX":
self.current_value = max(self.current_value, pyarrow.compute.max(values).as_py())
elif self.aggregate_type != "COUNT":
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}")

def collect_literal(self, literal, count):
self.counter += count

if self.current_value is None:
if self.aggregate_type in ("SUM", "AVG"):
self.current_value = literal * count
elif self.aggregate_type == "MIN" or self.aggregate_type == "MAX":
self.current_value = literal
elif self.aggregate_type != "COUNT":
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}")
else:
if self.aggregate_type in ("SUM", "AVG"):
self.current_value += literal * count
elif self.aggregate_type == "MIN":
self.current_value = min(self.current_value, literal)
elif self.aggregate_type == "MAX":
self.current_value = max(self.current_value, literal)
elif self.aggregate_type != "COUNT":
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}")

def get_result(self):
if self.aggregate_type == "AVG":
if self.counter == 0 or self.current_value is None:
return None
return self.current_value / self.counter
if self.aggregate_type == "COUNT":
return self.counter
return self.current_value


class SimpleAggregateNode(BasePlanNode):
SIMPLE_AGGREGATES = {"SUM", "MIN", "MAX", "AVG", "COUNT"}

def __init__(self, properties: QueryProperties, **parameters):
BasePlanNode.__init__(self, properties=properties, **parameters)

self.aggregates = parameters.get("aggregates", [])

# Get any functions we need to execute before aggregating
self.evaluatable_nodes = extract_evaluations(self.aggregates)

# Create collectors for each aggregate
self.accumulator = {}
for aggregate in self.aggregates:
aggregate_type = aggregate.value
final_column_id = aggregate.schema_column.identity

self.accumulator[final_column_id] = SimpleAggregateCollector(
aggregate_type, final_column_id
)

@property
def config(self): # pragma: no cover
return str(self.aggregates)

@property
def name(self): # pragma: no cover
return "Aggregation (Simple)"

def execute(self, morsel: pyarrow.Table, **kwargs) -> pyarrow.Table:
if morsel == EOS:
names = []
values = []
for k, v in self.accumulator.items():
names.append(k)
values.append([v.get_result()])
yield pyarrow.Table.from_arrays(values, names=names)
yield EOS
return

# Allow grouping by functions by evaluating them first
if self.evaluatable_nodes:
morsel = evaluate_and_append(self.evaluatable_nodes, morsel)

for aggregate in self.aggregates:
if aggregate.node_type in (NodeType.AGGREGATOR,):
column_node = aggregate.parameters[0]

if column_node.node_type == NodeType.LITERAL:
self.accumulator[aggregate.schema_column.identity].collect_literal(
column_node.value, morsel.num_rows
)
elif column_node.node_type == NodeType.WILDCARD:
if "$COUNT(*)" in morsel.column_names:
self.accumulator[aggregate.schema_column.identity].collect_literal(
1, morsel["$COUNT(*)"][0].as_py()
)
else:
self.accumulator[aggregate.schema_column.identity].collect_literal(
1, morsel.num_rows
)
else:
raw_column_values = morsel[column_node.schema_column.identity]
self.accumulator[aggregate.schema_column.identity].collect(raw_column_values)
9 changes: 4 additions & 5 deletions opteryx/planner/physical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# Distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND.


from orso.schema import OrsoTypes

from opteryx import operators as operators
from opteryx.exceptions import InvalidInternalStateError
from opteryx.exceptions import UnsupportedSyntaxError
Expand All @@ -23,11 +21,12 @@ def create_physical_plan(logical_plan, query_properties) -> PhysicalPlan:

# fmt: off
if node_type == LogicalPlanStepType.Aggregate:
node = operators.AggregateNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "all_relations")})
if all(agg.value in operators.SimpleAggregateNode.SIMPLE_AGGREGATES for agg in node_config["aggregates"]):
node = operators.SimpleAggregateNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "all_relations")})
else:
node = operators.AggregateNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "all_relations")})
elif node_type == LogicalPlanStepType.AggregateAndGroup:
node = operators.AggregateAndGroupNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "groups", "projection", "all_relations")})
# elif node_type == LogicalPlanStepType.Defragment:
# node = operators.MorselDefragmentNode(query_properties, **node_config)
elif node_type == LogicalPlanStepType.Distinct:
node = operators.DistinctNode(query_properties, **node_config)
elif node_type == LogicalPlanStepType.Exit:
Expand Down
41 changes: 41 additions & 0 deletions tests/query_execution/test_agg_avg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx

def test_avg_parquet():
cur = opteryx.query("SELECT AVG(followers) FROM testdata.flat.formats.parquet")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 3081.258, first

def test_avg_non_parquet():
cur = opteryx.query("SELECT AVG(followers) FROM testdata.flat.ten_files;")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 250, stats["rows_read"]
assert stats["rows_seen"] == 250, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 7500362.668, first

def test_avg_group_by():
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """
cur = opteryx.query(
"SELECT AVG(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;"
)
stats = cur.stats
assert stats["columns_read"] == 2, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 6.0, first

if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_count_star_parquet():
assert stats["columns_read"] <= 1, stats["columns_read"]
assert stats["rows_read"] == 1, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
assert cur.fetchone()[0] == 100000
first = cur.fetchone()[0]
assert first == 100000, first

def test_count_star_non_parquet():
""" if is just SELECT COUNT(*) for non-parquet files, we read the rows"""
Expand All @@ -38,7 +39,8 @@ def test_count_star_parquest_read_the_rows():
assert stats["columns_read"] <= 1, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
assert cur.fetchone()[0] == 100000
first = cur.fetchone()[0]
assert first == 100000, first

def test_count_star_group_by():
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """
Expand Down
41 changes: 41 additions & 0 deletions tests/query_execution/test_agg_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx

def test_max_parquet():
cur = opteryx.query("SELECT MAX(followers) FROM testdata.flat.formats.parquet")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 8266250, first

def test_max_non_parquet():
cur = opteryx.query("SELECT MAX(followers) FROM testdata.flat.ten_files;")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 250, stats["rows_read"]
assert stats["rows_seen"] == 250, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 10415949, first

def test_max_group_by():
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """
cur = opteryx.query(
"SELECT MAX(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;"
)
stats = cur.stats
assert stats["columns_read"] == 2, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 6.0, first

if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
41 changes: 41 additions & 0 deletions tests/query_execution/test_agg_min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx

def test_min_parquet():
cur = opteryx.query("SELECT MIN(followers) FROM testdata.flat.formats.parquet")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 0, first

def test_min_non_parquet():
cur = opteryx.query("SELECT MIN(followers) FROM testdata.flat.ten_files;")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 250, stats["rows_read"]
assert stats["rows_seen"] == 250, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 100, first

def test_min_group_by():
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """
cur = opteryx.query(
"SELECT MIN(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;"
)
stats = cur.stats
assert stats["columns_read"] == 2, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 6.0, first

if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
41 changes: 41 additions & 0 deletions tests/query_execution/test_agg_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx

def test_sum_parquet():
cur = opteryx.query("SELECT SUM(followers) FROM testdata.flat.formats.parquet")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 308125800, first

def test_sum_non_parquet():
cur = opteryx.query("SELECT SUM(followers) FROM testdata.flat.ten_files;")
stats = cur.stats
assert stats["columns_read"] == 1, stats["columns_read"]
assert stats["rows_read"] == 250, stats["rows_read"]
assert stats["rows_seen"] == 250, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 1875090667, first

def test_sum_group_by():
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """
cur = opteryx.query(
"SELECT SUM(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;"
)
stats = cur.stats
assert stats["columns_read"] == 2, stats["columns_read"]
assert stats["rows_read"] == 100000, stats["rows_read"]
assert stats["rows_seen"] == 100000, stats["rows_seen"]
first = cur.fetchone()[0]
assert first == 6.0, first

if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
Loading