1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/kernel_util.h"
16 
17 #include <stdint.h>
18 #include <stdlib.h>
19 
20 #include <algorithm>
21 #include <complex>
22 #include <limits>
23 #include <memory>
24 #include <string>
25 
26 #include "tensorflow/lite/c/builtin_op_data.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/kernels/internal/cppmath.h"
29 #include "tensorflow/lite/kernels/internal/quantization_util.h"
30 
31 namespace tflite {
32 
33 namespace {
34 
35 // Assumes tensor_index is a valid index (in bounds)
GetTensorAtIndex(const TfLiteContext * context,int tensor_index)36 inline TfLiteTensor* GetTensorAtIndex(const TfLiteContext* context,
37                                       int tensor_index) {
38   if (context->tensors != nullptr) {
39     return &context->tensors[tensor_index];
40   } else {
41     return context->GetTensor(context, tensor_index);
42   }
43 }
44 
45 // Validate in a single place to reduce binary size
ValidateTensorIndexingSafe(const TfLiteContext * context,int index,int max_size,const int * tensor_indices,int * tensor_index)46 inline TfLiteStatus ValidateTensorIndexingSafe(const TfLiteContext* context,
47                                                int index, int max_size,
48                                                const int* tensor_indices,
49                                                int* tensor_index) {
50   if (index < 0 || index >= max_size) {
51     TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
52                        "Invalid tensor index %d (not in [0, %d))\n", index,
53                        max_size);
54     return kTfLiteError;
55   }
56   if (tensor_indices[index] == kTfLiteOptionalTensor) {
57     TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
58                        "Tensor at index %d was optional but was expected\n",
59                        index);
60     return kTfLiteError;
61   }
62 
63   *tensor_index = tensor_indices[index];
64   return kTfLiteOk;
65 }
66 
67 // Same as above but returns -1 for invalid inputs instead of status + logging
68 // error.
ValidateTensorIndexing(const TfLiteContext * context,int index,int max_size,const int * tensor_indices)69 inline int ValidateTensorIndexing(const TfLiteContext* context, int index,
70                                   int max_size, const int* tensor_indices) {
71   if (index >= 0 && index < max_size) {
72     const int tensor_index = tensor_indices[index];
73     if (tensor_index != kTfLiteOptionalTensor) {
74       return tensor_index;
75     }
76   }
77   return -1;
78 }
79 
GetMutableInput(const TfLiteContext * context,const TfLiteNode * node,int index)80 inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
81                                      const TfLiteNode* node, int index) {
82   const int tensor_index = ValidateTensorIndexing(
83       context, index, node->inputs->size, node->inputs->data);
84   if (tensor_index < 0) {
85     return nullptr;
86   }
87   return GetTensorAtIndex(context, tensor_index);
88 }
89 
GetMutableInputSafe(const TfLiteContext * context,const TfLiteNode * node,int index,const TfLiteTensor ** tensor)90 inline TfLiteStatus GetMutableInputSafe(const TfLiteContext* context,
91                                         const TfLiteNode* node, int index,
92                                         const TfLiteTensor** tensor) {
93   int tensor_index;
94   TF_LITE_ENSURE_OK(
95       context, ValidateTensorIndexingSafe(context, index, node->inputs->size,
96                                           node->inputs->data, &tensor_index));
97   *tensor = GetTensorAtIndex(context, tensor_index);
98   return kTfLiteOk;
99 }
100 
101 }  // anonymous namespace.
102 
GetInput(const TfLiteContext * context,const TfLiteNode * node,int index)103 const TfLiteTensor* GetInput(const TfLiteContext* context,
104                              const TfLiteNode* node, int index) {
105   return GetMutableInput(context, node, index);
106 }
107 
GetInputSafe(const TfLiteContext * context,const TfLiteNode * node,int index,const TfLiteTensor ** tensor)108 TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
109                           int index, const TfLiteTensor** tensor) {
110   return GetMutableInputSafe(context, node, index, tensor);
111 }
112 
GetVariableInput(TfLiteContext * context,const TfLiteNode * node,int index)113 TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
114                                int index) {
115   TfLiteTensor* tensor = GetMutableInput(context, node, index);
116   return tensor->is_variable ? tensor : nullptr;
117 }
118 
GetOutput(TfLiteContext * context,const TfLiteNode * node,int index)119 TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
120                         int index) {
121   const int tensor_index = ValidateTensorIndexing(
122       context, index, node->outputs->size, node->outputs->data);
123   if (tensor_index < 0) {
124     return nullptr;
125   }
126   return GetTensorAtIndex(context, tensor_index);
127 }
128 
GetOutputSafe(const TfLiteContext * context,const TfLiteNode * node,int index,TfLiteTensor ** tensor)129 TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
130                            int index, TfLiteTensor** tensor) {
131   int tensor_index;
132   TF_LITE_ENSURE_OK(
133       context, ValidateTensorIndexingSafe(context, index, node->outputs->size,
134                                           node->outputs->data, &tensor_index));
135   *tensor = GetTensorAtIndex(context, tensor_index);
136   return kTfLiteOk;
137 }
138 
GetOptionalInputTensor(const TfLiteContext * context,const TfLiteNode * node,int index)139 const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
140                                            const TfLiteNode* node, int index) {
141   return GetInput(context, node, index);
142 }
143 
144 #ifndef TF_LITE_STATIC_MEMORY
GetTemporary(TfLiteContext * context,const TfLiteNode * node,int index)145 TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
146                            int index) {
147   const int tensor_index = ValidateTensorIndexing(
148       context, index, node->temporaries->size, node->temporaries->data);
149   if (tensor_index < 0) {
150     return nullptr;
151   }
152   return GetTensorAtIndex(context, tensor_index);
153 }
154 
GetTemporarySafe(const TfLiteContext * context,const TfLiteNode * node,int index,TfLiteTensor ** tensor)155 TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
156                               const TfLiteNode* node, int index,
157                               TfLiteTensor** tensor) {
158   int tensor_index;
159   TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
160                                  context, index, node->temporaries->size,
161                                  node->temporaries->data, &tensor_index));
162   *tensor = GetTensorAtIndex(context, tensor_index);
163   return kTfLiteOk;
164 }
165 
GetIntermediates(TfLiteContext * context,const TfLiteNode * node,int index)166 const TfLiteTensor* GetIntermediates(TfLiteContext* context,
167                                      const TfLiteNode* node, int index) {
168   const int tensor_index = ValidateTensorIndexing(
169       context, index, node->intermediates->size, node->intermediates->data);
170   if (tensor_index < 0) {
171     return nullptr;
172   }
173   return GetTensorAtIndex(context, tensor_index);
174 }
175 
GetIntermediatesSafe(const TfLiteContext * context,const TfLiteNode * node,int index,TfLiteTensor ** tensor)176 TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
177                                   const TfLiteNode* node, int index,
178                                   TfLiteTensor** tensor) {
179   int tensor_index;
180   TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
181                                  context, index, node->intermediates->size,
182                                  node->intermediates->data, &tensor_index));
183   *tensor = GetTensorAtIndex(context, tensor_index);
184   return kTfLiteOk;
185 }
186 #endif  // TF_LITE_STATIC_MEMORY
187 
188 // Per-axis
PopulateConvolutionQuantizationParams(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,const TfLiteFusedActivation & activation,int32_t * multiplier,int * shift,int32_t * output_activation_min,int32_t * output_activation_max,int32_t * per_channel_multiplier,int * per_channel_shift)189 TfLiteStatus PopulateConvolutionQuantizationParams(
190     TfLiteContext* context, const TfLiteTensor* input,
191     const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
192     const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
193     int32_t* output_activation_min, int32_t* output_activation_max,
194     int32_t* per_channel_multiplier, int* per_channel_shift) {
195   const auto* affine_quantization =
196       reinterpret_cast<TfLiteAffineQuantization*>(filter->quantization.params);
197   return PopulateConvolutionQuantizationParams(
198       context, input, filter, bias, output, activation, multiplier, shift,
199       output_activation_min, output_activation_max, per_channel_multiplier,
200       per_channel_shift, affine_quantization->scale->size);
201 }
202 
203 // Per-axis & per-tensor
PopulateConvolutionQuantizationParams(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,const TfLiteFusedActivation & activation,int32_t * multiplier,int * shift,int32_t * output_activation_min,int32_t * output_activation_max,int32_t * per_channel_multiplier,int * per_channel_shift,int num_channels)204 TfLiteStatus PopulateConvolutionQuantizationParams(
205     TfLiteContext* context, const TfLiteTensor* input,
206     const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
207     const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
208     int32_t* output_activation_min, int32_t* output_activation_max,
209     int32_t* per_channel_multiplier, int* per_channel_shift, int num_channels) {
210   TF_LITE_ENSURE_EQ(context, input->quantization.type,
211                     kTfLiteAffineQuantization);
212   TF_LITE_ENSURE_EQ(context, filter->quantization.type,
213                     kTfLiteAffineQuantization);
214   // TODO(jianlijianli): Enable bias type check and bias scale == input scale
215   // * filter scale for each channel in affine quantization once bias
216   // quantization is properly populated.
217   // TF_LITE_ENSURE_EQ(context, bias->quantization.type,
218   // kTfLiteAffineQuantization);
219 
220   // Check data type.
221   const auto* affine_quantization =
222       reinterpret_cast<TfLiteAffineQuantization*>(filter->quantization.params);
223   TF_LITE_ENSURE(context, affine_quantization);
224   TF_LITE_ENSURE(context, affine_quantization->scale);
225   const bool is_per_channel = affine_quantization->scale->size > 1;
226   if (is_per_channel) {
227     //  Currently only Int8/Int16 is supported for per channel quantization.
228     TF_LITE_ENSURE(context,
229                    input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
230     TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8);
231     TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, num_channels);
232     TF_LITE_ENSURE_EQ(
233         context, num_channels,
234         filter->dims->data[affine_quantization->quantized_dimension]);
235   }
236 
237   // Populate multiplier and shift using affine quantization.
238   const float input_scale = input->params.scale;
239   const float output_scale = output->params.scale;
240   const float* filter_scales = affine_quantization->scale->data;
241   for (int i = 0; i < num_channels; ++i) {
242     // If per-tensor quantization parameter is specified, broadcast it along the
243     // quantization dimension (channels_out).
244     const float scale = is_per_channel ? filter_scales[i] : filter_scales[0];
245     const double filter_scale = static_cast<double>(scale);
246     const double effective_output_scale = static_cast<double>(input_scale) *
247                                           filter_scale /
248                                           static_cast<double>(output_scale);
249     int32_t significand;
250     int channel_shift;
251     QuantizeMultiplier(effective_output_scale, &significand, &channel_shift);
252     per_channel_multiplier[i] = significand;
253     per_channel_shift[i] = channel_shift;
254   }
255 
256   // Populate scalar quantization parameters.
257   // This check on legacy quantization parameters is kept only for backward
258   // compatibility.
259   if (input->type == kTfLiteUInt8) {
260     // Check bias scale == input scale * filter scale.
261     double real_multiplier = 0.0;
262     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
263         context, input, filter, bias, output, &real_multiplier));
264     int exponent;
265 
266     // Populate quantization parameters with multiplier and shift.
267     QuantizeMultiplier(real_multiplier, multiplier, &exponent);
268     *shift = -exponent;
269   }
270   if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
271       input->type == kTfLiteInt16) {
272     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
273         context, activation, output, output_activation_min,
274         output_activation_max));
275   }
276   return kTfLiteOk;
277 }
278 
GetQuantizedConvolutionMultipler(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,double * multiplier)279 TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
280                                               const TfLiteTensor* input,
281                                               const TfLiteTensor* filter,
282                                               const TfLiteTensor* bias,
283                                               TfLiteTensor* output,
284                                               double* multiplier) {
285   const double input_product_scale = static_cast<double>(input->params.scale) *
286                                      static_cast<double>(filter->params.scale);
287   // The following conditions must be guaranteed by the training pipeline.
288   if (bias) {
289     const double bias_scale = static_cast<double>(bias->params.scale);
290     // Here we're making sure the input_product_scale & bias_scale are about the
291     // same. Since we have:
292     // (output - output_zp) * output_scale =
293     // input_product_scale * input_product + bias * bias_scale ---- (0)
294     //
295     // (0) equals:
296     // (input_product + bias) * input_product_scale ----- (1)
297     //           +
298     // bias * (bias_scale - input_product_scale)   ------ (2)
299     //
300     // For the real kernel computation, we're doing (1), so we really need to
301     // make sure (2) has minimum impact on the output, so:
302     // bias * (bias_scale - input_product_scale) / output_scale should be
303     // a small number for an integer.
304     // Since normally bias should be within a small range.
305     // We should expect (bias_scale - input_product_scale) / output_scale to
306     // be a small number like 0.02.
307     const double scale_diff = std::abs(input_product_scale - bias_scale);
308     const double output_scale = static_cast<double>(output->params.scale);
309 
310     TF_LITE_ENSURE(context, scale_diff / output_scale <= 0.02);
311   }
312   return GetQuantizedConvolutionMultipler(context, input, filter, output,
313                                           multiplier);
314 }
315 
GetQuantizedConvolutionMultipler(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,TfLiteTensor * output,double * multiplier)316 TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
317                                               const TfLiteTensor* input,
318                                               const TfLiteTensor* filter,
319                                               TfLiteTensor* output,
320                                               double* multiplier) {
321   const double input_product_scale =
322       static_cast<double>(input->params.scale * filter->params.scale);
323   TF_LITE_ENSURE(context, input_product_scale >= 0);
324   *multiplier = input_product_scale / static_cast<double>(output->params.scale);
325 
326   return kTfLiteOk;
327 }
328 
329 namespace {
CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation,int32_t qmin,int32_t qmax,TfLiteTensor * output,int32_t * act_min,int32_t * act_max)330 void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation,
331                                            int32_t qmin, int32_t qmax,
332                                            TfLiteTensor* output,
333                                            int32_t* act_min, int32_t* act_max) {
334   const auto scale = output->params.scale;
335   const auto zero_point = output->params.zero_point;
336 
337   auto quantize = [scale, zero_point](float f) {
338     return zero_point + static_cast<int32_t>(TfLiteRound(f / scale));
339   };
340 
341   if (activation == kTfLiteActRelu) {
342     *act_min = std::max(qmin, quantize(0.0));
343     *act_max = qmax;
344   } else if (activation == kTfLiteActRelu6) {
345     *act_min = std::max(qmin, quantize(0.0));
346     *act_max = std::min(qmax, quantize(6.0));
347   } else if (activation == kTfLiteActReluN1To1) {
348     *act_min = std::max(qmin, quantize(-1.0));
349     *act_max = std::min(qmax, quantize(1.0));
350   } else {
351     *act_min = qmin;
352     *act_max = qmax;
353   }
354 }
355 }  // namespace
356 
CalculateActivationRangeQuantized(TfLiteContext * context,TfLiteFusedActivation activation,TfLiteTensor * output,int32_t * act_min,int32_t * act_max)357 TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
358                                                TfLiteFusedActivation activation,
359                                                TfLiteTensor* output,
360                                                int32_t* act_min,
361                                                int32_t* act_max) {
362   int32_t qmin = 0;
363   int32_t qmax = 0;
364   if (output->type == kTfLiteUInt8) {
365     qmin = std::numeric_limits<uint8_t>::min();
366     qmax = std::numeric_limits<uint8_t>::max();
367   } else if (output->type == kTfLiteInt8) {
368     qmin = std::numeric_limits<int8_t>::min();
369     qmax = std::numeric_limits<int8_t>::max();
370   } else if (output->type == kTfLiteInt16) {
371     qmin = std::numeric_limits<int16_t>::min();
372     qmax = std::numeric_limits<int16_t>::max();
373   } else {
374     TF_LITE_ENSURE(context, false);
375   }
376 
377   CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min,
378                                         act_max);
379   return kTfLiteOk;
380 }
381 
HaveSameShapes(const TfLiteTensor * input1,const TfLiteTensor * input2)382 bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
383   return TfLiteIntArrayEqual(input1->dims, input2->dims);
384 }
385 
386 #ifndef TF_LITE_STATIC_MEMORY
387 
388 // TODO(b/172067338): Having this function be part of TF_LITE_STATIC_MEMORY
389 // build results in a 6KB size increase, even though the function is unsused for
390 // that build. What appears to be happening is that while the linker drops the
391 // unsused function, the string library that gets pulled in is not dropped,
392 // resulting in the increased binary size.
GetShapeDebugString(const TfLiteIntArray * shape)393 std::string GetShapeDebugString(const TfLiteIntArray* shape) {
394   std::string str;
395   for (int d = 0; d < shape->size; ++d) {
396     if (str.empty())
397       str = "[" + std::to_string(shape->data[d]);
398     else
399       str += ", " + std::to_string(shape->data[d]);
400   }
401   str += "]";
402   return str;
403 }
404 
CalculateShapeForBroadcast(TfLiteContext * context,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteIntArray ** output_shape)405 TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
406                                         const TfLiteTensor* input1,
407                                         const TfLiteTensor* input2,
408                                         TfLiteIntArray** output_shape) {
409   int dims1 = NumDimensions(input1);
410   int dims2 = NumDimensions(input2);
411   int out_dims = std::max(dims1, dims2);
412   if (NumElements(input1) == 0) {
413     *output_shape = TfLiteIntArrayCopy(input1->dims);
414     return kTfLiteOk;
415   }
416   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
417       TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree);
418   for (int i = 0; i < out_dims; ++i) {
419     int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
420     int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
421     if (!(d1 == d2 || d1 == 1 || d2 == 1)) {
422       context->ReportError(context,
423                            "Given shapes, %s and %s, are not broadcastable.",
424                            GetShapeDebugString(input1->dims).c_str(),
425                            GetShapeDebugString(input2->dims).c_str());
426       return kTfLiteError;
427     }
428     shape->data[out_dims - i - 1] = std::max(d1, d2);
429   }
430   *output_shape = shape.release();
431   return kTfLiteOk;
432 }
433 
CalculateShapeForBroadcast(TfLiteContext * context,const TfLiteTensor * input1,const TfLiteTensor * input2,const TfLiteTensor * input3,TfLiteIntArray ** output_shape)434 TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
435                                         const TfLiteTensor* input1,
436                                         const TfLiteTensor* input2,
437                                         const TfLiteTensor* input3,
438                                         TfLiteIntArray** output_shape) {
439   int dims1 = NumDimensions(input1);
440   int dims2 = NumDimensions(input2);
441   int dims3 = NumDimensions(input3);
442   int out_dims = std::max(std::max(dims1, dims2), dims3);
443   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
444       TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree);
445   for (int i = 0; i < out_dims; ++i) {
446     int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
447     int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
448     int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
449     int max_value = std::max(std::max(d1, d2), d3);
450     if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
451         !(d3 == 1 || d3 == max_value)) {
452       context->ReportError(
453           context, "Given shapes, %s, %s and %s, are not broadcastable.",
454           GetShapeDebugString(input1->dims).c_str(),
455           GetShapeDebugString(input2->dims).c_str(),
456           GetShapeDebugString(input3->dims).c_str());
457       return kTfLiteError;
458     }
459     shape->data[out_dims - i - 1] = max_value;
460   }
461   *output_shape = shape.release();
462   return kTfLiteOk;
463 }
464 #endif  // TF_LITE_STATIC_MEMORY
465 
466 // Size of string is not constant, return 0 in such case.
TfLiteTypeGetSize(TfLiteType type)467 int TfLiteTypeGetSize(TfLiteType type) {
468   switch (type) {
469     case kTfLiteUInt8:
470       TF_LITE_ASSERT_EQ(sizeof(uint8_t), 1);
471       return 1;
472     case kTfLiteInt8:
473       TF_LITE_ASSERT_EQ(sizeof(int8_t), 1);
474       return 1;
475     case kTfLiteBool:
476       return sizeof(bool);
477     case kTfLiteInt16:
478       TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
479       return 2;
480     case kTfLiteFloat16:
481       TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
482       return 2;
483     case kTfLiteFloat32:
484       TF_LITE_ASSERT_EQ(sizeof(float), 4);
485       return 4;
486     case kTfLiteInt32:
487       TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
488       return 4;
489     case kTfLiteUInt32:
490       TF_LITE_ASSERT_EQ(sizeof(uint32_t), 4);
491       return 4;
492     case kTfLiteInt64:
493       TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
494       return 8;
495     case kTfLiteUInt64:
496       TF_LITE_ASSERT_EQ(sizeof(uint64_t), 8);
497       return 8;
498     case kTfLiteFloat64:
499       TF_LITE_ASSERT_EQ(sizeof(double), 8);
500       return 8;
501     case kTfLiteComplex64:
502       TF_LITE_ASSERT_EQ(sizeof(std::complex<float>), 8);
503       return 8;
504     case kTfLiteComplex128:
505       TF_LITE_ASSERT_EQ(sizeof(std::complex<double>), 16);
506       return 16;
507     default:
508       return 0;
509   }
510 }
511 
512 }  // namespace tflite
513