1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.android.textclassifier; 18 19 import java.util.Collection; 20 import java.util.concurrent.atomic.AtomicBoolean; 21 22 /** 23 * Java wrapper for Annotator native library interface. This library is used for detecting entities 24 * in text. 25 * 26 * @hide 27 */ 28 public final class AnnotatorModel implements AutoCloseable { 29 private final AtomicBoolean isClosed = new AtomicBoolean(false); 30 31 static { 32 System.loadLibrary("textclassifier"); 33 } 34 35 // Keep these in sync with the constants defined in AOSP. 36 static final String TYPE_UNKNOWN = ""; 37 static final String TYPE_OTHER = "other"; 38 static final String TYPE_EMAIL = "email"; 39 static final String TYPE_PHONE = "phone"; 40 static final String TYPE_ADDRESS = "address"; 41 static final String TYPE_URL = "url"; 42 static final String TYPE_DATE = "date"; 43 static final String TYPE_DATE_TIME = "datetime"; 44 static final String TYPE_FLIGHT_NUMBER = "flight"; 45 46 private long annotatorPtr; 47 48 /** Enumeration for specifying the usecase of the annotations. */ 49 public static enum AnnotationUsecase { 50 /** Results are optimized for Smart{Select,Share,Linkify}. */ 51 SMART(0), 52 53 /** 54 * Results are optimized for using TextClassifier as an infrastructure that annotates as much as 55 * possible. 56 */ 57 RAW(1); 58 59 private final int value; 60 AnnotationUsecase(int value)61 AnnotationUsecase(int value) { 62 this.value = value; 63 } 64 getValue()65 public int getValue() { 66 return value; 67 } 68 }; 69 70 /** 71 * Creates a new instance of SmartSelect predictor, using the provided model image, given as a 72 * file descriptor. 73 */ AnnotatorModel(int fileDescriptor)74 public AnnotatorModel(int fileDescriptor) { 75 annotatorPtr = nativeNewAnnotator(fileDescriptor); 76 if (annotatorPtr == 0L) { 77 throw new IllegalArgumentException("Couldn't initialize TC from file descriptor."); 78 } 79 } 80 81 /** 82 * Creates a new instance of SmartSelect predictor, using the provided model image, given as a 83 * file path. 84 */ AnnotatorModel(String path)85 public AnnotatorModel(String path) { 86 annotatorPtr = nativeNewAnnotatorFromPath(path); 87 if (annotatorPtr == 0L) { 88 throw new IllegalArgumentException("Couldn't initialize TC from given file."); 89 } 90 } 91 92 /** Initializes the knowledge engine, passing the given serialized config to it. */ initializeKnowledgeEngine(byte[] serializedConfig)93 public void initializeKnowledgeEngine(byte[] serializedConfig) { 94 if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) { 95 throw new IllegalArgumentException("Couldn't initialize the KG engine"); 96 } 97 } 98 99 /** Initializes the contact engine, passing the given serialized config to it. */ initializeContactEngine(byte[] serializedConfig)100 public void initializeContactEngine(byte[] serializedConfig) { 101 if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) { 102 throw new IllegalArgumentException("Couldn't initialize the contact engine"); 103 } 104 } 105 106 /** Initializes the installed app engine, passing the given serialized config to it. */ initializeInstalledAppEngine(byte[] serializedConfig)107 public void initializeInstalledAppEngine(byte[] serializedConfig) { 108 if (!nativeInitializeInstalledAppEngine(annotatorPtr, serializedConfig)) { 109 throw new IllegalArgumentException("Couldn't initialize the installed app engine"); 110 } 111 } 112 113 /** 114 * Given a string context and current selection, computes the selection suggestion. 115 * 116 * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the 117 * character index where the selection begins, and selectionEnd is the index of one character past 118 * the selection span. 119 * 120 * <p>The return value is an array of two ints: suggested selection beginning and end, with the 121 * same semantics as the input selectionBeginning and selectionEnd. 122 */ suggestSelection( String context, int selectionBegin, int selectionEnd, SelectionOptions options)123 public int[] suggestSelection( 124 String context, int selectionBegin, int selectionEnd, SelectionOptions options) { 125 return nativeSuggestSelection(annotatorPtr, context, selectionBegin, selectionEnd, options); 126 } 127 128 /** 129 * Given a string context and current selection, classifies the type of the selected text. 130 * 131 * <p>The begin and end params are character indices in the context string. 132 * 133 * <p>Returns an array of ClassificationResult objects with the probability scores for different 134 * collections. 135 */ classifyText( String context, int selectionBegin, int selectionEnd, ClassificationOptions options)136 public ClassificationResult[] classifyText( 137 String context, int selectionBegin, int selectionEnd, ClassificationOptions options) { 138 return classifyText( 139 context, 140 selectionBegin, 141 selectionEnd, 142 options, 143 /*appContext=*/ null, 144 /*deviceLocales=*/ null); 145 } 146 classifyText( String context, int selectionBegin, int selectionEnd, ClassificationOptions options, Object appContext, String deviceLocales)147 public ClassificationResult[] classifyText( 148 String context, 149 int selectionBegin, 150 int selectionEnd, 151 ClassificationOptions options, 152 Object appContext, 153 String deviceLocales) { 154 return nativeClassifyText( 155 annotatorPtr, context, selectionBegin, selectionEnd, options, appContext, deviceLocales); 156 } 157 158 /** 159 * Annotates given input text. The annotations should cover the whole input context except for 160 * whitespaces, and are sorted by their position in the context string. 161 */ annotate(String text, AnnotationOptions options)162 public AnnotatedSpan[] annotate(String text, AnnotationOptions options) { 163 return nativeAnnotate(annotatorPtr, text, options); 164 } 165 166 /** 167 * Looks up a knowledge entity by its identifier. Returns null if the entity is not found or on 168 * error. 169 */ lookUpKnowledgeEntity(String id)170 public byte[] lookUpKnowledgeEntity(String id) { 171 return nativeLookUpKnowledgeEntity(annotatorPtr, id); 172 } 173 174 /** Frees up the allocated memory. */ 175 @Override close()176 public void close() { 177 if (isClosed.compareAndSet(false, true)) { 178 nativeCloseAnnotator(annotatorPtr); 179 annotatorPtr = 0L; 180 } 181 } 182 183 @Override finalize()184 protected void finalize() throws Throwable { 185 try { 186 close(); 187 } finally { 188 super.finalize(); 189 } 190 } 191 192 /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */ getLocales(int fd)193 public static String getLocales(int fd) { 194 return nativeGetLocales(fd); 195 } 196 197 /** Returns the version of the model. */ getVersion(int fd)198 public static int getVersion(int fd) { 199 return nativeGetVersion(fd); 200 } 201 202 /** Returns the name of the model. */ getName(int fd)203 public static String getName(int fd) { 204 return nativeGetName(fd); 205 } 206 207 /** Information about a parsed time/date. */ 208 public static final class DatetimeResult { 209 210 public static final int GRANULARITY_YEAR = 0; 211 public static final int GRANULARITY_MONTH = 1; 212 public static final int GRANULARITY_WEEK = 2; 213 public static final int GRANULARITY_DAY = 3; 214 public static final int GRANULARITY_HOUR = 4; 215 public static final int GRANULARITY_MINUTE = 5; 216 public static final int GRANULARITY_SECOND = 6; 217 218 private final long timeMsUtc; 219 private final int granularity; 220 DatetimeResult(long timeMsUtc, int granularity)221 public DatetimeResult(long timeMsUtc, int granularity) { 222 this.timeMsUtc = timeMsUtc; 223 this.granularity = granularity; 224 } 225 getTimeMsUtc()226 public long getTimeMsUtc() { 227 return timeMsUtc; 228 } 229 getGranularity()230 public int getGranularity() { 231 return granularity; 232 } 233 } 234 235 /** Classification result for classifyText method. */ 236 public static final class ClassificationResult { 237 private final String collection; 238 private final float score; 239 private final DatetimeResult datetimeResult; 240 private final byte[] serializedKnowledgeResult; 241 private final String contactName; 242 private final String contactGivenName; 243 private final String contactNickname; 244 private final String contactEmailAddress; 245 private final String contactPhoneNumber; 246 private final String contactId; 247 private final String appName; 248 private final String appPackageName; 249 private final NamedVariant[] entityData; 250 private final byte[] serializedEntityData; 251 private final RemoteActionTemplate[] remoteActionTemplates; 252 private final long durationMs; 253 private final long numericValue; 254 ClassificationResult( String collection, float score, DatetimeResult datetimeResult, byte[] serializedKnowledgeResult, String contactName, String contactGivenName, String contactNickname, String contactEmailAddress, String contactPhoneNumber, String contactId, String appName, String appPackageName, NamedVariant[] entityData, byte[] serializedEntityData, RemoteActionTemplate[] remoteActionTemplates, long durationMs, long numericValue)255 public ClassificationResult( 256 String collection, 257 float score, 258 DatetimeResult datetimeResult, 259 byte[] serializedKnowledgeResult, 260 String contactName, 261 String contactGivenName, 262 String contactNickname, 263 String contactEmailAddress, 264 String contactPhoneNumber, 265 String contactId, 266 String appName, 267 String appPackageName, 268 NamedVariant[] entityData, 269 byte[] serializedEntityData, 270 RemoteActionTemplate[] remoteActionTemplates, 271 long durationMs, 272 long numericValue) { 273 this.collection = collection; 274 this.score = score; 275 this.datetimeResult = datetimeResult; 276 this.serializedKnowledgeResult = serializedKnowledgeResult; 277 this.contactName = contactName; 278 this.contactGivenName = contactGivenName; 279 this.contactNickname = contactNickname; 280 this.contactEmailAddress = contactEmailAddress; 281 this.contactPhoneNumber = contactPhoneNumber; 282 this.contactId = contactId; 283 this.appName = appName; 284 this.appPackageName = appPackageName; 285 this.entityData = entityData; 286 this.serializedEntityData = serializedEntityData; 287 this.remoteActionTemplates = remoteActionTemplates; 288 this.durationMs = durationMs; 289 this.numericValue = numericValue; 290 } 291 292 /** Returns the classified entity type. */ getCollection()293 public String getCollection() { 294 return collection; 295 } 296 297 /** Confidence score between 0 and 1. */ getScore()298 public float getScore() { 299 return score; 300 } 301 getDatetimeResult()302 public DatetimeResult getDatetimeResult() { 303 return datetimeResult; 304 } 305 getSerializedKnowledgeResult()306 public byte[] getSerializedKnowledgeResult() { 307 return serializedKnowledgeResult; 308 } 309 getContactName()310 public String getContactName() { 311 return contactName; 312 } 313 getContactGivenName()314 public String getContactGivenName() { 315 return contactGivenName; 316 } 317 getContactNickname()318 public String getContactNickname() { 319 return contactNickname; 320 } 321 getContactEmailAddress()322 public String getContactEmailAddress() { 323 return contactEmailAddress; 324 } 325 getContactPhoneNumber()326 public String getContactPhoneNumber() { 327 return contactPhoneNumber; 328 } 329 getContactId()330 public String getContactId() { 331 return contactId; 332 } 333 getAppName()334 public String getAppName() { 335 return appName; 336 } 337 getAppPackageName()338 public String getAppPackageName() { 339 return appPackageName; 340 } 341 getEntityData()342 public NamedVariant[] getEntityData() { 343 return entityData; 344 } 345 getSerializedEntityData()346 public byte[] getSerializedEntityData() { 347 return serializedEntityData; 348 } 349 getRemoteActionTemplates()350 public RemoteActionTemplate[] getRemoteActionTemplates() { 351 return remoteActionTemplates; 352 } 353 getDurationMs()354 public long getDurationMs() { 355 return durationMs; 356 } 357 getNumericValue()358 public long getNumericValue() { 359 return numericValue; 360 } 361 } 362 363 /** Represents a result of Annotate call. */ 364 public static final class AnnotatedSpan { 365 private final int startIndex; 366 private final int endIndex; 367 private final ClassificationResult[] classification; 368 AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification)369 AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) { 370 this.startIndex = startIndex; 371 this.endIndex = endIndex; 372 this.classification = classification; 373 } 374 getStartIndex()375 public int getStartIndex() { 376 return startIndex; 377 } 378 getEndIndex()379 public int getEndIndex() { 380 return endIndex; 381 } 382 getClassification()383 public ClassificationResult[] getClassification() { 384 return classification; 385 } 386 } 387 388 /** Represents options for the suggestSelection call. */ 389 public static final class SelectionOptions { 390 private final String locales; 391 private final String detectedTextLanguageTags; 392 private final int annotationUsecase; 393 SelectionOptions( String locales, String detectedTextLanguageTags, int annotationUsecase)394 public SelectionOptions( 395 String locales, String detectedTextLanguageTags, int annotationUsecase) { 396 this.locales = locales; 397 this.detectedTextLanguageTags = detectedTextLanguageTags; 398 this.annotationUsecase = annotationUsecase; 399 } 400 SelectionOptions(String locales, String detectedTextLanguageTags)401 public SelectionOptions(String locales, String detectedTextLanguageTags) { 402 this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue()); 403 } 404 getLocales()405 public String getLocales() { 406 return locales; 407 } 408 409 /** Returns a comma separated list of BCP 47 language tags. */ getDetectedTextLanguageTags()410 public String getDetectedTextLanguageTags() { 411 return detectedTextLanguageTags; 412 } 413 getAnnotationUsecase()414 public int getAnnotationUsecase() { 415 return annotationUsecase; 416 } 417 } 418 419 /** Represents options for the classifyText call. */ 420 public static final class ClassificationOptions { 421 private final long referenceTimeMsUtc; 422 private final String referenceTimezone; 423 private final String locales; 424 private final String detectedTextLanguageTags; 425 private final int annotationUsecase; 426 ClassificationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags, int annotationUsecase)427 public ClassificationOptions( 428 long referenceTimeMsUtc, 429 String referenceTimezone, 430 String locales, 431 String detectedTextLanguageTags, 432 int annotationUsecase) { 433 this.referenceTimeMsUtc = referenceTimeMsUtc; 434 this.referenceTimezone = referenceTimezone; 435 this.locales = locales; 436 this.detectedTextLanguageTags = detectedTextLanguageTags; 437 this.annotationUsecase = annotationUsecase; 438 } 439 ClassificationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags)440 public ClassificationOptions( 441 long referenceTimeMsUtc, 442 String referenceTimezone, 443 String locales, 444 String detectedTextLanguageTags) { 445 this( 446 referenceTimeMsUtc, 447 referenceTimezone, 448 locales, 449 detectedTextLanguageTags, 450 AnnotationUsecase.SMART.getValue()); 451 } 452 getReferenceTimeMsUtc()453 public long getReferenceTimeMsUtc() { 454 return referenceTimeMsUtc; 455 } 456 getReferenceTimezone()457 public String getReferenceTimezone() { 458 return referenceTimezone; 459 } 460 getLocale()461 public String getLocale() { 462 return locales; 463 } 464 465 /** Returns a comma separated list of BCP 47 language tags. */ getDetectedTextLanguageTags()466 public String getDetectedTextLanguageTags() { 467 return detectedTextLanguageTags; 468 } 469 getAnnotationUsecase()470 public int getAnnotationUsecase() { 471 return annotationUsecase; 472 } 473 } 474 475 /** Represents options for the annotate call. */ 476 public static final class AnnotationOptions { 477 private final long referenceTimeMsUtc; 478 private final String referenceTimezone; 479 private final String locales; 480 private final String detectedTextLanguageTags; 481 private final String[] entityTypes; 482 private final int annotationUsecase; 483 private final boolean isSerializedEntityDataEnabled; 484 AnnotationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags, Collection<String> entityTypes, int annotationUsecase, boolean isSerializedEntityDataEnabled)485 public AnnotationOptions( 486 long referenceTimeMsUtc, 487 String referenceTimezone, 488 String locales, 489 String detectedTextLanguageTags, 490 Collection<String> entityTypes, 491 int annotationUsecase, 492 boolean isSerializedEntityDataEnabled) { 493 this.referenceTimeMsUtc = referenceTimeMsUtc; 494 this.referenceTimezone = referenceTimezone; 495 this.locales = locales; 496 this.detectedTextLanguageTags = detectedTextLanguageTags; 497 this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]); 498 this.annotationUsecase = annotationUsecase; 499 this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled; 500 } 501 AnnotationOptions( long referenceTimeMsUtc, String referenceTimezone, String locales, String detectedTextLanguageTags)502 public AnnotationOptions( 503 long referenceTimeMsUtc, 504 String referenceTimezone, 505 String locales, 506 String detectedTextLanguageTags) { 507 this( 508 referenceTimeMsUtc, 509 referenceTimezone, 510 locales, 511 detectedTextLanguageTags, 512 null, 513 AnnotationUsecase.SMART.getValue(), 514 /* isSerializedEntityDataEnabled */ false); 515 } 516 getReferenceTimeMsUtc()517 public long getReferenceTimeMsUtc() { 518 return referenceTimeMsUtc; 519 } 520 getReferenceTimezone()521 public String getReferenceTimezone() { 522 return referenceTimezone; 523 } 524 getLocale()525 public String getLocale() { 526 return locales; 527 } 528 529 /** Returns a comma separated list of BCP 47 language tags. */ getDetectedTextLanguageTags()530 public String getDetectedTextLanguageTags() { 531 return detectedTextLanguageTags; 532 } 533 getEntityTypes()534 public String[] getEntityTypes() { 535 return entityTypes; 536 } 537 getAnnotationUsecase()538 public int getAnnotationUsecase() { 539 return annotationUsecase; 540 } 541 isSerializedEntityDataEnabled()542 public boolean isSerializedEntityDataEnabled() { 543 return isSerializedEntityDataEnabled; 544 } 545 } 546 547 /** 548 * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long 549 * as the pointer is used. 550 */ getNativeAnnotator()551 long getNativeAnnotator() { 552 return nativeGetNativeModelPtr(annotatorPtr); 553 } 554 nativeNewAnnotator(int fd)555 private static native long nativeNewAnnotator(int fd); 556 nativeNewAnnotatorFromPath(String path)557 private static native long nativeNewAnnotatorFromPath(String path); 558 nativeGetLocales(int fd)559 private static native String nativeGetLocales(int fd); 560 nativeGetVersion(int fd)561 private static native int nativeGetVersion(int fd); 562 nativeGetName(int fd)563 private static native String nativeGetName(int fd); 564 nativeGetNativeModelPtr(long context)565 private native long nativeGetNativeModelPtr(long context); 566 nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig)567 private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig); 568 nativeInitializeContactEngine(long context, byte[] serializedConfig)569 private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig); 570 nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig)571 private native boolean nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig); 572 nativeSuggestSelection( long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options)573 private native int[] nativeSuggestSelection( 574 long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options); 575 nativeClassifyText( long context, String text, int selectionBegin, int selectionEnd, ClassificationOptions options, Object appContext, String deviceLocales)576 private native ClassificationResult[] nativeClassifyText( 577 long context, 578 String text, 579 int selectionBegin, 580 int selectionEnd, 581 ClassificationOptions options, 582 Object appContext, 583 String deviceLocales); 584 nativeAnnotate( long context, String text, AnnotationOptions options)585 private native AnnotatedSpan[] nativeAnnotate( 586 long context, String text, AnnotationOptions options); 587 nativeLookUpKnowledgeEntity(long context, String id)588 private native byte[] nativeLookUpKnowledgeEntity(long context, String id); 589 nativeCloseAnnotator(long context)590 private native void nativeCloseAnnotator(long context); 591 } 592