1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "OperationResolver"
18 
19 #include "OperationResolver.h"
20 
21 #include "NeuralNetworks.h"
22 
23 namespace android {
24 namespace nn {
25 
26 // TODO(b/119608412): Find a way to not reference every operation here.
27 const OperationRegistration* register_ABS();
28 const OperationRegistration* register_ADD();
29 const OperationRegistration* register_AVERAGE_POOL_2D();
30 const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
31 const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN();
32 const OperationRegistration* register_BOX_WITH_NMS_LIMIT();
33 const OperationRegistration* register_CHANNEL_SHUFFLE();
34 const OperationRegistration* register_CONCATENATION();
35 const OperationRegistration* register_CONV_2D();
36 const OperationRegistration* register_DEQUANTIZE();
37 const OperationRegistration* register_DETECTION_POSTPROCESSING();
38 const OperationRegistration* register_DIV();
39 const OperationRegistration* register_EQUAL();
40 const OperationRegistration* register_EXP();
41 const OperationRegistration* register_FULLY_CONNECTED();
42 const OperationRegistration* register_GATHER();
43 const OperationRegistration* register_GENERATE_PROPOSALS();
44 const OperationRegistration* register_GREATER();
45 const OperationRegistration* register_GREATER_EQUAL();
46 const OperationRegistration* register_HEATMAP_MAX_KEYPOINT();
47 const OperationRegistration* register_INSTANCE_NORMALIZATION();
48 const OperationRegistration* register_L2_NORMALIZATION();
49 const OperationRegistration* register_L2_POOL_2D();
50 const OperationRegistration* register_LESS();
51 const OperationRegistration* register_LESS_EQUAL();
52 const OperationRegistration* register_LOG();
53 const OperationRegistration* register_LOGICAL_AND();
54 const OperationRegistration* register_LOGICAL_NOT();
55 const OperationRegistration* register_LOGICAL_OR();
56 const OperationRegistration* register_LOGISTIC();
57 const OperationRegistration* register_LOG_SOFTMAX();
58 const OperationRegistration* register_MAX_POOL_2D();
59 const OperationRegistration* register_MUL();
60 const OperationRegistration* register_NEG();
61 const OperationRegistration* register_NOT_EQUAL();
62 const OperationRegistration* register_PRELU();
63 const OperationRegistration* register_QUANTIZE();
64 const OperationRegistration* register_REDUCE_ALL();
65 const OperationRegistration* register_REDUCE_ANY();
66 const OperationRegistration* register_REDUCE_MAX();
67 const OperationRegistration* register_REDUCE_MIN();
68 const OperationRegistration* register_REDUCE_PROD();
69 const OperationRegistration* register_REDUCE_SUM();
70 const OperationRegistration* register_RELU();
71 const OperationRegistration* register_RELU1();
72 const OperationRegistration* register_RELU6();
73 const OperationRegistration* register_RESIZE_BILINEAR();
74 const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR();
75 const OperationRegistration* register_ROI_ALIGN();
76 const OperationRegistration* register_ROI_POOLING();
77 const OperationRegistration* register_RSQRT();
78 const OperationRegistration* register_SELECT();
79 const OperationRegistration* register_SIN();
80 const OperationRegistration* register_SLICE();
81 const OperationRegistration* register_SOFTMAX();
82 const OperationRegistration* register_SQRT();
83 const OperationRegistration* register_SUB();
84 const OperationRegistration* register_TANH();
85 const OperationRegistration* register_TRANSPOSE();
86 const OperationRegistration* register_TRANSPOSE_CONV_2D();
87 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM();
88 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN();
89 
BuiltinOperationResolver()90 BuiltinOperationResolver::BuiltinOperationResolver() {
91     registerOperation(register_ABS());
92     registerOperation(register_ADD());
93     registerOperation(register_AVERAGE_POOL_2D());
94     registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
95     registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN());
96     registerOperation(register_BOX_WITH_NMS_LIMIT());
97     registerOperation(register_CHANNEL_SHUFFLE());
98     registerOperation(register_CONCATENATION());
99     registerOperation(register_CONV_2D());
100     registerOperation(register_DEQUANTIZE());
101     registerOperation(register_DETECTION_POSTPROCESSING());
102     registerOperation(register_DIV());
103     registerOperation(register_EQUAL());
104     registerOperation(register_EXP());
105     registerOperation(register_FULLY_CONNECTED());
106     registerOperation(register_GATHER());
107     registerOperation(register_GENERATE_PROPOSALS());
108     registerOperation(register_GREATER());
109     registerOperation(register_GREATER_EQUAL());
110     registerOperation(register_HEATMAP_MAX_KEYPOINT());
111     registerOperation(register_INSTANCE_NORMALIZATION());
112     registerOperation(register_L2_NORMALIZATION());
113     registerOperation(register_L2_POOL_2D());
114     registerOperation(register_LESS());
115     registerOperation(register_LESS_EQUAL());
116     registerOperation(register_LOG());
117     registerOperation(register_LOGICAL_AND());
118     registerOperation(register_LOGICAL_NOT());
119     registerOperation(register_LOGICAL_OR());
120     registerOperation(register_LOGISTIC());
121     registerOperation(register_LOG_SOFTMAX());
122     registerOperation(register_MAX_POOL_2D());
123     registerOperation(register_MUL());
124     registerOperation(register_NEG());
125     registerOperation(register_NOT_EQUAL());
126     registerOperation(register_PRELU());
127     registerOperation(register_QUANTIZE());
128     registerOperation(register_REDUCE_ALL());
129     registerOperation(register_REDUCE_ANY());
130     registerOperation(register_REDUCE_MAX());
131     registerOperation(register_REDUCE_MIN());
132     registerOperation(register_REDUCE_PROD());
133     registerOperation(register_REDUCE_SUM());
134     registerOperation(register_RELU());
135     registerOperation(register_RELU1());
136     registerOperation(register_RELU6());
137     registerOperation(register_RESIZE_BILINEAR());
138     registerOperation(register_RESIZE_NEAREST_NEIGHBOR());
139     registerOperation(register_ROI_ALIGN());
140     registerOperation(register_ROI_POOLING());
141     registerOperation(register_RSQRT());
142     registerOperation(register_SELECT());
143     registerOperation(register_SIN());
144     registerOperation(register_SLICE());
145     registerOperation(register_SOFTMAX());
146     registerOperation(register_SQRT());
147     registerOperation(register_SUB());
148     registerOperation(register_TANH());
149     registerOperation(register_TRANSPOSE());
150     registerOperation(register_TRANSPOSE_CONV_2D());
151     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM());
152     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN());
153 }
154 
findOperation(OperationType operationType) const155 const OperationRegistration* BuiltinOperationResolver::findOperation(
156         OperationType operationType) const {
157     auto index = static_cast<int32_t>(operationType);
158     if (index < 0 || index >= kNumberOfOperationTypes) {
159         return nullptr;
160     }
161     return mRegistrations[index];
162 }
163 
registerOperation(const OperationRegistration * operationRegistration)164 void BuiltinOperationResolver::registerOperation(
165         const OperationRegistration* operationRegistration) {
166     CHECK(operationRegistration != nullptr);
167     auto index = static_cast<int32_t>(operationRegistration->type);
168     CHECK_LE(0, index);
169     CHECK_LT(index, kNumberOfOperationTypes);
170     CHECK(mRegistrations[index] == nullptr);
171     mRegistrations[index] = operationRegistration;
172 }
173 
174 }  // namespace nn
175 }  // namespace android
176