Skip to content

Commit

Permalink
LF: Imporve safety of the Serialization of proto message. (digital-as…
Browse files Browse the repository at this point in the history
…set#12686)

This is a follow up of digital-asset#12638, applied to LF support for KV.

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
remyhaemmerle-da authored Feb 1, 2022
1 parent b4ed15b commit 183f936
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ package archive

import java.io.File

sealed abstract class Error(val msg: String) extends RuntimeException(msg)
sealed abstract class Error(val msg: String)
extends RuntimeException(msg)
with Product
with Serializable

object Error {

Expand Down Expand Up @@ -42,4 +45,6 @@ object Error {
extends Error(s"Unsupported file extension: ${file.getAbsolutePath}")

final case class Parsing(override val msg: String) extends Error(msg)

final case class Encoding(override val msg: String) extends Error(msg)
}
1 change: 1 addition & 0 deletions daml-lf/encoder/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ da_scala_library(
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/data",
"//daml-lf/language",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.archive.testing
package com.daml.lf
package archive.testing

import java.security.MessageDigest
import com.daml.SafeProto

import java.security.MessageDigest
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.language.Ast.Package
import com.daml.lf.language.{LanguageMajorVersion, LanguageVersion}
Expand Down Expand Up @@ -35,7 +37,7 @@ object Encode {

final def encodeArchive(pkg: (PackageId, Package), version: LanguageVersion): PLF.Archive = {

val payload = encodePayloadOfVersion(pkg, version).toByteString
val payload = data.assertRight(SafeProto.toByteString(encodePayloadOfVersion(pkg, version)))
val hash = PackageId.assertFromString(
MessageDigest.getInstance("SHA-256").digest(payload.toByteArray).map("%02x" format _).mkString
)
Expand Down
1 change: 1 addition & 0 deletions daml-lf/kv-support/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ da_scala_library(
"//daml-lf/transaction",
"//daml-lf/transaction:transaction_proto_java",
"//daml-lf/transaction:value_proto_java",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ object ConversionError {
extends ConversionError(errorMessage)
final case class DecodeError(cause: ValueCoder.DecodeError)
extends ConversionError(cause.errorMessage)
final case class EncodeError(cause: ValueCoder.EncodeError)
extends ConversionError(cause.errorMessage)
final case class InternalError(override val errorMessage: String)
extends ConversionError(errorMessage)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.archives

import com.daml.SafeProto
import com.daml.lf.archive.{ArchiveParser, Decode, Error => ArchiveError}
import com.daml.lf.data.Ref
import com.daml.lf.data.Ref.PackageId
Expand All @@ -21,12 +22,15 @@ object ArchiveConversions {

def parsePackageIdsAndRawArchives(
archives: List[com.daml.daml_lf_dev.DamlLf.Archive]
): Either[ArchiveError.Parsing, Map[Ref.PackageId, RawArchive]] =
): Either[ArchiveError, Map[Ref.PackageId, RawArchive]] =
archives.partitionMap { archive =>
Ref.PackageId.fromString(archive.getHash).map(_ -> RawArchive(archive.toByteString))
for {
pkgId <- Ref.PackageId.fromString(archive.getHash).left.map(ArchiveError.Parsing)
bytes <- SafeProto.toByteString(archive).left.map(ArchiveError.Encoding)
} yield pkgId -> RawArchive(bytes)
} match {
case (Nil, hashesAndRawArchives) => Right(hashesAndRawArchives.toMap)
case (errors, _) => Left(ArchiveError.Parsing(errors.head))
case (errors, _) => Left(errors.head)
}

def decodePackages(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.contracts

import com.daml.SafeProto
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.{TransactionCoder, TransactionOuterClass}
import com.daml.lf.value.{Value, ValueCoder}
Expand All @@ -14,9 +15,10 @@ object ContractConversions {
def encodeContractInstance(
coinst: Value.VersionedContractInstance
): Either[ValueCoder.EncodeError, RawContractInstance] =
TransactionCoder
.encodeContractInstance(ValueCoder.CidEncoder, coinst)
.map(contractInstance => RawContractInstance(contractInstance.toByteString))
for {
message <- TransactionCoder.encodeContractInstance(ValueCoder.CidEncoder, coinst)
bytes <- SafeProto.toByteString(message).left.map(ValueCoder.EncodeError(_))
} yield RawContractInstance(bytes)

def decodeContractInstance(
rawContractInstance: RawContractInstance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.transactions

import com.daml.SafeProto
import com.daml.lf.data.{FrontStack, FrontStackCons, ImmArray}
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.TransactionOuterClass.Node.NodeTypeCase
Expand All @@ -25,9 +26,11 @@ object TransactionConversions {
def encodeTransaction(
tx: VersionedTransaction
): Either[ValueCoder.EncodeError, RawTransaction] =
TransactionCoder
.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
.map(transaction => RawTransaction(transaction.toByteString))
for {
msg <-
TransactionCoder.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
bytes <- SafeProto.toByteString(msg).left.map(ValueCoder.EncodeError(_))
} yield RawTransaction(bytes)

def decodeTransaction(
rawTx: RawTransaction
Expand Down Expand Up @@ -63,7 +66,7 @@ object TransactionConversions {
def reconstructTransaction(
transactionVersion: String,
nodesWithIds: Seq[TransactionNodeIdWithNode],
): Either[ConversionError.ParseError, RawTransaction] = {
): Either[ConversionError, RawTransaction] = {
import scalaz.std.either._
import scalaz.std.list._
import scalaz.syntax.traverse._
Expand Down Expand Up @@ -94,7 +97,14 @@ object TransactionConversions {
}
.toList
.sequence_
.map(_ => RawTransaction(transactionBuilder.build.toByteString))
.flatMap(_ =>
SafeProto.toByteString(transactionBuilder.build()) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
Left(ConversionError.EncodeError(ValueCoder.EncodeError(msg)))
}
)
}

/** Decodes and extracts outputs of a submitted transaction, that is the IDs and keys of contracts created or updated
Expand Down Expand Up @@ -210,7 +220,7 @@ object TransactionConversions {
}
}

goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).map {
goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).flatMap {
nodesToKeep =>
val filteredRoots = transaction.getRootsList.asScala.filter(nodesToKeep)

Expand Down Expand Up @@ -239,7 +249,14 @@ object TransactionConversions {
.addAllNodes(filteredNodes.asJavaCollection)
.setVersion(transaction.getVersion)
.build()
RawTransaction(newTransaction.toByteString)

SafeProto.toByteString(newTransaction) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
// Should not happen as removing nodes should results into a smaller transaction.
Left(ConversionError.InternalError(msg))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ object TransactionTraversal {
case Left(error) => Left(ConversionError.DecodeError(error))
case Right(nodeWitnesses) =>
val witnesses = parentWitnesses union nodeWitnesses
// Here node.toByteString is safe.
// Indeed node is a submessage of the transaction `rawTx` we got serialized
// as input of `traverseTransactionWithWitnesses` and successfully decoded, i.e.
// `rawTx` requires less than 2GB to be serialized, so does <node`.
// See com.daml.SafeProto for more details about issues with the toByteString method.
f(nodeId, RawTransaction.Node(node.toByteString), witnesses)
// Recurse into children (if any).
node.getNodeTypeCase match {
Expand All @@ -62,7 +67,7 @@ object TransactionTraversal {
}
}

private def informeesOfNode(
private[this] def informeesOfNode(
txVersion: TransactionVersion,
node: TransactionOuterClass.Node,
): Either[ValueCoder.DecodeError, Set[Ref.Party]] =
Expand Down
2 changes: 1 addition & 1 deletion libs-scala/safe-proto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ da_scala_library(
da_scala_test_suite(
name = "safe-protot-test",
srcs = glob(["src/test/scala/**/*.scala"]),
max_heap_size = "4g",
max_heap_size = "3g",
deps = [
":safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object SafeProto {
case e: RuntimeException
if e.isInstanceOf[NegativeArraySizeException] ||
e.getCause != null && e.getCause.isInstanceOf[CodedOutputStream.OutOfSpaceException] =>
Left(s"the ${message.getClass.getName} message is too big to be serialized")
Left(s"the ${message.getClass.getName} message is too large to be serialized")
}

def toByteString(message: AbstractMessageLite[_, _]): Either[String, ByteString] =
Expand Down

0 comments on commit 183f936

Please sign in to comment.