Skip to content

Commit

Permalink
Add a JWT authentication to sandbox (digital-asset#3283)
Browse files Browse the repository at this point in the history
  • Loading branch information
rautenrieth-da authored Nov 7, 2019
1 parent df11293 commit 89d6c73
Show file tree
Hide file tree
Showing 14 changed files with 532 additions and 141 deletions.
28 changes: 28 additions & 0 deletions docs/source/tools/sandbox.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ Due to possible conflicts between the ``&`` character and various terminal shell
If you're not familiar with JDBC URLs, see the JDBC docs for more information: https://jdbc.postgresql.org/documentation/head/connect.html

Running with authentication
***************************

By default, Sandbox does not use any authentication and accepts all valid ledger API requests.

To start Sandbox with authentication based on `JWT <https://jwt.io/>`_, run ``daml sandbox --auth-jwt-hs256=<secret>`` where ``<secret>`` is the secret used to sign the token with the HMAC256 algorithm.

The JWT payload has the following schema:

.. code-block:: json
{
"ledgerId": "aaaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
"participantId": null,
"applicationId": null,
"exp": 1300819380,
"admin": true,
"actAs": ["Alice"],
"readAs": ["Alice", "Bob"],
}
where
``ledgerId``, ``participantId``, ``applicationId`` restricts the validity of the token to the given ledger, participant, or application;
``exp`` is the standard JWT expiration date;
``admin`` determines whether the token bearer is authorized to use admin endpoints of the ledger API;
``actAs`` lists all DAML parties the token bearer can act as (e.g., as submitter of a command); and
``readAs`` lists all DAML parties the token bearer can read data for.

Command-line reference
**********************

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import scalaz.syntax.traverse._
class JwtVerifier(verifier: com.auth0.jwt.interfaces.JWTVerifier) {

def verify(jwt: domain.Jwt): Error \/ domain.DecodedJwt[String] = {
// The auth0 library verification already fails if the token has expired,
// but we still need to do manual expiration checks in ongoing streams
\/.fromTryCatchNonFatal(verifier.verify(jwt.value))
.bimap(
e => Error('verify, e.getMessage),
Expand Down
16 changes: 16 additions & 0 deletions ledger/ledger-api-auth/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,34 @@ da_scala_library(
"//ledger-api/grpc-definitions:ledger-api-scalapb",
"//ledger-api/rs-grpc-akka",
"//ledger-api/rs-grpc-bridge",
"//ledger-service/jwt",
"//ledger/ledger-api-akka",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//ledger/ledger-api-domain",
"//ledger/ledger-api-scala-logging",
"@maven//:com_auth0_java_jwt",
"@maven//:com_typesafe_akka_akka_actor_2_12",
"@maven//:com_typesafe_akka_akka_stream_2_12",
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_context",
"@maven//:io_grpc_grpc_core",
"@maven//:io_grpc_grpc_services",
"@maven//:io_spray_spray_json_2_12",
"@maven//:org_scala_lang_modules_scala_java8_compat_2_12",
"@maven//:org_scalaz_scalaz_core_2_12",
"@maven//:org_slf4j_slf4j_api",
],
)

da_scala_test_suite(
name = "ledger-api-auth-scala-tests",
srcs = glob(["src/test/suite/**/*.scala"]),
deps = [
":ledger-api-auth",
"@maven//:io_spray_spray_json_2_12",
"@maven//:org_scalacheck_scalacheck_2_12",
"@maven//:org_scalatest_scalatest_2_12",
"@maven//:org_scalaz_scalaz_scalacheck_binding_2_12",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2019 The DAML Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.digitalasset.ledger.api.auth

import java.util.concurrent.{CompletableFuture, CompletionStage}

import com.digitalasset.daml.lf.data.Ref
import com.digitalasset.jwt.JwtVerifier
import com.digitalasset.ledger.api.auth.AuthServiceJWT.Error
import io.grpc.Metadata
import org.slf4j.{Logger, LoggerFactory}
import spray.json._

import scala.collection.mutable.ListBuffer
import scala.util.Try

/** An AuthService that reads a JWT token from a `Authorization: Bearer` HTTP header.
* The token is expected to use the format as defined in [[AuthServiceJWTPayload]]:
*/
class AuthServiceJWT(verifier: JwtVerifier) extends AuthService {

protected val logger: Logger = LoggerFactory.getLogger(AuthServiceJWT.getClass)

override def decodeMetadata(headers: Metadata): CompletionStage[Claims] = {
decodeAndParse(headers).fold(
error => {
logger.warn("Authorization error: " + error.message)
CompletableFuture.completedFuture(Claims.empty)
},
token => CompletableFuture.completedFuture(payloadToClaims(token))
)
}

private[this] def parsePayload(jwtPayload: String): Either[Error, AuthServiceJWTPayload] = {
import AuthServiceJWTCodec.JsonImplicits._
Try(JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]).toEither.left.map(t =>
Error("Could not parse JWT token: " + t.getMessage))
}

private[this] def decodeAndParse(headers: Metadata): Either[Error, AuthServiceJWTPayload] = {
val bearerTokenRegex = "Bearer (.*)".r

for {
headerValue <- Option
.apply(headers.get(AuthServiceJWT.AUTHORIZATION_KEY))
.toRight(Error("Authorization header not found"))
token <- bearerTokenRegex
.findFirstMatchIn(headerValue)
.map(_.group(1))
.toRight(Error("Authorization header does not use Bearer format"))
decoded <- verifier
.verify(com.digitalasset.jwt.domain.Jwt(token))
.toEither
.left
.map(e => Error("Could not verify JWT token: " + e.message))
parsed <- parsePayload(decoded.payload)
} yield parsed
}

private[this] def payloadToClaims(payload: AuthServiceJWTPayload): Claims = {
val claims = ListBuffer[Claim]()

// Any valid token authorizes the user to use public services
claims.append(ClaimPublic)

if (payload.admin)
claims.append(ClaimAdmin)

payload.actAs
.foreach(party => claims.append(ClaimActAsParty(Ref.Party.assertFromString(party))))

Claims(claims.toList, payload.exp)
}
}

object AuthServiceJWT {
final case class Error(message: String)

val AUTHORIZATION_KEY: Metadata.Key[String] =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)

def apply(verifier: com.auth0.jwt.interfaces.JWTVerifier) =
new AuthServiceJWT(new JwtVerifier(verifier))

def apply(verifier: JwtVerifier) =
new AuthServiceJWT(verifier)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (c) 2019 The DAML Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.digitalasset.ledger.api.auth

import java.time.Instant

import spray.json._

/** The JWT token payload used in [[AuthServiceJWT]]
*
*
* @param ledgerId If set, the token is only valid for the given ledger ID.
* May also be used to fill in missing ledger ID fields in ledger API requests.
*
* @param participantId If set, the token is only valid for the given participant ID.
* May also be used to fill in missing participant ID fields in ledger API requests.
*
* @param applicationId If set, the token is only valid for the given application ID.
* May also be used to fill in missing application ID fields in ledger API requests.
*
* @param exp If set, the token is only valid before the given instant.
* Note: This is a registered claim in JWT
*
* @param admin Whether the token bearer is authorized to use admin endpoints of the ledger API.
*
* @param actAs List of parties the token bearer can act as.
* May also be used to fill in missing party fields in ledger API requests (e.g., submitter).
*
* @param readAs List of parties the token bearer can read data for.
* May also be used to fill in missing party fields in ledger API requests (e.g., transaction filter).
*/
case class AuthServiceJWTPayload(
ledgerId: Option[String],
participantId: Option[String],
applicationId: Option[String],
exp: Option[Instant],
admin: Boolean,
actAs: List[String],
readAs: List[String]
)

/**
* Codec for writing and reading [[AuthServiceJWTPayload]] to and from JSON.
*
* In general:
* - All fields are optional in JSON for forward/backward compatibility reasons.
* - Extra JSON fields are ignored when reading.
* - Null values and missing JSON fields map to None or a safe default value (if there is one).
*
* Example:
* ```
* {
* "ledgerId": "aaaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
* "participantId": null,
* "applicationId": null,
* "exp": 1300819380,
* "admin": true,
* "actAs": ["Alice"],
* "readAs": ["Alice", "Bob"],
* }
* ```
*/
object AuthServiceJWTCodec {

// ------------------------------------------------------------------------------------------------------------------
// Constants used in the encoding
// ------------------------------------------------------------------------------------------------------------------
private[this] final val propLedgerId: String = "ledgerId"
private[this] final val propParticipantId: String = "participantId"
private[this] final val propApplicationId: String = "applicationId"
private[this] final val propAdmin: String = "admin"
private[this] final val propActAs: String = "actAs"
private[this] final val propReadAs: String = "readAs"
private[this] final val propExp: String = "exp"

// ------------------------------------------------------------------------------------------------------------------
// Encoding
// ------------------------------------------------------------------------------------------------------------------
def writePayload(v: AuthServiceJWTPayload): JsValue = JsObject(
propLedgerId -> writeOptionalString(v.ledgerId),
propParticipantId -> writeOptionalString(v.participantId),
propApplicationId -> writeOptionalString(v.applicationId),
propAdmin -> JsBoolean(v.admin),
propExp -> writeOptionalInstant(v.exp),
propActAs -> writeStringList(v.actAs),
propReadAs -> writeStringList(v.readAs)
)

/** Writes the given payload to a compact JSON string */
def compactPrint(v: AuthServiceJWTPayload): String = writePayload(v).compactPrint

private[this] def writeOptionalString(value: Option[String]): JsValue =
value.fold[JsValue](JsNull)(JsString(_))

private[this] def writeStringList(value: List[String]): JsValue =
JsArray(value.map(JsString(_)): _*)

private[this] def writeOptionalInstant(value: Option[Instant]): JsValue =
value.fold[JsValue](JsNull)(i => JsNumber(i.getEpochSecond))

// ------------------------------------------------------------------------------------------------------------------
// Decoding
// ------------------------------------------------------------------------------------------------------------------
def readPayload(value: JsValue): AuthServiceJWTPayload = value match {
case JsObject(fields) =>
AuthServiceJWTPayload(
ledgerId = readOptionalString(propLedgerId, fields),
participantId = readOptionalString(propParticipantId, fields),
applicationId = readOptionalString(propApplicationId, fields),
exp = readInstant(propExp, fields),
admin = readOptionalBoolean(propAdmin, fields).getOrElse(false),
actAs = readOptionalStringList(propActAs, fields),
readAs = readOptionalStringList(propReadAs, fields)
)
case _ =>
deserializationError(s"Can't read ${value.prettyPrint} as AuthServiceJWTPayload")
}

private[this] def readOptionalString(name: String, fields: Map[String, JsValue]): Option[String] =
fields.get(name) match {
case None => None
case Some(JsNull) => None
case Some(JsString(value)) => Some(value)
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as string for $name")
}

private[this] def readOptionalStringList(
name: String,
fields: Map[String, JsValue]): List[String] = fields.get(name) match {
case None => List.empty
case Some(JsNull) => List.empty
case Some(JsArray(values)) =>
values.toList.map {
case JsString(value) => value
case value =>
deserializationError(s"Can't read ${value.prettyPrint} as string element for $name")
}
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as string list for $name")
}

private[this] def readOptionalBoolean(
name: String,
fields: Map[String, JsValue]): Option[Boolean] = fields.get(name) match {
case None => None
case Some(JsNull) => None
case Some(JsBoolean(value)) => Some(value)
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as boolean for $name")
}

private[this] def readInstant(name: String, fields: Map[String, JsValue]): Option[Instant] =
fields.get(name) match {
case None => None
case Some(JsNull) => None
case Some(JsNumber(epochSeconds)) => Some(Instant.ofEpochSecond(epochSeconds.longValue()))
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as epoch seconds for $name")
}

// ------------------------------------------------------------------------------------------------------------------
// Implicits that can be imported to write JSON
// ------------------------------------------------------------------------------------------------------------------
object JsonImplicits extends DefaultJsonProtocol {
implicit object AuthServiceJWTPayloadFormat extends RootJsonFormat[AuthServiceJWTPayload] {
override def write(v: AuthServiceJWTPayload): JsValue = writePayload(v)
override def read(json: JsValue): AuthServiceJWTPayload = readPayload(json)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class Authorizer(now: () => Instant) {
requireClaimsForAllPartiesOnStream(party.toList, call)

/** Checks whether the current Claims authorize to act as the given party, if any.
* Note: An missing party does NOT result in an authorization error.
* Note: A missing party does NOT result in an authorization error.
*/
def requireClaimsForParty[Req, Res](
party: Option[String],
Expand Down
Loading

0 comments on commit 89d6c73

Please sign in to comment.