Skip to content

Commit

Permalink
Store statistically equivalent plan nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
pranjalssh authored and rschlussel committed Aug 5, 2022
1 parent be314f5 commit e6470a9
Show file tree
Hide file tree
Showing 47 changed files with 953 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorTableHandle;
import com.facebook.presto.spi.ConnectorTableLayoutHandle;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.plan.PlanNode;
Expand All @@ -34,7 +35,9 @@
import java.util.Objects;
import java.util.Optional;

import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
Expand All @@ -55,7 +58,18 @@ public CanonicalTableScanNode(
@JsonProperty("outputVariables") List<VariableReferenceExpression> outputVariables,
@JsonProperty("assignments") Map<VariableReferenceExpression, ColumnHandle> assignments)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), table, outputVariables, assignments);
}

public CanonicalTableScanNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
CanonicalTableHandle table,
List<VariableReferenceExpression> outputVariables,
Map<VariableReferenceExpression, ColumnHandle> assignments)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.table = requireNonNull(table, "table is null");
this.outputVariables = unmodifiableList(requireNonNull(outputVariables, "outputVariables is null"));
this.assignments = unmodifiableMap(new HashMap<>(requireNonNull(assignments, "assignments is null")));
Expand All @@ -82,6 +96,12 @@ public PlanNode replaceChildren(List<PlanNode> newChildren)
return this;
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Cannot assign canonical plan id to Canonical table scan node: %s", this));
}

@JsonProperty
public CanonicalTableHandle getTable()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner.iterative;

import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.plan.LogicalProperties;
import com.facebook.presto.spi.plan.PlanNode;
Expand All @@ -25,6 +26,9 @@
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static java.lang.String.format;

public class GroupReference
extends InternalPlanNode
{
Expand All @@ -34,7 +38,7 @@ public class GroupReference

public GroupReference(Optional<SourceLocation> sourceLocation, PlanNodeId id, int groupId, List<VariableReferenceExpression> outputs, Optional<LogicalProperties> logicalProperties)
{
super(sourceLocation, id);
super(sourceLocation, id, Optional.empty());
this.groupId = groupId;
this.outputs = ImmutableList.copyOf(outputs);
this.logicalProperties = logicalProperties;
Expand Down Expand Up @@ -69,6 +73,12 @@ public PlanNode replaceChildren(List<PlanNode> newChildren)
throw new UnsupportedOperationException();
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Cannot assign canonical plan id to Group Reference node: %s", this));
}

public Optional<LogicalProperties> getLogicalProperties()
{
return logicalProperties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
public abstract class AbstractJoinNode
extends InternalPlanNode
{
protected AbstractJoinNode(Optional<SourceLocation> sourceLocation, PlanNodeId planNodeId)
protected AbstractJoinNode(Optional<SourceLocation> sourceLocation, PlanNodeId planNodeId, Optional<PlanNode> statsEquivalentPlanNode)
{
super(sourceLocation, planNodeId);
super(sourceLocation, planNodeId, statsEquivalentPlanNode);
}

public abstract Map<String, VariableReferenceExpression> getDynamicFilters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,20 @@ public ApplyNode(
@JsonProperty("correlation") List<VariableReferenceExpression> correlation,
@JsonProperty("originSubqueryError") String originSubqueryError)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), input, subquery, subqueryAssignments, correlation, originSubqueryError);
}

public ApplyNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode input,
PlanNode subquery,
Assignments subqueryAssignments,
List<VariableReferenceExpression> correlation,
String originSubqueryError)
{
super(sourceLocation, id, statsEquivalentPlanNode);
checkArgument(input.getOutputVariables().containsAll(correlation), "Input does not contain symbols from correlation");
verifySubquerySupported(subqueryAssignments);

Expand Down Expand Up @@ -145,6 +158,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes");
return new ApplyNode(getSourceLocation(), getId(), newChildren.get(0), newChildren.get(1), subqueryAssignments, correlation, originSubqueryError);
return new ApplyNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), newChildren.get(1), subqueryAssignments, correlation, originSubqueryError);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new ApplyNode(getSourceLocation(), getId(), statsEquivalentPlanNode, input, subquery, subqueryAssignments, correlation, originSubqueryError);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ public AssignUniqueId(
@JsonProperty("source") PlanNode source,
@JsonProperty("idVariable") VariableReferenceExpression idVariable)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), source, idVariable);
}

public AssignUniqueId(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source,
VariableReferenceExpression idVariable)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.source = requireNonNull(source, "source is null");
this.idVariable = requireNonNull(idVariable, "idVariable is null");
}
Expand Down Expand Up @@ -92,6 +102,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
checkArgument(newChildren.size() == 1, "expected newChildren to contain 1 node");
return new AssignUniqueId(newChildren.get(0).getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), idVariable);
return new AssignUniqueId(newChildren.get(0).getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), idVariable);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new AssignUniqueId(getSourceLocation(), getId(), statsEquivalentPlanNode, source, idVariable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,18 @@ public DeleteNode(
@JsonProperty("rowId") VariableReferenceExpression rowId,
@JsonProperty("outputVariables") List<VariableReferenceExpression> outputVariables)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), source, rowId, outputVariables);
}

public DeleteNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source,
VariableReferenceExpression rowId,
List<VariableReferenceExpression> outputVariables)
{
super(sourceLocation, id, statsEquivalentPlanNode);

this.source = requireNonNull(source, "source is null");
this.rowId = requireNonNull(rowId, "rowId is null");
Expand Down Expand Up @@ -86,6 +97,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new DeleteNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), rowId, outputVariables);
return new DeleteNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), rowId, outputVariables);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new DeleteNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, rowId, outputVariables);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@ public EnforceSingleRowNode(
@JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), source);
}

public EnforceSingleRowNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source)
{
super(sourceLocation, id, statsEquivalentPlanNode);

this.source = requireNonNull(source, "source is null");
}
Expand Down Expand Up @@ -73,6 +82,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new EnforceSingleRowNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren));
return new EnforceSingleRowNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren));
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new EnforceSingleRowNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,22 @@ public ExchangeNode(
@JsonProperty("ensureSourceOrdering") boolean ensureSourceOrdering,
@JsonProperty("orderingScheme") Optional<OrderingScheme> orderingScheme)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), type, scope, partitioningScheme, sources, inputs, ensureSourceOrdering, orderingScheme);
}

public ExchangeNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
Type type,
Scope scope,
PartitioningScheme partitioningScheme,
List<PlanNode> sources,
List<List<VariableReferenceExpression>> inputs,
boolean ensureSourceOrdering,
Optional<OrderingScheme> orderingScheme)
{
super(sourceLocation, id, statsEquivalentPlanNode);

requireNonNull(type, "type is null");
requireNonNull(scope, "scope is null");
Expand Down Expand Up @@ -314,6 +329,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new ExchangeNode(getSourceLocation(), getId(), type, scope, partitioningScheme, newChildren, inputs, ensureSourceOrdering, orderingScheme);
return new ExchangeNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), type, scope, partitioningScheme, newChildren, inputs, ensureSourceOrdering, orderingScheme);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new ExchangeNode(getSourceLocation(), getId(), statsEquivalentPlanNode, type, scope, partitioningScheme, sources, inputs, ensureSourceOrdering, orderingScheme);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,18 @@ public ExplainAnalyzeNode(
@JsonProperty("outputVariable") VariableReferenceExpression outputVariable,
@JsonProperty("verbose") boolean verbose)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), source, outputVariable, verbose);
}

public ExplainAnalyzeNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source,
VariableReferenceExpression outputVariable,
boolean verbose)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.source = requireNonNull(source, "source is null");
this.outputVariable = requireNonNull(outputVariable, "outputVariable is null");
this.verbose = verbose;
Expand Down Expand Up @@ -90,6 +101,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new ExplainAnalyzeNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), outputVariable, isVerbose());
return new ExplainAnalyzeNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), outputVariable, isVerbose());
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new ExplainAnalyzeNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, outputVariable, isVerbose());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,20 @@ public GroupIdNode(
@JsonProperty("aggregationArguments") List<VariableReferenceExpression> aggregationArguments,
@JsonProperty("groupIdVariable") VariableReferenceExpression groupIdVariable)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), source, groupingSets, groupingColumns, aggregationArguments, groupIdVariable);
}

public GroupIdNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source,
List<List<VariableReferenceExpression>> groupingSets,
Map<VariableReferenceExpression, VariableReferenceExpression> groupingColumns,
List<VariableReferenceExpression> aggregationArguments,
VariableReferenceExpression groupIdVariable)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.source = requireNonNull(source);
this.groupingSets = listOfListsCopy(requireNonNull(groupingSets, "groupingSets is null"));
this.groupingColumns = ImmutableMap.copyOf(requireNonNull(groupingColumns));
Expand Down Expand Up @@ -155,7 +168,13 @@ public Set<VariableReferenceExpression> getCommonGroupingColumns()
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new GroupIdNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), groupingSets, groupingColumns, aggregationArguments, groupIdVariable);
return new GroupIdNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), groupingSets, groupingColumns, aggregationArguments, groupIdVariable);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new GroupIdNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, groupingSets, groupingColumns, aggregationArguments, groupIdVariable);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,21 @@ public IndexJoinNode(
@JsonProperty("probeHashVariable") Optional<VariableReferenceExpression> probeHashVariable,
@JsonProperty("indexHashVariable") Optional<VariableReferenceExpression> indexHashVariable)
{
super(sourceLocation, id);
this(sourceLocation, id, Optional.empty(), type, probeSource, indexSource, criteria, probeHashVariable, indexHashVariable);
}

public IndexJoinNode(
Optional<SourceLocation> sourceLocation,
PlanNodeId id,
Optional<PlanNode> statsEquivalentPlanNode,
Type type,
PlanNode probeSource,
PlanNode indexSource,
List<EquiJoinClause> criteria,
Optional<VariableReferenceExpression> probeHashVariable,
Optional<VariableReferenceExpression> indexHashVariable)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.type = requireNonNull(type, "type is null");
this.probeSource = requireNonNull(probeSource, "probeSource is null");
this.indexSource = requireNonNull(indexSource, "indexSource is null");
Expand Down Expand Up @@ -139,7 +153,13 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes");
return new IndexJoinNode(getSourceLocation(), getId(), type, newChildren.get(0), newChildren.get(1), criteria, probeHashVariable, indexHashVariable);
return new IndexJoinNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), type, newChildren.get(0), newChildren.get(1), criteria, probeHashVariable, indexHashVariable);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new IndexJoinNode(getSourceLocation(), getId(), statsEquivalentPlanNode, type, probeSource, indexSource, criteria, probeHashVariable, indexHashVariable);
}

public static class EquiJoinClause
Expand Down
Loading

0 comments on commit e6470a9

Please sign in to comment.