1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
19 
20 #include <functional>
21 #include <utility>
22 
23 #include "OperationsExecutionUtils.h"
24 #include "OperationsValidationUtils.h"
25 
26 namespace android {
27 namespace nn {
28 
29 // Encapsulates an operation implementation.
30 struct OperationRegistration {
31     OperationType type;
32     const char* name;
33 
34     // Validates operand types, shapes, and any values known during graph creation.
35     // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension
36     // types.
37     std::function<Result<Version>(const IOperationValidationContext*)> validate;
38 
39     // prepare is called when the inputs this operation depends on have been
40     // computed. Typically, prepare does any remaining validation and sets
41     // output shapes via context->setOutputShape(...).
42     std::function<bool(IOperationExecutionContext*)> prepare;
43 
44     // Executes the operation, reading from context->getInputBuffer(...)
45     // and writing to context->getOutputBuffer(...).
46     std::function<bool(IOperationExecutionContext*)> execute;
47 
48     struct Flag {
49         // Whether the operation allows at least one operand to be omitted.
50         bool allowOmittedOperand = false;
51         // Whether the operation allows at least one input operand to be a zero-sized tensor.
52         bool allowZeroSizedInput = false;
53     } flags;
54 
OperationRegistrationOperationRegistration55     OperationRegistration(
56             OperationType type, const char* name,
57             std::function<Result<Version>(const IOperationValidationContext*)> validate,
58             std::function<bool(IOperationExecutionContext*)> prepare,
59             std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
60         : type(type),
61           name(name),
62           validate(std::move(validate)),
63           prepare(std::move(prepare)),
64           execute(std::move(execute)),
65           flags(flags) {}
66 };
67 
68 // A registry of operation implementations.
69 class IOperationResolver {
70    public:
71     virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
~IOperationResolver()72     virtual ~IOperationResolver() {}
73 };
74 
75 // A registry of builtin operation implementations.
76 //
77 // Note that some operations bypass BuiltinOperationResolver (b/124041202).
78 //
79 // Usage:
80 //   const OperationRegistration* operationRegistration =
81 //           BuiltinOperationResolver::get()->findOperation(operationType);
82 //   NN_RET_CHECK(operationRegistration != nullptr);
83 //   NN_RET_CHECK(operationRegistration->validate != nullptr);
84 //   NN_RET_CHECK(operationRegistration->validate(&context));
85 //
86 class BuiltinOperationResolver : public IOperationResolver {
87     DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
88 
89    public:
get()90     static const BuiltinOperationResolver* get() {
91         static BuiltinOperationResolver instance;
92         return &instance;
93     }
94 
95     const OperationRegistration* findOperation(OperationType operationType) const override;
96 
97     // The number of operation types (OperationCode) defined in NeuralNetworksTypes.h.
98     static constexpr int kNumberOfOperationTypes = 106;
99 
100 #ifdef NN_EXPERIMENTAL_FEATURE
101     // The number of experimental operation types (ANeuralNetworksExperimentalOperationCode) defined
102     // in NeuralNetworksExperimentalFeatures.h.
103     static constexpr int kNumberOfExperimentalOperationTypes = 1;
104 
105     // The starting value of experimental operation types (ANeuralNetworksExperimentalOperationCode)
106     // defined in NeuralNetworksExperimentalFeatures.h.
107     static constexpr int kStartOfExperimentalOperations = 20000;
108 #endif  // NN_EXPERIMENTAL_FEATURE
109 
110    private:
111     BuiltinOperationResolver();
112 
113     void registerOperation(const OperationRegistration* operationRegistration);
114 
115     const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
116 
117 #ifdef NN_EXPERIMENTAL_FEATURE
118     const OperationRegistration* mExperimentalRegistrations[kNumberOfExperimentalOperationTypes] =
119             {};
120 #endif  // NN_EXPERIMENTAL_FEATURE
121 };
122 
123 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
124 // OperationResolver.
125 //
126 // Usage:
127 // (check OperationRegistration::Flag for available fields and default values.)
128 //
129 // - With default flags.
130 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
131 //                         foo_op::prepare, foo_op::execute);
132 //
133 // - With a customized flag.
134 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
135 //                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
136 //
137 // - With multiple customized flags.
138 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
139 //                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
140 //                         .allowZeroSizedInput = true);
141 //
142 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
143 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)     \
144     const OperationRegistration* register_##identifier() {                                    \
145         static OperationRegistration registration(OperationType::identifier, operationName,   \
146                                                   validate, prepare, execute, {__VA_ARGS__}); \
147         return &registration;                                                                 \
148     }
149 #else
150 // This version ignores CPU execution logic (prepare and execute).
151 // The compiler is supposed to omit that code so that only validation logic
152 // makes it into libneuralnetworks_common*.
153 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
154                               ...)                                                                 \
155     const OperationRegistration* register_##identifier() {                                         \
156         static OperationRegistration registration(OperationType::identifier, operationName,        \
157                                                   validate, nullptr, nullptr, {__VA_ARGS__});      \
158         return &registration;                                                                      \
159     }
160 #endif
161 
162 #define NN_REGISTER_OPERATION_DEFAULT_VALIDATION(identifier, prepare, execute, ...)         \
163     NN_VALIDATION_FUNCTION_SIGNATURE(identifier);                                           \
164     NN_REGISTER_OPERATION(identifier, #identifier, NN_VALIDATION_FUNCTION_NAME(identifier), \
165                           prepare, execute, __VA_ARGS__);
166 
167 #define NN_OPERATION_IS_NOT_IMPLEMENTED(identifier) \
168     const OperationRegistration* register_##identifier() { return nullptr; }
169 
170 }  // namespace nn
171 }  // namespace android
172 
173 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
174