1// Copyright 2018 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at: 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#import "tensorflow/lite/objc/apis/TFLTensor.h" 16 17#import "TFLErrorUtil.h" 18#import "TFLInterpreter+Internal.h" 19#import "TFLTensor+Internal.h" 20 21#import "tensorflow/lite/objc/apis/TFLInterpreter.h" 22 23NS_ASSUME_NONNULL_BEGIN 24 25// String names of input or output tensor types. 26static NSString *const kTFLInputTensorTypeString = @"input"; 27static NSString *const kTFLOutputTensorTypeString = @"output"; 28 29@interface TFLTensor () 30 31// Redefines readonly properties. 32@property(nonatomic) TFLTensorType type; 33@property(nonatomic) NSUInteger index; 34@property(nonatomic, copy) NSString *name; 35@property(nonatomic) TFLTensorDataType dataType; 36@property(nonatomic, nullable) TFLQuantizationParameters *quantizationParameters; 37 38/** 39 * The backing interpreter. It's a strong reference to ensure that the interpreter is never released 40 * before this tensor is released. 41 * 42 * @warning Never let the interpreter hold a strong reference to the tensor to avoid retain cycles. 43 */ 44@property(nonatomic) TFLInterpreter *interpreter; 45 46@end 47 48@implementation TFLTensor 49 50#pragma mark - Public 51 52- (BOOL)copyData:(NSData *)data error:(NSError **)error { 53 if (self.type == TFLTensorTypeOutput) { 54 [TFLErrorUtil 55 saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed 56 description:@"Cannot copy data into an output tensor." 57 error:error]; 58 return NO; 59 } 60 61 return [self.interpreter copyData:data toInputTensorAtIndex:self.index error:error]; 62} 63 64- (nullable NSData *)dataWithError:(NSError **)error { 65 return [self.interpreter dataFromTensor:self error:error]; 66} 67 68- (nullable NSArray<NSNumber *> *)shapeWithError:(NSError **)error { 69 return [self.interpreter shapeOfTensor:self error:error]; 70} 71 72#pragma mark - TFLTensor (Internal) 73 74- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter 75 type:(TFLTensorType)type 76 index:(NSUInteger)index 77 name:(NSString *)name 78 dataType:(TFLTensorDataType)dataType 79 quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters { 80 self = [super init]; 81 if (self != nil) { 82 _interpreter = interpreter; 83 _type = type; 84 _index = index; 85 _name = [name copy]; 86 _dataType = dataType; 87 _quantizationParameters = quantizationParameters; 88 } 89 return self; 90} 91 92+ (NSString *)stringForTensorType:(TFLTensorType)type { 93 switch (type) { 94 case TFLTensorTypeInput: 95 return kTFLInputTensorTypeString; 96 case TFLTensorTypeOutput: 97 return kTFLOutputTensorTypeString; 98 } 99} 100 101@end 102 103NS_ASSUME_NONNULL_END 104