1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.support.image; 17 18 import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; 19 20 import android.graphics.Bitmap; 21 import java.nio.ByteBuffer; 22 import org.tensorflow.lite.DataType; 23 import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; 24 25 /** 26 * TensorImage is the wrapper class for Image object. When using image processing utils in 27 * TFLite.support library, it's common to convert image objects in variant types to TensorImage at 28 * first. 29 * 30 * <p>At present, only RGB images are supported, and the A channel is always ignored. 31 * 32 * <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a 33 * {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only 34 * converts one to the other when needed. A typical use case of {@link TensorImage} is to first load 35 * a {@link Bitmap} image, then process it using {@link ImageProcessor}, and finally get the 36 * underlying {@link ByteBuffer} of the {@link TensorBuffer} and feed it into the TFLite 37 * interpreter. 38 * 39 * <p>IMPORTANT: to achieve the best performance, {@link TensorImage} avoids copying data whenever 40 * it's possible. Therefore, it doesn't own its data. Callers should not modify data objects those 41 * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer)}. 42 * 43 * <p>IMPORTANT: all methods are not proved thread-safe. 44 * 45 * @see ImageProcessor which is often used for transforming a {@link TensorImage}. 46 */ 47 // TODO(b/138907116): Support loading images from TensorBuffer with properties. 48 // TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary. 49 public class TensorImage { 50 51 private final DataType dataType; 52 private ImageContainer container = null; 53 54 /** 55 * Initializes a {@link TensorImage} object. 56 * 57 * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link 58 * #TensorImage(DataType)} if other data types are preferred. 59 */ TensorImage()60 public TensorImage() { 61 this(DataType.UINT8); 62 } 63 64 /** 65 * Initializes a {@link TensorImage} object with the specified data type. 66 * 67 * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage}, 68 * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be 69 * converted to the specified data type. 70 * 71 * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of 72 * the image being loaded to this {@link TensorImage}. 73 * 74 * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is 75 * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use 76 * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the 77 * same time. 78 * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor 79 * {@link DataType#FLOAT32} 80 */ TensorImage(DataType dataType)81 public TensorImage(DataType dataType) { 82 checkArgument( 83 dataType == DataType.UINT8 || dataType == DataType.FLOAT32, 84 "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted"); 85 this.dataType = dataType; 86 } 87 88 /** 89 * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link Bitmap} . 90 * 91 * @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects 92 * frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}. 93 */ fromBitmap(Bitmap bitmap)94 public static TensorImage fromBitmap(Bitmap bitmap) { 95 TensorImage image = new TensorImage(); 96 image.load(bitmap); 97 return image; 98 } 99 100 /** 101 * Creates a deep-copy of a given {@link TensorImage} with the desired data type. 102 * 103 * @param src the {@link TensorImage} to copy from 104 * @param dataType the expected data type of newly created {@link TensorImage} 105 * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code 106 * dataType} 107 */ createFrom(TensorImage src, DataType dataType)108 public static TensorImage createFrom(TensorImage src, DataType dataType) { 109 TensorImage dst = new TensorImage(dataType); 110 dst.container = src.container.clone(); 111 return dst; 112 } 113 114 /** 115 * Loads a {@link Bitmap} image object into this {@link TensorImage}. 116 * 117 * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric 118 * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link 119 * #getBuffer}, where the {@link Bitmap} will be converted into a {@link TensorBuffer}. 120 * 121 * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The 122 * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well. 123 * In this method, we perform a zero-copy approach for that bitmap, by simply holding its 124 * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary. 125 * 126 * <p>Note: to get the best performance, please load images in the same shape to avoid memory 127 * re-allocation. 128 * 129 * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888 130 */ load(Bitmap bitmap)131 public void load(Bitmap bitmap) { 132 container = BitmapContainer.create(bitmap); 133 } 134 135 /** 136 * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels 137 * inside. 138 * 139 * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32}, 140 * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link 141 * #getBuffer}. 142 * 143 * @param pixels the RGB pixels representing the image 144 * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) 145 * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) 146 */ load(float[] pixels, int[] shape)147 public void load(float[] pixels, int[] shape) { 148 TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); 149 buffer.loadArray(pixels, shape); 150 load(buffer); 151 } 152 153 /** 154 * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside. 155 * 156 * <p>Note: numeric casting and clamping will be applied to convert the values into the data type 157 * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}. 158 * 159 * @param pixels the RGB pixels representing the image 160 * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) 161 * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) 162 */ load(int[] pixels, int[] shape)163 public void load(int[] pixels, int[] shape) { 164 TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); 165 buffer.loadArray(pixels, shape); 166 load(buffer); 167 } 168 169 /** 170 * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB. 171 * 172 * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, 173 * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link 174 * #getBuffer}. 175 * 176 * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or 177 * (1, h, w, 3) 178 * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) 179 */ load(TensorBuffer buffer)180 public void load(TensorBuffer buffer) { 181 load(buffer, ColorSpaceType.RGB); 182 } 183 184 /** 185 * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSapceType}. 186 * 187 * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, 188 * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link 189 * #getBuffer}. 190 * 191 * @throws IllegalArgumentException if the shape of buffer does not match the color space type 192 * @see ColorSpaceType#assertShape 193 */ load(TensorBuffer buffer, ColorSpaceType colorSpaceType)194 public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) { 195 container = TensorBufferContainer.create(buffer, colorSpaceType); 196 } 197 198 /** 199 * Returns a {@link Bitmap} representation of this {@link TensorImage}. 200 * 201 * <p>Numeric casting and clamping will be applied if the stored data is not uint8. 202 * 203 * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code 204 * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work. 205 * 206 * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance 207 * concern, but if modification is necessary, please make a copy. 208 * 209 * @return a reference to a {@link Bitmap} in {@code ARGB_8888} config ("A" channel is always 210 * opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of this {@link 211 * TensorBuffer}. 212 * @throws IllegalStateException if the {@link TensorImage} never loads data 213 */ getBitmap()214 public Bitmap getBitmap() { 215 if (container == null) { 216 throw new IllegalStateException("No image has been loaded yet."); 217 } 218 219 return container.getBitmap(); 220 } 221 222 /** 223 * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data 224 * type. 225 * 226 * <p>Numeric casting and clamping will be applied if the stored data is different from the data 227 * type of the {@link TensorImage}. 228 * 229 * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance 230 * concern, but if modification is necessary, please make a copy. 231 * 232 * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}. 233 * 234 * @return a reference to a {@link ByteBuffer} which holds the image data 235 * @throws IllegalStateException if the {@link TensorImage} never loads data 236 */ getBuffer()237 public ByteBuffer getBuffer() { 238 return getTensorBuffer().getBuffer(); 239 } 240 241 /** 242 * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected 243 * data type. 244 * 245 * <p>Numeric casting and clamping will be applied if the stored data is different from the data 246 * type of the {@link TensorImage}. 247 * 248 * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance 249 * concern, but if modification is necessary, please make a copy. 250 * 251 * @return a reference to a {@link TensorBuffer} which holds the image data 252 * @throws IllegalStateException if the {@link TensorImage} never loads data 253 */ getTensorBuffer()254 public TensorBuffer getTensorBuffer() { 255 if (container == null) { 256 throw new IllegalStateException("No image has been loaded yet."); 257 } 258 259 return container.getTensorBuffer(dataType); 260 } 261 262 /** 263 * Gets the data type of this {@link TensorImage}. 264 * 265 * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are 266 * supported. 267 */ getDataType()268 public DataType getDataType() { 269 return dataType; 270 } 271 272 /** 273 * Gets the color space type of this {@link TensorImage}. 274 * 275 * @throws IllegalStateException if the {@link TensorImage} never loads data 276 */ getColorSpaceType()277 public ColorSpaceType getColorSpaceType() { 278 if (container == null) { 279 throw new IllegalStateException("No image has been loaded yet."); 280 } 281 282 return container.getColorSpaceType(); 283 } 284 285 /** 286 * Gets the image width. 287 * 288 * @throws IllegalStateException if the {@link TensorImage} never loads data 289 * @throws IllegalArgumentException if the underlying data is corrupted 290 */ getWidth()291 public int getWidth() { 292 if (container == null) { 293 throw new IllegalStateException("No image has been loaded yet."); 294 } 295 296 return container.getWidth(); 297 } 298 299 /** 300 * Gets the image height. 301 * 302 * @throws IllegalStateException if the {@link TensorImage} never loads data 303 * @throws IllegalArgumentException if the underlying data is corrupted 304 */ getHeight()305 public int getHeight() { 306 if (container == null) { 307 throw new IllegalStateException("No image has been loaded yet."); 308 } 309 310 return container.getHeight(); 311 } 312 } 313