1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 19 20 #include <functional> 21 #include <utility> 22 23 #include "OperationsExecutionUtils.h" 24 #include "OperationsValidationUtils.h" 25 26 namespace android { 27 namespace nn { 28 29 // Encapsulates an operation implementation. 30 struct OperationRegistration { 31 OperationType type; 32 const char* name; 33 34 // Validates operand types, shapes, and any values known during graph creation. 35 // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension 36 // types. 37 std::function<Result<Version>(const IOperationValidationContext*)> validate; 38 39 // prepare is called when the inputs this operation depends on have been 40 // computed. Typically, prepare does any remaining validation and sets 41 // output shapes via context->setOutputShape(...). 42 std::function<bool(IOperationExecutionContext*)> prepare; 43 44 // Executes the operation, reading from context->getInputBuffer(...) 45 // and writing to context->getOutputBuffer(...). 46 std::function<bool(IOperationExecutionContext*)> execute; 47 48 struct Flag { 49 // Whether the operation allows at least one operand to be omitted. 50 bool allowOmittedOperand = false; 51 // Whether the operation allows at least one input operand to be a zero-sized tensor. 52 bool allowZeroSizedInput = false; 53 } flags; 54 OperationRegistrationOperationRegistration55 OperationRegistration( 56 OperationType type, const char* name, 57 std::function<Result<Version>(const IOperationValidationContext*)> validate, 58 std::function<bool(IOperationExecutionContext*)> prepare, 59 std::function<bool(IOperationExecutionContext*)> execute, Flag flags) 60 : type(type), 61 name(name), 62 validate(std::move(validate)), 63 prepare(std::move(prepare)), 64 execute(std::move(execute)), 65 flags(flags) {} 66 }; 67 68 // A registry of operation implementations. 69 class IOperationResolver { 70 public: 71 virtual const OperationRegistration* findOperation(OperationType operationType) const = 0; ~IOperationResolver()72 virtual ~IOperationResolver() {} 73 }; 74 75 // A registry of builtin operation implementations. 76 // 77 // Note that some operations bypass BuiltinOperationResolver (b/124041202). 78 // 79 // Usage: 80 // const OperationRegistration* operationRegistration = 81 // BuiltinOperationResolver::get()->findOperation(operationType); 82 // NN_RET_CHECK(operationRegistration != nullptr); 83 // NN_RET_CHECK(operationRegistration->validate != nullptr); 84 // NN_RET_CHECK(operationRegistration->validate(&context)); 85 // 86 class BuiltinOperationResolver : public IOperationResolver { 87 DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver); 88 89 public: get()90 static const BuiltinOperationResolver* get() { 91 static BuiltinOperationResolver instance; 92 return &instance; 93 } 94 95 const OperationRegistration* findOperation(OperationType operationType) const override; 96 97 // The number of operation types (OperationCode) defined in NeuralNetworksTypes.h. 98 static constexpr int kNumberOfOperationTypes = 106; 99 100 #ifdef NN_EXPERIMENTAL_FEATURE 101 // The number of experimental operation types (ANeuralNetworksExperimentalOperationCode) defined 102 // in NeuralNetworksExperimentalFeatures.h. 103 static constexpr int kNumberOfExperimentalOperationTypes = 1; 104 105 // The starting value of experimental operation types (ANeuralNetworksExperimentalOperationCode) 106 // defined in NeuralNetworksExperimentalFeatures.h. 107 static constexpr int kStartOfExperimentalOperations = 20000; 108 #endif // NN_EXPERIMENTAL_FEATURE 109 110 private: 111 BuiltinOperationResolver(); 112 113 void registerOperation(const OperationRegistration* operationRegistration); 114 115 const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {}; 116 117 #ifdef NN_EXPERIMENTAL_FEATURE 118 const OperationRegistration* mExperimentalRegistrations[kNumberOfExperimentalOperationTypes] = 119 {}; 120 #endif // NN_EXPERIMENTAL_FEATURE 121 }; 122 123 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by 124 // OperationResolver. 125 // 126 // Usage: 127 // (check OperationRegistration::Flag for available fields and default values.) 128 // 129 // - With default flags. 130 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 131 // foo_op::prepare, foo_op::execute); 132 // 133 // - With a customized flag. 134 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 135 // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true); 136 // 137 // - With multiple customized flags. 138 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 139 // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true, 140 // .allowZeroSizedInput = true); 141 // 142 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION 143 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \ 144 const OperationRegistration* register_##identifier() { \ 145 static OperationRegistration registration(OperationType::identifier, operationName, \ 146 validate, prepare, execute, {__VA_ARGS__}); \ 147 return ®istration; \ 148 } 149 #else 150 // This version ignores CPU execution logic (prepare and execute). 151 // The compiler is supposed to omit that code so that only validation logic 152 // makes it into libneuralnetworks_common*. 153 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \ 154 ...) \ 155 const OperationRegistration* register_##identifier() { \ 156 static OperationRegistration registration(OperationType::identifier, operationName, \ 157 validate, nullptr, nullptr, {__VA_ARGS__}); \ 158 return ®istration; \ 159 } 160 #endif 161 162 #define NN_REGISTER_OPERATION_DEFAULT_VALIDATION(identifier, prepare, execute, ...) \ 163 NN_VALIDATION_FUNCTION_SIGNATURE(identifier); \ 164 NN_REGISTER_OPERATION(identifier, #identifier, NN_VALIDATION_FUNCTION_NAME(identifier), \ 165 prepare, execute, __VA_ARGS__); 166 167 #define NN_OPERATION_IS_NOT_IMPLEMENTED(identifier) \ 168 const OperationRegistration* register_##identifier() { return nullptr; } 169 170 } // namespace nn 171 } // namespace android 172 173 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 174