1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.view.textclassifier.cts; 17 18 import static com.android.compatibility.common.util.ShellUtils.runShellCommand; 19 20 import android.content.Context; 21 import android.provider.DeviceConfig; 22 import android.support.test.uiautomator.UiDevice; 23 import android.text.TextUtils; 24 import android.util.Log; 25 26 import androidx.annotation.NonNull; 27 import androidx.annotation.Nullable; 28 import androidx.test.InstrumentationRegistry; 29 import androidx.test.core.app.ApplicationProvider; 30 31 import com.android.compatibility.common.util.DeviceConfigStateManager; 32 import com.android.compatibility.common.util.SafeCleanerRule; 33 34 import org.junit.rules.TestWatcher; 35 import org.junit.runner.Description; 36 37 import java.util.ArrayList; 38 import java.util.Collections; 39 import java.util.List; 40 import java.util.concurrent.CountDownLatch; 41 import java.util.concurrent.TimeUnit; 42 43 /** 44 * Custom {@link TestWatcher} that does TextClassifierService setup and reset tasks for the tests. 45 */ 46 final class TextClassifierTestWatcher extends TestWatcher { 47 48 private static final String TAG = "TextClassifierTestWatcher"; 49 private static final long GENERIC_TIMEOUT_MS = 10_000; 50 // TODO: Use default value defined in TextClassificationConstants when TestApi is ready 51 private static final String DEFAULT_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE = null; 52 private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; 53 private static final String DEVICECONFIG_NAME_SPACE = "textclassifier"; 54 55 private String mOriginalOverrideService; 56 private boolean mOriginalSystemTextClassifierEnabled; 57 58 private static final ArrayList<Throwable> sExceptions = new ArrayList<>(); 59 60 private static ServiceWatcher sServiceWatcher; 61 62 @Override starting(Description description)63 protected void starting(Description description) { 64 super.starting(description); 65 // get original settings 66 mOriginalOverrideService = getOriginalOverrideService(); 67 mOriginalSystemTextClassifierEnabled = isSystemTextClassifierEnabled(); 68 69 // set system TextClassifier enabled 70 setAndAssertSystemTextclassifierEnabledSetIfNeeded(true); 71 72 setService(); 73 } 74 75 @Override finished(Description description)76 protected void finished(Description description) { 77 super.finished(description); 78 // restore original settings 79 setAndAssertSystemTextclassifierEnabledSetIfNeeded(mOriginalSystemTextClassifierEnabled); 80 // restore service and make sure service disconnected. 81 // clear the static values. 82 try { 83 resetService(); 84 } catch (Exception e) { 85 throw new RuntimeException(e); 86 } finally { 87 resetStaticState(); 88 } 89 } 90 91 /** 92 * Wait for the TextClassifierService to connect. Note that the system requires a query to the 93 * TextClassifierService before it is first connected. 94 * 95 * @return the CtsTextClassifierService when connected. 96 * 97 * @throws InterruptedException if the current thread is interrupted while waiting. 98 * @throws AssertionError if no CtsTextClassifierService is returned. 99 */ getService()100 CtsTextClassifierService getService() throws InterruptedException, AssertionError { 101 CtsTextClassifierService service = waitServiceLazyConnect(); 102 if (service == null) { 103 throw new AssertionError("Can not get service."); 104 } 105 return service; 106 } 107 108 /** 109 * Waits for the current application to idle. Default wait timeout is 10 seconds 110 */ waitForIdle()111 static void waitForIdle() { 112 UiDevice.getInstance(InstrumentationRegistry.getInstrumentation()) 113 .waitForIdle(); 114 } 115 setServiceWatcher()116 private static void setServiceWatcher() { 117 if (sServiceWatcher == null) { 118 sServiceWatcher = new ServiceWatcher(); 119 } 120 } 121 clearServiceWatcher()122 private static void clearServiceWatcher() { 123 if (sServiceWatcher != null) { 124 sServiceWatcher.mService = null; 125 sServiceWatcher = null; 126 } 127 } 128 resetStaticState()129 private static void resetStaticState() { 130 sExceptions.clear(); 131 clearServiceWatcher(); 132 } 133 134 @Nullable getOriginalOverrideService()135 private String getOriginalOverrideService() { 136 final String deviceConfigSetting = runShellCommand( 137 "device_config get textclassifier textclassifier_service_package_override"); 138 if (!TextUtils.isEmpty(deviceConfigSetting) && !deviceConfigSetting.equals("null")) { 139 return deviceConfigSetting; 140 } 141 return DEFAULT_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE; 142 } 143 isSystemTextClassifierEnabled()144 private boolean isSystemTextClassifierEnabled() { 145 final String deviceConfigSetting = runShellCommand( 146 "device_config get textclassifier system_textclassifier_enabled"); 147 if (!TextUtils.isEmpty(deviceConfigSetting) && !deviceConfigSetting.equals("null")) { 148 return deviceConfigSetting.toLowerCase().equals("true"); 149 } 150 return SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT; 151 } 152 setService()153 private void setService() { 154 setServiceWatcher(); 155 // set the test service 156 setAndAssertServicePackageOverrideSetIfNeeded(CtsTextClassifierService.MY_PACKAGE); 157 158 // Wait for the current bound TCS to be unbounded. 159 try { 160 Thread.sleep(1_000); 161 } catch (InterruptedException e) { 162 Log.e(TAG, "Error while sleeping"); 163 } 164 } 165 resetOriginalService()166 private void resetOriginalService() { 167 Log.d(TAG, "reset to " + mOriginalOverrideService); 168 setAndAssertServicePackageOverrideSetIfNeeded(mOriginalOverrideService); 169 } 170 setAndAssertSystemTextclassifierEnabledSetIfNeeded(boolean value)171 private void setAndAssertSystemTextclassifierEnabledSetIfNeeded(boolean value) { 172 boolean currentValue = isSystemTextClassifierEnabled(); 173 if (currentValue != value) { 174 final Context context = ApplicationProvider.getApplicationContext(); 175 DeviceConfigStateManager stateManager = 176 new DeviceConfigStateManager(context, DEVICECONFIG_NAME_SPACE, 177 "system_textclassifier_enabled"); 178 stateManager.set(Boolean.toString(value)); 179 } 180 } 181 setAndAssertServicePackageOverrideSetIfNeeded(String value)182 private void setAndAssertServicePackageOverrideSetIfNeeded(String value) { 183 String currentValue = getOriginalOverrideService(); 184 if (!TextUtils.equals(currentValue, value)) { 185 final Context context = ApplicationProvider.getApplicationContext(); 186 DeviceConfigStateManager stateManager = 187 new DeviceConfigStateManager(context, DEVICECONFIG_NAME_SPACE, 188 "textclassifier_service_package_override"); 189 stateManager.set(value); 190 } 191 } 192 resetService()193 private void resetService() throws InterruptedException { 194 resetOriginalService(); 195 if (sServiceWatcher != null && sServiceWatcher.mService != null) { 196 sServiceWatcher.waitOnDisconnected(); 197 } else { 198 waitForIdle(); 199 } 200 } 201 202 /** 203 * Returns the TestRule that runs clean up after a test is finished. See {@link SafeCleanerRule} 204 * for more details. 205 */ newSafeCleaner()206 public SafeCleanerRule newSafeCleaner() { 207 return new SafeCleanerRule() 208 .add(() -> { 209 return getExceptions(); 210 }); 211 } 212 213 /** 214 * Gets the exceptions that were thrown while the service handled requests. 215 */ 216 @NonNull getExceptions()217 private static List<Throwable> getExceptions() throws Exception { 218 return Collections.unmodifiableList(sExceptions); 219 } 220 addException(@onNull String fmt, @Nullable Object...args)221 private static void addException(@NonNull String fmt, @Nullable Object...args) { 222 final String msg = String.format(fmt, args); 223 Log.e(TAG, msg); 224 sExceptions.add(new IllegalStateException(msg)); 225 } 226 waitServiceLazyConnect()227 private CtsTextClassifierService waitServiceLazyConnect() throws InterruptedException { 228 if (sServiceWatcher != null) { 229 return sServiceWatcher.waitOnConnected(); 230 } 231 return null; 232 } 233 234 public static final class ServiceWatcher { 235 private final CountDownLatch mCreated = new CountDownLatch(1); 236 private final CountDownLatch mDestroyed = new CountDownLatch(1); 237 238 CtsTextClassifierService mService; 239 onConnected(CtsTextClassifierService service)240 public static void onConnected(CtsTextClassifierService service) { 241 Log.i(TAG, "onConnected: sServiceWatcher=" + sServiceWatcher); 242 243 if (sServiceWatcher == null) { 244 addException("onConnected() without a watcher"); 245 return; 246 } 247 248 if (sServiceWatcher.mService != null) { 249 addException("onConnected(): already created: " + sServiceWatcher); 250 return; 251 } 252 253 sServiceWatcher.mService = service; 254 sServiceWatcher.mCreated.countDown(); 255 } 256 onDisconnected()257 public static void onDisconnected() { 258 Log.i(TAG, "onDisconnected: sServiceWatcher=" + sServiceWatcher); 259 260 if (sServiceWatcher == null) { 261 addException("onDisconnected() without a watcher"); 262 return; 263 } 264 265 if (sServiceWatcher.mService == null) { 266 addException("onDisconnected(): no service on %s", sServiceWatcher); 267 return; 268 } 269 sServiceWatcher.mDestroyed.countDown(); 270 } 271 272 @NonNull waitOnConnected()273 public CtsTextClassifierService waitOnConnected() throws InterruptedException { 274 await(mCreated, "not created"); 275 276 if (mService == null) { 277 throw new IllegalStateException("not created"); 278 } 279 return mService; 280 } 281 waitOnDisconnected()282 public void waitOnDisconnected() throws InterruptedException { 283 await(mDestroyed, "not destroyed"); 284 } 285 await(@onNull CountDownLatch latch, @NonNull String fmt, @Nullable Object... args)286 private void await(@NonNull CountDownLatch latch, @NonNull String fmt, 287 @Nullable Object... args) 288 throws InterruptedException { 289 final boolean called = latch.await(GENERIC_TIMEOUT_MS, TimeUnit.MILLISECONDS); 290 if (!called) { 291 throw new IllegalStateException(String.format(fmt, args) 292 + " in " + GENERIC_TIMEOUT_MS + "ms"); 293 } 294 } 295 } 296 } 297