1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.support.label; 17 18 import android.content.Context; 19 import java.nio.ByteBuffer; 20 import java.util.ArrayList; 21 import java.util.Arrays; 22 import java.util.LinkedHashMap; 23 import java.util.List; 24 import java.util.Map; 25 import org.checkerframework.checker.nullness.qual.NonNull; 26 import org.tensorflow.lite.DataType; 27 import org.tensorflow.lite.support.common.SupportPreconditions; 28 import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; 29 30 /** 31 * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis. 32 * 33 * <p>For example, an image classification model may have an output tensor with shape as {1, 10}, 34 * where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could 35 * label each sub-tensor with the name or description of each corresponding category. {@link 36 * TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from 37 * predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link 38 * TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which 39 * is Tensor in shape {} (scalar). Usage example: 40 * 41 * <pre> 42 * TensorBuffer outputTensor = ...; 43 * {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath); 44 * // labels the first axis with size greater than one 45 * TensorLabel labeled = new TensorLabel(labels, outputTensor); 46 * // If each sub-tensor has effectively size 1, we can directly get a float value 47 * {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue(); 48 * // Or get sub-tensors, when each sub-tensor has elements more than 1 49 * {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer(); 50 * </pre> 51 * 52 * <p>Note: currently we only support tensor-to-map conversion for the first label with size greater 53 * than 1. 54 * 55 * @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from 56 * a label file (plain text file whose each line is a label) in assets simply. 57 */ 58 public class TensorLabel { 59 private final Map<Integer, List<String>> axisLabels; 60 private final TensorBuffer tensorBuffer; 61 private final int[] shape; 62 63 /** 64 * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors. 65 * 66 * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding 67 * labels. Note: The size of labels should be same with the size of the tensor on that axis. 68 * @param tensorBuffer The TensorBuffer to be labeled. 69 * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any 70 * value in {@code axisLabels} is null. 71 * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to 72 * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code 73 * tensorBuffer} on the given dimension. 74 */ TensorLabel( @onNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer)75 public TensorLabel( 76 @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) { 77 SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null."); 78 SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null."); 79 this.axisLabels = axisLabels; 80 this.tensorBuffer = tensorBuffer; 81 this.shape = tensorBuffer.getShape(); 82 for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) { 83 int axis = entry.getKey(); 84 SupportPreconditions.checkArgument( 85 axis >= 0 && axis < shape.length, "Invalid axis id: " + axis); 86 SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis); 87 SupportPreconditions.checkArgument( 88 shape[axis] == entry.getValue().size(), 89 "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis); 90 } 91 } 92 93 /** 94 * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors. 95 * 96 * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if 97 * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from 98 * 0), and size of {@code axisLabels} should be 10 as well. 99 * 100 * @param axisLabels A list of labels, whose size should be same with the size of the tensor on 101 * the to-be-labeled axis. 102 * @param tensorBuffer The TensorBuffer to be labeled. 103 */ TensorLabel(@onNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer)104 public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) { 105 this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer); 106 } 107 108 /** 109 * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the 110 * mapping on the first axis with size greater than 1 currently. 111 */ 112 @NonNull getMapWithTensorBuffer()113 public Map<String, TensorBuffer> getMapWithTensorBuffer() { 114 int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); 115 116 Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>(); 117 SupportPreconditions.checkArgument( 118 axisLabels.containsKey(labeledAxis), 119 "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis."); 120 List<String> labels = axisLabels.get(labeledAxis); 121 122 DataType dataType = tensorBuffer.getDataType(); 123 int typeSize = tensorBuffer.getTypeSize(); 124 int flatSize = tensorBuffer.getFlatSize(); 125 126 // Gets the underlying bytes that could be used to generate the sub-array later. 127 ByteBuffer byteBuffer = tensorBuffer.getBuffer(); 128 byteBuffer.rewind(); 129 130 // Note: computation below is only correct when labeledAxis is the first axis with size greater 131 // than 1. 132 int subArrayLength = flatSize / shape[labeledAxis] * typeSize; 133 int i = 0; 134 SupportPreconditions.checkNotNull(labels, "Label list should never be null"); 135 for (String label : labels) { 136 // Gets the corresponding TensorBuffer. 137 byteBuffer.position(i * subArrayLength); 138 ByteBuffer subBuffer = byteBuffer.slice(); 139 // ByteBuffer.slice doesn't keep order. Modify it to align with the original one. 140 subBuffer.order(byteBuffer.order()).limit(subArrayLength); 141 TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType); 142 labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length)); 143 labelToTensorMap.put(label, labelBuffer); 144 i += 1; 145 } 146 return labelToTensorMap; 147 } 148 149 /** 150 * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater 151 * than 1, and the axis should be effectively the last axis (which means every sub tensor 152 * specified by this axis should have a flat size of 1). 153 * 154 * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result. 155 * 156 * @throws IllegalStateException if size of a sub tensor on each label is not 1. 157 */ 158 @NonNull getMapWithFloatValue()159 public Map<String, Float> getMapWithFloatValue() { 160 int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); 161 SupportPreconditions.checkState( 162 labeledAxis == shape.length - 1, 163 "get a <String, Scalar> map is only valid when the only labeled axis is the last one."); 164 List<String> labels = axisLabels.get(labeledAxis); 165 float[] data = tensorBuffer.getFloatArray(); 166 SupportPreconditions.checkState(labels.size() == data.length); 167 Map<String, Float> result = new LinkedHashMap<>(); 168 int i = 0; 169 for (String label : labels) { 170 result.put(label, data[i]); 171 i += 1; 172 } 173 return result; 174 } 175 176 /** 177 * Gets a list of {@link Category} from the {@link TensorLabel} object. 178 * 179 * <p>The axis of label should be effectively the last axis (which means every sub tensor 180 * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be 181 * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}} 182 * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}. 183 * 184 * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as 185 * the result. 186 * 187 * @throws IllegalStateException if size of a sub tensor on each label is not 1. 188 */ 189 @NonNull getCategoryList()190 public List<Category> getCategoryList() { 191 int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); 192 SupportPreconditions.checkState( 193 labeledAxis == shape.length - 1, 194 "get a Category list is only valid when the only labeled axis is the last one."); 195 List<String> labels = axisLabels.get(labeledAxis); 196 float[] data = tensorBuffer.getFloatArray(); 197 SupportPreconditions.checkState(labels.size() == data.length); 198 List<Category> result = new ArrayList<>(); 199 int i = 0; 200 for (String label : labels) { 201 result.add(new Category(label, data[i])); 202 i += 1; 203 } 204 return result; 205 } 206 getFirstAxisWithSizeGreaterThanOne(@onNull TensorBuffer tensorBuffer)207 private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) { 208 int[] shape = tensorBuffer.getShape(); 209 for (int i = 0; i < shape.length; i++) { 210 if (shape[i] > 1) { 211 return i; 212 } 213 } 214 throw new IllegalArgumentException( 215 "Cannot find an axis to label. A valid axis to label should have size larger than 1."); 216 } 217 218 // Helper function to wrap the List<String> to a one-entry map. makeMap(int axis, List<String> labels)219 private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) { 220 Map<Integer, List<String>> map = new LinkedHashMap<>(); 221 map.put(axis, labels); 222 return map; 223 } 224 } 225