1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.support.image;
17 
18 import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
19 
20 import android.graphics.Bitmap;
21 import java.nio.ByteBuffer;
22 import org.tensorflow.lite.DataType;
23 import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
24 
25 /**
26  * TensorImage is the wrapper class for Image object. When using image processing utils in
27  * TFLite.support library, it's common to convert image objects in variant types to TensorImage at
28  * first.
29  *
30  * <p>At present, only RGB images are supported, and the A channel is always ignored.
31  *
32  * <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a
33  * {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only
34  * converts one to the other when needed. A typical use case of {@link TensorImage} is to first load
35  * a {@link Bitmap} image, then process it using {@link ImageProcessor}, and finally get the
36  * underlying {@link ByteBuffer} of the {@link TensorBuffer} and feed it into the TFLite
37  * interpreter.
38  *
39  * <p>IMPORTANT: to achieve the best performance, {@link TensorImage} avoids copying data whenever
40  * it's possible. Therefore, it doesn't own its data. Callers should not modify data objects those
41  * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer)}.
42  *
43  * <p>IMPORTANT: all methods are not proved thread-safe.
44  *
45  * @see ImageProcessor which is often used for transforming a {@link TensorImage}.
46  */
47 // TODO(b/138907116): Support loading images from TensorBuffer with properties.
48 // TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
49 public class TensorImage {
50 
51   private final DataType dataType;
52   private ImageContainer container = null;
53 
54   /**
55    * Initializes a {@link TensorImage} object.
56    *
57    * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
58    * #TensorImage(DataType)} if other data types are preferred.
59    */
TensorImage()60   public TensorImage() {
61     this(DataType.UINT8);
62   }
63 
64   /**
65    * Initializes a {@link TensorImage} object with the specified data type.
66    *
67    * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
68    * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
69    * converted to the specified data type.
70    *
71    * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
72    * the image being loaded to this {@link TensorImage}.
73    *
74    * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
75    *     always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use
76    *     {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
77    *     same time.
78    * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
79    *     {@link DataType#FLOAT32}
80    */
TensorImage(DataType dataType)81   public TensorImage(DataType dataType) {
82     checkArgument(
83         dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
84         "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
85     this.dataType = dataType;
86   }
87 
88   /**
89    * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link Bitmap} .
90    *
91    * @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects
92    *     frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}.
93    */
fromBitmap(Bitmap bitmap)94   public static TensorImage fromBitmap(Bitmap bitmap) {
95     TensorImage image = new TensorImage();
96     image.load(bitmap);
97     return image;
98   }
99 
100   /**
101    * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
102    *
103    * @param src the {@link TensorImage} to copy from
104    * @param dataType the expected data type of newly created {@link TensorImage}
105    * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
106    *     dataType}
107    */
createFrom(TensorImage src, DataType dataType)108   public static TensorImage createFrom(TensorImage src, DataType dataType) {
109     TensorImage dst = new TensorImage(dataType);
110     dst.container = src.container.clone();
111     return dst;
112   }
113 
114   /**
115    * Loads a {@link Bitmap} image object into this {@link TensorImage}.
116    *
117    * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
118    * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
119    * #getBuffer}, where the {@link Bitmap} will be converted into a {@link TensorBuffer}.
120    *
121    * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
122    * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
123    * In this method, we perform a zero-copy approach for that bitmap, by simply holding its
124    * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
125    *
126    * <p>Note: to get the best performance, please load images in the same shape to avoid memory
127    * re-allocation.
128    *
129    * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
130    */
load(Bitmap bitmap)131   public void load(Bitmap bitmap) {
132     container = BitmapContainer.create(bitmap);
133   }
134 
135   /**
136    * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
137    * inside.
138    *
139    * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
140    * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
141    * #getBuffer}.
142    *
143    * @param pixels the RGB pixels representing the image
144    * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
145    * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
146    */
load(float[] pixels, int[] shape)147   public void load(float[] pixels, int[] shape) {
148     TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
149     buffer.loadArray(pixels, shape);
150     load(buffer);
151   }
152 
153   /**
154    * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside.
155    *
156    * <p>Note: numeric casting and clamping will be applied to convert the values into the data type
157    * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}.
158    *
159    * @param pixels the RGB pixels representing the image
160    * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
161    * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
162    */
load(int[] pixels, int[] shape)163   public void load(int[] pixels, int[] shape) {
164     TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
165     buffer.loadArray(pixels, shape);
166     load(buffer);
167   }
168 
169   /**
170    * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
171    *
172    * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
173    * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
174    * #getBuffer}.
175    *
176    * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
177    *     (1, h, w, 3)
178    * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
179    */
load(TensorBuffer buffer)180   public void load(TensorBuffer buffer) {
181     load(buffer, ColorSpaceType.RGB);
182   }
183 
184   /**
185    * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSapceType}.
186    *
187    * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
188    * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
189    * #getBuffer}.
190    *
191    * @throws IllegalArgumentException if the shape of buffer does not match the color space type
192    * @see ColorSpaceType#assertShape
193    */
load(TensorBuffer buffer, ColorSpaceType colorSpaceType)194   public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
195     container = TensorBufferContainer.create(buffer, colorSpaceType);
196   }
197 
198   /**
199    * Returns a {@link Bitmap} representation of this {@link TensorImage}.
200    *
201    * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
202    *
203    * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
204    * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
205    *
206    * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
207    * concern, but if modification is necessary, please make a copy.
208    *
209    * @return a reference to a {@link Bitmap} in {@code ARGB_8888} config ("A" channel is always
210    *     opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of this {@link
211    *     TensorBuffer}.
212    * @throws IllegalStateException if the {@link TensorImage} never loads data
213    */
getBitmap()214   public Bitmap getBitmap() {
215     if (container == null) {
216       throw new IllegalStateException("No image has been loaded yet.");
217     }
218 
219     return container.getBitmap();
220   }
221 
222   /**
223    * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data
224    * type.
225    *
226    * <p>Numeric casting and clamping will be applied if the stored data is different from the data
227    * type of the {@link TensorImage}.
228    *
229    * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
230    * concern, but if modification is necessary, please make a copy.
231    *
232    * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
233    *
234    * @return a reference to a {@link ByteBuffer} which holds the image data
235    * @throws IllegalStateException if the {@link TensorImage} never loads data
236    */
getBuffer()237   public ByteBuffer getBuffer() {
238     return getTensorBuffer().getBuffer();
239   }
240 
241   /**
242    * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
243    * data type.
244    *
245    * <p>Numeric casting and clamping will be applied if the stored data is different from the data
246    * type of the {@link TensorImage}.
247    *
248    * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
249    * concern, but if modification is necessary, please make a copy.
250    *
251    * @return a reference to a {@link TensorBuffer} which holds the image data
252    * @throws IllegalStateException if the {@link TensorImage} never loads data
253    */
getTensorBuffer()254   public TensorBuffer getTensorBuffer() {
255     if (container == null) {
256       throw new IllegalStateException("No image has been loaded yet.");
257     }
258 
259     return container.getTensorBuffer(dataType);
260   }
261 
262   /**
263    * Gets the data type of this {@link TensorImage}.
264    *
265    * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
266    *     supported.
267    */
getDataType()268   public DataType getDataType() {
269     return dataType;
270   }
271 
272   /**
273    * Gets the color space type of this {@link TensorImage}.
274    *
275    * @throws IllegalStateException if the {@link TensorImage} never loads data
276    */
getColorSpaceType()277   public ColorSpaceType getColorSpaceType() {
278     if (container == null) {
279       throw new IllegalStateException("No image has been loaded yet.");
280     }
281 
282     return container.getColorSpaceType();
283   }
284 
285   /**
286    * Gets the image width.
287    *
288    * @throws IllegalStateException if the {@link TensorImage} never loads data
289    * @throws IllegalArgumentException if the underlying data is corrupted
290    */
getWidth()291   public int getWidth() {
292     if (container == null) {
293       throw new IllegalStateException("No image has been loaded yet.");
294     }
295 
296     return container.getWidth();
297   }
298 
299   /**
300    * Gets the image height.
301    *
302    * @throws IllegalStateException if the {@link TensorImage} never loads data
303    * @throws IllegalArgumentException if the underlying data is corrupted
304    */
getHeight()305   public int getHeight() {
306     if (container == null) {
307       throw new IllegalStateException("No image has been loaded yet.");
308     }
309 
310     return container.getHeight();
311   }
312 }
313