1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software distributed under the
11  * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12  * KIND, either express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 package android.platform.uiautomator_helpers
16 
17 import android.os.SystemClock.sleep
18 import android.os.SystemClock.uptimeMillis
19 import android.os.Trace
20 import android.platform.uiautomator_helpers.TracingUtils.trace
21 import android.platform.uiautomator_helpers.WaitUtils.LoggerImpl.Companion.withEventualLogging
22 import android.util.Log
23 import androidx.test.uiautomator.StaleObjectException
24 import java.io.Closeable
25 import java.time.Duration
26 import java.time.Instant.now
27 
28 sealed interface WaitResult {
29     data class WaitThrown(val thrown: Throwable?) : WaitResult
30     data object WaitSuccess : WaitResult
31     data object WaitFailure : WaitResult
32 }
33 
34 data class WaitReport(val result: WaitResult, val iterations: Int)
35 
36 /**
37  * Collection of utilities to ensure a certain conditions is met.
38  *
39  * Those are meant to make tests more understandable from perfetto traces, and less flaky.
40  */
41 object WaitUtils {
42     private val DEFAULT_DEADLINE = Duration.ofSeconds(10)
43     private val POLLING_WAIT = Duration.ofMillis(100)
44     private val DEFAULT_SETTLE_TIME = Duration.ofSeconds(3)
45     private const val TAG = "WaitUtils"
46     private const val VERBOSE = true
47 
48     /**
49      * Ensures that [condition] succeeds within [timeout], or fails with [errorProvider] message.
50      *
51      * This also logs with atrace each iteration, and its entire execution. Those traces are then
52      * visible in perfetto. Note that logs are output only after the end of the method, all
53      * together.
54      *
55      * Example of usage:
56      * ```
57      * ensureThat("screen is on") { uiDevice.isScreenOn }
58      * ```
59      */
60     @JvmStatic
61     @JvmOverloads
ensureThatnull62     fun ensureThat(
63         description: String? = null,
64         timeout: Duration = DEFAULT_DEADLINE,
65         errorProvider: (() -> String)? = null,
66         ignoreFailure: Boolean = false,
67         ignoreException: Boolean = false,
68         condition: () -> Boolean,
69     ) {
70         val errorProvider =
71             errorProvider
72                 ?: { "Error ensuring that \"$description\" within ${timeout.toMillis()}ms" }
73         waitToBecomeTrue(description, timeout, condition).run {
74             when (result) {
75                 WaitResult.WaitSuccess -> return
76                 WaitResult.WaitFailure -> {
77                     if (ignoreFailure) {
78                         Log.w(TAG, "Ignoring ensureThat failure: ${errorProvider()}")
79                     } else {
80                         throw FailedEnsureException(errorProvider())
81                     }
82                 }
83                 is WaitResult.WaitThrown -> {
84                     if (!ignoreException) {
85                         throw RuntimeException("[#$iterations] iteration failed.", result.thrown)
86                     } else {
87                         return
88                     }
89                 }
90             }
91         }
92     }
93 
94     /**
95      * Wait until [timeout] for [condition] to become true, and then return a [WaitReport] with the
96      * result.
97      *
98      * This can be a useful replacement for [ensureThat] in situations where you want to wait for
99      * the condition to become true, but want a chance to recover if it does not.
100      */
101     @JvmStatic
102     @JvmOverloads
waitToBecomeTruenull103     fun waitToBecomeTrue(
104         description: String? = null,
105         timeout: Duration = DEFAULT_DEADLINE,
106         condition: () -> Boolean,
107     ): WaitReport {
108         val traceName =
109             if (description != null) {
110                 "Ensuring $description"
111             } else {
112                 "ensure"
113             }
114         var i = 1
115         trace(traceName) {
116             val startTime = uptimeMillis()
117             val timeoutMs = timeout.toMillis()
118             Log.d(TAG, "Starting $traceName")
119             withEventualLogging(logTimeDelta = true) {
120                 log(traceName)
121                 while (uptimeMillis() < startTime + timeoutMs) {
122                     trace("iteration $i") {
123                         try {
124                             if (condition()) {
125                                 log("[#$i] Condition true")
126                                 return WaitReport(WaitResult.WaitSuccess, i)
127                             }
128                         } catch (t: Throwable) {
129                             log("[#$i] Condition failing with exception")
130                             return WaitReport(WaitResult.WaitThrown(t), i)
131                         }
132 
133                         log("[#$i] Condition false, might retry.")
134                         sleep(POLLING_WAIT.toMillis())
135                         i++
136                     }
137                 }
138                 log("[#$i] Condition has always been false. Failing.")
139                 return WaitReport(WaitResult.WaitFailure, i)
140             }
141         }
142     }
143 
144     /**
145      * Same as [waitForNullableValueToSettle], but assumes that [supplier] return value is non-null.
146      */
147     @JvmStatic
148     @JvmOverloads
waitForValueToSettlenull149     fun <T> waitForValueToSettle(
150         description: String? = null,
151         minimumSettleTime: Duration = DEFAULT_SETTLE_TIME,
152         timeout: Duration = DEFAULT_DEADLINE,
153         errorProvider: () -> String =
154             defaultWaitForSettleError(minimumSettleTime, description, timeout),
155         supplier: () -> T,
156     ): T {
157         return waitForNullableValueToSettle(
158             description,
159             minimumSettleTime,
160             timeout,
161             errorProvider,
162             supplier
163         )
164             ?: error(errorProvider())
165     }
166 
167     /**
168      * Waits for [supplier] to return the same value for at least [minimumSettleTime].
169      *
170      * If the value changes, the timer gets restarted. Fails when reaching [timeoutMs]. The minimum
171      * running time of this method is [minimumSettleTime], in case the value is stable since the
172      * beginning.
173      *
174      * Fails if [supplier] throws an exception.
175      *
176      * Outputs atraces visible with perfetto.
177      *
178      * Example of usage:
179      * ```
180      * val screenOn = waitForValueToSettle("Screen on") { uiDevice.isScreenOn }
181      * ```
182      *
183      * Note: Prefer using [waitForValueToSettle] when [supplier] doesn't return a null value.
184      *
185      * @return the settled value. Throws if it doesn't settle.
186      */
187     @JvmStatic
188     @JvmOverloads
waitForNullableValueToSettlenull189     fun <T> waitForNullableValueToSettle(
190         description: String? = null,
191         minimumSettleTime: Duration = DEFAULT_SETTLE_TIME,
192         timeout: Duration = DEFAULT_DEADLINE,
193         errorProvider: () -> String =
194             defaultWaitForSettleError(minimumSettleTime, description, timeout),
195         supplier: () -> T?,
196     ): T? {
197         val prefix =
198             if (description != null) {
199                 "Waiting for \"$description\" to settle"
200             } else {
201                 "waitForValueToSettle"
202             }
203         val traceName =
204             prefix +
205                 " (settleTime=${minimumSettleTime.toMillis()}ms, deadline=${timeout.toMillis()}ms)"
206         trace(traceName) {
207             Log.d(TAG, "Starting $traceName")
208             withEventualLogging(logTimeDelta = true) {
209                 log(traceName)
210 
211                 val startTime = now()
212                 var settledSince = startTime
213                 var previousValue: T? = null
214                 var previousValueSet = false
215                 while (now().isBefore(startTime + timeout)) {
216                     val newValue =
217                         try {
218                             supplier()
219                         } catch (t: Throwable) {
220                             if (previousValueSet) {
221                                 Trace.endSection()
222                             }
223                             log("Supplier has thrown an exception")
224                             throw RuntimeException(t)
225                         }
226                     val currentTime = now()
227                     if (previousValue != newValue || !previousValueSet) {
228                         log("value changed to $newValue")
229                         settledSince = currentTime
230                         if (previousValueSet) {
231                             Trace.endSection()
232                         }
233                         TracingUtils.beginSectionSafe("New value: $newValue")
234                         previousValue = newValue
235                         previousValueSet = true
236                     } else if (now().isAfter(settledSince + minimumSettleTime)) {
237                         log("Got settled value. Returning \"$previousValue\"")
238                         Trace.endSection() // previousValue is guaranteed to be non-null.
239                         return previousValue
240                     }
241                     sleep(POLLING_WAIT.toMillis())
242                 }
243                 if (previousValueSet) {
244                     Trace.endSection()
245                 }
246                 error(errorProvider())
247             }
248         }
249     }
250 
defaultWaitForSettleErrornull251     private fun defaultWaitForSettleError(
252         minimumSettleTime: Duration,
253         description: String?,
254         timeout: Duration
255     ): () -> String {
256         return {
257             "Error getting settled (${minimumSettleTime.toMillis()}) " +
258                 "value for \"$description\" within ${timeout.toMillis()}."
259         }
260     }
261 
262     /**
263      * Waits for [supplier] to return a non-null value within [timeout].
264      *
265      * Returns null after the timeout finished.
266      */
waitForNullablenull267     fun <T> waitForNullable(
268         description: String,
269         timeout: Duration = DEFAULT_DEADLINE,
270         checker: (T?) -> Boolean = { it != null },
271         supplier: () -> T?,
272     ): T? {
273         var result: T? = null
274 
<lambda>null275         ensureThat("Waiting for \"$description\"", timeout, ignoreFailure = true) {
276             result = supplier()
277             checker(result)
278         }
279         return result
280     }
281 
282     /** Wraps [waitForNullable] using the default checker, and allowing kotlin supplier syntax. */
waitForNullablenull283     fun <T> waitForNullable(
284         description: String,
285         timeout: Duration = DEFAULT_DEADLINE,
286         supplier: () -> T?,
287     ): T? = waitForNullable(description, timeout, checker = { it != null }, supplier)
288 
289     /**
290      * Waits for [supplier] to return a not null and not empty list within [timeout].
291      *
292      * Returns the not-empty list as soon as it's received, or an empty list once reached the
293      * timeout.
294      */
waitForPossibleEmptynull295     fun <T> waitForPossibleEmpty(
296         description: String,
297         timeout: Duration = DEFAULT_DEADLINE,
298         supplier: () -> List<T>?
299     ): List<T> =
300         waitForNullable(description, timeout, { !it.isNullOrEmpty() }, supplier) ?: emptyList()
301 
302     /**
303      * Waits for [supplier] to return a non-null value within [timeout].
304      *
305      * Throws an exception with [errorProvider] provided message if [supplier] failed to produce a
306      * non-null value within [timeout].
307      */
waitFornull308     fun <T> waitFor(
309         description: String,
310         timeout: Duration = DEFAULT_DEADLINE,
311         errorProvider: () -> String = {
312             "Didn't get a non-null value for \"$description\" within ${timeout.toMillis()}ms"
313         },
314         supplier: () -> T?
315     ): T = waitForNullable(description, timeout, supplier) ?: error(errorProvider())
316 
317     /**
318      * Retry a block of code [times] times, if it throws a StaleObjectException.
319      *
320      * This can be used to reduce flakiness in cases where waitForObj throws although the object
321      * does seem to be present.
322      */
retryIfStalenull323     fun <T> retryIfStale(description: String, times: Int, block: () -> T): T {
324         return trace("retryIfStale: $description") outerTrace@{
325             repeat(times) {
326                 trace("attempt #$it") {
327                     try {
328                         return@outerTrace block()
329                     } catch (e: StaleObjectException) {
330                         Log.w(TAG, "Caught a StaleObjectException ($e). Retrying.")
331                     }
332                 }
333             }
334             // Run the block once without catching
335             trace("final attempt") { block() }
336         }
337     }
338 
339     /** Generic logging interface. */
340     private interface Logger {
lognull341         fun log(s: String)
342     }
343 
344     /** Logs all messages when closed. */
345     private class LoggerImpl private constructor(private val logTimeDelta: Boolean) :
346         Closeable, Logger {
347         private val logs = mutableListOf<String>()
348         private val startTime = uptimeMillis()
349 
350         companion object {
351             /** Executes [block] and prints all logs at the end. */
352             inline fun <T> withEventualLogging(
353                 logTimeDelta: Boolean = false,
354                 block: Logger.() -> T
355             ): T = LoggerImpl(logTimeDelta).use { it.block() }
356         }
357 
358         override fun log(s: String) {
359             logs += if (logTimeDelta) "+${uptimeMillis() - startTime}ms $s" else s
360         }
361 
362         override fun close() {
363             if (VERBOSE) {
364                 Log.d(TAG, logs.joinToString("\n"))
365             }
366         }
367     }
368 }
369 
370 /** Exception thrown when [WaitUtils.ensureThat] fails. */
371 class FailedEnsureException(message: String? = null) : IllegalStateException(message)
372