1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static java.util.stream.Collectors.toCollection; 20 21 import android.app.PendingIntent; 22 import android.app.RemoteAction; 23 import android.content.Context; 24 import android.content.Intent; 25 import android.icu.util.ULocale; 26 import android.os.Bundle; 27 import android.os.LocaleList; 28 import android.os.Looper; 29 import android.os.ParcelFileDescriptor; 30 import android.util.ArrayMap; 31 import android.view.View.OnClickListener; 32 import android.view.textclassifier.ConversationAction; 33 import android.view.textclassifier.ConversationActions; 34 import android.view.textclassifier.SelectionEvent; 35 import android.view.textclassifier.TextClassification; 36 import android.view.textclassifier.TextClassificationSessionId; 37 import android.view.textclassifier.TextClassifier; 38 import android.view.textclassifier.TextClassifierEvent; 39 import android.view.textclassifier.TextLanguage; 40 import android.view.textclassifier.TextLinks; 41 import android.view.textclassifier.TextSelection; 42 import androidx.annotation.GuardedBy; 43 import androidx.annotation.WorkerThread; 44 import androidx.core.util.Pair; 45 import com.android.textclassifier.ModelFileManager.ModelFile; 46 import com.android.textclassifier.common.base.TcLog; 47 import com.android.textclassifier.common.intent.LabeledIntent; 48 import com.android.textclassifier.common.intent.TemplateIntentFactory; 49 import com.android.textclassifier.common.logging.ResultIdUtils; 50 import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo; 51 import com.android.textclassifier.common.statsd.GenerateLinksLogger; 52 import com.android.textclassifier.common.statsd.SelectionEventConverter; 53 import com.android.textclassifier.common.statsd.TextClassificationSessionIdConverter; 54 import com.android.textclassifier.common.statsd.TextClassifierEventConverter; 55 import com.android.textclassifier.common.statsd.TextClassifierEventLogger; 56 import com.android.textclassifier.utils.IndentingPrintWriter; 57 import com.google.android.textclassifier.ActionsSuggestionsModel; 58 import com.google.android.textclassifier.AnnotatorModel; 59 import com.google.android.textclassifier.LangIdModel; 60 import com.google.common.base.Optional; 61 import com.google.common.base.Preconditions; 62 import com.google.common.collect.FluentIterable; 63 import com.google.common.collect.ImmutableList; 64 import java.io.File; 65 import java.io.FileNotFoundException; 66 import java.io.IOException; 67 import java.time.ZoneId; 68 import java.time.ZonedDateTime; 69 import java.util.ArrayList; 70 import java.util.Collection; 71 import java.util.List; 72 import java.util.Map; 73 import java.util.Objects; 74 import javax.annotation.Nullable; 75 76 /** 77 * A text classifier that is running locally. 78 * 79 * <p>This class uses machine learning to recognize entities in text. Unless otherwise stated, 80 * methods of this class are blocking operations and should most likely not be called on the UI 81 * thread. 82 */ 83 final class TextClassifierImpl { 84 85 private static final String TAG = "TextClassifierImpl"; 86 87 private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/"); 88 // Annotator 89 private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX = 90 "textclassifier\\.(.*)\\.model"; 91 private static final File ANNOTATOR_UPDATED_MODEL_FILE = 92 new File("/data/misc/textclassifier/textclassifier.model"); 93 94 // LangIdModel 95 private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model"; 96 private static final File UPDATED_LANG_ID_MODEL_FILE = 97 new File("/data/misc/textclassifier/lang_id.model"); 98 99 // Actions 100 private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = 101 "actions_suggestions\\.(.*)\\.model"; 102 private static final File UPDATED_ACTIONS_MODEL = 103 new File("/data/misc/textclassifier/actions_suggestions.model"); 104 105 private final Context context; 106 private final TextClassifier fallback; 107 private final GenerateLinksLogger generateLinksLogger; 108 109 private final Object lock = new Object(); 110 111 @GuardedBy("lock") 112 private ModelFileManager.ModelFile annotatorModelInUse; 113 114 @GuardedBy("lock") 115 private AnnotatorModel annotatorImpl; 116 117 @GuardedBy("lock") 118 private ModelFileManager.ModelFile langIdModelInUse; 119 120 @GuardedBy("lock") 121 private LangIdModel langIdImpl; 122 123 @GuardedBy("lock") 124 private ModelFileManager.ModelFile actionModelInUse; 125 126 @GuardedBy("lock") 127 private ActionsSuggestionsModel actionsImpl; 128 129 private final TextClassifierEventLogger textClassifierEventLogger = 130 new TextClassifierEventLogger(); 131 132 private final TextClassifierSettings settings; 133 134 private final ModelFileManager annotatorModelFileManager; 135 private final ModelFileManager langIdModelFileManager; 136 private final ModelFileManager actionsModelFileManager; 137 private final TemplateIntentFactory templateIntentFactory; 138 139 TextClassifierImpl(Context context, TextClassifierSettings settings, TextClassifier fallback) { 140 this.context = Preconditions.checkNotNull(context); 141 this.fallback = Preconditions.checkNotNull(fallback); 142 this.settings = Preconditions.checkNotNull(settings); 143 generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate()); 144 annotatorModelFileManager = 145 new ModelFileManager( 146 new ModelFileManager.ModelFileSupplierImpl( 147 FACTORY_MODEL_DIR, 148 ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX, 149 ANNOTATOR_UPDATED_MODEL_FILE, 150 AnnotatorModel::getVersion, 151 AnnotatorModel::getLocales)); 152 langIdModelFileManager = 153 new ModelFileManager( 154 new ModelFileManager.ModelFileSupplierImpl( 155 FACTORY_MODEL_DIR, 156 LANG_ID_FACTORY_MODEL_FILENAME_REGEX, 157 UPDATED_LANG_ID_MODEL_FILE, 158 LangIdModel::getVersion, 159 fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); 160 actionsModelFileManager = 161 new ModelFileManager( 162 new ModelFileManager.ModelFileSupplierImpl( 163 FACTORY_MODEL_DIR, 164 ACTIONS_FACTORY_MODEL_FILENAME_REGEX, 165 UPDATED_ACTIONS_MODEL, 166 ActionsSuggestionsModel::getVersion, 167 ActionsSuggestionsModel::getLocales)); 168 169 templateIntentFactory = new TemplateIntentFactory(); 170 } 171 172 TextClassifierImpl(Context context, TextClassifierSettings settings) { 173 this(context, settings, TextClassifier.NO_OP); 174 } 175 176 @WorkerThread 177 TextSelection suggestSelection(TextSelection.Request request) { 178 Preconditions.checkNotNull(request); 179 checkMainThread(); 180 try { 181 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 182 final String string = request.getText().toString(); 183 if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) { 184 final String localesString = concatenateLocales(request.getDefaultLocales()); 185 final Optional<LangIdModel> langIdModel = getLangIdImpl(); 186 final String detectLanguageTags = 187 String.join(",", detectLanguageTags(langIdModel, request.getText())); 188 final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault()); 189 final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales()); 190 final int[] startEnd = 191 annotatorImpl.suggestSelection( 192 string, 193 request.getStartIndex(), 194 request.getEndIndex(), 195 new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags)); 196 final int start = startEnd[0]; 197 final int end = startEnd[1]; 198 if (start < end 199 && start >= 0 200 && end <= string.length() 201 && start <= request.getStartIndex() 202 && end >= request.getEndIndex()) { 203 final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end); 204 final AnnotatorModel.ClassificationResult[] results = 205 annotatorImpl.classifyText( 206 string, 207 start, 208 end, 209 new AnnotatorModel.ClassificationOptions( 210 refTime.toInstant().toEpochMilli(), 211 refTime.getZone().getId(), 212 localesString, 213 detectLanguageTags), 214 // Passing null here to suppress intent generation 215 // TODO: Use an explicit flag to suppress it. 216 /* appContext */ null, 217 /* deviceLocales */ null); 218 final int size = results.length; 219 for (int i = 0; i < size; i++) { 220 tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore()); 221 } 222 final String resultId = 223 createAnnotatorId(string, request.getStartIndex(), request.getEndIndex()); 224 return tsBuilder.setId(resultId).build(); 225 } else { 226 // We can not trust the result. Log the issue and ignore the result. 227 TcLog.d(TAG, "Got bad indices for input text. Ignoring result."); 228 } 229 } 230 } catch (Throwable t) { 231 // Avoid throwing from this method. Log the error. 232 TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t); 233 } 234 // Getting here means something went wrong, return a NO_OP result. 235 return fallback.suggestSelection(request); 236 } 237 238 @WorkerThread 239 TextClassification classifyText(TextClassification.Request request) { 240 Preconditions.checkNotNull(request); 241 checkMainThread(); 242 try { 243 Optional<LangIdModel> langId = getLangIdImpl(); 244 List<String> detectLanguageTags = detectLanguageTags(langId, request.getText()); 245 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 246 final String string = request.getText().toString(); 247 if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) { 248 final String localesString = concatenateLocales(request.getDefaultLocales()); 249 final ZonedDateTime refTime = 250 request.getReferenceTime() != null 251 ? request.getReferenceTime() 252 : ZonedDateTime.now(ZoneId.systemDefault()); 253 final AnnotatorModel.ClassificationResult[] results = 254 getAnnotatorImpl(request.getDefaultLocales()) 255 .classifyText( 256 string, 257 request.getStartIndex(), 258 request.getEndIndex(), 259 new AnnotatorModel.ClassificationOptions( 260 refTime.toInstant().toEpochMilli(), 261 refTime.getZone().getId(), 262 localesString, 263 String.join(",", detectLanguageTags), 264 AnnotatorModel.AnnotationUsecase.SMART.getValue(), 265 LocaleList.getDefault().toLanguageTags()), 266 context, 267 getResourceLocalesString()); 268 if (results.length > 0) { 269 return createClassificationResult( 270 results, string, request.getStartIndex(), request.getEndIndex(), langId); 271 } 272 } 273 } catch (Throwable t) { 274 // Avoid throwing from this method. Log the error. 275 TcLog.e(TAG, "Error getting text classification info.", t); 276 } 277 // Getting here means something went wrong, return a NO_OP result. 278 return fallback.classifyText(request); 279 } 280 281 @WorkerThread 282 TextLinks generateLinks(TextLinks.Request request) { 283 Preconditions.checkNotNull(request); 284 Preconditions.checkArgument( 285 request.getText().length() <= getMaxGenerateLinksTextLength(), 286 "text.length() cannot be greater than %s", 287 getMaxGenerateLinksTextLength()); 288 checkMainThread(); 289 290 final String textString = request.getText().toString(); 291 final TextLinks.Builder builder = new TextLinks.Builder(textString); 292 293 try { 294 final long startTimeMs = System.currentTimeMillis(); 295 final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault()); 296 final Collection<String> entitiesToIdentify = 297 request.getEntityConfig() != null 298 ? request 299 .getEntityConfig() 300 .resolveEntityListModifications( 301 getEntitiesForHints(request.getEntityConfig().getHints())) 302 : settings.getEntityListDefault(); 303 final String localesString = concatenateLocales(request.getDefaultLocales()); 304 Optional<LangIdModel> langId = getLangIdImpl(); 305 ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText()); 306 final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales()); 307 final boolean isSerializedEntityDataEnabled = 308 ExtrasUtils.isSerializedEntityDataEnabled(request); 309 final AnnotatorModel.AnnotatedSpan[] annotations = 310 annotatorImpl.annotate( 311 textString, 312 new AnnotatorModel.AnnotationOptions( 313 refTime.toInstant().toEpochMilli(), 314 refTime.getZone().getId(), 315 localesString, 316 String.join(",", detectLanguageTags), 317 entitiesToIdentify, 318 AnnotatorModel.AnnotationUsecase.SMART.getValue(), 319 isSerializedEntityDataEnabled)); 320 for (AnnotatorModel.AnnotatedSpan span : annotations) { 321 final AnnotatorModel.ClassificationResult[] results = span.getClassification(); 322 if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) { 323 continue; 324 } 325 final Map<String, Float> entityScores = new ArrayMap<>(); 326 for (int i = 0; i < results.length; i++) { 327 entityScores.put(results[i].getCollection(), results[i].getScore()); 328 } 329 Bundle extras = new Bundle(); 330 if (isSerializedEntityDataEnabled) { 331 ExtrasUtils.putEntities(extras, results); 332 } 333 builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras); 334 } 335 final TextLinks links = builder.build(); 336 final long endTimeMs = System.currentTimeMillis(); 337 final String callingPackageName = 338 request.getCallingPackageName() == null 339 ? context.getPackageName() // local (in process) TC. 340 : request.getCallingPackageName(); 341 Optional<ModelInfo> annotatorModelInfo; 342 Optional<ModelInfo> langIdModelInfo; 343 synchronized (lock) { 344 annotatorModelInfo = 345 Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo); 346 langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo); 347 } 348 generateLinksLogger.logGenerateLinks( 349 request.getText(), 350 links, 351 callingPackageName, 352 endTimeMs - startTimeMs, 353 annotatorModelInfo, 354 langIdModelInfo); 355 return links; 356 } catch (Throwable t) { 357 // Avoid throwing from this method. Log the error. 358 TcLog.e(TAG, "Error getting links info.", t); 359 } 360 return fallback.generateLinks(request); 361 } 362 363 int getMaxGenerateLinksTextLength() { 364 return settings.getGenerateLinksMaxTextLength(); 365 } 366 367 private Collection<String> getEntitiesForHints(Collection<String> hints) { 368 final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE); 369 final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE); 370 371 // Use the default if there is no hint, or conflicting ones. 372 final boolean useDefault = editable == notEditable; 373 if (useDefault) { 374 return settings.getEntityListDefault(); 375 } else if (editable) { 376 return settings.getEntityListEditable(); 377 } else { // notEditable 378 return settings.getEntityListNotEditable(); 379 } 380 } 381 382 void onSelectionEvent(SelectionEvent event) { 383 TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event); 384 if (textClassifierEvent == null) { 385 return; 386 } 387 onTextClassifierEvent(event.getSessionId(), textClassifierEvent); 388 } 389 390 void onTextClassifierEvent( 391 @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) { 392 textClassifierEventLogger.writeEvent( 393 TextClassificationSessionIdConverter.fromPlatform(sessionId), 394 TextClassifierEventConverter.fromPlatform(event)); 395 } 396 397 TextLanguage detectLanguage(TextLanguage.Request request) { 398 Preconditions.checkNotNull(request); 399 checkMainThread(); 400 try { 401 final TextLanguage.Builder builder = new TextLanguage.Builder(); 402 Optional<LangIdModel> langIdImpl = getLangIdImpl(); 403 if (langIdImpl.isPresent()) { 404 final LangIdModel.LanguageResult[] langResults = 405 langIdImpl.get().detectLanguages(request.getText().toString()); 406 for (int i = 0; i < langResults.length; i++) { 407 builder.putLocale( 408 ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore()); 409 } 410 return builder.build(); 411 } 412 } catch (Throwable t) { 413 // Avoid throwing from this method. Log the error. 414 TcLog.e(TAG, "Error detecting text language.", t); 415 } 416 return fallback.detectLanguage(request); 417 } 418 419 ConversationActions suggestConversationActions(ConversationActions.Request request) { 420 Preconditions.checkNotNull(request); 421 checkMainThread(); 422 try { 423 ActionsSuggestionsModel actionsImpl = getActionsImpl(); 424 if (actionsImpl == null) { 425 // Actions model is optional, fallback if it is not available. 426 return fallback.suggestConversationActions(request); 427 } 428 Optional<LangIdModel> langId = getLangIdImpl(); 429 ActionsSuggestionsModel.ConversationMessage[] nativeMessages = 430 ActionsSuggestionsHelper.toNativeMessages( 431 request.getConversation(), text -> detectLanguageTags(langId, text)); 432 if (nativeMessages.length == 0) { 433 return fallback.suggestConversationActions(request); 434 } 435 ActionsSuggestionsModel.Conversation nativeConversation = 436 new ActionsSuggestionsModel.Conversation(nativeMessages); 437 438 ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions = 439 actionsImpl.suggestActionsWithIntents( 440 nativeConversation, 441 null, 442 context, 443 getResourceLocalesString(), 444 getAnnotatorImpl(LocaleList.getDefault())); 445 return createConversationActionResult(request, nativeSuggestions); 446 } catch (Throwable t) { 447 // Avoid throwing from this method. Log the error. 448 TcLog.e(TAG, "Error suggesting conversation actions.", t); 449 } 450 return fallback.suggestConversationActions(request); 451 } 452 453 /** 454 * Returns the {@link ConversationAction} result, with a non-null extras. 455 * 456 * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a 457 * non-null component name is in the extras. 458 */ 459 private ConversationActions createConversationActionResult( 460 ConversationActions.Request request, 461 ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) { 462 Collection<String> expectedTypes = resolveActionTypesFromRequest(request); 463 List<ConversationAction> conversationActions = new ArrayList<>(); 464 for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) { 465 String actionType = nativeSuggestion.getActionType(); 466 if (!expectedTypes.contains(actionType)) { 467 continue; 468 } 469 LabeledIntent.Result labeledIntentResult = 470 ActionsSuggestionsHelper.createLabeledIntentResult( 471 context, templateIntentFactory, nativeSuggestion); 472 RemoteAction remoteAction = null; 473 Bundle extras = new Bundle(); 474 if (labeledIntentResult != null) { 475 remoteAction = labeledIntentResult.remoteAction.toRemoteAction(); 476 ExtrasUtils.putActionIntent( 477 extras, stripPackageInfoFromIntent(labeledIntentResult.resolvedIntent)); 478 } 479 ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData()); 480 ExtrasUtils.putEntitiesExtras( 481 extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData())); 482 conversationActions.add( 483 new ConversationAction.Builder(actionType) 484 .setConfidenceScore(nativeSuggestion.getScore()) 485 .setTextReply(nativeSuggestion.getResponseText()) 486 .setAction(remoteAction) 487 .setExtras(extras) 488 .build()); 489 } 490 conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions); 491 if (request.getMaxSuggestions() >= 0 492 && conversationActions.size() > request.getMaxSuggestions()) { 493 conversationActions = conversationActions.subList(0, request.getMaxSuggestions()); 494 } 495 synchronized (lock) { 496 String resultId = 497 ActionsSuggestionsHelper.createResultId( 498 context, 499 request.getConversation(), 500 Optional.fromNullable(actionModelInUse), 501 Optional.fromNullable(annotatorModelInUse), 502 Optional.fromNullable(langIdModelInUse)); 503 return new ConversationActions(conversationActions, resultId); 504 } 505 } 506 507 private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { 508 List<String> defaultActionTypes = 509 request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION) 510 ? settings.getNotificationConversationActionTypes() 511 : settings.getInAppConversationActionTypes(); 512 return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes); 513 } 514 515 private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException { 516 synchronized (lock) { 517 localeList = localeList == null ? LocaleList.getDefault() : localeList; 518 final ModelFileManager.ModelFile bestModel = 519 annotatorModelFileManager.findBestModelFile(localeList); 520 if (bestModel == null) { 521 throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags()); 522 } 523 if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) { 524 TcLog.d(TAG, "Loading " + bestModel); 525 final ParcelFileDescriptor pfd = 526 ParcelFileDescriptor.open( 527 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 528 try { 529 if (pfd != null) { 530 // The current annotator model may be still used by another thread / model. 531 // Do not call close() here, and let the GC to clean it up when no one else 532 // is using it. 533 annotatorImpl = new AnnotatorModel(pfd.getFd()); 534 Optional<LangIdModel> langIdModel = getLangIdImpl(); 535 if (langIdModel.isPresent()) { 536 annotatorImpl.setLangIdModel(langIdModel.get()); 537 } 538 annotatorModelInUse = bestModel; 539 } 540 } finally { 541 maybeCloseAndLogError(pfd); 542 } 543 } 544 return annotatorImpl; 545 } 546 } 547 548 private Optional<LangIdModel> getLangIdImpl() { 549 synchronized (lock) { 550 final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null); 551 if (bestModel == null) { 552 return Optional.absent(); 553 } 554 if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) { 555 TcLog.d(TAG, "Loading " + bestModel); 556 final ParcelFileDescriptor pfd; 557 try { 558 pfd = 559 ParcelFileDescriptor.open( 560 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 561 } catch (FileNotFoundException e) { 562 TcLog.e(TAG, "Failed to open the LangID model file", e); 563 return Optional.absent(); 564 } 565 try { 566 if (pfd != null) { 567 langIdImpl = new LangIdModel(pfd.getFd()); 568 langIdModelInUse = bestModel; 569 } 570 } finally { 571 maybeCloseAndLogError(pfd); 572 } 573 } 574 return Optional.of(langIdImpl); 575 } 576 } 577 578 @Nullable 579 private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException { 580 synchronized (lock) { 581 // TODO: Use LangID to determine the locale we should use here? 582 final ModelFileManager.ModelFile bestModel = 583 actionsModelFileManager.findBestModelFile(LocaleList.getDefault()); 584 if (bestModel == null) { 585 return null; 586 } 587 if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) { 588 TcLog.d(TAG, "Loading " + bestModel); 589 final ParcelFileDescriptor pfd = 590 ParcelFileDescriptor.open( 591 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 592 try { 593 if (pfd == null) { 594 TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath()); 595 return null; 596 } 597 actionsImpl = new ActionsSuggestionsModel(pfd.getFd()); 598 actionModelInUse = bestModel; 599 } finally { 600 maybeCloseAndLogError(pfd); 601 } 602 } 603 return actionsImpl; 604 } 605 } 606 607 private String createAnnotatorId(String text, int start, int end) { 608 synchronized (lock) { 609 return ResultIdUtils.createId( 610 context, 611 text, 612 start, 613 end, 614 ModelFile.toModelInfos( 615 Optional.fromNullable(annotatorModelInUse), Optional.fromNullable(langIdModelInUse))); 616 } 617 } 618 619 private static String concatenateLocales(@Nullable LocaleList locales) { 620 return (locales == null) ? "" : locales.toLanguageTags(); 621 } 622 623 private TextClassification createClassificationResult( 624 AnnotatorModel.ClassificationResult[] classifications, 625 String text, 626 int start, 627 int end, 628 Optional<LangIdModel> langId) { 629 final String classifiedText = text.substring(start, end); 630 final TextClassification.Builder builder = 631 new TextClassification.Builder().setText(classifiedText); 632 633 final int typeCount = classifications.length; 634 AnnotatorModel.ClassificationResult highestScoringResult = 635 typeCount > 0 ? classifications[0] : null; 636 for (int i = 0; i < typeCount; i++) { 637 builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore()); 638 if (classifications[i].getScore() > highestScoringResult.getScore()) { 639 highestScoringResult = classifications[i]; 640 } 641 } 642 643 boolean isPrimaryAction = true; 644 final ImmutableList<LabeledIntent> labeledIntents = 645 highestScoringResult == null 646 ? ImmutableList.of() 647 : templateIntentFactory.create(highestScoringResult.getRemoteActionTemplates()); 648 final LabeledIntent.TitleChooser titleChooser = 649 (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity; 650 651 ArrayList<Intent> actionIntents = new ArrayList<>(); 652 for (LabeledIntent labeledIntent : labeledIntents) { 653 final LabeledIntent.Result result = labeledIntent.resolve(context, titleChooser); 654 if (result == null) { 655 continue; 656 } 657 658 final Intent intent = result.resolvedIntent; 659 final RemoteAction action = result.remoteAction.toRemoteAction(); 660 if (isPrimaryAction) { 661 // For O backwards compatibility, the first RemoteAction is also written to the 662 // legacy API fields. 663 builder.setIcon(action.getIcon().loadDrawable(context)); 664 builder.setLabel(action.getTitle().toString()); 665 builder.setIntent(intent); 666 builder.setOnClickListener( 667 createIntentOnClickListener( 668 createPendingIntent(context, intent, labeledIntent.requestCode))); 669 isPrimaryAction = false; 670 } 671 builder.addAction(action); 672 actionIntents.add(intent); 673 } 674 Bundle extras = new Bundle(); 675 Optional<Bundle> foreignLanguageExtra = 676 langId 677 .transform(model -> maybeCreateExtrasForTranslate(actionIntents, model)) 678 .or(Optional.<Bundle>absent()); 679 if (foreignLanguageExtra.isPresent()) { 680 ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get()); 681 } 682 if (actionIntents.stream().anyMatch(Objects::nonNull)) { 683 ArrayList<Intent> strippedIntents = 684 actionIntents.stream() 685 .map(TextClassifierImpl::stripPackageInfoFromIntent) 686 .collect(toCollection(ArrayList::new)); 687 ExtrasUtils.putActionsIntents(extras, strippedIntents); 688 } 689 ExtrasUtils.putEntities(extras, classifications); 690 builder.setExtras(extras); 691 String resultId = createAnnotatorId(text, start, end); 692 return builder.setId(resultId).build(); 693 } 694 695 private static OnClickListener createIntentOnClickListener(final PendingIntent intent) { 696 Preconditions.checkNotNull(intent); 697 return v -> { 698 try { 699 intent.send(); 700 } catch (PendingIntent.CanceledException e) { 701 TcLog.e(TAG, "Error sending PendingIntent", e); 702 } 703 }; 704 } 705 706 private static Optional<Bundle> maybeCreateExtrasForTranslate( 707 List<Intent> intents, LangIdModel langId) { 708 Optional<Intent> translateIntent = 709 FluentIterable.from(intents) 710 .filter(Objects::nonNull) 711 .filter(intent -> Intent.ACTION_TRANSLATE.equals(intent.getAction())) 712 .first(); 713 if (!translateIntent.isPresent()) { 714 return Optional.absent(); 715 } 716 Pair<String, Float> topLanguageWithScore = ExtrasUtils.getTopLanguage(translateIntent.get()); 717 if (topLanguageWithScore == null) { 718 return Optional.absent(); 719 } 720 return Optional.of( 721 ExtrasUtils.createForeignLanguageExtra( 722 topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion())); 723 } 724 725 private ImmutableList<String> detectLanguageTags( 726 Optional<LangIdModel> langId, CharSequence text) { 727 return langId 728 .transform( 729 model -> { 730 float threshold = getLangIdThreshold(model); 731 EntityConfidence languagesConfidence = detectLanguages(model, text, threshold); 732 return ImmutableList.copyOf(languagesConfidence.getEntities()); 733 }) 734 .or(ImmutableList.of()); 735 } 736 737 /** 738 * Detects languages for the specified text. Only returns languages with score that is higher than 739 * or equal to the specified threshold. 740 */ 741 private static EntityConfidence detectLanguages( 742 LangIdModel langId, CharSequence text, float threshold) { 743 final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString()); 744 final Map<String, Float> languagesMap = new ArrayMap<>(); 745 for (LangIdModel.LanguageResult langResult : langResults) { 746 if (langResult.getScore() >= threshold) { 747 languagesMap.put(langResult.getLanguage(), langResult.getScore()); 748 } 749 } 750 return new EntityConfidence(languagesMap); 751 } 752 753 private float getLangIdThreshold(LangIdModel langId) { 754 return settings.getLangIdThresholdOverride() >= 0 755 ? settings.getLangIdThresholdOverride() 756 : langId.getLangIdThreshold(); 757 } 758 759 void dump(IndentingPrintWriter printWriter) { 760 synchronized (lock) { 761 printWriter.println("TextClassifierImpl:"); 762 printWriter.increaseIndent(); 763 printWriter.println("Annotator model file(s):"); 764 printWriter.increaseIndent(); 765 for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) { 766 printWriter.println(modelFile.toString()); 767 } 768 printWriter.decreaseIndent(); 769 printWriter.println("LangID model file(s):"); 770 printWriter.increaseIndent(); 771 for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) { 772 printWriter.println(modelFile.toString()); 773 } 774 printWriter.decreaseIndent(); 775 printWriter.println("Actions model file(s):"); 776 printWriter.increaseIndent(); 777 for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) { 778 printWriter.println(modelFile.toString()); 779 } 780 printWriter.decreaseIndent(); 781 printWriter.printPair("mFallback", fallback); 782 printWriter.decreaseIndent(); 783 printWriter.println(); 784 settings.dump(printWriter); 785 } 786 } 787 788 /** Returns the locales string for the current resources configuration. */ 789 private String getResourceLocalesString() { 790 try { 791 return context.getResources().getConfiguration().getLocales().toLanguageTags(); 792 } catch (NullPointerException e) { 793 794 // NPE is unexpected. Erring on the side of caution. 795 return LocaleList.getDefault().toLanguageTags(); 796 } 797 } 798 799 /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */ 800 private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { 801 if (fd == null) { 802 return; 803 } 804 805 try { 806 fd.close(); 807 } catch (IOException e) { 808 TcLog.e(TAG, "Error closing file.", e); 809 } 810 } 811 812 private static void checkMainThread() { 813 if (Looper.myLooper() == Looper.getMainLooper()) { 814 TcLog.e(TAG, "TextClassifier called on main thread", new Exception()); 815 } 816 } 817 818 private static PendingIntent createPendingIntent( 819 final Context context, final Intent intent, int requestCode) { 820 return PendingIntent.getActivity( 821 context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT); 822 } 823 824 @Nullable 825 private static Intent stripPackageInfoFromIntent(@Nullable Intent intent) { 826 if (intent == null) { 827 return null; 828 } 829 Intent strippedIntent = new Intent(intent); 830 strippedIntent.setPackage(null); 831 strippedIntent.setComponent(null); 832 return strippedIntent; 833 } 834 } 835