1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.support.tensorbuffer; 17 18 import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; 19 import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; 20 import static org.tensorflow.lite.support.common.SupportPreconditions.checkState; 21 22 import java.nio.ByteBuffer; 23 import java.nio.ByteOrder; 24 import java.util.Arrays; 25 import org.checkerframework.checker.nullness.qual.NonNull; 26 import org.tensorflow.lite.DataType; 27 28 /** Represents the data buffer for either a model's input or its output. */ 29 public abstract class TensorBuffer { 30 /** Where the data is stored. */ 31 protected ByteBuffer buffer; 32 33 /** Shape of the tensor stored in this buffer. */ 34 protected int[] shape; 35 36 /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */ 37 protected int flatSize = -1; 38 39 /** 40 * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have 41 * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed. 42 */ 43 protected final boolean isDynamic; 44 45 /** 46 * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some 47 * examples: 48 * 49 * <pre> 50 * Creating a float TensorBuffer with shape {2, 3}: 51 * int[] shape = new int[] {2, 3}; 52 * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); 53 * </pre> 54 * 55 * <pre> 56 * Creating an uint8 TensorBuffer of a scalar: 57 * int[] shape = new int[] {}; 58 * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); 59 * </pre> 60 * 61 * <pre> 62 * Creating an empty uint8 TensorBuffer: 63 * int[] shape = new int[] {0}; 64 * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); 65 * </pre> 66 * 67 * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created. 68 * 69 * @param shape The shape of the {@link TensorBuffer} to be created. 70 * @param dataType The dataType of the {@link TensorBuffer} to be created. 71 * @throws NullPointerException if {@code shape} is null. 72 * @throws IllegalArgumentException if {@code shape} has non-positive elements. 73 */ 74 @NonNull createFixedSize(@onNull int[] shape, DataType dataType)75 public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) { 76 switch (dataType) { 77 case FLOAT32: 78 return new TensorBufferFloat(shape); 79 case UINT8: 80 return new TensorBufferUint8(shape); 81 default: 82 throw new AssertionError("TensorBuffer does not support data type: " + dataType); 83 } 84 } 85 86 /** 87 * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the 88 * created {@link TensorBuffer} is {0}. 89 * 90 * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of 91 * different buffer sizes. 92 * 93 * @param dataType The dataType of the {@link TensorBuffer} to be created. 94 */ 95 @NonNull createDynamic(DataType dataType)96 public static TensorBuffer createDynamic(DataType dataType) { 97 switch (dataType) { 98 case FLOAT32: 99 return new TensorBufferFloat(); 100 case UINT8: 101 return new TensorBufferUint8(); 102 default: 103 throw new AssertionError("TensorBuffer does not support data type: " + dataType); 104 } 105 } 106 107 /** 108 * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}. 109 * 110 * @param buffer the source {@link TensorBuffer} to copy from. 111 * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}. 112 * @throws NullPointerException if {@code buffer} is null. 113 */ 114 @NonNull createFrom(@onNull TensorBuffer buffer, DataType dataType)115 public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) { 116 checkNotNull(buffer, "Cannot create a buffer from null"); 117 TensorBuffer result; 118 if (buffer.isDynamic()) { 119 result = createDynamic(dataType); 120 } else { 121 result = createFixedSize(buffer.shape, dataType); 122 } 123 // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as 124 // intermediate container. 125 // The assumption is not true when we support other data types. 126 if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) { 127 float[] data = buffer.getFloatArray(); 128 result.loadArray(data, buffer.shape); 129 } else { 130 int[] data = buffer.getIntArray(); 131 result.loadArray(data, buffer.shape); 132 } 133 return result; 134 } 135 136 /** Returns the data buffer. */ 137 @NonNull getBuffer()138 public ByteBuffer getBuffer() { 139 return buffer; 140 } 141 142 /** 143 * Gets the {@link TensorBuffer#flatSize} of the buffer. 144 * 145 * @throws IllegalStateException if the underlying data is corrupted 146 */ getFlatSize()147 public int getFlatSize() { 148 assertShapeIsCorect(); 149 return flatSize; 150 } 151 152 /** 153 * Gets the current shape. (returning a copy here to avoid unexpected modification.) 154 * 155 * @throws IllegalStateException if the underlying data is corrupted 156 */ 157 @NonNull getShape()158 public int[] getShape() { 159 assertShapeIsCorect(); 160 return Arrays.copyOf(shape, shape.length); 161 } 162 163 /** Returns the data type of this buffer. */ getDataType()164 public abstract DataType getDataType(); 165 166 /** 167 * Returns a float array of the values stored in this buffer. If the buffer is of different types 168 * than float, the values will be converted into float. For example, values in {@link 169 * TensorBufferUint8} will be converted from uint8 to float. 170 */ 171 @NonNull getFloatArray()172 public abstract float[] getFloatArray(); 173 174 /** 175 * Returns a float value at a given index. If the buffer is of different types than float, the 176 * value will be converted into float. For example, when reading a value from {@link 177 * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from 178 * uint8 to float. 179 * 180 * <pre> 181 * For example, a TensorBuffer with shape {2, 3} that represents the following array, 182 * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. 183 * 184 * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: 185 * float v = tensorBuffer.getFloatValue(3); 186 * </pre> 187 * 188 * @param absIndex The absolute index of the value to be read. 189 */ getFloatValue(int absIndex)190 public abstract float getFloatValue(int absIndex); 191 192 /** 193 * Returns an int array of the values stored in this buffer. If the buffer is of different type 194 * than int, the values will be converted into int, and loss of precision may apply. For example, 195 * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output 196 * is {400, 23}. 197 */ 198 @NonNull getIntArray()199 public abstract int[] getIntArray(); 200 201 /** 202 * Returns an int value at a given index. If the buffer is of different types than int, the value 203 * will be converted into int. For example, when reading a value from {@link TensorBufferFloat}, 204 * the value will be first read out as float, and then will be converted from float to int. Loss 205 * of precision may apply. 206 * 207 * <pre> 208 * For example, a TensorBuffer with shape {2, 3} that represents the following array, 209 * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. 210 * 211 * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: 212 * int v = tensorBuffer.getIntValue(3); 213 * Note that v is converted from 3.0f to 3 as a result of type conversion. 214 * </pre> 215 * 216 * @param absIndex The absolute index of the value to be read. 217 */ getIntValue(int absIndex)218 public abstract int getIntValue(int absIndex); 219 220 /** 221 * Returns the number of bytes of a single element in the array. For example, a float buffer will 222 * return 4, and a byte buffer will return 1. 223 */ getTypeSize()224 public abstract int getTypeSize(); 225 226 /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */ isDynamic()227 public boolean isDynamic() { 228 return isDynamic; 229 } 230 231 /** 232 * Loads an int array into this buffer with specific shape. If the buffer is of different types 233 * than int, the values will be converted into the buffer's type before being loaded into the 234 * buffer, and loss of precision may apply. For example, loading an int array with values {400, 235 * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be 236 * casted to uint8 by {255, 0}. 237 * 238 * @param src The source array to be loaded. 239 * @param shape Shape of the tensor that {@code src} represents. 240 * @throws NullPointerException if {@code src} is null. 241 * @throws NullPointerException if {@code shape} is null. 242 * @throws IllegalArgumentException if the size of the array to be loaded does not match the 243 * specified shape. 244 */ loadArray(@onNull int[] src, @NonNull int[] shape)245 public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape); 246 247 /** 248 * Loads an int array into this buffer. If the buffer is of different types than int, the values 249 * will be converted into the buffer's type before being loaded into the buffer, and loss of 250 * precision may apply. For example, loading an int array with values {400, -23} into a {@link 251 * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by 252 * {255, 0}. 253 * 254 * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both 255 * fixed-size and dynamic {@link TensorBuffer}. 256 * 257 * @param src The source array to be loaded. 258 */ loadArray(@onNull int[] src)259 public void loadArray(@NonNull int[] src) { 260 loadArray(src, shape); 261 } 262 263 /** 264 * Loads a float array into this buffer with specific shape. If the buffer is of different types 265 * than float, the values will be converted into the buffer's type before being loaded into the 266 * buffer, and loss of precision may apply. For example, loading a float array into a {@link 267 * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and 268 * then be casted to uint8 by {255, 0}. 269 * 270 * @param src The source array to be loaded. 271 * @param shape Shape of the tensor that {@code src} represents. 272 * @throws NullPointerException if {@code src} is null. 273 * @throws NullPointerException if {@code shape} is null. 274 * @throws IllegalArgumentException if the size of the array to be loaded does not match the 275 * specified shape. 276 */ loadArray(@onNull float[] src, @NonNull int[] shape)277 public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape); 278 279 /** 280 * Loads a float array into this buffer. If the buffer is of different types than float, the 281 * values will be converted into the buffer's type before being loaded into the buffer, and loss 282 * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8} 283 * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to 284 * uint8 by {255, 0}. 285 * 286 * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both 287 * fixed-size and dynamic {@link TensorBuffer}. 288 * 289 * @param src The source array to be loaded. 290 */ loadArray(@onNull float[] src)291 public void loadArray(@NonNull float[] src) { 292 loadArray(src, shape); 293 } 294 295 /** 296 * Loads a byte buffer into this {@link TensorBuffer} with specific shape. 297 * 298 * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for 299 * performance concern, but if modification is necessary, please make a copy. 300 * 301 * @param buffer The byte buffer to load. 302 * @throws NullPointerException if {@code buffer} is null. 303 * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not 304 * match or the size of {@code buffer} and {@code flatSize} do not match. 305 */ loadBuffer(@onNull ByteBuffer buffer, @NonNull int[] shape)306 public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) { 307 checkNotNull(buffer, "Byte buffer cannot be null."); 308 int flatSize = computeFlatSize(shape); 309 checkArgument( 310 (buffer.limit() == getTypeSize() * flatSize), 311 "The size of byte buffer and the shape do not match."); 312 313 resize(shape); 314 buffer.rewind(); 315 this.buffer = buffer; 316 } 317 318 /** 319 * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of 320 * this {@link TensorBuffer}. 321 * 322 * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for 323 * performance concern, but if modification is necessary, please make a copy. 324 * 325 * @param buffer The byte buffer to load. 326 */ loadBuffer(@onNull ByteBuffer buffer)327 public void loadBuffer(@NonNull ByteBuffer buffer) { 328 loadBuffer(buffer, shape); 329 } 330 331 /** 332 * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}. 333 * 334 * @throws NullPointerException if {@code shape} is null. 335 * @throws IllegalArgumentException if {@code shape} has non-positive elements. 336 */ TensorBuffer(@onNull int[] shape)337 protected TensorBuffer(@NonNull int[] shape) { 338 isDynamic = false; 339 allocateMemory(shape); 340 } 341 342 /** Constructs a dynamic {@link TensorBuffer} which can be resized. */ TensorBuffer()343 protected TensorBuffer() { 344 isDynamic = true; 345 // Initialize the dynamic TensorBuffer with an empty ByteBuffer. 346 allocateMemory(new int[] {0}); 347 } 348 349 /** Calculates number of elements in the buffer. */ computeFlatSize(@onNull int[] shape)350 protected static int computeFlatSize(@NonNull int[] shape) { 351 checkNotNull(shape, "Shape cannot be null."); 352 int prod = 1; 353 for (int s : shape) { 354 prod = prod * s; 355 } 356 return prod; 357 } 358 359 /** 360 * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code 361 * shape} of src fits the buffer size. 362 */ resize(@onNull int[] shape)363 protected void resize(@NonNull int[] shape) { 364 if (isDynamic) { 365 allocateMemory(shape); 366 } else { 367 // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. 368 checkArgument(Arrays.equals(shape, this.shape)); 369 this.shape = shape.clone(); 370 } 371 } 372 373 /** 374 * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this 375 * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1. 376 * 377 * @throws NullPointerException if {@code shape} is null. 378 * @throws IllegalArgumentException if {@code shape} has negative elements. 379 */ allocateMemory(@onNull int[] shape)380 private void allocateMemory(@NonNull int[] shape) { 381 checkNotNull(shape, "TensorBuffer shape cannot be null."); 382 checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); 383 384 // Check if the new shape is the same as current shape. 385 int newFlatSize = computeFlatSize(shape); 386 this.shape = shape.clone(); 387 if (flatSize == newFlatSize) { 388 return; 389 } 390 391 // Update to the new shape. 392 flatSize = newFlatSize; 393 buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); 394 buffer.order(ByteOrder.nativeOrder()); 395 } 396 397 /** 398 * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link 399 * ByteBuffer}. 400 */ assertShapeIsCorect()401 private void assertShapeIsCorect() { 402 int flatSize = computeFlatSize(shape); 403 checkState( 404 (buffer.limit() == getTypeSize() * flatSize), 405 String.format( 406 "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The" 407 + " ByteBuffer may have been changed.", 408 buffer.limit(), Arrays.toString(shape))); 409 } 410 411 /** 412 * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape} 413 * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar. 414 */ isShapeValid(@onNull int[] shape)415 private static boolean isShapeValid(@NonNull int[] shape) { 416 if (shape.length == 0) { 417 // This shape refers to a scalar. 418 return true; 419 } 420 421 // This shape refers to a multidimensional array. 422 for (int s : shape) { 423 // All elements in shape should be non-negative. 424 if (s < 0) { 425 return false; 426 } 427 } 428 return true; 429 } 430 } 431