1 // Copyright 2019 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at: 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 import TensorFlowLiteCMetal 16 17 /// A delegate that uses the `Metal` framework for performing TensorFlow Lite graph operations with 18 /// GPU acceleration. 19 /// 20 /// - Important: This is an experimental interface that is subject to change. 21 public final class MetalDelegate: Delegate { 22 /// The configuration options for the `MetalDelegate`. 23 public let options: Options 24 25 // Conformance to the `Delegate` protocol. 26 public private(set) var cDelegate: CDelegate 27 28 /// Creates a new instance configured with the given `options`. 29 /// 30 /// - Parameters: 31 /// - options: Configurations for the delegate. The default is a new instance of 32 /// `MetalDelegate.Options` with the default configuration values. 33 public init(options: Options = Options()) { 34 self.options = options 35 var delegateOptions = TFLGpuDelegateOptions() 36 delegateOptions.allow_precision_loss = options.isPrecisionLossAllowed 37 delegateOptions.wait_type = options.waitType.cWaitType 38 delegateOptions.enable_quantization = options.isQuantizationEnabled 39 cDelegate = TFLGpuDelegateCreate(&delegateOptions) 40 } 41 42 deinit { 43 TFLGpuDelegateDelete(cDelegate) 44 } 45 } 46 47 extension MetalDelegate { 48 /// Options for configuring the `MetalDelegate`. 49 public struct Options: Equatable, Hashable { 50 /// Indicates whether the GPU delegate allows precision loss, such as allowing `Float16` 51 /// precision for a `Float32` computation. The default is `false`. 52 public var isPrecisionLossAllowed = false 53 54 @available(*, deprecated: 2.4, renamed: "isPrecisionLossAllowed") 55 public var allowsPrecisionLoss: Bool { 56 get { return isPrecisionLossAllowed } 57 set(value) { isPrecisionLossAllowed = value } 58 } 59 60 /// A type indicating how the current thread should wait for work on the GPU to complete. The 61 /// default is `passive`. 62 public var waitType: ThreadWaitType = .passive 63 64 /// Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default 65 /// is `true`. 66 public var isQuantizationEnabled = true 67 68 /// Creates a new instance with the default values. 69 public init() {} 70 } 71 } 72 73 /// A type indicating how the current thread should wait for work scheduled on the GPU to complete. 74 public enum ThreadWaitType: Equatable, Hashable { 75 /// The thread does not wait for the work to complete. Useful when the output of the work is used 76 /// with the GPU pipeline. 77 case none 78 /// The thread waits until the work is complete. 79 case passive 80 /// The thread waits for the work to complete with minimal latency, which may require additional 81 /// CPU resources. 82 case active 83 /// The thread waits for the work while trying to prevent the GPU from going into sleep mode. 84 case aggressive 85 86 /// The C `TFLGpuDelegateWaitType` for the current `ThreadWaitType`. 87 var cWaitType: TFLGpuDelegateWaitType { 88 switch self { 89 case .none: 90 return TFLGpuDelegateWaitTypeDoNotWait 91 case .passive: 92 return TFLGpuDelegateWaitTypePassive 93 case .active: 94 return TFLGpuDelegateWaitTypeActive 95 case .aggressive: 96 return TFLGpuDelegateWaitTypeAggressive 97 } 98 } 99 } 100