Skip to content

Commit

Permalink
Add extension methods in Kotlin for request scoping (#5171)
Browse files Browse the repository at this point in the history
Motivation:
- `ArmeriaRequestCoroutineContext` for Kotlin is an internal API that users shouldn't use. We need to provide a public API for that.
- It would be nice if we could add `ArmeriaRequestCoroutineContext(ctx)` automatically for `ContextAwareExecutor` dispatcher so that the context is automatically propagated even when threads are changed.

Modifications:
- Add `RequestContext.asCoroutineContext()` extension method that replaces the usage of `ArmeriaRequestCoroutineContext(ctx)`
- Add `ContextAwareExecutor.asCoroutineDispatcher()` extension method that returns the dispatcher with `RequestContext.asCoroutineContext()`.

Result:
- You can now easily propagate an Armeria `RequestContext` in kotlin using `ContextAwareExecutor.asCoroutineDispatcher()`.
  • Loading branch information
minwoox authored Sep 25, 2023
1 parent 33f1b57 commit d8be51b
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 23 deletions.
1 change: 1 addition & 0 deletions examples/context-propagation/kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ plugins {

dependencies {
implementation(project(":core"))
implementation(project(":kotlin"))
runtimeOnly(libs.slf4j.simple)

implementation(kotlin("stdlib-jdk8"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ import com.linecorp.armeria.common.HttpRequest
import com.linecorp.armeria.common.HttpResponse
import com.linecorp.armeria.server.HttpService
import com.linecorp.armeria.server.ServiceRequestContext
import java.time.Duration
import java.util.concurrent.CompletableFuture
import java.util.function.Supplier
import java.util.stream.Collectors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.future.asDeferred
import kotlinx.coroutines.future.await
import kotlinx.coroutines.future.future
import java.time.Duration
import java.util.concurrent.CompletableFuture
import java.util.function.Supplier
import java.util.stream.Collectors
import kotlinx.coroutines.withContext

class MainService(private val backendClient: WebClient) : HttpService {
override fun serve(ctx: ServiceRequestContext, req: HttpRequest): HttpResponse {
val ctxExecutor = ctx.eventLoop()
val response = GlobalScope.future(ctxExecutor.asCoroutineDispatcher()) {
val response = GlobalScope.future(ctx.eventLoop().asCoroutineDispatcher()) {
val numsFromRequest = async { fetchFromRequest(ctx, req) }
val numsFromDb = async { fetchFromFakeDb(ctx) }
val nums = awaitAll(numsFromRequest, numsFromDb).flatten()
Expand Down Expand Up @@ -75,24 +76,29 @@ class MainService(private val backendClient: WebClient) : HttpService {
}

private suspend fun fetchFromRequest(ctx: ServiceRequestContext, req: HttpRequest): List<Long> {
// The context is mounted in a thread-local, meaning it is available to all logic such as tracing.
require(ServiceRequestContext.current() === ctx)
require(ctx.eventLoop().inEventLoop())
// Switch to the default dispatcher.
val nums = withContext(Dispatchers.Default) {
// The thread is switched.
require(!ctx.eventLoop().inEventLoop())
// The context is still mounted in a thread-local.
require(ServiceRequestContext.current() === ctx)

val aggregatedHttpRequest = req.aggregate().await()
val aggregatedHttpRequest = req.aggregate().await()

// The context is kept after resume.
require(ServiceRequestContext.current() === ctx)
require(ctx.eventLoop().inEventLoop())
// The context is kept after resume.
require(ServiceRequestContext.current() === ctx)
require(!ctx.eventLoop().inEventLoop())

val nums = mutableListOf<Long>()
for (
token in Iterables.concat(
NUM_SPLITTER.split(aggregatedHttpRequest.path().substring(1)),
NUM_SPLITTER.split(aggregatedHttpRequest.contentUtf8())
)
) {
nums.add(token.toLong())
val nums = mutableListOf<Long>()
for (
token in Iterables.concat(
NUM_SPLITTER.split(aggregatedHttpRequest.path().substring(1)),
NUM_SPLITTER.split(aggregatedHttpRequest.contentUtf8())
)
) {
nums.add(token.toLong())
}
nums
}
return nums
}
Expand Down
1 change: 1 addition & 0 deletions kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies {
implementation(libs.kotlin.coroutines.jdk8)
implementation(libs.kotlin.reflect)

testImplementation(libs.kotlin.coroutines.test)
testImplementation(libs.reactivestreams.tck)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.common.kotlin

import com.linecorp.armeria.common.ContextAwareExecutor
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher

/**
* Converts an instance of [ContextAwareExecutor] to an implementation of [CoroutineDispatcher].
* The returned [CoroutineContext] also contains an [ArmeriaRequestCoroutineContext] that automatically
* propagates the [ContextAwareExecutor.context] when the coroutine is resumed on a thread.
*/
fun ContextAwareExecutor.asCoroutineDispatcher(): CoroutineContext {
return this.withoutContext().asCoroutineDispatcher() + context().asCoroutineContext()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.common.kotlin

import com.linecorp.armeria.common.RequestContext
import com.linecorp.armeria.common.util.SafeCloseable
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.ThreadContextElement

/**
* Converts an instance of [RequestContext] to an implementation of [CoroutineContext] that automatically
* propagates the [RequestContext]. The propagation is done by [RequestContext.push] when the coroutine is
* resumed on a thread.
*/
fun RequestContext.asCoroutineContext(): ArmeriaRequestCoroutineContext {
return ArmeriaRequestCoroutineContext(this)
}

/**
* Propagates [RequestContext] over coroutines.
*/
class ArmeriaRequestCoroutineContext internal constructor(
private val requestContext: RequestContext
) : ThreadContextElement<SafeCloseable>, AbstractCoroutineContextElement(Key) {

companion object Key : CoroutineContext.Key<ArmeriaRequestCoroutineContext>

override fun updateThreadContext(context: CoroutineContext): SafeCloseable {
return requestContext.push()
}

override fun restoreThreadContext(context: CoroutineContext, oldState: SafeCloseable) {
oldState.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package com.linecorp.armeria.internal.common.kotlin

import com.linecorp.armeria.common.ContextAwareExecutor
import com.linecorp.armeria.common.kotlin.CoroutineContexts
import com.linecorp.armeria.common.kotlin.asCoroutineContext
import com.linecorp.armeria.common.kotlin.asCoroutineDispatcher
import com.linecorp.armeria.internal.common.stream.StreamMessageUtil
import com.linecorp.armeria.server.ServiceRequestContext
import io.netty.util.concurrent.EventExecutor
Expand Down Expand Up @@ -82,6 +85,8 @@ internal fun <T : Any> Flow<T>.asPublisher(

private fun newCoroutineCtx(executorService: ExecutorService, ctx: ServiceRequestContext): CoroutineContext {
val userContext = CoroutineContexts.get(ctx) ?: EmptyCoroutineContext
val requestContext = ArmeriaRequestCoroutineContext(ctx)
return executorService.asCoroutineDispatcher() + requestContext + userContext
if (executorService is ContextAwareExecutor) {
return (executorService as ContextAwareExecutor).asCoroutineDispatcher() + userContext
}
return executorService.asCoroutineDispatcher() + ctx.asCoroutineContext() + userContext
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import kotlin.coroutines.CoroutineContext
/**
* Propagates [ServiceRequestContext] over coroutines.
*/
@Deprecated("Use RequestContext.asCoroutineContext() instead.", ReplaceWith("RequestContext.asCoroutineContext()"))
class ArmeriaRequestCoroutineContext(
private val requestContext: ServiceRequestContext
) : ThreadContextElement<SafeCloseable>, AbstractCoroutineContextElement(Key) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License
*/

package com.linecorp.armeria.common.kotlin

import com.linecorp.armeria.client.ClientRequestContext
import com.linecorp.armeria.common.HttpMethod
import com.linecorp.armeria.common.HttpRequest
import com.linecorp.armeria.internal.testing.GenerateNativeImageTrace
import com.linecorp.armeria.server.ServiceRequestContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

@GenerateNativeImageTrace
class CoroutineContextAwareExecutorTest {

@Test
fun serviceRequestContext() {
val ctx = ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")).build()
runTest {
withContext(ctx.eventLoop().asCoroutineDispatcher()) {
assertThat(ServiceRequestContext.current()).isSameAs(ctx)
assertThat(ctx.eventLoop().inEventLoop()).isTrue()
withContext(Dispatchers.Default) {
assertThat(ServiceRequestContext.current()).isSameAs(ctx)
assertThat(ctx.eventLoop().inEventLoop()).isFalse()
}
}
}
}

@Test
fun clientRequestContext() {
val ctx = ClientRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")).build()
runTest {
withContext(ctx.eventLoop().asCoroutineDispatcher()) {
assertThat(ClientRequestContext.current()).isSameAs(ctx)
assertThat(ctx.eventLoop().inEventLoop()).isTrue()
withContext(Dispatchers.Default) {
assertThat(ClientRequestContext.current()).isSameAs(ctx)
assertThat(ctx.eventLoop().inEventLoop()).isFalse()
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License
*/

package com.linecorp.armeria.common.kotlin

import com.linecorp.armeria.client.ClientRequestContext
import com.linecorp.armeria.common.HttpMethod
import com.linecorp.armeria.common.HttpRequest
import com.linecorp.armeria.internal.testing.GenerateNativeImageTrace
import com.linecorp.armeria.server.ServiceRequestContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

@GenerateNativeImageTrace
class CoroutineRequestContextTest {

@Test
fun serviceRequestContext() {
val ctx = ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")).build()
runTest(ctx.asCoroutineContext()) {
assertThat(ServiceRequestContext.current()).isSameAs(ctx)
withContext(Dispatchers.Default) {
assertThat(ServiceRequestContext.current()).isSameAs(ctx)
}
}
}

@Test
fun clientRequestContext() {
val ctx = ClientRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")).build()
runTest(ctx.asCoroutineContext()) {
assertThat(ClientRequestContext.current()).isSameAs(ctx)
withContext(Dispatchers.Default) {
assertThat(ClientRequestContext.current()).isSameAs(ctx)
}
}
}
}

0 comments on commit d8be51b

Please sign in to comment.