1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
17 
18 #include <cassert>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/types.h"
22 
23 namespace tflite {
24 namespace optimized_ops {
25 
26 template <typename T>
ExtractPatchIntoBufferColumn(const RuntimeShape & input_shape,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)27 inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
28                                          int h, int b, int kheight, int kwidth,
29                                          int stride_width, int stride_height,
30                                          int pad_width, int pad_height,
31                                          int in_width, int in_height,
32                                          int in_depth, int single_buffer_length,
33                                          int buffer_id, const T* in_data,
34                                          T* conv_buffer_data, uint8 zero_byte) {
35   ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn");
36   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
37   // This chunk of code reshapes all the inputs corresponding to
38   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
39   const int kwidth_times_indepth = kwidth * in_depth;
40   const int inwidth_times_indepth = in_width * in_depth;
41   const int ih_ungated_start = h * stride_height - pad_height;
42   const int ih_ungated_end = (ih_ungated_start + kheight);
43   const int ih_end = std::min(ih_ungated_end, in_height);
44   const int iw_ungated_start = w * stride_width - pad_width;
45   const int iw_ungated_end = (iw_ungated_start + kwidth);
46   const int iw_end = std::min(iw_ungated_end, in_width);
47   // If the patch is off the edge of the input image, skip writing those rows
48   // and columns from the patch into the output array.
49   const int h_offset = std::max(0, -ih_ungated_start);
50   const int w_offset = std::max(0, -iw_ungated_start);
51   const int ih_start = std::max(0, ih_ungated_start);
52   const int iw_start = std::max(0, iw_ungated_start);
53   const int single_row_num =
54       std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
55   const int output_row_offset = (buffer_id * single_buffer_length);
56   int out_offset =
57       output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
58   int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
59 
60   // Express all of the calculations as padding around the input patch.
61   const int top_padding = h_offset;
62   const int bottom_padding = (ih_ungated_end - ih_end);
63   const int left_padding = w_offset;
64   const int right_padding = (iw_ungated_end - iw_end);
65   assert(single_row_num ==
66          ((kwidth - (left_padding + right_padding)) * in_depth));
67 
68   // Write out zeroes to the elements representing the top rows of the input
69   // patch that are off the edge of the input image.
70   if (top_padding > 0) {
71     const int top_row_elements = (top_padding * kwidth * in_depth);
72     memset(conv_buffer_data + output_row_offset, zero_byte,
73            (top_row_elements * sizeof(T)));
74   }
75 
76   // If the patch is on the interior of the input image horizontally, just copy
77   // over the rows sequentially, otherwise add zero padding at the start or end.
78   if ((left_padding == 0) && (right_padding == 0)) {
79     for (int ih = ih_start; ih < ih_end; ++ih) {
80       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
81              single_row_num * sizeof(T));
82       out_offset += kwidth_times_indepth;
83       in_offset += inwidth_times_indepth;
84     }
85   } else {
86     for (int ih = ih_start; ih < ih_end; ++ih) {
87       if (left_padding > 0) {
88         const int left_start = (out_offset - (left_padding * in_depth));
89         memset(conv_buffer_data + left_start, zero_byte,
90                (left_padding * in_depth * sizeof(T)));
91       }
92       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
93              single_row_num * sizeof(T));
94       if (right_padding > 0) {
95         const int right_start = (out_offset + single_row_num);
96         memset(conv_buffer_data + right_start, zero_byte,
97                (right_padding * in_depth * sizeof(T)));
98       }
99       out_offset += kwidth_times_indepth;
100       in_offset += inwidth_times_indepth;
101     }
102   }
103 
104   // If the bottom of the patch falls off the input image, pad the values
105   // representing those input rows with zeroes.
106   if (bottom_padding > 0) {
107     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
108     const int bottom_start =
109         output_row_offset +
110         ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
111     memset(conv_buffer_data + bottom_start, zero_byte,
112            (bottom_row_elements * sizeof(T)));
113   }
114 }
115 
116 // Supports per-batch zero_byte for per-batch asymmetric quantized inputs.
117 template <typename T>
DilatedIm2col(const ConvParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data,const int32_t * zero_bytes,const int zero_bytes_len)118 void DilatedIm2col(const ConvParams& params, const RuntimeShape& input_shape,
119                    const T* input_data, const RuntimeShape& filter_shape,
120                    const RuntimeShape& output_shape, T* im2col_data,
121                    const int32_t* zero_bytes, const int zero_bytes_len) {
122   const int stride_width = params.stride_width;
123   const int stride_height = params.stride_height;
124   const int dilation_width_factor = params.dilation_width_factor;
125   const int dilation_height_factor = params.dilation_height_factor;
126   const int pad_width = params.padding_values.width;
127   const int pad_height = params.padding_values.height;
128   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
129   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
130   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
131 
132   // For dilated convolution, the input pixels are not contiguous therefore we
133   // can't use the same optimizations as Im2Col(). Though note this code would
134   // work fine for the non-dilated case too (though likely a bit slower).
135   ruy::profiler::ScopeLabel label("DilatedIm2col");
136   TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
137   TFLITE_DCHECK(im2col_data);
138   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
139   const int input_height = input_shape.Dims(1);
140   const int input_width = input_shape.Dims(2);
141   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
142   const int filter_height = filter_shape.Dims(1);
143   const int filter_width = filter_shape.Dims(2);
144   const int output_height = output_shape.Dims(1);
145   const int output_width = output_shape.Dims(2);
146   MatchingDim(output_shape, 3, filter_shape, 0);
147 
148   // Construct the MxN sized im2col matrix.
149   // The rows M, are sub-ordered B x H x W
150   const RuntimeShape row_shape({1, batches, output_height, output_width});
151   // The columns, N, are sub-ordered Kh x Kw x Din
152   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
153   // Use dimensions M and N to construct dims for indexing directly into im2col
154   const RuntimeShape im2col_shape(
155       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
156 
157   // Loop through the output rows (B x H x W)
158   for (int batch = 0; batch < batches; ++batch) {
159     const T zero_byte = zero_bytes_len > 1 ? static_cast<T>(zero_bytes[batch])
160                                            : static_cast<T>(zero_bytes[0]);
161     for (int out_y = 0; out_y < output_height; ++out_y) {
162       for (int out_x = 0; out_x < output_width; ++out_x) {
163         // Each im2col row is an output pixel. Arrange the input data in this
164         // row in an order we can conveniently multiply with the filter data.
165         int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
166         const int in_x_origin = (out_x * stride_width) - pad_width;
167         const int in_y_origin = (out_y * stride_height) - pad_height;
168         // Loop through all the pixels of the filter (Kh x Kw)
169         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
170           const int in_y = in_y_origin + dilation_height_factor * filter_y;
171           if ((in_y >= 0) && (in_y < input_height)) {
172             // Filter row is within the input data.
173             // Loop through all the filter pixels in this row.
174             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
175               const int in_x = in_x_origin + dilation_width_factor * filter_x;
176               int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
177               T* dst = im2col_data +
178                        Offset(im2col_shape, 0, 0, row_offset, col_offset);
179               if ((in_x >= 0) && (in_x < input_width)) {
180                 // Filter pixel is within the input, copy the input data.
181                 T const* src =
182                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
183                 memcpy(dst, src, input_depth * sizeof(T));
184               } else {
185                 // Filter pixel is outside the input, zero it out.
186                 memset(dst, zero_byte, input_depth * sizeof(T));
187               }
188             }
189           } else {
190             // Filter row is outside the input, zero out the entire filter row.
191             int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
192             T* dst = im2col_data +
193                      Offset(im2col_shape, 0, 0, row_offset, col_offset);
194             memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
195           }
196         }
197       }
198     }
199   }
200 }
201 
202 template <typename T>
DilatedIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)203 void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
204                    const RuntimeShape& input_shape, const T* input_data,
205                    const RuntimeShape& filter_shape,
206                    const RuntimeShape& output_shape, T* im2col_data) {
207   const int32_t zero_point = static_cast<int32_t>(zero_byte);
208   DilatedIm2col<T>(params, input_shape, input_data, filter_shape, output_shape,
209                    im2col_data, &zero_point, 1);
210 }
211 
212 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)213 void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
214             const RuntimeShape& input_shape, const T* input_data,
215             const RuntimeShape& output_shape, T* output_data) {
216   ruy::profiler::ScopeLabel label("Im2col");
217   const int stride_width = params.stride_width;
218   const int stride_height = params.stride_height;
219   const int pad_width = params.padding_values.width;
220   const int pad_height = params.padding_values.height;
221   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
222   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
223 
224   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
225   const int input_depth = input_shape.Dims(3);
226   const int input_width = input_shape.Dims(2);
227   const int input_height = input_shape.Dims(1);
228   const int output_depth = output_shape.Dims(3);
229   const int output_width = output_shape.Dims(2);
230   const int output_height = output_shape.Dims(1);
231 
232   int buffer_id = 0;
233   // Loop over the output nodes.
234   for (int b = 0; b < batches; ++b) {
235     for (int h = 0; h < output_height; ++h) {
236       for (int w = 0; w < output_width; ++w) {
237         ExtractPatchIntoBufferColumn(
238             input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
239             pad_width, pad_height, input_width, input_height, input_depth,
240             output_depth, buffer_id, input_data, output_data, zero_byte);
241         ++buffer_id;
242       }
243     }
244   }
245 }
246 
247 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,const int32_t * input_offsets,const int input_offsets_size,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)248 void Im2col(const ConvParams& params, int kheight, int kwidth,
249             const int32_t* input_offsets, const int input_offsets_size,
250             const RuntimeShape& input_shape, const T* input_data,
251             const RuntimeShape& output_shape, T* output_data) {
252   ruy::profiler::ScopeLabel label("Im2col");
253   const int stride_width = params.stride_width;
254   const int stride_height = params.stride_height;
255   const int pad_width = params.padding_values.width;
256   const int pad_height = params.padding_values.height;
257   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
258   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
259 
260   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
261   TFLITE_DCHECK_EQ(batches, input_offsets_size);
262   const int input_depth = input_shape.Dims(3);
263   const int input_width = input_shape.Dims(2);
264   const int input_height = input_shape.Dims(1);
265   const int output_depth = output_shape.Dims(3);
266   const int output_width = output_shape.Dims(2);
267   const int output_height = output_shape.Dims(1);
268 
269   int buffer_id = 0;
270   // Loop over the output nodes.
271   for (int b = 0; b < batches; ++b) {
272     uint8_t zero_byte = static_cast<uint8_t>(input_offsets[b]);
273     for (int h = 0; h < output_height; ++h) {
274       for (int w = 0; w < output_width; ++w) {
275         ExtractPatchIntoBufferColumn(
276             input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
277             pad_width, pad_height, input_width, input_height, input_depth,
278             output_depth, buffer_id, input_data, output_data, zero_byte);
279         ++buffer_id;
280       }
281     }
282   }
283 }
284 
285 }  // namespace optimized_ops
286 }  // namespace tflite
287 
288 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
289