Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAML-LF: restrict value versions #6109

Merged
merged 2 commits into from
May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ object SBuiltin {
case SBool(true) =>
()
case SBool(false) =>
asVersionedValue(args.get(0).toValue) match {
asVersionedValue(args.get(0).toValue, machine.supportedValueVersions) match {
case Left(err) => crash(err)
case Right(createArg) =>
throw DamlETemplatePreconditionViolated(
Expand Down Expand Up @@ -765,17 +765,15 @@ object SBuiltin {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
checkToken(args.get(5))
val createArg = args.get(0)
val createArgValue = asVersionedValue(createArg.toValue) match {
case Left(err) => crash(err)
case Right(x) => x
}
val createArgValue =
asVersionedValue(createArg.toValue, machine.supportedValueVersions).fold(crash, identity)
val agreement = args.get(1) match {
case SText(t) => t
case v => crash(s"agreement not text: $v")
}
val sigs = extractParties(args.get(2))
val obs = extractParties(args.get(3))
val key = extractOptionalKeyWithMaintainers(args.get(4))
val key = extractOptionalKeyWithMaintainers(args.get(4), machine.supportedValueVersions)

val (coid, newPtx) = machine.ptx
.insertCreate(
Expand Down Expand Up @@ -831,7 +829,7 @@ object SBuiltin {
val obs = extractParties(args.get(5))
val ctrls = extractParties(args.get(6))

val mbKey = extractOptionalKeyWithMaintainers(args.get(7))
val mbKey = extractOptionalKeyWithMaintainers(args.get(7), machine.supportedValueVersions)

machine.ptx = machine.ptx
.beginExercises(
Expand All @@ -846,10 +844,7 @@ object SBuiltin {
controllers = ctrls,
mbKey = mbKey,
byKey = byKey,
chosenValue = asVersionedValue(arg) match {
case Left(err) => crash(err)
case Right(x) => x
},
chosenValue = asVersionedValue(arg, machine.supportedValueVersions).fold(crash, identity)
)
.fold(err => throw DamlETransactionError(err), identity)
checkAborted(machine.ptx)
Expand All @@ -867,7 +862,7 @@ object SBuiltin {
checkToken(args.get(0))
val exerciseResult = args.get(1).toValue
machine.ptx = machine.ptx
.endExercises(asVersionedValue(exerciseResult) match {
.endExercises(asVersionedValue(exerciseResult, machine.supportedValueVersions) match {
case Left(err) => crash(err)
case Right(x) => x
})
Expand Down Expand Up @@ -939,7 +934,7 @@ object SBuiltin {
}
val signatories = extractParties(args.get(1))
val observers = extractParties(args.get(2))
val key = extractOptionalKeyWithMaintainers(args.get(3))
val key = extractOptionalKeyWithMaintainers(args.get(3), machine.supportedValueVersions)

val stakeholders = observers union signatories
val contextActors = machine.ptx.context.exeContext match {
Expand Down Expand Up @@ -972,7 +967,8 @@ object SBuiltin {
final case class SBULookupKey(templateId: TypeConName) extends SBuiltin(2) {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
checkToken(args.get(1))
val keyWithMaintainers = extractKeyWithMaintainers(args.get(0))
val keyWithMaintainers =
extractKeyWithMaintainers(args.get(0), machine.supportedValueVersions)
val gkey = GlobalKey(templateId, keyWithMaintainers.key.value)
// check if we find it locally
machine.ptx.keys.get(gkey) match {
Expand Down Expand Up @@ -1016,7 +1012,8 @@ object SBuiltin {
final case class SBUInsertLookupNode(templateId: TypeConName) extends SBuiltin(3) {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
checkToken(args.get(2))
val keyWithMaintainers = extractKeyWithMaintainers(args.get(0))
val keyWithMaintainers =
extractKeyWithMaintainers(args.get(0), machine.supportedValueVersions)
val mbCoid = args.get(1) match {
case SOptional(mb) =>
mb.map {
Expand Down Expand Up @@ -1047,7 +1044,8 @@ object SBuiltin {
final case class SBUFetchKey(templateId: TypeConName) extends SBuiltin(2) {
def execute(args: util.ArrayList[SValue], machine: Machine): Unit = {
checkToken(args.get(1))
val keyWithMaintainers = extractKeyWithMaintainers(args.get(0))
val keyWithMaintainers =
extractKeyWithMaintainers(args.get(0), machine.supportedValueVersions)
val gkey = GlobalKey(templateId, keyWithMaintainers.key.value)
// check if we find it locally
machine.ptx.keys.get(gkey) match {
Expand Down Expand Up @@ -1463,7 +1461,10 @@ object SBuiltin {
crash(s"value not a list of parties or party: $v")
}

private def extractKeyWithMaintainers(v: SValue): KeyWithMaintainers[Tx.Value[Nothing]] =
private def extractKeyWithMaintainers(
v: SValue,
supportedValueVersions: VersionRange[value.ValueVersion],
): KeyWithMaintainers[Tx.Value[Nothing]] =
v match {
case SStruct(flds, vals)
if flds.length == 2 && flds(0) == Ast.keyFieldName && flds(1) == Ast.maintainersFieldName =>
Expand All @@ -1475,7 +1476,7 @@ object SBuiltin {
.ensureNoCid
.left
.map(coid => s"Unexpected contract id in key: $coid")
versionedKeyVal <- asVersionedValue(keyVal)
versionedKeyVal <- asVersionedValue(keyVal, supportedValueVersions)
} yield
KeyWithMaintainers(
key = versionedKeyVal,
Expand All @@ -1485,9 +1486,11 @@ object SBuiltin {
}

private def extractOptionalKeyWithMaintainers(
optKey: SValue): Option[KeyWithMaintainers[Tx.Value[Nothing]]] =
optKey: SValue,
supportedValueVersions: VersionRange[value.ValueVersion],
): Option[KeyWithMaintainers[Tx.Value[Nothing]]] =
optKey match {
case SOptional(mbKey) => mbKey.map(extractKeyWithMaintainers)
case SOptional(mbKey) => mbKey.map(extractKeyWithMaintainers(_, supportedValueVersions))
case v => crash(s"Expected optional key with maintainers, got: $v")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ object Speedy {

/** The speedy CEK machine. */
final case class Machine(
/* Value versions that the machine can output */
supportedValueVersions: VersionRange[value.ValueVersion],
/* The control is what the machine should be evaluating. If this is not
* null, then `returnValue` must be null.
*/
Expand Down Expand Up @@ -527,6 +529,7 @@ object Speedy {
globalCids: Set[V.ContractId]
) =
Machine(
supportedValueVersions = value.ValueVersions.DefaultSupportedVersions,
ctrl = null,
returnValue = null,
frame = null,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf

final case class VersionRange[V](
min: V,
max: V,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.daml.lf.crypto.Hash
import com.daml.lf.data.Ref.{Identifier, Name}
import com.daml.lf.data._
import data.ScalazEqual._
import com.daml.lf.language.LanguageVersion

import scala.annotation.tailrec
import scalaz.{@@, Equal, Order, Tag}
Expand Down Expand Up @@ -223,13 +222,6 @@ object Value extends CidContainer1[Value] {
private[lf] def map1[Cid2](f: Cid => Cid2): VersionedValue[Cid2] =
VersionedValue.map1(f)(this)

/** Increase the `version` if appropriate for `languageVersions`. */
def typedBy(languageVersions: LanguageVersion*): VersionedValue[Cid] = {
import com.daml.lf.transaction.VersionTimeline, VersionTimeline._, Implicits._
copy(version =
latestWhenAllPresent(version, languageVersions map (a => a: SpecifiedVersion): _*))
}

def foreach1(f: Cid => Unit) =
VersionedValue.foreach1(f)(self)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.value
package com.daml.lf
package value

import com.daml.lf.data.Ref._
import com.daml.lf.data._
Expand Down Expand Up @@ -221,9 +222,10 @@ object ValueCoder {
def encodeVersionedValue[Cid](
encodeCid: EncodeCid[Cid],
value: Value[Cid],
supportedVersions: VersionRange[ValueVersion],
): Either[EncodeError, proto.VersionedValue] =
ValueVersions
.assignVersion(value)
.assignVersion(value, supportedVersions)
.fold(
err => Left(EncodeError(err)),
version => encodeVersionedValueWithCustomVersion(encodeCid, VersionedValue(version, value)),
Expand Down Expand Up @@ -561,9 +563,9 @@ object ValueCoder {
private[value] def valueToBytes[Cid](
encodeCid: EncodeCid[Cid],
v: Value[Cid],
): Either[EncodeError, Array[Byte]] = {
encodeVersionedValue(encodeCid, v).map(_.toByteArray)
}
supportedVersions: VersionRange[ValueVersion] = ValueVersions.DefaultSupportedVersions,
): Either[EncodeError, Array[Byte]] =
encodeVersionedValue(encodeCid, v, supportedVersions).map(_.toByteArray)

private[value] def valueFromBytes[Cid](
decodeCid: DecodeCid[Cid],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.value
package com.daml.lf
package value

import com.daml.lf.value.Value._
import com.daml.lf.LfVersions
import com.daml.lf.data.{Decimal, FrontStack, FrontStackCons, ImmArray}
import com.daml.lf.transaction.VersionTimeline

Expand Down Expand Up @@ -32,10 +32,14 @@ object ValueVersions
// Older versions are deprecated https://github.com/digital-asset/daml/issues/5220
// We force output of recent version, but keep reading older version as long as
// Sandbox is alive.
private[value] val minOutputVersion = ValueVersion("6")
val DefaultSupportedVersions = VersionRange(ValueVersion("6"), acceptedVersions.last)

def assignVersion[Cid](v0: Value[Cid]): Either[String, ValueVersion] = {
def assignVersion[Cid](
v0: Value[Cid],
supportedVersions: VersionRange[ValueVersion] = DefaultSupportedVersions,
): Either[String, ValueVersion] = {
import VersionTimeline.{maxVersion => maxVV}
import VersionTimeline.Implicits._

@tailrec
def go(
Expand Down Expand Up @@ -80,25 +84,32 @@ object ValueVersions
}
}

go(minOutputVersion, FrontStack(v0))
go(supportedVersions.min, FrontStack(v0)) match {
case Right(inferredVersion) if supportedVersions.max precedes inferredVersion =>
Left(s"inferred version $inferredVersion is not supported")
case res =>
res
}

}

@throws[IllegalArgumentException]
def assertAssignVersion[Cid](v0: Value[Cid]): ValueVersion =
assignVersion(v0) match {
case Left(err) => throw new IllegalArgumentException(err)
case Right(x) => x
}
def assertAssignVersion[Cid](
v0: Value[Cid],
supportedVersions: VersionRange[ValueVersion] = DefaultSupportedVersions,
): ValueVersion =
data.assertRight(assignVersion(v0, supportedVersions))

def asVersionedValue[Cid](
value: Value[Cid],
supportedVersions: VersionRange[ValueVersion] = DefaultSupportedVersions,
): Either[String, VersionedValue[Cid]] =
assignVersion(value).map(version => VersionedValue(version = version, value = value))
assignVersion(value, supportedVersions).map(VersionedValue(_, value))

@throws[IllegalArgumentException]
def assertAsVersionedValue[Cid](value: Value[Cid]): VersionedValue[Cid] =
asVersionedValue(value) match {
case Left(err) => throw new IllegalArgumentException(err)
case Right(x) => x
}
def assertAsVersionedValue[Cid](
value: Value[Cid],
supportedVersions: VersionRange[ValueVersion] = DefaultSupportedVersions,
): VersionedValue[Cid] =
data.assertRight(asVersionedValue(value, supportedVersions))
}