Skip to content

Commit

Permalink
Daml lf type safty (Party & PackageId) (digital-asset#761)
Browse files Browse the repository at this point in the history
* daml-lf: split SimpleString into Party and PackageId

* daml-lf remove parameter from DefinitionRef
  • Loading branch information
remyhaemmerle-da authored May 6, 2019
1 parent 78bf1b8 commit 0489c6e
Show file tree
Hide file tree
Showing 132 changed files with 472 additions and 466 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ScenarioService extends ScenarioServiceGrpc.ScenarioServiceImplBase {
.flatMap { context =>
val packageId = scenarioId.getPackage.getSumCase match {
case PackageIdentifier.SumCase.SELF =>
context.homePackageId.underlyingString
context.homePackageId
case PackageIdentifier.SumCase.PACKAGE_ID =>
scenarioId.getPackage.getPackageId
case PackageIdentifier.SumCase.SUM_NOT_SET =>
Expand Down Expand Up @@ -204,7 +204,7 @@ class ScenarioService extends ScenarioServiceGrpc.ScenarioServiceImplBase {
)

resp.addAllLoadedModules(ctx.loadedModules().map(_.toString).asJava)
resp.addAllLoadedPackages(ctx.loadedPackages().map(_.underlyingString).asJava)
resp.addAllLoadedPackages((ctx.loadedPackages(): Iterable[String]).asJava)
respObs.onNext(resp.build)
respObs.onCompleted()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.digitalasset.daml.lf.data.Ref.{
ModuleName,
PackageId,
QualifiedName,
SimpleString
}
import com.digitalasset.daml.lf.archive.LanguageVersion
import com.digitalasset.daml.lf.lfpackage.Ast
Expand Down Expand Up @@ -60,11 +59,11 @@ class Context(val contextId: Context.ContextId) {
* in extPackages.
*/
val homePackageId: PackageId =
SimpleString.assertFromString("-homePackageId-")
PackageId.assertFromString("-homePackageId-")

private var modules: Map[ModuleName, Ast.Module] = Map.empty
private var extPackages: Map[PackageId, Ast.Package] = Map.empty
private var defns: Map[DefinitionRef[PackageId], SExpr] = Map.empty
private var defns: Map[DefinitionRef, SExpr] = Map.empty

def loadedModules(): Iterable[ModuleName] = modules.keys
def loadedPackages(): Iterable[PackageId] = extPackages.keys
Expand Down Expand Up @@ -104,7 +103,7 @@ class Context(val contextId: Context.ContextId) {
ref.packageId != homePackageId || ref.qualifiedName.module != lfModuleId)
}
unloadPackages.foreach { pkgId =>
val lfPkgId = assert(SimpleString.fromString(pkgId))
val lfPkgId = assert(PackageId.fromString(pkgId))
extPackages -= lfPkgId
defns = defns.filterKeys(ref => ref.packageId != lfPkgId)
}
Expand Down Expand Up @@ -165,7 +164,7 @@ class Context(val contextId: Context.ContextId) {
name: String
): Option[(Ledger, Speedy.Machine, Either[SError, SValue])] =
buildMachine(
Identifier(assert(SimpleString.fromString(pkgId)), assert(QualifiedName.fromString(name))))
Identifier(assert(PackageId.fromString(pkgId)), assert(QualifiedName.fromString(name))))
.map { machine =>
ScenarioRunner(machine).run() match {
case Right((diff @ _, steps @ _, ledger)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ case class Conversions(homePackageId: Ref.PackageId) {
case SError.ScenarioErrorMustFailSucceeded(tx @ _) =>
builder.setScenarioMustfailSucceeded(empty)

case SError.ScenarioErrorInvalidPartyName(party) =>
case SError.ScenarioErrorInvalidPartyName(party, _) =>
builder.setScenarioInvalidPartyName(party)

}
Expand Down Expand Up @@ -514,7 +514,7 @@ case class Conversions(homePackageId: Ref.PackageId) {
// Reconstitute the self package reference.
packageIdSelf
else
PackageIdentifier.newBuilder.setPackageId(pkg.underlyingString).build
PackageIdentifier.newBuilder.setPackageId(pkg).build

def convertIdentifier(identifier: Ref.Identifier): Identifier =
Identifier.newBuilder
Expand Down Expand Up @@ -604,7 +604,7 @@ case class Conversions(homePackageId: Ref.PackageId) {
case V.ValueText(t) => builder.setText(t)
case V.ValueTimestamp(ts) => builder.setTimestamp(ts.micros)
case V.ValueDate(d) => builder.setDate(d.days)
case V.ValueParty(p) => builder.setParty(p.underlyingString)
case V.ValueParty(p) => builder.setParty(p)
case V.ValueBool(b) => builder.setBool(b)
case V.ValueUnit => builder.setUnit(empty)
case V.ValueOptional(mbV) =>
Expand All @@ -627,6 +627,6 @@ case class Conversions(homePackageId: Ref.PackageId) {
}

def convertParty(p: Ref.Party): Party =
Party.newBuilder.setParty(p.underlyingString).build
Party.newBuilder.setParty(p).build

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package archive
import java.io.InputStream
import java.security.MessageDigest

import com.digitalasset.daml.lf.data.Ref.SimpleString
import com.digitalasset.daml.lf.data.Ref.PackageId
import com.digitalasset.daml_lf.DamlLf
import com.google.protobuf.CodedInputStream

Expand All @@ -22,7 +22,7 @@ abstract class Reader[+Pkg] {
def withRecursionLimit(recursionLimit: Int): Reader[Pkg] = new Reader[Pkg] {
override val PROTOBUF_RECURSION_LIMIT = recursionLimit
protected[this] override def readArchivePayloadOfVersion(
hash: SimpleString,
hash: PackageId,
lf: DamlLf.ArchivePayload,
version: LanguageVersion
): Pkg =
Expand All @@ -44,12 +44,12 @@ abstract class Reader[+Pkg] {
lf.getHashFunction match {
case DamlLf.HashFunction.SHA256 =>
val payload = lf.getPayload.toByteArray()
val theirHash = SimpleString.fromString(lf.getHash) match {
val theirHash = PackageId.fromString(lf.getHash) match {
case Right(hash) => hash
case Left(err) => throw ParseError(s"Invalid hash: $err")
}
val ourHash =
SimpleString.assertFromString(
PackageId.assertFromString(
MessageDigest.getInstance("SHA-256").digest(payload).map("%02x" format _).mkString)
if (ourHash != theirHash) {
throw ParseError(s"Mismatching hashes! Expected $ourHash but got $theirHash")
Expand All @@ -66,12 +66,12 @@ abstract class Reader[+Pkg] {
readArchiveAndVersion(lf)._1

@throws[ParseError]
final def readArchivePayload(hash: SimpleString, lf: DamlLf.ArchivePayload): Pkg =
final def readArchivePayload(hash: PackageId, lf: DamlLf.ArchivePayload): Pkg =
readArchivePayloadAndVersion(hash, lf)._1

@throws[ParseError]
final def readArchivePayloadAndVersion(
hash: SimpleString,
hash: PackageId,
lf: DamlLf.ArchivePayload): (Pkg, LanguageMajorVersion) = {
val majorVersion = readArchiveVersion(lf)
// for DAML-LF v1, we translate "no version" to minor version 0,
Expand All @@ -93,12 +93,12 @@ abstract class Reader[+Pkg] {
}

protected[this] def readArchivePayloadOfVersion(
hash: SimpleString,
hash: PackageId,
lf: DamlLf.ArchivePayload,
version: LanguageVersion): Pkg
}

object Reader extends Reader[(SimpleString, DamlLf.ArchivePayload)] {
object Reader extends Reader[(PackageId, DamlLf.ArchivePayload)] {
final case class ParseError(error: String) extends RuntimeException(error)

def damlLfCodedInputStreamFromBytes(
Expand Down Expand Up @@ -129,8 +129,8 @@ object Reader extends Reader[(SimpleString, DamlLf.ArchivePayload)] {
}

protected[this] override def readArchivePayloadOfVersion(
hash: SimpleString,
hash: PackageId,
lf: DamlLf.ArchivePayload,
version: LanguageVersion,
): (SimpleString, DamlLf.ArchivePayload) = (hash, lf)
): (PackageId, DamlLf.ArchivePayload) = (hash, lf)
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class DarReaderTest extends WordSpec with Matchers with Inside {
((packageId2, archive2), LanguageMajorVersion.V1) :: (
(packageId3, archive3),
LanguageMajorVersion.V1) :: Nil)) =>
packageId1.underlyingString shouldNot be('empty)
packageId2.underlyingString shouldNot be('empty)
packageId3.underlyingString shouldNot be('empty)
packageId1 shouldNot be('empty)
packageId2 shouldNot be('empty)
packageId3 shouldNot be('empty)
archive1.getDamlLf1.getModulesCount should be > 0
archive2.getDamlLf1.getModulesCount should be > 0
archive3.getDamlLf1.getModulesCount should be > 0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2019 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.digitalasset.daml.lf.data

import scala.util.matching.Regex

sealed abstract class MatchingStringModule {
type T <: String

def fromString(s: String): Either[String, T]

@throws[IllegalArgumentException]
def assertFromString(s: String): T =
fromString(s).fold(e => throw new IllegalArgumentException(e), identity)

def unapply(x: T): Some[String] = Some(x)
}

object MatchingStringModule extends (Regex => MatchingStringModule) {

override def apply(regex: Regex): MatchingStringModule = new MatchingStringModule {
type T = String

private val pattern = regex.pattern

def fromString(s: String): Either[String, T] =
Either.cond(pattern.matcher(s).matches(), s, s"""string "$s" does not match regex "$regex"""")
}

}
61 changes: 16 additions & 45 deletions daml-lf/data/src/main/scala/com/digitalasset/daml/lf/data/Ref.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,9 @@ package com.digitalasset.daml.lf.data

object Ref {

// SimpleString are non empty US-ASCII strings built with letters, digits, space, minus and,
// underscore. We use them to represent packageIds and party literals. In this way, we avoid
// empty identifiers, escaping problems, and other similar pitfalls.
final case class SimpleString private (underlyingString: String) extends Ordered[SimpleString] {
def compare(that: SimpleString): Int =
underlyingString.compareTo(that.underlyingString)
}

object SimpleString {

private def valid(c: Char) =
('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9') ||
c == ' ' || c == '-' || c == '_'

def fromString(string: String): Either[String, SimpleString] =
if (string.isEmpty)
Left(s"Expected a non-empty string")
else
string.find(c => !valid(c)) match {
case None =>
Right(new SimpleString(string))
case Some(c) =>
Left(s"""Invalid character ${c.toInt.formatted("%#x")} found in "$string"""")
}

/** Crashes if the string is not a valid [[SimpleString]]. */
@throws[IllegalArgumentException]
def assertFromString(s: String): SimpleString =
assert(fromString(s))
}

type Party = SimpleString
val Party = SimpleString

/* Location annotation */
case class Location(packageId: PackageId, module: ModuleName, start: (Int, Int), end: (Int, Int))

/* Choice name in a template. */
type ChoiceName = String

type ModuleName = DottedName
val ModuleName = DottedName

// we do not use String.split because `":foo".split(":")`
// results in `List("foo")` rather than `List("", "foo")`
private def split(s: String, splitCh: Char): ImmArray[String] = {
Expand Down Expand Up @@ -157,17 +115,30 @@ object Ref {
* specified package. */
case class Identifier(packageId: PackageId, qualifiedName: QualifiedName)

/* Choice name in a template. */
type ChoiceName = String

type ModuleName = DottedName
val ModuleName = DottedName

/** Party are non empty US-ASCII strings built with letters, digits, space, minus and,
underscore. We use them to represent [PackageId]s and [Party] literals. In this way, we avoid
empty identifiers, escaping problems, and other similar pitfalls.
*/
val Party = MatchingStringModule("""[a-zA-Z0-9\-_ ]+""".r)
type Party = Party.T

/** Reference to a package via a package identifier. The identifier is the ascii7
* lowercase hex-encoded hash of the package contents found in the DAML LF Archive. */
type PackageId = SimpleString
val PackageId = SimpleString
val PackageId = MatchingStringModule("""[a-zA-Z0-9\-_ ]+""".r)
type PackageId = PackageId.T

/** Reference to a value defined in the specified module. */
type ValueRef = Identifier
val ValueRef = Identifier

/** Reference to a value defined in the specified module. */
type DefinitionRef[PkgId] = Identifier
type DefinitionRef = Identifier
val DefinitionRef = Identifier

/** Reference to a type constructor. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// SPDX-License-Identifier: Apache-2.0

package com.digitalasset.daml.lf.data
import com.digitalasset.daml.lf.data.Ref.{DottedName, SimpleString, QualifiedName}

import com.digitalasset.daml.lf.data.Ref.{DottedName, PackageId, Party, QualifiedName}
import org.scalatest.{FreeSpec, Matchers}

class RefTest extends FreeSpec with Matchers {
Expand Down Expand Up @@ -64,34 +65,42 @@ class RefTest extends FreeSpec with Matchers {
}
}

"String" - {
"Party and PackageId" - {

val simpleChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_ "

"accepts simple characters" in {
for (c <- simpleChars)
SimpleString.fromString(s"the character $c is simple") shouldBe 'right
for (c <- simpleChars) {
Party.fromString(s"the character $c is simple") shouldBe 'right
PackageId.fromString(s"the character $c is simple") shouldBe 'right
}
}

"rejects the empty string" in {
SimpleString.fromString("") shouldBe 'left
Party.fromString("") shouldBe 'left
PackageId.fromString("") shouldBe 'left
}

"rejects non simple US-ASCII characters" in {
for {
c <- '\u0001' to '\u007f' if !simpleChars.contains(c)
} SimpleString.fromString(s"the US-ASCII character $c is not simple") shouldBe 'left
for (c <- '\u0001' to '\u007f' if !simpleChars.contains(c)) {
Party.fromString(s"the US-ASCII character $c is not simple") shouldBe 'left
PackageId.fromString(s"the US-ASCII character $c is not simple") shouldBe 'left
}
}

"rejects no US-ASCII characters" in {
for (c <- '\u0080' to '\u00ff')
SimpleString.fromString(s"the character $c is not US-ASCII") shouldBe 'left
for (c <- '\u0080' to '\u00ff') {
Party.fromString(s"the character $c is not US-ASCII") shouldBe 'left
PackageId.fromString(s"the character $c is not US-ASCII") shouldBe 'left
}
for (s <- List(
"español",
"東京",
"Λ (τ : ⋆) (σ: ⋆ → ⋆). λ (e : ∀ (α : ⋆). σ α) → (( e @τ ))"
))
SimpleString.fromString(s) shouldBe 'left
)) {
Party.fromString(s) shouldBe 'left
PackageId.fromString(s) shouldBe 'left
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import com.digitalasset.daml.lf.speedy.{Compiler, SExpr}
final class ConcurrentCompiledPackages extends CompiledPackages {
private[this] val _packages: ConcurrentHashMap[PackageId, Package] =
new ConcurrentHashMap()
private[this] val _defns: ConcurrentHashMap[DefinitionRef[PackageId], SExpr] =
private[this] val _defns: ConcurrentHashMap[DefinitionRef, SExpr] =
new ConcurrentHashMap()

def getPackage(pId: PackageId): Option[Package] = Option(_packages.get(pId))
def getDefinition(dref: DefinitionRef[PackageId]): Option[SExpr] = Option(_defns.get(dref))
def getDefinition(dref: DefinitionRef): Option[SExpr] = Option(_defns.get(dref))

/** Might ask for a package if the package you're trying to add references it.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.digitalasset.daml.lf.engine

import com.digitalasset.daml.lf.command._
import com.digitalasset.daml.lf.data._
import com.digitalasset.daml.lf.data.Ref.{Party, SimpleString}
import com.digitalasset.daml.lf.data.Ref.Party
import com.digitalasset.daml.lf.lfpackage.Ast._
import com.digitalasset.daml.lf.speedy.Compiler
import com.digitalasset.daml.lf.speedy.Pretty
Expand Down Expand Up @@ -140,7 +140,7 @@ final class Engine {
*/
def validatePartial(
tx: GenTransaction.WithTxValue[Tx.NodeId, AbsoluteContractId],
submitter: Option[SimpleString],
submitter: Option[Party],
ledgerEffectiveTime: Time.Timestamp,
requestor: Party,
contractIdMaping: ContractId => AbsoluteContractId,
Expand Down
Loading

0 comments on commit 0489c6e

Please sign in to comment.