1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.car.test.mocks;
17 
18 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doAnswer;
19 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
20 
21 import static org.mockito.ArgumentMatchers.any;
22 import static org.mockito.ArgumentMatchers.anyInt;
23 import static org.mockito.ArgumentMatchers.anyString;
24 import static org.mockito.ArgumentMatchers.notNull;
25 import static org.mockito.Mockito.when;
26 
27 import static java.lang.annotation.ElementType.METHOD;
28 import static java.lang.annotation.RetentionPolicy.RUNTIME;
29 
30 import android.annotation.NonNull;
31 import android.annotation.Nullable;
32 import android.annotation.UserIdInt;
33 import android.app.ActivityManager;
34 import android.os.Handler;
35 import android.os.HandlerThread;
36 import android.os.Trace;
37 import android.os.UserManager;
38 import android.provider.Settings;
39 import android.util.Log;
40 import android.util.Slog;
41 import android.util.TimingsTraceLog;
42 
43 import com.android.dx.mockito.inline.extended.StaticMockitoSessionBuilder;
44 import com.android.internal.util.Preconditions;
45 
46 import org.junit.After;
47 import org.junit.Before;
48 import org.junit.Rule;
49 import org.junit.rules.TestRule;
50 import org.junit.runner.Description;
51 import org.junit.runners.model.Statement;
52 import org.mockito.MockitoSession;
53 import org.mockito.invocation.InvocationOnMock;
54 import org.mockito.quality.Strictness;
55 import org.mockito.session.MockitoSessionBuilder;
56 import org.mockito.stubbing.Answer;
57 
58 import java.lang.annotation.Retention;
59 import java.lang.annotation.Target;
60 import java.lang.reflect.Method;
61 import java.util.ArrayList;
62 import java.util.HashMap;
63 import java.util.List;
64 import java.util.Set;
65 
66 /**
67  * Base class for tests that must use {@link com.android.dx.mockito.inline.extended.ExtendedMockito}
68  * to mock static classes and final methods.
69  *
70  * <p><b>Note: </b> this class automatically spy on {@link Log} and {@link Slog} and fail tests that
71  * all any of their {@code wtf()} methods. If a test is expect to call {@code wtf()}, it should be
72  * annotated with {@link ExpectWtf}.
73  *
74  * <p><b>Note: </b>when using this class, you must include the following
75  * dependencies on {@code Android.bp} (or {@code Android.mk}:
76  * <pre><code>
77     jni_libs: [
78         "libdexmakerjvmtiagent",
79         "libstaticjvmtiagent",
80     ],
81 
82    LOCAL_JNI_SHARED_LIBRARIES := \
83       libdexmakerjvmtiagent \
84       libstaticjvmtiagent \
85  *  </code></pre>
86  */
87 public abstract class AbstractExtendedMockitoTestCase {
88 
89     private static final String TAG = AbstractExtendedMockitoTestCase.class.getSimpleName();
90 
91     private static final boolean TRACE = false;
92     private static final boolean VERBOSE = false;
93 
94     private final List<Class<?>> mStaticSpiedClasses = new ArrayList<>();
95 
96     // Tracks (S)Log.wtf() calls made during code execution, then used on verifyWtfNeverLogged()
97     private final List<RuntimeException> mWtfs = new ArrayList<>();
98 
99     private MockitoSession mSession;
100     private MockSettings mSettings;
101 
102     @Nullable
103     private final TimingsTraceLog mTracer;
104 
105     @Rule
106     public final WtfCheckerRule mWtfCheckerRule = new WtfCheckerRule();
107 
AbstractExtendedMockitoTestCase()108     protected AbstractExtendedMockitoTestCase() {
109         mTracer = TRACE ? new TimingsTraceLog(TAG, Trace.TRACE_TAG_APP) : null;
110     }
111 
112     @Before
startSession()113     public final void startSession() {
114         beginTrace("startSession()");
115 
116         beginTrace("startMocking()");
117         mSession = newSessionBuilder().startMocking();
118         endTrace();
119 
120         beginTrace("MockSettings()");
121         mSettings = new MockSettings();
122         endTrace();
123 
124         beginTrace("interceptWtfCalls()");
125         interceptWtfCalls();
126         endTrace();
127 
128         endTrace(); // startSession
129     }
130 
131     @After
finishSession()132     public final void finishSession() {
133         beginTrace("finishSession()");
134         completeAllHandlerThreadTasks();
135         if (mSession != null) {
136             beginTrace("finishMocking()");
137             mSession.finishMocking();
138             endTrace();
139         } else {
140             Log.w(TAG, getClass().getSimpleName() + ".finishSession(): no session");
141         }
142         endTrace();
143     }
144 
145     /**
146      * Waits for completion of all pending Handler tasks for all HandlerThread in the process.
147      *
148      * <p>This can prevent pending Handler tasks of one test from affecting another. This does not
149      * work if the message is posted with delay.
150      */
completeAllHandlerThreadTasks()151     protected void completeAllHandlerThreadTasks() {
152         beginTrace("completeAllHandlerThreadTasks");
153         Set<Thread> threadSet = Thread.getAllStackTraces().keySet();
154         ArrayList<HandlerThread> handlerThreads = new ArrayList<>(threadSet.size());
155         Thread currentThread = Thread.currentThread();
156         for (Thread t : threadSet) {
157             if (t != currentThread && t instanceof HandlerThread) {
158                 handlerThreads.add((HandlerThread) t);
159             }
160         }
161         ArrayList<SyncRunnable> syncs = new ArrayList<>(handlerThreads.size());
162         Log.i(TAG, "will wait for " + handlerThreads.size() + " HandlerThreads");
163         for (int i = 0; i < handlerThreads.size(); i++) {
164             Handler handler = new Handler(handlerThreads.get(i).getLooper());
165             SyncRunnable sr = new SyncRunnable(() -> { });
166             handler.post(sr);
167             syncs.add(sr);
168         }
169         beginTrace("waitForComplete");
170         for (int i = 0; i < syncs.size(); i++) {
171             syncs.get(i).waitForComplete();
172         }
173         endTrace(); // waitForComplete
174         endTrace(); // completeAllHandlerThreadTasks
175     }
176 
177     /**
178      * Adds key-value(int) pair in mocked Settings.Global and Settings.Secure
179      */
putSettingsInt(@onNull String key, int value)180     protected void putSettingsInt(@NonNull String key, int value) {
181         mSettings.insertObject(key, value);
182     }
183 
184     /**
185      * Gets value(int) from mocked Settings.Global and Settings.Secure
186      */
getSettingsInt(@onNull String key)187     protected int getSettingsInt(@NonNull String key) {
188         return mSettings.getInt(key);
189     }
190 
191     /**
192      * Adds key-value(String) pair in mocked Settings.Global and Settings.Secure
193      */
putSettingsString(@onNull String key, @NonNull String value)194     protected void putSettingsString(@NonNull String key, @NonNull String value) {
195         mSettings.insertObject(key, value);
196     }
197 
198     /**
199      * Gets value(String) from mocked Settings.Global and Settings.Secure
200      */
getSettingsString(@onNull String key)201     protected String getSettingsString(@NonNull String key) {
202         return mSettings.getString(key);
203     }
204 
205     /**
206      * Asserts that the giving settings was not set.
207      */
assertSettingsNotSet(String key)208     protected void assertSettingsNotSet(String key) {
209         mSettings.assertDoesNotContainsKey(key);
210     }
211 
212     /**
213      * Subclasses can use this method to initialize the Mockito session that's started before every
214      * test on {@link #startSession()}.
215      *
216      * <p>Typically, it should be overridden when mocking static methods.
217      */
onSessionBuilder(@onNull CustomMockitoSessionBuilder session)218     protected void onSessionBuilder(@NonNull CustomMockitoSessionBuilder session) {
219         if (VERBOSE) Log.v(TAG, getLogPrefix() + "onSessionBuilder()");
220     }
221 
222     /**
223      * Changes the value of the session created by
224      * {@link #onSessionBuilder(CustomMockitoSessionBuilder)}.
225      *
226      * <p>By default it's set to {@link Strictness.LENIENT}, but subclasses can overwrite this
227      * method to change the behavior.
228      */
229     @NonNull
getSessionStrictness()230     protected Strictness getSessionStrictness() {
231         return Strictness.LENIENT;
232     }
233 
234     /**
235      * Mocks a call to {@link ActivityManager#getCurrentUser()}.
236      *
237      * @param userId result of such call
238      *
239      * @throws IllegalStateException if class didn't override {@link #newSessionBuilder()} and
240      * called {@code spyStatic(ActivityManager.class)} on the session passed to it.
241      */
mockGetCurrentUser(@serIdInt int userId)242     protected final void mockGetCurrentUser(@UserIdInt int userId) {
243         if (VERBOSE) Log.v(TAG, getLogPrefix() + "mockGetCurrentUser(" + userId + ")");
244         assertSpied(ActivityManager.class);
245 
246         beginTrace("mockAmGetCurrentUser-" + userId);
247         AndroidMockitoHelper.mockAmGetCurrentUser(userId);
248         endTrace();
249     }
250 
251     /**
252      * Mocks a call to {@link UserManager#isHeadlessSystemUserMode()}.
253      *
254      * @param mode result of such call
255      *
256      * @throws IllegalStateException if class didn't override {@link #newSessionBuilder()} and
257      * called {@code spyStatic(UserManager.class)} on the session passed to it.
258      */
mockIsHeadlessSystemUserMode(boolean mode)259     protected final void mockIsHeadlessSystemUserMode(boolean mode) {
260         if (VERBOSE) Log.v(TAG, getLogPrefix() + "mockIsHeadlessSystemUserMode(" + mode + ")");
261         assertSpied(UserManager.class);
262 
263         beginTrace("mockUmIsHeadlessSystemUserMode");
264         AndroidMockitoHelper.mockUmIsHeadlessSystemUserMode(mode);
265         endTrace();
266     }
267 
268     /**
269      * Starts a tracing message.
270      *
271      * <p>MUST be followed by a {@link #endTrace()} calls.
272      *
273      * <p>Ignored if {@value #VERBOSE} is {@code false}.
274      */
beginTrace(@onNull String message)275     protected final void beginTrace(@NonNull String message) {
276         if (mTracer == null) return;
277 
278         Log.d(TAG, getLogPrefix() + message);
279         mTracer.traceBegin(message);
280     }
281 
282     /**
283      * Ends a tracing call.
284      *
285      * <p>MUST be called after {@link #beginTrace(String)}.
286      *
287      * <p>Ignored if {@value #VERBOSE} is {@code false}.
288      */
endTrace()289     protected final void endTrace() {
290         if (mTracer == null) return;
291 
292         mTracer.traceEnd();
293     }
294 
interceptWtfCalls()295     private void interceptWtfCalls() {
296         doAnswer((invocation) -> {
297             return addWtf(invocation);
298         }).when(() -> Log.wtf(anyString(), anyString()));
299         doAnswer((invocation) -> {
300             return addWtf(invocation);
301         }).when(() -> Log.wtf(anyString(), anyString(), notNull()));
302         doAnswer((invocation) -> {
303             return addWtf(invocation);
304         }).when(() -> Slog.wtf(anyString(), anyString()));
305         doAnswer((invocation) -> {
306             return addWtf(invocation);
307         }).when(() -> Slog.wtf(anyString(), anyString(), notNull()));
308     }
309 
addWtf(InvocationOnMock invocation)310     private Object addWtf(InvocationOnMock invocation) {
311         String message = "Called " + invocation;
312         Log.d(TAG, message); // Log always, as some test expect it
313         mWtfs.add(new IllegalStateException(message));
314         return null;
315     }
316 
verifyWtfLogged()317     private void verifyWtfLogged() {
318         Preconditions.checkState(!mWtfs.isEmpty(), "no wtf() called");
319     }
320 
verifyWtfNeverLogged()321     private void verifyWtfNeverLogged() {
322         int size = mWtfs.size();
323 
324         switch (size) {
325             case 0:
326                 return;
327             case 1:
328                 throw mWtfs.get(0);
329             default:
330                 StringBuilder msg = new StringBuilder("wtf called ").append(size).append(" times")
331                         .append(": ").append(mWtfs);
332                 throw new AssertionError(msg.toString());
333         }
334     }
335 
336     @NonNull
newSessionBuilder()337     private MockitoSessionBuilder newSessionBuilder() {
338         // TODO (b/155523104): change from mock to spy
339         StaticMockitoSessionBuilder builder = mockitoSession()
340                 .strictness(getSessionStrictness())
341                 .mockStatic(Settings.Global.class)
342                 .mockStatic(Settings.System.class)
343                 .mockStatic(Settings.Secure.class);
344 
345         CustomMockitoSessionBuilder customBuilder =
346                 new CustomMockitoSessionBuilder(builder, mStaticSpiedClasses)
347                     .spyStatic(Log.class)
348                     .spyStatic(Slog.class);
349 
350         onSessionBuilder(customBuilder);
351 
352         if (VERBOSE) Log.v(TAG, "spied classes" + customBuilder.mStaticSpiedClasses);
353 
354         return builder.initMocks(this);
355     }
356 
357     /**
358      * Gets a prefix for {@link Log} calls
359      */
getLogPrefix()360     protected String getLogPrefix() {
361         return getClass().getSimpleName() + ".";
362     }
363 
364     /**
365      * Asserts the given class is being spied in the Mockito session.
366      */
assertSpied(Class<?> clazz)367     protected void assertSpied(Class<?> clazz) {
368         Preconditions.checkArgument(mStaticSpiedClasses.contains(clazz),
369                 "did not call spyStatic() on %s", clazz.getName());
370     }
371 
372     /**
373      * Custom {@code MockitoSessionBuilder} used to make sure some pre-defined mock stations
374      * (like {@link AbstractExtendedMockitoTestCase#mockGetCurrentUser(int)} fail if the test case
375      * didn't explicitly set it to spy / mock the required classes.
376      *
377      * <p><b>NOTE: </b>for now it only provides simple {@link #spyStatic(Class)}, but more methods
378      * (as provided by {@link StaticMockitoSessionBuilder}) could be provided as needed.
379      */
380     public static final class CustomMockitoSessionBuilder {
381         private final StaticMockitoSessionBuilder mBuilder;
382         private final List<Class<?>> mStaticSpiedClasses;
383 
CustomMockitoSessionBuilder(StaticMockitoSessionBuilder builder, List<Class<?>> staticSpiedClasses)384         private CustomMockitoSessionBuilder(StaticMockitoSessionBuilder builder,
385                 List<Class<?>> staticSpiedClasses) {
386             mBuilder = builder;
387             mStaticSpiedClasses = staticSpiedClasses;
388         }
389 
390         /**
391          * Same as {@link StaticMockitoSessionBuilder#spyStatic(Class)}.
392          */
spyStatic(Class<T> clazz)393         public <T> CustomMockitoSessionBuilder spyStatic(Class<T> clazz) {
394             Preconditions.checkState(!mStaticSpiedClasses.contains(clazz),
395                     "already called spyStatic() on " + clazz);
396             mStaticSpiedClasses.add(clazz);
397             mBuilder.spyStatic(clazz);
398             return this;
399         }
400     }
401 
402     private final class WtfCheckerRule implements TestRule {
403 
404         @Override
apply(Statement base, Description description)405         public Statement apply(Statement base, Description description) {
406             return new Statement() {
407                 @Override
408                 public void evaluate() throws Throwable {
409                     String testName = description.getMethodName();
410                     if (VERBOSE) Log.v(TAG, "running " + testName);
411                     beginTrace("evaluate-" + testName);
412                     base.evaluate();
413                     endTrace();
414 
415                     Method testMethod = AbstractExtendedMockitoTestCase.this.getClass()
416                             .getMethod(testName);
417                     ExpectWtf expectWtfAnnotation = testMethod.getAnnotation(ExpectWtf.class);
418 
419                     beginTrace("verify-wtfs");
420                     try {
421                         if (expectWtfAnnotation != null) {
422                             if (VERBOSE) Log.v(TAG, "expecting wtf()");
423                             verifyWtfLogged();
424                         } else {
425                             if (VERBOSE) Log.v(TAG, "NOT expecting wtf()");
426                             verifyWtfNeverLogged();
427                         }
428                     } finally {
429                         endTrace();
430                     }
431                 }
432             };
433         }
434     }
435 
436     // TODO (b/155523104): Add log
437     // TODO (b/156033195): Clean settings API
438     private static final class MockSettings {
439         private static final int INVALID_DEFAULT_INDEX = -1;
440         private HashMap<String, Object> mSettingsMapping = new HashMap<>();
441 
442         MockSettings() {
443 
444             Answer<Object> insertObjectAnswer =
445                     invocation -> insertObjectFromInvocation(invocation, 1, 2);
446             Answer<Integer> getIntAnswer = invocation ->
447                     getAnswer(invocation, Integer.class, 1, 2);
448             Answer<String> getStringAnswer = invocation ->
449                     getAnswer(invocation, String.class, 1, INVALID_DEFAULT_INDEX);
450 
451             when(Settings.Global.putInt(any(), any(), anyInt())).thenAnswer(insertObjectAnswer);
452 
453             when(Settings.Global.getInt(any(), any(), anyInt())).thenAnswer(getIntAnswer);
454 
455             when(Settings.Secure.putIntForUser(any(), any(), anyInt(), anyInt()))
456                     .thenAnswer(insertObjectAnswer);
457 
458             when(Settings.Secure.getIntForUser(any(), any(), anyInt(), anyInt()))
459                     .thenAnswer(getIntAnswer);
460 
461             when(Settings.Secure.putStringForUser(any(), anyString(), anyString(), anyInt()))
462                     .thenAnswer(insertObjectAnswer);
463 
464             when(Settings.Global.putString(any(), any(), any()))
465                     .thenAnswer(insertObjectAnswer);
466 
467             when(Settings.Global.getString(any(), any())).thenAnswer(getStringAnswer);
468 
469             when(Settings.System.putIntForUser(any(), any(), anyInt(), anyInt()))
470                     .thenAnswer(insertObjectAnswer);
471 
472             when(Settings.System.getIntForUser(any(), any(), anyInt(), anyInt()))
473                     .thenAnswer(getIntAnswer);
474 
475             when(Settings.System.putStringForUser(any(), any(), anyString(), anyInt()))
476                     .thenAnswer(insertObjectAnswer);
477         }
478 
479         private Object insertObjectFromInvocation(InvocationOnMock invocation,
480                 int keyIndex, int valueIndex) {
481             String key = (String) invocation.getArguments()[keyIndex];
482             Object value = invocation.getArguments()[valueIndex];
483             insertObject(key, value);
484             return null;
485         }
486 
487         private void insertObject(String key, Object value) {
488             if (VERBOSE) Log.v(TAG, "Inserting Setting " + key + ": " + value);
489             mSettingsMapping.put(key, value);
490         }
491 
492         private <T> T getAnswer(InvocationOnMock invocation, Class<T> clazz,
493                 int keyIndex, int defaultValueIndex) {
494             String key = (String) invocation.getArguments()[keyIndex];
495             T defaultValue = null;
496             if (defaultValueIndex > INVALID_DEFAULT_INDEX) {
497                 defaultValue = safeCast(invocation.getArguments()[defaultValueIndex], clazz);
498             }
499             return get(key, defaultValue, clazz);
500         }
501 
502         @Nullable
503         private <T> T get(String key, T defaultValue, Class<T> clazz) {
504             if (VERBOSE) {
505                 Log.v(TAG, "get(): key=" + key + ", default=" + defaultValue + ", class=" + clazz);
506             }
507             Object value = mSettingsMapping.get(key);
508             if (value == null) {
509                 if (VERBOSE) Log.v(TAG, "not found");
510                 return defaultValue;
511             }
512 
513             if (VERBOSE) Log.v(TAG, "returning " + value);
514             return safeCast(value, clazz);
515         }
516 
517         private static <T> T safeCast(Object value, Class<T> clazz) {
518             if (value == null) {
519                 return null;
520             }
521             Preconditions.checkArgument(value.getClass() == clazz,
522                     "Setting value has class %s but requires class %s",
523                     value.getClass(), clazz);
524             return clazz.cast(value);
525         }
526 
527         private String getString(String key) {
528             return get(key, null, String.class);
529         }
530 
531         public int getInt(String key) {
532             return get(key, null, Integer.class);
533         }
534 
535         public void assertDoesNotContainsKey(String key) {
536             if (mSettingsMapping.containsKey(key)) {
537                 throw new AssertionError("Should not have key " + key + ", but has: "
538                         + mSettingsMapping.get(key));
539             }
540         }
541     }
542 
543     /**
544      * Annotation used on test methods that are expect to call {@code wtf()} methods on {@link Log}
545      * or {@link Slog} - if such methods are not annotated with this annotation, they will fail.
546      */
547     @Retention(RUNTIME)
548     @Target({METHOD})
549     public static @interface ExpectWtf {
550     }
551 
552     private static final class SyncRunnable implements Runnable {
553         private final Runnable mTarget;
554         private volatile boolean mComplete = false;
555 
556         private SyncRunnable(Runnable target) {
557             mTarget = target;
558         }
559 
560         @Override
561         public void run() {
562             mTarget.run();
563             synchronized (this) {
564                 mComplete = true;
565                 notifyAll();
566             }
567         }
568 
569         private void waitForComplete() {
570             synchronized (this) {
571                 while (!mComplete) {
572                     try {
573                         wait();
574                     } catch (InterruptedException e) {
575                     }
576                 }
577             }
578         }
579     }
580 }
581