1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
18
19 #include <limits>
20 #include <vector>
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23
24 namespace tflite {
25 namespace strided_slice {
26
27 // Use until std::clamp() is available from C++17.
Clamp(const int v,const int lo,const int hi)28 inline int Clamp(const int v, const int lo, const int hi) {
29 TFLITE_DCHECK(!(hi < lo));
30 if (hi < v) return hi;
31 if (v < lo) return lo;
32 return v;
33 }
34
StridedSlicePadIndices(tflite::StridedSliceParams * p,int dim_count)35 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
36 int dim_count) {
37 // Add indices and mask bits to fully include extra dimensions
38 TFLITE_CHECK_LE(dim_count, 4);
39 TFLITE_CHECK_GE(dim_count, p->start_indices_count);
40 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
41 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
42
43 const int pad_count = dim_count - p->start_indices_count;
44
45 // Pad indices at start, so move arrays by pad_count.
46 for (int i = p->start_indices_count - 1; i > 0; --i) {
47 p->strides[i + pad_count] = p->strides[i];
48 p->start_indices[i + pad_count] = p->start_indices[i];
49 p->stop_indices[i + pad_count] = p->stop_indices[i];
50 }
51 for (int i = 0; i < pad_count; ++i) {
52 p->start_indices[i] = 0;
53 p->stop_indices[i] = 0;
54 p->strides[i] = 1;
55 }
56
57 // Pad masks with 0s or 1s as required.
58 p->shrink_axis_mask <<= pad_count;
59 p->ellipsis_mask <<= pad_count;
60 p->new_axis_mask <<= pad_count;
61 p->begin_mask <<= pad_count;
62 p->end_mask <<= pad_count;
63 p->begin_mask |= (1 << pad_count) - 1;
64 p->end_mask |= (1 << pad_count) - 1;
65
66 p->start_indices_count = dim_count;
67 p->stop_indices_count = dim_count;
68 p->strides_count = dim_count;
69 }
70
71 // Return the index for the first element along that axis. This index will be a
72 // positive integer between [0, axis_size - 1] that can be used to index
73 // directly into the data.
StartForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis)74 inline int StartForAxis(const tflite::StridedSliceParams& params,
75 const RuntimeShape& input_shape, int axis) {
76 const auto begin_mask = params.begin_mask;
77 const auto* start_indices = params.start_indices;
78 const auto* strides = params.strides;
79 // Begin with the specified index.
80 int start = start_indices[axis];
81
82 // begin_mask override
83 if (begin_mask & 1 << axis) {
84 if (strides[axis] > 0) {
85 // Forward iteration - use the first element. These values will get
86 // clamped below (Note: We could have set them to 0 and axis_size-1, but
87 // use lowest() and max() to maintain symmetry with StopForAxis())
88 start = std::numeric_limits<int>::lowest();
89 } else {
90 // Backward iteration - use the last element.
91 start = std::numeric_limits<int>::max();
92 }
93 }
94
95 // Handle negative indices
96 int axis_size = input_shape.Dims(axis);
97 if (start < 0) {
98 start += axis_size;
99 }
100
101 // Clamping
102 start = Clamp(start, 0, axis_size - 1);
103
104 return start;
105 }
106
107 // Return the "real" index for the end of iteration along that axis. This is an
108 // "end" in the traditional C sense, in that it points to one past the last
109 // element. ie. So if you were iterating through all elements of a 1D array of
110 // size 4, this function would return 4 as the stop, because it is one past the
111 // "real" indices of 0, 1, 2 & 3.
StopForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis,int start_for_axis)112 inline int StopForAxis(const tflite::StridedSliceParams& params,
113 const RuntimeShape& input_shape, int axis,
114 int start_for_axis) {
115 const auto end_mask = params.end_mask;
116 const auto shrink_axis_mask = params.shrink_axis_mask;
117 const auto* stop_indices = params.stop_indices;
118 const auto* strides = params.strides;
119
120 // Begin with the specified index
121 const bool shrink_axis = shrink_axis_mask & (1 << axis);
122 int stop = stop_indices[axis];
123
124 // When shrinking an axis, the end position does not matter (and can be
125 // incorrect when negative indexing is used, see Issue #19260). Always use
126 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
127 // already been adjusted for negative indices.
128 if (shrink_axis) {
129 stop = start_for_axis + 1;
130 }
131
132 // end_mask override
133 if (end_mask & (1 << axis)) {
134 if (strides[axis] > 0) {
135 // Forward iteration - use the last element. These values will get
136 // clamped below
137 stop = std::numeric_limits<int>::max();
138 } else {
139 // Backward iteration - use the first element.
140 stop = std::numeric_limits<int>::lowest();
141 }
142 }
143
144 // Handle negative indices
145 const int axis_size = input_shape.Dims(axis);
146 if (stop < 0) {
147 stop += axis_size;
148 }
149
150 // Clamping
151 // Because the end index points one past the last element, we need slightly
152 // different clamping ranges depending on the direction.
153 if (strides[axis] > 0) {
154 // Forward iteration
155 stop = Clamp(stop, 0, axis_size);
156 } else {
157 // Backward iteration
158 stop = Clamp(stop, -1, axis_size - 1);
159 }
160
161 return stop;
162 }
163
LoopCondition(int index,int stop,int stride)164 inline bool LoopCondition(int index, int stop, int stride) {
165 // True when we have reached the end of an axis and should loop.
166 return stride > 0 ? index >= stop : index <= stop;
167 }
168
BuildStridedSliceParams(int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides)169 inline tflite::StridedSliceParams BuildStridedSliceParams(
170 int begin_mask, int end_mask, int shrink_axis_mask,
171 const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
172 const std::vector<int>& strides) {
173 tflite::StridedSliceParams op_params;
174 const int dims_count = start_indices.size();
175
176 op_params.start_indices_count = dims_count;
177 op_params.stop_indices_count = dims_count;
178 op_params.strides_count = dims_count;
179 for (int i = 0; i < dims_count; ++i) {
180 op_params.start_indices[i] = start_indices[i];
181 op_params.stop_indices[i] = stop_indices[i];
182 op_params.strides[i] = strides[i];
183 }
184
185 op_params.begin_mask = begin_mask;
186 op_params.ellipsis_mask = 0;
187 op_params.end_mask = end_mask;
188 op_params.new_axis_mask = 0;
189 op_params.shrink_axis_mask = shrink_axis_mask;
190
191 return op_params;
192 }
193
194 } // namespace strided_slice
195
196 } // namespace tflite
197
198 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
199