Skip to content

Commit

Permalink
Make CommandTracker distinguish submissions of the same command usi…
Browse files Browse the repository at this point in the history
…ng `submissionId` [KVL-1104]

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
hubert-da committed Sep 13, 2021
1 parent b50bb8e commit 832a1e8
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
package com.daml.ledger.client.services.commands

import java.time.Duration

import akka.NotUsed
import akka.stream.scaladsl.{Concat, Flow, GraphDSL, Merge, Source}
import akka.stream.{DelayOverflowStrategy, FlowShape, OverflowStrategy}
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
import com.daml.ledger.client.services.commands.tracker.CommandTracker
import com.daml.ledger.client.services.commands.tracker.{TrackedCommandKey, CommandTracker}
import com.daml.ledger.client.services.commands.tracker.CompletionResponse.{
CompletionFailure,
CompletionSuccess,
Expand All @@ -32,13 +31,13 @@ object CommandTrackerFlow {

final case class Materialized[SubmissionMat, Context](
submissionMat: SubmissionMat,
trackingMat: Future[immutable.Map[String, Context]],
trackingMat: Future[immutable.Map[TrackedCommandKey, Context]],
)

def apply[Context, SubmissionMat](
commandSubmissionFlow: Flow[
Ctx[(Context, String), CommandSubmission],
Ctx[(Context, String), Try[
Ctx[(Context, TrackedCommandKey), CommandSubmission],
Ctx[(Context, TrackedCommandKey), Try[
Empty
]],
SubmissionMat,
Expand All @@ -62,12 +61,13 @@ object CommandTrackerFlow {
implicit builder => (submissionFlow, tracker) =>
import GraphDSL.Implicits._

val wrapResult = builder.add(Flow[Ctx[(Context, String), Try[Empty]]].map(Left.apply))
val wrapResult =
builder.add(Flow[Ctx[(Context, TrackedCommandKey), Try[Empty]]].map(Left.apply))

val wrapCompletion = builder.add(Flow[CompletionStreamElement].map(Right.apply))

val merge = builder.add(
Merge[Either[Ctx[(Context, String), Try[Empty]], CompletionStreamElement]](
Merge[Either[Ctx[(Context, TrackedCommandKey), Try[Empty]], CompletionStreamElement]](
inputPorts = 2,
eagerComplete = false,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,30 @@ private[commands] class CommandTracker[Context](
timeoutDetectionPeriod: FiniteDuration,
) extends GraphStageWithMaterializedValue[
CommandTrackerShape[Context],
Future[Map[String, Context]],
Future[Map[TrackedCommandKey, Context]],
] {

private val logger = LoggerFactory.getLogger(this.getClass.getName)

val submitRequestIn: Inlet[Ctx[Context, CommandSubmission]] =
Inlet[Ctx[Context, CommandSubmission]]("submitRequestIn")
val submitRequestOut: Outlet[Ctx[(Context, String), CommandSubmission]] =
Outlet[Ctx[(Context, String), CommandSubmission]]("submitRequestOut")
val commandResultIn: Inlet[Either[Ctx[(Context, String), Try[Empty]], CompletionStreamElement]] =
Inlet[Either[Ctx[(Context, String), Try[Empty]], CompletionStreamElement]]("commandResultIn")
val submitRequestOut: Outlet[Ctx[(Context, TrackedCommandKey), CommandSubmission]] =
Outlet[Ctx[(Context, TrackedCommandKey), CommandSubmission]]("submitRequestOut")
val commandResultIn
: Inlet[Either[Ctx[(Context, TrackedCommandKey), Try[Empty]], CompletionStreamElement]] =
Inlet[Either[Ctx[(Context, TrackedCommandKey), Try[Empty]], CompletionStreamElement]](
"commandResultIn"
)
val resultOut: Outlet[Ctx[Context, Either[CompletionFailure, CompletionSuccess]]] =
Outlet[Ctx[Context, Either[CompletionFailure, CompletionSuccess]]]("resultOut")
val offsetOut: Outlet[LedgerOffset] =
Outlet[LedgerOffset]("offsetOut")

override def createLogicAndMaterializedValue(
inheritedAttributes: Attributes
): (GraphStageLogic, Future[Map[String, Context]]) = {
): (GraphStageLogic, Future[Map[TrackedCommandKey, Context]]) = {

val promise = Promise[immutable.Map[String, Context]]()
val promise = Promise[immutable.Map[TrackedCommandKey, Context]]()

val logic: TimerGraphStageLogic = new TimerGraphStageLogic(shape) {

Expand All @@ -96,7 +99,7 @@ private[commands] class CommandTracker[Context](
}
}

private val pendingCommands = new mutable.HashMap[String, TrackingData[Context]]()
private val pendingCommands = new mutable.HashMap[TrackedCommandKey, TrackingData[Context]]()

setHandler(
submitRequestOut,
Expand All @@ -116,11 +119,20 @@ private[commands] class CommandTracker[Context](
override def onPush(): Unit = {
val submitRequest = grab(submitRequestIn)
registerSubmission(submitRequest)
val commands = submitRequest.value.commands
val submissionId = commands.submissionId
val commandId = commands.commandId
logger.trace(
"Submitted command {}",
submitRequest.value.commands.commandId,
"Submitted command {} in submission {}",
commandId,
submissionId,
)
push(
submitRequestOut,
submitRequest.enrich((context, _) =>
context -> TrackedCommandKey(submissionId, commandId)
),
)
push(submitRequestOut, submitRequest.enrich(_ -> _.commands.commandId))
}

override def onUpstreamFinish(): Unit = {
Expand Down Expand Up @@ -197,33 +209,40 @@ private[commands] class CommandTracker[Context](

import CommandTracker.nonTerminalCodes

private def handleSubmitResponse(submitResponse: Ctx[(Context, String), Try[Empty]]) = {
val Ctx((_, commandId), value, _) = submitResponse
private def handleSubmitResponse(
submitResponse: Ctx[(Context, TrackedCommandKey), Try[Empty]]
) = {
val Ctx((_, commandKey), value, _) = submitResponse
value match {
case Failure(GrpcException(status @ GrpcStatus(code, _), metadata))
if !nonTerminalCodes(code) =>
getOutputForTerminalStatusCode(commandId, GrpcStatus.toProto(status, metadata))
getOutputForTerminalStatusCode(commandKey, GrpcStatus.toProto(status, metadata))
case Failure(throwable) =>
logger.warn(
s"Service responded with error for submitting command with context ${submitResponse.context}. Status of command is unknown. watching for completion...",
throwable,
)
None
case Success(_) =>
logger.trace("Received confirmation that command {} was accepted.", commandId)
logger.trace(
"Received confirmation that command {} from submission {} was accepted.",
commandKey.commandId,
commandKey.submissionId,
)
None
}
}

@nowarn("msg=deprecated")
private def registerSubmission(submission: Ctx[Context, CommandSubmission]): Unit = {
val commands = submission.value.commands
val submissionId = commands.submissionId
val commandId = commands.commandId
logger.trace("Begin tracking of command {}", commandId)
if (pendingCommands.contains(commandId)) {
logger.trace("Begin tracking of command {} for submission {}", commandId, submissionId)
if (pendingCommands.contains(TrackedCommandKey(submissionId, commandId))) {
// TODO return an error identical to the server side duplicate command error once that's defined.
throw new IllegalStateException(
s"A command with id $commandId is already being tracked. CommandIds submitted to the CommandTracker must be unique."
s"A command $commandId from a submission $submissionId is already being tracked. CommandIds submitted to the CommandTracker must be unique."
) with NoStackTrace
}
val commandTimeout = submission.value.timeout match {
Expand All @@ -249,19 +268,20 @@ private[commands] class CommandTracker[Context](
commandTimeout = Instant.now().plus(commandTimeout),
context = submission.context,
)
pendingCommands += commandId -> trackingData
pendingCommands += TrackedCommandKey(submissionId, commandId) -> trackingData
()
}

private def getOutputForTimeout(instant: Instant) = {
logger.trace("Checking timeouts at {}", instant)
pendingCommands.view
.flatMap { case (commandId, trackingData) =>
.flatMap { case (commandKey, trackingData) =>
if (trackingData.commandTimeout.isBefore(instant)) {
pendingCommands -= commandId
pendingCommands -= commandKey
logger.info(
s"Command {} (command timeout {}) timed out at checkpoint {}.",
commandId,
s"Command {} from submission {} (command timeout {}) timed out at checkpoint {}.",
commandKey.commandId,
commandKey.submissionId,
trackingData.commandTimeout,
instant,
)
Expand All @@ -279,33 +299,56 @@ private[commands] class CommandTracker[Context](
}

private def getOutputForCompletion(completion: Completion) = {
val (commandId, errorText) = {
val (commandKey, errorText) = {
completion.status match {
case Some(StatusProto(code, _, _, _)) if code == Status.Code.OK.value =>
completion.commandId -> "successful completion of command"
TrackedCommandKey(
completion.submissionId,
completion.commandId,
) -> "successful completion of command"
case _ =>
completion.commandId -> "failed completion of command"
TrackedCommandKey(
completion.submissionId,
completion.commandId,
) -> "failed completion of command"
}
}

logger.trace("Handling {} {}", errorText, completion.commandId: Any)
pendingCommands.remove(commandId).map { t =>
pendingCommands.remove(commandKey).map { t =>
Ctx(t.context, tracker.CompletionResponse(completion))
}
}

private def getOutputForTerminalStatusCode(
commandId: String,
commandKey: TrackedCommandKey,
status: StatusProto,
): Option[Ctx[Context, Either[CompletionFailure, CompletionSuccess]]] = {
logger.trace("Handling failure of command {}", commandId)
logger.trace(
"Handling failure of command {} from submission {}",
commandKey.commandId,
commandKey.submissionId,
)
pendingCommands
.remove(commandId)
.remove(commandKey)
.map { t =>
Ctx(t.context, tracker.CompletionResponse(Completion(commandId, Some(status))))
Ctx(
t.context,
tracker.CompletionResponse(
Completion(
commandKey.commandId,
Some(status),
submissionId = commandKey.submissionId,
)
),
)
}
.orElse {
logger.trace("Platform signaled failure for unknown command {}", commandId)
logger.trace(
"Platform signaled failure for unknown command {} from submission {}",
commandKey.commandId,
commandKey.submissionId,
)
None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import scala.util.Try

private[tracker] final case class CommandTrackerShape[Context](
submitRequestIn: Inlet[Ctx[Context, CommandSubmission]],
submitRequestOut: Outlet[Ctx[(Context, String), CommandSubmission]],
commandResultIn: Inlet[Either[Ctx[(Context, String), Try[Empty]], CompletionStreamElement]],
submitRequestOut: Outlet[Ctx[(Context, TrackedCommandKey), CommandSubmission]],
commandResultIn: Inlet[
Either[Ctx[(Context, TrackedCommandKey), Try[Empty]], CompletionStreamElement]
],
resultOut: Outlet[Ctx[Context, Either[CompletionFailure, CompletionSuccess]]],
offsetOut: Outlet[LedgerOffset],
) extends Shape {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.client.services.commands.tracker

case class TrackedCommandKey(submissionId: String, commandId: String)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.daml.ledger.api.validation.CommandsValidator
import com.daml.ledger.client.LedgerClient
import com.daml.ledger.client.configuration.CommandClientConfiguration
import com.daml.ledger.client.services.commands.CommandTrackerFlow.Materialized
import com.daml.ledger.client.services.commands.tracker.TrackedCommandKey
import com.daml.ledger.client.services.commands.tracker.CompletionResponse.{
CompletionFailure,
CompletionSuccess,
Expand Down Expand Up @@ -149,7 +150,7 @@ private[daml] final class CommandClient(
.via(commandUpdaterFlow[Context](ledgerIdToUse))
.viaMat(
CommandTrackerFlow[Context, NotUsed](
commandSubmissionFlow = CommandSubmissionFlow[(Context, String)](
commandSubmissionFlow = CommandSubmissionFlow[(Context, TrackedCommandKey)](
submit(token),
config.maxParallelSubmissions,
),
Expand Down
Loading

0 comments on commit 832a1e8

Please sign in to comment.