Skip to content

Commit

Permalink
[PIR] Don't generate 'fixBy' if you don't need to (#5954)
Browse files Browse the repository at this point in the history
Removes unnecessary generation of `fixBy` when we only need `fix`.
  • Loading branch information
effectfully committed Aug 6, 2024
1 parent fb168b1 commit d683fcc
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,51 @@ applyFun = runQuote $ do
. lamAbs () x (TyVar () a)
$ apply () (var () f) (var () x)

{- Note [Recursion combinators]
We create singly recursive and mutually recursive functions using different combinators.
For singly recursive functions we use the Z combinator (a strict cousin of the Y combinator) that in
UPLC looks like this:
\f -> (\s -> s s) (\s -> f (\x -> s s x))
We have benchmarked its Haskell version at
https://github.com/IntersectMBO/plutus/tree/9538fc9829426b2ecb0628d352e2d7af96ec8204/doc/notes/fomega/z-combinator-benchmarks
and observed that in Haskell there's no detectable difference in performance of functions defined
using explicit recursion versus the Z combinator. However Haskell is a compiled language and Plutus
is interpreted, so it's very likely that natively supporting recursion in Plutus instead of
compiling recursive functions to combinators would significantly boost performance.
We've tried using
\f -> (\s -> s s) (\s x -> f (s s) x)
instead of
\f -> (\s -> s s) (\s -> f (\x -> s s x))
and while it worked OK at the PLC level, it wasn't a suitable primitive for compilation of recursive
functions, because it would add laziness in unexpected places, see
https://github.com/IntersectMBO/plutus/issues/5961
so we had to change it.
We use
\f -> (\s -> s s) (\s x -> f (s s) x)
instead of the more standard
\f -> (\s x -> f (s s) x) (\s x -> f (s s) x)
because in practice @f@ gets inlined and we wouldn't be able to do so if it occurred twice in the
term. Plus the former also allows us to save on the size of the term.
For mutually recursive functions we use the 'fixBy' combinator, which is, to the best of our
knowledge, our own invention. It was first described at
https://github.com/IntersectMBO/plutus/blob/067e74f0606fddc5e183dd45209b461e293a6224/doc/notes/fomega/mutual-term-level-recursion/FixN.agda
and fully specified in our "Unraveling recursion: compiling an IR with recursion to System F" paper.
-}

-- | @Self@ as a PLC type.
--
-- > fix \(self :: * -> *) (a :: *) -> self a -> a
Expand Down Expand Up @@ -144,7 +189,6 @@ fixAndType = runQuote $ do
$ TyFun () (TyFun () funAB funAB) funAB
pure (fixTerm, fixType)


-- | A type that looks like a transformation.
--
-- > trans F G Q : F Q -> G Q
Expand Down Expand Up @@ -337,6 +381,7 @@ fixNAndType n fixByTerm = runQuote $ do
]
pure (fixNTerm, fixNType)

-- See Note [Recursion combinators].
-- | Get the fixed-point of a single recursive function.
getSingleFixOf
:: (TermLike term TyName Name uni fun)
Expand All @@ -346,6 +391,7 @@ getSingleFixOf ann fix1 fun@FunctionDef{_functionDefType=(FunctionType _ dom cod
abstractedBody = mkIterLamAbs [functionDefVarDecl fun] $ _functionDefTerm fun
in apply ann instantiatedFix abstractedBody

-- See Note [Recursion combinators].
-- | Get the fixed-point of a list of mutually recursive functions.
--
-- > MutualFixOf _ fixN [ FunctionDef _ fN1 (FunctionType _ a1 b1) f1
Expand Down
11 changes: 6 additions & 5 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler/Recursion.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ mkFixpoint bs = do
name <- liftQuote $ toProgramName fixByKey
let (fixByTerm, fixByType) = Function.fixByAndType
pure (PLC.Def (PLC.VarDecl noProvenance name (noProvenance <$ fixByType)) (noProvenance <$ fixByTerm, Strict), mempty)
fixBy <- lookupOrDefineTerm p0 fixByKey mkFixByDef

let mkFixNDef = do
name <- liftQuote $ toProgramName fixNKey
let ((fixNTerm, fixNType), fixNDeps) =
if arity == 1
then (Function.fixAndType, mempty)
((fixNTerm, fixNType), fixNDeps) <-
if arity == 1
then pure (Function.fixAndType, mempty)
-- fixN depends on fixBy
else (Function.fixNAndType arity (void fixBy), Set.singleton fixByKey)
else do
fixBy <- lookupOrDefineTerm p0 fixByKey mkFixByDef
pure (Function.fixNAndType arity (void fixBy), Set.singleton fixByKey)
pure (PLC.Def (PLC.VarDecl noProvenance name (noProvenance <$ fixNType)) (noProvenance <$ fixNTerm, Strict), fixNDeps)
fixN <- lookupOrDefineTerm p0 fixNKey mkFixNDef

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@
1.1.0
[
[
(lam s_1651 [ s_1651 s_1651 ])
(lam s_1609 [ s_1609 s_1609 ])
(lam
s_1652
s_1610
(lam
i_1653
i_1611
[
[
[
[
(force (builtin ifThenElse))
[ [ (builtin equalsInteger) (con integer 0) ] i_1653 ]
[ [ (builtin equalsInteger) (con integer 0) ] i_1611 ]
]
(lam u_1654 (con integer 1))
(lam u_1612 (con integer 1))
]
(lam
u_1655
u_1613
[
[ (builtin multiplyInteger) i_1653 ]
[ (builtin multiplyInteger) i_1611 ]
[
(lam x_1656 [ [ s_1652 s_1652 ] x_1656 ])
[ [ (builtin subtractInteger) i_1653 ] (con integer 1) ]
(lam x_1614 [ [ s_1610 s_1610 ] x_1614 ])
[ [ (builtin subtractInteger) i_1611 ] (con integer 1) ]
]
]
)
Expand Down
Loading

0 comments on commit d683fcc

Please sign in to comment.