1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at:
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#import "tensorflow/lite/objc/apis/TFLTensor.h"
16
17#import "TFLErrorUtil.h"
18#import "TFLInterpreter+Internal.h"
19#import "TFLTensor+Internal.h"
20
21#import "tensorflow/lite/objc/apis/TFLInterpreter.h"
22
23NS_ASSUME_NONNULL_BEGIN
24
25// String names of input or output tensor types.
26static NSString *const kTFLInputTensorTypeString = @"input";
27static NSString *const kTFLOutputTensorTypeString = @"output";
28
29@interface TFLTensor ()
30
31// Redefines readonly properties.
32@property(nonatomic) TFLTensorType type;
33@property(nonatomic) NSUInteger index;
34@property(nonatomic, copy) NSString *name;
35@property(nonatomic) TFLTensorDataType dataType;
36@property(nonatomic, nullable) TFLQuantizationParameters *quantizationParameters;
37
38/**
39 * The backing interpreter. It's a strong reference to ensure that the interpreter is never released
40 * before this tensor is released.
41 *
42 * @warning Never let the interpreter hold a strong reference to the tensor to avoid retain cycles.
43 */
44@property(nonatomic) TFLInterpreter *interpreter;
45
46@end
47
48@implementation TFLTensor
49
50#pragma mark - Public
51
52- (BOOL)copyData:(NSData *)data error:(NSError **)error {
53  if (self.type == TFLTensorTypeOutput) {
54    [TFLErrorUtil
55        saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed
56                         description:@"Cannot copy data into an output tensor."
57                               error:error];
58    return NO;
59  }
60
61  return [self.interpreter copyData:data toInputTensorAtIndex:self.index error:error];
62}
63
64- (nullable NSData *)dataWithError:(NSError **)error {
65  return [self.interpreter dataFromTensor:self error:error];
66}
67
68- (nullable NSArray<NSNumber *> *)shapeWithError:(NSError **)error {
69  return [self.interpreter shapeOfTensor:self error:error];
70}
71
72#pragma mark - TFLTensor (Internal)
73
74- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter
75                               type:(TFLTensorType)type
76                              index:(NSUInteger)index
77                               name:(NSString *)name
78                           dataType:(TFLTensorDataType)dataType
79             quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters {
80  self = [super init];
81  if (self != nil) {
82    _interpreter = interpreter;
83    _type = type;
84    _index = index;
85    _name = [name copy];
86    _dataType = dataType;
87    _quantizationParameters = quantizationParameters;
88  }
89  return self;
90}
91
92+ (NSString *)stringForTensorType:(TFLTensorType)type {
93  switch (type) {
94    case TFLTensorTypeInput:
95      return kTFLInputTensorTypeString;
96    case TFLTensorTypeOutput:
97      return kTFLOutputTensorTypeString;
98  }
99}
100
101@end
102
103NS_ASSUME_NONNULL_END
104