1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.app.tracing.coroutines
18 
19 import com.android.systemui.Flags.coroutineTracing
20 import kotlin.coroutines.CoroutineContext
21 import kotlin.coroutines.EmptyCoroutineContext
22 import kotlinx.coroutines.CopyableThreadContextElement
23 
24 private const val DEBUG = false
25 
26 /** Log a message with a tag indicating the current thread ID */
debugnull27 private inline fun debug(message: () -> String) {
28     if (DEBUG) println("Thread #${Thread.currentThread().id}: ${message()}")
29 }
30 
31 /** Use a final subclass to avoid virtual calls (b/316642146). */
32 @PublishedApi internal class TraceDataThreadLocal : ThreadLocal<TraceData?>()
33 
34 /**
35  * Thread-local storage for giving each thread a unique [TraceData]. It can only be used when paired
36  * with a [TraceContextElement].
37  *
38  * [traceThreadLocal] will be `null` if either 1) we aren't in a coroutine, or 2) the current
39  * coroutine context does not have [TraceContextElement]. In both cases, writing to this
40  * thread-local would be undefined behavior if it were not null, which is why we use null as the
41  * default value rather than an empty TraceData.
42  *
43  * @see traceCoroutine
44  */
45 @PublishedApi internal val traceThreadLocal = TraceDataThreadLocal()
46 
47 /**
48  * Returns a new [CoroutineContext] used for tracing. Used to hide internal implementation details.
49  */
createCoroutineTracingContextnull50 fun createCoroutineTracingContext(): CoroutineContext {
51     return if (coroutineTracing()) TraceContextElement(TraceData()) else EmptyCoroutineContext
52 }
53 
54 /**
55  * Used for safely persisting [TraceData] state when coroutines are suspended and resumed.
56  *
57  * This is internal machinery for [traceCoroutine]. It cannot be made `internal` or `private`
58  * because [traceCoroutine] is a Public-API inline function.
59  *
60  * @see traceCoroutine
61  */
62 internal class TraceContextElement(internal val traceData: TraceData? = TraceData()) :
63     CopyableThreadContextElement<TraceData?> {
64 
65     internal companion object Key : CoroutineContext.Key<TraceContextElement>
66 
67     override val key: CoroutineContext.Key<*>
68         get() = Key
69 
70     init {
<lambda>null71         debug { "$this #init" }
72     }
73 
74     /**
75      * This function is invoked before the coroutine is resumed on the current thread. When a
76      * multi-threaded dispatcher is used, calls to `updateThreadContext` may happen in parallel to
77      * the prior `restoreThreadContext` in the same context. However, calls to `updateThreadContext`
78      * will not run in parallel on the same context.
79      *
80      * ```
81      * Thread #1 | [updateThreadContext]....^              [restoreThreadContext]
82      * --------------------------------------------------------------------------------------------
83      * Thread #2 |                           [updateThreadContext]...........^[restoreThreadContext]
84      * ```
85      *
86      * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
87      * `^` is a suspension point)
88      */
updateThreadContextnull89     override fun updateThreadContext(context: CoroutineContext): TraceData? {
90         val oldState = traceThreadLocal.get()
91         debug { "$this #updateThreadContext oldState=$oldState" }
92         if (oldState !== traceData) {
93             traceThreadLocal.set(traceData)
94             // Calls to `updateThreadContext` will not happen in parallel on the same context, and
95             // they cannot happen before the prior suspension point. Additionally,
96             // `restoreThreadContext` does not modify `traceData`, so it is safe to iterate over the
97             // collection here:
98             traceData?.beginAllOnThread()
99         }
100         return oldState
101     }
102 
103     /**
104      * This function is invoked after the coroutine has suspended on the current thread. When a
105      * multi-threaded dispatcher is used, calls to `restoreThreadContext` may happen in parallel to
106      * the subsequent `updateThreadContext` and `restoreThreadContext` operations. The coroutine
107      * body itself will not run in parallel, but `TraceData` could be modified by a coroutine body
108      * after the suspension point in parallel to `restoreThreadContext` associated with the
109      * coroutine body _prior_ to the suspension point.
110      *
111      * ```
112      * Thread #1 | [updateThreadContext].x..^              [restoreThreadContext]
113      * --------------------------------------------------------------------------------------------
114      * Thread #2 |                           [updateThreadContext]..x..x.....^[restoreThreadContext]
115      * ```
116      *
117      * OR
118      *
119      * ```
120      * Thread #1 |                                 [restoreThreadContext]
121      * --------------------------------------------------------------------------------------------
122      * Thread #2 |     [updateThreadContext]...x....x..^[restoreThreadContext]
123      * ```
124      *
125      * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
126      * `^` is a suspension point; `x` are calls to modify the thread-local trace data)
127      *
128      * ```
129      */
restoreThreadContextnull130     override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
131         debug { "$this#restoreThreadContext restoring=$oldState" }
132         // We not use the `TraceData` object here because it may have been modified on another
133         // thread after the last suspension point. This is why we use a [TraceStateHolder]:
134         // so we can end the correct number of trace sections, restoring the thread to its state
135         // prior to the last call to [updateThreadContext].
136         if (oldState !== traceThreadLocal.get()) {
137             traceData?.endAllOnThread()
138             traceThreadLocal.set(oldState)
139         }
140     }
141 
copyForChildnull142     override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
143         debug { "$this #copyForChild" }
144         return TraceContextElement(traceData?.clone())
145     }
146 
mergeForChildnull147     override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
148         debug { "$this #mergeForChild" }
149         // For our use-case, we always give precedence to the parent trace context, and the
150         // child context (overwritingElement) is ignored
151         return TraceContextElement(traceData?.clone())
152     }
153 
toStringnull154     override fun toString(): String {
155         return "TraceContextElement@${hashCode().toHexString()}[$traceData]"
156     }
157 }
158