1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import androidx.annotation.FloatRange;
20 import androidx.collection.ArrayMap;
21 import com.google.common.base.Preconditions;
22 import java.util.ArrayList;
23 import java.util.Collections;
24 import java.util.List;
25 import java.util.Map;
26 
27 /** Helper object for setting and getting entity scores for classified text. */
28 final class EntityConfidence {
29 
30   static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
31 
32   private final ArrayMap<String, Float> entityConfidence = new ArrayMap<>();
33   private final ArrayList<String> sortedEntities = new ArrayList<>();
34 
35   /**
36    * Constructs an EntityConfidence from a map of entity to confidence.
37    *
38    * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
39    *
40    * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
41    *     (high confidence).
42    */
43   EntityConfidence(Map<String, Float> source) {
44     Preconditions.checkNotNull(source);
45 
46     // Prune non-existent entities and clamp to 1.
47     entityConfidence.ensureCapacity(source.size());
48     for (Map.Entry<String, Float> it : source.entrySet()) {
49       if (it.getValue() <= 0) {
50         continue;
51       }
52       entityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
53     }
54     resetSortedEntitiesFromMap();
55   }
56 
57   /**
58    * Returns an immutable list of entities found in the classified text ordered from high confidence
59    * to low confidence.
60    */
61   public List<String> getEntities() {
62     return Collections.unmodifiableList(sortedEntities);
63   }
64 
65   /**
66    * Returns the confidence score for the specified entity. The value ranges from 0 (low confidence)
67    * to 1 (high confidence). 0 indicates that the entity was not found for the classified text.
68    */
69   @FloatRange(from = 0.0, to = 1.0)
70   public float getConfidenceScore(String entity) {
71     return entityConfidence.getOrDefault(entity, 0f);
72   }
73 
74   @Override
75   public String toString() {
76     return entityConfidence.toString();
77   }
78 
79   private void resetSortedEntitiesFromMap() {
80     sortedEntities.clear();
81     sortedEntities.ensureCapacity(entityConfidence.size());
82     sortedEntities.addAll(entityConfidence.keySet());
83     sortedEntities.sort(
84         (e1, e2) -> {
85           float score1 = entityConfidence.get(e1);
86           float score2 = entityConfidence.get(e2);
87           return Float.compare(score2, score1);
88         });
89   }
90 }
91