From 40943c2748fdd28d970d017cb8ee86c294ee62df Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 5 Sep 2023 15:35:12 +0200 Subject: [PATCH] [SPARK-45072][CONNECT] Fix outer scopes for ammonite classes ### What changes were proposed in this pull request? Ammonite places all user code inside Helper classes which are nested inside the class it creates for each command. This PR adds a custom code class wrapper for the Ammonite REPL. It makes sure the Helper classes generated by ammonite are always registered as an outer scope immediately. This way we can instantiate classes defined inside the Helper class, even when we execute Spark code as part of the Helper's constructor. ### Why are the changes needed? When you currently define a class and execute a Spark command using that class inside the same cell/line this will fail with an NullPointerException. The reason for that is that we cannot resolve the outer scope needed to instantiate the class. This PR fixes that issue. The following code will now execute successfully (include the curly braces): ```scala { case class Thing(val value: String) val r = (0 to 10).map( value => Thing(value.toString) ) spark.createDataFrame(r) } ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I added more tests to the `ReplE2ESuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42807 from hvanhovell/SPARK-45072. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../spark/sql/application/ConnectRepl.scala | 29 +++++++++-- .../spark/sql/application/ReplE2ESuite.scala | 48 +++++++++++++++---- .../CheckConnectJvmClientCompatibility.scala | 6 +++ 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index e6ada566398c7..0360a40578869 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -22,7 +22,8 @@ import java.util.concurrent.Semaphore import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper -import ammonite.util.Bind +import ammonite.compiler.iface.CodeWrapper +import ammonite.util.{Bind, Imports, Name, Util} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SparkSession @@ -94,8 +95,8 @@ object ConnectRepl { val main = ammonite.Main( welcomeBanner = Option(splash), predefCode = predefCode, - replCodeWrapper = CodeClassWrapper, - scriptCodeWrapper = CodeClassWrapper, + replCodeWrapper = ExtendedCodeClassWrapper, + scriptCodeWrapper = ExtendedCodeClassWrapper, inputStream = inputStream, outputStream = outputStream, errorStream = errorStream) @@ -107,3 +108,25 @@ object ConnectRepl { } } } + +/** + * [[CodeWrapper]] that makes sure new Helper classes are always registered as an outer scope. + */ +@DeveloperApi +object ExtendedCodeClassWrapper extends CodeWrapper { + override def wrapperPath: Seq[Name] = CodeClassWrapper.wrapperPath + override def apply( + code: String, + source: Util.CodeSource, + imports: Imports, + printCode: String, + indexedWrapper: Name, + extraCode: String): (String, String, Int) = { + val (top, bottom, level) = + CodeClassWrapper(code, source, imports, printCode, indexedWrapper, extraCode) + // Make sure we register the Helper before anything else, so outer scopes work as expected. + val augmentedTop = top + + "\norg.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)\n" + (augmentedTop, bottom, level) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 4106d298dbe2b..5bb8cbf3543b0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -79,12 +79,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { override def afterEach(): Unit = { semaphore.drainPermits() - if (ammoniteOut != null) { - ammoniteOut.reset() - } } def runCommandsInShell(input: String): String = { + ammoniteOut.reset() require(input.nonEmpty) // Pad the input with a semaphore release so that we know when the execution of the provided // input is complete. @@ -105,6 +103,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { getCleanString(ammoniteOut) } + def runCommandsUsingSingleCellInShell(input: String): String = { + runCommandsInShell("{\n" + input + "\n}") + } + def assertContains(message: String, output: String): Unit = { val isContain = output.contains(message) assert( @@ -263,6 +265,31 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) } + test("Single Cell Compilation") { + val input = + """ + |case class C1(value: Int) + |case class C2(value: Int) + |val h1 = classOf[C1].getDeclaringClass + |val h2 = classOf[C2].getDeclaringClass + |val same = h1 == h2 + |""".stripMargin + assertContains("same: Boolean = false", runCommandsInShell(input)) + assertContains("same: Boolean = true", runCommandsUsingSingleCellInShell(input)) + } + + test("Local relation containing REPL generated class") { + val input = + """ + |case class MyTestClass(value: Int) + |val data = (0 to 10).map(MyTestClass) + |spark.createDataset(data).map(mtc => mtc.value).select(sum($"value")).as[Long].head + |""".stripMargin + val expected = "Long = 55L" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) + } + test("Collect REPL generated class") { val input = """ @@ -275,8 +302,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { | map(mtc => s"MyTestClass(${mtc.value})"). | mkString("[", ", ", "]") """.stripMargin - val output = runCommandsInShell(input) - assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output) + val expected = """String = "[MyTestClass(1), MyTestClass(3)]"""" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) } test("REPL class in encoder") { @@ -288,8 +316,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { | map(mtc => mtc.value). | collect() """.stripMargin - val output = runCommandsInShell(input) - assertContains("Array[Int] = Array(0, 1, 2)", output) + val expected = "Array[Int] = Array(0, 1, 2)" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) } test("REPL class in UDF") { @@ -301,8 +330,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { | map(mtc => s"MyTestClass(${mtc.value})"). | mkString("[", ", ", "]") """.stripMargin - val output = runCommandsInShell(input) - assertContains("""String = "[MyTestClass(0), MyTestClass(1)]"""", output) + val expected = """String = "[MyTestClass(0), MyTestClass(1)]"""" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) } test("streaming works with REPL generated code") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 1e536cd37fec1..a6cd20aff68c0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -357,6 +357,12 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.application.ConnectRepl$" // developer API ), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.application.ExtendedCodeClassWrapper" // developer API + ), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.application.ExtendedCodeClassWrapper$" // developer API + ), // SparkSession // developer API