Skip to content

Commit

Permalink
Scenario service: optimize for common usecases of context update (dig…
Browse files Browse the repository at this point in the history
…ital-asset#5666)

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
remyhaemmerle-da authored Apr 22, 2020
1 parent 276bc71 commit d90b036
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import java.net.{InetAddress, InetSocketAddress}
import java.util.logging.{Level, Logger}

import com.daml.lf.archive.Decode.ParseError
import com.daml.lf.data.Ref
import com.daml.lf.data.Ref.ModuleName
import com.daml.lf.scenario.api.v1.{Map => _, _}
import io.grpc.stub.StreamObserver
import io.grpc.{Status, StatusRuntimeException}
Expand Down Expand Up @@ -117,7 +119,7 @@ class ScenarioService extends ScenarioServiceGrpc.ScenarioServiceImplBase {
req: NewContextRequest,
respObs: StreamObserver[NewContextResponse],
): Unit = {
val ctx = Context.newContext()
val ctx = Context.newContext
contexts += (ctx.contextId -> ctx)
val response = NewContextResponse.newBuilder.setContextId(ctx.contextId).build
respObs.onNext(response)
Expand Down Expand Up @@ -183,12 +185,13 @@ class ScenarioService extends ScenarioServiceGrpc.ScenarioServiceImplBase {

case Some(ctx) =>
try {

val unloadModules =
if (req.hasUpdateModules)
req.getUpdateModules.getUnloadModulesList.asScala
.map(ModuleName.assertFromString)
.toSet
else
Seq.empty
Set.empty[ModuleName]

val loadModules =
if (req.hasUpdateModules)
Expand All @@ -199,8 +202,10 @@ class ScenarioService extends ScenarioServiceGrpc.ScenarioServiceImplBase {
val unloadPackages =
if (req.hasUpdatePackages)
req.getUpdatePackages.getUnloadPackagesList.asScala
.map(Ref.PackageId.assertFromString)
.toSet
else
Seq.empty
Set.empty[Ref.PackageId]

val loadPackages =
if (req.hasUpdatePackages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package com.daml.lf
package scenario

import java.util.concurrent.atomic.AtomicLong

import com.daml.lf.archive.Decode
import com.daml.lf.archive.Decode.ParseError
import com.daml.lf.data.Ref.{Identifier, ModuleName, PackageId, QualifiedName}
Expand All @@ -20,22 +22,19 @@ import com.daml.lf.speedy.SExpr.{LfDefRef, SDefinitionRef}
import com.daml.lf.validation.Validation
import com.google.protobuf.ByteString

import scala.collection.immutable.HashMap

/**
* Scenario interpretation context: maintains a set of modules and external packages, with which
* scenarios can be interpreted.
*/
object Context {
type ContextId = Long
case class ContextException(err: String) extends RuntimeException(err, null, true, false)
case class ContextException(err: String) extends RuntimeException(err)

var nextContextId: ContextId = 0
private val contextCounter = new AtomicLong()

def newContext(): Context = {
this.synchronized {
nextContextId += 1
new Context(nextContextId)
}
}
def newContext: Context = new Context(contextCounter.incrementAndGet())

private def assert[X](either: Either[String, X]): X =
either.fold(e => throw new ParseError(e), identity)
Expand All @@ -51,20 +50,23 @@ class Context(val contextId: Context.ContextId) {
* self-references. We only care that the identifier is disjunct from the package ids
* in extPackages.
*/
val homePackageId: PackageId =
PackageId.assertFromString("-homePackageId-")
val homePackageId: PackageId = PackageId.assertFromString("-homePackageId-")

private var modules: Map[ModuleName, Ast.Module] = Map.empty
private var extPackages: Map[PackageId, Ast.Package] = Map.empty
private var defns: Map[SDefinitionRef, SExpr] = Map.empty
private var extPackages: Map[PackageId, Ast.Package] = HashMap.empty
private var extDefns: Map[SDefinitionRef, SExpr] = HashMap.empty
private var modules: Map[ModuleName, Ast.Module] = HashMap.empty
private var modDefns: Map[ModuleName, Map[SDefinitionRef, SExpr]] = HashMap.empty
private var defns: Map[SDefinitionRef, SExpr] = HashMap.empty

def loadedModules(): Iterable[ModuleName] = modules.keys
def loadedPackages(): Iterable[PackageId] = extPackages.keys

def cloneContext(): Context = this.synchronized {
def cloneContext(): Context = synchronized {
val newCtx = Context.newContext
newCtx.modules = modules
newCtx.extPackages = extPackages
newCtx.extDefns = extDefns
newCtx.modules = modules
newCtx.modDefns = modDefns
newCtx.defns = defns
newCtx
}
Expand All @@ -83,60 +85,58 @@ class Context(val contextId: Context.ContextId) {
dop.decodeScenarioModule(homePackageId, lfScenarioModule)
}

private def validate(pkgIds: Traversable[PackageId]): Unit =
pkgIds.foreach(
Validation.checkPackage(allPackages, _).left.foreach(e => throw ParseError(e.pretty)),
)

@throws[ParseError]
def update(
unloadModules: Seq[String],
unloadModules: Set[ModuleName],
loadModules: Seq[ProtoScenarioModule],
unloadPackages: Seq[String],
unloadPackages: Set[PackageId],
loadPackages: Seq[ByteString],
omitValidation: Boolean,
): Unit = this.synchronized {
): Unit = synchronized {

val newModules = loadModules.map(module =>
decodeModule(LanguageVersion.Major.V1, module.getMinor, module.getDamlLf1))
modules --= unloadModules
newModules.foreach(mod => modules += mod.name -> mod)

// First we unload modules and packages
unloadModules.foreach { moduleId =>
val lfModuleId = assert(ModuleName.fromString(moduleId))
modules -= lfModuleId
defns = defns.filterKeys(ref => ref.packageId != homePackageId || ref.modName != lfModuleId)
}
unloadPackages.foreach { pkgId =>
val lfPkgId = assert(PackageId.fromString(pkgId))
extPackages -= lfPkgId
defns = defns.filterKeys(ref => ref.packageId != lfPkgId)
}
// Now we can load the new packages.
val newPackages =
loadPackages.map { archive =>
Decode.decodeArchiveFromInputStream(archive.newInput)
}.toMap
extPackages ++= newPackages
defns ++= Compiler.compilePackages(extPackages, !omitValidation).right.get

// And now the new modules can be loaded.
val lfModules = loadModules.map(module =>
decodeModule(LanguageVersion.Major.V1, module.getMinor, module.getDamlLf1))
val modulesToCompile =
if (unloadPackages.nonEmpty || newPackages.nonEmpty) {
// if any change we recompile everything
extPackages --= unloadPackages
extPackages ++= newPackages
extDefns = assert(Compiler.compilePackages(extPackages))
modDefns = HashMap.empty
modules.values
} else {
modDefns --= unloadModules
newModules
}

val pkgs = allPackages
val compiler = Compiler(pkgs)

modules ++= lfModules.map(m => m.name -> m)

// At this point 'allPackages' is consistent and we can
// compile the new modules.
val compiler = Compiler(allPackages)
defns = lfModules.foldLeft(defns)(
(newDefns, m) =>
newDefns.filterKeys(ref => ref.packageId != homePackageId || ref.modName != m.name)
++ m.definitions.flatMap {
case (defName, defn) =>
compiler
.unsafeCompileDefn(Identifier(homePackageId, QualifiedName(m.name, defName)), defn)
})
modulesToCompile.foreach { mod =>
if (!omitValidation)
assert(Validation.checkModule(pkgs, homePackageId, mod.name).left.map(_.pretty))
modDefns += mod.name -> mod.definitions.flatMap {
case (defName, defn) =>
compiler
.unsafeCompileDefn(Identifier(homePackageId, QualifiedName(mod.name, defName)), defn)
}
}

defns = extDefns
modDefns.values.foreach(defns ++= _)
}

def allPackages: Map[PackageId, Ast.Package] =
def allPackages: Map[PackageId, Ast.Package] = synchronized {
extPackages + (homePackageId -> Ast.Package(modules, extPackages.keySet, None))
}

// We use a fix Hash and fix time to seed the contract id, so we get reproducible run.
private val submissionTime =
Expand All @@ -145,19 +145,18 @@ class Context(val contextId: Context.ContextId) {
speedy.InitialSeeding.TransactionSeed(crypto.Hash.hashPrivateKey(s"scenario-service"))

private def buildMachine(identifier: Identifier): Option[Speedy.Machine] = {
val defns = this.defns
for {
defn <- defns.get(LfDefRef(identifier))
} yield
// note that the use of `Map#mapValues` here is intentional: we lazily project the
// definition out rather than rebuilding the map.
Speedy.Machine
.build(
checkSubmitterInMaintainers = false,
sexpr = defn,
compiledPackages = PureCompiledPackages(allPackages, defns),
submissionTime,
initialSeeding,
)
Speedy.Machine
.build(
checkSubmitterInMaintainers = false,
sexpr = defn,
compiledPackages = PureCompiledPackages(allPackages, defns),
submissionTime,
initialSeeding,
)
}

def interpretScenario(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class PureCompiledPackages private (
packages: Map[PackageId, Package],
defns: Map[SDefinitionRef, SExpr],
) extends CompiledPackages {
override def packageIds = packages.keySet
override def packageIds: Set[PackageId] = packages.keySet
override def getPackage(pkgId: PackageId): Option[Package] = packages.get(pkgId)
override def getDefinition(dref: SDefinitionRef): Option[SExpr] = defns.get(dref)
}
Expand All @@ -35,7 +35,7 @@ object PureCompiledPackages {
/** Important: use this method only if you _know_ you have all the definitions! Otherwise
* use the other apply, which will compile them for you.
*/
def apply(
private[lf] def apply(
packages: Map[PackageId, Package],
defns: Map[SDefinitionRef, SExpr],
): PureCompiledPackages =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,43 @@ import com.daml.lf.language.Ast.{Module, Package}

object Validation {

private def runSafely[X](x: => X): Either[ValidationError, X] =
try {
Right(x)
} catch {
case e: ValidationError => Left(e)
}

def checkPackage(
pkgs: PartialFunction[PackageId, Package],
pkgId: PackageId
): Either[ValidationError, Unit] =
try {
runSafely {
val world = new World(pkgs)
Right(checkPackage(world, pkgId, world.lookupPackage(NoContext, pkgId).modules))
} catch {
case e: ValidationError =>
Left(e)
unsafeCheckPackage(world, pkgId, world.lookupPackage(NoContext, pkgId).modules)
}

private def checkPackage(
private def unsafeCheckPackage(
world: World,
pkgId: PackageId,
modules: Map[ModuleName, Module]
): Unit = {
Collision.checkPackage(pkgId, modules)
Recursion.checkPackage(pkgId, modules)
modules.values.foreach(checkModule(world, pkgId, _))
modules.values.foreach(unsafeCheckModule(world, pkgId, _))
}

private def checkModule(world: World, pkgId: PackageId, mod: Module): Unit = {
def checkModule(
pkgs: PartialFunction[PackageId, Package],
pkgId: PackageId,
modName: ModuleName,
): Either[ValidationError, Unit] =
runSafely {
val world = new World(pkgs)
unsafeCheckModule(world, pkgId, world.lookupModule(NoContext, pkgId, modName))
}

private def unsafeCheckModule(world: World, pkgId: PackageId, mod: Module): Unit = {
Typing.checkModule(world, pkgId, mod)
Serializability.checkModule(world, pkgId, mod)
PartyLiterals.checkModule(world, pkgId, mod)
Expand Down

0 comments on commit d90b036

Please sign in to comment.