Skip to content

Commit

Permalink
Customizable JWT audiences (digital-asset#16330)
Browse files Browse the repository at this point in the history
  • Loading branch information
skisel-da authored Feb 27, 2023
1 parent 87d79f3 commit b33d635
Show file tree
Hide file tree
Showing 47 changed files with 699 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ trait JsonApiFixture
ApiServerConfig.copy(
timeProviderType = TimeProviderType.WallClock
),
authentication = UnsafeJwtHmac256(secret),
authentication = UnsafeJwtHmac256(secret, None),
),
)
def httpPort: Int = suiteResource.value._3.localAddress.getPort
Expand Down Expand Up @@ -151,6 +151,7 @@ trait JsonApiFixture
participantId = None,
exp = None,
format = StandardJWTTokenFormat.Scope,
audiences = List.empty,
)
val header = """{"alg": "HS256", "typ": "JWT"}"""
val jwt = DecodedJwt[String](header, AuthServiceJWTCodec.writeToString(payload))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait SandboxAuthParticipantFixture
ApiServerConfig.copy(
timeProviderType = TimeProviderType.WallClock
),
authentication = UnsafeJwtHmac256(secret),
authentication = UnsafeJwtHmac256(secret, None),
)
)
override def timeMode = ScriptTimeMode.WallClock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ final class LedgerServices(val ledgerId: String) {
private val participantId = "LedgerServicesParticipant"
private val authorizer =
new Authorizer(
() => Clock.systemUTC().instant(),
ledgerId,
participantId,
new InMemoryUserManagementStore(),
executionContext,
now = () => Clock.systemUTC().instant(),
ledgerId = ledgerId,
participantId = participantId,
userManagementStore = new InMemoryUserManagementStore(),
ec = executionContext,
userRightsCheckIntervalInSeconds = 1,
akkaScheduler = akkaSystem.scheduler,
)(LoggingContext.ForTesting)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ message IdentityProviderConfig {
// Required
// Modifiable
string jwks_url = 4;

// Specifies the audience of the JWT token.
// When set, the callers using JWT tokens issued by this identity provider are allowed to get an access
// only if the "aud" claim includes the string specified here
// Optional,
// Modifiable
string audience = 5;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object JwtVerifierConfigurationCliSpec {
private def parseConfig(args: Array[String]): AuthService = {
val parser = new OptionParser[AtomicReference[AuthService]]("test") {}
JwtVerifierConfigurationCli.parse(parser) { (verifier, config) =>
config.set(AuthServiceJWT(verifier))
config.set(AuthServiceJWT(verifier, targetAudience = None))
config
}
parser.parse(args, new AtomicReference[AuthService](AuthServiceWildcard)).get.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import scala.util.Try
/** An AuthService that reads a JWT token from a `Authorization: Bearer` HTTP header.
* The token is expected to use the format as defined in [[AuthServiceJWTPayload]]:
*/
class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
class AuthServiceJWT(verifier: JwtVerifierBase, targetAudience: Option[String])
extends AuthService {

protected val logger: Logger = LoggerFactory.getLogger(AuthServiceJWT.getClass)

Expand All @@ -44,10 +45,37 @@ class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
)

private[this] def parsePayload(jwtPayload: String): Either[Error, AuthServiceJWTPayload] = {
val parsed =
if (targetAudience.isDefined)
Try(parseAudienceBasedPayload(jwtPayload))
else
Try(parseAuthServicePayload(jwtPayload))

parsed.toEither.left
.map(t => Error(Symbol("parsePayload"), "Could not parse JWT token: " + t.getMessage))
.flatMap(checkAudience)
}

private def checkAudience(payload: AuthServiceJWTPayload): Either[Error, AuthServiceJWTPayload] =
(payload, targetAudience) match {
case (payload: StandardJWTPayload, Some(audience)) if payload.audiences.contains(audience) =>
Right(payload)
case (payload, None) =>
Right(payload)
case _ =>
Left(Error(Symbol("checkAudience"), "Could not check the audience"))
}

private[this] def parseAuthServicePayload(jwtPayload: String): AuthServiceJWTPayload = {
import AuthServiceJWTCodec.JsonImplicits._
Try(JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]).toEither.left.map(t =>
Error(Symbol("parsePayload"), "Could not parse JWT token: " + t.getMessage)
)
JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]
}

private[this] def parseAudienceBasedPayload(
jwtPayload: String
): AuthServiceJWTPayload = {
import AuthServiceJWTCodec.AudienceBasedTokenJsonImplicits._
JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]
}

private[this] def parseJWTPayload(header: String): Either[Error, AuthServiceJWTPayload] =
Expand Down Expand Up @@ -98,9 +126,12 @@ class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
}

object AuthServiceJWT {
def apply(verifier: com.auth0.jwt.interfaces.JWTVerifier) =
new AuthServiceJWT(new JwtVerifier(verifier))

def apply(verifier: JwtVerifierBase) =
new AuthServiceJWT(verifier)
def apply(
verifier: com.auth0.jwt.interfaces.JWTVerifier,
targetAudience: Option[String],
): AuthServiceJWT =
new AuthServiceJWT(new JwtVerifier(verifier), targetAudience)

def apply(verifier: JwtVerifierBase, targetAudience: Option[String]): AuthServiceJWT =
new AuthServiceJWT(verifier, targetAudience)
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ object StandardJWTTokenFormat {
* If set then the user is authenticated for the given participantId.
*
* @param exp If set, the token is only valid before the given instant.
* @param audiences If non-empty and it is an audience-based token,
* the token is only valid for the intended recipients.
*/
final case class StandardJWTPayload(
issuer: Option[String],
userId: String,
participantId: Option[String],
exp: Option[Instant],
format: StandardJWTTokenFormat,
audiences: List[String],
) extends AuthServiceJWTPayload

/** Codec for writing and reading [[AuthServiceJWTPayload]] to and from JSON.
Expand Down Expand Up @@ -122,6 +125,8 @@ object AuthServiceJWTCodec {
private[this] final val propActAs: String = "actAs"
private[this] final val propReadAs: String = "readAs"
private[this] final val propExp: String = "exp"
private[this] final val propSub: String = "sub"
private[this] final val propScope: String = "scope"
private[this] final val propParty: String = "party" // Legacy JSON API payload

// ------------------------------------------------------------------------------------------------------------------
Expand All @@ -147,21 +152,39 @@ object AuthServiceJWTCodec {
JsObject(
propIss -> writeOptionalString(v.issuer),
propAud -> writeOptionalString(v.participantId),
"sub" -> JsString(v.userId),
"exp" -> writeOptionalInstant(v.exp),
"scope" -> JsString(scopeLedgerApiFull),
propSub -> JsString(v.userId),
propExp -> writeOptionalInstant(v.exp),
propScope -> JsString(scopeLedgerApiFull),
)
case v: StandardJWTPayload =>
JsObject(
propIss -> writeOptionalString(v.issuer),
propAud -> JsString(audPrefix + v.participantId.getOrElse("")),
"sub" -> JsString(v.userId),
"exp" -> writeOptionalInstant(v.exp),
propSub -> JsString(v.userId),
propExp -> writeOptionalInstant(v.exp),
)
}

def writeAudienceBasedPayload: AuthServiceJWTPayload => JsValue = {
case v: StandardJWTPayload if v.format == StandardJWTTokenFormat.ParticipantId =>
JsObject(
propIss -> writeOptionalString(v.issuer),
propAud -> writeStringList(v.audiences),
propSub -> JsString(v.userId),
propExp -> writeOptionalInstant(v.exp),
)
case _: StandardJWTPayload =>
serializationError(
s"Could not write StandardJWTPayload of Scope format as audience-based payload"
)
case _: CustomDamlJWTPayload =>
serializationError(s"Could not write CustomDamlJWTPayload as audience-based payload")
}

/** Writes the given payload to a compact JSON string */
def compactPrint(v: AuthServiceJWTPayload): String = writePayload(v).compactPrint
def compactPrint(v: AuthServiceJWTPayload, audienceBasedToken: Boolean = false): String =
if (audienceBasedToken) writeAudienceBasedPayload(v).compactPrint
else writePayload(v).compactPrint

private[this] def writeOptionalString(value: Option[String]): JsValue =
value.fold[JsValue](JsNull)(JsString(_))
Expand All @@ -182,9 +205,25 @@ object AuthServiceJWTCodec {
} yield parsed
}

def readAudienceBasedToken(value: JsValue): AuthServiceJWTPayload = value match {
case JsObject(fields) =>
StandardJWTPayload(
issuer = readOptionalString(propIss, fields),
participantId = None,
userId = readString(propSub, fields),
exp = readInstant(propExp, fields),
format = StandardJWTTokenFormat.ParticipantId,
audiences = readOptionalStringOrArray(propAud, fields),
)
case _ =>
deserializationError(
s"Could not read ${value.prettyPrint} as AuthServiceJWTPayload: value is not an object"
)
}

def readPayload(value: JsValue): AuthServiceJWTPayload = value match {
case JsObject(fields) =>
val scope = fields.get("scope")
val scope = fields.get(propScope)
val scopes = scope.toList.collect({ case JsString(scope) => scope.split(" ") }).flatten
// We're using this rather restrictive test to ensure we continue parsing all legacy sandbox tokens that
// are in use before the 2.0 release; and thereby maintain full backwards compatibility.
Expand All @@ -199,11 +238,13 @@ object AuthServiceJWTCodec {
.filter(_.nonEmpty) match {
case participantId :: Nil =>
StandardJWTPayload(
issuer = readOptionalString("iss", fields),
issuer = readOptionalString(propIss, fields),
participantId = Some(participantId),
userId = readOptionalString("sub", fields).get, // guarded by if-clause above
exp = readInstant("exp", fields),
userId = readOptionalString(propSub, fields).get, // guarded by if-clause above
exp = readInstant(propExp, fields),
format = StandardJWTTokenFormat.ParticipantId,
audiences =
List.empty, // we do not read or extract audience claims for ParticipantId-based tokens
)
case Nil =>
deserializationError(
Expand All @@ -228,11 +269,12 @@ object AuthServiceJWTCodec {
)
}
StandardJWTPayload(
issuer = readOptionalString("iss", fields),
issuer = readOptionalString(propIss, fields),
participantId = participantId,
userId = readOptionalString("sub", fields).get, // guarded by if-clause above
exp = readInstant("exp", fields),
userId = readString(propSub, fields),
exp = readInstant(propExp, fields),
format = StandardJWTTokenFormat.Scope,
audiences = List.empty, // we do not read or extract audience claims for Scope-based tokens
)
} else {
if (scope.nonEmpty)
Expand Down Expand Up @@ -293,6 +335,15 @@ object AuthServiceJWTCodec {
deserializationError(s"Could not read ${value.prettyPrint} as string for $name")
}

private[this] def readString(name: String, fields: Map[String, JsValue]): String =
fields.get(name) match {
case Some(JsString(value)) => value
case Some(value) =>
deserializationError(s"Could not read ${value.prettyPrint} as string for $name")
case _ =>
deserializationError(s"Could not read value for $name")
}

private[this] def readOptionalStringOrArray(
name: String,
fields: Map[String, JsValue],
Expand Down Expand Up @@ -355,4 +406,11 @@ object AuthServiceJWTCodec {
override def read(json: JsValue): AuthServiceJWTPayload = readPayload(json)
}
}
object AudienceBasedTokenJsonImplicits extends DefaultJsonProtocol {
implicit object AuthServiceJWTPayloadFormat extends RootJsonFormat[AuthServiceJWTPayload] {
override def write(v: AuthServiceJWTPayload): JsValue = writeAudienceBasedPayload(v)

override def read(json: JsValue): AuthServiceJWTPayload = readAudienceBasedToken(json)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object ClaimSet {
* @param expiration If set, the claims will cease to be valid at the given time.
* @param resolvedFromUser If set, then the claims were resolved from a user in the user management service.
* @param identityProviderId If set, the claims will only be valid on the given Identity Provider configuration.
* @param audience Claims which identifies the intended recipients.
*/
final case class Claims(
claims: Seq[Claim],
Expand All @@ -83,6 +84,7 @@ object ClaimSet {
identityProviderId: IdentityProviderId,
resolvedFromUser: Boolean,
) extends ClaimSet {

def validForLedger(id: String): Either[AuthorizationError, Unit] =
Either.cond(ledgerId.forall(_ == id), (), AuthorizationError.InvalidLedger(ledgerId.get, id))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,28 @@ class IdentityProviderAwareAuthServiceImpl(
keyId,
)
decodedJwt <- verifyToken(token, verifier)
payload <- Future(parse(decodedJwt.payload))
payload <- Future(
parse(decodedJwt.payload, targetAudience = identityProviderConfig.audience)
)
_ <- checkAudience(payload, identityProviderConfig.audience)
jwtPayload <- parsePayload(payload)
} yield toAuthenticatedUser(jwtPayload, identityProviderConfig.identityProviderId)
}
}

private def checkAudience(
payload: AuthServiceJWTPayload,
targetAudience: Option[String],
): Future[Unit] =
(payload, targetAudience) match {
case (payload: StandardJWTPayload, Some(audience)) if payload.audiences.contains(audience) =>
Future.unit
case (_, None) =>
Future.unit
case _ =>
Future.failed(new Exception(s"JWT token has an audience which is not recognized"))
}

private def verifyToken(token: String, verifier: JwtVerifier): Future[DecodedJwt[String]] =
toFuture(verifier.verify(com.daml.jwt.domain.Jwt(token)).toEither)

Expand All @@ -90,11 +106,24 @@ class IdentityProviderAwareAuthServiceImpl(
Future.successful(payload)
}

private def parse(jwtPayload: String): AuthServiceJWTPayload = {
private def parse(jwtPayload: String, targetAudience: Option[String]): AuthServiceJWTPayload =
if (targetAudience.isDefined)
parseAudienceBasedPayload(jwtPayload)
else
parseAuthServicePayload(jwtPayload)

private def parseAuthServicePayload(jwtPayload: String): AuthServiceJWTPayload = {
import AuthServiceJWTCodec.JsonImplicits._
JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]
}

private[this] def parseAudienceBasedPayload(
jwtPayload: String
): AuthServiceJWTPayload = {
import AuthServiceJWTCodec.AudienceBasedTokenJsonImplicits._
JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]
}

private def toAuthenticatedUser(payload: StandardJWTPayload, id: IdentityProviderId.Id) =
ClaimSet.AuthenticatedUser(
identityProviderId = id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ final class AuthorizationInterceptor(
Future.failed(
LedgerApiErrors.AuthorizationChecks.PermissionDenied
.Reject(
s"Could not resolve is_deactivated status for user '$userId' due to '$msg'"
s"Could not resolve is_deactivated status for user '$userId' and identity_provider_id '$identityProviderId' due to '$msg'"
)(errorLogger)
.asGrpcError
)
Expand Down
Loading

0 comments on commit b33d635

Please sign in to comment.