1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
17
18 #include <cassert>
19
20 #include "ruy/profiler/instrumentation.h" // from @ruy
21 #include "tensorflow/lite/kernels/internal/types.h"
22
23 namespace tflite {
24 namespace optimized_ops {
25
26 template <typename T>
ExtractPatchIntoBufferColumn(const RuntimeShape & input_shape,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)27 inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
28 int h, int b, int kheight, int kwidth,
29 int stride_width, int stride_height,
30 int pad_width, int pad_height,
31 int in_width, int in_height,
32 int in_depth, int single_buffer_length,
33 int buffer_id, const T* in_data,
34 T* conv_buffer_data, uint8 zero_byte) {
35 ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn");
36 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
37 // This chunk of code reshapes all the inputs corresponding to
38 // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
39 const int kwidth_times_indepth = kwidth * in_depth;
40 const int inwidth_times_indepth = in_width * in_depth;
41 const int ih_ungated_start = h * stride_height - pad_height;
42 const int ih_ungated_end = (ih_ungated_start + kheight);
43 const int ih_end = std::min(ih_ungated_end, in_height);
44 const int iw_ungated_start = w * stride_width - pad_width;
45 const int iw_ungated_end = (iw_ungated_start + kwidth);
46 const int iw_end = std::min(iw_ungated_end, in_width);
47 // If the patch is off the edge of the input image, skip writing those rows
48 // and columns from the patch into the output array.
49 const int h_offset = std::max(0, -ih_ungated_start);
50 const int w_offset = std::max(0, -iw_ungated_start);
51 const int ih_start = std::max(0, ih_ungated_start);
52 const int iw_start = std::max(0, iw_ungated_start);
53 const int single_row_num =
54 std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
55 const int output_row_offset = (buffer_id * single_buffer_length);
56 int out_offset =
57 output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
58 int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
59
60 // Express all of the calculations as padding around the input patch.
61 const int top_padding = h_offset;
62 const int bottom_padding = (ih_ungated_end - ih_end);
63 const int left_padding = w_offset;
64 const int right_padding = (iw_ungated_end - iw_end);
65 assert(single_row_num ==
66 ((kwidth - (left_padding + right_padding)) * in_depth));
67
68 // Write out zeroes to the elements representing the top rows of the input
69 // patch that are off the edge of the input image.
70 if (top_padding > 0) {
71 const int top_row_elements = (top_padding * kwidth * in_depth);
72 memset(conv_buffer_data + output_row_offset, zero_byte,
73 (top_row_elements * sizeof(T)));
74 }
75
76 // If the patch is on the interior of the input image horizontally, just copy
77 // over the rows sequentially, otherwise add zero padding at the start or end.
78 if ((left_padding == 0) && (right_padding == 0)) {
79 for (int ih = ih_start; ih < ih_end; ++ih) {
80 memcpy(conv_buffer_data + out_offset, in_data + in_offset,
81 single_row_num * sizeof(T));
82 out_offset += kwidth_times_indepth;
83 in_offset += inwidth_times_indepth;
84 }
85 } else {
86 for (int ih = ih_start; ih < ih_end; ++ih) {
87 if (left_padding > 0) {
88 const int left_start = (out_offset - (left_padding * in_depth));
89 memset(conv_buffer_data + left_start, zero_byte,
90 (left_padding * in_depth * sizeof(T)));
91 }
92 memcpy(conv_buffer_data + out_offset, in_data + in_offset,
93 single_row_num * sizeof(T));
94 if (right_padding > 0) {
95 const int right_start = (out_offset + single_row_num);
96 memset(conv_buffer_data + right_start, zero_byte,
97 (right_padding * in_depth * sizeof(T)));
98 }
99 out_offset += kwidth_times_indepth;
100 in_offset += inwidth_times_indepth;
101 }
102 }
103
104 // If the bottom of the patch falls off the input image, pad the values
105 // representing those input rows with zeroes.
106 if (bottom_padding > 0) {
107 const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
108 const int bottom_start =
109 output_row_offset +
110 ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
111 memset(conv_buffer_data + bottom_start, zero_byte,
112 (bottom_row_elements * sizeof(T)));
113 }
114 }
115
116 // Supports per-batch zero_byte for per-batch asymmetric quantized inputs.
117 template <typename T>
DilatedIm2col(const ConvParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data,const int32_t * zero_bytes,const int zero_bytes_len)118 void DilatedIm2col(const ConvParams& params, const RuntimeShape& input_shape,
119 const T* input_data, const RuntimeShape& filter_shape,
120 const RuntimeShape& output_shape, T* im2col_data,
121 const int32_t* zero_bytes, const int zero_bytes_len) {
122 const int stride_width = params.stride_width;
123 const int stride_height = params.stride_height;
124 const int dilation_width_factor = params.dilation_width_factor;
125 const int dilation_height_factor = params.dilation_height_factor;
126 const int pad_width = params.padding_values.width;
127 const int pad_height = params.padding_values.height;
128 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
129 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
130 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
131
132 // For dilated convolution, the input pixels are not contiguous therefore we
133 // can't use the same optimizations as Im2Col(). Though note this code would
134 // work fine for the non-dilated case too (though likely a bit slower).
135 ruy::profiler::ScopeLabel label("DilatedIm2col");
136 TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
137 TFLITE_DCHECK(im2col_data);
138 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
139 const int input_height = input_shape.Dims(1);
140 const int input_width = input_shape.Dims(2);
141 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
142 const int filter_height = filter_shape.Dims(1);
143 const int filter_width = filter_shape.Dims(2);
144 const int output_height = output_shape.Dims(1);
145 const int output_width = output_shape.Dims(2);
146 MatchingDim(output_shape, 3, filter_shape, 0);
147
148 // Construct the MxN sized im2col matrix.
149 // The rows M, are sub-ordered B x H x W
150 const RuntimeShape row_shape({1, batches, output_height, output_width});
151 // The columns, N, are sub-ordered Kh x Kw x Din
152 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
153 // Use dimensions M and N to construct dims for indexing directly into im2col
154 const RuntimeShape im2col_shape(
155 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
156
157 // Loop through the output rows (B x H x W)
158 for (int batch = 0; batch < batches; ++batch) {
159 const T zero_byte = zero_bytes_len > 1 ? static_cast<T>(zero_bytes[batch])
160 : static_cast<T>(zero_bytes[0]);
161 for (int out_y = 0; out_y < output_height; ++out_y) {
162 for (int out_x = 0; out_x < output_width; ++out_x) {
163 // Each im2col row is an output pixel. Arrange the input data in this
164 // row in an order we can conveniently multiply with the filter data.
165 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
166 const int in_x_origin = (out_x * stride_width) - pad_width;
167 const int in_y_origin = (out_y * stride_height) - pad_height;
168 // Loop through all the pixels of the filter (Kh x Kw)
169 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
170 const int in_y = in_y_origin + dilation_height_factor * filter_y;
171 if ((in_y >= 0) && (in_y < input_height)) {
172 // Filter row is within the input data.
173 // Loop through all the filter pixels in this row.
174 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
175 const int in_x = in_x_origin + dilation_width_factor * filter_x;
176 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
177 T* dst = im2col_data +
178 Offset(im2col_shape, 0, 0, row_offset, col_offset);
179 if ((in_x >= 0) && (in_x < input_width)) {
180 // Filter pixel is within the input, copy the input data.
181 T const* src =
182 input_data + Offset(input_shape, batch, in_y, in_x, 0);
183 memcpy(dst, src, input_depth * sizeof(T));
184 } else {
185 // Filter pixel is outside the input, zero it out.
186 memset(dst, zero_byte, input_depth * sizeof(T));
187 }
188 }
189 } else {
190 // Filter row is outside the input, zero out the entire filter row.
191 int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
192 T* dst = im2col_data +
193 Offset(im2col_shape, 0, 0, row_offset, col_offset);
194 memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
195 }
196 }
197 }
198 }
199 }
200 }
201
202 template <typename T>
DilatedIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)203 void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
204 const RuntimeShape& input_shape, const T* input_data,
205 const RuntimeShape& filter_shape,
206 const RuntimeShape& output_shape, T* im2col_data) {
207 const int32_t zero_point = static_cast<int32_t>(zero_byte);
208 DilatedIm2col<T>(params, input_shape, input_data, filter_shape, output_shape,
209 im2col_data, &zero_point, 1);
210 }
211
212 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)213 void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
214 const RuntimeShape& input_shape, const T* input_data,
215 const RuntimeShape& output_shape, T* output_data) {
216 ruy::profiler::ScopeLabel label("Im2col");
217 const int stride_width = params.stride_width;
218 const int stride_height = params.stride_height;
219 const int pad_width = params.padding_values.width;
220 const int pad_height = params.padding_values.height;
221 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
222 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
223
224 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
225 const int input_depth = input_shape.Dims(3);
226 const int input_width = input_shape.Dims(2);
227 const int input_height = input_shape.Dims(1);
228 const int output_depth = output_shape.Dims(3);
229 const int output_width = output_shape.Dims(2);
230 const int output_height = output_shape.Dims(1);
231
232 int buffer_id = 0;
233 // Loop over the output nodes.
234 for (int b = 0; b < batches; ++b) {
235 for (int h = 0; h < output_height; ++h) {
236 for (int w = 0; w < output_width; ++w) {
237 ExtractPatchIntoBufferColumn(
238 input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
239 pad_width, pad_height, input_width, input_height, input_depth,
240 output_depth, buffer_id, input_data, output_data, zero_byte);
241 ++buffer_id;
242 }
243 }
244 }
245 }
246
247 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,const int32_t * input_offsets,const int input_offsets_size,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)248 void Im2col(const ConvParams& params, int kheight, int kwidth,
249 const int32_t* input_offsets, const int input_offsets_size,
250 const RuntimeShape& input_shape, const T* input_data,
251 const RuntimeShape& output_shape, T* output_data) {
252 ruy::profiler::ScopeLabel label("Im2col");
253 const int stride_width = params.stride_width;
254 const int stride_height = params.stride_height;
255 const int pad_width = params.padding_values.width;
256 const int pad_height = params.padding_values.height;
257 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
258 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
259
260 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
261 TFLITE_DCHECK_EQ(batches, input_offsets_size);
262 const int input_depth = input_shape.Dims(3);
263 const int input_width = input_shape.Dims(2);
264 const int input_height = input_shape.Dims(1);
265 const int output_depth = output_shape.Dims(3);
266 const int output_width = output_shape.Dims(2);
267 const int output_height = output_shape.Dims(1);
268
269 int buffer_id = 0;
270 // Loop over the output nodes.
271 for (int b = 0; b < batches; ++b) {
272 uint8_t zero_byte = static_cast<uint8_t>(input_offsets[b]);
273 for (int h = 0; h < output_height; ++h) {
274 for (int w = 0; w < output_width; ++w) {
275 ExtractPatchIntoBufferColumn(
276 input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
277 pad_width, pad_height, input_width, input_height, input_depth,
278 output_depth, buffer_id, input_data, output_data, zero_byte);
279 ++buffer_id;
280 }
281 }
282 }
283 }
284
285 } // namespace optimized_ops
286 } // namespace tflite
287
288 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
289