1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core;
18 
19 import android.content.ComponentName;
20 import android.content.Context;
21 import android.content.Intent;
22 import android.content.ServiceConnection;
23 import android.os.Handler;
24 import android.os.IBinder;
25 import android.os.Message;
26 import android.os.Messenger;
27 import android.os.RemoteException;
28 import android.util.Log;
29 
30 import java.util.Optional;
31 import java.util.Timer;
32 import java.util.concurrent.atomic.AtomicBoolean;
33 import java.util.concurrent.atomic.AtomicReference;
34 
35 public class CrashTestCoordinator {
36 
37     private static String TAG = "CrashTestCoordinator";
38 
39     private final Context mContext;
40     private static final Timer mTestTimeoutTimer = new Timer("TestTimeoutTimer");
41     private AtomicBoolean mServiceBound = new AtomicBoolean(false);
42     private final AtomicBoolean mAlreadyNotified = new AtomicBoolean(false);
43     private String mTestName;
44 
45     public interface CrashTestIntentInitializer {
addIntentParams(Intent intent)46         void addIntentParams(Intent intent);
47     }
48 
49     public interface CrashTestCompletionListener {
testCrashed()50         void testCrashed();
51 
testSucceeded()52         void testSucceeded();
53 
testFailed(String cause)54         void testFailed(String cause);
55 
testProgressing(Optional<String> description)56         void testProgressing(Optional<String> description);
57     }
58 
CrashTestCoordinator(Context context)59     public CrashTestCoordinator(Context context) {
60         mContext = context;
61     }
62 
63     class KeepAliveServiceConnection implements ServiceConnection {
64         private final CrashTestCompletionListener mTestCompletionListener;
65         private Messenger mMessenger = null;
66         private IBinder mService = null;
67 
KeepAliveServiceConnection( CrashTestCompletionListener testCompletionListener)68         KeepAliveServiceConnection(
69                 CrashTestCompletionListener testCompletionListener) {
70             mTestCompletionListener = testCompletionListener;
71         }
72 
isServiceAlive()73         public boolean isServiceAlive() {
74             if (mService == null) {
75                 Log.w(TAG, "Keep alive service connection is not bound.");
76                 return false;
77             }
78             return mService.isBinderAlive();
79         }
80 
killServiceProcess()81         public void killServiceProcess() throws RemoteException {
82             mMessenger.send(Message.obtain(null, CrashTestService.KILL_PROCESS));
83         }
84 
onServiceCrashed()85         protected void onServiceCrashed() {
86             Log.w(TAG, "Test service crashed, unbinding and notifying listener");
87             unbindService();
88             mTestCompletionListener.testCrashed();
89         }
90 
91         @Override
onServiceConnected(ComponentName name, IBinder service)92         public void onServiceConnected(ComponentName name, IBinder service) {
93             Log.d(TAG, String.format("Service '%s' connected with binder %s", name, service));
94 
95             mService = service;
96             mMessenger = new Messenger(service);
97 
98             try {
99                 service.linkToDeath(this::onServiceCrashed, 0);
100 
101                 final Message setCommChannelMsg = Message.obtain(null,
102                         CrashTestService.SET_COMM_CHANNEL);
103                 setCommChannelMsg.replyTo = new Messenger(new Handler(msgFromTest -> {
104                     switch (msgFromTest.what) {
105                         case CrashTestService.SUCCESS:
106                             if (!mAlreadyNotified.getAndSet(true)) {
107                                 Log.d(TAG, String.format("Test '%s' succeeded", mTestName));
108                                 mTestCompletionListener.testSucceeded();
109                             }
110                             unbindService();
111                             break;
112 
113                         case CrashTestService.FAILURE:
114                             if (!mAlreadyNotified.getAndSet(true)) {
115                                 String reason = msgFromTest.getData().getString(
116                                         CrashTestService.DESCRIPTION);
117                                 Log.i(TAG,
118                                         String.format("Test '%s' failed with reason: %s", mTestName,
119                                                 reason));
120                                 mTestCompletionListener.testFailed(reason);
121                             }
122                             unbindService();
123                             break;
124 
125                         case CrashTestService.PROGRESS:
126                             String description = msgFromTest.getData().getString(
127                                     CrashTestService.DESCRIPTION);
128                             Log.d(TAG, "Test progress message: " + description);
129                             mTestCompletionListener.testProgressing(
130                                     Optional.ofNullable(description));
131                             break;
132                     }
133                     return true;
134                 }));
135                 mMessenger.send(setCommChannelMsg);
136             } catch (RemoteException serviceShutDown) {
137                 Log.w(TAG, "Unable to talk to service, it might have been shut down",
138                         serviceShutDown);
139                 if (!mAlreadyNotified.getAndSet(true)) {
140                     mTestCompletionListener.testCrashed();
141                 }
142             }
143         }
144 
145         @Override
onServiceDisconnected(ComponentName name)146         public void onServiceDisconnected(ComponentName name) {
147             Log.d(TAG, "Service disconnected");
148             unbindService();
149         }
150     }
151 
152     private final AtomicReference<KeepAliveServiceConnection> mServiceConnection =
153             new AtomicReference<>(null);
154 
155     /**
156      * @throws IllegalStateException if unable to start the service
157      */
startTest(Class<? extends CrashTest> crashTestClass, CrashTestIntentInitializer intentParamsProvider, CrashTestCompletionListener testCompletionListener, boolean separateProcess, String testName)158     public void startTest(Class<? extends CrashTest> crashTestClass,
159             CrashTestIntentInitializer intentParamsProvider,
160             CrashTestCompletionListener testCompletionListener,
161             boolean separateProcess, String testName) {
162 
163         final Intent crashTestServiceIntent = new Intent(mContext,
164                 separateProcess ? OutOfProcessCrashTestService.class
165                         : InProcessCrashTestService.class);
166         crashTestServiceIntent.putExtra(CrashTestService.EXTRA_KEY_CRASH_TEST_CLASS,
167                 crashTestClass.getName());
168         intentParamsProvider.addIntentParams(crashTestServiceIntent);
169 
170         mServiceConnection.set(new KeepAliveServiceConnection(testCompletionListener));
171 
172         mServiceBound.set(mContext.bindService(crashTestServiceIntent, mServiceConnection.get(),
173                 Context.BIND_AUTO_CREATE));
174 
175         if (!mServiceBound.get()) {
176             Log.e(TAG, String.format("Crash test service failed to start %s for test '%s'.",
177                     separateProcess ? " in a separate process"
178                             : "in a local process", testName));
179 
180             throw new IllegalStateException("Unsable to start service");
181         }
182 
183         Log.i(TAG, String.format("Crash test service started %s for test '%s'.",
184                 separateProcess ? " in a separate process"
185                         : "in a local process", testName));
186 
187         mTestName = testName;
188     }
189 
shutdown()190     public void shutdown() {
191         unbindService();
192     }
193 
killCrashTestService()194     public void killCrashTestService() throws RemoteException, IllegalArgumentException {
195         if (!mServiceBound.get()) {
196             throw new IllegalArgumentException("No service bound!");
197         }
198         mServiceConnection.get().killServiceProcess();
199     }
200 
unbindService()201     private void unbindService() {
202         try {
203             KeepAliveServiceConnection sc = mServiceConnection.get();
204             if (sc != null) {
205                 if (mServiceBound.get()) {
206                     Log.i(TAG, "Unbinding service");
207                     mServiceBound.set(false);
208                     mContext.unbindService(sc);
209                 } else {
210                     Log.w(TAG, "Service was not bound!!");
211                 }
212                 mServiceConnection.compareAndSet(sc, null);
213             }
214         } catch (Exception e) {
215             Log.w(TAG,
216                     "Error trying to unbind service, this might be expected if the service crashed.",
217                     e);
218         }
219     }
220 }
221