1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains the implementation of the operations.
18
19 #define LOG_TAG "Operations"
20
21 #pragma clang diagnostic push
22 #pragma clang diagnostic ignored "-Wunused-parameter"
23 #pragma clang diagnostic ignored "-Wsign-compare"
24 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
25 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
26 #include <tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h>
27 #pragma clang diagnostic pop
28
29 #include <vector>
30
31 #include "CpuOperationUtils.h"
32 #include "Operations.h"
33 #include "SimpleMath.h"
34 #include "Tracing.h"
35
36 namespace android {
37 namespace nn {
38
meanFloat16(_Float16 * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,_Float16 * outputData,const Shape & outputShape)39 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
40 const Shape& axisShape, bool keepDims, _Float16* outputData,
41 const Shape& outputShape) {
42 NNTRACE_TRANS("meanFloat16");
43 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
44 convertFloat16ToFloat32(inputData, &inputDataFloat32);
45
46 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
47 meanGeneric<float, float>(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims,
48 outputDataFloat32.data(), outputShape);
49 convertFloat32ToFloat16(outputDataFloat32, outputData);
50 return true;
51 }
52
53 template <typename T, typename U>
meanGeneric(T * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,T * outputData,const Shape & outputShape)54 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
55 bool keepDims, T* outputData, const Shape& outputShape) {
56 NNTRACE_TRANS("meanGeneric");
57 // Creates a temp index to iterate through input data.
58 int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
59
60 // Creates a temp tensor to store resolved axis given input data.
61 int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
62 int32_t* resolvedAxis = new int32_t[axisSize];
63
64 bool result = true;
65 U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)];
66 if (!tempSumBuffer) {
67 LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN";
68 result = false;
69 } else {
70 NNTRACE_COMP_SWITCH("optimized_ops::Mean");
71 tflite::reference_ops::Mean<T, U>(
72 inputData, reinterpret_cast<const int*>(inputShape.dimensions.data()),
73 getNumberOfDimensions(inputShape), outputData,
74 reinterpret_cast<const int*>(outputShape.dimensions.data()),
75 getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer,
76 resolvedAxis, tempSumBuffer);
77 delete[] tempSumBuffer;
78 }
79 delete[] scratchBuffer;
80 delete[] resolvedAxis;
81 return result;
82 }
83 template bool meanGeneric<float, float>(float* inputData, const Shape& inputShape,
84 const int32_t* axis, const Shape& axisShape, bool keepDims,
85 float* outputData, const Shape& outputShape);
86 template bool meanGeneric<uint8_t, int32_t>(uint8_t* inputData, const Shape& inputShape,
87 const int32_t* axis, const Shape& axisShape,
88 bool keepDims, uint8_t* outputData,
89 const Shape& outputShape);
90 template bool meanGeneric<int8_t, int32_t>(int8_t* inputData, const Shape& inputShape,
91 const int32_t* axis, const Shape& axisShape,
92 bool keepDims, int8_t* outputData,
93 const Shape& outputShape);
94
95 } // namespace nn
96 } // namespace android
97