Skip to content

Commit

Permalink
Refine solution
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Aug 14, 2023
1 parent 6d4891b commit 8907488
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,9 @@ import org.apache.spark.util.SparkClassUtils

object OuterScopes {
private[this] val queue = new ReferenceQueue[AnyRef]
private class HashableWeakReference(v: AnyRef) extends WeakReference[AnyRef](v, queue) {
private[this] val hash = v.hashCode()
override def hashCode(): Int = hash
override def equals(obj: Any): Boolean = {
obj match {
case other: HashableWeakReference =>
// Note that referential equality is used to identify & purge
// references from the map whose' referent went out of scope.
if (this eq other) {
true
} else {
val referent = get()
val otherReferent = other.get()
referent != null && otherReferent != null && Objects.equals(referent, otherReferent)
}
case _ => false
}
}
}

private def classLoaderRef(c: Class[_]): HashableWeakReference = {
new HashableWeakReference(c.getClassLoader)
new HashableWeakReference(c.getClassLoader, queue)
}

private[this] val outerScopes = {
Expand Down Expand Up @@ -154,3 +135,31 @@ object OuterScopes {
// e.g. `ammonite.$sess.cmd8$Helper$Foo` -> `ammonite.$sess.cmd8.instance.Foo`
private[this] val AmmoniteREPLClass = """^(ammonite\.\$sess\.cmd(?:\d+)\$).*""".r
}

/**
* A [[WeakReference]] that has a stable hash-key. When the referent is still alive we will use
* the referent for equality, once it is dead it we will fallback to referential equality. This
* way you can still do lookups in a map when the referent is alive, and are capable of removing
* dead entries after GC (using a [[ReferenceQueue]]).
*/
private[catalyst] class HashableWeakReference(v: AnyRef, queue: ReferenceQueue[AnyRef])
extends WeakReference[AnyRef](v, queue) {
def this(v: AnyRef) = this(v, null)
private[this] val hash = v.hashCode()
override def hashCode(): Int = hash
override def equals(obj: Any): Boolean = {
obj match {
case other: HashableWeakReference =>
// Note that referential equality is used to identify & purge
// references from the map whose' referent went out of scope.
if (this eq other) {
true
} else {
val referent = get()
val otherReferent = other.get()
referent != null && otherReferent != null && Objects.equals(referent, otherReferent)
}
case _ => false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import java.io.ByteArrayInputStream
import java.util.UUID

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.codehaus.commons.compiler.{CompileException, InternalCompilerException}
import org.codehaus.janino.ClassBodyEvaluator
Expand All @@ -37,6 +35,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.HashableWeakReference
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.types._
Expand Down Expand Up @@ -1441,7 +1440,8 @@ object CodeGenerator extends Logging {
* @return a pair of a generated class and the bytecode statistics of generated functions.
*/
def compile(code: CodeAndComment): (GeneratedClass, ByteCodeStats) = try {
cache.get((classLoaderUUID.get(Utils.getContextOrSparkClassLoader), code))
val classLoaderRef = new HashableWeakReference(Utils.getContextOrSparkClassLoader)
cache.get((classLoaderRef, code))
} catch {
// Cache.get() may wrap the original exception. See the following URL
// https://guava.dev/releases/14.0.1/api/docs/com/google/common/cache/
Expand Down Expand Up @@ -1583,7 +1583,7 @@ object CodeGenerator extends Logging {
* aborted. See [[NonFateSharingCache]] for more details.
*/
private val cache = {
val loadFunc: ((ClassLoaderId, CodeAndComment)) => (GeneratedClass, ByteCodeStats) = {
val loadFunc: ((HashableWeakReference, CodeAndComment)) => (GeneratedClass, ByteCodeStats) = {
case (_, code) =>
val startTime = System.nanoTime()
val result = doCompile(code)
Expand All @@ -1599,16 +1599,6 @@ object CodeGenerator extends Logging {
NonFateSharingCache(loadFunc, SQLConf.get.codegenCacheMaxEntries)
}

type ClassLoaderId = String
private val classLoaderUUID = {
NonFateSharingCache(CacheBuilder.newBuilder()
.weakKeys
.maximumSize(SQLConf.get.codegenCacheMaxEntries)
.build(new CacheLoader[ClassLoader, ClassLoaderId]() {
override def load(code: ClassLoader): ClassLoaderId = UUID.randomUUID.toString
}))
}

/**
* Name of Java primitive data type
*/
Expand Down

0 comments on commit 8907488

Please sign in to comment.