Skip to content

Commit

Permalink
LF: Add "requires" field to Scala Ast. (#12028)
Browse files Browse the repository at this point in the history
Part of #11978. Adds typechecking for this field on the interface side,
and enforces that any template that implements A must implement B if A requires B.

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
remyhaemmerle-da authored Dec 7, 2021
1 parent a159b50 commit 9d7eb07
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ private[archive] class DecodeV1(minor: LV.Minor) {
lfInterface: PLF.DefInterface,
): DefInterface =
DefInterface.build(
requires = lfInterface.getRequiresList.asScala.view.map(decodeTypeConName),
param = getInternedName(lfInterface.getParamInternedStr, "DefInterface.param"),
fixedChoices = lfInterface.getFixedChoicesList.asScala.view.map(decodeChoice(id, _)),
methods = lfInterface.getMethodsList.asScala.view.map(decodeInterfaceMethod),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ object Ast {
val TemplateKeySignature = new GenTemplateKeyCompanion[Unit]

final case class GenDefInterface[E](
requires: Set[TypeConName],
param: ExprVarName, // Binder for template argument.
fixedChoices: Map[ChoiceName, GenTemplateChoice[E]],
methods: Map[MethodName, InterfaceMethod],
Expand All @@ -657,11 +658,16 @@ object Ast {
final class GenDefInterfaceCompanion[E] {
@throws[PackageError]
def build(
requires: Iterable[TypeConName],
param: ExprVarName, // Binder for template argument.
fixedChoices: Iterable[GenTemplateChoice[E]],
methods: Iterable[InterfaceMethod],
precond: E,
): GenDefInterface[E] = {
val requiresSet = toSetWithoutDuplicate(
requires,
(name: TypeConName) => PackageError(s"repeated required interface $name"),
)
val fixedChoiceMap = toMapWithoutDuplicate(
fixedChoices.view.map(c => c.name -> c),
(name: ChoiceName) => PackageError(s"collision on interface choice name $name"),
Expand All @@ -670,26 +676,28 @@ object Ast {
methods.view.map(c => c.name -> c),
(name: MethodName) => PackageError(s"collision on interface method name $name"),
)
GenDefInterface(param, fixedChoiceMap, methodMap, precond)
GenDefInterface(requiresSet, param, fixedChoiceMap, methodMap, precond)
}

def apply(
requires: Set[TypeConName],
param: ExprVarName,
fixedChoices: Map[ChoiceName, GenTemplateChoice[E]],
methods: Map[MethodName, InterfaceMethod],
precond: E,
): GenDefInterface[E] =
GenDefInterface(param, fixedChoices, methods, precond)
GenDefInterface(requires, param, fixedChoices, methods, precond)

def unapply(arg: GenDefInterface[E]): Some[
(
Set[TypeConName],
ExprVarName,
Map[ChoiceName, GenTemplateChoice[E]],
Map[MethodName, InterfaceMethod],
E,
)
] =
Some((arg.param, arg.fixedChoices, arg.methods, arg.precond))
Some((arg.requires, arg.param, arg.fixedChoices, arg.methods, arg.precond))
}

type DefInterface = GenDefInterface[Expr]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ object Util {

private def toSignature(interface: DefInterface): DefInterfaceSignature =
interface match {
case DefInterface(param, fixedChoices, methods, _) =>
case DefInterface(requires, param, fixedChoices, methods, _) =>
DefInterfaceSignature(
requires,
param,
fixedChoices.transform((_, choice) => toSignature(choice)),
methods,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ private[daml] class AstRewriter(

def apply(x: DefInterface): DefInterface =
x match {
case DefInterface(param, fixedChoices, methods, precond) =>
case DefInterface(requires, param, fixedChoices, methods, precond) =>
DefInterface(
requires,
param,
fixedChoices.transform((_, v) => apply(v)),
methods.transform((_, v) => apply(v)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ private[parser] class ModParser[P](parameters: ParserParameters[P]) {
choices =>
IfaceDef(
tycon,
DefInterface.build(x, choices, methods, precond),
DefInterface.build(Set.empty, x, choices, methods, precond),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ class ParsersSpec extends AnyWordSpec with ScalaCheckPropertyChecks with Matcher

val interface =
DefInterface(
requires = Set.empty,
param = n"this",
precond = e"False",
methods = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,14 @@ private[validation] object Typing {
checkExpr(key.maintainers, TFun(key.typ, TParties))
()
}
implementations.values.foreach(env.checkIfaceImplementation(tplName, _))
env.checkIfaceImplementations(tplName, implementations)
}

def checkDefIface(ifaceName: TypeConName, iface: DefInterface): Unit =
iface match {
case DefInterface(param, fixedChoices, methods, precond) =>
case DefInterface(requires, param, fixedChoices, methods, precond) =>
val env = introExprVar(param, TTyCon(ifaceName))
requires.foreach(required => handleLookup(ctx, interface.lookupInterface(required)))
env.checkExpr(precond, TBool)
methods.values.foreach(checkIfaceMethod)
fixedChoices.values.foreach(env.checkChoice(ifaceName, _))
Expand All @@ -469,31 +470,41 @@ private[validation] object Typing {
AlphaEquiv.alphaEquiv(t1, t2) ||
AlphaEquiv.alphaEquiv(expandTypeSynonyms(t1), expandTypeSynonyms(t2))

def checkIfaceImplementation(tplTcon: TypeConName, impl: TemplateImplements): Unit = {
val DefInterfaceSignature(_, fixedChoices, methods, _) =
handleLookup(ctx, interface.lookupInterface(impl.interfaceId))

val fixedChoiceSet = fixedChoices.keySet
if (impl.inheritedChoices != fixedChoiceSet) {
throw EBadInheritedChoices(
ctx,
impl.interfaceId,
tplTcon,
fixedChoiceSet,
impl.inheritedChoices,
)
}
def checkIfaceImplementations(
tplTcon: TypeConName,
impls: Map[TypeConName, TemplateImplements],
): Unit = {

methods.values.foreach { (method: InterfaceMethod) =>
if (!impl.methods.contains(method.name))
throw EMissingInterfaceMethod(ctx, tplTcon, impl.interfaceId, method.name)
}
impl.methods.values.foreach { (tplMethod: TemplateImplementsMethod) =>
methods.get(tplMethod.name) match {
case None =>
throw EUnknownInterfaceMethod(ctx, tplTcon, impl.interfaceId, tplMethod.name)
case Some(method) =>
checkExpr(tplMethod.value, TFun(TTyCon(tplTcon), method.returnType))
impls.foreach { case (iface, impl) =>
val DefInterfaceSignature(requires, _, fixedChoices, methods, _) =
handleLookup(ctx, interface.lookupInterface(impl.interfaceId))

requires
.filterNot(impls.contains)
.foreach(required => throw EMissingRequiredInterface(ctx, tplTcon, iface, required))

val fixedChoiceSet = fixedChoices.keySet
if (impl.inheritedChoices != fixedChoiceSet) {
throw EBadInheritedChoices(
ctx,
impl.interfaceId,
tplTcon,
fixedChoiceSet,
impl.inheritedChoices,
)
}

methods.values.foreach { (method: InterfaceMethod) =>
if (!impl.methods.contains(method.name))
throw EMissingInterfaceMethod(ctx, tplTcon, impl.interfaceId, method.name)
}
impl.methods.values.foreach { (tplMethod: TemplateImplementsMethod) =>
methods.get(tplMethod.name) match {
case None =>
throw EUnknownInterfaceMethod(ctx, tplTcon, impl.interfaceId, tplMethod.name)
case Some(method) =>
checkExpr(tplMethod.value, TFun(TTyCon(tplTcon), method.returnType))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,13 @@ final case class ETemplateDoesNotImplementInterface(
override protected def prettyInternal: String =
s"Template $template does not implement interface $iface"
}

final case class EMissingRequiredInterface(
context: Context,
template: TypeConName,
requiringIface: TypeConName,
missingRequiredIface: TypeConName,
) extends ValidationError {
override protected def prettyInternal: String =
s"Template $template is missing an implementation of interface $missingRequiredIface required by interface $requiringIface"
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ private[validation] object ExprIterable {
private[iterable] def iterator(x: DefInterface): Iterator[Expr] =
x match {
case DefInterface(
requires @ _,
param @ _,
fixedChoices,
methods @ _,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ private[validation] object TypeIterable {

private[validation] def iterator(interface: DefInterface): Iterator[Type] =
interface match {
case DefInterface(_, fixedChoice, methods, precond) =>
iterator(precond) ++
case DefInterface(requires, _, fixedChoice, methods, precond) =>
requires.iterator.map(TTyCon) ++
iterator(precond) ++
fixedChoice.values.iterator.flatMap(iterator) ++
methods.values.iterator.flatMap(iterator)
}
Expand Down

0 comments on commit 9d7eb07

Please sign in to comment.