1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17 
18 #include "ruy/profiler/instrumentation.h"  // from @ruy
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
23 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace squared_difference {
33 
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37 
38 struct OpData {
39   bool requires_broadcast;
40   ArithmeticParams arithmetic_params;
41 };
42 
43 template <typename T>
SquaredDifference(T input1,T input2)44 T SquaredDifference(T input1, T input2) {
45   const T difference = input1 - input2;
46   return difference * difference;
47 }
48 
Init(TfLiteContext * context,const char * buffer,size_t length)49 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
50   auto* data = new OpData;
51   data->requires_broadcast = false;
52   return data;
53 }
54 
Free(TfLiteContext * context,void * buffer)55 void Free(TfLiteContext* context, void* buffer) {
56   delete reinterpret_cast<OpData*>(buffer);
57 }
58 
Prepare(TfLiteContext * context,TfLiteNode * node)59 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60   OpData* data = reinterpret_cast<OpData*>(node->user_data);
61 
62   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
63   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
64 
65   const TfLiteTensor* input1;
66   TF_LITE_ENSURE_OK(context,
67                     GetInputSafe(context, node, kInputTensor1, &input1));
68   const TfLiteTensor* input2;
69   TF_LITE_ENSURE_OK(context,
70                     GetInputSafe(context, node, kInputTensor2, &input2));
71   TfLiteTensor* output;
72   TF_LITE_ENSURE_OK(context,
73                     GetOutputSafe(context, node, kOutputTensor, &output));
74 
75   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
76   output->type = input2->type;
77 
78   // Ensure the quantization parameters are equivalent.
79   if (input1->type == kTfLiteInt8) {
80     const auto& input1_quantization_params = input1->params;
81     const auto& input2_quantization_params = input2->params;
82     const auto& output_quantization_params = output->params;
83     const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
84     const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
85     TF_LITE_ENSURE(context,
86                    input1_quantization_params.zero_point >= integer_type_min);
87     TF_LITE_ENSURE(context,
88                    input1_quantization_params.zero_point <= integer_type_max);
89     TF_LITE_ENSURE(context,
90                    input2_quantization_params.zero_point >= integer_type_min);
91     TF_LITE_ENSURE(context,
92                    input2_quantization_params.zero_point <= integer_type_max);
93     TF_LITE_ENSURE(context,
94                    output_quantization_params.zero_point >= integer_type_min);
95     TF_LITE_ENSURE(context,
96                    output_quantization_params.zero_point <= integer_type_max);
97     data->arithmetic_params.input1_offset =
98         -input1_quantization_params.zero_point;
99     data->arithmetic_params.input2_offset =
100         -input2_quantization_params.zero_point;
101     data->arithmetic_params.output_offset =
102         output_quantization_params.zero_point;
103 
104     // shift to make integer for scales.
105     data->arithmetic_params.left_shift = 7;
106     const double twice_max_input_scale =
107         2 * std::max(input1_quantization_params.scale,
108                      input2_quantization_params.scale);
109     const double real_input1_multiplier =
110         input1_quantization_params.scale / twice_max_input_scale;
111     double real_input2_multiplier =
112         input2_quantization_params.scale / twice_max_input_scale;
113     const double real_output_multiplier =
114         (twice_max_input_scale * twice_max_input_scale) /
115         ((1 << data->arithmetic_params.left_shift * 2) *
116          output_quantization_params.scale);
117     tflite::QuantizeMultiplierSmallerThanOneExp(
118         real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
119         &data->arithmetic_params.input1_shift);
120     tflite::QuantizeMultiplierSmallerThanOneExp(
121         real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
122         &data->arithmetic_params.input2_shift);
123     tflite::QuantizeMultiplierSmallerThanOneExp(
124         real_output_multiplier, &data->arithmetic_params.output_multiplier,
125         &data->arithmetic_params.output_shift);
126     data->arithmetic_params.quantized_activation_min =
127         std::numeric_limits<int8_t>::min();
128     data->arithmetic_params.quantized_activation_max =
129         std::numeric_limits<int8_t>::max();
130   }
131 
132   data->requires_broadcast = !HaveSameShapes(input1, input2);
133 
134   TfLiteIntArray* output_size = nullptr;
135   if (data->requires_broadcast) {
136     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
137                                    context, input1, input2, &output_size));
138   } else {
139     output_size = TfLiteIntArrayCopy(input1->dims);
140   }
141 
142   return context->ResizeTensor(context, output, output_size);
143 }
144 
SquaredDifference(int8_t x,int8_t y,const ArithmeticParams & params)145 inline int8_t SquaredDifference(int8_t x, int8_t y,
146                                 const ArithmeticParams& params) {
147   const int32_t input1_val = params.input1_offset + x;
148   const int32_t input2_val = params.input2_offset + y;
149   const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
150   const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
151   const int32_t scaled_input1_val =
152       MultiplyByQuantizedMultiplierSmallerThanOneExp(
153           shifted_input1_val, params.input1_multiplier, params.input1_shift);
154   const int32_t scaled_input2_val =
155       MultiplyByQuantizedMultiplierSmallerThanOneExp(
156           shifted_input2_val, params.input2_multiplier, params.input2_shift);
157   const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
158 
159   // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
160   const int32_t squared_raw_diff = raw_diff * raw_diff;
161   const int32_t raw_output =
162       MultiplyByQuantizedMultiplierSmallerThanOneExp(
163           squared_raw_diff, params.output_multiplier, params.output_shift) +
164       params.output_offset;
165   const int32_t clamped_output =
166       std::min(params.quantized_activation_max,
167                std::max(params.quantized_activation_min, raw_output));
168   return static_cast<int8_t>(clamped_output);
169 }
170 
171 template <typename T>
EvalQuantizedSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)172 void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
173                                     const OpData* data,
174                                     const TfLiteTensor* input1,
175                                     const TfLiteTensor* input2,
176                                     TfLiteTensor* output) {
177   const auto* op_data = static_cast<const OpData*>(node->user_data);
178   if (data->requires_broadcast) {
179     reference_integer_ops::BroadcastBinaryFunction4DSlow(
180         op_data->arithmetic_params, GetTensorShape(input1),
181         GetTensorData<T>(input1), GetTensorShape(input2),
182         GetTensorData<T>(input2), GetTensorShape(output),
183         GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams,
184         SquaredDifference);
185   } else {
186     const int flat_size = GetTensorShape(input1).FlatSize();
187     reference_integer_ops::ElementWise(
188         flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1),
189         GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output),
190         reference_integer_ops::CheckArithmeticParams, SquaredDifference);
191   }
192 }
193 
194 template <typename T>
EvalSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)195 void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
196                            const OpData* data, const TfLiteTensor* input1,
197                            const TfLiteTensor* input2, TfLiteTensor* output) {
198   if (data->requires_broadcast) {
199     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
200         GetTensorShape(input1), GetTensorData<T>(input1),
201         GetTensorShape(input2), GetTensorData<T>(input2),
202         GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
203   } else {
204     reference_ops::BinaryFunction<T, T, T>(
205         GetTensorShape(input1), GetTensorData<T>(input1),
206         GetTensorShape(input2), GetTensorData<T>(input2),
207         GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
208   }
209 }
210 
Eval(TfLiteContext * context,TfLiteNode * node)211 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
212   OpData* data = reinterpret_cast<OpData*>(node->user_data);
213   ruy::profiler::ScopeLabel label("SquaredDifference");
214 
215   const TfLiteTensor* input1;
216   TF_LITE_ENSURE_OK(context,
217                     GetInputSafe(context, node, kInputTensor1, &input1));
218   const TfLiteTensor* input2;
219   TF_LITE_ENSURE_OK(context,
220                     GetInputSafe(context, node, kInputTensor2, &input2));
221   TfLiteTensor* output;
222   TF_LITE_ENSURE_OK(context,
223                     GetOutputSafe(context, node, kOutputTensor, &output));
224 
225   if (output->type == kTfLiteFloat32) {
226     EvalSquaredDifference<float>(context, node, data, input1, input2, output);
227   } else if (output->type == kTfLiteInt32) {
228     EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
229   } else if (output->type == kTfLiteInt8) {
230     EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
231                                            output);
232   } else {
233     context->ReportError(
234         context,
235         "SquaredDifference only supports FLOAT32 and INT32 now, got %d.",
236         output->type);
237     return kTfLiteError;
238   }
239 
240   return kTfLiteOk;
241 }
242 
243 }  // namespace squared_difference
244 
Register_SQUARED_DIFFERENCE()245 TfLiteRegistration* Register_SQUARED_DIFFERENCE() {
246   static TfLiteRegistration r = {
247       squared_difference::Init, squared_difference::Free,
248       squared_difference::Prepare, squared_difference::Eval};
249   return &r;
250 }
251 
252 }  // namespace builtin
253 }  // namespace ops
254 }  // namespace tflite
255