diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java index 3cb096ef289e..d72bbba2ce37 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java @@ -113,6 +113,7 @@ protected List createOperatorFactories() getColumnTypes("lineitem", "returnflag", "linestatus"), Ints.asList(0, 1), ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of( doubleSum.bind(ImmutableList.of(2), Optional.empty()), diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java index 2c8514cc11e8..b887ba8cbecd 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java @@ -58,6 +58,7 @@ protected List createOperatorFactories() ImmutableList.of(tableTypes.get(0)), Ints.asList(0), ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(doubleSum.bind(ImmutableList.of(1), Optional.empty())), Optional.empty(), diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3d1c0009535f..2a8586a32fc3 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -1215,7 +1215,7 @@ public SystemSessionProperties( SEGMENTED_AGGREGATION_ENABLED, "Enable segmented aggregation.", featuresConfig.isSegmentedAggregationEnabled(), - true), + false), new PropertyMetadata<>( AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, format("Set the strategy used to rewrite AGG IF to AGG FILTER. Options are %s", diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index e2ee11c0968b..5ad821b8d82f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.Accumulator; @@ -29,11 +30,14 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.stream.Collectors; import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; @@ -41,6 +45,7 @@ import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; @@ -56,6 +61,8 @@ public static class HashAggregationOperatorFactory private final PlanNodeId planNodeId; private final List groupByTypes; private final List groupByChannels; + // A subset of groupByChannels, containing channels that are already sorted. + private final List preGroupedChannels; private final List globalAggregationGroupIds; private final Step step; private final boolean produceDefaultOutput; @@ -80,6 +87,7 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + List preGroupedChannels, List globalAggregationGroupIds, Step step, List accumulatorFactories, @@ -94,6 +102,7 @@ public HashAggregationOperatorFactory( planNodeId, groupByTypes, groupByChannels, + preGroupedChannels, globalAggregationGroupIds, step, false, @@ -117,6 +126,7 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + List preGroupedChannels, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -135,6 +145,7 @@ public HashAggregationOperatorFactory( planNodeId, groupByTypes, groupByChannels, + preGroupedChannels, globalAggregationGroupIds, step, produceDefaultOutput, @@ -157,6 +168,7 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + List preGroupedChannels, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -178,6 +190,7 @@ public HashAggregationOperatorFactory( this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null"); this.groupByTypes = ImmutableList.copyOf(groupByTypes); this.groupByChannels = ImmutableList.copyOf(groupByChannels); + this.preGroupedChannels = ImmutableList.copyOf(preGroupedChannels); this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); this.step = step; this.produceDefaultOutput = produceDefaultOutput; @@ -202,6 +215,7 @@ public Operator createOperator(DriverContext driverContext) operatorContext, groupByTypes, groupByChannels, + preGroupedChannels, globalAggregationGroupIds, step, produceDefaultOutput, @@ -233,6 +247,7 @@ public OperatorFactory duplicate() planNodeId, groupByTypes, groupByChannels, + preGroupedChannels, globalAggregationGroupIds, step, produceDefaultOutput, @@ -253,6 +268,7 @@ public OperatorFactory duplicate() private final OperatorContext operatorContext; private final List groupByTypes; private final List groupByChannels; + private final int[] preGroupedChannels; private final List globalAggregationGroupIds; private final Step step; private final boolean produceDefaultOutput; @@ -267,6 +283,7 @@ public OperatorFactory duplicate() private final SpillerFactory spillerFactory; private final JoinCompiler joinCompiler; private final boolean useSystemMemory; + private final Optional preGroupedHashStrategy; private final List types; private final HashCollisionsCounter hashCollisionsCounter; @@ -276,6 +293,8 @@ public OperatorFactory duplicate() private boolean inputProcessed; private boolean finishing; private boolean finished; + private Page firstUnfinishedSegment; + private Page remainingPageForSegmentedAggregation; // for yield when memory is not available private Work unfinishedWork; @@ -284,6 +303,7 @@ public HashAggregationOperator( OperatorContext operatorContext, List groupByTypes, List groupByChannels, + List preGroupedChannels, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -306,6 +326,7 @@ public HashAggregationOperator( this.groupByTypes = ImmutableList.copyOf(groupByTypes); this.groupByChannels = ImmutableList.copyOf(groupByChannels); + this.preGroupedChannels = Ints.toArray(requireNonNull(preGroupedChannels, "preGroupedChannels is null")); this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories); this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); @@ -323,6 +344,13 @@ public HashAggregationOperator( this.hashCollisionsCounter = new HashCollisionsCounter(operatorContext); operatorContext.setInfoSupplier(hashCollisionsCounter); this.useSystemMemory = useSystemMemory; + + checkState(ImmutableSet.copyOf(groupByChannels).containsAll(preGroupedChannels), "groupByChannels must include all channels in preGroupedChannels"); + this.preGroupedHashStrategy = preGroupedChannels.isEmpty() + ? Optional.empty() + : Optional.of(joinCompiler.compilePagesHashStrategyFactory( + preGroupedChannels.stream().map(groupByTypes::get).collect(toImmutableList()), preGroupedChannels, Optional.empty()) + .createPagesHashStrategy(groupByTypes.stream().map(type -> ImmutableList.of()).collect(toImmutableList()), OptionalInt.empty())); } @Override @@ -348,13 +376,15 @@ public boolean isFinished() // - 2. Current page has been processed. // - 3. Aggregation builder has not been triggered or has finished processing. // - 4. If this is partial aggregation then it must have not reached the memory limit. + // - 5. If running in segmented aggregation mode, there must be no remaining page to process. @Override public boolean needsInput() { return !finishing && unfinishedWork == null && outputPages == null - && !partialAggregationReachedMemoryLimit(); + && !partialAggregationReachedMemoryLimit() + && remainingPageForSegmentedAggregation == null; } @Override @@ -366,10 +396,10 @@ public void addInput(Page page) inputProcessed = true; initializeAggregationBuilderIfNeeded(); + processInputPage(page); // process the current page; save the unfinished work if we are waiting for memory - unfinishedWork = aggregationBuilder.processPage(page); - if (unfinishedWork.process()) { + if (unfinishedWork != null && unfinishedWork.process()) { unfinishedWork = null; } aggregationBuilder.updateMemory(); @@ -436,6 +466,7 @@ public Page getOutput() if (outputPages.isFinished()) { closeAggregationBuilder(); + processRemainingPageForSegmentedAggregation(); return null; } @@ -454,6 +485,53 @@ public HashAggregationBuilder getAggregationBuilder() return aggregationBuilder; } + private void processInputPage(Page page) + { + // 1. normal aggregation + if (!preGroupedHashStrategy.isPresent()) { + unfinishedWork = aggregationBuilder.processPage(page); + return; + } + + // 2. segmented aggregation + if (firstUnfinishedSegment == null) { + // If this is the first page, treat the first segment in this page as the current segment. + firstUnfinishedSegment = page.getRegion(0, 1); + } + + Page pageOnPreGroupedChannels = page.extractChannels(preGroupedChannels); + int lastRowInPage = page.getPositionCount() - 1; + int lastSegmentStart = findLastSegmentStart(preGroupedHashStrategy.get(), pageOnPreGroupedChannels); + if (lastSegmentStart == 0) { + // The whole page is in one segment. + if (preGroupedHashStrategy.get().rowEqualsRow(0, firstUnfinishedSegment.extractChannels(preGroupedChannels), 0, pageOnPreGroupedChannels)) { + // All rows in this page belong to the previous unfinished segment, process the whole page. + unfinishedWork = aggregationBuilder.processPage(page); + } + else { + // If the current page starts with a new segment, flush before processing it. + remainingPageForSegmentedAggregation = page; + } + } + else { + // If the current segment ends in the current page, flush it with all the segments (if exist) except the last segment of the current page. + unfinishedWork = aggregationBuilder.processPage(page.getRegion(0, lastSegmentStart)); + remainingPageForSegmentedAggregation = page.getRegion(lastSegmentStart, lastRowInPage - lastSegmentStart + 1); + } + // Record the last segment. + firstUnfinishedSegment = page.getRegion(lastRowInPage, 1); + } + + private int findLastSegmentStart(PagesHashStrategy pagesHashStrategy, Page page) + { + for (int i = page.getPositionCount() - 1; i > 0; i--) { + if (!pagesHashStrategy.rowEqualsRow(i - 1, page, i, page)) { + return i; + } + } + return 0; + } + private void closeAggregationBuilder() { outputPages = null; @@ -468,6 +546,16 @@ private void closeAggregationBuilder() operatorContext.localRevocableMemoryContext().setBytes(0); } + private void processRemainingPageForSegmentedAggregation() + { + // Running in segmented aggregation mode, reopen the aggregation builder and process the remaining page. + if (remainingPageForSegmentedAggregation != null) { + initializeAggregationBuilderIfNeeded(); + unfinishedWork = aggregationBuilder.processPage(remainingPageForSegmentedAggregation); + remainingPageForSegmentedAggregation = null; + } + } + private void initializeAggregationBuilderIfNeeded() { if (aggregationBuilder != null) { @@ -509,9 +597,10 @@ private void initializeAggregationBuilderIfNeeded() // Flush if one of the following is true: // - received finish() signal (no more input to come). // - it is a partial aggregation and has reached memory limit + // - running in segmented aggregation mode and at least one segment has been fully processed private boolean shouldFlush() { - return finishing || partialAggregationReachedMemoryLimit(); + return finishing || partialAggregationReachedMemoryLimit() || remainingPageForSegmentedAggregation != null; } private boolean partialAggregationReachedMemoryLimit() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 7750a83dc2c6..da01d8661d3a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -2541,6 +2541,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl aggregation.getAggregations(), ImmutableSet.of(), groupingVariables, + ImmutableList.of(), PARTIAL, Optional.empty(), Optional.empty(), @@ -2646,6 +2647,7 @@ public PhysicalOperation visitTableWriteMerge(TableWriterMergeNode node, LocalEx aggregation.getAggregations(), ImmutableSet.of(), groupingVariables, + ImmutableList.of(), INTERMEDIATE, Optional.empty(), Optional.empty(), @@ -2700,6 +2702,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl aggregation.getAggregations(), ImmutableSet.of(), groupingVariables, + ImmutableList.of(), FINAL, Optional.empty(), Optional.empty(), @@ -3075,6 +3078,7 @@ private PhysicalOperation planGroupByAggregation( node.getAggregations(), node.getGlobalGroupingSets(), node.getGroupingKeys(), + node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable(), @@ -3099,6 +3103,7 @@ private OperatorFactory createHashAggregationOperatorFactory( Map aggregations, Set globalGroupingSets, List groupbyVariables, + List preGroupedVariables, Step step, Optional hashVariable, Optional groupIdVariable, @@ -3167,11 +3172,13 @@ private OperatorFactory createHashAggregationOperatorFactory( } else { Optional hashChannel = hashVariable.map(variableChannelGetter(source)); + List preGroupedChannels = getChannelsForVariables(preGroupedVariables, source.getLayout()); return new HashAggregationOperatorFactory( context.getNextOperatorId(), planNodeId, groupByTypes, groupByChannels, + preGroupedChannels, ImmutableList.copyOf(globalGroupingSets), step, hasDefaultOutput, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java new file mode 100644 index 000000000000..7959472e6604 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; +import com.facebook.presto.operator.aggregation.InternalAggregationFunction; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spiller.SpillerFactory; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.gen.JoinCompiler; +import com.facebook.presto.testing.TestingTaskContext; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; +import org.testng.annotations.Test; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; +import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE; +import static com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators.Context.TOTAL_PAGES; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.airlift.units.DataSize.succinctBytes; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; +import static org.testng.Assert.assertEquals; + +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 5) +@Measurement(iterations = 10, time = 2, timeUnit = SECONDS) +public class BenchmarkHashAndSegmentedAggregationOperators +{ + private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = metadata.getFunctionAndTypeManager(); + + private static final InternalAggregationFunction LONG_SUM = FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation( + FUNCTION_AND_TYPE_MANAGER.lookupFunction("sum", fromTypes(BIGINT))); + private static final InternalAggregationFunction COUNT = FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation( + FUNCTION_AND_TYPE_MANAGER.lookupFunction("count", ImmutableList.of())); + + @State(Thread) + public static class Context + { + public static final int TOTAL_PAGES = 100; + public static final int ROWS_PER_PAGE = 1000; + + @Param({"1", "10", "800", "100000"}) + public int rowsPerSegment; + + @Param({"segmented", "hash"}) + public String operatorType; + + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private OperatorFactory operatorFactory; + private List pages; + private int outputRows; + + @Setup + public void setup() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + outputRows = 0; + + boolean segmentedAggregation = operatorType.equalsIgnoreCase("segmented"); + RowPagesBuilder pagesBuilder = RowPagesBuilder.rowPagesBuilder(true, ImmutableList.of(0, 1), VARCHAR, BIGINT, BIGINT); + for (int i = 0; i < TOTAL_PAGES; i++) { + BlockBuilder sortedBlockBuilder = VARCHAR.createBlockBuilder(null, ROWS_PER_PAGE); + for (int j = 0; j < ROWS_PER_PAGE; j++) { + int currentSegment = (i * ROWS_PER_PAGE + j) / rowsPerSegment; + VARCHAR.writeString(sortedBlockBuilder, String.valueOf(currentSegment)); + } + outputRows += (ROWS_PER_PAGE - 1) / rowsPerSegment + 1; + pagesBuilder.addBlocksPage(sortedBlockBuilder, createLongRepeatBlock(i, ROWS_PER_PAGE), createLongSequenceBlock(0, ROWS_PER_PAGE)); + } + + pages = pagesBuilder.build(); + operatorFactory = createHashAggregationOperatorFactory(pagesBuilder.getHashChannel(), segmentedAggregation); + } + + private OperatorFactory createHashAggregationOperatorFactory(Optional hashChannel, boolean segmentedAggregation) + { + JoinCompiler joinCompiler = new JoinCompiler(metadata, new FeaturesConfig()); + SpillerFactory spillerFactory = (types, localSpillContext, aggregatedMemoryContext) -> null; + + return new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(VARCHAR, BIGINT), + ImmutableList.of(0, 1), + segmentedAggregation ? ImmutableList.of(0) : ImmutableList.of(), + ImmutableList.of(), + AggregationNode.Step.SINGLE, + false, + ImmutableList.of(COUNT.bind(ImmutableList.of(2), Optional.empty()), + LONG_SUM.bind(ImmutableList.of(2), Optional.empty())), + hashChannel, + Optional.empty(), + 100_000, + Optional.of(new DataSize(16, MEGABYTE)), + false, + succinctBytes(8), + succinctBytes(Integer.MAX_VALUE), + spillerFactory, + joinCompiler, + false); + } + + public TaskContext createTaskContext() + { + return TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION, new DataSize(2, GIGABYTE)); + } + + public OperatorFactory getOperatorFactory() + { + return operatorFactory; + } + + public List getPages() + { + return pages; + } + } + + @Benchmark + public List benchmark(Context context) + { + DriverContext driverContext = context.createTaskContext().addPipelineContext(0, true, true, false).addDriverContext(); + Operator operator = context.getOperatorFactory().createOperator(driverContext); + + Iterator input = context.getPages().iterator(); + ImmutableList.Builder outputPages = ImmutableList.builder(); + + boolean finishing = false; + for (int loops = 0; !operator.isFinished() && loops < 1_000_000; loops++) { + if (operator.needsInput()) { + if (input.hasNext()) { + Page inputPage = input.next(); + operator.addInput(inputPage); + } + else if (!finishing) { + operator.finish(); + finishing = true; + } + } + + Page outputPage = operator.getOutput(); + if (outputPage != null) { + outputPages.add(outputPage); + } + } + + return outputPages.build(); + } + + @Test + public void verifyHash() + { + verify(1, "hash"); + verify(10, "hash"); + verify(800, "hash"); + verify(100000, "hash"); + } + + @Test + public void verifySegmented() + { + verify(1, "segmented"); + verify(10, "segmented"); + verify(800, "segmented"); + verify(100000, "segmented"); + } + + private void verify(int rowsPerSegment, String operatorType) + { + Context context = new Context(); + context.operatorType = operatorType; + context.rowsPerSegment = rowsPerSegment; + context.setup(); + + assertEquals(TOTAL_PAGES, context.getPages().size()); + for (int i = 0; i < TOTAL_PAGES; i++) { + assertEquals(ROWS_PER_PAGE, context.getPages().get(i).getPositionCount()); + } + + List outputPages = benchmark(context); + assertEquals(context.outputRows, outputPages.stream().mapToInt(Page::getPositionCount).sum()); + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkHashAndSegmentedAggregationOperators.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java index 12a4d575b8f8..7d84befea2e4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -159,6 +159,7 @@ private OperatorFactory createHashAggregationOperatorFactory(Optional h ImmutableList.of(VARCHAR), ImmutableList.of(0), ImmutableList.of(), + ImmutableList.of(), AggregationNode.Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index 44f9b939902b..9f9cd8db11c9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -176,6 +176,7 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boole ImmutableList.of(VARCHAR), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), @@ -228,6 +229,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna new PlanNodeId("test"), ImmutableList.of(VARCHAR, BIGINT), groupByChannels, + ImmutableList.of(), globalAggregationGroupIds, Step.SINGLE, true, @@ -280,6 +282,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, true, ImmutableList.of(arrayAggColumn.bind(ImmutableList.of(0), Optional.empty())), @@ -322,6 +325,7 @@ public void testMemoryLimit(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), @@ -360,6 +364,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boo ImmutableList.of(VARCHAR), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), @@ -387,6 +392,7 @@ public void testMemoryReservationYield(Type type) ImmutableList.of(type), ImmutableList.of(0), ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), Optional.of(1), @@ -439,6 +445,7 @@ public void testHashBuilderResizeLimit(boolean hashEnabled) ImmutableList.of(VARCHAR), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), @@ -472,6 +479,7 @@ public void testMultiSliceAggregationOutput(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(1), Optional.empty())), @@ -504,6 +512,7 @@ public void testMultiplePartialFlushes(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.PARTIAL, ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), @@ -585,6 +594,7 @@ public void testMergeWithMemorySpill() ImmutableList.of(BIGINT), ImmutableList.of(0), ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), @@ -636,6 +646,7 @@ public void testSpillerFailure() ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), @@ -679,6 +690,7 @@ public void testMask() ImmutableList.of(BIGINT), ImmutableList.of(0), ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(1), Optional.of(2))), @@ -723,6 +735,7 @@ private void testMemoryTracking(boolean useSystemMemory) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java new file mode 100644 index 000000000000..425c8cccff4a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.LongArrayBlock; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; +import com.facebook.presto.operator.aggregation.InternalAggregationFunction; +import com.facebook.presto.spi.plan.AggregationNode.Step; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.gen.JoinCompiler; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.TestingTaskContext; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.operator.OperatorAssertion.assertPagesEqualIgnoreOrder; +import static com.facebook.presto.operator.OperatorAssertion.toPages; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.airlift.units.DataSize.succinctBytes; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestHashAggregationOperatorInSegmentedAggregationMode +{ + private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager(); + + private static final InternalAggregationFunction COUNT = FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation( + FUNCTION_AND_TYPE_MANAGER.lookupFunction("count", ImmutableList.of())); + + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private JoinCompiler joinCompiler = new JoinCompiler(MetadataManager.createTestMetadataManager(), new FeaturesConfig()); + + private HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT, BIGINT), + ImmutableList.of(0, 1), + ImmutableList.of(0), + ImmutableList.of(), + Step.SINGLE, + false, + ImmutableList.of(COUNT.bind(ImmutableList.of(2), Optional.empty())), + Optional.empty(), + Optional.empty(), + 4, + Optional.of(new DataSize(16, MEGABYTE)), + false, + new DataSize(16, MEGABYTE), + new DataSize(16, MEGABYTE), + new DummySpillerFactory(), + joinCompiler, + false); + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @AfterMethod + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testSegmentedAggregationSinglePage() + { + int numberOfRows = 10; + Block sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 2, 2, 2, 3, 3, 3, 3}); + Block groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 2, 1, 2, 2, 1, 2, 1}); + Block countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + Page inputPage = new Page(sortedBlock, groupingBlock, countBlock); + + DriverContext driverContext = createDriverContext(); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT); + expectedBuilder.row(1L, 1L, 3L); + expectedBuilder.row(2L, 1L, 1L); + expectedBuilder.row(2L, 2L, 2L); + expectedBuilder.row(3L, 1L, 2L); + expectedBuilder.row(3L, 2L, 2L); + + MaterializedResult expected = expectedBuilder.build(); + List outputPages = toPages(operatorFactory, driverContext, ImmutableList.of(inputPage)); + assertEquals(outputPages.size(), 2); + assertPagesEqualIgnoreOrder(driverContext, outputPages, expected, true, Optional.empty()); + } + + @Test + public void testSegmentedAggregationSingleSegment() + { + int numberOfRows = 5; + + Block sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 1, 1}); + Block groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + Block countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage1 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 1, 1}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage2 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 1, 1}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage3 = new Page(sortedBlock, groupingBlock, countBlock); + + DriverContext driverContext = createDriverContext(); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT); + expectedBuilder.row(1L, 1L, 9L); + expectedBuilder.row(1L, 2L, 6L); + + MaterializedResult expected = expectedBuilder.build(); + List outputPages = toPages(operatorFactory, driverContext, ImmutableList.of(inputPage1, inputPage2, inputPage3)); + assertEquals(outputPages.size(), 1); + assertPagesEqualIgnoreOrder(driverContext, outputPages, expected, true, Optional.empty()); + } + + @Test + public void testSegmentedAggregationMultiplePages() + { + int numberOfRows = 5; + + Block sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 1, 1, 2, 2}); + Block groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + Block countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage1 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{2, 2, 2, 2, 2}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage2 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{2, 2, 3, 3, 4}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage3 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{5, 5, 5, 5, 5}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage4 = new Page(sortedBlock, groupingBlock, countBlock); + + sortedBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{5, 6, 7, 8, 8}); + groupingBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 1, 2, 1}); + countBlock = new LongArrayBlock(numberOfRows, Optional.of(new boolean[numberOfRows]), new long[]{1, 2, 3, 4, 5}); + Page inputPage5 = new Page(sortedBlock, groupingBlock, countBlock); + + DriverContext driverContext = createDriverContext(); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT); + expectedBuilder.row(1L, 1L, 2L); + expectedBuilder.row(1L, 2L, 1L); + expectedBuilder.row(2L, 1L, 5L); + expectedBuilder.row(2L, 2L, 4L); + expectedBuilder.row(3L, 1L, 1L); + expectedBuilder.row(3L, 2L, 1L); + expectedBuilder.row(4L, 1L, 1L); + expectedBuilder.row(5L, 1L, 4L); + expectedBuilder.row(5L, 2L, 2L); + expectedBuilder.row(6L, 2L, 1L); + expectedBuilder.row(7L, 1L, 1L); + expectedBuilder.row(8L, 1L, 1L); + expectedBuilder.row(8L, 2L, 1L); + + MaterializedResult expected = expectedBuilder.build(); + List outputPages = toPages(operatorFactory, driverContext, ImmutableList.of(inputPage1, inputPage2, inputPage3, inputPage4, inputPage5)); + // segment 1: [1] | segment 2: [2 - 3] | segment 3: [4] | segment 4: [5 - 7] | segment 5: [8] + assertEquals(outputPages.size(), 5); + assertPagesEqualIgnoreOrder(driverContext, outputPages, expected, true, Optional.empty()); + } + + private DriverContext createDriverContext() + { + return createDriverContext(Integer.MAX_VALUE); + } + + private DriverContext createDriverContext(long memoryLimit) + { + return TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(succinctBytes(memoryLimit)) + .build() + .addPipelineContext(0, true, true, false) + .addDriverContext(); + } +}