1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.view.textclassifier.cts;
17 
18 import static com.android.compatibility.common.util.ShellUtils.runShellCommand;
19 
20 import android.content.Context;
21 import android.provider.DeviceConfig;
22 import android.support.test.uiautomator.UiDevice;
23 import android.text.TextUtils;
24 import android.util.Log;
25 
26 import androidx.annotation.NonNull;
27 import androidx.annotation.Nullable;
28 import androidx.test.InstrumentationRegistry;
29 import androidx.test.core.app.ApplicationProvider;
30 
31 import com.android.compatibility.common.util.DeviceConfigStateManager;
32 import com.android.compatibility.common.util.SafeCleanerRule;
33 
34 import org.junit.rules.TestWatcher;
35 import org.junit.runner.Description;
36 
37 import java.util.ArrayList;
38 import java.util.Collections;
39 import java.util.List;
40 import java.util.concurrent.CountDownLatch;
41 import java.util.concurrent.TimeUnit;
42 
43 /**
44  * Custom {@link TestWatcher} that does TextClassifierService setup and reset tasks for the tests.
45  */
46 final class TextClassifierTestWatcher extends TestWatcher {
47 
48     private static final String TAG = "TextClassifierTestWatcher";
49     private static final long GENERIC_TIMEOUT_MS = 10_000;
50     // TODO: Use default value defined in TextClassificationConstants when TestApi is ready
51     private static final String DEFAULT_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE = null;
52     private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
53     private static final String DEVICECONFIG_NAME_SPACE = "textclassifier";
54 
55     private String mOriginalOverrideService;
56     private boolean mOriginalSystemTextClassifierEnabled;
57 
58     private static final ArrayList<Throwable> sExceptions = new ArrayList<>();
59 
60     private static ServiceWatcher sServiceWatcher;
61 
62     @Override
starting(Description description)63     protected void starting(Description description) {
64         super.starting(description);
65         // get original settings
66         mOriginalOverrideService = getOriginalOverrideService();
67         mOriginalSystemTextClassifierEnabled = isSystemTextClassifierEnabled();
68 
69         // set system TextClassifier enabled
70         setAndAssertSystemTextclassifierEnabledSetIfNeeded(true);
71 
72         setService();
73     }
74 
75     @Override
finished(Description description)76     protected void finished(Description description) {
77         super.finished(description);
78         // restore original settings
79         setAndAssertSystemTextclassifierEnabledSetIfNeeded(mOriginalSystemTextClassifierEnabled);
80         // restore service and make sure service disconnected.
81         // clear the static values.
82         try {
83             resetService();
84         } catch (Exception e) {
85             throw new RuntimeException(e);
86         } finally {
87             resetStaticState();
88         }
89     }
90 
91     /**
92      * Wait for the TextClassifierService to connect. Note that the system requires a query to the
93      * TextClassifierService before it is first connected.
94      *
95      * @return the CtsTextClassifierService when connected.
96      *
97      * @throws InterruptedException if the current thread is interrupted while waiting.
98      * @throws AssertionError if no CtsTextClassifierService is returned.
99      */
getService()100     CtsTextClassifierService getService() throws InterruptedException, AssertionError {
101         CtsTextClassifierService service = waitServiceLazyConnect();
102         if (service == null) {
103             throw new AssertionError("Can not get service.");
104         }
105         return service;
106     }
107 
108     /**
109      * Waits for the current application to idle. Default wait timeout is 10 seconds
110      */
waitForIdle()111     static void waitForIdle() {
112         UiDevice.getInstance(InstrumentationRegistry.getInstrumentation())
113                 .waitForIdle();
114     }
115 
setServiceWatcher()116     private static void setServiceWatcher() {
117         if (sServiceWatcher == null) {
118             sServiceWatcher = new ServiceWatcher();
119         }
120     }
121 
clearServiceWatcher()122     private static void clearServiceWatcher() {
123         if (sServiceWatcher != null) {
124             sServiceWatcher.mService = null;
125             sServiceWatcher = null;
126         }
127     }
128 
resetStaticState()129     private static void resetStaticState() {
130         sExceptions.clear();
131         clearServiceWatcher();
132     }
133 
134     @Nullable
getOriginalOverrideService()135     private String getOriginalOverrideService() {
136         final String deviceConfigSetting = runShellCommand(
137                 "device_config get textclassifier textclassifier_service_package_override");
138         if (!TextUtils.isEmpty(deviceConfigSetting) && !deviceConfigSetting.equals("null")) {
139             return deviceConfigSetting;
140         }
141         return DEFAULT_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE;
142     }
143 
isSystemTextClassifierEnabled()144     private boolean isSystemTextClassifierEnabled() {
145         final String deviceConfigSetting = runShellCommand(
146                 "device_config get textclassifier system_textclassifier_enabled");
147         if (!TextUtils.isEmpty(deviceConfigSetting) && !deviceConfigSetting.equals("null")) {
148             return deviceConfigSetting.toLowerCase().equals("true");
149         }
150         return SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT;
151     }
152 
setService()153     private void setService() {
154         setServiceWatcher();
155         // set the test service
156         setAndAssertServicePackageOverrideSetIfNeeded(CtsTextClassifierService.MY_PACKAGE);
157 
158         // Wait for the current bound TCS to be unbounded.
159         try {
160             Thread.sleep(1_000);
161         } catch (InterruptedException e) {
162             Log.e(TAG, "Error while sleeping");
163         }
164     }
165 
resetOriginalService()166     private void resetOriginalService() {
167         Log.d(TAG, "reset to " + mOriginalOverrideService);
168         setAndAssertServicePackageOverrideSetIfNeeded(mOriginalOverrideService);
169     }
170 
setAndAssertSystemTextclassifierEnabledSetIfNeeded(boolean value)171     private void setAndAssertSystemTextclassifierEnabledSetIfNeeded(boolean value) {
172         boolean currentValue = isSystemTextClassifierEnabled();
173         if (currentValue != value) {
174             final Context context = ApplicationProvider.getApplicationContext();
175             DeviceConfigStateManager stateManager =
176                     new DeviceConfigStateManager(context, DEVICECONFIG_NAME_SPACE,
177                             "system_textclassifier_enabled");
178             stateManager.set(Boolean.toString(value));
179         }
180     }
181 
setAndAssertServicePackageOverrideSetIfNeeded(String value)182     private void setAndAssertServicePackageOverrideSetIfNeeded(String value) {
183         String currentValue = getOriginalOverrideService();
184         if (!TextUtils.equals(currentValue, value)) {
185             final Context context = ApplicationProvider.getApplicationContext();
186             DeviceConfigStateManager stateManager =
187                     new DeviceConfigStateManager(context, DEVICECONFIG_NAME_SPACE,
188                             "textclassifier_service_package_override");
189             stateManager.set(value);
190         }
191     }
192 
resetService()193     private void resetService() throws InterruptedException {
194         resetOriginalService();
195         if (sServiceWatcher != null && sServiceWatcher.mService != null) {
196             sServiceWatcher.waitOnDisconnected();
197         } else {
198             waitForIdle();
199         }
200     }
201 
202     /**
203      * Returns the TestRule that runs clean up after a test is finished. See {@link SafeCleanerRule}
204      * for more details.
205      */
newSafeCleaner()206     public SafeCleanerRule newSafeCleaner() {
207         return new SafeCleanerRule()
208                 .add(() -> {
209                     return getExceptions();
210                 });
211     }
212 
213     /**
214      * Gets the exceptions that were thrown while the service handled requests.
215      */
216     @NonNull
getExceptions()217     private static List<Throwable> getExceptions() throws Exception {
218         return Collections.unmodifiableList(sExceptions);
219     }
220 
addException(@onNull String fmt, @Nullable Object...args)221     private static void addException(@NonNull String fmt, @Nullable Object...args) {
222         final String msg = String.format(fmt, args);
223         Log.e(TAG, msg);
224         sExceptions.add(new IllegalStateException(msg));
225     }
226 
waitServiceLazyConnect()227     private CtsTextClassifierService waitServiceLazyConnect() throws InterruptedException {
228         if (sServiceWatcher != null) {
229             return sServiceWatcher.waitOnConnected();
230         }
231         return null;
232     }
233 
234     public static final class ServiceWatcher {
235         private final CountDownLatch mCreated = new CountDownLatch(1);
236         private final CountDownLatch mDestroyed = new CountDownLatch(1);
237 
238         CtsTextClassifierService mService;
239 
onConnected(CtsTextClassifierService service)240         public static void onConnected(CtsTextClassifierService service) {
241             Log.i(TAG, "onConnected:  sServiceWatcher=" + sServiceWatcher);
242 
243             if (sServiceWatcher == null) {
244                 addException("onConnected() without a watcher");
245                 return;
246             }
247 
248             if (sServiceWatcher.mService != null) {
249                 addException("onConnected(): already created: " + sServiceWatcher);
250                 return;
251             }
252 
253             sServiceWatcher.mService = service;
254             sServiceWatcher.mCreated.countDown();
255         }
256 
onDisconnected()257         public static void onDisconnected() {
258             Log.i(TAG, "onDisconnected:  sServiceWatcher=" + sServiceWatcher);
259 
260             if (sServiceWatcher == null) {
261                 addException("onDisconnected() without a watcher");
262                 return;
263             }
264 
265             if (sServiceWatcher.mService == null) {
266                 addException("onDisconnected(): no service on %s", sServiceWatcher);
267                 return;
268             }
269             sServiceWatcher.mDestroyed.countDown();
270         }
271 
272         @NonNull
waitOnConnected()273         public CtsTextClassifierService waitOnConnected() throws InterruptedException {
274             await(mCreated, "not created");
275 
276             if (mService == null) {
277                 throw new IllegalStateException("not created");
278             }
279             return mService;
280         }
281 
waitOnDisconnected()282         public void waitOnDisconnected() throws InterruptedException {
283             await(mDestroyed, "not destroyed");
284         }
285 
await(@onNull CountDownLatch latch, @NonNull String fmt, @Nullable Object... args)286         private void await(@NonNull CountDownLatch latch, @NonNull String fmt,
287                 @Nullable Object... args)
288                 throws InterruptedException {
289             final boolean called = latch.await(GENERIC_TIMEOUT_MS, TimeUnit.MILLISECONDS);
290             if (!called) {
291                 throw new IllegalStateException(String.format(fmt, args)
292                         + " in " + GENERIC_TIMEOUT_MS + "ms");
293             }
294         }
295     }
296 }
297