Skip to content

Commit

Permalink
support type synonyms in scala (#4101)
Browse files Browse the repository at this point in the history
* Support DAML-LF type synonyms in scala.

CHANGELOG_BEGIN
CHANGELOG_END

* dont create synonymns in GenerateSimpleDalf

* extend DAML-LF parser to support type synonyms

* test: expand type synonyms correctly
  • Loading branch information
nickchapman-da authored Jan 23, 2020
1 parent 20804a4 commit 62d592e
Show file tree
Hide file tree
Showing 25 changed files with 235 additions and 14 deletions.
1 change: 1 addition & 0 deletions compiler/daml-lf-ast/src/DA/Daml/LF/Ast/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ instance Pretty Type where
TVar v -> pretty v
TCon c -> pretty c
TSynApp s args ->
maybeParens (prec > precTApp) $
pretty s <-> hsep [pPrintPrec lvl (succ precTApp) arg | arg <- args ]
TApp (TApp (TBuiltin BTArrow) tx) ty ->
maybeParens (prec > precTFun)
Expand Down
8 changes: 1 addition & 7 deletions compiler/damlc/tests/src/DA/Test/GenerateSimpleDalf.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,11 @@ main = do
, tplChoices = NM.fromList ([chc,chc2] <> [arc | withArchiveChoice])
, tplKey = Nothing
}
let syn = DefTypeSyn
{ synLocation = Nothing
, synName = TypeSynName ["MySyn1"]
, synParams = []
, synType = TUnit
}
let mod = Module
{ moduleName = ModuleName ["Module"]
, moduleSource = Nothing
, moduleFeatureFlags = FeatureFlags{forbidPartyLiterals = True}
, moduleSynonyms = NM.fromList [syn]
, moduleSynonyms = NM.fromList []
, moduleDataTypes = NM.fromList ([tplRec, chcArg, chcArg2] <> [emptyRec | withArchiveChoice])
, moduleValues = NM.empty
, moduleTemplates = NM.fromList [tpl]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
val defs = mutable.ArrayBuffer[(DottedName, Definition)]()
val templates = mutable.ArrayBuffer[(DottedName, Template)]()

if (versionIsOlderThan(LV.Features.typeSynonyms)) {
assertEmpty(lfModule.getSynonymsList, "Module.synonyms")
} else if (!onlySerializableDataDefs) {
// collect type synonyms
lfModule.getSynonymsList.asScala
.foreach { defn =>
val defName = handleDottedName(
defn.getNameCase,
PLF.DefTypeSyn.NameCase.NAME_DNAME,
defn.getNameDname,
PLF.DefTypeSyn.NameCase.NAME_INTERNED_DNAME,
defn.getNameInternedDname,
"DefTypeSyn.name.name"
)
currentDefinitionRef =
Some(DefinitionRef(packageId, QualifiedName(moduleName, defName)))
val d = decodeDefTypeSyn(defn)
defs += (defName -> d)
}
}

// collect data types
lfModule.getDataTypesList.asScala
.filter(!onlySerializableDataDefs || _.getSerializable)
Expand Down Expand Up @@ -278,6 +299,14 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
)
}

private[this] def decodeDefTypeSyn(lfTypeSyn: PLF.DefTypeSyn): DTypeSyn = {
val params = lfTypeSyn.getParamsList.asScala
DTypeSyn(
ImmArray(params).map(decodeTypeVarWithKind),
decodeType(lfTypeSyn.getType)
)
}

def handleInternedName[Case](
actualCase: Case,
stringCase: Case,
Expand Down Expand Up @@ -548,8 +577,11 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
(TTyCon(decodeTypeConName(tcon.getTycon)) /: [Type] tcon.getArgsList.asScala)(
(typ, arg) => TApp(typ, decodeType(arg)))
case PLF.Type.SumCase.SYN =>
// FIXME https://github.com/digital-asset/daml/issues/3616
throw ParseError("PLF.Type.SumCase.SYN") //TODO
val tsyn = lfType.getSyn
TSynApp(
decodeTypeSynName(tsyn.getTysyn),
ImmArray(tsyn.getArgsList.asScala.map(decodeType))
)
case PLF.Type.SumCase.PRIM =>
val prim = lfType.getPrim
val baseType =
Expand Down Expand Up @@ -629,6 +661,19 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
Identifier(packageId, QualifiedName(module, name))
}

private[this] def decodeTypeSynName(lfTySynName: PLF.TypeSynName): TypeSynName = {
val (packageId, module) = decodeModuleRef(lfTySynName.getModule)
val name = handleDottedName(
lfTySynName.getNameCase,
PLF.TypeSynName.NameCase.NAME_DNAME,
lfTySynName.getNameDname,
PLF.TypeSynName.NameCase.NAME_INTERNED_DNAME,
lfTySynName.getNameInternedDname,
"TypeSynName.name.name"
)
Identifier(packageId, QualifiedName(module, name))
}

private[this] def decodeTypeConApp(lfTyConApp: PLF.Type.Con): TypeConApp =
TypeConApp(
decodeTypeConName(lfTyConApp.getTycon),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ object Ref {
type TypeConName = Identifier
val TypeConName = Identifier

/** Reference to a type synonym. */
type TypeSynName = Identifier
val TypeSynName = Identifier

/**
* Used to reference to leger objects like contractIds, ledgerIds,
* transactionId, ... We use the same type for those ids, because we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ private[digitalasset] class EncodeV1(val minor: LV.Minor) {
}
case value @ DValue(_, _, _, _) =>
builder.addValues(name -> value)

case DTypeSyn(params @ _, typ @ _) =>
throw new RuntimeException("TODO #3616, EncodeV1, DTypeSyn")

}
builder
}
Expand Down Expand Up @@ -205,6 +209,9 @@ private[digitalasset] class EncodeV1(val minor: LV.Minor) {
case _ => typ0 -> ImmArray.empty
}
val builder = PLF.Type.newBuilder()
// Be warned: Both the use of the unapply pattern TForalls and the pattern
// case TBuiltin(BTArrow) if versionIsOlderThan(LV.Features.arrowType) =>
// cause scala's exhaustivty checking to be disabled in the following match.
typ match {
case TVar(varName) =>
val b = PLF.Type.Var.newBuilder()
Expand Down Expand Up @@ -242,6 +249,8 @@ private[digitalasset] class EncodeV1(val minor: LV.Minor) {
case TStruct(fields) =>
expect(args.isEmpty)
builder.setStruct(PLF.Type.Struct.newBuilder().accumulateLeft(fields)(_ addFields _))
case TSynApp(_, _) =>
throw new RuntimeException("TODO #3616,encodeTypeBuilder")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ object PackageLookup {
for {
defn <- lookupDefinition(pkg, identifier)
dataTyp <- defn match {
case dataType: DDataType => Right(dataType)
case _: DValue =>
Left(Error(s"Got value definition instead of datatype when looking up $identifier"))
case dataType: DDataType => Right(dataType)
case _: DTypeSyn =>
Left(
Error(s"Got type synonym definition instead of datatype when looking up $identifier"))
}
} yield dataTyp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ private[engine] class ValueTranslator(compiledPackages: CompiledPackages) {
case struct: TStruct =>
fail(
s"Unexpected struct when replacing parameters in command translation -- all types should be serializable, and structs are not: $struct")
case syn: TSynApp =>
fail(
s"Unexpected type synonym application when replacing parameters in command translation -- all types should be serializable, and synonyms are not: $syn")
}

go(typ0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ object Ast {
def prettyType(t0: Type, prec: Int = precTForall): String = t0 match {
case TVar(n) => n
case TNat(n) => n.toString
case TSynApp(syn, args) =>
maybeParens(
prec > precTApp,
syn.qualifiedName.name.toString + " " +
args
.map { t =>
prettyType(t, precTApp + 1)
}
.toSeq
.mkString(" ")
)
case TTyCon(con) => con.qualifiedName.name.toString
case TBuiltin(BTArrow) => "(->)"
case TBuiltin(bt) => bt.toString.stripPrefix("BT")
Expand Down Expand Up @@ -251,6 +262,9 @@ object Ast {
val Decimal: TNat = values(10)
}

/** Fully applied type synonym. */
final case class TSynApp(tysyn: TypeSynName, args: ImmArray[Type]) extends Type

/** Reference to a type constructor. */
final case class TTyCon(tycon: TypeConName) extends Type

Expand Down Expand Up @@ -513,6 +527,7 @@ object Ast {

sealed abstract class Definition extends Product with Serializable

final case class DTypeSyn(params: ImmArray[(TypeVarName, Kind)], typ: Type) extends Definition
final case class DDataType(
serializable: Boolean,
params: ImmArray[(TypeVarName, Kind)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object LanguageVersion {
val anyType = v1_7
val typeRep = v1_7
val genMap = v1_dev
val typeSynonyms = v1_dev

/** Unstable, experimental features. This should stay in 1.dev forever.
* Features implemented with this flag should be moved to a separate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ private[digitalasset] class AstRewriter(
if (typeRule.isDefinedAt(x)) typeRule(x)
else
x match {
case TSynApp(_, _) => throw new RuntimeException("TODO #3616,AstRewriter,TSynApp")
case TVar(_) | TNat(_) | TBuiltin(_) => x
case TTyCon(typeCon) =>
TTyCon(apply(typeCon))
Expand Down Expand Up @@ -205,6 +206,9 @@ private[digitalasset] class AstRewriter(
x
case DValue(typ, noPartyLiterals, body, isTest) =>
DValue(apply(typ), noPartyLiterals, apply(body), isTest)

case DTypeSyn(params @ _, typ @ _) =>
throw new RuntimeException("TODO #3616,AstRewriter,DTypeSyn")
}

def apply(x: Template): Template =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[parser] class ModParser[P](parameters: ParserParameters[P]) {
}

private lazy val definition: Parser[Def] =
recDefinition | variantDefinition | enumDefinition | valDefinition | templateDefinition
synDefinition | recDefinition | variantDefinition | enumDefinition | valDefinition | templateDefinition

private def tags(allowed: Set[String]): Parser[Set[String]] =
rep(`@` ~> id) ^^ { tags =>
Expand All @@ -56,6 +56,13 @@ private[parser] class ModParser[P](parameters: ParserParameters[P]) {
private lazy val binder: Parser[(Name, Type)] =
id ~ `:` ~ typ ^^ { case id ~ _ ~ typ => id -> typ }

private lazy val synDefinition: Parser[DataDef] =
Id("synonym") ~>! dottedName ~ rep(typeBinder) ~
(`=` ~> typ) ^^ {
case id ~ params ~ typ =>
DataDef(id, DTypeSyn(ImmArray(params), typ))
}

private lazy val recDefinition: Parser[DataDef] =
Id("record") ~>! tags(dataDefTags) ~ dottedName ~ rep(typeBinder) ~
(`=` ~ `{` ~> repsep(binder, `,`) <~ `}`) ^^ {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
private lazy val tStruct: Parser[Type] =
`<` ~>! rep1sep(fieldType, `,`) <~ `>` ^^ (fs => TStruct(ImmArray(fs)))

private lazy val tTypeSynApp: Parser[Type] =
`|` ~> fullIdentifier ~ rep(typ0) <~ `|` ^^ { case id ~ tys => TSynApp(id, ImmArray(tys)) }

lazy val typ0: Parser[Type] =
`(` ~> typ <~ `)` |
tNat |
tForall |
tStruct |
tTypeSynApp |
(id ^? builtinTypes) ^^ TBuiltin |
fullIdentifier ^^ TTyCon.apply |
id ^^ TVar.apply
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ object Repl {

def prettyDefinitionType(defn: Definition, pkgId: PackageId, modId: ModuleName): String =
defn match {
case DTypeSyn(_, _) => "<type synonym>" // FIXME: pp this
case DDataType(_, _, _) => "<data type>" // FIXME(JM): pp this
case DValue(typ, _, _, _) => prettyType(typ, pkgId, modId)
}
Expand All @@ -295,6 +296,12 @@ object Repl {
if (needParens) s"($s)" else s

def prettyType(t0: Type, prec: Int = precTForall): String = t0 match {
case TSynApp(syn, args) =>
maybeParens(
prec > precTApp,
prettyQualified(pkgId, modId, syn)
+ " " + args.map(t => prettyType(t, precTApp + 1)).toSeq.mkString(" ")
)
case TVar(n) => n
case TNat(n) => n.toString
case TTyCon(con) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ private[validation] object Collision {
// ignore values
// List(NValDef(module, defName, vDef))
List.empty

case _: Ast.DTypeSyn =>
List.empty // TODO #3616: check type synonyms

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ private[validation] object Recursion {
val modRefsInType: Set[ModuleName] = {

def modRefsInType(acc: Set[ModuleName], typ0: Type): Set[ModuleName] = typ0 match {
case TSynApp(typeSynName, _) if typeSynName.packageId == pkgId =>
((acc + typeSynName.qualifiedName.module) /: TypeTraversable(typ0))(modRefsInType)
case TTyCon(typeConName) if typeConName.packageId == pkgId =>
acc + typeConName.qualifiedName.module
case otherwise =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ private[validation] object Serializability {
if (!vars(name)) unserializable(URFreeVar(name))
case TNat(_) =>
unserializable(URNat)
case TSynApp(syn, _) => unserializable(URTypeSyn(syn))
case TTyCon(tycon) =>
lookupDefinition(ctx, tycon) match {
case DDataType(true, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ private[validation] object TypeSubst {

private def go(fv0: Set[TypeVarName], subst0: Map[TypeVarName, Type], typ0: Type): Type =
typ0 match {
case TSynApp(syn, args) => TSynApp(syn, args.map(go(fv0, subst0, _)))
case TVar(name) => subst0.getOrElse(name, typ0)
case TTyCon(_) | TBuiltin(_) | TNat(_) => typ0
case TApp(t1, t2) => TApp(go(fv0, subst0, t1), go(fv0, subst0, t2))
Expand Down
Loading

0 comments on commit 62d592e

Please sign in to comment.