1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.view.textclassifier.cts; 18 19 import static com.google.common.truth.Truth.assertThat; 20 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertNotNull; 23 import static org.junit.Assert.assertTrue; 24 25 import android.icu.util.ULocale; 26 import android.os.Bundle; 27 import android.os.LocaleList; 28 import android.os.Parcel; 29 import android.os.Parcelable; 30 import android.service.textclassifier.TextClassifierService; 31 import android.view.textclassifier.ConversationAction; 32 import android.view.textclassifier.ConversationActions; 33 import android.view.textclassifier.SelectionEvent; 34 import android.view.textclassifier.TextClassification; 35 import android.view.textclassifier.TextClassificationContext; 36 import android.view.textclassifier.TextClassificationManager; 37 import android.view.textclassifier.TextClassifier; 38 import android.view.textclassifier.TextClassifierEvent; 39 import android.view.textclassifier.TextLanguage; 40 import android.view.textclassifier.TextLinks; 41 import android.view.textclassifier.TextSelection; 42 43 import androidx.core.os.BuildCompat; 44 import androidx.test.InstrumentationRegistry; 45 import androidx.test.filters.SmallTest; 46 47 import com.google.common.collect.Range; 48 49 import org.junit.After; 50 import org.junit.Before; 51 import org.junit.Test; 52 import org.junit.runner.RunWith; 53 import org.junit.runners.Parameterized; 54 55 import java.util.Arrays; 56 import java.util.Collection; 57 import java.util.Collections; 58 import java.util.HashSet; 59 import java.util.List; 60 61 @SmallTest 62 @RunWith(Parameterized.class) 63 public class TextClassifierTest { 64 private static final String BUNDLE_KEY = "key"; 65 private static final String BUNDLE_VALUE = "value"; 66 private static final Bundle BUNDLE = new Bundle(); 67 static { BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE)68 BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE); 69 } 70 private static final LocaleList LOCALES = LocaleList.forLanguageTags("en"); 71 private static final int START = 1; 72 private static final int END = 3; 73 // This text has lots of things that are probably entities in many cases. 74 private static final String TEXT = "An email address is test@example.com. A phone number" 75 + " might be +12122537077. Somebody lives at 123 Main Street, Mountain View, CA," 76 + " and there's good stuff at https://www.android.com :)"; 77 private static final TextSelection.Request TEXT_SELECTION_REQUEST = 78 new TextSelection.Request.Builder(TEXT, START, END) 79 .setDefaultLocales(LOCALES) 80 .build(); 81 private static final TextClassification.Request TEXT_CLASSIFICATION_REQUEST = 82 new TextClassification.Request.Builder(TEXT, START, END) 83 .setDefaultLocales(LOCALES) 84 .build(); 85 private static final TextLanguage.Request TEXT_LANGUAGE_REQUEST = 86 new TextLanguage.Request.Builder(TEXT) 87 .setExtras(BUNDLE) 88 .build(); 89 private static final ConversationActions.Message FIRST_MESSAGE = 90 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_SELF) 91 .setText(TEXT) 92 .build(); 93 private static final ConversationActions.Message SECOND_MESSAGE = 94 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 95 .setText(TEXT) 96 .build(); 97 private static final ConversationActions.Request CONVERSATION_ACTIONS_REQUEST = 98 new ConversationActions.Request.Builder( 99 Arrays.asList(FIRST_MESSAGE, SECOND_MESSAGE)).build(); 100 101 private static final String CURRENT = "current"; 102 private static final String SESSION = "session"; 103 private static final String DEFAULT = "default"; 104 private static final String NO_OP = "no_op"; 105 106 @Parameterized.Parameters(name = "{0}") textClassifierTypes()107 public static Iterable<Object> textClassifierTypes() { 108 return Arrays.asList(CURRENT, SESSION, DEFAULT, NO_OP); 109 } 110 111 @Parameterized.Parameter 112 public String mTextClassifierType; 113 114 private TextClassifier mClassifier; 115 116 @Before setup()117 public void setup() { 118 TextClassificationManager manager = InstrumentationRegistry.getTargetContext() 119 .getSystemService(TextClassificationManager.class); 120 manager.setTextClassifier(null); // Resets the classifier. 121 if (mTextClassifierType.equals(CURRENT)) { 122 mClassifier = manager.getTextClassifier(); 123 } else if (mTextClassifierType.equals(SESSION)) { 124 mClassifier = manager.createTextClassificationSession( 125 new TextClassificationContext.Builder( 126 InstrumentationRegistry.getTargetContext().getPackageName(), 127 TextClassifier.WIDGET_TYPE_TEXTVIEW) 128 .build()); 129 } else if (mTextClassifierType.equals(NO_OP)) { 130 mClassifier = TextClassifier.NO_OP; 131 } else { 132 mClassifier = TextClassifierService.getDefaultTextClassifierImplementation( 133 InstrumentationRegistry.getTargetContext()); 134 } 135 } 136 137 @After tearDown()138 public void tearDown() { 139 mClassifier.destroy(); 140 } 141 142 @Test testTextClassifierDestroy()143 public void testTextClassifierDestroy() { 144 mClassifier.destroy(); 145 if (mTextClassifierType.equals(SESSION)) { 146 assertEquals(true, mClassifier.isDestroyed()); 147 } 148 } 149 150 @Test testGetMaxGenerateLinksTextLength()151 public void testGetMaxGenerateLinksTextLength() { 152 // TODO(b/143249163): Verify the value get from TextClassificationConstants 153 assertTrue(mClassifier.getMaxGenerateLinksTextLength() >= 0); 154 } 155 156 @Test testSmartSelection()157 public void testSmartSelection() { 158 assertValidResult(mClassifier.suggestSelection(TEXT_SELECTION_REQUEST)); 159 } 160 161 @Test testSuggestSelectionWith4Param()162 public void testSuggestSelectionWith4Param() { 163 assertValidResult(mClassifier.suggestSelection(TEXT, START, END, LOCALES)); 164 } 165 166 @Test testClassifyText()167 public void testClassifyText() { 168 assertValidResult(mClassifier.classifyText(TEXT_CLASSIFICATION_REQUEST)); 169 } 170 171 @Test testClassifyTextWith4Param()172 public void testClassifyTextWith4Param() { 173 assertValidResult(mClassifier.classifyText(TEXT, START, END, LOCALES)); 174 } 175 176 @Test testGenerateLinks()177 public void testGenerateLinks() { 178 assertValidResult(mClassifier.generateLinks(new TextLinks.Request.Builder(TEXT).build())); 179 } 180 181 @Test testSuggestConversationActions()182 public void testSuggestConversationActions() { 183 ConversationActions conversationActions = 184 mClassifier.suggestConversationActions(CONVERSATION_ACTIONS_REQUEST); 185 186 assertValidResult(conversationActions); 187 } 188 189 @Test testLanguageDetection()190 public void testLanguageDetection() { 191 assertValidResult(mClassifier.detectLanguage(TEXT_LANGUAGE_REQUEST)); 192 } 193 194 @Test(expected = RuntimeException.class) testLanguageDetection_nullRequest()195 public void testLanguageDetection_nullRequest() { 196 assertValidResult(mClassifier.detectLanguage(null)); 197 } 198 199 @Test testOnSelectionEvent()200 public void testOnSelectionEvent() { 201 // Doesn't crash. 202 mClassifier.onSelectionEvent( 203 SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, 0)); 204 } 205 206 @Test testOnTextClassifierEvent()207 public void testOnTextClassifierEvent() { 208 // Doesn't crash. 209 mClassifier.onTextClassifierEvent( 210 new TextClassifierEvent.ConversationActionsEvent.Builder( 211 TextClassifierEvent.TYPE_SMART_ACTION) 212 .build()); 213 } 214 215 @Test testResolveEntityListModifications_only_hints()216 public void testResolveEntityListModifications_only_hints() { 217 TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.createWithHints( 218 Arrays.asList("some_hint")); 219 assertEquals(1, entityConfig.getHints().size()); 220 assertTrue(entityConfig.getHints().contains("some_hint")); 221 assertEquals(new HashSet<String>(Arrays.asList("foo", "bar")), 222 entityConfig.resolveEntityListModifications(Arrays.asList("foo", "bar"))); 223 } 224 225 @Test testResolveEntityListModifications_include_exclude()226 public void testResolveEntityListModifications_include_exclude() { 227 TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.create( 228 Arrays.asList("some_hint"), 229 Arrays.asList("a", "b", "c"), 230 Arrays.asList("b", "d", "x")); 231 assertEquals(1, entityConfig.getHints().size()); 232 assertTrue(entityConfig.getHints().contains("some_hint")); 233 assertEquals(new HashSet(Arrays.asList("a", "c", "w")), 234 new HashSet(entityConfig.resolveEntityListModifications( 235 Arrays.asList("c", "w", "x")))); 236 } 237 238 @Test testResolveEntityListModifications_explicit()239 public void testResolveEntityListModifications_explicit() { 240 TextClassifier.EntityConfig entityConfig = 241 TextClassifier.EntityConfig.createWithExplicitEntityList(Arrays.asList("a", "b")); 242 assertEquals(Collections.EMPTY_LIST, entityConfig.getHints()); 243 assertEquals(new HashSet<String>(Arrays.asList("a", "b")), 244 entityConfig.resolveEntityListModifications(Arrays.asList("w", "x"))); 245 } 246 247 @Test testEntityConfig_full()248 public void testEntityConfig_full() { 249 TextClassifier.EntityConfig entityConfig = 250 new TextClassifier.EntityConfig.Builder() 251 .setIncludedTypes( 252 Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) 253 .setExcludedTypes( 254 Collections.singletonList(ConversationAction.TYPE_CALL_PHONE)) 255 .build(); 256 257 TextClassifier.EntityConfig recovered = 258 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR); 259 260 assertFullEntityConfig(entityConfig); 261 assertFullEntityConfig(recovered); 262 } 263 264 @Test testEntityConfig_full_notIncludeTypesFromTextClassifier()265 public void testEntityConfig_full_notIncludeTypesFromTextClassifier() { 266 TextClassifier.EntityConfig entityConfig = 267 new TextClassifier.EntityConfig.Builder() 268 .includeTypesFromTextClassifier(false) 269 .setIncludedTypes( 270 Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) 271 .setExcludedTypes( 272 Collections.singletonList(ConversationAction.TYPE_CALL_PHONE)) 273 .build(); 274 275 TextClassifier.EntityConfig recovered = 276 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR); 277 278 assertFullEntityConfig_notIncludeTypesFromTextClassifier(entityConfig); 279 assertFullEntityConfig_notIncludeTypesFromTextClassifier(recovered); 280 } 281 282 @Test testEntityConfig_minimal()283 public void testEntityConfig_minimal() { 284 TextClassifier.EntityConfig entityConfig = 285 new TextClassifier.EntityConfig.Builder().build(); 286 287 TextClassifier.EntityConfig recovered = 288 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR); 289 290 assertMinimalEntityConfig(entityConfig); 291 assertMinimalEntityConfig(recovered); 292 } 293 assertValidResult(TextSelection selection)294 private static void assertValidResult(TextSelection selection) { 295 assertNotNull(selection); 296 assertTrue(selection.getSelectionStartIndex() >= 0); 297 assertTrue(selection.getSelectionEndIndex() > selection.getSelectionStartIndex()); 298 assertTrue(selection.getEntityCount() >= 0); 299 for (int i = 0; i < selection.getEntityCount(); i++) { 300 final String entity = selection.getEntity(i); 301 assertNotNull(entity); 302 final float confidenceScore = selection.getConfidenceScore(entity); 303 assertTrue(confidenceScore >= 0); 304 assertTrue(confidenceScore <= 1); 305 } 306 if (BuildCompat.isAtLeastS()) { 307 assertThat(selection.getTextClassification()).isNull(); 308 } 309 } 310 assertValidResult(TextClassification classification)311 private static void assertValidResult(TextClassification classification) { 312 assertNotNull(classification); 313 assertTrue(classification.getEntityCount() >= 0); 314 for (int i = 0; i < classification.getEntityCount(); i++) { 315 final String entity = classification.getEntity(i); 316 assertNotNull(entity); 317 final float confidenceScore = classification.getConfidenceScore(entity); 318 assertTrue(confidenceScore >= 0); 319 assertTrue(confidenceScore <= 1); 320 } 321 assertNotNull(classification.getActions()); 322 } 323 assertValidResult(TextLinks links)324 private static void assertValidResult(TextLinks links) { 325 assertNotNull(links); 326 for (TextLinks.TextLink link : links.getLinks()) { 327 assertTrue(link.getEntityCount() > 0); 328 assertTrue(link.getStart() >= 0); 329 assertTrue(link.getStart() <= link.getEnd()); 330 for (int i = 0; i < link.getEntityCount(); i++) { 331 String entityType = link.getEntity(i); 332 assertNotNull(entityType); 333 final float confidenceScore = link.getConfidenceScore(entityType); 334 assertTrue(confidenceScore >= 0); 335 assertTrue(confidenceScore <= 1); 336 } 337 } 338 } 339 assertValidResult(TextLanguage language)340 private static void assertValidResult(TextLanguage language) { 341 assertNotNull(language); 342 assertNotNull(language.getExtras()); 343 assertTrue(language.getLocaleHypothesisCount() >= 0); 344 for (int i = 0; i < language.getLocaleHypothesisCount(); i++) { 345 final ULocale locale = language.getLocale(i); 346 assertNotNull(locale); 347 final float confidenceScore = language.getConfidenceScore(locale); 348 assertTrue(confidenceScore >= 0); 349 assertTrue(confidenceScore <= 1); 350 } 351 } 352 assertValidResult(ConversationActions conversationActions)353 private static void assertValidResult(ConversationActions conversationActions) { 354 assertNotNull(conversationActions); 355 List<ConversationAction> conversationActionsList = 356 conversationActions.getConversationActions(); 357 assertNotNull(conversationActionsList); 358 for (ConversationAction conversationAction : conversationActionsList) { 359 assertThat(conversationAction.getType()).isNotNull(); 360 assertThat(conversationAction.getConfidenceScore()).isIn(Range.closed(0f, 1.0f)); 361 } 362 } 363 assertFullEntityConfig_notIncludeTypesFromTextClassifier( TextClassifier.EntityConfig typeConfig)364 private static void assertFullEntityConfig_notIncludeTypesFromTextClassifier( 365 TextClassifier.EntityConfig typeConfig) { 366 List<String> extraTypesFromTextClassifier = Arrays.asList( 367 ConversationAction.TYPE_CALL_PHONE, 368 ConversationAction.TYPE_CREATE_REMINDER); 369 370 Collection<String> resolvedTypes = 371 typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier); 372 373 assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isFalse(); 374 assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList())) 375 .containsExactly(ConversationAction.TYPE_OPEN_URL); 376 assertThat(resolvedTypes).containsExactly(ConversationAction.TYPE_OPEN_URL); 377 } 378 assertFullEntityConfig(TextClassifier.EntityConfig typeConfig)379 private static void assertFullEntityConfig(TextClassifier.EntityConfig typeConfig) { 380 List<String> extraTypesFromTextClassifier = Arrays.asList( 381 ConversationAction.TYPE_CALL_PHONE, 382 ConversationAction.TYPE_CREATE_REMINDER); 383 384 Collection<String> resolvedTypes = 385 typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier); 386 387 assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue(); 388 assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList())) 389 .containsExactly(ConversationAction.TYPE_OPEN_URL); 390 assertThat(resolvedTypes).containsExactly( 391 ConversationAction.TYPE_OPEN_URL, ConversationAction.TYPE_CREATE_REMINDER); 392 } 393 assertMinimalEntityConfig(TextClassifier.EntityConfig typeConfig)394 private static void assertMinimalEntityConfig(TextClassifier.EntityConfig typeConfig) { 395 assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue(); 396 assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList())).isEmpty(); 397 assertThat(typeConfig.resolveEntityListModifications( 398 Collections.singletonList(ConversationAction.TYPE_OPEN_URL))).containsExactly( 399 ConversationAction.TYPE_OPEN_URL); 400 } 401 parcelizeDeparcelize( T parcelable, Parcelable.Creator<T> creator)402 private static <T extends Parcelable> T parcelizeDeparcelize( 403 T parcelable, Parcelable.Creator<T> creator) { 404 Parcel parcel = Parcel.obtain(); 405 parcelable.writeToParcel(parcel, 0); 406 parcel.setDataPosition(0); 407 return creator.createFromParcel(parcel); 408 } 409 } 410