1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H
19 
20 #include <stddef.h>
21 
22 #include <cstdint>
23 #include <vector>
24 
25 #include "ActivationFunctor.h"
26 
27 namespace android {
28 namespace nn {
29 
30 struct Shape;
31 
32 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape);
33 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape);
34 
35 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape,
36                           const _Float16* filterData, const Shape& filterShape,
37                           const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft,
38                           int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
39                           int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
40                           int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
41                           _Float16* outputData, const Shape& outputShape);
42 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
43                           const Shape& filterShape, const float* biasData, const Shape& biasShape,
44                           int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop,
45                           int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight,
46                           int32_t dilationWidthFactor, int32_t dilationHeightFactor,
47                           int32_t depthMultiplier, int32_t activation, float* outputData,
48                           const Shape& outputShape);
49 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
50                          const uint8_t* filterData, const Shape& filterShape,
51                          const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft,
52                          int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
53                          int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
54                          int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
55                          uint8_t* outputData, const Shape& outputShape);
56 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
57                                    const int8_t* filterData, const Shape& filterShape,
58                                    const float* filterScales, const int32_t* biasData,
59                                    const Shape& biasShape, int32_t paddingLeft,
60                                    int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
61                                    int32_t strideWidth, int32_t strideHeight,
62                                    int32_t dilationWidthFactor, int32_t dilationHeightFactor,
63                                    int32_t depthMultiplier, int32_t activation, uint8_t* outputData,
64                                    const Shape& outputShape);
65 
66 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius,
67                               float bias, float alpha, float beta, int32_t axis,
68                               _Float16* outputData, const Shape& outputShape);
69 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius,
70                               float bias, float alpha, float beta, int32_t axis, float* outputData,
71                               const Shape& outputShape);
72 
73 bool copyData(const void* inputData, const Shape& inputShape, void* outputData,
74               const Shape& outputShape);
75 
76 template <typename T>
77 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
78                          T* outputData, const Shape& outputShape);
79 template <typename T>
80 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
81                          T* outputData, const Shape& outputShape);
82 
83 template <typename T>
84 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value,
85                 T* outputData, const Shape& outputShape);
86 
87 template <typename T>
88 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
89                          T* outputData, const Shape& outputShape);
90 
91 template <typename T>
92 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
93                          const int32_t* padding, const Shape& paddingShape, T* outputData,
94                          const Shape& outputShape);
95 
96 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
97                  const Shape& axisShape, bool keepDims, _Float16* outputData,
98                  const Shape& outputShape);
99 template <typename T, typename U>
100 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
101                  bool keepDims, T* outputData, const Shape& outputShape);
102 
103 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
104                          const int32_t* beginData, const int32_t* endData,
105                          const int32_t* stridesData, int32_t beginMask, int32_t endMask,
106                          int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape);
107 
108 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
109                       bool isArgMin, uint8_t* outputData, const Shape& outputShape);
110 
111 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
112                   const std::vector<_Float16*>* outputDataPtrs,
113                   const std::vector<Shape>& outputShapes);
114 
115 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis,
116                   const std::vector<float*>* outputDataPtrs,
117                   const std::vector<Shape>& outputShapes);
118 
119 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis,
120                 const std::vector<int32_t*>* outputDataPtrs,
121                 const std::vector<Shape>& outputShapes);
122 
123 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis,
124                  const std::vector<uint8_t*>* outputDataPtrs,
125                  const std::vector<Shape>& outputShapes);
126 
127 bool splitQuant8Signed(const int8_t* inputData, const Shape& inputShape, const int32_t axis,
128                        const std::vector<int8_t*>* outputDataPtrs,
129                        const std::vector<Shape>& outputShapes);
130 
131 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape,
132                         const _Float16* filterData, const Shape& filterShape,
133                         const _Float16* biasData, const Shape& biasShape, int32_t numGroups,
134                         int32_t padding_left, int32_t padding_right, int32_t padding_top,
135                         int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
136                         int32_t activation, _Float16* outputData, const Shape& outputShape);
137 
138 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
139                         const Shape& filterShape, const float* biasData, const Shape& biasShape,
140                         int32_t numGroups, int32_t padding_left, int32_t padding_right,
141                         int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
142                         int32_t stride_height, int32_t activation, float* outputData,
143                         const Shape& outputShape);
144 
145 template <typename T>
146 bool groupedConvQuant8(const T* inputData, const Shape& inputShape, const T* filterData,
147                        const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
148                        int32_t numGroups, int32_t padding_left, int32_t padding_right,
149                        int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
150                        int32_t stride_height, int32_t activation, T* outputData,
151                        const Shape& outputShape);
152 
153 template <typename T>
154 bool groupedConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
155                                  const int8_t* filterData, const Shape& filterShape,
156                                  const float* filterScales, const int32_t* biasData,
157                                  const Shape& biasShape, int32_t padding_left,
158                                  int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
159                                  int32_t stride_width, int32_t stride_height, int32_t numGroups,
160                                  int32_t activation, T* outputData, const Shape& outputShape);
161 
162 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
163                            int32_t axis, uint8_t* outputData, const Shape& outputShape);
164 }  // namespace nn
165 }  // namespace android
166 
167 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H
168