Skip to content

Commit

Permalink
Use Absolute-indexes as keys for the Env-mapping during closure-conve…
Browse files Browse the repository at this point in the history
…rsion (#11912)

* Use Absolute-indexes as keys for the Env-mapping during closure-conversion.

Do runtime check to confirm behaviour matches the existing Relative-indexes.

changelog_begin
changelog_end

* remove quadratic shift!

remove (dev)pretty-print code
remove relative-index keys from Env-mapping
remove runtime *diff* check
increase depth for stack-safety tests

* improve/simplify indexing calculation for Env-keys
  • Loading branch information
nickchapman-da authored Nov 30, 2021
1 parent 1d7bca8 commit 0ee4154
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,45 @@ private[speedy] object ClosureConversion {

private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {

// TODO: Recode the 'Env' management to avoid the polynomial-complexity of 'shift'. Issue #11830
case class Env(depth: Int, mapping: Map[Int, target.SELoc]) {
case class Abs(a: Int) // absolute variable index, determined by tracking sourceDepth

def lookup(i: Int): target.SELoc =
mapping.get(i) match {
case class Env(sourceDepth: Int, mapping: Map[Abs, target.SELoc], targetDepth: Int) {

def lookup(abs: Abs): target.SELoc = {
mapping.get(abs) match {
case Some(loc) => loc
case None =>
throw sys.error(s"lookup($i),in:$mapping")
throw sys.error(s"lookup($abs),in:$mapping")
}
}

def shift(n: Int): Env = {
// We just update the keys of the map (the relative-indexes from the original SEVar)
val m1 = mapping.map { case (k, loc) => (n + k, loc) }
// And create mappings for the `n` new stack items
val m2 = (1 to n).view.map { rel =>
val abs = this.depth + n - rel
(rel, target.SELocAbsoluteS(abs))
def extend(n: Int): Env = {
// Create mappings for `n` new stack items, and combine with the (unshifted!) existing mapping.
val m2 = (0 until n).view.map { i =>
val abs = Abs(sourceDepth + i)
(abs, target.SELocAbsoluteS(targetDepth + i))
}
Env(this.depth + n, m1 ++ m2)
Env(sourceDepth + n, mapping ++ m2, targetDepth + n)
}
}

object Env {
def apply(): Env = {
Env(0, Map.empty)
}
def absBody(arity: Int, fvs: List[Int]): Env = {
val newRemapsF: Map[Int, target.SELoc] = fvs.view.zipWithIndex.map { case (orig, i) =>
(orig + arity) -> target.SELocF(i)
def absBody(arity: Int, fvs: List[Abs]): Env = {
val newRemapsF: Map[Abs, target.SELoc] = fvs.view.zipWithIndex.map { case (abs, i) =>
abs -> target.SELocF(i)
}.toMap
val newRemapsA = (1 to arity).view.map { case i =>
i -> target.SELocA(arity - i)
val newRemapsA = (0 until arity).view.map { case i =>
val abs = Abs(sourceDepth + i)
abs -> target.SELocA(i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val m1 = newRemapsF ++ newRemapsA
Env(0, m1)
// Only targetDepth is reset to 0 in an abstraction body
Env(sourceDepth + arity, m1, 0)
}
}

object Env {
def apply(): Env = {
Env(0, Map.empty, 0)
}
}

Expand Down Expand Up @@ -173,7 +176,10 @@ private[speedy] object ClosureConversion {
// Going Down: match on expression form...
case Down(exp, env) =>
exp match {
case source.SEVar(i) => loop(Up(env.lookup(i)), conts)
case source.SEVar(r) =>
val abs = Abs(env.sourceDepth - r)
loop(Up(env.lookup(abs)), conts)

case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
case source.SEValue(x) => loop(Up(target.SEValue(x)), conts)
Expand All @@ -182,9 +188,11 @@ private[speedy] object ClosureConversion {
loop(Down(body, env), Cont.Location(loc) :: conts)

case source.SEAbs(arity, body) =>
val fvsAsListInt = freeVars(body, arity).toList.sorted
val fvs = fvsAsListInt.map(i => env.lookup(i))
loop(Down(body, Env.absBody(arity, fvsAsListInt)), Cont.Abs(arity, fvs) :: conts)
val fvsAsListAbs = freeVars(body, arity).toList.sorted.map { r =>
Abs(env.sourceDepth - r)
}
val fvs = fvsAsListAbs.map { abs => env.lookup(abs) }
loop(Down(body, env.absBody(arity, fvsAsListAbs)), Cont.Abs(arity, fvs) :: conts)

case source.SEApp(fun, args) =>
loop(Down(fun, env), Cont.App1(env, args) :: conts)
Expand Down Expand Up @@ -255,7 +263,8 @@ private[speedy] object ClosureConversion {
loop(Up(target.SECase(scrut, Nil)), conts)
case source.SCaseAlt(pat, rhs) :: alts =>
val n = pat.numArgs
loop(Down(rhs, env.shift(n)), Cont.Case2(scrut, Nil, pat, env, alts) :: conts)
val env1 = env.extend(n)
loop(Down(rhs, env1), Cont.Case2(scrut, Nil, pat, env, alts) :: conts)
}

case Cont.Case2(scrut, altsDone0, pat, env, alts) =>
Expand All @@ -265,14 +274,14 @@ private[speedy] object ClosureConversion {
loop(Up(target.SECase(scrut, altsDone.reverse)), conts)
case source.SCaseAlt(pat, rhs) :: alts =>
val n = pat.numArgs
val env1 = env.shift(n)
val env1 = env.extend(n)
loop(Down(rhs, env1), Cont.Case2(scrut, altsDone, pat, env, alts) :: conts)
}

case Cont.Let1(boundsDone0, env, bounds, body) =>
val boundsDone = result :: boundsDone0
val depth = boundsDone.length
val env1 = env.shift(depth)
val n = boundsDone.length
val env1 = env.extend(n)
bounds match {
case Nil =>
loop(Down(body, env1), Cont.Let2(boundsDone) :: conts)
Expand All @@ -286,7 +295,7 @@ private[speedy] object ClosureConversion {

case Cont.TryCatch1(env, handler) =>
val body = result
loop(Down(handler, env.shift(1)), Cont.TryCatch2(body) :: conts)
loop(Down(handler, env.extend(1)), Cont.TryCatch2(body) :: conts)

case Cont.TryCatch2(body) =>
val handler = result
Expand Down Expand Up @@ -321,7 +330,12 @@ private[speedy] object ClosureConversion {
def go(expr: source.SExpr, bound: Int, free: Set[Int]): Set[Int] =
expr match {
case source.SEVar(i) =>
if (i > bound) free + (i - bound) else free /* adjust to caller's environment */
if (i > bound) {
val rel = (i - bound) /* adjust to caller's environment */
free + rel
} else {
free
}
case _: source.SEVal => free
case _: source.SEBuiltin => free
case _: source.SEValue => free
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class ClosureConversionTest extends AnyFreeSpec with Matchers with TableDrivenPr
("Let1", let1),
("TryCatch2", tryCatch2),
("Labelclosure", labelClosure),
("Alt1", alt1),
("Alt2", alt2),
("Let2", let2),
("LetBody", letBody),
("TryCatch1", tryCatch1),
)
}

Expand All @@ -82,28 +87,10 @@ class ClosureConversionTest extends AnyFreeSpec with Matchers with TableDrivenPr
Table[String, SExpr => SExpr](
("name", "recursion-point"),
("Abs", abs1),
("Alt1", alt1),
("Alt2", alt2),
("Let2", let2),
("LetBody", letBody),
("TryCatch1", tryCatch1),
)
}

{
// All tests. Shallow enough for pre-stack-safe closure-conversion code to pass.
val depth = 100
s"depth = $depth" - {
forEvery(testCases1 ++ testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
name in {
runTest(depth, recursionPoint)
}
}
}
}

{
// Only first set. At this depth we can be really sure that we are stack-safe.
val depth = 100000
s"depth = $depth" - {
forEvery(testCases1) { (name: String, recursionPoint: SExpr => SExpr) =>
Expand All @@ -115,11 +102,10 @@ class ClosureConversionTest extends AnyFreeSpec with Matchers with TableDrivenPr
}

{
// Only 2nd set. This depth is not really deep enough to ensure stack-safety, but
// much deeper and the quadratic-or-worse time-complexity starts to seriously slow
// down the test run.
// TODO: fix quadratic time issue to allow these tests to be run at depth 100000.
val depth = 1000
// TODO: There remains a quadratic issue with the freeVars calculation (#11830).
// This affects only Abs testcase. It takes 12s when run to a larger depth of 100k.
// So we only run to 10k.
val depth = 10000
s"depth = $depth" - {
forEvery(testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
name in {
Expand Down

0 comments on commit 0ee4154

Please sign in to comment.