1# TensorFlow Lite on GPU 2 3[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) supports several 4hardware accelerators. This document describes how to use the GPU backend using 5the TensorFlow Lite delegate APIs on Android (requires OpenCL or OpenGL ES 3.1 6and higher) and iOS (requires iOS 8 or later). 7 8## Benefits of GPU acceleration 9 10### Speed 11 12GPUs are designed to have high throughput for massively parallelizable 13workloads. Thus, they are well-suited for deep neural nets, which consist of a 14huge number of operators, each working on some input tensor(s) that can be 15easily divided into smaller workloads and carried out in parallel. This 16parallelism typically results in lower latency. In the best scenario, inference 17on the GPU may run fast enough to become suitable for real-time applications 18that were not previously possible. 19 20### Accuracy 21 22GPUs do their computation with 16-bit or 32-bit floating point numbers and 23(unlike the CPUs) do not require quantization for optimal performance. If 24decreased accuracy made quantization untenable for your models, running your 25neural network on a GPU may eliminate this concern. 26 27### Energy efficiency 28 29Another benefit that comes with GPU inference is its power efficiency. A GPU 30carries out computations in a very efficient and optimized way, consuming less 31power and generating less heat than the same task run on a CPU. 32 33## Supported ops 34 35TensorFlow Lite on GPU supports the following ops in 16-bit and 32-bit float 36precision: 37 38* `ADD` 39* `AVERAGE_POOL_2D` 40* `CONCATENATION` 41* `CONV_2D` 42* `DEPTHWISE_CONV_2D v1-2` 43* `EXP` 44* `FULLY_CONNECTED` 45* `LOGISTIC` 46* `LSTM v2 (Basic LSTM only)` 47* `MAX_POOL_2D` 48* `MAXIMUM` 49* `MINIMUM` 50* `MUL` 51* `PAD` 52* `PRELU` 53* `RELU` 54* `RELU6` 55* `RESHAPE` 56* `RESIZE_BILINEAR v1-3` 57* `SOFTMAX` 58* `STRIDED_SLICE` 59* `SUB` 60* `TRANSPOSE_CONV` 61 62By default, all ops are only supported at version 1. Enabling the 63[experimental quantization support](gpu_advanced.md#running-quantized-models-experimental-android-only) 64allows the appropriate versions; for example, ADD v2. 65 66## Basic usage 67 68There are two ways to invoke model acceleration in Android depending on if you 69are using 70[Android Studio ML Model Binding](../inference_with_metadata/codegen#acceleration) 71or TensorFlow Lite Interpreter. 72 73### Android via TensorFlow Lite Interpreter 74 75Add the `tensorflow-lite-gpu` package alongside the existing `tensorflow-lite` 76package in the existing `dependencies` block. 77 78``` 79dependencies { 80 ... 81 implementation 'org.tensorflow:tensorflow-lite:2.3.0' 82 implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0' 83} 84``` 85 86Then run TensorFlow Lite on GPU with `TfLiteDelegate`. In Java, you can specify 87the `GpuDelegate` through `Interpreter.Options`. 88 89<div> 90 <devsite-selector> 91 <section> 92 <h3>Kotlin</h3> 93 <p><pre class="prettyprint lang-kotlin"> 94 import org.tensorflow.lite.Interpreter 95 import org.tensorflow.lite.gpu.CompatibilityList 96 import org.tensorflow.lite.gpu.GpuDelegate 97 98 val compatList = CompatibilityList() 99 100 val options = Interpreter.Options().apply{ 101 if(compatList.isDelegateSupportedOnThisDevice){ 102 // if the device has a supported GPU, add the GPU delegate 103 val delegateOptions = compatList.bestOptionsForThisDevice 104 this.addDelegate(GpuDelegate(delegateOptions)) 105 } else { 106 // if the GPU is not supported, run on 4 threads 107 this.setNumThreads(4) 108 } 109 } 110 111 val interpreter = Interpreter(model, options) 112 113 // Run inference 114 writeToInput(input) 115 interpreter.run(input, output) 116 readFromOutput(output) 117 </pre></p> 118 </section> 119 <section> 120 <h3>Java</h3> 121 <p><pre class="prettyprint lang-java"> 122 import org.tensorflow.lite.Interpreter; 123 import org.tensorflow.lite.gpu.CompatibilityList; 124 import org.tensorflow.lite.gpu.GpuDelegate; 125 126 // Initialize interpreter with GPU delegate 127 Interpreter.Options options = new Interpreter.Options(); 128 CompatibilityList compatList = CompatibilityList(); 129 130 if(compatList.isDelegateSupportedOnThisDevice()){ 131 // if the device has a supported GPU, add the GPU delegate 132 GpuDelegate.Options delegateOptions = compatList.getBestOptionsForThisDevice(); 133 GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions); 134 options.addDelegate(gpuDelegate); 135 } else { 136 // if the GPU is not supported, run on 4 threads 137 options.setNumThreads(4); 138 } 139 140 Interpreter interpreter = new Interpreter(model, options); 141 142 // Run inference 143 writeToInput(input); 144 interpreter.run(input, output); 145 readFromOutput(output); 146 </pre></p> 147 </section> 148 </devsite-selector> 149</div> 150 151### Android (C/C++) 152 153For C/C++ usage of TensorFlow Lite GPU on Android, the GPU delegate can be 154created with `TfLiteGpuDelegateV2Create()` and destroyed with 155`TfLiteGpuDelegateV2Delete()`. 156 157```c++ 158// Set up interpreter. 159auto model = FlatBufferModel::BuildFromFile(model_path); 160if (!model) return false; 161ops::builtin::BuiltinOpResolver op_resolver; 162std::unique_ptr<Interpreter> interpreter; 163InterpreterBuilder(*model, op_resolver)(&interpreter); 164 165// NEW: Prepare GPU delegate. 166auto* delegate = TfLiteGpuDelegateV2Create(/*default options=*/nullptr); 167if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; 168 169// Run inference. 170WriteToInputTensor(interpreter->typed_input_tensor<float>(0)); 171if (interpreter->Invoke() != kTfLiteOk) return false; 172ReadFromOutputTensor(interpreter->typed_output_tensor<float>(0)); 173 174// NEW: Clean up. 175TfLiteGpuDelegateV2Delete(delegate); 176``` 177 178Take a look at `TfLiteGpuDelegateOptionsV2` to create a delegate instance with 179custom options. You can initialize the default options with 180`TfLiteGpuDelegateOptionsV2Default()` and then modify them as necessary. 181 182TFLite GPU for Android C/C++ uses the [Bazel](https://bazel.io) build system. 183The delegate can be built, for example, using the following command: 184 185```sh 186bazel build -c opt --config android_arm64 tensorflow/lite/delegates/gpu:delegate # for static library 187bazel build -c opt --config android_arm64 tensorflow/lite/delegates/gpu:libtensorflowlite_gpu_delegate.so # for dynamic library 188``` 189 190Note: When calling `Interpreter::ModifyGraphWithDelegate()` or 191`Interpreter::Invoke()`, the caller must have an `EGLContext` in the current 192thread and `Interpreter::Invoke()` must be called from the same `EGLContext`. If 193an `EGLContext` does not exist, the delegate will internally create one, but 194then the developer must ensure that `Interpreter::Invoke()` is always called 195from the same thread in which `Interpreter::ModifyGraphWithDelegate()` was 196called. 197 198### iOS (C++) 199 200Note: For Swift/Objective-C/C use cases, please refer to 201[GPU delegate guide](gpu#ios) 202 203Note: This is only available when you are using bazel or build TensorFlow Lite 204by yourself. C++ API can't be used with CocoaPods. 205 206To use TensorFlow Lite on GPU, get the GPU delegate via `TFLGpuDelegateCreate()` 207and then pass it to `Interpreter::ModifyGraphWithDelegate()` (instead of calling 208`Interpreter::AllocateTensors()`). 209 210```c++ 211// Set up interpreter. 212auto model = FlatBufferModel::BuildFromFile(model_path); 213if (!model) return false; 214tflite::ops::builtin::BuiltinOpResolver op_resolver; 215std::unique_ptr<Interpreter> interpreter; 216InterpreterBuilder(*model, op_resolver)(&interpreter); 217 218// NEW: Prepare GPU delegate. 219 220auto* delegate = TFLGpuDelegateCreate(/*default options=*/nullptr); 221if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; 222 223// Run inference. 224WriteToInputTensor(interpreter->typed_input_tensor<float>(0)); 225if (interpreter->Invoke() != kTfLiteOk) return false; 226ReadFromOutputTensor(interpreter->typed_output_tensor<float>(0)); 227 228// Clean up. 229TFLGpuDelegateDelete(delegate); 230``` 231 232## Advanced usage 233 234### Delegate Options for iOS 235 236Constructor for GPU delegate accepts a `struct` of options. 237([Swift API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/swift/Sources/MetalDelegate.swift), 238[Objective-C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/objc/apis/TFLMetalDelegate.h), 239[C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/gpu/metal_delegate.h)) 240 241Passing `nullptr` (C API) or nothing (Objective-C and Swift API) to the 242initializer sets the default options (which are explicated in the Basic Usage 243example above). 244 245<div> 246 <devsite-selector> 247 <section> 248 <h3>Swift</h3> 249 <p><pre class="prettyprint lang-swift"> 250 // THIS: 251 var options = MetalDelegate.Options() 252 options.isPrecisionLossAllowed = false 253 options.waitType = .passive 254 options.isQuantizationEnabled = true 255 let delegate = MetalDelegate(options: options) 256 257 // IS THE SAME AS THIS: 258 let delegate = MetalDelegate() 259 </pre></p> 260 </section> 261 <section> 262 <h3>Objective-C</h3> 263 <p><pre class="prettyprint lang-objc"> 264 // THIS: 265 TFLMetalDelegateOptions* options = [[TFLMetalDelegateOptions alloc] init]; 266 options.precisionLossAllowed = false; 267 options.waitType = TFLMetalDelegateThreadWaitTypePassive; 268 options.quantizationEnabled = true; 269 270 TFLMetalDelegate* delegate = [[TFLMetalDelegate alloc] initWithOptions:options]; 271 272 // IS THE SAME AS THIS: 273 TFLMetalDelegate* delegate = [[TFLMetalDelegate alloc] init]; 274 </pre></p> 275 </section> 276 <section> 277 <h3>C</h3> 278 <p><pre class="prettyprint lang-c"> 279 // THIS: 280 const TFLGpuDelegateOptions options = { 281 .allow_precision_loss = false, 282 .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive, 283 .enable_quantization = true, 284 }; 285 286 TfLiteDelegate* delegate = TFLGpuDelegateCreate(options); 287 288 // IS THE SAME AS THIS: 289 TfLiteDelegate* delegate = TFLGpuDelegateCreate(nullptr); 290 </pre></p> 291 </section> 292 </devsite-selector> 293</div> 294 295While it is convenient to use `nullptr` or default constructors, we recommend 296that you explicitly set the options, to avoid any unexpected behavior if default 297values are changed in the future. 298 299### Running quantized models on GPU 300 301This section explains how the GPU delegate accelerates 8-bit quantized models. 302This includes all flavors of quantization, including: 303 304* Models trained with 305 [Quantization-aware training](https://www.tensorflow.org/lite/convert/quantization) 306* [Post-training dynamic-range quantization](https://www.tensorflow.org/lite/performance/post_training_quant) 307* [Post-training full-integer quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant) 308 309To optimize performance, use models that have floating-point input & output 310tensors. 311 312#### How does this work? 313 314Since the GPU backend only supports floating-point execution, we run quantized 315models by giving it a ‘floating-point view’ of the original model. At a 316high-level, this entails the following steps: 317 318* *Constant tensors* (such as weights/biases) are dequantized once into the 319 GPU memory. This happens when the delegate is applied to the TFLite 320 Interpreter. 321 322* *Inputs and outputs* to the GPU program, if 8-bit quantized, are dequantized 323 and quantized (respectively) for each inference. This is done on the CPU 324 using TFLite’s optimized kernels. 325 326* The GPU program is modified to mimic quantized behavior by inserting 327 *quantization simulators* between operations. This is necessary for models 328 where ops expect activations to follow bounds learnt during quantization. 329 330This feature can be enabled using delegate options as follows: 331 332#### Android 333 334Android APIs support quantized models by default. To disable, do the following: 335 336**C++ API** 337 338```c++ 339TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); 340options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; 341 342auto* delegate = TfLiteGpuDelegateV2Create(options); 343if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; 344``` 345 346**Java API** 347 348```java 349GpuDelegate delegate = new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(false)); 350 351Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate); 352``` 353 354#### iOS 355 356iOS APIs support quantized models by default. To disable, do the following: 357 358<div> 359 <devsite-selector> 360 <section> 361 <h3>Swift</h3> 362 <p><pre class="prettyprint lang-swift"> 363 var options = MetalDelegate.Options() 364 options.isQuantizationEnabled = false 365 let delegate = MetalDelegate(options: options) 366 </pre></p> 367 </section> 368 <section> 369 <h3>Objective-C</h3> 370 <p><pre class="prettyprint lang-objc"> 371 TFLMetalDelegateOptions* options = [[TFLMetalDelegateOptions alloc] init]; 372 options.quantizationEnabled = false; 373 </pre></p> 374 </section> 375 <section> 376 <h3>C</h3> 377 <p><pre class="prettyprint lang-c"> 378 TFLGpuDelegateOptions options = TFLGpuDelegateOptionsDefault(); 379 options.enable_quantization = false; 380 381 TfLiteDelegate* delegate = TFLGpuDelegateCreate(options); 382 </pre></p> 383 </section> 384 </devsite-selector> 385</div> 386 387### Input/Output Buffers (iOS, C++ API only) 388 389Note: This is only available when you are using bazel or build TensorFlow Lite 390by yourself. C++ API can't be used with CocoaPods. 391 392To do computation on the GPU, data must be made available to the GPU. This often 393requires performing a memory copy. It is desirable not to cross the CPU/GPU 394memory boundary if possible, as this can take up a significant amount of time. 395Usually, such crossing is inevitable, but in some special cases, one or the 396other can be omitted. 397 398If the network's input is an image already loaded in the GPU memory (for 399example, a GPU texture containing the camera feed) it can stay in the GPU memory 400without ever entering the CPU memory. Similarly, if the network's output is in 401the form of a renderable image (for example, 402[image style transfer](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)) 403it can be directly displayed on the screen. 404 405To achieve best performance, TensorFlow Lite makes it possible for users to 406directly read from and write to the TensorFlow hardware buffer and bypass 407avoidable memory copies. 408 409Assuming the image input is in GPU memory, it must first be converted to a 410`MTLBuffer` object for Metal. You can associate a TfLiteTensor to a 411user-prepared `MTLBuffer` with `TFLGpuDelegateBindMetalBufferToTensor()`. Note 412that `TFLGpuDelegateBindMetalBufferToTensor()` must be called after 413`Interpreter::ModifyGraphWithDelegate()`. Additionally, the inference output is, 414by default, copied from GPU memory to CPU memory. This behavior can be turned 415off by calling `Interpreter::SetAllowBufferHandleOutput(true)` during 416initialization. 417 418```c++ 419#include "tensorflow/lite/delegates/gpu/metal_delegate.h" 420#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" 421 422// ... 423 424// Prepare GPU delegate. 425auto* delegate = TFLGpuDelegateCreate(nullptr); 426 427if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; 428 429interpreter->SetAllowBufferHandleOutput(true); // disable default gpu->cpu copy 430if (!TFLGpuDelegateBindMetalBufferToTensor( 431 delegate, interpreter->inputs()[0], user_provided_input_buffer)) { 432 return false; 433} 434if (!TFLGpuDelegateBindMetalBufferToTensor( 435 delegate, interpreter->outputs()[0], user_provided_output_buffer)) { 436 return false; 437} 438 439// Run inference. 440if (interpreter->Invoke() != kTfLiteOk) return false; 441``` 442 443Note: Once the default behavior is turned off, copying the inference output from 444GPU memory to CPU memory requires an explicit call to 445`Interpreter::EnsureTensorDataIsReadable()` for each output tensor. 446 447Note: This also works for quantized models, but you still need to a **float32 448sized buffer with float32 data**, because the buffer will be bound to the 449internal dequantized buffer. 450 451## Tips and Tricks 452 453* Some operations that are trivial on the CPU may be high cost on a GPU. One 454 class of such operation includes various forms of reshape operations 455 (including `BATCH_TO_SPACE`, `SPACE_TO_BATCH`, `SPACE_TO_DEPTH`, and similar 456 operation). If these operations are not required (for example, they were 457 inserted to help the network architect reason about the system but do not 458 otherwise affect output), it is worth removing them for performance. 459 460* On a GPU, tensor data is sliced into 4-channels. Thus, a computation on a 461 tensor of shape `[B, H, W, 5]` will perform about the same on a tensor of 462 shape `[B, H, W, 8]`, but significantly worse than `[B, H, W, 4]`. 463 464 * For example, if the camera hardware supports image frames in RGBA, 465 feeding that 4-channel input is significantly faster, because a memory 466 copy (from 3-channel RGB to 4-channel RGBX) can be avoided. 467 468* For best performance, do not hesitate to re-train your classifier with 469 mobile-optimized network architecture. That is a significant part of 470 optimization for on-device inference. 471