1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
5  * except in compliance with the License. You may obtain a copy of the License at
6  *
7  *      http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software distributed under the
10  * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11  * KIND, either express or implied. See the License for the specific language governing
12  * permissions and limitations under the License.
13  */
14 
15 package android.testing;
16 
17 import android.os.Handler;
18 import android.os.HandlerThread;
19 import android.os.Looper;
20 import android.os.Message;
21 import android.os.MessageQueue;
22 import android.os.TestLooperManager;
23 import android.support.test.InstrumentationRegistry;
24 import android.util.ArrayMap;
25 
26 import org.junit.runners.model.FrameworkMethod;
27 
28 import java.lang.annotation.ElementType;
29 import java.lang.annotation.Retention;
30 import java.lang.annotation.RetentionPolicy;
31 import java.lang.annotation.Target;
32 import java.lang.reflect.Field;
33 import java.util.Map;
34 
35 /**
36  * Creates a looper on the current thread with control over if/when messages are
37  * executed. Warning: This class works through some reflection and may break/need
38  * to be updated from time to time.
39  */
40 public class TestableLooper {
41 
42     private Looper mLooper;
43     private MessageQueue mQueue;
44     private boolean mMain;
45     private Object mOriginalMain;
46     private MessageHandler mMessageHandler;
47 
48     private Handler mHandler;
49     private Runnable mEmptyMessage;
50     private TestLooperManager mQueueWrapper;
51 
TestableLooper(Looper l)52     public TestableLooper(Looper l) throws Exception {
53         this(InstrumentationRegistry.getInstrumentation().acquireLooperManager(l), l);
54     }
55 
TestableLooper(TestLooperManager wrapper, Looper l)56     private TestableLooper(TestLooperManager wrapper, Looper l) throws Exception {
57         mQueueWrapper = wrapper;
58         setupQueue(l);
59     }
60 
TestableLooper(Looper looper, boolean b)61     private TestableLooper(Looper looper, boolean b) throws Exception {
62         setupQueue(looper);
63     }
64 
getLooper()65     public Looper getLooper() {
66         return mLooper;
67     }
68 
setupQueue(Looper l)69     private void setupQueue(Looper l) throws Exception {
70         mLooper = l;
71         mQueue = mLooper.getQueue();
72         mHandler = new Handler(mLooper);
73     }
74 
setAsMainLooper()75     public void setAsMainLooper() throws NoSuchFieldException, IllegalAccessException {
76         mMain = true;
77         setAsMainInt();
78     }
79 
setAsMainInt()80     private void setAsMainInt() throws NoSuchFieldException, IllegalAccessException {
81         Field field = mLooper.getClass().getDeclaredField("sMainLooper");
82         field.setAccessible(true);
83         if (mOriginalMain == null) {
84             mOriginalMain = field.get(null);
85         }
86         field.set(null, mLooper);
87     }
88 
89     /**
90      * Must be called if setAsMainLooper is called to restore the main looper when the
91      * test is complete, otherwise the main looper will not be available for any subsequent
92      * tests.
93      */
destroy()94     public void destroy() throws NoSuchFieldException, IllegalAccessException {
95         mQueueWrapper.release();
96         if (mMain && mOriginalMain != null) {
97             Field field = mLooper.getClass().getDeclaredField("sMainLooper");
98             field.setAccessible(true);
99             field.set(null, mOriginalMain);
100             mOriginalMain = null;
101         }
102     }
103 
setMessageHandler(MessageHandler handler)104     public void setMessageHandler(MessageHandler handler) {
105         mMessageHandler = handler;
106     }
107 
108     /**
109      * Parse num messages from the message queue.
110      *
111      * @param num Number of messages to parse
112      */
processMessages(int num)113     public int processMessages(int num) {
114         for (int i = 0; i < num; i++) {
115             if (!parseMessageInt()) {
116                 return i + 1;
117             }
118         }
119         return num;
120     }
121 
processAllMessages()122     public void processAllMessages() {
123         while (processQueuedMessages() != 0) ;
124     }
125 
processQueuedMessages()126     private int processQueuedMessages() {
127         int count = 0;
128         mEmptyMessage = () -> { };
129         mHandler.post(mEmptyMessage);
130         waitForMessage(mQueueWrapper, mHandler, mEmptyMessage);
131         while (parseMessageInt()) count++;
132         return count;
133     }
134 
parseMessageInt()135     private boolean parseMessageInt() {
136         try {
137             Message result = mQueueWrapper.next();
138             if (result != null) {
139                 // This is a break message.
140                 if (result.getCallback() == mEmptyMessage) {
141                     mQueueWrapper.recycle(result);
142                     return false;
143                 }
144 
145                 if (mMessageHandler != null) {
146                     if (mMessageHandler.onMessageHandled(result)) {
147                         result.getTarget().dispatchMessage(result);
148                         mQueueWrapper.recycle(result);
149                     } else {
150                         mQueueWrapper.recycle(result);
151                         // Message handler indicated it doesn't want us to continue.
152                         return false;
153                     }
154                 } else {
155                     result.getTarget().dispatchMessage(result);
156                     mQueueWrapper.recycle(result);
157                 }
158             } else {
159                 // No messages, don't continue parsing
160                 return false;
161             }
162         } catch (Exception e) {
163             throw new RuntimeException(e);
164         }
165         return true;
166     }
167 
168     /**
169      * Runs an executable with myLooper set and processes all messages added.
170      */
runWithLooper(RunnableWithException runnable)171     public void runWithLooper(RunnableWithException runnable) throws Exception {
172         new Handler(getLooper()).post(() -> {
173             try {
174                 runnable.run();
175             } catch (Exception e) {
176                 throw new RuntimeException(e);
177             }
178         });
179         processAllMessages();
180     }
181 
182     public interface RunnableWithException {
run()183         void run() throws Exception;
184     }
185 
186     @Retention(RetentionPolicy.RUNTIME)
187     @Target({ElementType.METHOD, ElementType.TYPE})
188     public @interface RunWithLooper {
setAsMainLooper()189         boolean setAsMainLooper() default false;
190     }
191 
waitForMessage(TestLooperManager queueWrapper, Handler handler, Runnable execute)192     private static void waitForMessage(TestLooperManager queueWrapper, Handler handler,
193             Runnable execute) {
194         for (int i = 0; i < 10; i++) {
195             if (!queueWrapper.hasMessages(handler, null, execute)) {
196                 try {
197                     Thread.sleep(1);
198                 } catch (InterruptedException e) {
199                 }
200             }
201         }
202         if (!queueWrapper.hasMessages(handler, null, execute)) {
203             throw new RuntimeException("Message didn't queue...");
204         }
205     }
206 
207     private static final Map<Object, TestableLooper> sLoopers = new ArrayMap<>();
208 
get(Object test)209     public static TestableLooper get(Object test) {
210         return sLoopers.get(test);
211     }
212 
213     public static class LooperFrameworkMethod extends FrameworkMethod {
214         private HandlerThread mHandlerThread;
215 
216         private final TestableLooper mTestableLooper;
217         private final Looper mLooper;
218         private final Handler mHandler;
219 
LooperFrameworkMethod(FrameworkMethod base, boolean setAsMain, Object test)220         public LooperFrameworkMethod(FrameworkMethod base, boolean setAsMain, Object test) {
221             super(base.getMethod());
222             try {
223                 mLooper = setAsMain ? Looper.getMainLooper() : createLooper();
224                 mTestableLooper = new TestableLooper(mLooper, false);
225             } catch (Exception e) {
226                 throw new RuntimeException(e);
227             }
228             sLoopers.put(test, mTestableLooper);
229             mHandler = new Handler(mLooper);
230         }
231 
LooperFrameworkMethod(TestableLooper other, FrameworkMethod base)232         public LooperFrameworkMethod(TestableLooper other, FrameworkMethod base) {
233             super(base.getMethod());
234             mLooper = other.mLooper;
235             mTestableLooper = other;
236             mHandler = new Handler(mLooper);
237         }
238 
get(FrameworkMethod base, boolean setAsMain, Object test)239         public static FrameworkMethod get(FrameworkMethod base, boolean setAsMain, Object test) {
240             if (sLoopers.containsKey(test)) {
241                 return new LooperFrameworkMethod(sLoopers.get(test), base);
242             }
243             return new LooperFrameworkMethod(base, setAsMain, test);
244         }
245 
246         @Override
invokeExplosively(Object target, Object... params)247         public Object invokeExplosively(Object target, Object... params) throws Throwable {
248             if (Looper.myLooper() == mLooper) {
249                 // Already on the right thread from another statement, just execute then.
250                 return super.invokeExplosively(target, params);
251             }
252             boolean set = mTestableLooper.mQueueWrapper == null;
253             if (set) {
254                 mTestableLooper.mQueueWrapper = InstrumentationRegistry.getInstrumentation()
255                         .acquireLooperManager(mLooper);
256             }
257             try {
258                 Object[] ret = new Object[1];
259                 // Run the execution on the looper thread.
260                 Runnable execute = () -> {
261                     try {
262                         ret[0] = super.invokeExplosively(target, params);
263                     } catch (Throwable throwable) {
264                         throw new LooperException(throwable);
265                     }
266                 };
267                 Message m = Message.obtain(mHandler, execute);
268 
269                 // Dispatch our message.
270                 try {
271                     mTestableLooper.mQueueWrapper.execute(m);
272                 } catch (LooperException e) {
273                     throw e.getSource();
274                 } catch (RuntimeException re) {
275                     // If the TestLooperManager has to post, it will wrap what it throws in a
276                     // RuntimeException, make sure we grab the actual source.
277                     if (re.getCause() instanceof LooperException) {
278                         throw ((LooperException) re.getCause()).getSource();
279                     } else {
280                         throw re.getCause();
281                     }
282                 } finally {
283                     m.recycle();
284                 }
285                 return ret[0];
286             } finally {
287                 if (set) {
288                     mTestableLooper.mQueueWrapper.release();
289                     mTestableLooper.mQueueWrapper = null;
290                 }
291             }
292         }
293 
createLooper()294         private Looper createLooper() {
295             // TODO: Find way to share these.
296             mHandlerThread = new HandlerThread(TestableLooper.class.getSimpleName());
297             mHandlerThread.start();
298             return mHandlerThread.getLooper();
299         }
300 
301         @Override
finalize()302         protected void finalize() throws Throwable {
303             super.finalize();
304             if (mHandlerThread != null) {
305                 mHandlerThread.quit();
306             }
307         }
308 
309         private static class LooperException extends RuntimeException {
310             private final Throwable mSource;
311 
LooperException(Throwable t)312             public LooperException(Throwable t) {
313                 mSource = t;
314             }
315 
getSource()316             public Throwable getSource() {
317                 return mSource;
318             }
319         }
320     }
321 
322     public interface MessageHandler {
323         /**
324          * Return true to have the message executed and delivered to target.
325          * Return false to not execute the message and stop executing messages.
326          */
onMessageHandled(Message m)327         boolean onMessageHandled(Message m);
328     }
329 }
330