1 // Copyright 2020 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import TensorFlowLiteCCoreML
16 
17 /// A delegate that uses the `Core ML` framework for performing TensorFlow Lite graph operations.
18 ///
19 /// - Important: This is an experimental interface that is subject to change.
20 public final class CoreMLDelegate: Delegate {
21   /// The configuration options for the `CoreMLDelegate`.
22   public let options: Options
23 
24   // Conformance to the `Delegate` protocol.
25   public private(set) var cDelegate: CDelegate
26 
27   /// Creates a new instance configured with the given `options`. Returns `nil` if the underlying
28   /// Core ML delegate could not be created because `Options.enabledDevices` was set to
29   /// `neuralEngine` but the device does not have the Neural Engine.
30   ///
31   /// - Parameters:
32   ///   - options: Configurations for the delegate. The default is a new instance of
33   ///       `CoreMLDelegate.Options` with the default configuration values.
34   public init?(options: Options = Options()) {
35     self.options = options
36     var delegateOptions = TfLiteCoreMlDelegateOptions()
37     delegateOptions.enabled_devices = options.enabledDevices.cEnabledDevices
38     delegateOptions.coreml_version = Int32(options.coreMLVersion)
39     delegateOptions.max_delegated_partitions = Int32(options.maxDelegatedPartitions)
40     delegateOptions.min_nodes_per_partition = Int32(options.minNodesPerPartition)
41     guard let delegate = TfLiteCoreMlDelegateCreate(&delegateOptions) else { return nil }
42     cDelegate = delegate
43   }
44 
45   deinit {
46     TfLiteCoreMlDelegateDelete(cDelegate)
47   }
48 }
49 
50 extension CoreMLDelegate {
51   /// A type indicating which devices the Core ML delegate should be enabled for.
52   public enum EnabledDevices: Equatable, Hashable {
53     /// Enables the delegate for devices with Neural Engine only.
54     case neuralEngine
55     /// Enables the delegate for all devices.
56     case all
57 
58     /// The C `TfLiteCoreMlDelegateEnabledDevices` for the current `EnabledDevices`.
59     var cEnabledDevices: TfLiteCoreMlDelegateEnabledDevices {
60       switch self {
61       case .neuralEngine:
62         return TfLiteCoreMlDelegateDevicesWithNeuralEngine
63       case .all:
64         return TfLiteCoreMlDelegateAllDevices
65       }
66     }
67   }
68 
69   /// Options for configuring the `CoreMLDelegate`.
70   // TODO(b/143931022): Add preferred device support.
71   public struct Options: Equatable, Hashable {
72     /// A type indicating which devices the Core ML delegate should be enabled for. The default
73     /// value is `.neuralEngine` indicating that the delegate is enabled for Neural Engine devices
74     /// only.
75     public var enabledDevices: EnabledDevices = .neuralEngine
76     /// Target Core ML version for the model conversion. When it's not set, Core ML version will
77     /// be set to highest available version for the platform.
78     public var coreMLVersion = 0
79     /// The maximum number of Core ML delegate partitions created. Each graph corresponds to one
80     /// delegated node subset in the TFLite model. The default value is `0` indicating that all
81     /// possible partitions are delegated.
82     public var maxDelegatedPartitions = 0
83     /// The minimum number of nodes per partition to be delegated by the Core ML delegate. The
84     /// default value is `2`.
85     public var minNodesPerPartition = 2
86 
87     /// Creates a new instance with the default values.
88     public init() {}
89   }
90 }
91