Skip to content

Commit

Permalink
[SPARK-45033][SQL] Support maps by parameterized sql()
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose to allow a few more column expressions as parameters of `sql()` to construct maps/array: `CreateArray`, `MapFromArrays`, `CreateMap`, MapFromArrays` and `MapFromEntries`. Need to allow such expression because Spark SQL doesn't support constructing literals of the map type.

### Why are the changes needed?
To improve user experience with Spark SQL, and support the rest built-in type.

### Does this PR introduce _any_ user-facing change?
No. It extends the existing API.

### How was this patch tested?
By running new tests:
```
$ build/sbt "test:testOnly *ParametersSuite"
```

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#42752 from MaxGekk/parameterized-sql-create-map.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Sep 4, 2023
1 parent b0b7835 commit 4162076
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
Expand Down Expand Up @@ -96,7 +96,12 @@ case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression])
*/
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
def isNotAllowed(expr: Expression): Boolean = expr.exists {
case _: Literal | _: CreateArray | _: CreateNamedStruct |
_: CreateMap | _: MapFromArrays | _: MapFromEntries => false
case _ => true
}
args.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
expr.failAnalysis(
errorClass = "INVALID_SQL_ARG",
messageParameters = Map("name" -> name))
Expand All @@ -119,11 +124,13 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
// relations are not children of `UnresolvedWith`.
case p @ NameParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
case NameParameterizedQuery(child, args)
if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_._2.resolved) =>
checkArgs(args)
bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }

case p @ PosParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
case PosParameterizedQuery(child, args)
if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) =>
val indexedArgs = args.zipWithIndex
checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql

import java.time.{Instant, LocalDate, LocalDateTime, ZoneId}

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.{array, lit, map, map_from_arrays, map_from_entries, str_to_map, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -529,4 +530,63 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
spark.sql("SELECT ?[?][?]", Array(Array(Array(1f, 2f), Array.empty[Float], Array(3f)), 0, 1)),
Row(2f))
}

test("SPARK-45033: maps as parameters") {
def fromArr(keys: Array[_], values: Array[_]): Column = {
map_from_arrays(Column(Literal(keys)), Column(Literal(values)))
}
def createMap(keys: Array[_], values: Array[_]): Column = {
val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v))))
map(zipped.map { case (k, v) => Seq(k, v) }.flatten: _*)
}
def fromEntries(keys: Array[_], values: Array[_]): Column = {
val structures = keys.zip(values)
.map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))}
map_from_entries(array(structures: _*))
}

Seq(fromArr(_, _), createMap(_, _)).foreach { f =>
checkAnswer(
spark.sql("SELECT map_contains_key(:mapParam, 0)",
Map("mapParam" -> f(Array.empty[Int], Array.empty[String]))),
Row(false))
checkAnswer(
spark.sql("SELECT map_contains_key(?, 'a')",
Array(f(Array.empty[String], Array.empty[Double]))),
Row(false))
}
Seq(fromArr(_, _), createMap(_, _), fromEntries(_, _)).foreach { f =>
checkAnswer(
spark.sql("SELECT element_at(:mapParam, 'a')",
Map("mapParam" -> f(Array("a"), Array(0)))),
Row(0))
checkAnswer(
spark.sql("SELECT element_at(?, 'a')", Array(f(Array("a"), Array(0)))),
Row(0))
checkAnswer(
spark.sql("SELECT :m[10]", Map("m" -> f(Array(10, 20, 30), Array(0, 1, 2)))),
Row(0))
checkAnswer(
spark.sql("SELECT ?[?]", Array(f(Array(1f, 2f, 3f), Array(1, 2, 3)), 2f)),
Row(2))
}
checkAnswer(
spark.sql("SELECT :m['a'][1]",
Map("m" ->
map_from_arrays(
Column(Literal(Array("a"))),
array(map_from_arrays(Column(Literal(Array(1))), Column(Literal(Array(2)))))))),
Row(2))
// `str_to_map` is not supported
checkError(
exception = intercept[AnalysisException] {
spark.sql("SELECT :m['a'][1]",
Map("m" ->
map_from_arrays(
Column(Literal(Array("a"))),
array(str_to_map(Column(Literal("a:1,b:2,c:3")))))))
},
errorClass = "INVALID_SQL_ARG",
parameters = Map("name" -> "m"))
}
}

0 comments on commit 4162076

Please sign in to comment.