1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Operations.h"
18 #include "CpuOperationUtils.h"
19 
20 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
22 
23 namespace android {
24 namespace nn {
25 
26 // executionMutex is used to protect concurrent access of non-threadsafe resources
27 // like gemmlowp::GemmContext.
28 // std::mutex is safe for pthreads on Android.
29 static std::mutex executionMutex;
30 
fullyConnectedFloat32(const float * inputData,const Shape & inputShape,const float * weightsData,const Shape & weightsShape,const float * biasData,const Shape & biasShape,int32_t activation,float * outputData,const Shape & outputShape)31 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
32                            const float* weightsData, const Shape& weightsShape,
33                            const float* biasData, const Shape& biasShape,
34                            int32_t activation,
35                            float* outputData, const Shape& outputShape) {
36     float output_activation_min, output_activation_max;
37     CalculateActivationRangeFloat(activation, &output_activation_min,
38                                   &output_activation_max);
39 
40     // b/80425683, optimized implementation produces incorrect results when the
41     // number of input elements is the squre of batch_size.
42     uint32_t batch_size = getSizeOfDimension(outputShape, 0);
43     uint32_t input_n_elements = getNumberOfElements(inputShape);
44     if (batch_size * batch_size == input_n_elements) {
45         tflite::reference_ops::FullyConnected(
46                 inputData, convertShapeToDims(inputShape),
47                 weightsData, convertShapeToDims(weightsShape),
48                 biasData, convertShapeToDims(biasShape),
49                 output_activation_min, output_activation_max,
50                 outputData, convertShapeToDims(outputShape));
51     } else {
52         tflite::optimized_ops::FullyConnected(
53                 inputData, convertShapeToDims(inputShape),
54                 weightsData, convertShapeToDims(weightsShape),
55                 biasData, convertShapeToDims(biasShape),
56                 output_activation_min, output_activation_max,
57                 outputData, convertShapeToDims(outputShape));
58     }
59     return true;
60 }
61 
fullyConnectedQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,uint8_t * outputData,const Shape & outputShape)62 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
63                           const uint8_t* weightsData, const Shape& weightsShape,
64                           const int32_t* biasData, const Shape& biasShape,
65                           int32_t activation,
66                           uint8_t* outputData, const Shape& outputShape) {
67     int32_t inputOffset = -inputShape.offset;
68     int32_t weightsOffset = -weightsShape.offset;
69     int32_t outputOffset = outputShape.offset;
70 
71     float real_multiplier = 0.0;
72     int32_t output_multiplier = 0;
73     int32_t output_shift = 0;
74     int32_t output_activation_min = 0;
75     int32_t output_activation_max = 0;
76 
77     if (!GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape,
78                                           outputShape, &real_multiplier) ||
79             !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
80                                               &output_shift)) {
81         return false;
82     }
83     CalculateActivationRangeUint8(activation, outputShape,
84                                   &output_activation_min,
85                                   &output_activation_max);
86 
87     static gemmlowp::GemmContext gemm_context;
88 
89     // Prevent concurrent executions that access gemm_context.
90     std::unique_lock<std::mutex> lock(executionMutex);
91     // Alow gemmlowp automatically decide how many threads to use.
92     gemm_context.set_max_num_threads(0);
93 
94     tflite::optimized_ops::FullyConnected(
95             inputData, convertShapeToDims(inputShape), inputOffset,
96             weightsData, convertShapeToDims(weightsShape), weightsOffset,
97             biasData, convertShapeToDims(biasShape),
98             outputOffset, output_multiplier, output_shift,
99             output_activation_min, output_activation_max,
100             outputData, convertShapeToDims(outputShape), &gemm_context);
101 
102     return true;
103 }
104 }  // namespace nn
105 }  // namespace android
106