1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.support.tensorbuffer;
17 
18 import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
19 import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
20 import static org.tensorflow.lite.support.common.SupportPreconditions.checkState;
21 
22 import java.nio.ByteBuffer;
23 import java.nio.ByteOrder;
24 import java.util.Arrays;
25 import org.checkerframework.checker.nullness.qual.NonNull;
26 import org.tensorflow.lite.DataType;
27 
28 /** Represents the data buffer for either a model's input or its output. */
29 public abstract class TensorBuffer {
30   /** Where the data is stored. */
31   protected ByteBuffer buffer;
32 
33   /** Shape of the tensor stored in this buffer. */
34   protected int[] shape;
35 
36   /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
37   protected int flatSize = -1;
38 
39   /**
40    * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
41    * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
42    */
43   protected final boolean isDynamic;
44 
45   /**
46    * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
47    * examples:
48    *
49    * <pre>
50    * Creating a float TensorBuffer with shape {2, 3}:
51    * int[] shape = new int[] {2, 3};
52    * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
53    * </pre>
54    *
55    * <pre>
56    * Creating an uint8 TensorBuffer of a scalar:
57    * int[] shape = new int[] {};
58    * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
59    * </pre>
60    *
61    * <pre>
62    * Creating an empty uint8 TensorBuffer:
63    * int[] shape = new int[] {0};
64    * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
65    * </pre>
66    *
67    * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
68    *
69    * @param shape The shape of the {@link TensorBuffer} to be created.
70    * @param dataType The dataType of the {@link TensorBuffer} to be created.
71    * @throws NullPointerException if {@code shape} is null.
72    * @throws IllegalArgumentException if {@code shape} has non-positive elements.
73    */
74   @NonNull
createFixedSize(@onNull int[] shape, DataType dataType)75   public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
76     switch (dataType) {
77       case FLOAT32:
78         return new TensorBufferFloat(shape);
79       case UINT8:
80         return new TensorBufferUint8(shape);
81       default:
82         throw new AssertionError("TensorBuffer does not support data type: " + dataType);
83     }
84   }
85 
86   /**
87    * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
88    * created {@link TensorBuffer} is {0}.
89    *
90    * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
91    * different buffer sizes.
92    *
93    * @param dataType The dataType of the {@link TensorBuffer} to be created.
94    */
95   @NonNull
createDynamic(DataType dataType)96   public static TensorBuffer createDynamic(DataType dataType) {
97     switch (dataType) {
98       case FLOAT32:
99         return new TensorBufferFloat();
100       case UINT8:
101         return new TensorBufferUint8();
102       default:
103         throw new AssertionError("TensorBuffer does not support data type: " + dataType);
104     }
105   }
106 
107   /**
108    * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
109    *
110    * @param buffer the source {@link TensorBuffer} to copy from.
111    * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
112    * @throws NullPointerException if {@code buffer} is null.
113    */
114   @NonNull
createFrom(@onNull TensorBuffer buffer, DataType dataType)115   public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
116     checkNotNull(buffer, "Cannot create a buffer from null");
117     TensorBuffer result;
118     if (buffer.isDynamic()) {
119       result = createDynamic(dataType);
120     } else {
121       result = createFixedSize(buffer.shape, dataType);
122     }
123     // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
124     // intermediate container.
125     // The assumption is not true when we support other data types.
126     if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
127       float[] data = buffer.getFloatArray();
128       result.loadArray(data, buffer.shape);
129     } else {
130       int[] data = buffer.getIntArray();
131       result.loadArray(data, buffer.shape);
132     }
133     return result;
134   }
135 
136   /** Returns the data buffer. */
137   @NonNull
getBuffer()138   public ByteBuffer getBuffer() {
139     return buffer;
140   }
141 
142   /**
143    * Gets the {@link TensorBuffer#flatSize} of the buffer.
144    *
145    * @throws IllegalStateException if the underlying data is corrupted
146    */
getFlatSize()147   public int getFlatSize() {
148     assertShapeIsCorect();
149     return flatSize;
150   }
151 
152   /**
153    * Gets the current shape. (returning a copy here to avoid unexpected modification.)
154    *
155    * @throws IllegalStateException if the underlying data is corrupted
156    */
157   @NonNull
getShape()158   public int[] getShape() {
159     assertShapeIsCorect();
160     return Arrays.copyOf(shape, shape.length);
161   }
162 
163   /** Returns the data type of this buffer. */
getDataType()164   public abstract DataType getDataType();
165 
166   /**
167    * Returns a float array of the values stored in this buffer. If the buffer is of different types
168    * than float, the values will be converted into float. For example, values in {@link
169    * TensorBufferUint8} will be converted from uint8 to float.
170    */
171   @NonNull
getFloatArray()172   public abstract float[] getFloatArray();
173 
174   /**
175    * Returns a float value at a given index. If the buffer is of different types than float, the
176    * value will be converted into float. For example, when reading a value from {@link
177    * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
178    * uint8 to float.
179    *
180    * <pre>
181    * For example, a TensorBuffer with shape {2, 3} that represents the following array,
182    * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
183    *
184    * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
185    * float v = tensorBuffer.getFloatValue(3);
186    * </pre>
187    *
188    * @param absIndex The absolute index of the value to be read.
189    */
getFloatValue(int absIndex)190   public abstract float getFloatValue(int absIndex);
191 
192   /**
193    * Returns an int array of the values stored in this buffer. If the buffer is of different type
194    * than int, the values will be converted into int, and loss of precision may apply. For example,
195    * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
196    * is {400, 23}.
197    */
198   @NonNull
getIntArray()199   public abstract int[] getIntArray();
200 
201   /**
202    * Returns an int value at a given index. If the buffer is of different types than int, the value
203    * will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
204    * the value will be first read out as float, and then will be converted from float to int. Loss
205    * of precision may apply.
206    *
207    * <pre>
208    * For example, a TensorBuffer with shape {2, 3} that represents the following array,
209    * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
210    *
211    * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
212    * int v = tensorBuffer.getIntValue(3);
213    * Note that v is converted from 3.0f to 3 as a result of type conversion.
214    * </pre>
215    *
216    * @param absIndex The absolute index of the value to be read.
217    */
getIntValue(int absIndex)218   public abstract int getIntValue(int absIndex);
219 
220   /**
221    * Returns the number of bytes of a single element in the array. For example, a float buffer will
222    * return 4, and a byte buffer will return 1.
223    */
getTypeSize()224   public abstract int getTypeSize();
225 
226   /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
isDynamic()227   public boolean isDynamic() {
228     return isDynamic;
229   }
230 
231   /**
232    * Loads an int array into this buffer with specific shape. If the buffer is of different types
233    * than int, the values will be converted into the buffer's type before being loaded into the
234    * buffer, and loss of precision may apply. For example, loading an int array with values {400,
235    * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
236    * casted to uint8 by {255, 0}.
237    *
238    * @param src The source array to be loaded.
239    * @param shape Shape of the tensor that {@code src} represents.
240    * @throws NullPointerException if {@code src} is null.
241    * @throws NullPointerException if {@code shape} is null.
242    * @throws IllegalArgumentException if the size of the array to be loaded does not match the
243    *     specified shape.
244    */
loadArray(@onNull int[] src, @NonNull int[] shape)245   public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
246 
247   /**
248    * Loads an int array into this buffer. If the buffer is of different types than int, the values
249    * will be converted into the buffer's type before being loaded into the buffer, and loss of
250    * precision may apply. For example, loading an int array with values {400, -23} into a {@link
251    * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
252    * {255, 0}.
253    *
254    * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
255    * fixed-size and dynamic {@link TensorBuffer}.
256    *
257    * @param src The source array to be loaded.
258    */
loadArray(@onNull int[] src)259   public void loadArray(@NonNull int[] src) {
260     loadArray(src, shape);
261   }
262 
263   /**
264    * Loads a float array into this buffer with specific shape. If the buffer is of different types
265    * than float, the values will be converted into the buffer's type before being loaded into the
266    * buffer, and loss of precision may apply. For example, loading a float array into a {@link
267    * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
268    * then be casted to uint8 by {255, 0}.
269    *
270    * @param src The source array to be loaded.
271    * @param shape Shape of the tensor that {@code src} represents.
272    * @throws NullPointerException if {@code src} is null.
273    * @throws NullPointerException if {@code shape} is null.
274    * @throws IllegalArgumentException if the size of the array to be loaded does not match the
275    *     specified shape.
276    */
loadArray(@onNull float[] src, @NonNull int[] shape)277   public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
278 
279   /**
280    * Loads a float array into this buffer. If the buffer is of different types than float, the
281    * values will be converted into the buffer's type before being loaded into the buffer, and loss
282    * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
283    * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
284    * uint8 by {255, 0}.
285    *
286    * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
287    * fixed-size and dynamic {@link TensorBuffer}.
288    *
289    * @param src The source array to be loaded.
290    */
loadArray(@onNull float[] src)291   public void loadArray(@NonNull float[] src) {
292     loadArray(src, shape);
293   }
294 
295   /**
296    * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
297    *
298    * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
299    * performance concern, but if modification is necessary, please make a copy.
300    *
301    * @param buffer The byte buffer to load.
302    * @throws NullPointerException if {@code buffer} is null.
303    * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
304    *     match or the size of {@code buffer} and {@code flatSize} do not match.
305    */
loadBuffer(@onNull ByteBuffer buffer, @NonNull int[] shape)306   public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
307     checkNotNull(buffer, "Byte buffer cannot be null.");
308     int flatSize = computeFlatSize(shape);
309     checkArgument(
310         (buffer.limit() == getTypeSize() * flatSize),
311         "The size of byte buffer and the shape do not match.");
312 
313     resize(shape);
314     buffer.rewind();
315     this.buffer = buffer;
316   }
317 
318   /**
319    * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
320    * this {@link TensorBuffer}.
321    *
322    * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
323    * performance concern, but if modification is necessary, please make a copy.
324    *
325    * @param buffer The byte buffer to load.
326    */
loadBuffer(@onNull ByteBuffer buffer)327   public void loadBuffer(@NonNull ByteBuffer buffer) {
328     loadBuffer(buffer, shape);
329   }
330 
331   /**
332    * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
333    *
334    * @throws NullPointerException if {@code shape} is null.
335    * @throws IllegalArgumentException if {@code shape} has non-positive elements.
336    */
TensorBuffer(@onNull int[] shape)337   protected TensorBuffer(@NonNull int[] shape) {
338     isDynamic = false;
339     allocateMemory(shape);
340   }
341 
342   /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
TensorBuffer()343   protected TensorBuffer() {
344     isDynamic = true;
345     // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
346     allocateMemory(new int[] {0});
347   }
348 
349   /** Calculates number of elements in the buffer. */
computeFlatSize(@onNull int[] shape)350   protected static int computeFlatSize(@NonNull int[] shape) {
351     checkNotNull(shape, "Shape cannot be null.");
352     int prod = 1;
353     for (int s : shape) {
354       prod = prod * s;
355     }
356     return prod;
357   }
358 
359   /**
360    * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
361    * shape} of src fits the buffer size.
362    */
resize(@onNull int[] shape)363   protected void resize(@NonNull int[] shape) {
364     if (isDynamic) {
365       allocateMemory(shape);
366     } else {
367       // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
368       checkArgument(Arrays.equals(shape, this.shape));
369       this.shape = shape.clone();
370     }
371   }
372 
373   /**
374    * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
375    * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
376    *
377    * @throws NullPointerException if {@code shape} is null.
378    * @throws IllegalArgumentException if {@code shape} has negative elements.
379    */
allocateMemory(@onNull int[] shape)380   private void allocateMemory(@NonNull int[] shape) {
381     checkNotNull(shape, "TensorBuffer shape cannot be null.");
382     checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
383 
384     // Check if the new shape is the same as current shape.
385     int newFlatSize = computeFlatSize(shape);
386     this.shape = shape.clone();
387     if (flatSize == newFlatSize) {
388       return;
389     }
390 
391     // Update to the new shape.
392     flatSize = newFlatSize;
393     buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
394     buffer.order(ByteOrder.nativeOrder());
395   }
396 
397   /**
398    * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
399    * ByteBuffer}.
400    */
assertShapeIsCorect()401   private void assertShapeIsCorect() {
402     int flatSize = computeFlatSize(shape);
403     checkState(
404         (buffer.limit() == getTypeSize() * flatSize),
405         String.format(
406             "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
407                 + " ByteBuffer may have been changed.",
408             buffer.limit(), Arrays.toString(shape)));
409   }
410 
411   /**
412    * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
413    * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
414    */
isShapeValid(@onNull int[] shape)415   private static boolean isShapeValid(@NonNull int[] shape) {
416     if (shape.length == 0) {
417       // This shape refers to a scalar.
418       return true;
419     }
420 
421     // This shape refers to a multidimensional array.
422     for (int s : shape) {
423       // All elements in shape should be non-negative.
424       if (s < 0) {
425         return false;
426       }
427     }
428     return true;
429   }
430 }
431