1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.view.textclassifier.cts;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertNotNull;
23 import static org.junit.Assert.assertTrue;
24 
25 import android.icu.util.ULocale;
26 import android.os.Bundle;
27 import android.os.LocaleList;
28 import android.os.Parcel;
29 import android.os.Parcelable;
30 import android.service.textclassifier.TextClassifierService;
31 import android.view.textclassifier.ConversationAction;
32 import android.view.textclassifier.ConversationActions;
33 import android.view.textclassifier.SelectionEvent;
34 import android.view.textclassifier.TextClassification;
35 import android.view.textclassifier.TextClassificationContext;
36 import android.view.textclassifier.TextClassificationManager;
37 import android.view.textclassifier.TextClassifier;
38 import android.view.textclassifier.TextClassifierEvent;
39 import android.view.textclassifier.TextLanguage;
40 import android.view.textclassifier.TextLinks;
41 import android.view.textclassifier.TextSelection;
42 
43 import androidx.core.os.BuildCompat;
44 import androidx.test.InstrumentationRegistry;
45 import androidx.test.filters.SmallTest;
46 
47 import com.google.common.collect.Range;
48 
49 import org.junit.After;
50 import org.junit.Before;
51 import org.junit.Test;
52 import org.junit.runner.RunWith;
53 import org.junit.runners.Parameterized;
54 
55 import java.util.Arrays;
56 import java.util.Collection;
57 import java.util.Collections;
58 import java.util.HashSet;
59 import java.util.List;
60 
61 @SmallTest
62 @RunWith(Parameterized.class)
63 public class TextClassifierTest {
64     private static final String BUNDLE_KEY = "key";
65     private static final String BUNDLE_VALUE = "value";
66     private static final Bundle BUNDLE = new Bundle();
67     static {
BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE)68         BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE);
69     }
70     private static final LocaleList LOCALES = LocaleList.forLanguageTags("en");
71     private static final int START = 1;
72     private static final int END = 3;
73     // This text has lots of things that are probably entities in many cases.
74     private static final String TEXT = "An email address is test@example.com. A phone number"
75             + " might be +12122537077. Somebody lives at 123 Main Street, Mountain View, CA,"
76             + " and there's good stuff at https://www.android.com :)";
77     private static final TextSelection.Request TEXT_SELECTION_REQUEST =
78             new TextSelection.Request.Builder(TEXT, START, END)
79                     .setDefaultLocales(LOCALES)
80                     .build();
81     private static final TextClassification.Request TEXT_CLASSIFICATION_REQUEST =
82             new TextClassification.Request.Builder(TEXT, START, END)
83                     .setDefaultLocales(LOCALES)
84                     .build();
85     private static final TextLanguage.Request TEXT_LANGUAGE_REQUEST =
86             new TextLanguage.Request.Builder(TEXT)
87                     .setExtras(BUNDLE)
88                     .build();
89     private static final ConversationActions.Message FIRST_MESSAGE =
90             new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_SELF)
91                     .setText(TEXT)
92                     .build();
93     private static final ConversationActions.Message SECOND_MESSAGE =
94             new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
95                     .setText(TEXT)
96                     .build();
97     private static final ConversationActions.Request CONVERSATION_ACTIONS_REQUEST =
98             new ConversationActions.Request.Builder(
99                     Arrays.asList(FIRST_MESSAGE, SECOND_MESSAGE)).build();
100 
101     private static final String CURRENT = "current";
102     private static final String SESSION = "session";
103     private static final String DEFAULT = "default";
104     private static final String NO_OP = "no_op";
105 
106     @Parameterized.Parameters(name = "{0}")
textClassifierTypes()107     public static Iterable<Object> textClassifierTypes() {
108         return Arrays.asList(CURRENT, SESSION, DEFAULT, NO_OP);
109     }
110 
111     @Parameterized.Parameter
112     public String mTextClassifierType;
113 
114     private TextClassifier mClassifier;
115 
116     @Before
setup()117     public void setup() {
118         TextClassificationManager manager = InstrumentationRegistry.getTargetContext()
119                 .getSystemService(TextClassificationManager.class);
120         manager.setTextClassifier(null); // Resets the classifier.
121         if (mTextClassifierType.equals(CURRENT)) {
122             mClassifier = manager.getTextClassifier();
123         } else if (mTextClassifierType.equals(SESSION)) {
124             mClassifier = manager.createTextClassificationSession(
125                     new TextClassificationContext.Builder(
126                             InstrumentationRegistry.getTargetContext().getPackageName(),
127                             TextClassifier.WIDGET_TYPE_TEXTVIEW)
128                             .build());
129         } else if (mTextClassifierType.equals(NO_OP)) {
130             mClassifier = TextClassifier.NO_OP;
131         } else {
132             mClassifier = TextClassifierService.getDefaultTextClassifierImplementation(
133                     InstrumentationRegistry.getTargetContext());
134         }
135     }
136 
137     @After
tearDown()138     public void tearDown() {
139         mClassifier.destroy();
140     }
141 
142     @Test
testTextClassifierDestroy()143     public void testTextClassifierDestroy() {
144         mClassifier.destroy();
145         if (mTextClassifierType.equals(SESSION)) {
146             assertEquals(true, mClassifier.isDestroyed());
147         }
148     }
149 
150     @Test
testGetMaxGenerateLinksTextLength()151     public void testGetMaxGenerateLinksTextLength() {
152         // TODO(b/143249163): Verify the value get from TextClassificationConstants
153         assertTrue(mClassifier.getMaxGenerateLinksTextLength() >= 0);
154     }
155 
156     @Test
testSmartSelection()157     public void testSmartSelection() {
158         assertValidResult(mClassifier.suggestSelection(TEXT_SELECTION_REQUEST));
159     }
160 
161     @Test
testSuggestSelectionWith4Param()162     public void testSuggestSelectionWith4Param() {
163         assertValidResult(mClassifier.suggestSelection(TEXT, START, END, LOCALES));
164     }
165 
166     @Test
testClassifyText()167     public void testClassifyText() {
168         assertValidResult(mClassifier.classifyText(TEXT_CLASSIFICATION_REQUEST));
169     }
170 
171     @Test
testClassifyTextWith4Param()172     public void testClassifyTextWith4Param() {
173         assertValidResult(mClassifier.classifyText(TEXT, START, END, LOCALES));
174     }
175 
176     @Test
testGenerateLinks()177     public void testGenerateLinks() {
178         assertValidResult(mClassifier.generateLinks(new TextLinks.Request.Builder(TEXT).build()));
179     }
180 
181     @Test
testSuggestConversationActions()182     public void testSuggestConversationActions() {
183         ConversationActions conversationActions =
184                 mClassifier.suggestConversationActions(CONVERSATION_ACTIONS_REQUEST);
185 
186         assertValidResult(conversationActions);
187     }
188 
189     @Test
testLanguageDetection()190     public void testLanguageDetection() {
191         assertValidResult(mClassifier.detectLanguage(TEXT_LANGUAGE_REQUEST));
192     }
193 
194     @Test(expected = RuntimeException.class)
testLanguageDetection_nullRequest()195     public void testLanguageDetection_nullRequest() {
196         assertValidResult(mClassifier.detectLanguage(null));
197     }
198 
199     @Test
testOnSelectionEvent()200     public void testOnSelectionEvent() {
201         // Doesn't crash.
202         mClassifier.onSelectionEvent(
203                 SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, 0));
204     }
205 
206     @Test
testOnTextClassifierEvent()207     public void testOnTextClassifierEvent() {
208         // Doesn't crash.
209         mClassifier.onTextClassifierEvent(
210                 new TextClassifierEvent.ConversationActionsEvent.Builder(
211                         TextClassifierEvent.TYPE_SMART_ACTION)
212                         .build());
213     }
214 
215     @Test
testResolveEntityListModifications_only_hints()216     public void testResolveEntityListModifications_only_hints() {
217         TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.createWithHints(
218                 Arrays.asList("some_hint"));
219         assertEquals(1, entityConfig.getHints().size());
220         assertTrue(entityConfig.getHints().contains("some_hint"));
221         assertEquals(new HashSet<String>(Arrays.asList("foo", "bar")),
222                 entityConfig.resolveEntityListModifications(Arrays.asList("foo", "bar")));
223     }
224 
225     @Test
testResolveEntityListModifications_include_exclude()226     public void testResolveEntityListModifications_include_exclude() {
227         TextClassifier.EntityConfig entityConfig = TextClassifier.EntityConfig.create(
228                 Arrays.asList("some_hint"),
229                 Arrays.asList("a", "b", "c"),
230                 Arrays.asList("b", "d", "x"));
231         assertEquals(1, entityConfig.getHints().size());
232         assertTrue(entityConfig.getHints().contains("some_hint"));
233         assertEquals(new HashSet(Arrays.asList("a", "c", "w")),
234                 new HashSet(entityConfig.resolveEntityListModifications(
235                         Arrays.asList("c", "w", "x"))));
236     }
237 
238     @Test
testResolveEntityListModifications_explicit()239     public void testResolveEntityListModifications_explicit() {
240         TextClassifier.EntityConfig entityConfig =
241                 TextClassifier.EntityConfig.createWithExplicitEntityList(Arrays.asList("a", "b"));
242         assertEquals(Collections.EMPTY_LIST, entityConfig.getHints());
243         assertEquals(new HashSet<String>(Arrays.asList("a", "b")),
244                 entityConfig.resolveEntityListModifications(Arrays.asList("w", "x")));
245     }
246 
247     @Test
testEntityConfig_full()248     public void testEntityConfig_full() {
249         TextClassifier.EntityConfig entityConfig =
250                 new TextClassifier.EntityConfig.Builder()
251                         .setIncludedTypes(
252                                 Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
253                         .setExcludedTypes(
254                                 Collections.singletonList(ConversationAction.TYPE_CALL_PHONE))
255                         .build();
256 
257         TextClassifier.EntityConfig recovered =
258                 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
259 
260         assertFullEntityConfig(entityConfig);
261         assertFullEntityConfig(recovered);
262     }
263 
264     @Test
testEntityConfig_full_notIncludeTypesFromTextClassifier()265     public void testEntityConfig_full_notIncludeTypesFromTextClassifier() {
266         TextClassifier.EntityConfig entityConfig =
267                 new TextClassifier.EntityConfig.Builder()
268                         .includeTypesFromTextClassifier(false)
269                         .setIncludedTypes(
270                                 Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
271                         .setExcludedTypes(
272                                 Collections.singletonList(ConversationAction.TYPE_CALL_PHONE))
273                         .build();
274 
275         TextClassifier.EntityConfig recovered =
276                 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
277 
278         assertFullEntityConfig_notIncludeTypesFromTextClassifier(entityConfig);
279         assertFullEntityConfig_notIncludeTypesFromTextClassifier(recovered);
280     }
281 
282     @Test
testEntityConfig_minimal()283     public void testEntityConfig_minimal() {
284         TextClassifier.EntityConfig entityConfig =
285                 new TextClassifier.EntityConfig.Builder().build();
286 
287         TextClassifier.EntityConfig recovered =
288                 parcelizeDeparcelize(entityConfig, TextClassifier.EntityConfig.CREATOR);
289 
290         assertMinimalEntityConfig(entityConfig);
291         assertMinimalEntityConfig(recovered);
292     }
293 
assertValidResult(TextSelection selection)294     private static void assertValidResult(TextSelection selection) {
295         assertNotNull(selection);
296         assertTrue(selection.getSelectionStartIndex() >= 0);
297         assertTrue(selection.getSelectionEndIndex() > selection.getSelectionStartIndex());
298         assertTrue(selection.getEntityCount() >= 0);
299         for (int i = 0; i < selection.getEntityCount(); i++) {
300             final String entity = selection.getEntity(i);
301             assertNotNull(entity);
302             final float confidenceScore = selection.getConfidenceScore(entity);
303             assertTrue(confidenceScore >= 0);
304             assertTrue(confidenceScore <= 1);
305         }
306         if (BuildCompat.isAtLeastS()) {
307             assertThat(selection.getTextClassification()).isNull();
308         }
309     }
310 
assertValidResult(TextClassification classification)311     private static void assertValidResult(TextClassification classification) {
312         assertNotNull(classification);
313         assertTrue(classification.getEntityCount() >= 0);
314         for (int i = 0; i < classification.getEntityCount(); i++) {
315             final String entity = classification.getEntity(i);
316             assertNotNull(entity);
317             final float confidenceScore = classification.getConfidenceScore(entity);
318             assertTrue(confidenceScore >= 0);
319             assertTrue(confidenceScore <= 1);
320         }
321         assertNotNull(classification.getActions());
322     }
323 
assertValidResult(TextLinks links)324     private static void assertValidResult(TextLinks links) {
325         assertNotNull(links);
326         for (TextLinks.TextLink link : links.getLinks()) {
327             assertTrue(link.getEntityCount() > 0);
328             assertTrue(link.getStart() >= 0);
329             assertTrue(link.getStart() <= link.getEnd());
330             for (int i = 0; i < link.getEntityCount(); i++) {
331                 String entityType = link.getEntity(i);
332                 assertNotNull(entityType);
333                 final float confidenceScore = link.getConfidenceScore(entityType);
334                 assertTrue(confidenceScore >= 0);
335                 assertTrue(confidenceScore <= 1);
336             }
337         }
338     }
339 
assertValidResult(TextLanguage language)340     private static void assertValidResult(TextLanguage language) {
341         assertNotNull(language);
342         assertNotNull(language.getExtras());
343         assertTrue(language.getLocaleHypothesisCount() >= 0);
344         for (int i = 0; i < language.getLocaleHypothesisCount(); i++) {
345             final ULocale locale = language.getLocale(i);
346             assertNotNull(locale);
347             final float confidenceScore = language.getConfidenceScore(locale);
348             assertTrue(confidenceScore >= 0);
349             assertTrue(confidenceScore <= 1);
350         }
351     }
352 
assertValidResult(ConversationActions conversationActions)353     private static void assertValidResult(ConversationActions conversationActions) {
354         assertNotNull(conversationActions);
355         List<ConversationAction> conversationActionsList =
356                 conversationActions.getConversationActions();
357         assertNotNull(conversationActionsList);
358         for (ConversationAction conversationAction : conversationActionsList) {
359             assertThat(conversationAction.getType()).isNotNull();
360             assertThat(conversationAction.getConfidenceScore()).isIn(Range.closed(0f, 1.0f));
361         }
362     }
363 
assertFullEntityConfig_notIncludeTypesFromTextClassifier( TextClassifier.EntityConfig typeConfig)364     private static void assertFullEntityConfig_notIncludeTypesFromTextClassifier(
365             TextClassifier.EntityConfig typeConfig) {
366         List<String> extraTypesFromTextClassifier = Arrays.asList(
367                 ConversationAction.TYPE_CALL_PHONE,
368                 ConversationAction.TYPE_CREATE_REMINDER);
369 
370         Collection<String> resolvedTypes =
371                 typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier);
372 
373         assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isFalse();
374         assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList()))
375                 .containsExactly(ConversationAction.TYPE_OPEN_URL);
376         assertThat(resolvedTypes).containsExactly(ConversationAction.TYPE_OPEN_URL);
377     }
378 
assertFullEntityConfig(TextClassifier.EntityConfig typeConfig)379     private static void assertFullEntityConfig(TextClassifier.EntityConfig typeConfig) {
380         List<String> extraTypesFromTextClassifier = Arrays.asList(
381                 ConversationAction.TYPE_CALL_PHONE,
382                 ConversationAction.TYPE_CREATE_REMINDER);
383 
384         Collection<String> resolvedTypes =
385                 typeConfig.resolveEntityListModifications(extraTypesFromTextClassifier);
386 
387         assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue();
388         assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList()))
389                 .containsExactly(ConversationAction.TYPE_OPEN_URL);
390         assertThat(resolvedTypes).containsExactly(
391                 ConversationAction.TYPE_OPEN_URL, ConversationAction.TYPE_CREATE_REMINDER);
392     }
393 
assertMinimalEntityConfig(TextClassifier.EntityConfig typeConfig)394     private static void assertMinimalEntityConfig(TextClassifier.EntityConfig typeConfig) {
395         assertThat(typeConfig.shouldIncludeTypesFromTextClassifier()).isTrue();
396         assertThat(typeConfig.resolveEntityListModifications(Collections.emptyList())).isEmpty();
397         assertThat(typeConfig.resolveEntityListModifications(
398                 Collections.singletonList(ConversationAction.TYPE_OPEN_URL))).containsExactly(
399                 ConversationAction.TYPE_OPEN_URL);
400     }
401 
parcelizeDeparcelize( T parcelable, Parcelable.Creator<T> creator)402     private static <T extends Parcelable> T parcelizeDeparcelize(
403             T parcelable, Parcelable.Creator<T> creator) {
404         Parcel parcel = Parcel.obtain();
405         parcelable.writeToParcel(parcel, 0);
406         parcel.setDataPosition(0);
407         return creator.createFromParcel(parcel);
408     }
409 }
410