Skip to content

Commit

Permalink
Implement AnyTemplate DAML-LF type on the Scala side (digital-asset#2905
Browse files Browse the repository at this point in the history
)

* Implement AnyTemplate DAML-LF type on the Scala side

This is the first part of
digital-asset#2876. The PR adds
AnyTemplate to Speedy and to the internal expression representation
and adapts all the relevant infrastructure (e.g., the typechecker) and
the tests.

It does not yet change the protobuf representation, the Haskell side
or the spec. I’ll update the spec together with changing the protobuf.

* Add comments to SBToAnyTemplate and SBFromAnyTemplate

* Address some comments from Remy

* Only allocate TBuiltin(BTAnyTemplate) once
  • Loading branch information
cocreature authored and mergify[bot] committed Sep 16, 2019
1 parent dc32abb commit 1703a51
Show file tree
Hide file tree
Showing 20 changed files with 201 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ object InterfaceReader {
unserializableDataType(
ctx,
s"Unserializable primitive type: $a must be applied to one and only one TNat")
case Ast.BTUpdate | Ast.BTScenario | Ast.BTArrow =>
case Ast.BTUpdate | Ast.BTScenario | Ast.BTArrow | Ast.BTAnyTemplate =>
unserializableDataType(ctx, s"Unserializable primitive type: $a")
}
(arity, primType) = ab
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TypeSpec extends WordSpec with Matchers {
TypePrim(PrimTypeBool, ImmArraySeq(assertOneArg(args)))
case Pkg.BTOptional => TypePrim(PrimTypeOptional, ImmArraySeq(assertOneArg(args)))
case Pkg.BTArrow => sys.error("cannot use arrow in interface type")
case Pkg.BTAnyTemplate => sys.error("cannot use anytemplate in interface type")
}
case Pkg.TTyCon(tycon) => TypeCon(TypeConName(tycon), args.toImmArray.toSeq)
case Pkg.TNat(_) => sys.error("cannot use nat type in interface type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,12 @@ final case class Compiler(packages: PackageId PartialFunction Package) {

case ELocation(loc, e) =>
SELocation(loc, translate(e))

case EToAnyTemplate(e) =>
SEApp(SEBuiltin(SBToAnyTemplate), Array(translate(e)))

case EFromAnyTemplate(tmplId, e) =>
SEApp(SEBuiltin(SBFromAnyTemplate(tmplId)), Array(translate(e)))
}

@tailrec
Expand Down Expand Up @@ -1030,6 +1036,7 @@ final case class Compiler(packages: PackageId PartialFunction Package) {
case SRecord(_, _, args) => args.forEach(goV)
case SVariant(_, _, value) => goV(value)
case SEnum(_, _) => ()
case SAnyTemplate(SRecord(_, _, args)) => args.forEach(goV)
case _: SPAP | SToken | STuple(_, _) =>
throw CompileError("validate: unexpected SEValue")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,33 @@ object SBuiltin {
throw DamlEUserError(args.get(0).asInstanceOf[SText].value)
}

/** $to_any_template
* :: arg (template argument)
* -> AnyTemplate
*/
final case object SBToAnyTemplate extends SBuiltin(1) {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
machine.ctrl = CtrlValue(args.get(0) match {
case r @ SRecord(_, _, _) => SAnyTemplate(r)
case v => crash(s"ToAnyTemplate on non-record: $v")
})
}
}

/** $from_any_template
* :: AnyTemplate
* -> Optional t (where t = TTyCon(expectedTemplateId))
*/
final case class SBFromAnyTemplate(expectedTemplateId: TypeConName) extends SBuiltin(1) {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
machine.ctrl = CtrlValue(args.get(0) match {
case SAnyTemplate(r @ SRecord(actualTemplateId, _, _)) =>
SOptional(if (actualTemplateId == expectedTemplateId) Some(r) else None)
case v => crash(s"FromAnyTemplate applied to non-AnyTemplate: $v")
})
}
}

// Helpers
//

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ sealed trait SValue {
V.ValueMap(SortedLookupList(mVal).mapValue(_.toValue))
case SContractId(coid) =>
V.ValueContractId(coid)
case SAnyTemplate(_) =>
throw SErrorCrash("SValue.toValue: unexpected SAnyTemplate")
case STNat(_) =>
throw SErrorCrash("SValue.toValue: unexpected STNat")
case _: SPAP =>
Expand Down Expand Up @@ -109,7 +111,8 @@ sealed trait SValue {
case SContractId(coid) =>
SContractId(f(coid))
case SEnum(_, _) | _: SPrimLit | SToken | STNat(_) => this

case SAnyTemplate(SRecord(tycon, fields, values)) =>
SAnyTemplate(SRecord(tycon, fields, mapArrayList(values, v => v.mapContractId(f))))
}

def equalTo(v2: SValue): Boolean = {
Expand Down Expand Up @@ -181,6 +184,8 @@ object SValue {

final case class SMap(value: HashMap[String, SValue]) extends SValue

final case class SAnyTemplate(t: SRecord) extends SValue

// Corresponds to a DAML-LF Nat type reified as a Speedy value.
// It is currently used to track at runtime the scale of the
// Numeric builtin's arguments/output. Should never be translated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ object Speedy {
}
}
case SContractId(_) | SDate(_) | SNumeric(_) | SInt64(_) | SParty(_) | SText(_) |
STimestamp(_) | STuple(_, _) | SMap(_) | SRecord(_, _, _) | STNat(_) | _: SPAP |
SToken =>
STimestamp(_) | STuple(_, _) | SMap(_) | SRecord(_, _, _) | SAnyTemplate(_) | STNat(_) |
_: SPAP | SToken =>
crash("Match on non-matchable value")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.digitalasset.daml.lf.speedy

import java.util

import com.digitalasset.daml.lf.data.Ref._
import com.digitalasset.daml.lf.PureCompiledPackages
import com.digitalasset.daml.lf.data.{FrontStack, Ref}
import com.digitalasset.daml.lf.language.Ast
Expand All @@ -19,6 +20,7 @@ import org.scalatest.{Matchers, WordSpec}
class SpeedyTest extends WordSpec with Matchers {

import SpeedyTest._
import defaultParserParameters.{defaultPackageId => pkgId}

"pattern matching" should {

Expand Down Expand Up @@ -91,6 +93,74 @@ class SpeedyTest extends WordSpec with Matchers {

}

val anyTemplatePkg =
p"""
module Test {
record @serializable T1 = { party: Party } ;
template (record : T1) = {
precondition True,
signatories Cons @Party [(Test:T1 {party} record)] (Nil @Party),
observers Nil @Party,
agreement "Agreement",
choices {
}
} ;
record @serializable T2 = { party: Party } ;
template (record : T2) = {
precondition True,
signatories Cons @Party [(Test:T2 {party} record)] (Nil @Party),
observers Nil @Party,
agreement "Agreement",
choices {
}
} ;
}
"""

val anyTemplatePkgs = typeAndCompile(anyTemplatePkg)

"to_any_template" should {

"throw an exception on Int64" in {
eval(e"""to_any_template 1""", anyTemplatePkgs) shouldBe 'left
}
"succeed on template type" in {
eval(e"""to_any_template (Test:T1 {party = 'Alice'})""", anyTemplatePkgs) shouldBe
Right(
SAnyTemplate(SRecord(
Identifier(pkgId, QualifiedName.assertFromString("Test:T1")),
Name.Array(Name.assertFromString("party")),
ArrayList(SParty(Party.assertFromString("Alice")))
)))
}

}

"from_any_template" should {

"throw an exception on Int64" in {
eval(e"""from_any_template @Test:T1 1""", anyTemplatePkgs) shouldBe 'left
}

"return Some(tpl) if template id matches" in {
eval(
e"""from_any_template @Test:T1 (to_any_template (Test:T1 {party = 'Alice'}))""",
anyTemplatePkgs) shouldBe
Right(
SOptional(Some(SRecord(
Identifier(pkgId, QualifiedName.assertFromString("Test:T1")),
Name.Array(Name.assertFromString("party")),
ArrayList(SParty(Party.assertFromString("Alice")))
))))
}

"return None if template id does not match" in {
eval(
e"""from_any_template @Test:T2 (to_any_template (Test:T1 {party = 'Alice'}))""",
anyTemplatePkgs) shouldBe Right(SOptional(None))
}
}

}

object SpeedyTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object Ast {

final case class ESome(typ: Type, body: Expr) extends Expr

/** AnyTemplate constructor **/
final case class EToAnyTemplate(body: Expr) extends Expr

/** Extract the underlying template if it matches the tmplId **/
final case class EFromAnyTemplate(tmplId: TypeConName, body: Expr) extends Expr

//
// Kinds
//
Expand Down Expand Up @@ -269,6 +275,7 @@ object Ast {
case object BTDate extends BuiltinType
case object BTContractId extends BuiltinType
case object BTArrow extends BuiltinType
case object BTAnyTemplate extends BuiltinType

//
// Primitive literals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ object Util {
val TTimestamp = TBuiltin(BTTimestamp)
val TDate = TBuiltin(BTDate)
val TParty = TBuiltin(BTParty)
val TAnyTemplate = TBuiltin(BTAnyTemplate)

val TNumeric = new ParametricType1(BTNumeric)
val TList = new ParametricType1(BTList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ private[digitalasset] class AstRewriter(
ENone(apply(typ))
case ESome(typ, body) =>
ESome(apply(typ), apply(body))
case EToAnyTemplate(body) =>
EToAnyTemplate(apply(body))
case EFromAnyTemplate(tmplId, body) =>
EFromAnyTemplate(tmplId, apply(body))
}

def apply(x: TypeConApp): TypeConApp = x match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ private[parser] class ExprParser[P](parserParameters: ParserParameters[P]) {
eAbs |
eTyAbs |
eLet |
eToAnyTemplate |
eFromAnyTemplate |
contractId |
fullIdentifier ^^ EVal |
(id ^? builtinFunctions) ^^ EBuiltin |
Expand Down Expand Up @@ -174,6 +176,14 @@ private[parser] class ExprParser[P](parserParameters: ParserParameters[P]) {
case b ~ body => ELet(b, body)
}

private lazy val eToAnyTemplate: Parser[Expr] =
`to_any_template` ~>! expr0 ^^ EToAnyTemplate

private lazy val eFromAnyTemplate: Parser[Expr] =
`from_any_template` ~>! `@` ~> fullIdentifier ~ expr0 ^^ {
case tyCon ~ e => EFromAnyTemplate(tyCon, e)
}

private lazy val pattern: Parser[CasePat] =
primCon ^^ CPPrimCon |
`nil` ^^^ CPNil |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ private[parser] object Lexer extends RegexParsers {
"fetch_by_key" -> `fetch_by_key`,
"lookup_by_key" -> `lookup_by_key`,
"by" -> `by`,
"to" -> `to`
"to" -> `to`,
"to_any_template" -> `to_any_template`,
"from_any_template" -> `from_any_template`
)

val token: Parser[Token] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ private[parser] object Token {
case object `lookup_by_key` extends Token
case object `by` extends Token
case object `to` extends Token
case object `to_any_template` extends Token
case object `from_any_template` extends Token

final case class Id(s: String) extends Token
final case class ContractId(s: String) extends Token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
"ContractId" -> BTContractId,
"Arrow" -> BTArrow,
"Map" -> BTMap,
"AnyTemplate" -> BTAnyTemplate,
)

private[parser] def fullIdentifier: Parser[Ref.Identifier] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ private[validation] object Serializability {
unserializable(URContractId)
case BTArrow =>
unserializable(URFunction)
case BTAnyTemplate =>
unserializable(URAnyTemplate)
}
case TForall(_, _) =>
unserializable(URForall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ private[validation] case class TypeSubst(map: Map[TypeVarName, Type], private va
ENone(apply(typ))
case ESome(typ, body) =>
ESome(apply(typ), apply(body))
case EToAnyTemplate(body) =>
EToAnyTemplate(apply(body))
case EFromAnyTemplate(tmplId, body) =>
EFromAnyTemplate(tmplId, apply(body))

}

def apply(choice: TemplateChoice): TemplateChoice = choice match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ private[validation] object Typing {
}

private def kindOfBuiltin(bType: BuiltinType): Kind = bType match {
case BTInt64 | BTText | BTTimestamp | BTParty | BTBool | BTDate | BTUnit => KStar
case BTInt64 | BTText | BTTimestamp | BTParty | BTBool | BTDate | BTUnit | BTAnyTemplate =>
KStar
case BTNumeric => KArrow(KNat, KStar)
case BTList | BTUpdate | BTScenario | BTContractId | BTOptional | BTMap => KArrow(KStar, KStar)
case BTArrow => KArrow(KStar, KArrow(KStar, KStar))
Expand Down Expand Up @@ -714,6 +715,21 @@ private[validation] object Typing {
checkExpr(exp, TScenario(typ))
}

private def typeOfToAnyTemplate(body: Expr): Type =
typeOf(body) match {
case TTyCon(tmplId) =>
lookupTemplate(ctx, tmplId)
TAnyTemplate
case typ =>
throw EExpectedTemplateType(ctx, typ)
}

private def typeOfFromAnyTemplate(tpl: TypeConName, body: Expr): Type = {
lookupTemplate(ctx, tpl)
checkExpr(body, TAnyTemplate)
TOptional(TTyCon(tpl))
}

def typeOf(expr0: Expr): Type = expr0 match {
case EVar(name) =>
lookupExpVar(name)
Expand Down Expand Up @@ -777,6 +793,10 @@ private[validation] object Typing {
checkType(typ, KStar)
val _ = checkExpr(body, typ)
TOptional(typ)
case EToAnyTemplate(body) =>
typeOfToAnyTemplate(body)
case EFromAnyTemplate(tmplId, body) =>
typeOfFromAnyTemplate(tmplId, body)
}

def checkExpr(expr: Expr, typ: Type): Type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ final case class URHigherKinded(varName: TypeVarName, kind: Kind) extends Unseri
case object URUninhabitatedType extends UnserializabilityReason {
def pretty: String = "variant type without constructors"
}
case object URAnyTemplate extends UnserializabilityReason {
def pretty: String = "AnyTemplate"
}

abstract class ValidationError extends java.lang.RuntimeException with Product with Serializable {
def context: Context
Expand Down Expand Up @@ -285,6 +288,11 @@ final case class EExpectedTemplatableType(context: Context, conName: TypeConName
protected def prettyInternal: String =
s"expected monomorphic record type in template definition, but found: ${conName.qualifiedName}"

}
final case class EExpectedTemplateType(context: Context, typ: Type) extends ValidationError {
protected def prettyInternal: String =
s"expected template type as argument to toAnyTemplate, but found: ${typ.pretty}"

}
final case class EImportCycle(context: Context, modName: List[ModuleName]) extends ValidationError {
protected def prettyInternal: String = s"cycle in module dependency ${modName.mkString(" -> ")}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ private[validation] object ExprTraversable {
case ENone(typ @ _) =>
case ESome(typ @ _, body) =>
f(body)
case EToAnyTemplate(body) =>
f(body)
case EFromAnyTemplate(tmplId @ _, body) =>
f(body)
}
()
}
Expand Down
Loading

0 comments on commit 1703a51

Please sign in to comment.