1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.autofillservice.cts.testcore;
17 
18 import static android.autofillservice.cts.testcore.Timeouts.CONNECTION_TIMEOUT;
19 import static android.autofillservice.cts.testcore.Timeouts.FILL_TIMEOUT;
20 
21 import android.app.assist.AssistStructure;
22 import android.os.CancellationSignal;
23 import android.os.Handler;
24 import android.os.HandlerThread;
25 import android.os.OutcomeReceiver;
26 import android.service.assist.classification.FieldClassificationRequest;
27 import android.service.assist.classification.FieldClassificationResponse;
28 import android.service.assist.classification.FieldClassificationService;
29 import android.util.Log;
30 
31 import androidx.annotation.NonNull;
32 import androidx.annotation.Nullable;
33 
34 import com.android.compatibility.common.util.RetryableException;
35 
36 import java.util.concurrent.BlockingQueue;
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.LinkedBlockingQueue;
39 import java.util.concurrent.TimeUnit;
40 
41 /**
42  * {@link FieldClassificationService} instrumented FieldClassificationService
43  */
44 public class InstrumentedFieldClassificationService extends FieldClassificationService {
45 
46     private static final String TAG = InstrumentedFieldClassificationService.class.getSimpleName();
47     public static final String SERVICE_PACKAGE = Helper.MY_PACKAGE;
48     public static final String SERVICE_CLASS =
49             InstrumentedFieldClassificationService.class.getSimpleName();
50 
51     public static final String SERVICE_NAME = SERVICE_PACKAGE + "/.testcore." + SERVICE_CLASS;
52 
53     private static final Replier sReplier = new Replier();
54 
55     // We must handle all requests in a separate thread as the service's main thread is the also
56     // the UI thread of the test process and we don't want to hose it in case of failures here
57     private static final HandlerThread sMyThread =
58             new HandlerThread("MyInstrumentedFieldClassificationServiceThread");
59 
60     private final Handler mHandler;
61 
62     private final CountDownLatch mConnectedLatch = new CountDownLatch(1);
63     private final CountDownLatch mDisconnectedLatch = new CountDownLatch(1);
64 
65     private static volatile ServiceWatcher sServiceWatcher;
66 
67     static {
Log.i(TAG, "Starting thread " + sMyThread)68         Log.i(TAG, "Starting thread " + sMyThread);
sMyThread.start()69         sMyThread.start();
70     }
71 
InstrumentedFieldClassificationService()72     public InstrumentedFieldClassificationService() {
73         mHandler = Handler.createAsync(sMyThread.getLooper());
74         sReplier.setHandler(mHandler);
75     }
76 
setServiceWatcher()77     public static ServiceWatcher setServiceWatcher() {
78         if (sServiceWatcher != null) {
79             throw new IllegalStateException("There can be only one pcc service");
80         }
81         sServiceWatcher = new ServiceWatcher();
82         return sServiceWatcher;
83     }
84 
85     /**
86      * Waits until the system calls {@link #onConnected()}.
87      */
waitUntilConnected()88     public void waitUntilConnected() throws InterruptedException {
89         await(mConnectedLatch, "not connected");
90     }
91 
92     /**
93      * Awaits for a latch to be counted down.
94      */
await(@onNull CountDownLatch latch, @NonNull String fmt, @Nullable Object... args)95     public static void await(@NonNull CountDownLatch latch, @NonNull String fmt,
96             @Nullable Object... args) throws InterruptedException {
97         final boolean called = latch.await(CONNECTION_TIMEOUT.ms(), TimeUnit.MILLISECONDS);
98         if (!called) {
99             throw new IllegalStateException(String.format(fmt, args)
100                 + " in " + CONNECTION_TIMEOUT.ms() + "ms");
101         }
102     }
103 
104     @Override
onClassificationRequest( android.service.assist.classification.FieldClassificationRequest request, CancellationSignal cancellationSignal, OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver)105     public void onClassificationRequest(
106             android.service.assist.classification.FieldClassificationRequest request,
107             CancellationSignal cancellationSignal,
108             OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver) {
109 
110         sReplier.onClassificationRequest(request.getAssistStructure(), cancellationSignal,
111                 outcomeReceiver);
112     }
113 
114     @Override
onConnected()115     public void onConnected() {
116         Log.i(TAG, "onConnected(): sServiceWatcher=" + sServiceWatcher);
117 
118         if (sServiceWatcher == null) {
119             Log.w(TAG, "onConnected() without a watcher");
120             return;
121         }
122 
123         if (sServiceWatcher.mService != null) {
124             Log.w(TAG, "onConnected(): already created: " + sServiceWatcher);
125             return;
126         }
127 
128         sServiceWatcher.mService = this;
129         sServiceWatcher.mCreated.countDown();
130 
131         if (mConnectedLatch.getCount() == 0) {
132             Log.w(TAG, "already connected: " + mConnectedLatch);
133         }
134         mConnectedLatch.countDown();
135     }
136 
137     @Override
onDisconnected()138     public void onDisconnected() {
139         Log.i(TAG, "onDisconnected(): sServiceWatcher=" + sServiceWatcher);
140 
141         if (mDisconnectedLatch.getCount() == 0) {
142             Log.w(TAG, "already disconnected: " +  mConnectedLatch);
143         }
144         mDisconnectedLatch.countDown();
145 
146         if (sServiceWatcher == null) {
147             Log.w(TAG, "onDisconnected() without a watcher");
148             return;
149         }
150         if (sServiceWatcher.mService == null) {
151             Log.w(TAG, "onDisconnected(): no service on " + sServiceWatcher);
152             return;
153         }
154         sServiceWatcher.mDestroyed.countDown();
155         sServiceWatcher.mService = null;
156         sServiceWatcher = null;
157     }
158 
159     /**
160      * Gets the {@link Replier} singleton.
161      */
getReplier()162     public static Replier getReplier() {
163         return sReplier;
164     }
165 
166     /**
167      * POJO representation of a FieldClassificationRequest
168      */
169     public static final class FieldClassificationRequest {
170         public final AssistStructure assistStructure;
171         public final CancellationSignal cancellationSignal;
172         public final OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver;
173 
FieldClassificationRequest(AssistStructure assistStructure, CancellationSignal cancellationSignal, OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver)174         private FieldClassificationRequest(AssistStructure assistStructure,
175                 CancellationSignal cancellationSignal,
176                 OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver) {
177             this.assistStructure = assistStructure;
178             this.cancellationSignal = cancellationSignal;
179             this.outcomeReceiver = outcomeReceiver;
180         }
181     }
182 
183     /**
184      * Object used to answer a
185      * {@link FieldClassificationService#onClassificationRequest(
186      * android.service.assist.classification.FieldClassificationRequest,
187      * CancellationSignal, OutcomeReceiver<FieldClassificationResponse, Exception>)}
188      * on behalf of a unit test method.
189      */
190     public static final class Replier {
191         private final BlockingQueue<CannedFieldClassificationResponse> mResponses =
192                 new LinkedBlockingQueue<>();
193         private final BlockingQueue<FieldClassificationRequest> mFieldClassificationRequests =
194                 new LinkedBlockingQueue<>();
195 
196         private Handler mHandler;
197 
Replier()198         private Replier() {
199         }
200 
setHandler(Handler handler)201         public void setHandler(Handler handler) {
202             mHandler = handler;
203         }
204 
205         /**
206          * Enqueue the new FieldClassification Request
207          */
onClassificationRequest(AssistStructure assistStructure, CancellationSignal cancellationSignal, OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver)208         public void onClassificationRequest(AssistStructure assistStructure,
209                 CancellationSignal cancellationSignal,
210                 OutcomeReceiver<FieldClassificationResponse, Exception> outcomeReceiver) {
211             try {
212                 CannedFieldClassificationResponse response = null;
213                 try {
214                     response = mResponses.poll(CONNECTION_TIMEOUT.ms(), TimeUnit.MILLISECONDS);
215                 } catch (InterruptedException e) {
216                     Log.w(TAG, "Interrupted getting CannedResponse: " + e);
217                     Thread.currentThread().interrupt();
218                     return;
219                 }
220                 FieldClassificationResponse classificationResponse = response.asResponse(
221                         (id) -> Helper.findNodeByResourceId(assistStructure, id));
222 
223                 Log.v(TAG, "onClassificationRequest: FieldClassificationResponse = "
224                         + classificationResponse);
225                 outcomeReceiver.onResult(classificationResponse);
226             } finally {
227                 Helper.offer(mFieldClassificationRequests, new FieldClassificationRequest(
228                         assistStructure, cancellationSignal, outcomeReceiver),
229                         CONNECTION_TIMEOUT.ms());
230             }
231         }
232 
233         /**
234          * Gets the next field classification request, in the order received.
235          */
getNextFieldClassificationRequest()236         public FieldClassificationRequest getNextFieldClassificationRequest() {
237             FieldClassificationRequest request;
238             try {
239                 request =
240                     mFieldClassificationRequests.poll(FILL_TIMEOUT.ms(), TimeUnit.MILLISECONDS);
241             } catch (InterruptedException e) {
242                 Thread.currentThread().interrupt();
243                 throw new IllegalStateException("Interrupted", e);
244             }
245             if (request == null) {
246                 throw new RetryableException(FILL_TIMEOUT, "onClassificationRequest() not called");
247             }
248             return request;
249         }
250 
251         /**
252          * Asserts all {@link FieldClassificationService#onClassificationRequest(
253          * android.service.assist.classification.FieldClassificationRequest,
254          * CancellationSignal, OutcomeReceiver<FieldClassificationResponse, Exception>)}
255          * received by the service were properly {@link #getNextFieldClassificationRequest()}
256          * handled by the test case.
257          */
assertNoUnhandledFieldClassificationRequests()258         public void assertNoUnhandledFieldClassificationRequests() {
259             if (mFieldClassificationRequests.isEmpty()) return; // Good job, test case!
260 
261             throw new AssertionError(mFieldClassificationRequests.size()
262                 + " unhandled field classification requests: " + mFieldClassificationRequests);
263         }
264 
addResponse(CannedFieldClassificationResponse response)265         public void addResponse(CannedFieldClassificationResponse response) {
266             mResponses.add(response);
267         }
268 
269         /**
270          * Resets its internal state.
271          */
reset()272         public void reset() {
273             mFieldClassificationRequests.clear();
274         }
275     }
276 
277     public static final class ServiceWatcher {
278 
279         private final CountDownLatch mCreated = new CountDownLatch(1);
280         private final CountDownLatch mDestroyed = new CountDownLatch(1);
281 
282         private InstrumentedFieldClassificationService mService;
283 
284         @NonNull
waitOnConnected()285         public InstrumentedFieldClassificationService waitOnConnected()
286                 throws InterruptedException {
287             await(mCreated, "not created");
288 
289             if (mService == null) {
290                 throw new IllegalStateException("not created");
291             }
292 
293             return mService;
294         }
295 
waitOnDisconnected()296         public void waitOnDisconnected() throws InterruptedException {
297             await(mDestroyed, "not destroyed");
298         }
299 
300         @Override
toString()301         public String toString() {
302             return "mService: " + mService + " created: " + (mCreated.getCount() == 0)
303                 + " destroyed: " + (mDestroyed.getCount() == 0);
304         }
305     }
306 }
307