From 01d86142fd33cdeec17e4c2100ad7810881f3b68 Mon Sep 17 00:00:00 2001 From: Oleksandr Karpovich Date: Fri, 17 Jun 2022 11:55:45 +0200 Subject: [PATCH] A prototype of ThreadContextElement on k/js and k/native (copy/paste from jvm sources) --- .../common/src/ThreadContextElement.common.kt | 77 +++++++++++ .../src/internal/ThreadContext.common.kt | 93 ++++++++++++- .../test/ThreadContextElementTest.common.kt | 127 ++++++++++++++++++ .../js/src/CoroutineContext.kt | 97 ++++++++++++- .../js/src/internal/ThreadContext.kt | 2 +- .../jvm/src/ThreadContextElement.kt | 70 ---------- .../jvm/src/internal/ThreadContext.kt | 89 ------------ .../native/src/CoroutineContext.kt | 94 ++++++++++++- .../native/src/internal/ThreadContext.kt | 2 +- 9 files changed, 483 insertions(+), 168 deletions(-) create mode 100644 kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt create mode 100644 kotlinx-coroutines-core/common/test/ThreadContextElementTest.common.kt diff --git a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt new file mode 100644 index 0000000000..81f4017bff --- /dev/null +++ b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt @@ -0,0 +1,77 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.coroutines.* + +/** + * Defines elements in [CoroutineContext] that are installed into thread context + * every time the coroutine with this element in the context is resumed on a thread. + * + * Implementations of this interface define a type [S] of the thread-local state that they need to store on + * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. + * + * Example usage looks like this: + * + * ``` + * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * class CoroutineName(val name: String) : ThreadContextElement { + * // declare companion object for a key of this element in coroutine context + * companion object Key : CoroutineContext.Key + * + * // provide the key of the corresponding context element + * override val key: CoroutineContext.Key + * get() = Key + * + * // this is invoked before coroutine is resumed on current thread + * override fun updateThreadContext(context: CoroutineContext): String { + * val previousName = Thread.currentThread().name + * Thread.currentThread().name = "$previousName # $name" + * return previousName + * } + * + * // this is invoked after coroutine has suspended on current thread + * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + * Thread.currentThread().name = oldState + * } + * } + * + * // Usage + * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } + * ``` + * + * Every time this coroutine is resumed on a thread, UI thread name is updated to + * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * this coroutine suspends. + * + * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. + */ +public interface ThreadContextElement : CoroutineContext.Element { + /** + * Updates context of the current thread. + * This function is invoked before the coroutine in the specified [context] is resumed in the current thread + * when the context of the coroutine this element. + * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + */ + public fun updateThreadContext(context: CoroutineContext): S + + /** + * Restores context of the current thread. + * This function is invoked after the coroutine in the specified [context] is suspended in the current thread + * if [updateThreadContext] was previously invoked on resume of this coroutine. + * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should + * be restored in the thread-local state by this function. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + * @param oldState the value returned by the previous invocation of [updateThreadContext]. + */ + public fun restoreThreadContext(context: CoroutineContext, oldState: S) +} diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt index 6d14e59230..20b4cb53e3 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt @@ -4,6 +4,97 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* +import kotlin.jvm.* -internal expect fun threadContextElements(context: CoroutineContext): Any + +@JvmField +internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") + +// Used when there are >= 2 active elements in the context +@Suppress("UNCHECKED_CAST") +private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { + private val values = arrayOfNulls(n) + private val elements = arrayOfNulls>(n) + private var i = 0 + + fun append(element: ThreadContextElement<*>, value: Any?) { + values[i] = value + elements[i++] = element as ThreadContextElement + } + + fun restore(context: CoroutineContext) { + for (i in elements.indices.reversed()) { + elements[i]!!.restoreThreadContext(context, values[i]) + } + } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +internal val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +private val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +private val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element, element.updateThreadContext(state.context)) + } + return state + } + +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + // ^^^ identity comparison for speed, we know zero always has the same identity + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} +//internal expect fun threadContextElements(context: CoroutineContext): Any + +internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! diff --git a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.common.kt b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.common.kt new file mode 100644 index 0000000000..44e31041f6 --- /dev/null +++ b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.common.kt @@ -0,0 +1,127 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.coroutines.* +import kotlin.test.* + +class ThreadContextElementCommonTest : TestBase() { + + interface TestThreadContextElement : ThreadContextElement { + companion object Key : CoroutineContext.Key + } + + @Test + fun updatesAndRestores() = runTest { + expect(1) + var update = 0 + var restore = 0 + val threadContextElement = object : TestThreadContextElement { + override fun updateThreadContext(context: CoroutineContext): Int { + update++ + return 0 + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Int) { + restore++ + } + + override val key: CoroutineContext.Key<*> + get() = TestThreadContextElement.Key + } + launch(Dispatchers.Unconfined + threadContextElement) { + assertEquals(1, update) + assertEquals(0, restore) + } + assertEquals(1, update) + assertEquals(1, restore) + finish(2) + } + + class TestThreadContextIntElement( + val update: () -> Int, + val restore: (Int) -> Unit + ) : TestThreadContextElement { + override val key: CoroutineContext.Key<*> + get() = TestThreadContextElement.Key + + override fun updateThreadContext(context: CoroutineContext): Int { + return update() + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Int) { + restore(oldState) + } + } + + @Test + fun twoCoroutinesUpdateAndRestore() = runTest { + expect(1) + var state = 0 + + var updateA = 0 + var restoreA = 0 + var updateB = 0 + var restoreB = 0 + + val lock = Job() + println("Launch A") + val jobA = launch(Dispatchers.Unconfined + TestThreadContextIntElement( + update = { + updateA++ + state = 10; 0 + }, + restore = { + restoreA++ + state = it + } + )) { + println("A started") + assertEquals(1, updateA) + assertEquals(10, state) + println("A lock reached") + lock.join() + assertEquals(1, restoreA) + assertEquals(1, updateB) + assertEquals(1, restoreB) + assertEquals(2, updateA) + println("A resumed") + assertEquals(10, state) + println("A completes") + } + + println("Launch B") + launch(Dispatchers.Unconfined + TestThreadContextIntElement( + update = { + updateB++ + state = 20; 0 + }, + restore = { + restoreB++ + state = it + } + )) { + println("B started") + assertEquals(1, updateB) + assertEquals(20, state) + println("B lock complete") + lock.complete() + println("B wait join A") + jobA.join() + assertEquals(2, updateB) + assertEquals(1, restoreB) + assertEquals(2, updateA) + assertEquals(2, restoreA) + println("B resumed") + assertEquals(20, state) + println("B completes") + } + + println("All complete") + assertEquals(0, state) + + finish(2) + } +} diff --git a/kotlinx-coroutines-core/js/src/CoroutineContext.kt b/kotlinx-coroutines-core/js/src/CoroutineContext.kt index 8036c88a10..4fb304f059 100644 --- a/kotlinx-coroutines-core/js/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/js/src/CoroutineContext.kt @@ -47,8 +47,69 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo } // No debugging facilities on JS -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on JS @@ -56,5 +117,35 @@ internal actual class UndispatchedCoroutine actual constructor( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) + + private var threadStateToRecover : Pair? = null + + init { + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + saveThreadContext(context, values) + } + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + threadStateToRecover = context to oldValue + } + + fun clearThreadContext(): Boolean { + if (threadStateToRecover == null) return false + threadStateToRecover = null + return true + } + + override fun afterResume(state: Any?) { + threadStateToRecover?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + threadStateToRecover = null + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + //uCont.resumeWith(recoverResult(state, uCont)) + } } diff --git a/kotlinx-coroutines-core/js/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/js/src/internal/ThreadContext.kt index 2370e42ff4..552bad3f4f 100644 --- a/kotlinx-coroutines-core/js/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/js/src/internal/ThreadContext.kt @@ -6,4 +6,4 @@ package kotlinx.coroutines.internal import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +//internal actual fun threadContextElements(context: CoroutineContext): Any = 0 diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index d2b6b6b988..ed0cccc1d0 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -7,76 +7,6 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -/** - * Defines elements in [CoroutineContext] that are installed into thread context - * every time the coroutine with this element in the context is resumed on a thread. - * - * Implementations of this interface define a type [S] of the thread-local state that they need to store on - * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. - * - * Example usage looks like this: - * - * ``` - * // Appends "name" of a coroutine to a current thread name when coroutine is executed - * class CoroutineName(val name: String) : ThreadContextElement { - * // declare companion object for a key of this element in coroutine context - * companion object Key : CoroutineContext.Key - * - * // provide the key of the corresponding context element - * override val key: CoroutineContext.Key - * get() = Key - * - * // this is invoked before coroutine is resumed on current thread - * override fun updateThreadContext(context: CoroutineContext): String { - * val previousName = Thread.currentThread().name - * Thread.currentThread().name = "$previousName # $name" - * return previousName - * } - * - * // this is invoked after coroutine has suspended on current thread - * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { - * Thread.currentThread().name = oldState - * } - * } - * - * // Usage - * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } - * ``` - * - * Every time this coroutine is resumed on a thread, UI thread name is updated to - * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when - * this coroutine suspends. - * - * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. - */ -public interface ThreadContextElement : CoroutineContext.Element { - /** - * Updates context of the current thread. - * This function is invoked before the coroutine in the specified [context] is resumed in the current thread - * when the context of the coroutine this element. - * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - */ - public fun updateThreadContext(context: CoroutineContext): S - - /** - * Restores context of the current thread. - * This function is invoked after the coroutine in the specified [context] is suspended in the current thread - * if [updateThreadContext] was previously invoked on resume of this coroutine. - * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should - * be restored in the thread-local state by this function. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - * @param oldState the value returned by the previous invocation of [updateThreadContext]. - */ - public fun restoreThreadContext(context: CoroutineContext, oldState: S) -} - /** * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. * diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 8536cef65d..8911934c4d 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -7,95 +7,6 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* -@JvmField -internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") - -// Used when there are >= 2 active elements in the context -@Suppress("UNCHECKED_CAST") -private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { - private val values = arrayOfNulls(n) - private val elements = arrayOfNulls>(n) - private var i = 0 - - fun append(element: ThreadContextElement<*>, value: Any?) { - values[i] = value - elements[i++] = element as ThreadContextElement - } - - fun restore(context: CoroutineContext) { - for (i in elements.indices.reversed()) { - elements[i]!!.restoreThreadContext(context, values[i]) - } - } -} - -// Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) -private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { - if (element is ThreadContextElement<*>) { - val inCount = countOrElement as? Int ?: 1 - return if (inCount == 0) element else inCount + 1 - } - return countOrElement - } - -// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one -private val findOne = - fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { - if (found != null) return found - return element as? ThreadContextElement<*> - } - -// Updates state for ThreadContextElements in the context using the given ThreadState -private val updateState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - if (element is ThreadContextElement<*>) { - state.append(element, element.updateThreadContext(state.context)) - } - return state - } - -internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! - -// countOrElement is pre-cached in dispatched continuation -// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements -internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { - @Suppress("NAME_SHADOWING") - val countOrElement = countOrElement ?: threadContextElements(context) - @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") - return when { - countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements - // ^^^ identity comparison for speed, we know zero always has the same identity - countOrElement is Int -> { - // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values - context.fold(ThreadState(context, countOrElement), updateState) - } - else -> { - // fast path for one ThreadContextElement (no allocations, no additional context scan) - @Suppress("UNCHECKED_CAST") - val element = countOrElement as ThreadContextElement - element.updateThreadContext(context) - } - } -} - -internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { - when { - oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements - oldState is ThreadState -> { - // slow path with multiple stored ThreadContextElements - oldState.restore(context) - } - else -> { - // fast path for one ThreadContextElement, but need to find it - @Suppress("UNCHECKED_CAST") - val element = context.fold(null, findOne) as ThreadContextElement - element.restoreThreadContext(context, oldState) - } - } -} - // top-level data class for a nicer out-of-the-box toString representation and class name @PublishedApi internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index 6e2dac1a29..f66dd8b686 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -52,10 +52,70 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { return this + addedContext } +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + // No debugging facilities on native -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on native @@ -63,5 +123,33 @@ internal actual class UndispatchedCoroutine actual constructor( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) + + private var threadStateToRecover : Pair? = null + init { + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + saveThreadContext(context, values) + } + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + threadStateToRecover = context to oldValue + } + + fun clearThreadContext(): Boolean { + if (threadStateToRecover == null) return false + threadStateToRecover = null + return true + } + + override fun afterResume(state: Any?) { + threadStateToRecover?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + threadStateToRecover = null + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } } diff --git a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt index 2370e42ff4..552bad3f4f 100644 --- a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt @@ -6,4 +6,4 @@ package kotlinx.coroutines.internal import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +//internal actual fun threadContextElements(context: CoroutineContext): Any = 0