1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Operations.h"
18 #include "CpuOperationUtils.h"
19
20 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
22
23 namespace android {
24 namespace nn {
25
26 // executionMutex is used to protect concurrent access of non-threadsafe resources
27 // like gemmlowp::GemmContext.
28 // std::mutex is safe for pthreads on Android.
29 static std::mutex executionMutex;
30
fullyConnectedFloat32(const float * inputData,const Shape & inputShape,const float * weightsData,const Shape & weightsShape,const float * biasData,const Shape & biasShape,int32_t activation,float * outputData,const Shape & outputShape)31 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
32 const float* weightsData, const Shape& weightsShape,
33 const float* biasData, const Shape& biasShape,
34 int32_t activation,
35 float* outputData, const Shape& outputShape) {
36 float output_activation_min, output_activation_max;
37 CalculateActivationRangeFloat(activation, &output_activation_min,
38 &output_activation_max);
39
40 // b/80425683, optimized implementation produces incorrect results when the
41 // number of input elements is the squre of batch_size.
42 uint32_t batch_size = getSizeOfDimension(outputShape, 0);
43 uint32_t input_n_elements = getNumberOfElements(inputShape);
44 if (batch_size * batch_size == input_n_elements) {
45 tflite::reference_ops::FullyConnected(
46 inputData, convertShapeToDims(inputShape),
47 weightsData, convertShapeToDims(weightsShape),
48 biasData, convertShapeToDims(biasShape),
49 output_activation_min, output_activation_max,
50 outputData, convertShapeToDims(outputShape));
51 } else {
52 tflite::optimized_ops::FullyConnected(
53 inputData, convertShapeToDims(inputShape),
54 weightsData, convertShapeToDims(weightsShape),
55 biasData, convertShapeToDims(biasShape),
56 output_activation_min, output_activation_max,
57 outputData, convertShapeToDims(outputShape));
58 }
59 return true;
60 }
61
fullyConnectedQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,uint8_t * outputData,const Shape & outputShape)62 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
63 const uint8_t* weightsData, const Shape& weightsShape,
64 const int32_t* biasData, const Shape& biasShape,
65 int32_t activation,
66 uint8_t* outputData, const Shape& outputShape) {
67 int32_t inputOffset = -inputShape.offset;
68 int32_t weightsOffset = -weightsShape.offset;
69 int32_t outputOffset = outputShape.offset;
70
71 float real_multiplier = 0.0;
72 int32_t output_multiplier = 0;
73 int32_t output_shift = 0;
74 int32_t output_activation_min = 0;
75 int32_t output_activation_max = 0;
76
77 if (!GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape,
78 outputShape, &real_multiplier) ||
79 !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
80 &output_shift)) {
81 return false;
82 }
83 CalculateActivationRangeUint8(activation, outputShape,
84 &output_activation_min,
85 &output_activation_max);
86
87 static gemmlowp::GemmContext gemm_context;
88
89 // Prevent concurrent executions that access gemm_context.
90 std::unique_lock<std::mutex> lock(executionMutex);
91 // Alow gemmlowp automatically decide how many threads to use.
92 gemm_context.set_max_num_threads(0);
93
94 tflite::optimized_ops::FullyConnected(
95 inputData, convertShapeToDims(inputShape), inputOffset,
96 weightsData, convertShapeToDims(weightsShape), weightsOffset,
97 biasData, convertShapeToDims(biasShape),
98 outputOffset, output_multiplier, output_shift,
99 output_activation_min, output_activation_max,
100 outputData, convertShapeToDims(outputShape), &gemm_context);
101
102 return true;
103 }
104 } // namespace nn
105 } // namespace android
106