1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H 19 20 #include <stddef.h> 21 22 #include <cstdint> 23 #include <vector> 24 25 #include "ActivationFunctor.h" 26 27 namespace android { 28 namespace nn { 29 30 struct Shape; 31 32 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape); 33 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape); 34 35 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape, 36 const _Float16* filterData, const Shape& filterShape, 37 const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft, 38 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 39 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 40 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 41 _Float16* outputData, const Shape& outputShape); 42 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 43 const Shape& filterShape, const float* biasData, const Shape& biasShape, 44 int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop, 45 int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight, 46 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 47 int32_t depthMultiplier, int32_t activation, float* outputData, 48 const Shape& outputShape); 49 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape, 50 const uint8_t* filterData, const Shape& filterShape, 51 const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft, 52 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 53 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 54 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 55 uint8_t* outputData, const Shape& outputShape); 56 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 57 const int8_t* filterData, const Shape& filterShape, 58 const float* filterScales, const int32_t* biasData, 59 const Shape& biasShape, int32_t paddingLeft, 60 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 61 int32_t strideWidth, int32_t strideHeight, 62 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 63 int32_t depthMultiplier, int32_t activation, uint8_t* outputData, 64 const Shape& outputShape); 65 66 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius, 67 float bias, float alpha, float beta, int32_t axis, 68 _Float16* outputData, const Shape& outputShape); 69 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius, 70 float bias, float alpha, float beta, int32_t axis, float* outputData, 71 const Shape& outputShape); 72 73 bool copyData(const void* inputData, const Shape& inputShape, void* outputData, 74 const Shape& outputShape); 75 76 template <typename T> 77 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 78 T* outputData, const Shape& outputShape); 79 template <typename T> 80 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 81 T* outputData, const Shape& outputShape); 82 83 template <typename T> 84 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value, 85 T* outputData, const Shape& outputShape); 86 87 template <typename T> 88 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 89 T* outputData, const Shape& outputShape); 90 91 template <typename T> 92 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 93 const int32_t* padding, const Shape& paddingShape, T* outputData, 94 const Shape& outputShape); 95 96 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, 97 const Shape& axisShape, bool keepDims, _Float16* outputData, 98 const Shape& outputShape); 99 template <typename T, typename U> 100 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, 101 bool keepDims, T* outputData, const Shape& outputShape); 102 103 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 104 const int32_t* beginData, const int32_t* endData, 105 const int32_t* stridesData, int32_t beginMask, int32_t endMask, 106 int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape); 107 108 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis, 109 bool isArgMin, uint8_t* outputData, const Shape& outputShape); 110 111 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis, 112 const std::vector<_Float16*>* outputDataPtrs, 113 const std::vector<Shape>& outputShapes); 114 115 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis, 116 const std::vector<float*>* outputDataPtrs, 117 const std::vector<Shape>& outputShapes); 118 119 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis, 120 const std::vector<int32_t*>* outputDataPtrs, 121 const std::vector<Shape>& outputShapes); 122 123 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis, 124 const std::vector<uint8_t*>* outputDataPtrs, 125 const std::vector<Shape>& outputShapes); 126 127 bool splitQuant8Signed(const int8_t* inputData, const Shape& inputShape, const int32_t axis, 128 const std::vector<int8_t*>* outputDataPtrs, 129 const std::vector<Shape>& outputShapes); 130 131 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape, 132 const _Float16* filterData, const Shape& filterShape, 133 const _Float16* biasData, const Shape& biasShape, int32_t numGroups, 134 int32_t padding_left, int32_t padding_right, int32_t padding_top, 135 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 136 int32_t activation, _Float16* outputData, const Shape& outputShape); 137 138 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 139 const Shape& filterShape, const float* biasData, const Shape& biasShape, 140 int32_t numGroups, int32_t padding_left, int32_t padding_right, 141 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 142 int32_t stride_height, int32_t activation, float* outputData, 143 const Shape& outputShape); 144 145 template <typename T> 146 bool groupedConvQuant8(const T* inputData, const Shape& inputShape, const T* filterData, 147 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape, 148 int32_t numGroups, int32_t padding_left, int32_t padding_right, 149 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 150 int32_t stride_height, int32_t activation, T* outputData, 151 const Shape& outputShape); 152 153 template <typename T> 154 bool groupedConvQuant8PerChannel(const T* inputData, const Shape& inputShape, 155 const int8_t* filterData, const Shape& filterShape, 156 const float* filterScales, const int32_t* biasData, 157 const Shape& biasShape, int32_t padding_left, 158 int32_t padding_right, int32_t padding_top, int32_t padding_bottom, 159 int32_t stride_width, int32_t stride_height, int32_t numGroups, 160 int32_t activation, T* outputData, const Shape& outputShape); 161 162 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups, 163 int32_t axis, uint8_t* outputData, const Shape& outputShape); 164 } // namespace nn 165 } // namespace android 166 167 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATIONS_H 168