1 // Copyright 2019 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import TensorFlowLiteCMetal
16 
17 /// A delegate that uses the `Metal` framework for performing TensorFlow Lite graph operations with
18 /// GPU acceleration.
19 ///
20 /// - Important: This is an experimental interface that is subject to change.
21 public final class MetalDelegate: Delegate {
22   /// The configuration options for the `MetalDelegate`.
23   public let options: Options
24 
25   // Conformance to the `Delegate` protocol.
26   public private(set) var cDelegate: CDelegate
27 
28   /// Creates a new instance configured with the given `options`.
29   ///
30   /// - Parameters:
31   ///   - options: Configurations for the delegate. The default is a new instance of
32   ///       `MetalDelegate.Options` with the default configuration values.
33   public init(options: Options = Options()) {
34     self.options = options
35     var delegateOptions = TFLGpuDelegateOptions()
36     delegateOptions.allow_precision_loss = options.isPrecisionLossAllowed
37     delegateOptions.wait_type = options.waitType.cWaitType
38     delegateOptions.enable_quantization = options.isQuantizationEnabled
39     cDelegate = TFLGpuDelegateCreate(&delegateOptions)
40   }
41 
42   deinit {
43     TFLGpuDelegateDelete(cDelegate)
44   }
45 }
46 
47 extension MetalDelegate {
48   /// Options for configuring the `MetalDelegate`.
49   public struct Options: Equatable, Hashable {
50     /// Indicates whether the GPU delegate allows precision loss, such as allowing `Float16`
51     /// precision for a `Float32` computation. The default is `false`.
52     public var isPrecisionLossAllowed = false
53 
54     @available(*, deprecated: 2.4, renamed: "isPrecisionLossAllowed")
55     public var allowsPrecisionLoss: Bool {
56       get { return isPrecisionLossAllowed }
57       set(value) { isPrecisionLossAllowed = value }
58     }
59 
60     /// A type indicating how the current thread should wait for work on the GPU to complete. The
61     /// default is `passive`.
62     public var waitType: ThreadWaitType = .passive
63 
64     /// Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default
65     /// is `true`.
66     public var isQuantizationEnabled = true
67 
68     /// Creates a new instance with the default values.
69     public init() {}
70   }
71 }
72 
73 /// A type indicating how the current thread should wait for work scheduled on the GPU to complete.
74 public enum ThreadWaitType: Equatable, Hashable {
75   /// The thread does not wait for the work to complete. Useful when the output of the work is used
76   /// with the GPU pipeline.
77   case none
78   /// The thread waits until the work is complete.
79   case passive
80   /// The thread waits for the work to complete with minimal latency, which may require additional
81   /// CPU resources.
82   case active
83   /// The thread waits for the work while trying to prevent the GPU from going into sleep mode.
84   case aggressive
85 
86   /// The C `TFLGpuDelegateWaitType` for the current `ThreadWaitType`.
87   var cWaitType: TFLGpuDelegateWaitType {
88     switch self {
89     case .none:
90       return TFLGpuDelegateWaitTypeDoNotWait
91     case .passive:
92       return TFLGpuDelegateWaitTypePassive
93     case .active:
94       return TFLGpuDelegateWaitTypeActive
95     case .aggressive:
96       return TFLGpuDelegateWaitTypeAggressive
97     }
98   }
99 }
100