Skip to content

Commit

Permalink
feat: InlineIdentity optimization pass
Browse files Browse the repository at this point in the history
  • Loading branch information
nau committed Dec 20, 2024
1 parent dde4328 commit 3bd4e4d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion shared/src/main/scala/scalus/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scalus.sir.SIR
import scalus.sir.SimpleSirToUplcLowering
import scalus.uplc.Constant
import scalus.uplc.DefaultUni
import scalus.uplc.InlineIdentity
import scalus.uplc.Program
import scalus.uplc.Term
import scalus.uplc.eval.Result
Expand All @@ -32,7 +33,7 @@ package object scalus {
SimpleSirToUplcLowering(sir, generateErrorTraces).lower()
def toUplcOptimized(generateErrorTraces: Boolean = false): Term =
OptimizingSirToUplcLowering(sir |> RemoveRecursivity.apply, generateErrorTraces)
.lower() |> EtaReduce.apply
.lower() |> EtaReduce.apply |> InlineIdentity.inlineIdentity

def toPlutusProgram(
version: (Int, Int, Int),
Expand Down
19 changes: 19 additions & 0 deletions shared/src/main/scala/scalus/uplc/InlineIdentity.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package scalus.uplc
import scalus.uplc.Term.*

/** Inlines identity function application */
object InlineIdentity:
/** Inlines identity function application */
def apply(term: Term): Term = inlineIdentity(term)

/** Inlines identity function application */
def inlineIdentity(term: Term): Term = term match
case Apply(LamAbs(param, Var(NamedDeBruijn(name, _))), arg) if param == name =>
inlineIdentity(arg)
case Apply(f, arg) => Apply(inlineIdentity(f), inlineIdentity(arg))
case Force(term) => Force(inlineIdentity(term))
case Delay(term) => Delay(inlineIdentity(term))
case LamAbs(param, body) => LamAbs(param, inlineIdentity(body))
case Constr(tag, args) => Constr(tag, args.map(inlineIdentity))
case Case(arg, cases) => Case(inlineIdentity(arg), cases.map(inlineIdentity))
case _ => term
13 changes: 13 additions & 0 deletions shared/src/test/scala/scalus/uplc/InlineIdentitySpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package scalus
package uplc

import scalus.uplc.Term.*
import org.scalatest.funsuite.AnyFunSuite

class InlineIdentitySpec extends AnyFunSuite {
test("inlineIdentity should inline identity function application") {
val term = Apply(LamAbs("x", Var(NamedDeBruijn("x"))), Var(NamedDeBruijn("y")))
val expected = Var(NamedDeBruijn("y"))
assert(InlineIdentity.inlineIdentity(term) == expected)
}
}

0 comments on commit 3bd4e4d

Please sign in to comment.