1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 
25 namespace tflite {
26 
27 enum class FusedActivationFunctionType : uint8_t {
28   kNone,
29   kRelu6,
30   kRelu1,
31   kRelu
32 };
33 enum class PaddingType : uint8_t { kNone, kSame, kValid };
34 
35 struct PaddingValues {
36   int16_t width;
37   int16_t height;
38   // offset is used for calculating "remaining" padding, for example, `width`
39   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
40   // 1 + 1 = 2.
41   int16_t width_offset;
42   // Same as width_offset except it's over the height dimension.
43   int16_t height_offset;
44 };
45 
46 struct Padding3DValues {
47   int16_t width;
48   int16_t height;
49   int16_t depth;
50   // offset is used for calculating "remaining" padding, for example, `width`
51   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
52   // 1 + 1 = 2.
53   int16_t width_offset;
54   // Same as width_offset except it's over the height dimension.
55   int16_t height_offset;
56   // Same as width_offset except it's over the depth dimension.
57   int16_t depth_offset;
58 };
59 
60 // This enumeration allows for non-default formats for the weights array
61 // of a fully-connected operator, allowing the use of special optimized
62 // runtime paths.
63 enum class FullyConnectedWeightsFormat : uint8_t {
64   // Default format (flat 2D layout, the inner contiguous dimension
65   // is input_depth, the outer non-contiguous dimension is output_depth)
66   kDefault,
67   // Summary: optimized layout for fast CPU runtime implementation,
68   // aimed specifically at ARM CPUs at the moment, and specialized for
69   // 8-bit quantized layers.
70   //
71   // The use case we're concerned with here is: 8-bit quantization,
72   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
73   // a key application that drove this), very small batch size (e.g. 1 -- 4).
74   //
75   // Even with 8-bit quantization of weights, the performance of memory
76   // accesses to the weights can become the dominant issue when
77   // the batch size is small, so each weight value is used in only a few
78   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
79   // intensity. The specific issues that arise are of three kinds:
80   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
81   //     bound. That's the "good" issue to run into.
82   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
83   //     prefetched into the cache by the time we need it.
84   // (3) One may run into cache aliasing: multiple values that are
85   //     pre-fetched, alias each other in the L1 cache (which typically
86   //     has only 4-way set associativity in ARM CPUs) and thus evict
87   //     each other before we get to using them.
88   //
89   // The point of this shuffling is to avoid issues (2) and (3) so that
90   // we get as fast as possible given only the hard constraint (1).
91   // This is achieved by turning the difficulty into a solution: the
92   // difficulty, that each value loaded from memory is used only in
93   // one kernel iteration, making this operation memory-intensive, hints at
94   // the solution, of shuffling the weights so that they are stored in the
95   // exact order as the kernel needs to load them, so that the memory
96   // accesses made by the kernel are trivial. This solves (2) because the
97   // trivial memory access pattern allows the CPU's automatic prefetching
98   // to perform very well (no need even for preload instructions), and this
99   // solves (3) because the values being loaded concurrently are now
100   // contiguous in the address space, thus don't alias each other in the cache.
101   //
102   // On ARM, we typically want our kernel to process a 4x16 block of weights
103   // at a time, because:
104   //   - 16 is the number of bytes in a NEON register.
105   //   - 4 is how many rows we need to handle concurrently in the kernel in
106   //     order to have sufficient mutual independence of instructions to
107   //     maximize arithmetic throughput.
108   //
109   // Finally, the 'Int8' part in the name refers to the fact that this
110   // weights format has each weights value encoded as a signed int8_t value,
111   // even if the data type of the weights buffer is uint8_t.  This is intended
112   // to save runtime kernels the effort to have to XOR the top bit of these
113   // bytes before using them in signed arithmetic, see this file for more
114   // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
115   //
116   //   tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
117   //
118   kShuffled4x16Int8,
119 };
120 
121 // Quantization parameters, determining the mapping of quantized values
122 // to real values (i.e. determining how quantized values are mathematically
123 // interpreted).
124 //
125 // The correspondence is as follows:
126 //
127 //   real_value = scale * (quantized_value - zero_point);
128 //
129 // In other words, zero_point designates which quantized value corresponds to
130 // the real 0 value, and scale designates the difference between the real values
131 // corresponding to consecutive quantized values differing by 1.
132 struct QuantizationParams {
133   int32_t zero_point = 0;
134   double scale = 0.0;
135 };
136 
137 inline bool operator==(const QuantizationParams& qp1,
138                        const QuantizationParams& qp2) {
139   return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
140 }
141 
142 template <int N>
143 struct Dims {
144   int sizes[N];
145   int strides[N];
146 };
147 
148 class RuntimeShape {
149  public:
150   // Shapes with dimensions up to 5 are stored directly in the structure, while
151   // larger shapes are separately allocated.
152   static constexpr int kMaxSmallSize = 5;
153 
154   RuntimeShape& operator=(RuntimeShape const&) = delete;
155 
RuntimeShape()156   RuntimeShape() : size_(0) {}
157 
RuntimeShape(int dimensions_count)158   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
159     if (dimensions_count > kMaxSmallSize) {
160 #ifdef TF_LITE_STATIC_MEMORY
161       TFLITE_CHECK(false && "No shape resizing supported on this platform");
162 #else  // TF_LITE_STATIC_MEMORY
163       dims_pointer_ = new int32_t[dimensions_count];
164 #endif  // TF_LITE_STATIC_MEMORY
165     }
166   }
167 
RuntimeShape(int shape_size,int32_t value)168   RuntimeShape(int shape_size, int32_t value) : size_(0) {
169     Resize(shape_size);
170     for (int i = 0; i < shape_size; ++i) {
171       SetDim(i, value);
172     }
173   }
174 
RuntimeShape(int dimensions_count,const int32_t * dims_data)175   RuntimeShape(int dimensions_count, const int32_t* dims_data) : size_(0) {
176     ReplaceWith(dimensions_count, dims_data);
177   }
178 
RuntimeShape(const std::initializer_list<int> init_list)179   RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
180     BuildFrom(init_list);
181   }
182 
183   // Avoid using this constructor.  We should be able to delete it when C++17
184   // rolls out.
RuntimeShape(RuntimeShape const & other)185   RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
186     if (size_ > kMaxSmallSize) {
187 #ifdef TF_LITE_STATIC_MEMORY
188       TFLITE_CHECK(false && "No shape resizing supported on this platform");
189 #else
190       dims_pointer_ = new int32_t[size_];
191 #endif
192     }
193     std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
194   }
195 
196   bool operator==(const RuntimeShape& comp) const {
197     return this->size_ == comp.size_ &&
198            std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) ==
199                0;
200   }
201 
~RuntimeShape()202   ~RuntimeShape() {
203     if (size_ > kMaxSmallSize) {
204 #ifdef TF_LITE_STATIC_MEMORY
205       TFLITE_CHECK(false && "No shape resizing supported on this platform");
206 #else  // TF_LITE_STATIC_MEMORY
207       delete[] dims_pointer_;
208 #endif  // TF_LITE_STATIC_MEMORY
209     }
210   }
211 
DimensionsCount()212   inline int32_t DimensionsCount() const { return size_; }
Dims(int i)213   inline int32_t Dims(int i) const {
214     TFLITE_DCHECK_GE(i, 0);
215     TFLITE_DCHECK_LT(i, size_);
216     return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
217   }
SetDim(int i,int32_t val)218   inline void SetDim(int i, int32_t val) {
219     TFLITE_DCHECK_GE(i, 0);
220     TFLITE_DCHECK_LT(i, size_);
221     if (size_ > kMaxSmallSize) {
222       dims_pointer_[i] = val;
223     } else {
224       dims_[i] = val;
225     }
226   }
227 
DimsData()228   inline int32_t* DimsData() {
229     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
230   }
DimsData()231   inline const int32_t* DimsData() const {
232     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
233   }
234   // The caller must ensure that the shape is no bigger than 5-D.
DimsDataUpTo5D()235   inline const int32_t* DimsDataUpTo5D() const { return dims_; }
236 
Resize(int dimensions_count)237   inline void Resize(int dimensions_count) {
238     if (size_ > kMaxSmallSize) {
239 #ifdef TF_LITE_STATIC_MEMORY
240       TFLITE_CHECK(false && "No shape resizing supported on this platform");
241 #else  // TF_LITE_STATIC_MEMORY
242       delete[] dims_pointer_;
243 #endif  // TF_LITE_STATIC_MEMORY
244     }
245     size_ = dimensions_count;
246     if (dimensions_count > kMaxSmallSize) {
247 #ifdef TF_LITE_STATIC_MEMORY
248       TFLITE_CHECK(false && "No shape resizing supported on this platform");
249 #else  // TF_LITE_STATIC_MEMORY
250       dims_pointer_ = new int32_t[dimensions_count];
251 #endif  // TF_LITE_STATIC_MEMORY
252     }
253   }
254 
ReplaceWith(int dimensions_count,const int32_t * dims_data)255   inline void ReplaceWith(int dimensions_count, const int32_t* dims_data) {
256     Resize(dimensions_count);
257     int32_t* dst_dims = DimsData();
258     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
259   }
260 
261   template <typename T>
BuildFrom(const T & src_iterable)262   inline void BuildFrom(const T& src_iterable) {
263     const int dimensions_count =
264         std::distance(src_iterable.begin(), src_iterable.end());
265     Resize(dimensions_count);
266     int32_t* data = DimsData();
267     for (auto it : src_iterable) {
268       *data = it;
269       ++data;
270     }
271   }
272 
273   // This will probably be factored out. Old code made substantial use of 4-D
274   // shapes, and so this function is used to extend smaller shapes. Note that
275   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
276   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
277   // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)278   inline static RuntimeShape ExtendedShape(int new_shape_size,
279                                            const RuntimeShape& shape) {
280     return RuntimeShape(new_shape_size, shape, 1);
281   }
282 
BuildFrom(const std::initializer_list<int> init_list)283   inline void BuildFrom(const std::initializer_list<int> init_list) {
284     BuildFrom<const std::initializer_list<int>>(init_list);
285   }
286 
287   // Returns the total count of elements, that is the size when flattened into a
288   // vector.
FlatSize()289   inline int FlatSize() const {
290     int buffer_size = 1;
291     const int* dims_data = reinterpret_cast<const int*>(DimsData());
292     for (int i = 0; i < size_; i++) {
293       buffer_size *= dims_data[i];
294     }
295     return buffer_size;
296   }
297 
298   bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
299 
300  private:
301   // For use only by ExtendedShape(), written to guarantee (return-value) copy
302   // elision in C++17.
303   // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)304   RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
305       : size_(0) {
306     // If the following check fails, it is likely because a 4D-only kernel is
307     // being used with an array of larger dimension count.
308     TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
309     Resize(new_shape_size);
310     const int size_increase = new_shape_size - shape.DimensionsCount();
311     for (int i = 0; i < size_increase; ++i) {
312       SetDim(i, pad_value);
313     }
314     std::memcpy(DimsData() + size_increase, shape.DimsData(),
315                 sizeof(int32_t) * shape.DimensionsCount());
316   }
317 
318   int32_t size_;
319   union {
320     int32_t dims_[kMaxSmallSize];
321     int32_t* dims_pointer_;
322   };
323 };
324 
325 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)326 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
327   tflite::Dims<4> result;
328   const int dimensions_count = array_shape.DimensionsCount();
329   TFLITE_CHECK_LE(dimensions_count, 4);
330   int cum_prod = 1;
331   for (int i = 0; i < 4; i++) {
332     const int new_dim =
333         (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
334     result.sizes[i] = new_dim;
335     result.strides[i] = cum_prod;
336     cum_prod *= new_dim;
337   }
338   return result;
339 }
340 
341 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)342 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
343   return RuntimeShape(
344       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
345 }
346 
347 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)348 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
349   if (num_dims == 0) {
350     return false;
351   }
352   TFLITE_DCHECK(dims != nullptr);
353   TFLITE_DCHECK(current != nullptr);
354   int carry = 1;
355   for (int idx = num_dims - 1; idx >= 0; --idx) {
356     int current_val = current[idx] + carry;
357     TFLITE_DCHECK_GE(dims[idx], current_val);
358     if (dims[idx] == current_val) {
359       current[idx] = 0;
360     } else {
361       current[idx] = current_val;
362       carry = 0;
363       break;
364     }
365   }
366   return (carry == 0);
367 }
368 
369 // Gets offset of index if reducing on axis. When reducing, the flattened offset
370 // will not change, if the input index changes on the given axis. For example,
371 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
372 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
373 // offset.
374 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)375 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
376                                   const int* index, const int num_axis,
377                                   const int* axis) {
378   if (num_dims == 0) {
379     return 0;
380   }
381   TFLITE_DCHECK(dims != nullptr);
382   TFLITE_DCHECK(index != nullptr);
383   size_t offset = 0;
384   for (int idx = 0; idx < num_dims; ++idx) {
385     // if we need to skip this axis
386     bool is_axis = false;
387     if (axis != nullptr) {
388       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
389         if (idx == axis[axis_idx]) {
390           is_axis = true;
391           break;
392         }
393       }
394     }
395     if (!is_axis) {
396       offset = offset * static_cast<size_t>(dims[idx]) +
397                static_cast<size_t>(index[idx]);
398     }
399   }
400   return offset;
401 }
402 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)403 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
404   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
405   const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
406   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
407   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
408   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
409   TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
410   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
411 }
412 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3,int i4)413 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
414                   int i4) {
415   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
416   const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
417   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
418   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
419   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
420   TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
421   TFLITE_DCHECK(i4 >= 0 && i4 < dims_data[4]);
422   return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
423              dims_data[4] +
424          i4;
425 }
426 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)427 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
428   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
429   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
430   TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
431   TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
432   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
433          i3 * dims.strides[3];
434 }
435 
Offset(const Dims<4> & dims,int * index)436 inline int Offset(const Dims<4>& dims, int* index) {
437   return Offset(dims, index[0], index[1], index[2], index[3]);
438 }
439 
Offset(const RuntimeShape & shape,int * index)440 inline int Offset(const RuntimeShape& shape, int* index) {
441   return Offset(shape, index[0], index[1], index[2], index[3]);
442 }
443 
444 // Get array size, DCHECKing that the dim index is in range.
445 //
446 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
447 // already performs this check.
448 template <int N>
ArraySize(const Dims<N> & array,int index)449 int ArraySize(const Dims<N>& array, int index) {
450   TFLITE_DCHECK(index >= 0 && index < N);
451   return array.sizes[index];
452 }
453 
454 // Get common array size, DCHECKing that they all agree.
455 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)456 int MatchingArraySize(const ArrayType1& array1, int index1,
457                       const ArrayType2& array2, int index2) {
458   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
459   return ArraySize(array1, index1);
460 }
461 
462 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)463 int MatchingArraySize(const ArrayType1& array1, int index1,
464                       const ArrayType2& array2, int index2, Args... args) {
465   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
466   return MatchingArraySize(array1, index1, args...);
467 }
468 
469 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)470 inline int MatchingDim(const RuntimeShape& shape1, int index1,
471                        const RuntimeShape& shape2, int index2) {
472   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
473   return std::min(shape1.Dims(index1), shape2.Dims(index2));
474 }
475 
476 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)477 int MatchingDim(const RuntimeShape& shape1, int index1,
478                 const RuntimeShape& shape2, int index2, Args... args) {
479   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
480   return MatchingDim(shape1, index1, args...);
481 }
482 
483 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
484 template <int N>
FlatSize(const Dims<N> & dims)485 inline int FlatSize(const Dims<N>& dims) {
486   int flat_size = 1;
487   for (int i = 0; i < N; ++i) {
488     flat_size *= dims.sizes[i];
489   }
490   return flat_size;
491 }
492 
493 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)494 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
495   return FlatSize(dims);
496 }
497 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)498 inline int MatchingElementsSize(const RuntimeShape& shape,
499                                 const RuntimeShape& check_shape_0) {
500   const int size_1 = shape.FlatSize();
501   const int size_2 = check_shape_0.FlatSize();
502   TFLITE_CHECK_EQ(size_1, size_2);
503   return size_1;
504 }
505 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)506 inline int MatchingElementsSize(const RuntimeShape& shape,
507                                 const RuntimeShape& check_shape_0,
508                                 const RuntimeShape& check_shape_1) {
509   const int size_1 = shape.FlatSize();
510   const int size_2 = check_shape_0.FlatSize();
511   const int size_3 = check_shape_1.FlatSize();
512   TFLITE_CHECK_EQ(size_1, size_2);
513   TFLITE_CHECK_EQ(size_2, size_3);
514   return size_1;
515 }
516 
517 // Flat size calculation, checking that dimensions match with one or more other
518 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)519 inline int MatchingFlatSize(const RuntimeShape& shape,
520                             const RuntimeShape& check_shape_0) {
521   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
522   const int dims_count = shape.DimensionsCount();
523   for (int i = 0; i < dims_count; ++i) {
524     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
525   }
526   return shape.FlatSize();
527 }
528 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)529 inline int MatchingFlatSize(const RuntimeShape& shape,
530                             const RuntimeShape& check_shape_0,
531                             const RuntimeShape& check_shape_1) {
532   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
533   const int dims_count = shape.DimensionsCount();
534   for (int i = 0; i < dims_count; ++i) {
535     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
536   }
537   return MatchingFlatSize(shape, check_shape_1);
538 }
539 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)540 inline int MatchingFlatSize(const RuntimeShape& shape,
541                             const RuntimeShape& check_shape_0,
542                             const RuntimeShape& check_shape_1,
543                             const RuntimeShape& check_shape_2) {
544   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
545   const int dims_count = shape.DimensionsCount();
546   for (int i = 0; i < dims_count; ++i) {
547     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
548   }
549   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
550 }
551 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)552 inline int MatchingFlatSize(const RuntimeShape& shape,
553                             const RuntimeShape& check_shape_0,
554                             const RuntimeShape& check_shape_1,
555                             const RuntimeShape& check_shape_2,
556                             const RuntimeShape& check_shape_3) {
557   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
558   const int dims_count = shape.DimensionsCount();
559   for (int i = 0; i < dims_count; ++i) {
560     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
561   }
562   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
563 }
564 
565 // Flat size calculation, checking that dimensions match with one or more other
566 // arrays.
567 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)568 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
569   for (int i = 0; i < N; ++i) {
570     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
571   }
572   return FlatSize(dims);
573 }
574 
575 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)576 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
577                             const Dims<N>& check_dims_1) {
578   for (int i = 0; i < N; ++i) {
579     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
580   }
581   return MatchingFlatSize(dims, check_dims_1);
582 }
583 
584 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)585 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
586                             const Dims<N>& check_dims_1,
587                             const Dims<N>& check_dims_2) {
588   for (int i = 0; i < N; ++i) {
589     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
590   }
591   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
592 }
593 
594 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)595 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
596                             const Dims<N>& check_dims_1,
597                             const Dims<N>& check_dims_2,
598                             const Dims<N>& check_dims_3) {
599   for (int i = 0; i < N; ++i) {
600     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
601   }
602   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
603 }
604 
605 // Data is required to be contiguous, and so many operators can use either the
606 // full array flat size or the flat size with one dimension skipped (commonly
607 // the depth).
608 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)609 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
610   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
611   int flat_size = 1;
612   for (int i = 0; i < N; ++i) {
613     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
614   }
615   return flat_size;
616 }
617 
618 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
619 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)620 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
621                                    const Dims<N>& check_dims_0) {
622   for (int i = 0; i < N; ++i) {
623     if (i != skip_dim) {
624       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
625     }
626   }
627   return FlatSizeSkipDim(dims, skip_dim);
628 }
629 
630 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)631 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
632                                    const Dims<N>& check_dims_0,
633                                    const Dims<N>& check_dims_1) {
634   for (int i = 0; i < N; ++i) {
635     if (i != skip_dim) {
636       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
637     }
638   }
639   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
640 }
641 
642 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)643 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
644                                    const Dims<N>& check_dims_0,
645                                    const Dims<N>& check_dims_1,
646                                    const Dims<N>& check_dims_2) {
647   for (int i = 0; i < N; ++i) {
648     if (i != skip_dim) {
649       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
650     }
651   }
652   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
653 }
654 
655 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)656 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
657                                    const Dims<N>& check_dims_0,
658                                    const Dims<N>& check_dims_1,
659                                    const Dims<N>& check_dims_2,
660                                    const Dims<N>& check_dims_3) {
661   for (int i = 0; i < N; ++i) {
662     if (i != skip_dim) {
663       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
664     }
665   }
666   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
667                                  check_dims_3);
668 }
669 
670 // Data is required to be contiguous, and so many operators can use either the
671 // full array flat size or the flat size with one dimension skipped (commonly
672 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)673 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
674   const int dims_count = shape.DimensionsCount();
675   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
676   const auto* dims_data = shape.DimsData();
677   int flat_size = 1;
678   for (int i = 0; i < dims_count; ++i) {
679     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
680   }
681   return flat_size;
682 }
683 
684 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)685 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
686                                    const RuntimeShape& check_shape_0) {
687   const int dims_count = shape.DimensionsCount();
688   for (int i = 0; i < dims_count; ++i) {
689     if (i != skip_dim) {
690       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
691     }
692   }
693   return FlatSizeSkipDim(shape, skip_dim);
694 }
695 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)696 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
697                                    const RuntimeShape& check_shape_0,
698                                    const RuntimeShape& check_shape_1) {
699   const int dims_count = shape.DimensionsCount();
700   for (int i = 0; i < dims_count; ++i) {
701     if (i != skip_dim) {
702       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
703     }
704   }
705   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
706 }
707 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)708 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
709                                    const RuntimeShape& check_shape_0,
710                                    const RuntimeShape& check_shape_1,
711                                    const RuntimeShape& check_shape_2) {
712   const int dims_count = shape.DimensionsCount();
713   for (int i = 0; i < dims_count; ++i) {
714     if (i != skip_dim) {
715       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
716     }
717   }
718   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
719 }
720 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)721 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
722                                    const RuntimeShape& check_shape_0,
723                                    const RuntimeShape& check_shape_1,
724                                    const RuntimeShape& check_shape_2,
725                                    const RuntimeShape& check_shape_3) {
726   const int dims_count = shape.DimensionsCount();
727   for (int i = 0; i < dims_count; ++i) {
728     if (i != skip_dim) {
729       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
730     }
731   }
732   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
733                                  check_shape_3);
734 }
735 
736 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)737 bool IsPackedWithoutStrides(const Dims<N>& dims) {
738   int expected_stride = 1;
739   for (int d = 0; d < N; d++) {
740     if (dims.strides[d] != expected_stride) return false;
741     expected_stride *= dims.sizes[d];
742   }
743   return true;
744 }
745 
746 template <int N>
ComputeStrides(Dims<N> * dims)747 void ComputeStrides(Dims<N>* dims) {
748   dims->strides[0] = 1;
749   for (int d = 1; d < N; d++) {
750     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
751   }
752 }
753 
754 enum class BroadcastableOpCategory : uint8_t {
755   kNone,
756   kNonBroadcast,               // Matching input shapes.
757   kFirstInputBroadcastsFast,   // Fivefold nested loops.
758   kSecondInputBroadcastsFast,  // Fivefold nested loops.
759   kGenericBroadcast,           // Fall-back.
760 };
761 
762 struct MinMax {
763   float min;
764   float max;
765 };
766 static_assert(sizeof(MinMax) == 8, "");
767 
768 struct ActivationParams {
769   FusedActivationFunctionType activation_type;
770   // uint8_t, etc, activation params.
771   int32_t quantized_activation_min;
772   int32_t quantized_activation_max;
773 };
774 
775 struct ReluParams : public ActivationParams {
776   int32_t input_offset;
777   int32_t output_offset;
778   int32_t output_multiplier;
779   int output_shift;
780 };
781 
782 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
783 // op for pattern-specific optimization.
784 enum class ResizingCategory : uint8_t {
785   kNone,
786   kImageStyle,  // 4D, operating on inner dimensions, say {0, a, b, 0}.
787   kGenericResize,
788 };
789 
790 // For Add, Sub, Mul ops.
791 struct ArithmeticParams {
792   // Shape dependent / common to data / op types.
793   BroadcastableOpCategory broadcast_category;
794   // uint8_t inference params.
795   int32_t input1_offset;
796   int32_t input2_offset;
797   int32_t output_offset;
798   int32_t output_multiplier;
799   int output_shift;
800   // Add / Sub, not Mul, uint8_t inference params.
801   int left_shift;
802   int32_t input1_multiplier;
803   int input1_shift;
804   int32_t input2_multiplier;
805   int input2_shift;
806 
807   // TODO(b/158622529): Union the following activation params.
808   // uint8_t, etc, activation params.
809   int32_t quantized_activation_min;
810   int32_t quantized_activation_max;
811   // float activation params.
812   float float_activation_min;
813   float float_activation_max;
814   // int64_t activation params.
815   int64_t int64_activation_min;
816   int64_t int64_activation_max;
817 
818   // Processed output dimensions.
819   // Let input "a" be the one that broadcasts in the faster-changing dimension.
820   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
821   // {b0, b1, b2, b3, b4},
822   // broadcast_shape[4] = b0 = a0.
823   // broadcast_shape[3] = b1; a1 = 1.
824   // broadcast_shape[2] = b2 = a2.
825   // broadcast_shape[1] = a3; b3 = 1.
826   // broadcast_shape[0] = b4 = a4.
827   int broadcast_shape[5];
828 };
829 
830 struct ConcatenationParams {
831   int8_t axis;
832   const int32_t* input_zeropoint;
833   const float* input_scale;
834   uint16_t inputs_count;
835   int32_t output_zeropoint;
836   float output_scale;
837 };
838 
839 struct ComparisonParams {
840   // uint8_t inference params.
841   int left_shift;
842   int32_t input1_offset;
843   int32_t input1_multiplier;
844   int input1_shift;
845   int32_t input2_offset;
846   int32_t input2_multiplier;
847   int input2_shift;
848   // Shape dependent / common to inference types.
849   bool is_broadcast;
850 };
851 
852 struct ConvParams {
853   PaddingType padding_type;
854   PaddingValues padding_values;
855   // TODO(starka): This was just "stride", so check that width+height is OK.
856   int16_t stride_width;
857   int16_t stride_height;
858   int16_t dilation_width_factor;
859   int16_t dilation_height_factor;
860   // uint8_t inference params.
861   // TODO(b/65838351): Use smaller types if appropriate.
862   int32_t input_offset;
863   int32_t weights_offset;
864   int32_t output_offset;
865   int32_t output_multiplier;
866   int output_shift;
867   // uint8_t, etc, activation params.
868   int32_t quantized_activation_min;
869   int32_t quantized_activation_max;
870   // float activation params.
871   float float_activation_min;
872   float float_activation_max;
873 };
874 
875 struct Conv3DParams {
876   Padding3DValues padding_values;
877   int stride_width;
878   int stride_height;
879   int stride_depth;
880   int dilation_width;
881   int dilation_height;
882   int dilation_depth;
883   // float activation params.
884   float float_activation_min;
885   float float_activation_max;
886 };
887 
888 struct DepthToSpaceParams {
889   int32_t block_size;
890 };
891 
892 struct DepthwiseParams {
893   PaddingType padding_type;
894   PaddingValues padding_values;
895   int16_t stride_width;
896   int16_t stride_height;
897   int16_t dilation_width_factor;
898   int16_t dilation_height_factor;
899   int16_t depth_multiplier;
900   // uint8_t inference params.
901   // TODO(b/65838351): Use smaller types if appropriate.
902   int32_t input_offset;
903   int32_t weights_offset;
904   int32_t output_offset;
905   int32_t output_multiplier;
906   int output_shift;
907   // uint8_t, etc, activation params.
908   int32_t quantized_activation_min;
909   int32_t quantized_activation_max;
910   // float activation params.
911   float float_activation_min;
912   float float_activation_max;
913   const int32_t* output_multiplier_per_channel;
914   const int32_t* output_shift_per_channel;
915 };
916 
917 struct DequantizationParams {
918   double scale;
919   int32_t zero_point;
920 };
921 
922 struct PerChannelDequantizationParams {
923   const float* scale;
924   const int32_t* zero_point;
925   int32_t quantized_dimension;
926 };
927 
928 struct FakeQuantParams {
929   MinMax minmax;
930   int32_t num_bits;
931 };
932 
933 struct FullyConnectedParams {
934   // uint8_t inference params.
935   // TODO(b/65838351): Use smaller types if appropriate.
936   int32_t input_offset;
937   int32_t weights_offset;
938   int32_t output_offset;
939   int32_t output_multiplier;
940   int output_shift;
941   // uint8_t, etc, activation params.
942   int32_t quantized_activation_min;
943   int32_t quantized_activation_max;
944   // float activation params.
945   float float_activation_min;
946   float float_activation_max;
947   // Mark the operands as cacheable if they are unchanging, e.g. weights.
948   bool lhs_cacheable;
949   bool rhs_cacheable;
950   FullyConnectedWeightsFormat weights_format;
951 };
952 
953 struct GatherParams {
954   int16_t axis;
955 };
956 
957 struct L2NormalizationParams {
958   // uint8_t inference params.
959   int32_t input_zero_point;
960 };
961 
962 struct LocalResponseNormalizationParams {
963   int32_t range;
964   double bias;
965   double alpha;
966   double beta;
967 };
968 
969 struct HardSwishParams {
970   // zero_point of the input activations.
971   int16_t input_zero_point;
972   // zero_point of the output activations.
973   int16_t output_zero_point;
974   // 16bit fixed-point component of the multiplier to apply to go from the
975   // "high-res input scale", which is the input scale multiplied by 2^7, to the
976   // "relu-ish scale", which 3.0/32768.
977   // See the implementation of HardSwishPrepare.
978   int16_t reluish_multiplier_fixedpoint_int16;
979   // exponent/bit-shift component of the aforementioned multiplier.
980   int reluish_multiplier_exponent;
981   // 16bit fixed-point component of the multiplier to apply to go from the
982   // "high-res input scale", which is the input scale multiplied by 2^7, to the
983   // output scale.
984   // See the implementation of HardSwishPrepare.
985   int16_t output_multiplier_fixedpoint_int16;
986   // exponent/bit-shift component of the aforementioned multiplier.
987   int output_multiplier_exponent;
988 };
989 
990 struct LogisticParams {
991   // uint8_t inference params.
992   int32_t input_zero_point;
993   int32_t input_range_radius;
994   int32_t input_multiplier;
995   int input_left_shift;
996 };
997 
998 struct LstmCellParams {
999   int32_t weights_zero_point;
1000   int32_t accum_multiplier;
1001   int accum_shift;
1002   int state_integer_bits;
1003 };
1004 
1005 struct MeanParams {
1006   int8_t axis_count;
1007   int16_t axis[4];
1008 };
1009 
1010 struct PackParams {
1011   int8_t axis;
1012   const int32_t* input_zeropoint;
1013   const float* input_scale;
1014   uint16_t inputs_count;
1015   int32_t output_zeropoint;
1016   float output_scale;
1017 };
1018 
1019 struct PadParams {
1020   int8_t left_padding_count;
1021   int32_t left_padding[4];
1022   int8_t right_padding_count;
1023   int32_t right_padding[4];
1024   ResizingCategory resizing_category;
1025 };
1026 
1027 struct PreluParams {
1028   int32_t input_offset;
1029   int32_t alpha_offset;
1030   int32_t output_offset;
1031   int32_t output_multiplier_1;
1032   int output_shift_1;
1033   int32_t output_multiplier_2;
1034   int output_shift_2;
1035 };
1036 
1037 struct PoolParams {
1038   FusedActivationFunctionType activation;
1039   PaddingType padding_type;
1040   PaddingValues padding_values;
1041   int stride_height;
1042   int stride_width;
1043   int filter_height;
1044   int filter_width;
1045   // uint8_t, etc, activation params.
1046   int32_t quantized_activation_min;
1047   int32_t quantized_activation_max;
1048   // float activation params.
1049   float float_activation_min;
1050   float float_activation_max;
1051 };
1052 
1053 struct ReshapeParams {
1054   int8_t shape_count;
1055   int32_t shape[4];
1056 };
1057 
1058 struct ResizeBilinearParams {
1059   bool align_corners;
1060   // half_pixel_centers assumes pixels are of half the actual dimensions, and
1061   // yields more accurate resizes. Corresponds to the same argument for the
1062   // original TensorFlow op in TF2.0.
1063   bool half_pixel_centers;
1064 };
1065 
1066 struct ResizeNearestNeighborParams {
1067   bool align_corners;
1068   bool half_pixel_centers;
1069 };
1070 
1071 struct SliceParams {
1072   int8_t begin_count;
1073   int32_t begin[5];
1074   int8_t size_count;
1075   int32_t size[5];
1076 };
1077 
1078 struct SoftmaxParams {
1079   // beta is not really used (not a Tensorflow parameter) and not implemented
1080   // for LogSoftmax.
1081   double beta;
1082   // uint8_t inference params.  Used even when beta defaults to 1.0.
1083   int32_t input_multiplier;
1084   int32_t input_left_shift;
1085   // Reverse scaling is only used by LogSoftmax.
1086   int32_t reverse_scaling_divisor;
1087   int32_t reverse_scaling_right_shift;
1088   int diff_min;
1089   int32_t zero_point;
1090   float scale;
1091   float* table;
1092   // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
1093   int16_t* exp_lut;
1094   // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
1095   int16_t* one_over_one_plus_x_lut;
1096   uint8_t* uint8_table1;
1097   uint8_t* uint8_table2;
1098 };
1099 
1100 struct SpaceToBatchParams {
1101   // "Zero" padding for uint8_t means padding with the output offset.
1102   int32_t output_offset;
1103 };
1104 
1105 struct SpaceToDepthParams {
1106   int32_t block_size;
1107 };
1108 
1109 struct SplitParams {
1110   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
1111   // OperatorEdges are of type uint16_t.
1112   uint16_t num_split;
1113   int16_t axis;
1114 };
1115 
1116 struct SqueezeParams {
1117   int8_t squeeze_dims_count;
1118   int32_t squeeze_dims[4];
1119 };
1120 
1121 struct StridedSliceParams {
1122   int8_t start_indices_count;
1123   int32_t start_indices[5];
1124   int8_t stop_indices_count;
1125   int32_t stop_indices[5];
1126   int8_t strides_count;
1127   int32_t strides[5];
1128 
1129   int16_t begin_mask;
1130   int16_t ellipsis_mask;
1131   int16_t end_mask;
1132   int16_t new_axis_mask;
1133   int16_t shrink_axis_mask;
1134 };
1135 
1136 struct TanhParams {
1137   int32_t input_zero_point;
1138   int32_t input_range_radius;
1139   int32_t input_multiplier;
1140   int input_left_shift;
1141 };
1142 
1143 struct TransposeParams {
1144   int8_t perm_count;
1145   int32_t perm[5];
1146 };
1147 
1148 struct UnpackParams {
1149   uint16_t num_split;
1150   int16_t axis;
1151 };
1152 
1153 struct LeakyReluParams {
1154   float alpha;
1155   int32_t input_offset;
1156   int32_t output_offset;
1157   int32_t output_multiplier_alpha;
1158   int32_t output_shift_alpha;
1159   int32_t output_multiplier_identity;
1160   int32_t output_shift_identity;
1161 };
1162 
1163 template <typename P>
SetActivationParams(float min,float max,P * params)1164 inline void SetActivationParams(float min, float max, P* params) {
1165   params->float_activation_min = min;
1166   params->float_activation_max = max;
1167 }
1168 
1169 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1170 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1171   params->quantized_activation_min = min;
1172   params->quantized_activation_max = max;
1173 }
1174 
1175 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1176 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1177   params->int64_activation_min = min;
1178   params->int64_activation_max = max;
1179 }
1180 
1181 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1182 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1183   *min = params.quantized_activation_min;
1184   *max = params.quantized_activation_max;
1185 }
1186 
1187 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1188 inline void GetActivationParams(const P& params, float* min, float* max) {
1189   *min = params.float_activation_min;
1190   *max = params.float_activation_max;
1191 }
1192 
1193 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1194 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1195   *min = params.int64_activation_min;
1196   *max = params.int64_activation_max;
1197 }
1198 }  // namespace tflite
1199 
1200 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1201