Skip to content

Commit

Permalink
Add simplifier tags in UPLC simplifier (#6540)
Browse files Browse the repository at this point in the history
Co-authored-by: Ramsay Taylor <ramsay.taylor@iohk.io>
  • Loading branch information
ana-pantilie and ramsay-t authored Oct 18, 2024
1 parent 557a2c7 commit 388ad69
Show file tree
Hide file tree
Showing 23 changed files with 340 additions and 208 deletions.
7 changes: 2 additions & 5 deletions plutus-core/executables/plutus/AnyProgram/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ import Control.Lens hiding ((%~))
import Control.Monad.Error.Lens
import Control.Monad.Except (MonadError)
import Control.Monad.Reader
import Control.Monad.State (StateT (runStateT))
import Data.Singletons.Decide
import Data.Text
import PlutusCore.Compiler.Types (initUPLCSimplifierTrace)
import PlutusPrelude hiding ((%~))

-- Note that we use for erroring the original term's annotation
Expand Down Expand Up @@ -113,7 +111,7 @@ compileProgram = curry $ \case
-- first self-"compile" to plc (just for reusing code)
compileProgram sng1 (SPlc n2 a2)
-- PLC.compileProgram subsumes uplcOptimise
>=> (PLC.runQuoteT . PLC.evalCompile PLC.defaultCompilationOpts .
>=> (PLC.runQuoteT . flip runReaderT PLC.defaultCompilationOpts .
plcToUplcViaName n2 PLC.compileProgram)
>=> pure . UPLC.UnrestrictedProgram

Expand Down Expand Up @@ -230,8 +228,7 @@ uplcOptimise =
case safeOrUnsafe of
SafeOptimise -> set UPLC.soConservativeOpts True
UnsafeOptimise -> id
in (fmap . fmap) fst
. fmap (PLC.runQuoteT . flip runStateT initUPLCSimplifierTrace)
in fmap PLC.runQuoteT
. _Wrapped
. uplcViaName (UPLC.simplifyProgram sOpts def)

Expand Down
1 change: 1 addition & 0 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ library
UntypedPlutusCore.Rename
UntypedPlutusCore.Size
UntypedPlutusCore.Transform.CaseOfCase
UntypedPlutusCore.Transform.Simplifier

other-modules:
Data.Aeson.Flatten
Expand Down
30 changes: 1 addition & 29 deletions plutus-core/plutus-core/src/PlutusCore/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module PlutusCore.Compiler
( module Opts
, compileTerm
, compileProgram
, runCompile
, evalCompile
) where

import PlutusCore.Compiler.Erase
Expand All @@ -16,14 +14,12 @@ import UntypedPlutusCore.Core.Type qualified as UPLC
import UntypedPlutusCore.Simplify qualified as UPLC

import Control.Lens (view)
import Control.Monad.Reader (MonadReader, ReaderT (runReaderT))
import Control.Monad.State (MonadState (..), StateT (runStateT))
import Control.Monad.Reader (MonadReader)

-- | Compile a PLC term to UPLC, and optimize it.
compileTerm
:: (Compiling m uni fun name a
, MonadReader (CompilationOpts name fun a) m
, MonadState (UPLCSimplifierTrace name uni fun a) m
)
=> Term tyname name uni fun a
-> m (UPLC.Term name uni fun a)
Expand All @@ -38,31 +34,7 @@ compileTerm t = do
compileProgram
:: (Compiling m uni fun name a
, MonadReader (CompilationOpts name fun a) m
, MonadState (UPLCSimplifierTrace name uni fun a) m
)
=> Program tyname name uni fun a
-> m (UPLC.Program name uni fun a)
compileProgram (Program a v t) = UPLC.Program a v <$> compileTerm t

type Compile m name uni fun a =
ReaderT
(CompilationOpts name fun a)
(StateT
(UPLCSimplifierTrace name uni fun a)
m
)

runCompile
:: CompilationOpts name fun a
-> Compile m name uni fun a b
-> m (b, UPLCSimplifierTrace name uni fun a)
runCompile opts =
flip runStateT initUPLCSimplifierTrace
. flip runReaderT opts

evalCompile
:: Functor m
=> CompilationOpts name fun a
-> Compile m name uni fun a b
-> m b
evalCompile opts = fmap fst . runCompile opts
13 changes: 0 additions & 13 deletions plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@ import Data.Hashable
import PlutusCore.Builtin
import PlutusCore.Name.Unique
import PlutusCore.Quote
import UntypedPlutusCore.Core.Type qualified as UPLC

-- TODO1: move somewhere more appropriate?
-- TODO2: we probably don't want this in memory so after MVP
-- we should consider serializing this to disk
newtype UPLCSimplifierTrace name uni fun a =
UPLCSimplifierTrace
{ uplcSimplifierTrace
:: [UPLC.Term name uni fun a]
}

initUPLCSimplifierTrace :: UPLCSimplifierTrace name uni fun a
initUPLCSimplifierTrace = UPLCSimplifierTrace []

type Compiling m uni fun name a =
( ToBuiltinMeaning uni fun
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ compileTplcProgramOrFail
compileTplcProgramOrFail plcProgram =
handlePirErrorByFailing @SrcSpan =<< do
TPLC.compileProgram plcProgram
& TPLC.evalCompile TPLC.defaultCompilationOpts
& flip runReaderT TPLC.defaultCompilationOpts
& runQuoteT
& runExceptT

Expand Down
2 changes: 1 addition & 1 deletion plutus-core/testlib/PlutusCore/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ instance
toUPlc =
pure
. TPLC.runQuote
. TPLC.evalCompile TPLC.defaultCompilationOpts
. flip runReaderT TPLC.defaultCompilationOpts
. TPLC.compileProgram

instance ToUPlc (UPLC.Program UPLC.NamedDeBruijn uni fun ()) uni fun where
Expand Down
81 changes: 52 additions & 29 deletions plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module UntypedPlutusCore.Simplify (
module Opts,
simplifyTerm,
simplifyProgram,
simplifyProgramWithTrace,
InlineHints (..),
) where

Expand All @@ -21,74 +22,96 @@ import UntypedPlutusCore.Transform.Cse
import UntypedPlutusCore.Transform.FloatDelay (floatDelay)
import UntypedPlutusCore.Transform.ForceDelay (forceDelay)
import UntypedPlutusCore.Transform.Inline (InlineHints (..), inline)
import UntypedPlutusCore.Transform.Simplifier

import Control.Monad
import Control.Monad.State.Class (MonadState)
import Control.Monad.State.Class qualified as State
import Data.List as List (foldl')
import Data.Typeable

simplifyProgram ::
forall name uni fun m a.
(Compiling m uni fun name a
, MonadState (UPLCSimplifierTrace name uni fun a) m
) =>
(Compiling m uni fun name a) =>
SimplifyOpts name a ->
BuiltinSemanticsVariant fun ->
Program name uni fun a ->
m (Program name uni fun a)
simplifyProgram opts builtinSemanticsVariant (Program a v t) =
Program a v <$> simplifyTerm opts builtinSemanticsVariant t

simplifyProgramWithTrace ::
forall name uni fun m a.
(Compiling m uni fun name a) =>
SimplifyOpts name a ->
BuiltinSemanticsVariant fun ->
Program name uni fun a ->
m (Program name uni fun a, SimplifierTrace name uni fun a)
simplifyProgramWithTrace opts builtinSemanticsVariant (Program a v t) = do
(result, trace) <-
runSimplifierT
$ termSimplifier opts builtinSemanticsVariant t
pure (Program a v result, trace)

simplifyTerm ::
forall name uni fun m a.
( Compiling m uni fun name a
, MonadState (UPLCSimplifierTrace name uni fun a) m
) =>
(Compiling m uni fun name a) =>
SimplifyOpts name a ->
BuiltinSemanticsVariant fun ->
Term name uni fun a ->
m (Term name uni fun a)
simplifyTerm opts builtinSemanticsVariant =
simplifyTerm opts builtinSemanticsVariant term =
evalSimplifierT $ termSimplifier opts builtinSemanticsVariant term

termSimplifier ::
forall name uni fun m a.
(Compiling m uni fun name a) =>
SimplifyOpts name a ->
BuiltinSemanticsVariant fun ->
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
termSimplifier opts builtinSemanticsVariant =
simplifyNTimes (_soMaxSimplifierIterations opts) >=> cseNTimes cseTimes
where
-- Run the simplifier @n@ times
simplifyNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyNTimes ::
Int ->
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
simplifyNTimes n = List.foldl' (>=>) pure $ map simplifyStep [1..n]

-- Run CSE @n@ times, interleaved with the simplifier.
-- See Note [CSE]
cseNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
cseNTimes ::
Int ->
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
cseNTimes n = foldl' (>=>) pure $ concatMap (\i -> [cseStep i, simplifyStep i]) [1..n]

-- generate simplification step
simplifyStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyStep ::
Int ->
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
simplifyStep _ =
traceAST
>=> floatDelay
>=> traceAST
>=> pure . forceDelay
>=> traceAST
>=> pure . caseOfCase'
>=> traceAST
>=> pure . caseReduce
>=> traceAST
floatDelay
>=> forceDelay
>=> caseOfCase'
>=> caseReduce
>=> inline (_soInlineConstants opts) (_soInlineHints opts) builtinSemanticsVariant
>=> traceAST

caseOfCase' :: Term name uni fun a -> Term name uni fun a
caseOfCase' ::
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
caseOfCase' = case eqT @fun @DefaultFun of
Just Refl -> caseOfCase
Nothing -> id
Nothing -> pure

cseStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
cseStep ::
Int ->
Term name uni fun a ->
SimplifierT name uni fun a m (Term name uni fun a)
cseStep _ =
case (eqT @name @Name, eqT @uni @PLC.DefaultUni) of
(Just Refl, Just Refl) -> cse builtinSemanticsVariant
_ -> pure

traceAST ast = do
State.modify' (\st -> st { uplcSimplifierTrace = uplcSimplifierTrace st ++ [ast] })
return ast

cseTimes = if _soConservativeOpts opts then 0 else _soMaxCseIterations opts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,23 @@ module UntypedPlutusCore.Transform.CaseOfCase (caseOfCase) where
import PlutusCore qualified as PLC
import PlutusCore.MkPlc
import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseOfCase), SimplifierT,
recordSimplification)

import Control.Lens

caseOfCase :: (fun ~ PLC.DefaultFun) => Term name uni fun a -> Term name uni fun a
caseOfCase = transformOf termSubterms $ \case
caseOfCase
:: fun ~ PLC.DefaultFun
=> Monad m
=> Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
caseOfCase term = do
let result = transformOf termSubterms processTerm term
recordSimplification term CaseOfCase result
return result

processTerm :: (fun ~ PLC.DefaultFun) => Term name uni fun a -> Term name uni fun a
processTerm = \case
Case ann scrut alts
| ( ite@(Force a (Builtin _ PLC.IfThenElse))
, [cond, (trueAnn, true@Constr{}), (falseAnn, false@Constr{})]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ module UntypedPlutusCore.Transform.CaseReduce

import PlutusCore.MkPlc
import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseReduce), SimplifierT,
recordSimplification)

import Control.Lens (transformOf)
import Data.Vector qualified as V

caseReduce :: Term name uni fun a -> Term name uni fun a
caseReduce = transformOf termSubterms processTerm
caseReduce
:: Monad m
=> Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
caseReduce term = do
let result = transformOf termSubterms processTerm term
recordSimplification term CaseReduce result
return result

processTerm :: Term name uni fun a -> Term name uni fun a
processTerm = \case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import PlutusCore.Builtin (ToBuiltinMeaning (BuiltinSemanticsVariant))
import UntypedPlutusCore.Core
import UntypedPlutusCore.Purity (isWorkFree)
import UntypedPlutusCore.Size (termSize)
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CSE), SimplifierT,
recordSimplification)

import Control.Arrow ((>>>))
import Control.Lens (foldrOf, transformOf)
Expand Down Expand Up @@ -215,7 +217,7 @@ cse ::
) =>
BuiltinSemanticsVariant fun ->
Term Name uni fun ann ->
m (Term Name uni fun ann)
SimplifierT Name uni fun ann m (Term Name uni fun ann)
cse builtinSemanticsVariant t0 = do
t <- rename t0
let annotated = annotate t
Expand All @@ -229,7 +231,9 @@ cse builtinSemanticsVariant t0 = do
. join
. Map.elems
$ countOccs builtinSemanticsVariant annotated
mkCseTerm commonSubexprs annotated
result <- mkCseTerm commonSubexprs annotated
recordSimplification t0 CSE result
return result

-- | The second pass. See Note [CSE].
annotate :: Term name uni fun ann -> Term name uni fun (Path, ann)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,25 @@ import PlutusCore.Name.UniqueMap qualified as UMap
import PlutusCore.Name.UniqueSet qualified as USet
import UntypedPlutusCore.Core.Plated (termSubterms)
import UntypedPlutusCore.Core.Type (Term (..))
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (FloatDelay), SimplifierT,
recordSimplification)

import Control.Lens (forOf, forOf_, transformOf)
import Control.Monad ((>=>))
import Control.Monad.Trans.Writer.CPS (Writer, execWriter, runWriter, tell)

floatDelay ::
(PLC.MonadQuote m, PLC.Rename (Term name uni fun a), PLC.HasUnique name PLC.TermUnique) =>
( PLC.MonadQuote m
, PLC.Rename (Term name uni fun a)
, PLC.HasUnique name PLC.TermUnique
) =>
Term name uni fun a ->
m (Term name uni fun a)
floatDelay =
PLC.rename >=> \t ->
pure . uncurry (flip simplifyBodies) $ simplifyArgs (unforcedVars t) t
SimplifierT name uni fun a m (Term name uni fun a)
floatDelay term = do
result <-
PLC.rename term >>= \t ->
pure . uncurry (flip simplifyBodies) $ simplifyArgs (unforcedVars t) t
recordSimplification term FloatDelay result
return result

{- | First pass. Returns the names of all variables, at least one occurrence
of which is not under `Force`.
Expand Down
Loading

1 comment on commit 388ad69

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Plutus Benchmarks'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.05.

Benchmark suite Current: 388ad69 Previous: 557a2c7 Ratio
validation-multisig-sm-9 557.9 μs 514.9 μs 1.08
validation-multisig-sm-10 791.6 μs 554.6 μs 1.43
validation-ping-pong-1 464.3 μs 324.2 μs 1.43
validation-ping-pong-2 464.8 μs 326.5 μs 1.42
validation-ping-pong_2-1 284.4 μs 199.9 μs 1.42
validation-prism-1 239.5 μs 167.9 μs 1.43
validation-prism-2 582.8 μs 411.7 μs 1.42
validation-prism-3 525.8 μs 368.3 μs 1.43
validation-pubkey-1 201.5 μs 141.7 μs 1.42
validation-stablecoin_1-1 1297 μs 909.6 μs 1.43
validation-stablecoin_1-2 277.4 μs 196.1 μs 1.41
validation-stablecoin_1-3 1491 μs 1047 μs 1.42
validation-stablecoin_1-4 296.5 μs 207.9 μs 1.43
validation-stablecoin_1-5 1918 μs 1348 μs 1.42
validation-stablecoin_1-6 327.4 μs 258 μs 1.27
validation-uniswap-2 256.3 μs 219.7 μs 1.17
validation-uniswap-4 479.3 μs 336 μs 1.43
validation-uniswap-5 1225 μs 1166 μs 1.05
validation-decode-auction_1-1 262.6 μs 189.7 μs 1.38
validation-decode-auction_1-2 731.7 μs 524.6 μs 1.39
validation-decode-auction_1-3 731.2 μs 524.9 μs 1.39
validation-decode-auction_1-4 262.6 μs 191.2 μs 1.37
validation-decode-auction_2-1 263.9 μs 196.3 μs 1.34
validation-decode-auction_2-3 568.9 μs 524.6 μs 1.08
validation-decode-game-sm-success_2-6 224.7 μs 161.6 μs 1.39
validation-decode-multisig-sm-2 757.4 μs 566.8 μs 1.34
validation-decode-multisig-sm-7 718.9 μs 567.8 μs 1.27
validation-decode-multisig-sm-8 787.3 μs 567.9 μs 1.39
validation-decode-multisig-sm-9 789.2 μs 567.7 μs 1.39
validation-decode-multisig-sm-10 788.7 μs 567.4 μs 1.39
validation-decode-ping-pong-1 655.7 μs 476.5 μs 1.38
validation-decode-ping-pong_2-1 661.4 μs 476.3 μs 1.39
validation-decode-prism-1 220.8 μs 175.3 μs 1.26
validation-decode-token-account-1 319 μs 229.9 μs 1.39
validation-decode-token-account-2 268.8 μs 215.5 μs 1.25
nofib-primetest/10digits 28750 μs 23740 μs 1.21
nofib-primetest/30digits 88650 μs 62250 μs 1.42
nofib-primetest/50digits 146600 μs 103100 μs 1.42
nofib-queens4x4/bt 7586 μs 5604 μs 1.35
marlowe-semantics/0000020002010200020101020201000100010001020101020201010000020102 456.1 μs 321.5 μs 1.42
marlowe-semantics/004025fd712d6c325ffa12c16d157064192992faf62e0b991d7310a2f91666b8 1150 μs 808 μs 1.42
marlowe-semantics/0104010200020000040103020102020004040300030304040400010301040303 1102 μs 998.3 μs 1.10
marlowe-semantics/04000f0b04051006000e060f09080d0b090d0104050a0b0f0506070f0a070008 1040 μs 733.2 μs 1.42
marlowe-semantics/0543a00ba1f63076c1db6bf94c6ff13ae7d266dd7544678743890b0e8e1add63 1442 μs 1012.9999999999999 μs 1.42
marlowe-semantics/0705030002040601010206030604080208020207000101060706050502040301 1405 μs 986.8 μs 1.42
marlowe-semantics/07070c070510030509010e050d00040907050e0a0d06030f1006030701020607 1460 μs 1025 μs 1.42
marlowe-semantics/0bcfd9487614104ec48de2ea0b2c0979866a95115748c026f9ec129384c262c4 1593 μs 1121 μs 1.42
marlowe-semantics/0be82588e4e4bf2ef428d2f44b7687bbb703031d8de696d90ec789e70d6bc1d8 1941 μs 1568 μs 1.24
marlowe-semantics/1d56060c3b271226064c672a282663643b1b0823471c67737f0b076870331260 1101 μs 1039 μs 1.06
marlowe-semantics/1d6e3c137149a440f35e0efc685b16bfb8052ebcf66ec4ad77e51c11501381c7 429.5 μs 302 μs 1.42
marlowe-semantics/1f0f02191604101e1f201016171604060d010d1d1c150e110a110e1006160a0d 1452 μs 1016.9999999999999 μs 1.43

This comment was automatically generated by workflow using github-action-benchmark.

CC: @IntersectMBO/plutus-core

Please sign in to comment.