From 388ad69e43850674efe1e78a121f0c256ff62de4 Mon Sep 17 00:00:00 2001 From: Ana Pantilie <45069775+ana-pantilie@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:41:20 +0300 Subject: [PATCH] Add simplifier tags in UPLC simplifier (#6540) Co-authored-by: Ramsay Taylor --- .../executables/plutus/AnyProgram/Compile.hs | 7 +- plutus-core/plutus-core.cabal | 1 + .../plutus-core/src/PlutusCore/Compiler.hs | 30 +--- .../src/PlutusCore/Compiler/Types.hs | 13 -- .../Transform/StrictLetRec/Tests/Lib.hs | 2 +- plutus-core/testlib/PlutusCore/Test.hs | 2 +- .../src/UntypedPlutusCore/Simplify.hs | 81 +++++---- .../UntypedPlutusCore/Transform/CaseOfCase.hs | 16 +- .../UntypedPlutusCore/Transform/CaseReduce.hs | 12 +- .../src/UntypedPlutusCore/Transform/Cse.hs | 8 +- .../UntypedPlutusCore/Transform/FloatDelay.hs | 19 +- .../UntypedPlutusCore/Transform/ForceDelay.hs | 12 +- .../src/UntypedPlutusCore/Transform/Inline.hs | 23 ++- .../UntypedPlutusCore/Transform/Simplifier.hs | 96 +++++++++++ .../test/Transform/CaseOfCase/Test.hs | 10 +- .../test/Transform/Simplify/Lib.hs | 4 - plutus-executables/executables/pir/Main.hs | 2 +- plutus-executables/executables/uplc/Main.hs | 20 +-- plutus-metatheory/plutus-metatheory.cabal | 6 + .../src/VerifiedCompilation.lagda.md | 163 +++++++++--------- .../VerifiedCompilation/UCaseOfCase.lagda.md | 10 +- plutus-tx-plugin/src/PlutusTx/Plugin.hs | 4 +- plutus-tx/src/PlutusTx/Lift.hs | 7 +- 23 files changed, 340 insertions(+), 208 deletions(-) create mode 100644 plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Simplifier.hs diff --git a/plutus-core/executables/plutus/AnyProgram/Compile.hs b/plutus-core/executables/plutus/AnyProgram/Compile.hs index 41cbde1f67d..610b5e12900 100644 --- a/plutus-core/executables/plutus/AnyProgram/Compile.hs +++ b/plutus-core/executables/plutus/AnyProgram/Compile.hs @@ -31,10 +31,8 @@ import Control.Lens hiding ((%~)) import Control.Monad.Error.Lens import Control.Monad.Except (MonadError) import Control.Monad.Reader -import Control.Monad.State (StateT (runStateT)) import Data.Singletons.Decide import Data.Text -import PlutusCore.Compiler.Types (initUPLCSimplifierTrace) import PlutusPrelude hiding ((%~)) -- Note that we use for erroring the original term's annotation @@ -113,7 +111,7 @@ compileProgram = curry $ \case -- first self-"compile" to plc (just for reusing code) compileProgram sng1 (SPlc n2 a2) -- PLC.compileProgram subsumes uplcOptimise - >=> (PLC.runQuoteT . PLC.evalCompile PLC.defaultCompilationOpts . + >=> (PLC.runQuoteT . flip runReaderT PLC.defaultCompilationOpts . plcToUplcViaName n2 PLC.compileProgram) >=> pure . UPLC.UnrestrictedProgram @@ -230,8 +228,7 @@ uplcOptimise = case safeOrUnsafe of SafeOptimise -> set UPLC.soConservativeOpts True UnsafeOptimise -> id - in (fmap . fmap) fst - . fmap (PLC.runQuoteT . flip runStateT initUPLCSimplifierTrace) + in fmap PLC.runQuoteT . _Wrapped . uplcViaName (UPLC.simplifyProgram sOpts def) diff --git a/plutus-core/plutus-core.cabal b/plutus-core/plutus-core.cabal index 2ac8f789cf7..fdd3affde3d 100644 --- a/plutus-core/plutus-core.cabal +++ b/plutus-core/plutus-core.cabal @@ -212,6 +212,7 @@ library UntypedPlutusCore.Rename UntypedPlutusCore.Size UntypedPlutusCore.Transform.CaseOfCase + UntypedPlutusCore.Transform.Simplifier other-modules: Data.Aeson.Flatten diff --git a/plutus-core/plutus-core/src/PlutusCore/Compiler.hs b/plutus-core/plutus-core/src/PlutusCore/Compiler.hs index 1ada02bf5ef..50c8b5546aa 100644 --- a/plutus-core/plutus-core/src/PlutusCore/Compiler.hs +++ b/plutus-core/plutus-core/src/PlutusCore/Compiler.hs @@ -2,8 +2,6 @@ module PlutusCore.Compiler ( module Opts , compileTerm , compileProgram - , runCompile - , evalCompile ) where import PlutusCore.Compiler.Erase @@ -16,14 +14,12 @@ import UntypedPlutusCore.Core.Type qualified as UPLC import UntypedPlutusCore.Simplify qualified as UPLC import Control.Lens (view) -import Control.Monad.Reader (MonadReader, ReaderT (runReaderT)) -import Control.Monad.State (MonadState (..), StateT (runStateT)) +import Control.Monad.Reader (MonadReader) -- | Compile a PLC term to UPLC, and optimize it. compileTerm :: (Compiling m uni fun name a , MonadReader (CompilationOpts name fun a) m - , MonadState (UPLCSimplifierTrace name uni fun a) m ) => Term tyname name uni fun a -> m (UPLC.Term name uni fun a) @@ -38,31 +34,7 @@ compileTerm t = do compileProgram :: (Compiling m uni fun name a , MonadReader (CompilationOpts name fun a) m - , MonadState (UPLCSimplifierTrace name uni fun a) m ) => Program tyname name uni fun a -> m (UPLC.Program name uni fun a) compileProgram (Program a v t) = UPLC.Program a v <$> compileTerm t - -type Compile m name uni fun a = - ReaderT - (CompilationOpts name fun a) - (StateT - (UPLCSimplifierTrace name uni fun a) - m - ) - -runCompile - :: CompilationOpts name fun a - -> Compile m name uni fun a b - -> m (b, UPLCSimplifierTrace name uni fun a) -runCompile opts = - flip runStateT initUPLCSimplifierTrace - . flip runReaderT opts - -evalCompile - :: Functor m - => CompilationOpts name fun a - -> Compile m name uni fun a b - -> m b -evalCompile opts = fmap fst . runCompile opts diff --git a/plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs b/plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs index 8012a7e767a..a32d6096ff1 100644 --- a/plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs +++ b/plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs @@ -6,19 +6,6 @@ import Data.Hashable import PlutusCore.Builtin import PlutusCore.Name.Unique import PlutusCore.Quote -import UntypedPlutusCore.Core.Type qualified as UPLC - --- TODO1: move somewhere more appropriate? --- TODO2: we probably don't want this in memory so after MVP --- we should consider serializing this to disk -newtype UPLCSimplifierTrace name uni fun a = - UPLCSimplifierTrace - { uplcSimplifierTrace - :: [UPLC.Term name uni fun a] - } - -initUPLCSimplifierTrace :: UPLCSimplifierTrace name uni fun a -initUPLCSimplifierTrace = UPLCSimplifierTrace [] type Compiling m uni fun name a = ( ToBuiltinMeaning uni fun diff --git a/plutus-core/plutus-ir/test/PlutusIR/Transform/StrictLetRec/Tests/Lib.hs b/plutus-core/plutus-ir/test/PlutusIR/Transform/StrictLetRec/Tests/Lib.hs index 354154d375c..17f53f66da2 100644 --- a/plutus-core/plutus-ir/test/PlutusIR/Transform/StrictLetRec/Tests/Lib.hs +++ b/plutus-core/plutus-ir/test/PlutusIR/Transform/StrictLetRec/Tests/Lib.hs @@ -75,7 +75,7 @@ compileTplcProgramOrFail compileTplcProgramOrFail plcProgram = handlePirErrorByFailing @SrcSpan =<< do TPLC.compileProgram plcProgram - & TPLC.evalCompile TPLC.defaultCompilationOpts + & flip runReaderT TPLC.defaultCompilationOpts & runQuoteT & runExceptT diff --git a/plutus-core/testlib/PlutusCore/Test.hs b/plutus-core/testlib/PlutusCore/Test.hs index dcba344b762..cf31956bdc1 100644 --- a/plutus-core/testlib/PlutusCore/Test.hs +++ b/plutus-core/testlib/PlutusCore/Test.hs @@ -193,7 +193,7 @@ instance toUPlc = pure . TPLC.runQuote - . TPLC.evalCompile TPLC.defaultCompilationOpts + . flip runReaderT TPLC.defaultCompilationOpts . TPLC.compileProgram instance ToUPlc (UPLC.Program UPLC.NamedDeBruijn uni fun ()) uni fun where diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Simplify.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Simplify.hs index cf34d24c054..f4cdf837503 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Simplify.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Simplify.hs @@ -6,6 +6,7 @@ module UntypedPlutusCore.Simplify ( module Opts, simplifyTerm, simplifyProgram, + simplifyProgramWithTrace, InlineHints (..), ) where @@ -21,18 +22,15 @@ import UntypedPlutusCore.Transform.Cse import UntypedPlutusCore.Transform.FloatDelay (floatDelay) import UntypedPlutusCore.Transform.ForceDelay (forceDelay) import UntypedPlutusCore.Transform.Inline (InlineHints (..), inline) +import UntypedPlutusCore.Transform.Simplifier import Control.Monad -import Control.Monad.State.Class (MonadState) -import Control.Monad.State.Class qualified as State import Data.List as List (foldl') import Data.Typeable simplifyProgram :: forall name uni fun m a. - (Compiling m uni fun name a - , MonadState (UPLCSimplifierTrace name uni fun a) m - ) => + (Compiling m uni fun name a) => SimplifyOpts name a -> BuiltinSemanticsVariant fun -> Program name uni fun a -> @@ -40,55 +38,80 @@ simplifyProgram :: simplifyProgram opts builtinSemanticsVariant (Program a v t) = Program a v <$> simplifyTerm opts builtinSemanticsVariant t +simplifyProgramWithTrace :: + forall name uni fun m a. + (Compiling m uni fun name a) => + SimplifyOpts name a -> + BuiltinSemanticsVariant fun -> + Program name uni fun a -> + m (Program name uni fun a, SimplifierTrace name uni fun a) +simplifyProgramWithTrace opts builtinSemanticsVariant (Program a v t) = do + (result, trace) <- + runSimplifierT + $ termSimplifier opts builtinSemanticsVariant t + pure (Program a v result, trace) + simplifyTerm :: forall name uni fun m a. - ( Compiling m uni fun name a - , MonadState (UPLCSimplifierTrace name uni fun a) m - ) => + (Compiling m uni fun name a) => SimplifyOpts name a -> BuiltinSemanticsVariant fun -> Term name uni fun a -> m (Term name uni fun a) -simplifyTerm opts builtinSemanticsVariant = +simplifyTerm opts builtinSemanticsVariant term = + evalSimplifierT $ termSimplifier opts builtinSemanticsVariant term + +termSimplifier :: + forall name uni fun m a. + (Compiling m uni fun name a) => + SimplifyOpts name a -> + BuiltinSemanticsVariant fun -> + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) +termSimplifier opts builtinSemanticsVariant = simplifyNTimes (_soMaxSimplifierIterations opts) >=> cseNTimes cseTimes where -- Run the simplifier @n@ times - simplifyNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a) + simplifyNTimes :: + Int -> + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) simplifyNTimes n = List.foldl' (>=>) pure $ map simplifyStep [1..n] -- Run CSE @n@ times, interleaved with the simplifier. -- See Note [CSE] - cseNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a) + cseNTimes :: + Int -> + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) cseNTimes n = foldl' (>=>) pure $ concatMap (\i -> [cseStep i, simplifyStep i]) [1..n] -- generate simplification step - simplifyStep :: Int -> Term name uni fun a -> m (Term name uni fun a) + simplifyStep :: + Int -> + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) simplifyStep _ = - traceAST - >=> floatDelay - >=> traceAST - >=> pure . forceDelay - >=> traceAST - >=> pure . caseOfCase' - >=> traceAST - >=> pure . caseReduce - >=> traceAST + floatDelay + >=> forceDelay + >=> caseOfCase' + >=> caseReduce >=> inline (_soInlineConstants opts) (_soInlineHints opts) builtinSemanticsVariant - >=> traceAST - caseOfCase' :: Term name uni fun a -> Term name uni fun a + caseOfCase' :: + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) caseOfCase' = case eqT @fun @DefaultFun of Just Refl -> caseOfCase - Nothing -> id + Nothing -> pure - cseStep :: Int -> Term name uni fun a -> m (Term name uni fun a) + cseStep :: + Int -> + Term name uni fun a -> + SimplifierT name uni fun a m (Term name uni fun a) cseStep _ = case (eqT @name @Name, eqT @uni @PLC.DefaultUni) of (Just Refl, Just Refl) -> cse builtinSemanticsVariant _ -> pure - traceAST ast = do - State.modify' (\st -> st { uplcSimplifierTrace = uplcSimplifierTrace st ++ [ast] }) - return ast - cseTimes = if _soConservativeOpts opts then 0 else _soMaxCseIterations opts diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseOfCase.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseOfCase.hs index 67f376f02e5..096bdf4ce76 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseOfCase.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseOfCase.hs @@ -21,11 +21,23 @@ module UntypedPlutusCore.Transform.CaseOfCase (caseOfCase) where import PlutusCore qualified as PLC import PlutusCore.MkPlc import UntypedPlutusCore.Core +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseOfCase), SimplifierT, + recordSimplification) import Control.Lens -caseOfCase :: (fun ~ PLC.DefaultFun) => Term name uni fun a -> Term name uni fun a -caseOfCase = transformOf termSubterms $ \case +caseOfCase + :: fun ~ PLC.DefaultFun + => Monad m + => Term name uni fun a + -> SimplifierT name uni fun a m (Term name uni fun a) +caseOfCase term = do + let result = transformOf termSubterms processTerm term + recordSimplification term CaseOfCase result + return result + +processTerm :: (fun ~ PLC.DefaultFun) => Term name uni fun a -> Term name uni fun a +processTerm = \case Case ann scrut alts | ( ite@(Force a (Builtin _ PLC.IfThenElse)) , [cond, (trueAnn, true@Constr{}), (falseAnn, false@Constr{})] diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseReduce.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseReduce.hs index f49f151a04e..d17a44236ae 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseReduce.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/CaseReduce.hs @@ -6,12 +6,20 @@ module UntypedPlutusCore.Transform.CaseReduce import PlutusCore.MkPlc import UntypedPlutusCore.Core +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseReduce), SimplifierT, + recordSimplification) import Control.Lens (transformOf) import Data.Vector qualified as V -caseReduce :: Term name uni fun a -> Term name uni fun a -caseReduce = transformOf termSubterms processTerm +caseReduce + :: Monad m + => Term name uni fun a + -> SimplifierT name uni fun a m (Term name uni fun a) +caseReduce term = do + let result = transformOf termSubterms processTerm term + recordSimplification term CaseReduce result + return result processTerm :: Term name uni fun a -> Term name uni fun a processTerm = \case diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Cse.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Cse.hs index 46e7abaa36c..6663f314484 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Cse.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Cse.hs @@ -11,6 +11,8 @@ import PlutusCore.Builtin (ToBuiltinMeaning (BuiltinSemanticsVariant)) import UntypedPlutusCore.Core import UntypedPlutusCore.Purity (isWorkFree) import UntypedPlutusCore.Size (termSize) +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CSE), SimplifierT, + recordSimplification) import Control.Arrow ((>>>)) import Control.Lens (foldrOf, transformOf) @@ -215,7 +217,7 @@ cse :: ) => BuiltinSemanticsVariant fun -> Term Name uni fun ann -> - m (Term Name uni fun ann) + SimplifierT Name uni fun ann m (Term Name uni fun ann) cse builtinSemanticsVariant t0 = do t <- rename t0 let annotated = annotate t @@ -229,7 +231,9 @@ cse builtinSemanticsVariant t0 = do . join . Map.elems $ countOccs builtinSemanticsVariant annotated - mkCseTerm commonSubexprs annotated + result <- mkCseTerm commonSubexprs annotated + recordSimplification t0 CSE result + return result -- | The second pass. See Note [CSE]. annotate :: Term name uni fun ann -> Term name uni fun (Path, ann) diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/FloatDelay.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/FloatDelay.hs index 752ec643cea..4070d3d0fed 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/FloatDelay.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/FloatDelay.hs @@ -61,18 +61,25 @@ import PlutusCore.Name.UniqueMap qualified as UMap import PlutusCore.Name.UniqueSet qualified as USet import UntypedPlutusCore.Core.Plated (termSubterms) import UntypedPlutusCore.Core.Type (Term (..)) +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (FloatDelay), SimplifierT, + recordSimplification) import Control.Lens (forOf, forOf_, transformOf) -import Control.Monad ((>=>)) import Control.Monad.Trans.Writer.CPS (Writer, execWriter, runWriter, tell) floatDelay :: - (PLC.MonadQuote m, PLC.Rename (Term name uni fun a), PLC.HasUnique name PLC.TermUnique) => + ( PLC.MonadQuote m + , PLC.Rename (Term name uni fun a) + , PLC.HasUnique name PLC.TermUnique + ) => Term name uni fun a -> - m (Term name uni fun a) -floatDelay = - PLC.rename >=> \t -> - pure . uncurry (flip simplifyBodies) $ simplifyArgs (unforcedVars t) t + SimplifierT name uni fun a m (Term name uni fun a) +floatDelay term = do + result <- + PLC.rename term >>= \t -> + pure . uncurry (flip simplifyBodies) $ simplifyArgs (unforcedVars t) t + recordSimplification term FloatDelay result + return result {- | First pass. Returns the names of all variables, at least one occurrence of which is not under `Force`. diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/ForceDelay.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/ForceDelay.hs index db95d330b42..4f9ccc38bbb 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/ForceDelay.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/ForceDelay.hs @@ -136,6 +136,8 @@ module UntypedPlutusCore.Transform.ForceDelay ) where import UntypedPlutusCore.Core +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (ForceDelay), SimplifierT, + recordSimplification) import Control.Lens (transformOf) import Control.Monad (guard) @@ -144,8 +146,14 @@ import Data.Foldable as Foldable (foldl') {- | Traverses the term, for each node applying the optimisation detailed above. For implementation details see 'optimisationProcedure'. -} -forceDelay :: Term name uni fun a -> Term name uni fun a -forceDelay = transformOf termSubterms processTerm +forceDelay + :: Monad m + => Term name uni fun a + -> SimplifierT name uni fun a m (Term name uni fun a) +forceDelay term = do + let result = transformOf termSubterms processTerm term + recordSimplification term ForceDelay result + return result {- | Checks whether the term is of the right form, and "pushes" the 'Force' down into the underlying lambda abstractions. diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Inline.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Inline.hs index 30b7dc4b45c..7c4b69a99aa 100644 --- a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Inline.hs +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Inline.hs @@ -41,6 +41,8 @@ import UntypedPlutusCore.Purity import UntypedPlutusCore.Rename () import UntypedPlutusCore.Size import UntypedPlutusCore.Subst +import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (Inline), SimplifierT, + recordSimplification) import Control.Lens hiding (Strict) import Control.Monad.Extra @@ -173,20 +175,23 @@ See Note [Inlining and global uniqueness] -} inline :: forall name uni fun m a. - (ExternalConstraints name uni fun m) => + ExternalConstraints name uni fun m => -- | inline constants Bool -> InlineHints name a -> PLC.BuiltinSemanticsVariant fun -> Term name uni fun a -> - m (Term name uni fun a) -inline inlineConstants hints builtinSemanticsVariant t = - liftQuote $ flip evalStateT mempty $ runReaderT (processTerm t) InlineInfo - { _iiUsages = Usages.termUsages t - , _iiHints = hints - , _iiBuiltinSemanticsVariant = builtinSemanticsVariant - , _iiInlineConstants = inlineConstants - } + SimplifierT name uni fun a m (Term name uni fun a) +inline inlineConstants hints builtinSemanticsVariant t = do + result <- + liftQuote $ flip evalStateT mempty $ runReaderT (processTerm t) InlineInfo + { _iiUsages = Usages.termUsages t + , _iiHints = hints + , _iiBuiltinSemanticsVariant = builtinSemanticsVariant + , _iiInlineConstants = inlineConstants + } + recordSimplification t Inline result + return result -- See Note [Differences from PIR inliner] 3 diff --git a/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Simplifier.hs b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Simplifier.hs new file mode 100644 index 00000000000..4047f09c6ae --- /dev/null +++ b/plutus-core/untyped-plutus-core/src/UntypedPlutusCore/Transform/Simplifier.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE NamedFieldPuns #-} + +module UntypedPlutusCore.Transform.Simplifier ( + SimplifierT (..), + SimplifierTrace (..), + SimplifierStage (..), + Simplification (..), + runSimplifierT, + evalSimplifierT, + execSimplifierT, + Simplifier, + runSimplifier, + evalSimplifier, + execSimplifier, + initSimplifierTrace, + recordSimplification, +) where + +import Control.Monad.State (MonadTrans, StateT) +import Control.Monad.State qualified as State + +import Control.Monad.Identity (Identity, runIdentity) +import PlutusCore.Quote (MonadQuote) +import UntypedPlutusCore.Core.Type (Term) + +newtype SimplifierT name uni fun ann m a = + SimplifierT + { getSimplifierT :: StateT (SimplifierTrace name uni fun ann) m a + } + deriving newtype (Functor, Applicative, Monad, MonadTrans) + +instance MonadQuote m => MonadQuote (SimplifierT name uni fun ann m) + +runSimplifierT + :: SimplifierT name uni fun ann m a + -> m (a, SimplifierTrace name uni fun ann) +runSimplifierT = flip State.runStateT initSimplifierTrace . getSimplifierT + +evalSimplifierT + :: Monad m => SimplifierT name uni fun ann m a -> m a +evalSimplifierT = flip State.evalStateT initSimplifierTrace . getSimplifierT + +execSimplifierT + :: Monad m => SimplifierT name uni fun ann m a -> m (SimplifierTrace name uni fun ann) +execSimplifierT = flip State.execStateT initSimplifierTrace . getSimplifierT + +type Simplifier name uni fun ann = SimplifierT name uni fun ann Identity + +runSimplifier :: Simplifier name uni fun ann a -> (a, SimplifierTrace name uni fun ann) +runSimplifier = runIdentity . runSimplifierT + +evalSimplifier :: Simplifier name uni fun ann a -> a +evalSimplifier = runIdentity . evalSimplifierT + +execSimplifier :: Simplifier name uni fun ann a -> SimplifierTrace name uni fun ann +execSimplifier = runIdentity . execSimplifierT + +data SimplifierStage + = FloatDelay + | ForceDelay + | CaseOfCase + | CaseReduce + | Inline + | CSE + +data Simplification name uni fun a = + Simplification + { beforeAST :: Term name uni fun a + , stage :: SimplifierStage + , afterAST :: Term name uni fun a + } + +-- TODO2: we probably don't want this in memory so after MVP +-- we should consider serializing this to disk +newtype SimplifierTrace name uni fun a = + SimplifierTrace + { simplifierTrace + :: [Simplification name uni fun a] + } + +initSimplifierTrace :: SimplifierTrace name uni fun a +initSimplifierTrace = SimplifierTrace [] + +recordSimplification + :: Monad m + => Term name uni fun a + -> SimplifierStage + -> Term name uni fun a + -> SimplifierT name uni fun a m () +recordSimplification beforeAST stage afterAST = + let simplification = Simplification { beforeAST, stage, afterAST } + in + modify $ \st -> + st { simplifierTrace = simplification : simplifierTrace st } + where + modify f = SimplifierT $ State.modify' f diff --git a/plutus-core/untyped-plutus-core/test/Transform/CaseOfCase/Test.hs b/plutus-core/untyped-plutus-core/test/Transform/CaseOfCase/Test.hs index e5497483b65..f942dd431f7 100644 --- a/plutus-core/untyped-plutus-core/test/Transform/CaseOfCase/Test.hs +++ b/plutus-core/untyped-plutus-core/test/Transform/CaseOfCase/Test.hs @@ -26,6 +26,7 @@ import UntypedPlutusCore.Evaluation.Machine.Cek (CekMachineCosts, CekValue, Eval evaluateCek, noEmitter, unsafeSplitStructuralOperational) import UntypedPlutusCore.Transform.CaseOfCase (caseOfCase) +import UntypedPlutusCore.Transform.Simplifier (evalSimplifier) test_caseOfCase :: TestTree test_caseOfCase = @@ -111,12 +112,17 @@ caseOfCaseWithError = testCaseOfCaseWithError :: TestTree testCaseOfCaseWithError = testCase "Transformation doesn't evaluate error eagerly" do - let simplifiedTerm = caseOfCase caseOfCaseWithError + let simplifiedTerm = evalCaseOfCase caseOfCaseWithError evaluateUplc simplifiedTerm @?= evaluateUplc caseOfCaseWithError ---------------------------------------------------------------------------------------------------- -- Helper functions -------------------------------------------------------------------------------- +evalCaseOfCase + :: Term Name DefaultUni DefaultFun () + -> Term Name DefaultUni DefaultFun () +evalCaseOfCase term = evalSimplifier $ caseOfCase term + evaluateUplc :: UPLC.Term Name DefaultUni DefaultFun () -> EvaluationResult (UPLC.Term Name DefaultUni DefaultFun ()) @@ -136,4 +142,4 @@ goldenVsSimplified name = . encodeUtf8 . render . prettyClassicSimple - . caseOfCase + . evalCaseOfCase diff --git a/plutus-core/untyped-plutus-core/test/Transform/Simplify/Lib.hs b/plutus-core/untyped-plutus-core/test/Transform/Simplify/Lib.hs index c45d69afe2f..476ff9d7c54 100644 --- a/plutus-core/untyped-plutus-core/test/Transform/Simplify/Lib.hs +++ b/plutus-core/untyped-plutus-core/test/Transform/Simplify/Lib.hs @@ -3,12 +3,10 @@ module Transform.Simplify.Lib where import Control.Lens ((&), (.~)) -import Control.Monad.State (evalStateT) import Data.ByteString.Lazy qualified as BSL import Data.Text.Encoding (encodeUtf8) import PlutusCore qualified as PLC import PlutusCore.Builtin (BuiltinSemanticsVariant) -import PlutusCore.Compiler.Types (initUPLCSimplifierTrace) import PlutusCore.Pretty (PrettyPlc, Render (render), prettyPlcReadableSimple) import PlutusPrelude (Default (def)) import Test.Tasty (TestTree) @@ -27,7 +25,6 @@ goldenVsSimplified :: String -> Term Name PLC.DefaultUni PLC.DefaultFun () -> Te goldenVsSimplified name = goldenVsPretty ".uplc.golden" name . PLC.runQuote - . flip evalStateT initUPLCSimplifierTrace . simplifyTerm ( defaultSimplifyOpts -- Just run one iteration, to see what that does @@ -40,7 +37,6 @@ goldenVsCse :: String -> Term Name PLC.DefaultUni PLC.DefaultFun () -> TestTree goldenVsCse name = goldenVsPretty ".uplc.golden" name . PLC.runQuote - . flip evalStateT initUPLCSimplifierTrace . simplifyTerm ( defaultSimplifyOpts -- Just run one iteration, to see what that does diff --git a/plutus-executables/executables/pir/Main.hs b/plutus-executables/executables/pir/Main.hs index 98557593564..2075b3884dd 100644 --- a/plutus-executables/executables/pir/Main.hs +++ b/plutus-executables/executables/pir/Main.hs @@ -170,7 +170,7 @@ compileToUplc optimise plcProg = then PLC.defaultCompilationOpts else PLC.defaultCompilationOpts & PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations .~ 0 - in runQuote $ PLC.evalCompile plcCompilerOpts $ PLC.compileProgram plcProg + in runQuote $ flip runReaderT plcCompilerOpts $ PLC.compileProgram plcProg loadPirAndCompile :: CompileOptions -> IO () loadPirAndCompile (CompileOptions language optimise test inp ifmt outp ofmt mode) = do diff --git a/plutus-executables/executables/uplc/Main.hs b/plutus-executables/executables/uplc/Main.hs index 8f2fa60ac8c..eb6ca71648c 100644 --- a/plutus-executables/executables/uplc/Main.hs +++ b/plutus-executables/executables/uplc/Main.hs @@ -8,11 +8,11 @@ {-# LANGUAGE TypeSynonymInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} +{-# LANGUAGE NamedFieldPuns #-} module Main (main) where import PlutusCore qualified as PLC import PlutusCore.Annotation (SrcSpan) -import PlutusCore.Compiler.Types (UPLCSimplifierTrace (..), initUPLCSimplifierTrace) import PlutusCore.Data (Data) import PlutusCore.Default (BuiltinSemanticsVariant (..)) import PlutusCore.Evaluation.Machine.ExBudget (ExBudget (..), ExRestrictingBudget (..)) @@ -33,11 +33,11 @@ import UntypedPlutusCore.Evaluation.Machine.SteppableCek.Internal qualified as D import UntypedPlutusCore qualified as UPLC import UntypedPlutusCore.DeBruijn (FreeVariableError) import UntypedPlutusCore.Evaluation.Machine.Cek qualified as Cek +import UntypedPlutusCore.Transform.Simplifier import Control.DeepSeq (force) import Control.Monad.Except (runExcept) import Control.Monad.IO.Class (liftIO) -import Control.Monad.State (runStateT) import Criterion (benchmarkWith, whnf) import Criterion.Main (defaultConfig) import Criterion.Types (Config (..)) @@ -269,17 +269,17 @@ runOptimisations (OptimiseOptions inp ifmt outp ofmt mode cert) = do renamed <- PLC.rename prog let defaultBuiltinSemanticsVariant :: BuiltinSemanticsVariant PLC.DefaultFun defaultBuiltinSemanticsVariant = def - flip runStateT initUPLCSimplifierTrace - $ UPLC.simplifyProgram UPLC.defaultSimplifyOpts defaultBuiltinSemanticsVariant renamed + UPLC.simplifyProgramWithTrace UPLC.defaultSimplifyOpts defaultBuiltinSemanticsVariant renamed writeProgram outp ofmt mode simplified runCertifier cert simplificationTrace where - runCertifier (Just certName) (UPLCSimplifierTrace uplcSimplTrace) = do - let processAgdaAST t = - case UPLC.deBruijnTerm t of - Right res -> res - Left (err :: UPLC.FreeVariableError) -> error $ show err - rawAgdaTrace = AgdaFFI.conv . processAgdaAST . void <$> uplcSimplTrace + runCertifier (Just certName) (SimplifierTrace simplTrace) = do + let processAgdaAST Simplification {beforeAST, stage, afterAST} = + case (UPLC.deBruijnTerm beforeAST, UPLC.deBruijnTerm afterAST) of + (Right before', Right after') -> (stage, (AgdaFFI.conv (void before'), AgdaFFI.conv (void after'))) + (Left (err :: UPLC.FreeVariableError), _) -> error $ show err + (_, Left (err :: UPLC.FreeVariableError)) -> error $ show err + rawAgdaTrace = reverse $ processAgdaAST <$> simplTrace Agda.runCertifier (T.pack certName) rawAgdaTrace runCertifier Nothing _ = pure () diff --git a/plutus-metatheory/plutus-metatheory.cabal b/plutus-metatheory/plutus-metatheory.cabal index 9840d96054c..283aa0abb37 100644 --- a/plutus-metatheory/plutus-metatheory.cabal +++ b/plutus-metatheory/plutus-metatheory.cabal @@ -328,7 +328,10 @@ library MAlonzo.Code.Utils.Reflection MAlonzo.Code.VerifiedCompilation MAlonzo.Code.VerifiedCompilation.Equality + MAlonzo.Code.VerifiedCompilation.Purity MAlonzo.Code.VerifiedCompilation.UCaseOfCase + MAlonzo.Code.VerifiedCompilation.UCSE + MAlonzo.Code.VerifiedCompilation.UFloatDelay MAlonzo.Code.VerifiedCompilation.UForceDelay MAlonzo.Code.VerifiedCompilation.UntypedTranslation MAlonzo.Code.VerifiedCompilation.UntypedViews @@ -580,6 +583,9 @@ library MAlonzo.Code.Utils.Reflection MAlonzo.RTE MAlonzo.RTE.Float + MAlonzo.Code.VerifiedCompilation.Purity + MAlonzo.Code.VerifiedCompilation.UCSE + MAlonzo.Code.VerifiedCompilation.UFloatDelay executable plc-agda import: lang, os-support diff --git a/plutus-metatheory/src/VerifiedCompilation.lagda.md b/plutus-metatheory/src/VerifiedCompilation.lagda.md index 3796dea4635..9d8f3b1af6e 100644 --- a/plutus-metatheory/src/VerifiedCompilation.lagda.md +++ b/plutus-metatheory/src/VerifiedCompilation.lagda.md @@ -27,6 +27,7 @@ containing the generated proof object, a.k.a. the _certificate_. The certificate it into Agda and checking that it is correctly typed. ``` +{-# OPTIONS --allow-unsolved-metas #-} module VerifiedCompilation where ``` @@ -48,6 +49,8 @@ open import Agda.Builtin.Unit using (⊤;tt) import IO.Primitive as IO using (return;_>>=_) import VerifiedCompilation.UCaseOfCase as UCC import VerifiedCompilation.UForceDelay as UFD +import VerifiedCompilation.UFloatDelay as UFlD +import VerifiedCompilation.UCSE as UCSE open import Data.Empty using (⊥) open import Scoped using (ScopeError;deBError) open import VerifiedCompilation.Equality using (DecEq) @@ -55,83 +58,95 @@ import Relation.Binary as Binary using (Decidable) open import VerifiedCompilation.UntypedTranslation using (Translation; Relation; translation?) import Relation.Binary as Binary using (Decidable) import Relation.Unary as Unary using (Decidable) +import Relation.Nary as Nary using (Decidable) + ``` ## Compiler optimisation traces -A `Trace` represents a sequence of optimisation transformations applied to a program. It is a list of pairs of ASTs, -where each pair represents the before and after of a transformation application. -The `IsTransformation` type is a sum type that represents the possible transformations which are implemented in their -respective modules. Adding a new transformation requires extending this type. +A `Trace` represents a sequence of optimisation transformations applied to a program. It is a list of pairs of ASTs +and a tag (`SimplifierTag`), where each pair represents the before and after of a transformation application and the +tag indicates which transformation was applied. +The `Transformation` type is a sum type that represents the possible transformations which are implemented in their +respective modules. The `isTrace?` decision procedure is at the core of the certification process. It produces the proof that the given -list of ASTs are in relation with one another according to the transformations implemented in the project. It is -parametrised by the relation type in order to provide a generic interface for testing, but in practice it is always -instantiated with the `IsTransformation` type. +list of ASTs are in relation with one another according to the transformations implemented in the project. -The `IsTransformation?` decision procedure implements a logical disjunction of the different transformation types. +The `isTransformation?` decision procedure just dispatches to the decision procedure indicated by the tag. **TODO**: The `Trace` type or decision procedure should also enforce that the second element of a pair is the first element of the next pair in the list. This might not be necessary if we decide that we can assume that the function which produces a `Trace` always produces a correct one, although it might be useful to make this explicit in the type. - -**TODO**: The compiler should provide information on which transformation was applied at each step in the trace. -`IsTransformation?` is currently quadratic in the number of transformations, which is not ideal. - ``` -data Trace (R : Relation) : { X : Set } {{_ : DecEq X}} → List ((X ⊢) × (X ⊢)) → Set₁ where - empty : {X : Set}{{_ : DecEq X}} → Trace R {X} [] - cons : {X : Set}{{_ : DecEq X}} {x x' : X ⊢} {xs : List ((X ⊢) × (X ⊢))} → R x x' → Trace R {X} xs → Trace R {X} ((x , x') ∷ xs) - -data IsTransformation : Relation where - isCoC : {X : Set}{{_ : DecEq X}} → (ast ast' : X ⊢) → UCC.CoC ast ast' → IsTransformation ast ast' - isFD : {X : Set}{{_ : DecEq X}} → (ast ast' : X ⊢) → UFD.FD zero zero ast ast' → IsTransformation ast ast' - -isTrace? : {X : Set} {{_ : DecEq X}} {R : Relation} → Binary.Decidable (R {X}) → Unary.Decidable (Trace R {X}) -isTrace? {X} {R} isR? [] = yes empty -isTrace? {X} {R} isR? ((x₁ , x₂) ∷ xs) with isTrace? {X} {R} isR? xs -... | no ¬p = no λ {(cons a as) → ¬p as} -... | yes p with isR? x₁ x₂ -... | no ¬p = no λ {(cons x x₁) → ¬p x} -... | yes p₁ = yes (cons p₁ p) - -isTransformation? : {X : Set} {{_ : DecEq X}} → Binary.Decidable (IsTransformation {X}) -isTransformation? ast₁ ast₂ with UCC.isCoC? ast₁ ast₂ -... | scrt with UFD.isFD? zero zero ast₁ ast₂ -isTransformation? ast₁ ast₂ | no ¬p | no ¬p₁ = no λ {(isCoC .ast₁ .ast₂ x) → ¬p x - ; (isFD .ast₁ .ast₂ x) → ¬p₁ x} -isTransformation? ast₁ ast₂ | no ¬p | yes p = yes (isFD ast₁ ast₂ p) -isTransformation? ast₁ ast₂ | yes p | no ¬p = yes (isCoC ast₁ ast₂ p) --- TODO: this does not make much sense, see TODO above -isTransformation? ast₁ ast₂ | yes p | yes p₁ = yes (isCoC ast₁ ast₂ p) +data SimplifierTag : Set where + floatDelayT : SimplifierTag + forceDelayT : SimplifierTag + caseOfCaseT : SimplifierTag + caseReduceT : SimplifierTag + inlineT : SimplifierTag + cseT : SimplifierTag + +{-# FOREIGN GHC import UntypedPlutusCore.Transform.Simplifier #-} +{-# COMPILE GHC SimplifierTag = data SimplifierStage (FloatDelay | ForceDelay | CaseOfCase | CaseReduce | Inline | CSE) #-} + +data Transformation : SimplifierTag → Relation where + isCoC : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → UCC.CaseOfCase ast ast' → Transformation caseOfCaseT ast ast' + isFD : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → UFD.ForceDelay ast ast' → Transformation forceDelayT ast ast' + isFlD : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → UFlD.FloatDelay ast ast' → Transformation floatDelayT ast ast' + isCSE : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → UCSE.UntypedCSE ast ast' → Transformation cseT ast ast' + inlineNotImplemented : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → Transformation inlineT ast ast' + caseReduceNotImplemented : {X : Set}{{_ : DecEq X}} → {ast ast' : X ⊢} → Transformation caseReduceT ast ast' + +data Trace : { X : Set } {{_ : DecEq X}} → List (SimplifierTag × (X ⊢) × (X ⊢)) → Set₁ where + empty : {X : Set}{{_ : DecEq X}} → Trace {X} [] + cons + : {X : Set}{{_ : DecEq X}} + {tag : SimplifierTag} {x x' : X ⊢} + {xs : List (SimplifierTag × (X ⊢) × (X ⊢))} + → Transformation tag x x' + → Trace xs + → Trace ((tag , x , x') ∷ xs) + +isTransformation? : {X : Set} {{_ : DecEq X}} → (tag : SimplifierTag) → (ast ast' : X ⊢) → Nary.Decidable (Transformation tag ast ast') +isTransformation? tag ast ast' with tag +isTransformation? tag ast ast' | floatDelayT with UFlD.isFloatDelay? ast ast' +... | no ¬p = no λ { (isFlD x) → ¬p x } +... | yes p = yes (isFlD p) +isTransformation? tag ast ast' | forceDelayT with UFD.isForceDelay? ast ast' +... | no ¬p = no λ { (isFD x) → ¬p x } +... | yes p = yes (isFD p) +isTransformation? tag ast ast' | caseOfCaseT with UCC.isCaseOfCase? ast ast' +... | no ¬p = no λ { (isCoC x) → ¬p x } +... | yes p = yes (isCoC p) +isTransformation? tag ast ast' | caseReduceT = yes caseReduceNotImplemented +isTransformation? tag ast ast' | inlineT = yes inlineNotImplemented +isTransformation? tag ast ast' | cseT with UCSE.isUntypedCSE? ast ast' +... | no ¬p = no λ { (isCSE x) → ¬p x } +... | yes p = yes (isCSE p) + +isTrace? : {X : Set} {{_ : DecEq X}} → Unary.Decidable (Trace {X}) +isTrace? [] = yes empty +isTrace? ((tag , x₁ , x₂) ∷ xs) with isTrace? xs +... | no ¬pₜ = no λ {(cons _ rest) → ¬pₜ rest} +... | yes pₜ with isTransformation? tag x₁ x₂ +... | no ¬pₑ = no λ {(cons x _) → ¬pₑ x} +... | yes pₑ = yes (cons pₑ pₜ) + ``` ## Serialising the proofs The proof objects are converted to a textual representation which can be written to a file. -**TODO**: Finish the implementation. A textual representation is not usually ideal, but it is a good starting point. +**TODO**: This is currently not supported. The `showTrace` function is a placeholder for the actual serialisation function. ``` -showTranslation : {X : Set} {{_ : DecEq X}} {ast ast' : X ⊢} → Translation IsTransformation ast ast' → String -showTranslation (Translation.istranslation x) = "istranslation TODO" -showTranslation Translation.var = "var" -showTranslation (Translation.ƛ t) = "(ƛ " ++ showTranslation t ++ ")" -showTranslation (Translation.app t t₁) = "(app " ++ showTranslation t ++ " " ++ showTranslation t₁ ++ ")" -showTranslation (Translation.force t) = "(force " ++ showTranslation t ++ ")" -showTranslation (Translation.delay t) = "(delay " ++ showTranslation t ++ ")" -showTranslation Translation.con = "con" -showTranslation (Translation.constr x) = "(constr TODO)" -showTranslation (Translation.case x t) = "(case TODO " ++ showTranslation t ++ ")" -showTranslation Translation.builtin = "builtin" -showTranslation Translation.error = "error" - -showTrace : {X : Set} {{_ : DecEq X}} {xs : List ((X ⊢) × (X ⊢))} → Trace (Translation IsTransformation) xs → String -showTrace empty = "empty" -showTrace (cons x bla) = "(cons " ++ showTranslation x ++ showTrace bla ++ ")" - -serializeTraceProof : {X : Set} {{_ : DecEq X}} {xs : List ((X ⊢) × (X ⊢))} → Dec (Trace (Translation IsTransformation) xs) → String +showTrace : {X : Set} {{_ : DecEq X}} {xs : List (SimplifierTag × (X ⊢) × (X ⊢))} → Trace xs → String +showTrace _ = "TODO" + +serializeTraceProof : {X : Set} {{_ : DecEq X}} {xs : List (SimplifierTag × (X ⊢) × (X ⊢))} → Dec (Trace xs) → String serializeTraceProof (no ¬p) = "no" serializeTraceProof (yes p) = "yes " ++ showTrace p @@ -141,14 +156,14 @@ serializeTraceProof (yes p) = "yes " ++ showTrace p The `runCertifier` function is the top-level function which can be called by the compiler through the foreign function interface. It represents the "impure top layer" which receives the list of ASTs produced by the compiler and writes the certificate -generated by the `certifier` function to disk. Again, the `certifier` is generic for testing purposes but it is instantiated -with the top-level decision procedures by the `runCertifier` function. +generated by the `certifier` function to disk. ``` {-# FOREIGN GHC import qualified Data.Text.IO as TextIO #-} {-# FOREIGN GHC import qualified System.IO as IO #-} {-# FOREIGN GHC import qualified Data.Text as Text #-} +{-# FOREIGN GHC import PlutusCore.Compiler.Types #-} postulate FileHandle : Set {-# COMPILE GHC FileHandle = type IO.Handle #-} @@ -162,32 +177,26 @@ postulate {-# COMPILE GHC stderr = IO.stderr #-} {-# COMPILE GHC hPutStrLn = TextIO.hPutStr #-} -buildPairs : {X : Set} → List (Maybe X ⊢) -> List ((Maybe X ⊢) × (Maybe X ⊢)) -buildPairs [] = [] -buildPairs (x ∷ []) = (x , x) ∷ [] -buildPairs (x₁ ∷ (x₂ ∷ xs)) = (x₁ , x₂) ∷ buildPairs (x₂ ∷ xs) - -traverseEitherList : {A B E : Set} → (A → Either E B) → List A → Either E (List B) +traverseEitherList : {A B E : Set} → (A → Either E B) → List (SimplifierTag × A × A) → Either E (List (SimplifierTag × B × B)) traverseEitherList _ [] = inj₂ [] -traverseEitherList f (x ∷ xs) with f x -... | inj₁ err = inj₁ err -... | inj₂ x' with traverseEitherList f xs -... | inj₁ err = inj₁ err -... | inj₂ resList = inj₂ (x' ∷ resList) +traverseEitherList f ((tag , before , after) ∷ xs) with f before +... | inj₁ e = inj₁ e +... | inj₂ b with f after +... | inj₁ e = inj₁ e +... | inj₂ a with traverseEitherList f xs +... | inj₁ e = inj₁ e +... | inj₂ xs' = inj₂ (((tag , b , a)) ∷ xs') certifier : {X : Set} {{_ : DecEq X}} - → List Untyped - → Unary.Decidable (Trace (Translation IsTransformation) {Maybe X}) + → List (SimplifierTag × Untyped × Untyped) → Either ScopeError String -certifier {X} rawInput isRTrace? with traverseEitherList toWellScoped rawInput +certifier {X} rawInput with traverseEitherList (toWellScoped {X}) rawInput ... | inj₁ err = inj₁ err -... | inj₂ rawTrace = - let inputTrace = buildPairs rawTrace - in inj₂ (serializeTraceProof (isRTrace? inputTrace)) +... | inj₂ inputTrace = inj₂ (serializeTraceProof (isTrace? inputTrace)) -runCertifier : String → List Untyped → IO ⊤ -runCertifier fileName rawInput with certifier rawInput (isTrace? {Maybe ⊥} {Translation IsTransformation} (translation? isTransformation?)) +runCertifier : String → List (SimplifierTag × Untyped × Untyped) → IO ⊤ +runCertifier fileName rawInput with certifier {⊥} rawInput ... | inj₁ err = hPutStrLn stderr "error" -- TODO: pretty print error ... | inj₂ result = writeFile (fileName ++ ".agda") result {-# COMPILE GHC runCertifier as runCertifier #-} diff --git a/plutus-metatheory/src/VerifiedCompilation/UCaseOfCase.lagda.md b/plutus-metatheory/src/VerifiedCompilation/UCaseOfCase.lagda.md index 2bcce6d9f3f..3fe82818914 100644 --- a/plutus-metatheory/src/VerifiedCompilation/UCaseOfCase.lagda.md +++ b/plutus-metatheory/src/VerifiedCompilation/UCaseOfCase.lagda.md @@ -48,8 +48,8 @@ data CoC : Relation where (case ((((force (builtin ifThenElse)) · b) · (constr tn tt)) · (constr fn ft)) alts) (force ((((force (builtin ifThenElse)) · b) · (delay (case (constr tn tt') alts'))) · (delay (case (constr fn ft') alts')))) -UntypedCaseOfCase : {X : Set} {{_ : DecEq X}} → (ast : X ⊢) → (ast' : X ⊢) → Set₁ -UntypedCaseOfCase = Translation CoC +CaseOfCase : {X : Set} {{_ : DecEq X}} → (ast : X ⊢) → (ast' : X ⊢) → Set₁ +CaseOfCase = Translation CoC ``` ## Decision Procedure @@ -97,18 +97,18 @@ the individual pattern decision `isCoC?` and the overall translation decision `i recursive, so the `isUntypedCaseOfCase?` type declaration comes first, with the implementation later. ``` -isUntypedCaseOfCase? : {X : Set} {{_ : DecEq X}} → Binary.Decidable (Translation CoC {X}) +isCaseOfCase? : {X : Set} {{_ : DecEq X}} → Binary.Decidable (Translation CoC {X}) {-# TERMINATING #-} isCoC? : {X : Set} {{_ : DecEq X}} → Binary.Decidable (CoC {X}) isCoC? ast ast' with (isCoCCase? ast) ×-dec (isCoCForce? ast') ... | no ¬cf = no λ { (isCoC b tn fn tt tt' ft ft' alts alts' x x₁ x₂) → ¬cf (isCoCCase b tn fn tt ft alts , isCoCForce b tn fn tt' ft' alts') } -... | yes (isCoCCase b tn fn tt ft alts , isCoCForce b₁ tn₁ fn₁ tt' ft' alts') with (b ≟ b₁) ×-dec (tn ≟ tn₁) ×-dec (fn ≟ fn₁) ×-dec (decPointwise isUntypedCaseOfCase? tt tt') ×-dec (decPointwise isUntypedCaseOfCase? ft ft') ×-dec (decPointwise isUntypedCaseOfCase? alts alts') +... | yes (isCoCCase b tn fn tt ft alts , isCoCForce b₁ tn₁ fn₁ tt' ft' alts') with (b ≟ b₁) ×-dec (tn ≟ tn₁) ×-dec (fn ≟ fn₁) ×-dec (decPointwise isCaseOfCase? tt tt') ×-dec (decPointwise isCaseOfCase? ft ft') ×-dec (decPointwise isCaseOfCase? alts alts') ... | yes (refl , refl , refl , ttpw , ftpw , altpw) = yes (isCoC b tn fn tt tt' ft ft' alts alts' altpw ttpw ftpw) ... | no ¬p = no λ { (isCoC .b .tn .fn .tt .tt' .ft .ft' .alts .alts' x x₁ x₂) → ¬p (refl , refl , refl , x₁ , x₂ , x) } -isUntypedCaseOfCase? {X} = translation? {X} isCoC? +isCaseOfCase? {X} = translation? {X} isCoC? ``` ## Semantic Equivalence diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin.hs b/plutus-tx-plugin/src/PlutusTx/Plugin.hs index f9e4936f63a..2dc5f8d925c 100644 --- a/plutus-tx-plugin/src/PlutusTx/Plugin.hs +++ b/plutus-tx-plugin/src/PlutusTx/Plugin.hs @@ -85,8 +85,6 @@ import Data.Set qualified as Set import Data.Text qualified as Text import Data.Type.Bool qualified as PlutusTx.Bool import GHC.Num.Integer qualified -import PlutusCore.Compiler.Types (UPLCSimplifierTrace (UPLCSimplifierTrace), - initUPLCSimplifierTrace) import PlutusCore.Default (DefaultFun, DefaultUni) import PlutusIR.Analysis.Builtins import PlutusIR.Compiler.Provenance (noProvenance, original) @@ -569,7 +567,7 @@ runCompiler moduleName opts expr = do when (opts ^. posDoTypecheck) . void $ liftExcept $ PLC.inferTypeOfProgram plcTcConfig (plcP $> annMayInline) - uplcP <- PLC.evalCompile plcOpts $ PLC.compileProgram plcP + uplcP <- flip runReaderT plcOpts $ PLC.compileProgram plcP dbP <- liftExcept $ traverseOf UPLC.progTerm UPLC.deBruijnTerm uplcP when (opts ^. posDumpUPlc) . liftIO $ dumpFlat diff --git a/plutus-tx/src/PlutusTx/Lift.hs b/plutus-tx/src/PlutusTx/Lift.hs index 017eef94232..ac0a1f91063 100644 --- a/plutus-tx/src/PlutusTx/Lift.hs +++ b/plutus-tx/src/PlutusTx/Lift.hs @@ -49,14 +49,12 @@ import Control.Lens hiding (lifted) import Control.Monad (void) import Control.Monad.Except (ExceptT, MonadError, liftEither, runExceptT) import Control.Monad.Reader (runReaderT) -import Control.Monad.State (evalStateT) import Data.Bifunctor import Data.Default.Class import Data.Hashable import Data.Proxy -- We do not use qualified import because the whole module contains off-chain code -import PlutusCore.Compiler.Types (initUPLCSimplifierTrace) import Prelude as Haskell -- | Get a Plutus Core term corresponding to the given value. @@ -89,7 +87,7 @@ safeLift v x = do & PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations .~ 0 & PLC.coSimplifyOpts . UPLC.soMaxCseIterations .~ 0 plc <- flip runReaderT ccConfig $ compileProgram (Program () v pir) - uplc <- flip evalStateT initUPLCSimplifierTrace $ flip runReaderT ucOpts $ PLC.compileProgram plc + uplc <- flip runReaderT ucOpts $ PLC.compileProgram plc UPLC.Program _ _ db <- traverseOf UPLC.progTerm UPLC.deBruijnTerm uplc pure (void pir, void db) @@ -268,8 +266,7 @@ typeCode typeCode p prog = do _ <- typeCheckAgainst p prog compiled <- - flip evalStateT initUPLCSimplifierTrace - $ flip runReaderT PLC.defaultCompilationOpts + flip runReaderT PLC.defaultCompilationOpts $ PLC.compileProgram prog db <- traverseOf UPLC.progTerm UPLC.deBruijnTerm compiled pure $ DeserializedCode (mempty <$ db) Nothing mempty