1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
18 
19 #include <limits>
20 #include <vector>
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 
24 namespace tflite {
25 namespace strided_slice {
26 
27 // Use until std::clamp() is available from C++17.
Clamp(const int v,const int lo,const int hi)28 inline int Clamp(const int v, const int lo, const int hi) {
29   TFLITE_DCHECK(!(hi < lo));
30   if (hi < v) return hi;
31   if (v < lo) return lo;
32   return v;
33 }
34 
StridedSlicePadIndices(tflite::StridedSliceParams * p,int dim_count)35 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
36                                    int dim_count) {
37   // Add indices and mask bits to fully include extra dimensions
38   TFLITE_CHECK_LE(dim_count, 4);
39   TFLITE_CHECK_GE(dim_count, p->start_indices_count);
40   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
41   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
42 
43   const int pad_count = dim_count - p->start_indices_count;
44 
45   // Pad indices at start, so move arrays by pad_count.
46   for (int i = p->start_indices_count - 1; i > 0; --i) {
47     p->strides[i + pad_count] = p->strides[i];
48     p->start_indices[i + pad_count] = p->start_indices[i];
49     p->stop_indices[i + pad_count] = p->stop_indices[i];
50   }
51   for (int i = 0; i < pad_count; ++i) {
52     p->start_indices[i] = 0;
53     p->stop_indices[i] = 0;
54     p->strides[i] = 1;
55   }
56 
57   // Pad masks with 0s or 1s as required.
58   p->shrink_axis_mask <<= pad_count;
59   p->ellipsis_mask <<= pad_count;
60   p->new_axis_mask <<= pad_count;
61   p->begin_mask <<= pad_count;
62   p->end_mask <<= pad_count;
63   p->begin_mask |= (1 << pad_count) - 1;
64   p->end_mask |= (1 << pad_count) - 1;
65 
66   p->start_indices_count = dim_count;
67   p->stop_indices_count = dim_count;
68   p->strides_count = dim_count;
69 }
70 
71 // Return the index for the first element along that axis. This index will be a
72 // positive integer between [0, axis_size - 1] that can be used to index
73 // directly into the data.
StartForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis)74 inline int StartForAxis(const tflite::StridedSliceParams& params,
75                         const RuntimeShape& input_shape, int axis) {
76   const auto begin_mask = params.begin_mask;
77   const auto* start_indices = params.start_indices;
78   const auto* strides = params.strides;
79   // Begin with the specified index.
80   int start = start_indices[axis];
81 
82   // begin_mask override
83   if (begin_mask & 1 << axis) {
84     if (strides[axis] > 0) {
85       // Forward iteration - use the first element. These values will get
86       // clamped below (Note: We could have set them to 0 and axis_size-1, but
87       // use lowest() and max() to maintain symmetry with StopForAxis())
88       start = std::numeric_limits<int>::lowest();
89     } else {
90       // Backward iteration - use the last element.
91       start = std::numeric_limits<int>::max();
92     }
93   }
94 
95   // Handle negative indices
96   int axis_size = input_shape.Dims(axis);
97   if (start < 0) {
98     start += axis_size;
99   }
100 
101   // Clamping
102   start = Clamp(start, 0, axis_size - 1);
103 
104   return start;
105 }
106 
107 // Return the "real" index for the end of iteration along that axis. This is an
108 // "end" in the traditional C sense, in that it points to one past the last
109 // element. ie. So if you were iterating through all elements of a 1D array of
110 // size 4, this function would return 4 as the stop, because it is one past the
111 // "real" indices of 0, 1, 2 & 3.
StopForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis,int start_for_axis)112 inline int StopForAxis(const tflite::StridedSliceParams& params,
113                        const RuntimeShape& input_shape, int axis,
114                        int start_for_axis) {
115   const auto end_mask = params.end_mask;
116   const auto shrink_axis_mask = params.shrink_axis_mask;
117   const auto* stop_indices = params.stop_indices;
118   const auto* strides = params.strides;
119 
120   // Begin with the specified index
121   const bool shrink_axis = shrink_axis_mask & (1 << axis);
122   int stop = stop_indices[axis];
123 
124   // When shrinking an axis, the end position does not matter (and can be
125   // incorrect when negative indexing is used, see Issue #19260). Always use
126   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
127   // already been adjusted for negative indices.
128   if (shrink_axis) {
129     stop = start_for_axis + 1;
130   }
131 
132   // end_mask override
133   if (end_mask & (1 << axis)) {
134     if (strides[axis] > 0) {
135       // Forward iteration - use the last element. These values will get
136       // clamped below
137       stop = std::numeric_limits<int>::max();
138     } else {
139       // Backward iteration - use the first element.
140       stop = std::numeric_limits<int>::lowest();
141     }
142   }
143 
144   // Handle negative indices
145   const int axis_size = input_shape.Dims(axis);
146   if (stop < 0) {
147     stop += axis_size;
148   }
149 
150   // Clamping
151   // Because the end index points one past the last element, we need slightly
152   // different clamping ranges depending on the direction.
153   if (strides[axis] > 0) {
154     // Forward iteration
155     stop = Clamp(stop, 0, axis_size);
156   } else {
157     // Backward iteration
158     stop = Clamp(stop, -1, axis_size - 1);
159   }
160 
161   return stop;
162 }
163 
LoopCondition(int index,int stop,int stride)164 inline bool LoopCondition(int index, int stop, int stride) {
165   // True when we have reached the end of an axis and should loop.
166   return stride > 0 ? index >= stop : index <= stop;
167 }
168 
BuildStridedSliceParams(int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides)169 inline tflite::StridedSliceParams BuildStridedSliceParams(
170     int begin_mask, int end_mask, int shrink_axis_mask,
171     const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
172     const std::vector<int>& strides) {
173   tflite::StridedSliceParams op_params;
174   const int dims_count = start_indices.size();
175 
176   op_params.start_indices_count = dims_count;
177   op_params.stop_indices_count = dims_count;
178   op_params.strides_count = dims_count;
179   for (int i = 0; i < dims_count; ++i) {
180     op_params.start_indices[i] = start_indices[i];
181     op_params.stop_indices[i] = stop_indices[i];
182     op_params.strides[i] = strides[i];
183   }
184 
185   op_params.begin_mask = begin_mask;
186   op_params.ellipsis_mask = 0;
187   op_params.end_mask = end_mask;
188   op_params.new_axis_mask = 0;
189   op_params.shrink_axis_mask = shrink_axis_mask;
190 
191   return op_params;
192 }
193 
194 }  // namespace strided_slice
195 
196 }  // namespace tflite
197 
198 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
199