1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.support.label;
17 
18 import android.content.Context;
19 import java.nio.ByteBuffer;
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.LinkedHashMap;
23 import java.util.List;
24 import java.util.Map;
25 import org.checkerframework.checker.nullness.qual.NonNull;
26 import org.tensorflow.lite.DataType;
27 import org.tensorflow.lite.support.common.SupportPreconditions;
28 import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
29 
30 /**
31  * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
32  *
33  * <p>For example, an image classification model may have an output tensor with shape as {1, 10},
34  * where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
35  * label each sub-tensor with the name or description of each corresponding category. {@link
36  * TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from
37  * predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link
38  * TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which
39  * is Tensor in shape {} (scalar). Usage example:
40  *
41  * <pre>
42  *   TensorBuffer outputTensor = ...;
43  *   {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath);
44  *   // labels the first axis with size greater than one
45  *   TensorLabel labeled = new TensorLabel(labels, outputTensor);
46  *   // If each sub-tensor has effectively size 1, we can directly get a float value
47  *   {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue();
48  *   // Or get sub-tensors, when each sub-tensor has elements more than 1
49  *   {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer();
50  * </pre>
51  *
52  * <p>Note: currently we only support tensor-to-map conversion for the first label with size greater
53  * than 1.
54  *
55  * @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from
56  *     a label file (plain text file whose each line is a label) in assets simply.
57  */
58 public class TensorLabel {
59   private final Map<Integer, List<String>> axisLabels;
60   private final TensorBuffer tensorBuffer;
61   private final int[] shape;
62 
63   /**
64    * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
65    *
66    * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
67    *     labels. Note: The size of labels should be same with the size of the tensor on that axis.
68    * @param tensorBuffer The TensorBuffer to be labeled.
69    * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
70    *     value in {@code axisLabels} is null.
71    * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
72    *     the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
73    *     tensorBuffer} on the given dimension.
74    */
TensorLabel( @onNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer)75   public TensorLabel(
76       @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
77     SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
78     SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
79     this.axisLabels = axisLabels;
80     this.tensorBuffer = tensorBuffer;
81     this.shape = tensorBuffer.getShape();
82     for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
83       int axis = entry.getKey();
84       SupportPreconditions.checkArgument(
85           axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
86       SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
87       SupportPreconditions.checkArgument(
88           shape[axis] == entry.getValue().size(),
89           "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
90     }
91   }
92 
93   /**
94    * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
95    *
96    * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
97    * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
98    * 0), and size of {@code axisLabels} should be 10 as well.
99    *
100    * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
101    *     the to-be-labeled axis.
102    * @param tensorBuffer The TensorBuffer to be labeled.
103    */
TensorLabel(@onNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer)104   public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
105     this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
106   }
107 
108   /**
109    * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
110    * mapping on the first axis with size greater than 1 currently.
111    */
112   @NonNull
getMapWithTensorBuffer()113   public Map<String, TensorBuffer> getMapWithTensorBuffer() {
114     int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
115 
116     Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
117     SupportPreconditions.checkArgument(
118         axisLabels.containsKey(labeledAxis),
119         "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
120     List<String> labels = axisLabels.get(labeledAxis);
121 
122     DataType dataType = tensorBuffer.getDataType();
123     int typeSize = tensorBuffer.getTypeSize();
124     int flatSize = tensorBuffer.getFlatSize();
125 
126     // Gets the underlying bytes that could be used to generate the sub-array later.
127     ByteBuffer byteBuffer = tensorBuffer.getBuffer();
128     byteBuffer.rewind();
129 
130     // Note: computation below is only correct when labeledAxis is the first axis with size greater
131     // than 1.
132     int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
133     int i = 0;
134     SupportPreconditions.checkNotNull(labels, "Label list should never be null");
135     for (String label : labels) {
136       // Gets the corresponding TensorBuffer.
137       byteBuffer.position(i * subArrayLength);
138       ByteBuffer subBuffer = byteBuffer.slice();
139       // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
140       subBuffer.order(byteBuffer.order()).limit(subArrayLength);
141       TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
142       labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
143       labelToTensorMap.put(label, labelBuffer);
144       i += 1;
145     }
146     return labelToTensorMap;
147   }
148 
149   /**
150    * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
151    * than 1, and the axis should be effectively the last axis (which means every sub tensor
152    * specified by this axis should have a flat size of 1).
153    *
154    * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
155    *
156    * @throws IllegalStateException if size of a sub tensor on each label is not 1.
157    */
158   @NonNull
getMapWithFloatValue()159   public Map<String, Float> getMapWithFloatValue() {
160     int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
161     SupportPreconditions.checkState(
162         labeledAxis == shape.length - 1,
163         "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
164     List<String> labels = axisLabels.get(labeledAxis);
165     float[] data = tensorBuffer.getFloatArray();
166     SupportPreconditions.checkState(labels.size() == data.length);
167     Map<String, Float> result = new LinkedHashMap<>();
168     int i = 0;
169     for (String label : labels) {
170       result.put(label, data[i]);
171       i += 1;
172     }
173     return result;
174   }
175 
176   /**
177    * Gets a list of {@link Category} from the {@link TensorLabel} object.
178    *
179    * <p>The axis of label should be effectively the last axis (which means every sub tensor
180    * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
181    * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
182    * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
183    *
184    * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
185    * the result.
186    *
187    * @throws IllegalStateException if size of a sub tensor on each label is not 1.
188    */
189   @NonNull
getCategoryList()190   public List<Category> getCategoryList() {
191     int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
192     SupportPreconditions.checkState(
193         labeledAxis == shape.length - 1,
194         "get a Category list is only valid when the only labeled axis is the last one.");
195     List<String> labels = axisLabels.get(labeledAxis);
196     float[] data = tensorBuffer.getFloatArray();
197     SupportPreconditions.checkState(labels.size() == data.length);
198     List<Category> result = new ArrayList<>();
199     int i = 0;
200     for (String label : labels) {
201       result.add(new Category(label, data[i]));
202       i += 1;
203     }
204     return result;
205   }
206 
getFirstAxisWithSizeGreaterThanOne(@onNull TensorBuffer tensorBuffer)207   private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
208     int[] shape = tensorBuffer.getShape();
209     for (int i = 0; i < shape.length; i++) {
210       if (shape[i] > 1) {
211         return i;
212       }
213     }
214     throw new IllegalArgumentException(
215         "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
216   }
217 
218   // Helper function to wrap the List<String> to a one-entry map.
makeMap(int axis, List<String> labels)219   private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
220     Map<Integer, List<String>> map = new LinkedHashMap<>();
221     map.put(axis, labels);
222     return map;
223   }
224 }
225