Skip to content

Commit

Permalink
Adds DML to the AST Rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Dec 6, 2024
1 parent 1e1ccb3 commit 5d525fe
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 2 deletions.
40 changes: 39 additions & 1 deletion partiql-ast/api/partiql-ast.api
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,24 @@ public abstract class org/partiql/ast/AstRewriter : org/partiql/ast/AstVisitor {
public fun <init> ()V
public synthetic fun defaultReturn (Lorg/partiql/ast/AstNode;Ljava/lang/Object;)Ljava/lang/Object;
public fun defaultReturn (Lorg/partiql/ast/AstNode;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitConflictActionDoNothing (Lorg/partiql/ast/ConflictAction$DoNothing;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitConflictActionDoNothing (Lorg/partiql/ast/ConflictAction$DoNothing;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitConflictActionDoReplace (Lorg/partiql/ast/ConflictAction$DoReplace;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitConflictActionDoReplace (Lorg/partiql/ast/ConflictAction$DoReplace;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitConflictActionDoUpdate (Lorg/partiql/ast/ConflictAction$DoUpdate;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitConflictActionDoUpdate (Lorg/partiql/ast/ConflictAction$DoUpdate;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitConflictTargetConstraint (Lorg/partiql/ast/ConflictTarget$Constraint;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitConflictTargetConstraint (Lorg/partiql/ast/ConflictTarget$Constraint;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitConflictTargetIndex (Lorg/partiql/ast/ConflictTarget$Index;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitConflictTargetIndex (Lorg/partiql/ast/ConflictTarget$Index;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitDdl (Lorg/partiql/ast/ddl/Ddl;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitDdl (Lorg/partiql/ast/ddl/Ddl;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitDelete (Lorg/partiql/ast/Delete;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitDelete (Lorg/partiql/ast/Delete;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitDoReplaceActionExcluded (Lorg/partiql/ast/DoReplaceAction$Excluded;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitDoReplaceActionExcluded (Lorg/partiql/ast/DoReplaceAction$Excluded;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitDoUpdateActionExcluded (Lorg/partiql/ast/DoUpdateAction$Excluded;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitDoUpdateActionExcluded (Lorg/partiql/ast/DoUpdateAction$Excluded;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitExclude (Lorg/partiql/ast/Exclude;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitExclude (Lorg/partiql/ast/Exclude;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitExcludePath (Lorg/partiql/ast/ExcludePath;Ljava/lang/Object;)Ljava/lang/Object;
Expand Down Expand Up @@ -250,10 +266,18 @@ public abstract class org/partiql/ast/AstRewriter : org/partiql/ast/AstVisitor {
public fun visitIdentifier (Lorg/partiql/ast/Identifier;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitIdentifierChain (Lorg/partiql/ast/IdentifierChain;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitIdentifierChain (Lorg/partiql/ast/IdentifierChain;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitInsert (Lorg/partiql/ast/Insert;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitInsert (Lorg/partiql/ast/Insert;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitInsertSourceFromDefault (Lorg/partiql/ast/InsertSource$FromDefault;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitInsertSourceFromDefault (Lorg/partiql/ast/InsertSource$FromDefault;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitInsertSourceFromExpr (Lorg/partiql/ast/InsertSource$FromExpr;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitInsertSourceFromExpr (Lorg/partiql/ast/InsertSource$FromExpr;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitLet (Lorg/partiql/ast/Let;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitLet (Lorg/partiql/ast/Let;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitLetBinding (Lorg/partiql/ast/Let$Binding;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitLetBinding (Lorg/partiql/ast/Let$Binding;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitOnConflict (Lorg/partiql/ast/OnConflict;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitOnConflict (Lorg/partiql/ast/OnConflict;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitPathStepAllElements (Lorg/partiql/ast/expr/PathStep$AllElements;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitPathStepAllElements (Lorg/partiql/ast/expr/PathStep$AllElements;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitPathStepAllFields (Lorg/partiql/ast/expr/PathStep$AllFields;Ljava/lang/Object;)Ljava/lang/Object;
Expand All @@ -268,6 +292,8 @@ public abstract class org/partiql/ast/AstRewriter : org/partiql/ast/AstVisitor {
public fun visitQueryBodySFW (Lorg/partiql/ast/QueryBody$SFW;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitQueryBodySetOp (Lorg/partiql/ast/QueryBody$SetOp;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitQueryBodySetOp (Lorg/partiql/ast/QueryBody$SetOp;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitReplace (Lorg/partiql/ast/Replace;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitReplace (Lorg/partiql/ast/Replace;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitSelectItemExpr (Lorg/partiql/ast/SelectItem$Expr;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitSelectItemExpr (Lorg/partiql/ast/SelectItem$Expr;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitSelectItemStar (Lorg/partiql/ast/SelectItem$Star;Ljava/lang/Object;)Ljava/lang/Object;
Expand All @@ -280,8 +306,20 @@ public abstract class org/partiql/ast/AstRewriter : org/partiql/ast/AstVisitor {
public fun visitSelectStar (Lorg/partiql/ast/SelectStar;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitSelectValue (Lorg/partiql/ast/SelectValue;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitSelectValue (Lorg/partiql/ast/SelectValue;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitSetClause (Lorg/partiql/ast/SetClause;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitSetClause (Lorg/partiql/ast/SetClause;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitSetOp (Lorg/partiql/ast/SetOp;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitSetOp (Lorg/partiql/ast/SetOp;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitUpdate (Lorg/partiql/ast/Update;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdate (Lorg/partiql/ast/Update;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitUpdateTarget (Lorg/partiql/ast/UpdateTarget;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTarget (Lorg/partiql/ast/UpdateTarget;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitUpdateTargetStepElement (Lorg/partiql/ast/UpdateTargetStep$Element;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStepElement (Lorg/partiql/ast/UpdateTargetStep$Element;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitUpdateTargetStepField (Lorg/partiql/ast/UpdateTargetStep$Field;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStepField (Lorg/partiql/ast/UpdateTargetStep$Field;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
public synthetic fun visitUpsert (Lorg/partiql/ast/Upsert;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpsert (Lorg/partiql/ast/Upsert;Ljava/lang/Object;)Lorg/partiql/ast/AstNode;
}

public abstract class org/partiql/ast/AstVisitor {
Expand Down Expand Up @@ -416,7 +454,7 @@ public abstract class org/partiql/ast/AstVisitor {
public fun visitUnique (Lorg/partiql/ast/ddl/TableConstraint$Unique;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdate (Lorg/partiql/ast/Update;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTarget (Lorg/partiql/ast/UpdateTarget;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStep (Lorg/partiql/ast/UpdateTarget;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStep (Lorg/partiql/ast/UpdateTargetStep;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStepElement (Lorg/partiql/ast/UpdateTargetStep$Element;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpdateTargetStepField (Lorg/partiql/ast/UpdateTargetStep$Field;Ljava/lang/Object;)Ljava/lang/Object;
public fun visitUpsert (Lorg/partiql/ast/Upsert;Ljava/lang/Object;)Ljava/lang/Object;
Expand Down
2 changes: 1 addition & 1 deletion partiql-ast/src/main/java/org/partiql/ast/AstVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public R visitUpdateTarget(UpdateTarget node, C ctx) {
return defaultVisit(node, ctx);
}

public R visitUpdateTargetStep(UpdateTarget node, C ctx) {
public R visitUpdateTargetStep(UpdateTargetStep node, C ctx) {
return node.accept(this, ctx);
}

Expand Down
152 changes: 152 additions & 0 deletions partiql-ast/src/main/kotlin/org/partiql/ast/AstRewriter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -757,4 +757,156 @@ public abstract class AstRewriter<C> : AstVisitor<AstNode, C>() {
override fun visitDdl(node: Ddl, ctx: C): AstNode {
throw UnsupportedOperationException("DDL has not been supported yet in AstRewriter")
}

override fun visitInsert(node: Insert, ctx: C): AstNode {
val source = visitInsertSource(node.source, ctx) as InsertSource
val target = visitIdentifierChain(node.tableName, ctx) as IdentifierChain
val asAlias = node.asAlias?.let { visitIdentifier(it, ctx) as Identifier }
val onConflict = node.onConflict?.let { visitOnConflict(it, ctx) as OnConflict }
if (source !== node.source || target !== node.tableName || asAlias !== node.asAlias || onConflict !== node.onConflict) {
return Insert(target, asAlias, source, onConflict)
}
return node
}

override fun visitInsertSourceFromExpr(node: InsertSource.FromExpr, ctx: C): AstNode {
val expr = visitExpr(node.expr, ctx) as Expr
val columns = node.columns?.let { _visitList(it, ctx, ::visitIdentifier) }
if (expr !== node.expr || columns != node.columns) {
return InsertSource.FromExpr(columns, expr)
}
return node
}

override fun visitInsertSourceFromDefault(node: InsertSource.FromDefault, ctx: C): AstNode {
return node
}

override fun visitOnConflict(node: OnConflict, ctx: C): AstNode {
val action = visitConflictAction(node.action, ctx) as ConflictAction
val target = visitConflictTarget(node.target, ctx) as ConflictTarget
if (action !== node.action || target !== node.target) {
return OnConflict(action, target)
}
return node
}

override fun visitConflictActionDoNothing(node: ConflictAction.DoNothing, ctx: C): AstNode {
return node
}

override fun visitConflictActionDoReplace(node: ConflictAction.DoReplace, ctx: C): AstNode {
val action = visitDoReplaceAction(node.action, ctx) as DoReplaceAction
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
if (action !== node.action || condition !== node.condition) {
return ConflictAction.DoReplace(action, condition)
}
return node
}

override fun visitConflictActionDoUpdate(node: ConflictAction.DoUpdate, ctx: C): AstNode {
val action = visitDoUpdateAction(node.action, ctx) as DoUpdateAction
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
if (action !== node.action || condition !== node.condition) {
return ConflictAction.DoUpdate(action, condition)
}
return node
}

override fun visitConflictTargetConstraint(node: ConflictTarget.Constraint, ctx: C): AstNode {
val constraint = visitIdentifier(node.constraintName, ctx) as Identifier
if (constraint !== node.constraintName) {
return ConflictTarget.Constraint(constraint)
}
return node
}

override fun visitConflictTargetIndex(node: ConflictTarget.Index, ctx: C): AstNode {
val indexes = _visitList(node.indexes, ctx, ::visitIdentifier)
if (indexes !== node.indexes) {
return ConflictTarget.Index(indexes)
}
return node
}

override fun visitDoReplaceActionExcluded(node: DoReplaceAction.Excluded, ctx: C): AstNode {
return node
}

override fun visitDoUpdateActionExcluded(node: DoUpdateAction.Excluded, ctx: C): AstNode {
return node
}

override fun visitDelete(node: Delete, ctx: C): AstNode {
val tableName = visitIdentifierChain(node.tableName, ctx) as IdentifierChain
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
if (tableName !== node.tableName || condition !== node.condition) {
return Delete(tableName, condition)
}
return node
}

override fun visitUpsert(node: Upsert, ctx: C): AstNode {
val tableName = visitIdentifierChain(node.tableName, ctx) as IdentifierChain
val source = visitInsertSource(node.source, ctx) as InsertSource
val asAlias = node.asAlias?.let { visitIdentifier(it, ctx) as Identifier }
if (tableName !== node.tableName || source !== node.source || asAlias !== node.asAlias) {
return Upsert(tableName, asAlias, source)
}
return node
}

override fun visitReplace(node: Replace, ctx: C): AstNode {
val tableName = visitIdentifierChain(node.tableName, ctx) as IdentifierChain
val source = visitInsertSource(node.source, ctx) as InsertSource
val asAlias = node.asAlias?.let { visitIdentifier(it, ctx) as Identifier }
if (tableName !== node.tableName || source !== node.source || asAlias !== node.asAlias) {
return Replace(tableName, asAlias, source)
}
return node
}

override fun visitSetClause(node: SetClause, ctx: C): AstNode {
val target = visitUpdateTarget(node.target, ctx) as UpdateTarget
val expr = visitExpr(node.expr, ctx) as Expr
if (target !== node.target || expr !== node.expr) {
return SetClause(target, expr)
}
return node
}

override fun visitUpdate(node: Update, ctx: C): AstNode {
val tableName = visitIdentifierChain(node.tableName, ctx) as IdentifierChain
val setClauses = _visitList(node.setClauses, ctx, ::visitSetClause)
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
if (tableName !== node.tableName || setClauses !== node.setClauses || condition !== node.condition) {
return Update(tableName, setClauses, condition)
}
return node
}

override fun visitUpdateTarget(node: UpdateTarget, ctx: C): AstNode {
val root = visitIdentifier(node.root, ctx) as Identifier
val steps = _visitList(node.steps, ctx, ::visitUpdateTargetStep)
if (root !== node.root || steps !== node.steps) {
return UpdateTarget(root, steps)
}
return node
}

override fun visitUpdateTargetStepElement(node: UpdateTargetStep.Element, ctx: C): AstNode {
val exprLit = visitExprLit(node.key, ctx) as ExprLit
if (exprLit !== node.key) {
return UpdateTargetStep.Element(exprLit)
}
return node
}

override fun visitUpdateTargetStepField(node: UpdateTargetStep.Field, ctx: C): AstNode {
val key = visitIdentifier(node.key, ctx) as Identifier
if (key !== node.key) {
return UpdateTargetStep.Field(key)
}
return node
}
}

0 comments on commit 5d525fe

Please sign in to comment.