1// Copyright 2020 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at:
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#import "tensorflow/lite/objc/apis/TFLMetalDelegate.h"
16
17#ifdef COCOAPODS
18@import TensorFlowLiteCMetal;
19#else
20#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
21#endif
22
23NS_ASSUME_NONNULL_BEGIN
24
25@implementation TFLMetalDelegateOptions
26
27#pragma mark - Public
28
29- (instancetype)init {
30  self = [super init];
31  if (self != nil) {
32    _quantizationEnabled = true;
33    _waitType = TFLMetalDelegateThreadWaitTypePassive;
34  }
35  return self;
36}
37
38@end
39
40@implementation TFLMetalDelegate
41
42@synthesize cDelegate = _cDelegate;
43
44#pragma mark - NSObject
45
46- (void)dealloc {
47  TFLGpuDelegateDelete(self.cDelegate);
48}
49
50#pragma mark - Public
51
52- (nullable instancetype)init {
53  TFLMetalDelegateOptions* options = [[TFLMetalDelegateOptions alloc] init];
54  return [self initWithOptions:options];
55}
56
57- (nullable instancetype)initWithOptions:(TFLMetalDelegateOptions*)options {
58  self = [super init];
59  if (self != nil) {
60    TFLGpuDelegateOptions cOptions;
61    cOptions.allow_precision_loss = options.precisionLossAllowed;
62    cOptions.enable_quantization = options.quantizationEnabled;
63    switch (options.waitType) {
64      case TFLMetalDelegateThreadWaitTypeDoNotWait:
65        cOptions.wait_type = TFLGpuDelegateWaitTypeDoNotWait;
66        break;
67      case TFLMetalDelegateThreadWaitTypePassive:
68        cOptions.wait_type = TFLGpuDelegateWaitTypePassive;
69        break;
70      case TFLMetalDelegateThreadWaitTypeActive:
71        cOptions.wait_type = TFLGpuDelegateWaitTypeActive;
72        break;
73      case TFLMetalDelegateThreadWaitTypeAggressive:
74        cOptions.wait_type = TFLGpuDelegateWaitTypeAggressive;
75        break;
76    }
77    _cDelegate = TFLGpuDelegateCreate(&cOptions);
78    if (_cDelegate == nil) {
79      return nil;
80    }
81  }
82  return self;
83}
84
85@end
86
87NS_ASSUME_NONNULL_END
88