1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.android.textclassifier;
18 
19 import java.util.Collection;
20 import java.util.concurrent.atomic.AtomicBoolean;
21 
22 /**
23  * Java wrapper for Annotator native library interface. This library is used for detecting entities
24  * in text.
25  *
26  * @hide
27  */
28 public final class AnnotatorModel implements AutoCloseable {
29   private final AtomicBoolean isClosed = new AtomicBoolean(false);
30 
31   static {
32     System.loadLibrary("textclassifier");
33   }
34 
35   // Keep these in sync with the constants defined in AOSP.
36   static final String TYPE_UNKNOWN = "";
37   static final String TYPE_OTHER = "other";
38   static final String TYPE_EMAIL = "email";
39   static final String TYPE_PHONE = "phone";
40   static final String TYPE_ADDRESS = "address";
41   static final String TYPE_URL = "url";
42   static final String TYPE_DATE = "date";
43   static final String TYPE_DATE_TIME = "datetime";
44   static final String TYPE_FLIGHT_NUMBER = "flight";
45 
46   private long annotatorPtr;
47 
48   /** Enumeration for specifying the usecase of the annotations. */
49   public static enum AnnotationUsecase {
50     /** Results are optimized for Smart{Select,Share,Linkify}. */
51     SMART(0),
52 
53     /**
54      * Results are optimized for using TextClassifier as an infrastructure that annotates as much as
55      * possible.
56      */
57     RAW(1);
58 
59     private final int value;
60 
AnnotationUsecase(int value)61     AnnotationUsecase(int value) {
62       this.value = value;
63     }
64 
getValue()65     public int getValue() {
66       return value;
67     }
68   };
69 
70   /**
71    * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
72    * file descriptor.
73    */
AnnotatorModel(int fileDescriptor)74   public AnnotatorModel(int fileDescriptor) {
75     annotatorPtr = nativeNewAnnotator(fileDescriptor);
76     if (annotatorPtr == 0L) {
77       throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
78     }
79   }
80 
81   /**
82    * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
83    * file path.
84    */
AnnotatorModel(String path)85   public AnnotatorModel(String path) {
86     annotatorPtr = nativeNewAnnotatorFromPath(path);
87     if (annotatorPtr == 0L) {
88       throw new IllegalArgumentException("Couldn't initialize TC from given file.");
89     }
90   }
91 
92   /** Initializes the knowledge engine, passing the given serialized config to it. */
initializeKnowledgeEngine(byte[] serializedConfig)93   public void initializeKnowledgeEngine(byte[] serializedConfig) {
94     if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) {
95       throw new IllegalArgumentException("Couldn't initialize the KG engine");
96     }
97   }
98 
99   /** Initializes the contact engine, passing the given serialized config to it. */
initializeContactEngine(byte[] serializedConfig)100   public void initializeContactEngine(byte[] serializedConfig) {
101     if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) {
102       throw new IllegalArgumentException("Couldn't initialize the contact engine");
103     }
104   }
105 
106   /** Initializes the installed app engine, passing the given serialized config to it. */
initializeInstalledAppEngine(byte[] serializedConfig)107   public void initializeInstalledAppEngine(byte[] serializedConfig) {
108     if (!nativeInitializeInstalledAppEngine(annotatorPtr, serializedConfig)) {
109       throw new IllegalArgumentException("Couldn't initialize the installed app engine");
110     }
111   }
112 
113   /**
114    * Given a string context and current selection, computes the selection suggestion.
115    *
116    * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the
117    * character index where the selection begins, and selectionEnd is the index of one character past
118    * the selection span.
119    *
120    * <p>The return value is an array of two ints: suggested selection beginning and end, with the
121    * same semantics as the input selectionBeginning and selectionEnd.
122    */
suggestSelection( String context, int selectionBegin, int selectionEnd, SelectionOptions options)123   public int[] suggestSelection(
124       String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
125     return nativeSuggestSelection(annotatorPtr, context, selectionBegin, selectionEnd, options);
126   }
127 
128   /**
129    * Given a string context and current selection, classifies the type of the selected text.
130    *
131    * <p>The begin and end params are character indices in the context string.
132    *
133    * <p>Returns an array of ClassificationResult objects with the probability scores for different
134    * collections.
135    */
classifyText( String context, int selectionBegin, int selectionEnd, ClassificationOptions options)136   public ClassificationResult[] classifyText(
137       String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
138     return classifyText(
139         context,
140         selectionBegin,
141         selectionEnd,
142         options,
143         /*appContext=*/ null,
144         /*deviceLocales=*/ null);
145   }
146 
classifyText( String context, int selectionBegin, int selectionEnd, ClassificationOptions options, Object appContext, String deviceLocales)147   public ClassificationResult[] classifyText(
148       String context,
149       int selectionBegin,
150       int selectionEnd,
151       ClassificationOptions options,
152       Object appContext,
153       String deviceLocales) {
154     return nativeClassifyText(
155         annotatorPtr, context, selectionBegin, selectionEnd, options, appContext, deviceLocales);
156   }
157 
158   /**
159    * Annotates given input text. The annotations should cover the whole input context except for
160    * whitespaces, and are sorted by their position in the context string.
161    */
annotate(String text, AnnotationOptions options)162   public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
163     return nativeAnnotate(annotatorPtr, text, options);
164   }
165 
166   /**
167    * Looks up a knowledge entity by its identifier. Returns null if the entity is not found or on
168    * error.
169    */
lookUpKnowledgeEntity(String id)170   public byte[] lookUpKnowledgeEntity(String id) {
171     return nativeLookUpKnowledgeEntity(annotatorPtr, id);
172   }
173 
174   /** Frees up the allocated memory. */
175   @Override
close()176   public void close() {
177     if (isClosed.compareAndSet(false, true)) {
178       nativeCloseAnnotator(annotatorPtr);
179       annotatorPtr = 0L;
180     }
181   }
182 
183   @Override
finalize()184   protected void finalize() throws Throwable {
185     try {
186       close();
187     } finally {
188       super.finalize();
189     }
190   }
191 
192   /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
getLocales(int fd)193   public static String getLocales(int fd) {
194     return nativeGetLocales(fd);
195   }
196 
197   /** Returns the version of the model. */
getVersion(int fd)198   public static int getVersion(int fd) {
199     return nativeGetVersion(fd);
200   }
201 
202   /** Returns the name of the model. */
getName(int fd)203   public static String getName(int fd) {
204     return nativeGetName(fd);
205   }
206 
207   /** Information about a parsed time/date. */
208   public static final class DatetimeResult {
209 
210     public static final int GRANULARITY_YEAR = 0;
211     public static final int GRANULARITY_MONTH = 1;
212     public static final int GRANULARITY_WEEK = 2;
213     public static final int GRANULARITY_DAY = 3;
214     public static final int GRANULARITY_HOUR = 4;
215     public static final int GRANULARITY_MINUTE = 5;
216     public static final int GRANULARITY_SECOND = 6;
217 
218     private final long timeMsUtc;
219     private final int granularity;
220 
DatetimeResult(long timeMsUtc, int granularity)221     public DatetimeResult(long timeMsUtc, int granularity) {
222       this.timeMsUtc = timeMsUtc;
223       this.granularity = granularity;
224     }
225 
getTimeMsUtc()226     public long getTimeMsUtc() {
227       return timeMsUtc;
228     }
229 
getGranularity()230     public int getGranularity() {
231       return granularity;
232     }
233   }
234 
235   /** Classification result for classifyText method. */
236   public static final class ClassificationResult {
237     private final String collection;
238     private final float score;
239     private final DatetimeResult datetimeResult;
240     private final byte[] serializedKnowledgeResult;
241     private final String contactName;
242     private final String contactGivenName;
243     private final String contactNickname;
244     private final String contactEmailAddress;
245     private final String contactPhoneNumber;
246     private final String contactId;
247     private final String appName;
248     private final String appPackageName;
249     private final NamedVariant[] entityData;
250     private final byte[] serializedEntityData;
251     private final RemoteActionTemplate[] remoteActionTemplates;
252     private final long durationMs;
253     private final long numericValue;
254 
ClassificationResult( String collection, float score, DatetimeResult datetimeResult, byte[] serializedKnowledgeResult, String contactName, String contactGivenName, String contactNickname, String contactEmailAddress, String contactPhoneNumber, String contactId, String appName, String appPackageName, NamedVariant[] entityData, byte[] serializedEntityData, RemoteActionTemplate[] remoteActionTemplates, long durationMs, long numericValue)255     public ClassificationResult(
256         String collection,
257         float score,
258         DatetimeResult datetimeResult,
259         byte[] serializedKnowledgeResult,
260         String contactName,
261         String contactGivenName,
262         String contactNickname,
263         String contactEmailAddress,
264         String contactPhoneNumber,
265         String contactId,
266         String appName,
267         String appPackageName,
268         NamedVariant[] entityData,
269         byte[] serializedEntityData,
270         RemoteActionTemplate[] remoteActionTemplates,
271         long durationMs,
272         long numericValue) {
273       this.collection = collection;
274       this.score = score;
275       this.datetimeResult = datetimeResult;
276       this.serializedKnowledgeResult = serializedKnowledgeResult;
277       this.contactName = contactName;
278       this.contactGivenName = contactGivenName;
279       this.contactNickname = contactNickname;
280       this.contactEmailAddress = contactEmailAddress;
281       this.contactPhoneNumber = contactPhoneNumber;
282       this.contactId = contactId;
283       this.appName = appName;
284       this.appPackageName = appPackageName;
285       this.entityData = entityData;
286       this.serializedEntityData = serializedEntityData;
287       this.remoteActionTemplates = remoteActionTemplates;
288       this.durationMs = durationMs;
289       this.numericValue = numericValue;
290     }
291 
292     /** Returns the classified entity type. */
getCollection()293     public String getCollection() {
294       return collection;
295     }
296 
297     /** Confidence score between 0 and 1. */
getScore()298     public float getScore() {
299       return score;
300     }
301 
getDatetimeResult()302     public DatetimeResult getDatetimeResult() {
303       return datetimeResult;
304     }
305 
getSerializedKnowledgeResult()306     public byte[] getSerializedKnowledgeResult() {
307       return serializedKnowledgeResult;
308     }
309 
getContactName()310     public String getContactName() {
311       return contactName;
312     }
313 
getContactGivenName()314     public String getContactGivenName() {
315       return contactGivenName;
316     }
317 
getContactNickname()318     public String getContactNickname() {
319       return contactNickname;
320     }
321 
getContactEmailAddress()322     public String getContactEmailAddress() {
323       return contactEmailAddress;
324     }
325 
getContactPhoneNumber()326     public String getContactPhoneNumber() {
327       return contactPhoneNumber;
328     }
329 
getContactId()330     public String getContactId() {
331       return contactId;
332     }
333 
getAppName()334     public String getAppName() {
335       return appName;
336     }
337 
getAppPackageName()338     public String getAppPackageName() {
339       return appPackageName;
340     }
341 
getEntityData()342     public NamedVariant[] getEntityData() {
343       return entityData;
344     }
345 
getSerializedEntityData()346     public byte[] getSerializedEntityData() {
347       return serializedEntityData;
348     }
349 
getRemoteActionTemplates()350     public RemoteActionTemplate[] getRemoteActionTemplates() {
351       return remoteActionTemplates;
352     }
353 
getDurationMs()354     public long getDurationMs() {
355       return durationMs;
356     }
357 
getNumericValue()358     public long getNumericValue() {
359       return numericValue;
360     }
361   }
362 
363   /** Represents a result of Annotate call. */
364   public static final class AnnotatedSpan {
365     private final int startIndex;
366     private final int endIndex;
367     private final ClassificationResult[] classification;
368 
AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification)369     AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) {
370       this.startIndex = startIndex;
371       this.endIndex = endIndex;
372       this.classification = classification;
373     }
374 
getStartIndex()375     public int getStartIndex() {
376       return startIndex;
377     }
378 
getEndIndex()379     public int getEndIndex() {
380       return endIndex;
381     }
382 
getClassification()383     public ClassificationResult[] getClassification() {
384       return classification;
385     }
386   }
387 
388   /** Represents options for the suggestSelection call. */
389   public static final class SelectionOptions {
390     private final String locales;
391     private final String detectedTextLanguageTags;
392     private final int annotationUsecase;
393 
SelectionOptions( String locales, String detectedTextLanguageTags, int annotationUsecase)394     public SelectionOptions(
395         String locales, String detectedTextLanguageTags, int annotationUsecase) {
396       this.locales = locales;
397       this.detectedTextLanguageTags = detectedTextLanguageTags;
398       this.annotationUsecase = annotationUsecase;
399     }
400 
SelectionOptions(String locales, String detectedTextLanguageTags)401     public SelectionOptions(String locales, String detectedTextLanguageTags) {
402       this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
403     }
404 
getLocales()405     public String getLocales() {
406       return locales;
407     }
408 
409     /** Returns a comma separated list of BCP 47 language tags. */
getDetectedTextLanguageTags()410     public String getDetectedTextLanguageTags() {
411       return detectedTextLanguageTags;
412     }
413 
getAnnotationUsecase()414     public int getAnnotationUsecase() {
415       return annotationUsecase;
416     }
417   }
418 
419   /** Represents options for the classifyText call. */
420   public static final class ClassificationOptions {
421     private final long referenceTimeMsUtc;
422     private final String referenceTimezone;
423     private final String locales;
424     private final String detectedTextLanguageTags;
425     private final int annotationUsecase;
426 
ClassificationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags, int annotationUsecase)427     public ClassificationOptions(
428         long referenceTimeMsUtc,
429         String referenceTimezone,
430         String locales,
431         String detectedTextLanguageTags,
432         int annotationUsecase) {
433       this.referenceTimeMsUtc = referenceTimeMsUtc;
434       this.referenceTimezone = referenceTimezone;
435       this.locales = locales;
436       this.detectedTextLanguageTags = detectedTextLanguageTags;
437       this.annotationUsecase = annotationUsecase;
438     }
439 
ClassificationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags)440     public ClassificationOptions(
441         long referenceTimeMsUtc,
442         String referenceTimezone,
443         String locales,
444         String detectedTextLanguageTags) {
445       this(
446           referenceTimeMsUtc,
447           referenceTimezone,
448           locales,
449           detectedTextLanguageTags,
450           AnnotationUsecase.SMART.getValue());
451     }
452 
getReferenceTimeMsUtc()453     public long getReferenceTimeMsUtc() {
454       return referenceTimeMsUtc;
455     }
456 
getReferenceTimezone()457     public String getReferenceTimezone() {
458       return referenceTimezone;
459     }
460 
getLocale()461     public String getLocale() {
462       return locales;
463     }
464 
465     /** Returns a comma separated list of BCP 47 language tags. */
getDetectedTextLanguageTags()466     public String getDetectedTextLanguageTags() {
467       return detectedTextLanguageTags;
468     }
469 
getAnnotationUsecase()470     public int getAnnotationUsecase() {
471       return annotationUsecase;
472     }
473   }
474 
475   /** Represents options for the annotate call. */
476   public static final class AnnotationOptions {
477     private final long referenceTimeMsUtc;
478     private final String referenceTimezone;
479     private final String locales;
480     private final String detectedTextLanguageTags;
481     private final String[] entityTypes;
482     private final int annotationUsecase;
483     private final boolean isSerializedEntityDataEnabled;
484 
AnnotationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags, Collection<String> entityTypes, int annotationUsecase, boolean isSerializedEntityDataEnabled)485     public AnnotationOptions(
486         long referenceTimeMsUtc,
487         String referenceTimezone,
488         String locales,
489         String detectedTextLanguageTags,
490         Collection<String> entityTypes,
491         int annotationUsecase,
492         boolean isSerializedEntityDataEnabled) {
493       this.referenceTimeMsUtc = referenceTimeMsUtc;
494       this.referenceTimezone = referenceTimezone;
495       this.locales = locales;
496       this.detectedTextLanguageTags = detectedTextLanguageTags;
497       this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
498       this.annotationUsecase = annotationUsecase;
499       this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
500     }
501 
AnnotationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags)502     public AnnotationOptions(
503         long referenceTimeMsUtc,
504         String referenceTimezone,
505         String locales,
506         String detectedTextLanguageTags) {
507       this(
508           referenceTimeMsUtc,
509           referenceTimezone,
510           locales,
511           detectedTextLanguageTags,
512           null,
513           AnnotationUsecase.SMART.getValue(),
514           /* isSerializedEntityDataEnabled */ false);
515     }
516 
getReferenceTimeMsUtc()517     public long getReferenceTimeMsUtc() {
518       return referenceTimeMsUtc;
519     }
520 
getReferenceTimezone()521     public String getReferenceTimezone() {
522       return referenceTimezone;
523     }
524 
getLocale()525     public String getLocale() {
526       return locales;
527     }
528 
529     /** Returns a comma separated list of BCP 47 language tags. */
getDetectedTextLanguageTags()530     public String getDetectedTextLanguageTags() {
531       return detectedTextLanguageTags;
532     }
533 
getEntityTypes()534     public String[] getEntityTypes() {
535       return entityTypes;
536     }
537 
getAnnotationUsecase()538     public int getAnnotationUsecase() {
539       return annotationUsecase;
540     }
541 
isSerializedEntityDataEnabled()542     public boolean isSerializedEntityDataEnabled() {
543       return isSerializedEntityDataEnabled;
544     }
545   }
546 
547   /**
548    * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long
549    * as the pointer is used.
550    */
getNativeAnnotator()551   long getNativeAnnotator() {
552     return nativeGetNativeModelPtr(annotatorPtr);
553   }
554 
nativeNewAnnotator(int fd)555   private static native long nativeNewAnnotator(int fd);
556 
nativeNewAnnotatorFromPath(String path)557   private static native long nativeNewAnnotatorFromPath(String path);
558 
nativeGetLocales(int fd)559   private static native String nativeGetLocales(int fd);
560 
nativeGetVersion(int fd)561   private static native int nativeGetVersion(int fd);
562 
nativeGetName(int fd)563   private static native String nativeGetName(int fd);
564 
nativeGetNativeModelPtr(long context)565   private native long nativeGetNativeModelPtr(long context);
566 
nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig)567   private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
568 
nativeInitializeContactEngine(long context, byte[] serializedConfig)569   private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig);
570 
nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig)571   private native boolean nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig);
572 
nativeSuggestSelection( long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options)573   private native int[] nativeSuggestSelection(
574       long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
575 
nativeClassifyText( long context, String text, int selectionBegin, int selectionEnd, ClassificationOptions options, Object appContext, String deviceLocales)576   private native ClassificationResult[] nativeClassifyText(
577       long context,
578       String text,
579       int selectionBegin,
580       int selectionEnd,
581       ClassificationOptions options,
582       Object appContext,
583       String deviceLocales);
584 
nativeAnnotate( long context, String text, AnnotationOptions options)585   private native AnnotatedSpan[] nativeAnnotate(
586       long context, String text, AnnotationOptions options);
587 
nativeLookUpKnowledgeEntity(long context, String id)588   private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
589 
nativeCloseAnnotator(long context)590   private native void nativeCloseAnnotator(long context);
591 }
592