1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "ruy/profiler/instrumentation.h" // from @ruy
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
23 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace squared_difference {
33
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37
38 struct OpData {
39 bool requires_broadcast;
40 ArithmeticParams arithmetic_params;
41 };
42
43 template <typename T>
SquaredDifference(T input1,T input2)44 T SquaredDifference(T input1, T input2) {
45 const T difference = input1 - input2;
46 return difference * difference;
47 }
48
Init(TfLiteContext * context,const char * buffer,size_t length)49 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
50 auto* data = new OpData;
51 data->requires_broadcast = false;
52 return data;
53 }
54
Free(TfLiteContext * context,void * buffer)55 void Free(TfLiteContext* context, void* buffer) {
56 delete reinterpret_cast<OpData*>(buffer);
57 }
58
Prepare(TfLiteContext * context,TfLiteNode * node)59 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60 OpData* data = reinterpret_cast<OpData*>(node->user_data);
61
62 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
63 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
64
65 const TfLiteTensor* input1;
66 TF_LITE_ENSURE_OK(context,
67 GetInputSafe(context, node, kInputTensor1, &input1));
68 const TfLiteTensor* input2;
69 TF_LITE_ENSURE_OK(context,
70 GetInputSafe(context, node, kInputTensor2, &input2));
71 TfLiteTensor* output;
72 TF_LITE_ENSURE_OK(context,
73 GetOutputSafe(context, node, kOutputTensor, &output));
74
75 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
76 output->type = input2->type;
77
78 // Ensure the quantization parameters are equivalent.
79 if (input1->type == kTfLiteInt8) {
80 const auto& input1_quantization_params = input1->params;
81 const auto& input2_quantization_params = input2->params;
82 const auto& output_quantization_params = output->params;
83 const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
84 const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
85 TF_LITE_ENSURE(context,
86 input1_quantization_params.zero_point >= integer_type_min);
87 TF_LITE_ENSURE(context,
88 input1_quantization_params.zero_point <= integer_type_max);
89 TF_LITE_ENSURE(context,
90 input2_quantization_params.zero_point >= integer_type_min);
91 TF_LITE_ENSURE(context,
92 input2_quantization_params.zero_point <= integer_type_max);
93 TF_LITE_ENSURE(context,
94 output_quantization_params.zero_point >= integer_type_min);
95 TF_LITE_ENSURE(context,
96 output_quantization_params.zero_point <= integer_type_max);
97 data->arithmetic_params.input1_offset =
98 -input1_quantization_params.zero_point;
99 data->arithmetic_params.input2_offset =
100 -input2_quantization_params.zero_point;
101 data->arithmetic_params.output_offset =
102 output_quantization_params.zero_point;
103
104 // shift to make integer for scales.
105 data->arithmetic_params.left_shift = 7;
106 const double twice_max_input_scale =
107 2 * std::max(input1_quantization_params.scale,
108 input2_quantization_params.scale);
109 const double real_input1_multiplier =
110 input1_quantization_params.scale / twice_max_input_scale;
111 double real_input2_multiplier =
112 input2_quantization_params.scale / twice_max_input_scale;
113 const double real_output_multiplier =
114 (twice_max_input_scale * twice_max_input_scale) /
115 ((1 << data->arithmetic_params.left_shift * 2) *
116 output_quantization_params.scale);
117 tflite::QuantizeMultiplierSmallerThanOneExp(
118 real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
119 &data->arithmetic_params.input1_shift);
120 tflite::QuantizeMultiplierSmallerThanOneExp(
121 real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
122 &data->arithmetic_params.input2_shift);
123 tflite::QuantizeMultiplierSmallerThanOneExp(
124 real_output_multiplier, &data->arithmetic_params.output_multiplier,
125 &data->arithmetic_params.output_shift);
126 data->arithmetic_params.quantized_activation_min =
127 std::numeric_limits<int8_t>::min();
128 data->arithmetic_params.quantized_activation_max =
129 std::numeric_limits<int8_t>::max();
130 }
131
132 data->requires_broadcast = !HaveSameShapes(input1, input2);
133
134 TfLiteIntArray* output_size = nullptr;
135 if (data->requires_broadcast) {
136 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
137 context, input1, input2, &output_size));
138 } else {
139 output_size = TfLiteIntArrayCopy(input1->dims);
140 }
141
142 return context->ResizeTensor(context, output, output_size);
143 }
144
SquaredDifference(int8_t x,int8_t y,const ArithmeticParams & params)145 inline int8_t SquaredDifference(int8_t x, int8_t y,
146 const ArithmeticParams& params) {
147 const int32_t input1_val = params.input1_offset + x;
148 const int32_t input2_val = params.input2_offset + y;
149 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
150 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
151 const int32_t scaled_input1_val =
152 MultiplyByQuantizedMultiplierSmallerThanOneExp(
153 shifted_input1_val, params.input1_multiplier, params.input1_shift);
154 const int32_t scaled_input2_val =
155 MultiplyByQuantizedMultiplierSmallerThanOneExp(
156 shifted_input2_val, params.input2_multiplier, params.input2_shift);
157 const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
158
159 // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
160 const int32_t squared_raw_diff = raw_diff * raw_diff;
161 const int32_t raw_output =
162 MultiplyByQuantizedMultiplierSmallerThanOneExp(
163 squared_raw_diff, params.output_multiplier, params.output_shift) +
164 params.output_offset;
165 const int32_t clamped_output =
166 std::min(params.quantized_activation_max,
167 std::max(params.quantized_activation_min, raw_output));
168 return static_cast<int8_t>(clamped_output);
169 }
170
171 template <typename T>
EvalQuantizedSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)172 void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
173 const OpData* data,
174 const TfLiteTensor* input1,
175 const TfLiteTensor* input2,
176 TfLiteTensor* output) {
177 const auto* op_data = static_cast<const OpData*>(node->user_data);
178 if (data->requires_broadcast) {
179 reference_integer_ops::BroadcastBinaryFunction4DSlow(
180 op_data->arithmetic_params, GetTensorShape(input1),
181 GetTensorData<T>(input1), GetTensorShape(input2),
182 GetTensorData<T>(input2), GetTensorShape(output),
183 GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams,
184 SquaredDifference);
185 } else {
186 const int flat_size = GetTensorShape(input1).FlatSize();
187 reference_integer_ops::ElementWise(
188 flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1),
189 GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output),
190 reference_integer_ops::CheckArithmeticParams, SquaredDifference);
191 }
192 }
193
194 template <typename T>
EvalSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)195 void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
196 const OpData* data, const TfLiteTensor* input1,
197 const TfLiteTensor* input2, TfLiteTensor* output) {
198 if (data->requires_broadcast) {
199 reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
200 GetTensorShape(input1), GetTensorData<T>(input1),
201 GetTensorShape(input2), GetTensorData<T>(input2),
202 GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
203 } else {
204 reference_ops::BinaryFunction<T, T, T>(
205 GetTensorShape(input1), GetTensorData<T>(input1),
206 GetTensorShape(input2), GetTensorData<T>(input2),
207 GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
208 }
209 }
210
Eval(TfLiteContext * context,TfLiteNode * node)211 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
212 OpData* data = reinterpret_cast<OpData*>(node->user_data);
213 ruy::profiler::ScopeLabel label("SquaredDifference");
214
215 const TfLiteTensor* input1;
216 TF_LITE_ENSURE_OK(context,
217 GetInputSafe(context, node, kInputTensor1, &input1));
218 const TfLiteTensor* input2;
219 TF_LITE_ENSURE_OK(context,
220 GetInputSafe(context, node, kInputTensor2, &input2));
221 TfLiteTensor* output;
222 TF_LITE_ENSURE_OK(context,
223 GetOutputSafe(context, node, kOutputTensor, &output));
224
225 if (output->type == kTfLiteFloat32) {
226 EvalSquaredDifference<float>(context, node, data, input1, input2, output);
227 } else if (output->type == kTfLiteInt32) {
228 EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
229 } else if (output->type == kTfLiteInt8) {
230 EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
231 output);
232 } else {
233 context->ReportError(
234 context,
235 "SquaredDifference only supports FLOAT32 and INT32 now, got %d.",
236 output->type);
237 return kTfLiteError;
238 }
239
240 return kTfLiteOk;
241 }
242
243 } // namespace squared_difference
244
Register_SQUARED_DIFFERENCE()245 TfLiteRegistration* Register_SQUARED_DIFFERENCE() {
246 static TfLiteRegistration r = {
247 squared_difference::Init, squared_difference::Free,
248 squared_difference::Prepare, squared_difference::Eval};
249 return &r;
250 }
251
252 } // namespace builtin
253 } // namespace ops
254 } // namespace tflite
255