Skip to content

Commit

Permalink
parametrize iface types only by type... (digital-asset#678)
Browse files Browse the repository at this point in the history
...rather than by "field with type". in preparation to enums (digital-asset#105)
  • Loading branch information
bitonic authored Apr 25, 2019
1 parent 756b2c9 commit 14f6728
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.digitalasset.daml.lf.iface

import scalaz.std.map._
import scalaz.std.tuple._
import scalaz.syntax.applicative.^
import scalaz.syntax.traverse._
import scalaz.{Applicative, Bifunctor, Bitraverse, Functor, Traverse}
Expand All @@ -23,8 +24,11 @@ case class DefDataType[+RF, +VF](typeVars: ImmArraySeq[String], dataType: DataTy

object DefDataType {

/** Alias for application to [[FieldWithType]]. */
type FWT = DefDataType[FieldWithType, FieldWithType]
/** Alias for application to [[Type]]. Note that FWT stands for "Field with
* type", because before we parametrized over both the field and the type,
* while now we only parametrize over the type.
*/
type FWT = DefDataType[Type, Type]

implicit val `DDT bitraverse`: Bitraverse[DefDataType] =
new Bitraverse[DefDataType] {
Expand All @@ -36,20 +40,23 @@ object DefDataType {
}
}

sealed trait DataType[+RF, +VF] extends Product with Serializable {
def bimap[C, D](f: RF => C, g: VF => D): DataType[C, D] =
sealed trait DataType[+RT, +VT] extends Product with Serializable {
def bimap[C, D](f: RT => C, g: VT => D): DataType[C, D] =
Bifunctor[DataType].bimap(this)(f, g)

def fold[Z](record: Record[RF] => Z, variant: Variant[VF] => Z): Z = this match {
def fold[Z](record: Record[RT] => Z, variant: Variant[VT] => Z): Z = this match {
case r @ Record(_) => record(r)
case v @ Variant(_) => variant(v)
}
}

object DataType {

/** Alias for application to [[FieldWithType]]. */
type FWT = DataType[FieldWithType, FieldWithType]
/** Alias for application to [[Type]]. Note that FWT stands for "Field with
* type", because before we parametrized over both the field and the type,
* while now we only parametrize over the type.
*/
type FWT = DataType[Type, Type]

// While this instance appears to overlap the subclasses' traversals,
// naturality holds with respect to those instances and this one, so there is
Expand All @@ -67,44 +74,44 @@ object DataType {
}

sealed trait GetFields[+A] {
def fields: ImmArraySeq[A]
final def getFields: j.List[_ <: A] = fields.asJava
def fields: ImmArraySeq[(String, A)]
final def getFields: j.List[_ <: (String, A)] = fields.asJava
}
}

// Record TypeDecl`s have an object generated for them in their own file
final case class Record[+RF](fields: ImmArraySeq[RF])
extends DataType[RF, Nothing]
with DataType.GetFields[RF] {
final case class Record[+RT](fields: ImmArraySeq[(String, RT)])
extends DataType[RT, Nothing]
with DataType.GetFields[RT] {

/** Widen to DataType, in Java. */
def asDataType[PRF >: RF, VF]: DataType[PRF, VF] = this
def asDataType[PRT >: RT, VT]: DataType[PRT, VT] = this
}

object Record extends FWTLike[Record] {
implicit val `R traverse`: Traverse[Record] =
new Traverse[Record] {
override def traverseImpl[G[_]: Applicative, A, B](fa: Record[A])(
f: A => G[B]): G[Record[B]] =
Applicative[G].map(fa.fields traverse f)(bs => fa.copy(fields = bs))
Applicative[G].map(fa.fields traverse (_ traverse f))(bs => fa.copy(fields = bs))
}
}

// Variant TypeDecl`s have an object generated for them in their own file
final case class Variant[+VF](fields: ImmArraySeq[VF])
extends DataType[Nothing, VF]
with DataType.GetFields[VF] {
final case class Variant[+VT](fields: ImmArraySeq[(String, VT)])
extends DataType[Nothing, VT]
with DataType.GetFields[VT] {

/** Widen to DataType, in Java. */
def asDataType[RF, PVF >: VF]: DataType[RF, PVF] = this
def asDataType[RT, PVT >: VT]: DataType[RT, PVT] = this
}

object Variant extends FWTLike[Variant] {
implicit val `V traverse`: Traverse[Variant] =
new Traverse[Variant] {
override def traverseImpl[G[_]: Applicative, A, B](fa: Variant[A])(
f: A => G[B]): G[Variant[B]] =
Applicative[G].map(fa.fields traverse f)(bs => fa.copy(fields = bs))
Applicative[G].map(fa.fields traverse (_ traverse f))(bs => fa.copy(fields = bs))
}
}

Expand Down Expand Up @@ -149,6 +156,9 @@ object TemplateChoice {
/** Add aliases to companions. */
sealed abstract class FWTLike[F[+ _]] {

/** Alias for application to [[FieldWithType]]. */
type FWT = F[FieldWithType]
/** Alias for application to [[Type]]. Note that FWT stands for "Field with
* type", because before we parametrized over both the field and the type,
* while now we only parametrize over the type.
*/
type FWT = F[Type]
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ final case class TypeCon(name: TypeConName, typArgs: ImmArraySeq[Type])
defn.dataType
} else {
val paramsMap = Map(defn.typeVars.zip(typArgs): _*)
val instantiateFWT: FieldWithType => FieldWithType = {
case (field, typ) => (field, typ.mapTypeVars(v => paramsMap.getOrElse(v.name, v)))
}
val instantiateFWT: Type => Type = _.mapTypeVars(v => paramsMap.getOrElse(v.name, v))
defn.dataType.bimap(instantiateFWT, instantiateFWT)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class TypeSpec extends WordSpec with Matchers {
Record(ImmArraySeq("fld1" -> t"List a", "fld2" -> t"Mod:V b"))
)
)
inst shouldBe Record[FieldWithType](
ImmArraySeq("fld1" -> t"List Int64", "fld2" -> t"Mod:V Text"))
inst shouldBe Record[Type](ImmArraySeq("fld1" -> t"List Int64", "fld2" -> t"Mod:V Text"))
}

"mapTypeVars should replace all type variables in List(List a)" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import scalaz._

object Types {
final case class TypeDecls(
templates: Map[Identifier, iface.Record[iface.FieldWithType]] = Map.empty,
records: Map[Identifier, iface.Record[iface.FieldWithType]] = Map.empty,
variants: Map[Identifier, iface.Variant[iface.FieldWithType]] = Map.empty
templates: Map[Identifier, iface.Record.FWT] = Map.empty,
records: Map[Identifier, iface.Record.FWT] = Map.empty,
variants: Map[Identifier, iface.Variant.FWT] = Map.empty
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ trait DataFormat[S <: DataFormatState] {
}

object DataFormat {
type TemplateInfo = (Identifier, iface.Record[iface.FieldWithType])
type TemplateInfo = (Identifier, iface.Record.FWT)
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class MultiTableDataFormat(

private def createIOForTable(
tableName: String,
params: iface.Record[(String, iface.Type)],
params: iface.Record.FWT,
templateId: Identifier
): ConnectionIO[Unit] = {
val drop = dropTableIfExists(tableName).update.run
Expand Down Expand Up @@ -262,7 +262,7 @@ class MultiTableDataFormat(
case TypeCon(_, _) => "JSONB"
}

private def mapColumnTypes(params: iface.Record[(String, iface.Type)]): List[String] = {
private def mapColumnTypes(params: iface.Record.FWT): List[String] = {
params.fields.toList.map(_._2.fat).map {
case TypePrim(iface.PrimTypeOptional, typeArg :: _) => mapSQLType(typeArg) + " NULL"
case other => mapSQLType(other) + " NOT NULL"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,20 @@ object CodeGen {
val (unassociatedRecords, splattedVariants) = splatVariants(recordsAndVariants)

// 2. put templates/types into single Namespace.fromHierarchy
val treeified = Namespace.fromHierarchy {
def widenDDT[R, V](iddt: Iterable[ScopedDataType.DT[R, V]]) = iddt
val ntdRights =
(widenDDT(unassociatedRecords.map {
case ((q, tp), rec) => ScopedDataType(q, ImmArraySeq(tp: _*), rec)
}) ++ splattedVariants)
.map(sdt => (sdt.name, \/-(sdt)))
val tmplLefts = supportedTemplateIds.transform((_, v) => -\/(v))
(ntdRights ++ tmplLefts) map {
case (ddtIdent @ Identifier(_, qualName), body) =>
(qualName.module.segments.toList ++ qualName.name.segments.toList, (ddtIdent, body))
val treeified: Namespace[String, Option[lf.HierarchicalOutput.TemplateOrDatatype]] =
Namespace.fromHierarchy {
def widenDDT[R, V](iddt: Iterable[ScopedDataType.DT[R, V]]) = iddt
val ntdRights =
(widenDDT(unassociatedRecords.map {
case ((q, tp), rec) => ScopedDataType(q, ImmArraySeq(tp: _*), rec)
}) ++ splattedVariants)
.map(sdt => (sdt.name, \/-(sdt)))
val tmplLefts = supportedTemplateIds.transform((_, v) => -\/(v))
(ntdRights ++ tmplLefts) map {
case (ddtIdent @ Identifier(_, qualName), body) =>
(qualName.module.segments.toList ++ qualName.name.segments.toList, (ddtIdent, body))
}
}
}

// fold up the tree to discover the hierarchy's roots, each of which produces a file
val (treeErrors, topFiles) = lf.HierarchicalOutput.discoverFiles(treeified, util)
Expand All @@ -214,15 +215,15 @@ object CodeGen {
filePlans ++ specialPlans
}

type LHSIndexedRecords[+RF] = Map[(Identifier, List[String]), Record[RF]]
type LHSIndexedRecords[+RT] = Map[(Identifier, List[String]), Record[RT]]

private[this] def splitNTDs[RF, VF](recordsAndVariants: Iterable[ScopedDataType.DT[RF, VF]])
: (LHSIndexedRecords[RF], List[ScopedDataType[Variant[VF]]]) =
private[this] def splitNTDs[RT, VT](recordsAndVariants: Iterable[ScopedDataType.DT[RT, VT]])
: (LHSIndexedRecords[RT], List[ScopedDataType[Variant[VT]]]) =
partitionEithers(recordsAndVariants map {
case sdt @ ScopedDataType(qualName, typeVars, ddt) =>
ddt match {
case r: Record[RF] => Left(((qualName, typeVars.toList), r))
case v: Variant[VF] => Right(sdt copy (dataType = v))
case r: Record[RT] => Left(((qualName, typeVars.toList), r))
case v: Variant[VT] => Right(sdt copy (dataType = v))
}
})(breakOut, breakOut)

Expand All @@ -234,9 +235,9 @@ object CodeGen {
* figured by examining the _2: left means splatted, right means
* unchanged.
*/
private[this] def splatVariants[RF, VN <: String, VT <: iface.Type](
recordsAndVariants: Iterable[ScopedDataType.DT[RF, (VN, VT)]])
: (LHSIndexedRecords[RF], List[ScopedDataType[Variant[(VN, List[RF] \/ VT)]]]) = {
private[this] def splatVariants[RT <: iface.Type, VT <: iface.Type](
recordsAndVariants: Iterable[ScopedDataType.DT[RT, VT]])
: (LHSIndexedRecords[RT], List[ScopedDataType[Variant[List[(String, RT)] \/ VT]]]) = {

val (recordMap, variants) = splitNTDs(recordsAndVariants)

Expand All @@ -245,22 +246,21 @@ object CodeGen {
// or Scala 2.13
val (deletedRecords, newVariants) =
variants.traverseU {
case vsdt @ ScopedDataType(Identifier(packageId, qualName), vTypeVars, _) =>
case ScopedDataType(ident @ Identifier(packageId, qualName), vTypeVars, Variant(fields)) =>
val typeVarDelegate = Util simplyDelegates vTypeVars
vsdt.traverseU {
_.traverseU {
case (vn, vt) =>
val syntheticRecord = Identifier(
packageId,
qualName copy (name =
DottedName.assertFromSegments(qualName.name.segments.slowSnoc(vn).toSeq)))
val key = (syntheticRecord, vTypeVars.toList)
typeVarDelegate(vt)
.filter((_: Identifier) == syntheticRecord)
.flatMap(_ => recordMap get key)
.cata(nr => (Set(key), (vn, -\/(nr.fields.toList))), (noDeletion, (vn, \/-(vt))))
}
val (deleted, sdt) = fields.traverseU {
case (vn, vt) =>
val syntheticRecord = Identifier(
packageId,
qualName copy (name =
DottedName.assertFromSegments(qualName.name.segments.slowSnoc(vn).toSeq)))
val key = (syntheticRecord, vTypeVars.toList)
typeVarDelegate(vt)
.filter((_: Identifier) == syntheticRecord)
.flatMap(_ => recordMap get key)
.cata(nr => (Set(key), (vn, -\/(nr.fields.toList))), (noDeletion, (vn, \/-(vt))))
}
(deleted, ScopedDataType(ident, vTypeVars, Variant(sdt)))
}

(recordMap -- deletedRecords, newVariants)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.digitalasset.codegen.{Util, lf}
import lf.{DefTemplateWithRecord, EnvironmentInterface}

import scalaz.std.list._
import scalaz.std.tuple._
import scalaz.syntax.bifoldable._
import scalaz.syntax.foldable._
import scalaz.Bifoldable
Expand Down Expand Up @@ -39,7 +38,7 @@ private final case class LFDependencyGraph(private val util: lf.LFUtil)
}
val templateNodes = decls.collect {
case (qualName, InterfaceType.Template(typ, tpl)) =>
val recDeps = typ.foldMap((fwt: FieldWithType) => Util.genTypeTopLevelDeclNames(fwt._2))
val recDeps = typ.foldMap(Util.genTypeTopLevelDeclNames)
val choiceDeps = tpl.foldMap(Util.genTypeTopLevelDeclNames)
(
qualName,
Expand All @@ -51,12 +50,9 @@ private final case class LFDependencyGraph(private val util: lf.LFUtil)
Graph.cyclicDependencies(internalNodes = typeDeclNodes, roots = templateNodes)
}

private[this] def genTypeDependencies[B[_, _]: Bifoldable, I](gts: B[I, Type]): List[Identifier] =
Bifoldable[B].rightFoldable.foldMap(gts)(Util.genTypeTopLevelDeclNames)

private[this] def symmGenTypeDependencies[B[_, _]: Bifoldable, I, J](
gts: B[(I, Type), (J, Type)]): List[Identifier] =
gts.bifoldMap(genTypeDependencies(_))(genTypeDependencies(_))
private[this] def symmGenTypeDependencies[B[_, _]: Bifoldable](
gts: B[Type, Type]): List[Identifier] =
gts.bifoldMap(Util.genTypeTopLevelDeclNames)(Util.genTypeTopLevelDeclNames)
}

object DependencyGraph {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ object DamlRecordOrVariantTypeGen {

private val logger: Logger = Logger(getClass)

type VariantField = (String, List[FieldWithType] \/ iface.Type)
type RecordOrVariant = ScopedDataType.DT[FieldWithType, VariantField]
type VariantField = List[FieldWithType] \/ iface.Type
type RecordOrVariant = ScopedDataType.DT[iface.Type, VariantField]

def generate(
util: LFUtil,
Expand Down Expand Up @@ -142,7 +142,7 @@ object DamlRecordOrVariantTypeGen {
* - A type class instance (i.e. implicit object) for serializing/deserializing
* to/from the ArgumentValue type (see typed-ledger-api project)
*/
def toScalaDamlVariantType(fields: List[VariantField]): (Tree, Tree) = {
def toScalaDamlVariantType(fields: List[(String, VariantField)]): (Tree, Tree) = {
lazy val damlVariant =
if (fields.isEmpty) damlVariantZeroFields
else damlVariantOneOrMoreFields
Expand Down Expand Up @@ -194,7 +194,7 @@ object DamlRecordOrVariantTypeGen {
}"""
}

def variantWriteCase(variant: VariantField): CaseDef = variant match {
def variantWriteCase(variant: (String, VariantField)): CaseDef = variant match {
case (label, \/-(genTyp)) =>
cq"${TermName(label.capitalize)}(a) => ${typeObjectFromVariant(label, genTyp, Util.toIdent("a"))}"
case (label, -\/(record)) =>
Expand Down Expand Up @@ -250,7 +250,7 @@ object DamlRecordOrVariantTypeGen {
})
}

def variantGetBody(valueExpr: Tree, field: VariantField): Tree =
def variantGetBody(valueExpr: Tree, field: (String, VariantField)): Tree =
field match {
case (label, \/-(genType)) => fieldGetBody(valueExpr, label, genType)
case (label, -\/(record)) => recordGetBody(valueExpr, label, record)
Expand Down Expand Up @@ -378,7 +378,7 @@ object DamlRecordOrVariantTypeGen {
damlRecord
}

def lfEncodableForVariant(fields: Seq[VariantField]): Tree = {
def lfEncodableForVariant(fields: Seq[(String, VariantField)]): Tree = {
val lfEncodableName = TermName(s"${damlScalaName.name} LfEncodable")

val variantsWithNestedRecords: Seq[(String, List[(String, iface.Type)])] =
Expand Down Expand Up @@ -584,7 +584,7 @@ object DamlRecordOrVariantTypeGen {
private def generateVariantCaseDefList(util: LFUtil)(
appliedValueType: Tree,
typeArgs: List[TypeName],
fields: Seq[VariantField],
fields: Seq[(String, VariantField)],
recordFieldsByName: Map[String, TermName]): Seq[Tree] =
fields.map {
case (n, -\/(r)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scalaz.syntax.id._
import scalaz.syntax.foldable._
import scalaz.syntax.std.option._

case class DefTemplateWithRecord[+Type](`type`: Record[(String, Type)], template: DefTemplate[Type])
case class DefTemplateWithRecord[+Type](`type`: Record[Type], template: DefTemplate[Type])
object DefTemplateWithRecord {
type FWT = DefTemplateWithRecord[IType]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ object UsedTypeParams {

private def foldMapGenTypes[Z: Monoid](typeDecl: RecordOrVariant)(f: Type => Z): Z = {
val notAGT = (s: String) => mzero[Z]
typeDecl.foldMap(
_.bifoldMap(_.bifoldMap(notAGT)(f))(
_.bifoldMap(notAGT)(_.bifoldMap(_ foldMap (_.bifoldMap(notAGT)(f)))(f))))
typeDecl.foldMap(_.bifoldMap(f)(_.bifoldMap(_ foldMap (_.bifoldMap(notAGT)(f)))(f)))
}

private def collectTypeParams(field: Type): Set[String] = field match {
Expand Down
Loading

0 comments on commit 14f6728

Please sign in to comment.