1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20 
21 #include <stdlib.h>
22 
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/winograd_transform.h"
25 #include "tensorflow/core/util/work_sharder.h"
26 
27 namespace tensorflow {
28 
29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
31 //
32 // DeepConv2D is implemented by computing the following equation:
33 //
34 //   y = C[Ad * Bg]
35 //
36 //   C: output transform matrix
37 //   A: input data transform matrix
38 //   B: filter transform matrix
39 //   d: vectorized data tile
40 //   g: vectorized filter tile
41 //   y: vectorized output tile
42 //
43 // The transform matrices and input, filter and output tile sizes are all
44 // specified by the DeepConv2DTransform implementation selected at the
45 // start of the DeepConv2D call, based on convolution parameters.
46 
47 // Approximate cost models for direct and deep convolutions.
GetDeepConvCost(int input_tile_rows,int input_tile_cols,int out_tile_rows,int out_tile_cols,int in_depth,int out_depth,int out_rows,int out_cols)48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49                              int out_tile_rows, int out_tile_cols, int in_depth,
50                              int out_depth, int out_rows, int out_cols) {
51   // Input transform cost.
52   const int64 input_tile_spatial_size = input_tile_rows * input_tile_cols;
53   const int64 input_transform_cost =
54       input_tile_spatial_size * input_tile_spatial_size * in_depth;
55 
56   // Element-wise products (each product is a MatMul across depth).
57   const int64 product_cost = input_tile_spatial_size * in_depth * out_depth;
58 
59   // Output transform cost.
60   const int64 output_tile_spatial_size = out_tile_rows * out_tile_cols;
61   const int64 output_transform_cost =
62       output_tile_spatial_size * input_tile_spatial_size * out_depth;
63 
64   // Calculate number of input tiles to process.
65   const int64 row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
66   const int64 col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
67   const int64 num_tiles = row_tiles * col_tiles;
68 
69   // Return total cost.
70   return num_tiles *
71          (input_transform_cost + product_cost + output_transform_cost);
72 }
73 
GetDirectConvCost(int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
75                                int out_depth, int out_rows, int out_cols) {
76   return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
77 }
78 
79 // Reads environment variable 'env_var_name'.
80 // Returns 'true' if environment variable is enabled, false otherwise.
ReadBoolFromEnvVar(const char * env_var_name,bool default_val)81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
82   const char* tf_env_var_val = getenv(env_var_name);
83   if (tf_env_var_val != nullptr) {
84     StringPiece tf_env_var_val_str(tf_env_var_val);
85     if (tf_env_var_val_str == "0") {
86       return false;
87     }
88     return true;
89   }
90   return default_val;
91 }
92 
93 // Returns true if convolution can be computed efficiently by DeepConv2D,
94 // returns false otherwise.
95 // TODO(andydavis) Add support for other filter sizes and strides.
96 // TODO(andydavis) Add support for autotuning.
CanUseDeepConv2D(int stride_rows,int stride_cols,int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
98                       int filter_cols, int in_depth, int out_depth,
99                       int out_rows, int out_cols) {
100   // Check if convolution parameters are supported.
101   // TODO(andydavis) Add support for multiple filter sizes and strides.
102   if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
103       filter_cols != 3) {
104     return false;
105   }
106 
107   // Check if deep convolution is enabled by environment variable.
108   // NOTE: IF this environment variable name changes, update conv_ops_test.py.
109   if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
110     return false;
111   }
112 
113   // Check if flop cost of deep convolution is less than direct convolution.
114   WinogradTransform<float> t;
115   const int64 deep_conv_cost = GetDeepConvCost(
116       t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
117       t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
118   const int64 direct_conv_cost = GetDirectConvCost(
119       filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
120 
121   VLOG(2) << "CanUseDeepConv2D"
122           << " deep_conv_cost: " << deep_conv_cost
123           << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
124           << (static_cast<float>(deep_conv_cost) /
125               static_cast<float>(direct_conv_cost))
126           << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
127   return deep_conv_cost < direct_conv_cost;
128 }
129 
130 typedef Eigen::ThreadPoolDevice CPUDevice;
131 
132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
133 //
134 // filter_in:
135 //   [filter_rows, filter_cols, in_depth, out_depth]
136 //
137 // filter_buf:
138 //   [base_filter_rows, base_filter_cols, in_depth]
139 //
140 template <typename T>
141 struct CopyFilterDepth {
operator ()tensorflow::CopyFilterDepth142   void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
143     typedef typename Eigen::internal::packet_traits<T>::type Packet;
144     static constexpr int64 kPacketSize = (sizeof(Packet) / sizeof(T));
145 
146     const int64 vectorized_size = args.in_depth / kPacketSize;
147     const int64 scalar_size = args.in_depth % kPacketSize;
148     const int64 input_stride = args.out_depth * kPacketSize;
149 
150     // Copy vectorized portion of depth dimension.
151     for (int64 d = 0; d < vectorized_size; ++d) {
152       auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
153                                                    args.out_depth);
154       Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
155     }
156     // Copy scalar portion of inner dimension.
157     const int64 in_scalar_base = vectorized_size * input_stride;
158     const int64 buf_scalar_base = vectorized_size * kPacketSize;
159     for (int64 d = 0; d < scalar_size; ++d) {
160       filter_buf[buf_scalar_base + d] =
161           filter_in[in_scalar_base + d * args.out_depth];
162     }
163   }
164 };
165 
166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
169 // 'filter_out' at the coordinate stride required by the transformed filter
170 // data layout.
171 //
172 // filter_in:
173 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
174 //    in_depth]
175 //
176 // filter_out:
177 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
178 //
179 // transform_matrix:
180 //   [tile_spatial_size, base_filter_spatial_size]
181 //
182 // out_buffer:
183 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
184 
185 template <typename T>
186 struct ComputeFilterRangeTransform {
187   typedef typename Eigen::internal::packet_traits<T>::type Packet;
188   static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
189 
190   typedef Eigen::Map<
191       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
192       MatrixMap;
193   typedef Eigen::Map<
194       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
195       ConstMatrixMap;
196 
operator ()tensorflow::ComputeFilterRangeTransform197   void operator()(const Conv2DArgs& args,
198                   const DeepConv2DTransform<T>* transform, const int64 od_start,
199                   const int64 num_filters, const int64 shard_rows,
200                   const int64 shard_cols, const T* filter_in,
201                   const int64 in_stride, const int64 out_stride,
202                   const T* transform_matrix, T* out_buffer, T* filter_out) {
203     namespace ei = Eigen::internal;
204 
205     const int64 in_depth = args.in_depth;
206     const int64 base_filter_rows = transform->filter_shape().rows;
207     const int64 base_filter_cols = transform->filter_shape().cols;
208     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
209     const int64 tile_rows = transform->input_shape().rows;
210     const int64 tile_cols = transform->input_shape().cols;
211     const int64 tile_spatial_size = tile_rows * tile_cols;
212 
213     // Compute transform of 'num_filters' by 'transform_matrix'.
214     ConstMatrixMap A(transform_matrix, tile_spatial_size,
215                      base_filter_spatial_size);
216     ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
217     MatrixMap C(out_buffer, tile_spatial_size, in_stride);
218 
219     C.noalias() = A * B;
220 
221     // Copy 'out_buffer' to 'filter_out' at required filter output stride.
222     const int64 scalar_size = in_depth % kPacketSize;
223     const int64 vectorized_size = in_depth / kPacketSize;
224 
225     const int64 shard_stride = args.in_depth;
226     const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
227 
228     for (int64 od = 0; od < num_filters; ++od) {
229       const int64 out_depth_buf_base = od * out_depth_stride;
230       const int64 out_depth_base = (od_start + od) * out_depth_stride;
231 
232       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
233       for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
234         for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
235           const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
236 
237           for (int64 i = 0; i < tile_spatial_size; ++i) {
238             const int64 in_base =
239                 i * in_stride + out_depth_buf_base + shard_base;
240             const int64 out_base = i * out_stride + out_depth_base + shard_base;
241             // Copy vectorized portion of 'in_depth'.
242             for (int64 d = 0; d < vectorized_size; ++d) {
243               auto v =
244                   ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
245               ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
246             }
247             // Transform scalar portion of 'in_depth'.
248             const int64 scalar_base = vectorized_size * kPacketSize;
249             for (int64 d = 0; d < scalar_size; ++d) {
250               filter_out[out_base + scalar_base + d] =
251                   out_buffer[in_base + scalar_base + d];
252             }
253           }
254         }
255       }
256     }
257   }
258 };
259 
260 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
261 // For each filter in 'num_filters', copies data for all filter shards from
262 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
263 // Calls ComputeFilterRangeTransform to compute filter transform of data
264 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
265 //
266 // filter_in:
267 //   [filter_rows, filter_cols, in_depth, out_depth]
268 //
269 // filter_out:
270 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
271 //
272 // filter_buffer:
273 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
274 //    in_depth]
275 //
276 // transform_matrix:
277 //   [tile_spatial_size, base_filter_spatial_size]
278 //
279 // out_buffer:
280 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
281 //
282 
283 template <typename T>
284 struct TransformFilterRange {
operator ()tensorflow::TransformFilterRange285   void operator()(const Conv2DArgs& args,
286                   const DeepConv2DTransform<T>* transform, const int64 od_start,
287                   const int64 od_limit, const T* filter_in,
288                   const T* transform_matrix, T* out_buffer, T* filter_buf,
289                   T* filter_out) {
290     const int64 num_filters = od_limit - od_start;
291     const int64 base_filter_rows = transform->filter_shape().rows;
292     const int64 base_filter_cols = transform->filter_shape().cols;
293     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
294 
295     // Compute number of filter shards.
296     const int64 residual_row =
297         std::max(int64{0}, args.filter_rows - base_filter_rows);
298     const int64 shard_rows = 1 + (residual_row + 2 - 1) / 2;
299 
300     const int64 residual_col =
301         std::max(int64{0}, args.filter_cols - base_filter_cols);
302     const int64 shard_cols = 1 + (residual_col + 2 - 1) / 2;
303 
304     // Compute strides to be used for input and output IO.
305     const int64 shard_stride = args.in_depth;
306     const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
307     const int64 coord_stride = out_depth_stride * args.out_depth;
308     const int64 filter_buf_stride =
309         num_filters * shard_rows * shard_cols * args.in_depth;
310     const int64 tile_stride_rows = transform->output_shape().rows;
311     const int64 tile_stride_cols = transform->output_shape().cols;
312 
313     const int64 filter_buf_size = base_filter_spatial_size * num_filters *
314                                   shard_rows * shard_cols * args.in_depth;
315     memset(filter_buf, 0, sizeof(T) * filter_buf_size);
316 
317     // Copy filter range into 'filter_buf'.
318     for (int64 od = 0; od < num_filters; ++od) {
319       const int64 out_depth_base = od * out_depth_stride;
320 
321       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
322       for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
323         const int64 row_offset = s_r == 0 ? 0 : 1;
324 
325         for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
326           const int64 col_offset = s_c == 0 ? 0 : 1;
327           const int64 f_r_start = s_r * tile_stride_rows;
328           const int64 f_c_start = s_c * tile_stride_cols;
329 
330           const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
331 
332           for (int64 b_r = row_offset; b_r < base_filter_rows; ++b_r) {
333             const int64 f_r = f_r_start + b_r;
334             if (f_r >= args.filter_rows) continue;
335 
336             for (int64 b_c = col_offset; b_c < base_filter_cols; ++b_c) {
337               const int64 f_c = f_c_start + b_c;
338               if (f_c >= args.filter_cols) continue;
339 
340               const int64 in_index =
341                   args.out_depth *
342                       (args.in_depth * (f_r * args.filter_cols + f_c)) +
343                   (od_start + od);
344 
345               const int64 buf_index =
346                   filter_buf_stride * (b_r * base_filter_cols + b_c) +
347                   out_depth_base + shard_base;
348 
349               CopyFilterDepth<T>()(args, filter_in + in_index,
350                                    filter_buf + buf_index);
351             }
352           }
353         }
354       }
355     }
356 
357     // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
358     // Intermediate results are stored in 'out_buffer'.
359     // Final results are stored in 'filter_out'.
360     ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
361                                      shard_rows, shard_cols, filter_buf,
362                                      filter_buf_stride, coord_stride,
363                                      transform_matrix, out_buffer, filter_out);
364   }
365 };
366 
367 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
368 //
369 // filter_in:
370 //   [filter_rows, filter_cols, in_depth, out_depth]
371 //
372 // filter_out:
373 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
374 //
375 template <typename T>
376 struct TransformFilters {
operator ()tensorflow::TransformFilters377   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
378                   const DeepConv2DTransform<T>* transform,
379                   const int64 filter_shards_row, const int64 filter_shards_col,
380                   const T* filter_in, T* filter_out) {
381     const int64 in_depth = args.in_depth;
382     const int64 out_depth = args.out_depth;
383 
384     const int64 tile_rows = transform->input_shape().rows;
385     const int64 tile_cols = transform->input_shape().cols;
386     const int64 tile_spatial_size = tile_rows * tile_cols;
387 
388     const int64 base_filter_rows = transform->filter_shape().rows;
389     const int64 base_filter_cols = transform->filter_shape().cols;
390     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
391 
392     const int64 filter_shards_total = filter_shards_row * filter_shards_col;
393 
394     // Calculate filter transform batch based on cache/filter sizes.
395 
396     // Cache budget (based on L2 cache size = 256KB).
397     // TODO(andydavis) Read cache size from system.
398     const int64 cache_size = (256LL << 10) / sizeof(T);
399 
400     // Fixed cost.
401     const int64 filter_transform_matrix_size =
402         tile_spatial_size * base_filter_spatial_size;
403 
404     // Per-filter costs.
405     const int64 filter_total_size =
406         base_filter_spatial_size * in_depth * filter_shards_total;
407 
408     const int64 filter_transform_buffer_size =
409         base_filter_spatial_size * filter_shards_total * in_depth;
410 
411     const int64 filter_out_buf_size =
412         tile_spatial_size * filter_shards_total * in_depth;
413 
414     // Total per-filter costs.
415     const int64 per_filter_cost =
416         filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
417 
418     // Remove fixed cost and divide by per-filter cost.
419     const int64 num_filters_cache =
420         std::max(int64{1},
421                  (cache_size - filter_transform_matrix_size) / per_filter_cost);
422     const int64 num_filters_transform = std::min(out_depth, num_filters_cache);
423 
424     // Allocate buffer for filter transform matrix:
425     //   [tile_spatial_size, base_filter_spatial_size]
426     Tensor filter_transform_matrix;
427     OP_REQUIRES_OK(
428         ctx, ctx->allocate_temp(
429                  DataTypeToEnum<T>::value,
430                  TensorShape({tile_spatial_size, base_filter_spatial_size}),
431                  &filter_transform_matrix));
432     T* transform_matrix = filter_transform_matrix.template flat<T>().data();
433     transform->GetFilterTransformMatrix(
434         tile_spatial_size, base_filter_spatial_size, transform_matrix);
435 
436     auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
437                   &num_filters_transform, &in_depth, &filter_shards_row,
438                   &filter_shards_col, &tile_spatial_size, &filter_in,
439                   &transform_matrix, &filter_out](int64 start, int64 limit) {
440       // Allocate buffer for pre-processed filter:
441       //   [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
442       //
443       Tensor filter_transform_buffer;
444       OP_REQUIRES_OK(ctx,
445                      ctx->allocate_temp(
446                          DataTypeToEnum<T>::value,
447                          TensorShape({base_filter_rows, base_filter_cols,
448                                       num_filters_transform, filter_shards_row,
449                                       filter_shards_col, in_depth}),
450                          &filter_transform_buffer));
451       T* filter_buf = filter_transform_buffer.template flat<T>().data();
452 
453       // Allocate buffer for output filter transform matrix:
454       //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
455       Tensor filter_output_buffer;
456       OP_REQUIRES_OK(
457           ctx,
458           ctx->allocate_temp(
459               DataTypeToEnum<T>::value,
460               TensorShape({tile_spatial_size, num_filters_transform,
461                            filter_shards_row, filter_shards_col, in_depth}),
462               &filter_output_buffer));
463       T* out_buffer = filter_output_buffer.template flat<T>().data();
464 
465       const int64 num_filters = limit - start;
466       const int64 od_unroll = num_filters_transform;
467       const int64 od_unroll_limit = (num_filters / od_unroll) * od_unroll;
468 
469       for (int64 od = start; od < od_unroll_limit; od += od_unroll) {
470         TransformFilterRange<T>()(args, transform, od, od + od_unroll,
471                                   filter_in, transform_matrix, out_buffer,
472                                   filter_buf, filter_out);
473       }
474 
475       if (od_unroll_limit < limit) {
476         TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
477                                   filter_in, transform_matrix, out_buffer,
478                                   filter_buf, filter_out);
479       }
480     };
481     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
482 
483     const int64 shard_cost = args.filter_rows * args.filter_cols * in_depth *
484                              filter_shards_total * tile_spatial_size;
485     // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
486     Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
487   }
488 };
489 
490 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
491 // gemm-kernel friendly data layout.
492 //
493 // Data layout for 'lhs_block':
494 //   [out_depth, shard_rows, shard_cols, in_depth].
495 
496 template <typename T>
497 class GemmFilterPacker {
498  public:
499   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
500       LhsMapper;
501   typedef Eigen::internal::gebp_traits<T, T> Traits;
502   Eigen::internal::gemm_pack_lhs<
503       T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
504       typename Traits::LhsPacket4Packing, Eigen::RowMajor>
505       pack_lhs;
506 
GemmFilterPacker(const int64 rows,const int64 depth,const T * lhs_input,T * lhs_block)507   GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,
508                    T* lhs_block)
509       : rows_(rows),
510         depth_(depth),
511         lhs_block_(lhs_block),
512         lhs_mapper_(lhs_input, depth_) {}
513 
Run()514   void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
515 
516  private:
517   const int64 rows_;
518   const int64 depth_;
519   T* lhs_block_;
520   LhsMapper lhs_mapper_;
521 };
522 
523 // Packs transformed filter stored in 'filter_transform_data' into
524 // 'packed_filters' to be used by GemmState.
525 template <typename T>
526 struct PackFilters {
operator ()tensorflow::PackFilters527   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
528                   const int64 tile_spatial_size, const int64 filter_shards_row,
529                   const int64 filter_shards_col, const T* filter_transform_data,
530                   std::vector<Tensor>* packed_filters) {
531     const int64 in_depth = args.in_depth;
532     const int64 out_depth = args.out_depth;
533     const int64 num_filters = filter_shards_row * filter_shards_col * out_depth;
534 
535     auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
536                   &out_depth, &filter_shards_row, &filter_shards_col,
537                   &num_filters](int64 start, int64 limit) {
538       const int64 filter_coord_stride = num_filters * in_depth;
539       for (int64 i = start; i < limit; ++i) {
540         // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
541         OP_REQUIRES_OK(
542             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
543                                     TensorShape({out_depth, filter_shards_row,
544                                                  filter_shards_col, in_depth}),
545                                     &(*packed_filters)[i]));
546         T* packed_filter = (*packed_filters)[i].template flat<T>().data();
547         // Pack filters.
548         GemmFilterPacker<T> packer(
549             num_filters, in_depth,
550             filter_transform_data + i * filter_coord_stride, packed_filter);
551         packer.Run();
552       }
553     };
554     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
555     Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
556           num_filters * in_depth, shard);
557   }
558 };
559 
560 // Computes the product of filters stored in 'lhs_block' and input tiles
561 // stored in 'rhs_block', storing output in 'out_buffer'.
562 //
563 // Data layout for 'lhs_block':
564 //   [out_depth, shard_rows, shard_cols, in_depth].
565 //
566 // Data layout for 'rhs_block':
567 //   [num_tiles, in_depth]
568 //
569 // Data layout for 'out_buffer':
570 //   [num_tiles, out_depth, shard_rows, shard_cols]
571 
572 template <typename T>
573 class GemmState {
574  public:
575   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor>
576       RhsMapper;
577   typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor>
578       OutputMapper;
579   typedef Eigen::internal::gebp_traits<T, T> Traits;
580 
581   Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr,
582                                  Eigen::ColMajor>
583       pack_rhs;
584   Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr,
585                                Traits::nr, false, false>
586       gebp;
587 
GemmState(const int64 rows,const int64 cols,const int64 depth,const int64 out_buffer_size,const T * lhs_block,const T * rhs_input,T * rhs_block,T * out_buffer)588   GemmState(const int64 rows, const int64 cols, const int64 depth,
589             const int64 out_buffer_size, const T* lhs_block, const T* rhs_input,
590             T* rhs_block, T* out_buffer)
591       : rows_(rows),
592         cols_(cols),
593         depth_(depth),
594         out_buffer_size_(out_buffer_size),
595         lhs_block_(lhs_block),
596         rhs_block_(rhs_block),
597         out_buffer_(out_buffer),
598         rhs_mapper_(rhs_input, depth_),
599         out_mapper_(out_buffer, rows_) {}
600 
PackRhs()601   void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
602 
Compute()603   void Compute() {
604     memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
605     gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
606   }
607 
608  private:
609   const int64 rows_;
610   const int64 cols_;
611   const int64 depth_;
612   const int64 out_buffer_size_;
613   const T* lhs_block_;
614   T* rhs_block_;
615   T* out_buffer_;
616   RhsMapper rhs_mapper_;
617   OutputMapper out_mapper_;
618 };
619 
620 // Copies an input tile from 'input' into 'tile_buffer'.
621 //
622 // input:
623 //   [in_rows, in_cols, in_depth]
624 //
625 // tile_buffer:
626 //   [tile_rows, tile_cols, num_tiles, in_depth]
627 
628 template <typename T>
629 struct CopyInputTile {
operator ()tensorflow::CopyInputTile630   void operator()(const Conv2DArgs& args,
631                   const DeepConv2DTransform<T>* transform,
632                   const int64 num_tiles, const int64 in_r_start,
633                   const int64 in_c_start, const T* input, T* tile_buffer) {
634     typedef typename Eigen::internal::packet_traits<T>::type Packet;
635     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
636 
637     const int64 tile_rows = transform->input_shape().rows;
638     const int64 tile_cols = transform->input_shape().cols;
639     const int64 coord_stride = num_tiles * args.in_depth;
640 
641     // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
642     const int64 input_vectorized_size =
643         (args.in_depth / kPacketSize) * kPacketSize;
644     const int64 input_scalar_size = args.in_depth % kPacketSize;
645 
646     for (int64 r = 0; r < tile_rows; ++r) {
647       const int64 in_r = in_r_start + r;
648       if (in_r < 0 || in_r >= args.in_rows) continue;
649 
650       for (int64 c = 0; c < tile_cols; ++c) {
651         const int64 in_c = in_c_start + c;
652         if (in_c < 0 || in_c >= args.in_cols) continue;
653 
654         auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
655         auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
656         // Copy vectorized portion of depth dimension.
657         for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) {
658           auto v = Eigen::internal::ploadu<Packet>(in + d);
659           Eigen::internal::pstoreu<T>(tile, v);
660           tile += kPacketSize;
661         }
662         // Copy scalar portion of inner dimension.
663         for (int64 d = 0; d < input_scalar_size; ++d) {
664           tile[d] = in[input_vectorized_size + d];
665         }
666       }
667     }
668   }
669 };
670 
671 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
672 // final result in 'tile_transform'.
673 // Intermediate results are stored in 'tile_buffer'.
674 //
675 // input:
676 //   [in_rows, in_cols, in_depth]
677 // tile_buffer:
678 //   [tile_rows, tile_cols, num_tiles, in_depth]
679 // tile_transform_matrix:
680 //   [tile_spatial_size, tile_spatial_size]
681 // tile_transform:
682 //   [tile_rows, tile_cols, num_tiles, in_depth]
683 
684 template <typename T>
685 struct TransformInputTiles {
686   typedef Eigen::Map<
687       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
688       MatrixMap;
689   typedef Eigen::Map<
690       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
691       ConstMatrixMap;
692 
operator ()tensorflow::TransformInputTiles693   void operator()(const Conv2DArgs& args,
694                   const DeepConv2DTransform<T>* transform,
695                   const int64 num_tiles, const int64 in_r_start,
696                   const int64 in_c_start, const T* input,
697                   const T* transform_matrix, T* tile_buffer,
698                   T* tile_transform) {
699     const int64 tile_rows = transform->input_shape().rows;
700     const int64 tile_cols = transform->input_shape().cols;
701     const int64 tile_spatial_size = tile_rows * tile_cols;
702     const int64 tile_stride_cols = transform->output_shape().cols;
703     const int64 coord_stride = num_tiles * args.in_depth;
704     const int64 num_tiles_stride = args.in_depth;
705 
706     memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
707     const int64 in_r = in_r_start;
708     for (int64 t = 0; t < num_tiles; ++t) {
709       const int64 num_tiles_base = t * num_tiles_stride;
710       const int64 in_c = in_c_start + t * tile_stride_cols;
711       CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
712                          tile_buffer + num_tiles_base);
713     }
714 
715     ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
716     ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
717     MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
718 
719     C.noalias() = A * B;
720   }
721 };
722 
723 // Transforms output tiles from buffer by 'out_transform_matrix', storing
724 // final result in 'output' (intermediate results stored in 'out_buffer').
725 //
726 // out_buffer:
727 //   [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
728 //
729 // output transform buffer:
730 //  [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
731 //
732 // output:
733 //   [out_rows, out_cols, out_depth]
734 //
735 
736 template <typename T>
737 struct TransformOutputTile {
738   typedef Eigen::Map<
739       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
740       MatrixMap;
741   typedef Eigen::Map<
742       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
743       ConstMatrixMap;
744 
operator ()tensorflow::TransformOutputTile745   void operator()(const Conv2DArgs& args,
746                   const DeepConv2DTransform<T>* transform,
747                   const int64 num_tiles, const int64 in_r, const int64 in_c,
748                   const int64 filter_shards_row, const int64 filter_shards_col,
749                   const T* out_transform_matrix, const T* out_buffer,
750                   T* out_transform_buffer, T* output) {
751     const int64 tile_rows = transform->input_shape().rows;
752     const int64 tile_cols = transform->input_shape().cols;
753     const int64 tile_spatial_size = tile_rows * tile_cols;
754 
755     const int64 out_buf_stride =
756         num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
757 
758     const int64 out_tile_rows = transform->output_shape().rows;
759     const int64 out_tile_cols = transform->output_shape().cols;
760     const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
761 
762     // Compute output transform.
763     ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
764                      tile_spatial_size);
765     ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
766     MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
767 
768     C.noalias() = A * B;
769 
770     const int64 tile_stride_rows = transform->output_shape().rows;
771     const int64 tile_stride_cols = transform->output_shape().cols;
772 
773     const int64 out_depth_stride = filter_shards_row * filter_shards_col;
774     const int64 num_tiles_stride = args.out_depth * out_depth_stride;
775 
776     // Copy transformed output from 'out_transform_buffer' to proper index
777     // in 'output'. Note that some outputs at boundaries can be discarded.
778     for (int64 t = 0; t < num_tiles; ++t) {
779       const int64 tile_base = t * num_tiles_stride;
780 
781       for (int64 od = 0; od < args.out_depth; ++od) {
782         const int64 out_depth_base = od * out_depth_stride;
783 
784         // TODO(andydavis) Update filter sharding scheme in the next CL.
785         for (int64 sr = 0; sr < filter_shards_row; ++sr) {
786           for (int64 sc = 0; sc < filter_shards_col; ++sc) {
787             const int64 shard_base = sr * filter_shards_col + sc;
788             const int64 out_buf_base = tile_base + out_depth_base + shard_base;
789 
790             // Calculate output indices and outputs to drop (if needed).
791             const int64 out_r_start =
792                 in_r + args.pad_rows - sr * tile_stride_rows;
793             // NOTE: The index 't' for 'num_tiles is used in index calculation
794             // for 'out_c_start' because we 'num_tiles' progresses along the
795             // column dimension.
796             const int64 out_c_start = (in_c + t * tile_stride_cols) +
797                                       args.pad_cols - sc * tile_stride_cols;
798 
799             if (out_r_start < 0 || out_r_start >= args.out_rows ||
800                 out_c_start < 0 || out_c_start >= args.out_cols) {
801               continue;  // Skip un-needed outputs.
802             }
803 
804             // Increment output if not first filter shard.
805             const bool inc_output = (sr == 0 && sc == 0) ? false : true;
806 
807             for (int64 ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
808               const int64 out_r = out_r_start + ot_row;
809               if (out_r >= args.out_rows) continue;
810 
811               for (int64 ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
812                 const int64 out_c = out_c_start + ot_col;
813                 if (out_c >= args.out_cols) continue;
814 
815                 // Calculate out tile indexl
816                 const int64 out_buf_index = ot_row * out_tile_cols + ot_col;
817                 // Read output value from buffer.
818                 const T out_val =
819                     out_transform_buffer[out_buf_base +
820                                          out_buf_index * out_buf_stride];
821                 // Calculate output index.
822                 const int64 output_index =
823                     args.out_depth * (out_r * args.out_cols + out_c) + od;
824                 // Update output.
825                 if (inc_output) {
826                   output[output_index] += out_val;
827                 } else {
828                   output[output_index] = out_val;
829                 }
830               }
831             }
832           }
833         }
834       }
835     }
836   }
837 };
838 
839 template <typename T>
840 struct Conv2DState {
Conv2DStatetensorflow::Conv2DState841   Conv2DState(const int64 tile_spatial_size, const int64 filter_shards_row,
842               const int64 filter_shards_col, const T* input,
843               const T* tile_transform_matrix, const T* output_transform_matrix,
844               T* buffer1, T* buffer2, T* packed_tile_buffer,
845               T* gemm_output_buffer)
846       : tile_spatial_size(tile_spatial_size),
847         filter_shards_row(filter_shards_row),
848         filter_shards_col(filter_shards_col),
849         input(input),
850         tile_transform_matrix(tile_transform_matrix),
851         output_transform_matrix(output_transform_matrix),
852         buffer1(buffer1),
853         buffer2(buffer2),
854         packed_tile_buffer(packed_tile_buffer),
855         gemm_output_buffer(gemm_output_buffer) {}
856 
857   const int64 tile_spatial_size;
858   const int64 filter_shards_row;
859   const int64 filter_shards_col;
860   const T* input;
861   const T* tile_transform_matrix;
862   const T* output_transform_matrix;
863   T* buffer1;
864   T* buffer2;
865   T* packed_tile_buffer;
866   T* gemm_output_buffer;
867 };
868 
869 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
870 // (in_r, in_c), storing the results of the computation in 'output'.
871 // Details:
872 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
873 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
874 // *) Transforms output tiles, and stores result to 'output'.
875 
876 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
877 template <typename T>
878 struct ComputeConv2D {
operator ()tensorflow::ComputeConv2D879   void operator()(const Conv2DArgs& args,
880                   const DeepConv2DTransform<T>* transform,
881                   const Conv2DState<T>& cs, const int64 in_r, const int64 in_c,
882                   const int64 num_tiles,
883                   const std::vector<Tensor>& packed_filters, const T* input,
884                   T* output) {
885     // Transform input tiles.
886     TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
887                              cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
888 
889     // Compute element-wise product (each a MatMul): input tiles X filters.
890     const int64 in_depth = args.in_depth;
891     const int64 out_depth = args.out_depth;
892     const int64 num_filters =
893         cs.filter_shards_row * cs.filter_shards_col * out_depth;
894     const int64 tile_coord_stride = num_tiles * in_depth;
895     const int64 gemm_out_buf_size = num_tiles * num_filters;
896     const int64 gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
897 
898     for (int64 i = 0; i < cs.tile_spatial_size; ++i) {
899       GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
900                         packed_filters[i].template flat<T>().data(),
901                         cs.buffer2 + i * tile_coord_stride,
902                         cs.packed_tile_buffer, cs.gemm_output_buffer);
903       // Pack tile buffer.
904       gemm.PackRhs();
905       // Compute product.
906       gemm.Compute();
907       // Copy to larger output buffer without alignment requirements.
908       memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
909              gemm_out_buf_bytes);
910     }
911 
912     // Transform output.
913     TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
914                              cs.filter_shards_row, cs.filter_shards_col,
915                              cs.output_transform_matrix, cs.buffer1, cs.buffer2,
916                              output);
917   }
918 };
919 
920 namespace functor {
921 
922 // Conv2D operation specialized for deep convolutions (i.e. large
923 // in_depth * out_depth).
924 // Details:
925 // *) Transforms and packs filters from 'filter' in parallel.
926 // *) Computes Conv2D parallelized across 'batch' dimension.
927 //   *) Each thread loops over images in its batch shard, copying 'num_tiles'
928 //      input tiles into a local buffer, and computing the Conv2D output of
929 //      these tiles by all filters.
930 
931 // TODO(andydavis) Improve the performance of boundary cases where the input
932 // tile extends past the limit, and wasted outputs are computed. This overhead
933 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
934 // for smaller spatial sizes.
935 // TODO(andydavis) Improve the performance of sharded filters.
936 template <typename T>
937 struct DeepConv2D<CPUDevice, T> {
operator ()tensorflow::functor::DeepConv2D938   void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
939                   const T* filter, T* output) {
940     // TODO(andydavis) Add function to select transform based on conv params.
941     std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
942 
943     const int64 in_depth = args.in_depth;
944     const int64 out_depth = args.out_depth;
945 
946     const int64 tile_rows = transform->input_shape().rows;
947     const int64 tile_cols = transform->input_shape().cols;
948     const int64 tile_spatial_size = tile_rows * tile_cols;
949 
950     const int64 out_tile_rows = transform->output_shape().rows;
951     const int64 out_tile_cols = transform->output_shape().cols;
952     const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
953 
954     const int64 base_filter_rows = transform->filter_shape().rows;
955 
956     const int64 filter_residual_row =
957         std::max(int64{0}, args.filter_rows - base_filter_rows);
958     const int64 filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
959 
960     const int64 filter_residual_col =
961         std::max(int64{0}, args.filter_cols - base_filter_rows);
962     const int64 filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
963 
964     // Allocate buffer for transformed filters.
965     Tensor filter_transform;
966     OP_REQUIRES_OK(
967         ctx, ctx->allocate_temp(
968                  DataTypeToEnum<T>::value,
969                  TensorShape({tile_rows, tile_cols, out_depth,
970                               filter_shards_row, filter_shards_col, in_depth}),
971                  &filter_transform));
972     T* filter_transform_data = filter_transform.template flat<T>().data();
973 
974     // Transform filters.
975     TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
976                           filter_shards_col, filter, filter_transform_data);
977 
978     // Pack filters.
979     std::vector<Tensor> packed_filters(tile_spatial_size);
980     PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
981                      filter_shards_col, filter_transform_data, &packed_filters);
982 
983     // Allocate buffer for tile transform matrix.
984     Tensor tile_transform_matrix_tensor;
985     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
986                             DataTypeToEnum<T>::value,
987                             TensorShape({tile_spatial_size, tile_spatial_size}),
988                             &tile_transform_matrix_tensor));
989     T* tile_transform_matrix =
990         tile_transform_matrix_tensor.template flat<T>().data();
991     transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
992                                        tile_transform_matrix);
993 
994     // Allocate buffer for output transform matrix.
995     Tensor output_transform_matrix_tensor;
996     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
997                                            TensorShape({out_tile_spatial_size,
998                                                         tile_spatial_size}),
999                                            &output_transform_matrix_tensor));
1000     T* output_transform_matrix =
1001         output_transform_matrix_tensor.template flat<T>().data();
1002     transform->GetOutputTransformMatrix(
1003         out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1004 
1005     auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1006                   out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1007                   filter_shards_col, tile_spatial_size, &input,
1008                   &tile_transform_matrix, &output_transform_matrix,
1009                   &output](int64 batch_start, int64 batch_limit) {
1010       const int64 row_tiles =
1011           (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1012           filter_shards_row - 1;
1013       const int64 col_tiles =
1014           (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1015           filter_shards_col - 1;
1016 
1017       // Calculate number of tiles to process together.
1018       const int64 filter_shard_size = filter_shards_row * filter_shards_col;
1019       const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
1020 
1021       // Cache budget (based on L2 cache size = 256KB).
1022       // TODO(andydavis) Read cache size from the system.
1023       const int64 cache_size = (256LL << 10) / sizeof(T);
1024 
1025       // Fixed costs.
1026       const int64 tile_transform_matrix_size =
1027           tile_spatial_size * tile_spatial_size;
1028       const int64 output_transform_matrix_size =
1029           out_tile_spatial_size * tile_spatial_size;
1030       // Calculate cache reserve size.
1031       const int64 filter_depth_size = in_depth * out_depth * filter_shard_size;
1032       const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1033       const int64 cache_reserve_size = small_filter ? filter_depth_size : 1024;
1034       // Calculate total fixed cost.
1035       const int64 total_fixed_cost = tile_transform_matrix_size +
1036                                      output_transform_matrix_size +
1037                                      cache_reserve_size;
1038 
1039       // Per-tile costs.
1040       const int64 buffer1_per_tile_size =
1041           tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1042       const int64 buffer2_per_tile_size =
1043           std::max(tile_spatial_size * in_depth,
1044                    out_tile_spatial_size * out_depth * filter_shard_size);
1045       const int64 packed_tile_per_tile_size = in_depth;
1046       const int64 gemm_out_per_tile_size = out_depth * filter_shard_size;
1047       const int64 total_per_tile_cost =
1048           buffer1_per_tile_size + buffer2_per_tile_size +
1049           packed_tile_per_tile_size + gemm_out_per_tile_size;
1050 
1051       const int64 num_tiles_cache = std::max(
1052           int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1053       const int64 num_tiles = std::min(num_tiles_cache, col_tiles);
1054 
1055       // Allocate temporary buffer 'buffer1', which is first used for copying
1056       // input tiles, then re-used to buffer gemm output. Calculate the
1057       // required buffer size for 'buffer1', based on max buffer size required
1058       // between copying input tiles and buffering gemm product output.
1059       //   buffer1: [max(buf1_tile_size, buf1_out_size)]
1060       const int64 buffer1_tile_size = tile_spatial_size * num_tiles * in_depth;
1061       const int64 buffer1_out_size =
1062           tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1063       const int64 buffer1_size = std::max(buffer1_tile_size, buffer1_out_size);
1064       Tensor buffer1_tensor;
1065       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1066                                              TensorShape({buffer1_size}),
1067                                              &buffer1_tensor));
1068       T* buffer1 = buffer1_tensor.template flat<T>().data();
1069 
1070       // Allocate temporary buffer 'buffer2', which is first used for
1071       // transformed input tiles, then re-used for transformed output tiles.
1072       // Calculate required buffer size for 'buffer2' as max required buffer
1073       // between input and output transform buffer sizes.
1074       const int64 buffer2_tile_transform_size =
1075           tile_spatial_size * num_tiles * in_depth;
1076       const int64 buffer2_out_transform_size =
1077           out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1078       const int64 buffer2_size =
1079           std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1080       Tensor buffer2_tensor;
1081       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1082                                              TensorShape({buffer2_size}),
1083                                              &buffer2_tensor));
1084       T* buffer2 = buffer2_tensor.template flat<T>().data();
1085 
1086       // Allocate temporary buffer to store packed tiles for one coordinate.
1087       // packed tile buffer: [num_tiles, in_depth].
1088       Tensor packed_tile_tensor;
1089       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1090                                              TensorShape({num_tiles, in_depth}),
1091                                              &packed_tile_tensor));
1092       T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1093 
1094       // Allocate temporary buffer for gemm output.
1095       // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1096       Tensor gemm_output_tensor;
1097       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1098                                              TensorShape({num_tiles, out_depth,
1099                                                           filter_shards_row,
1100                                                           filter_shards_col}),
1101                                              &gemm_output_tensor));
1102       T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1103 
1104       // Capture state needed for ComputeConv2D inner loop.
1105       Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1106                                 filter_shards_col, input, tile_transform_matrix,
1107                                 output_transform_matrix, buffer1, buffer2,
1108                                 packed_tile_buffer, gemm_output_buffer);
1109 
1110       const int64 row_pad = args.pad_rows;
1111       const int64 col_pad = args.pad_cols;
1112       const int64 unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1113 
1114       const int64 input_image_size = args.in_rows * args.in_cols * in_depth;
1115       const int64 output_image_size = args.out_rows * args.out_cols * out_depth;
1116 
1117       const int64 tile_stride_rows = transform->output_shape().rows;
1118       const int64 tile_stride_cols = transform->output_shape().cols;
1119 
1120       for (int64 b = batch_start; b < batch_limit; ++b) {
1121         const int64 in_base = b * input_image_size;
1122         const int64 out_base = b * output_image_size;
1123 
1124         for (int64 tile_r = 0; tile_r < row_tiles; ++tile_r) {
1125           const int64 in_r = tile_r * tile_stride_rows - row_pad;
1126 
1127           // Process unrolled tiles.
1128           for (int64 tile_c = 0; tile_c < unroll_col_limit;
1129                tile_c += num_tiles) {
1130             const int64 in_c = tile_c * tile_stride_cols - col_pad;
1131             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1132                                num_tiles, packed_filters, input + in_base,
1133                                output + out_base);
1134           }
1135           // Process remaining tiles.
1136           if (unroll_col_limit < col_tiles) {
1137             const int64 rem_tiles = col_tiles - unroll_col_limit;
1138             const int64 in_c = unroll_col_limit * tile_stride_cols - col_pad;
1139             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1140                                rem_tiles, packed_filters, input + in_base,
1141                                output + out_base);
1142           }
1143         }
1144       }
1145     };
1146     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1147     const int64 shard_cost = args.out_rows * args.out_cols * args.out_depth *
1148                              tile_spatial_size * args.in_depth;
1149     Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1150           shard_cost, shard);
1151   }
1152 };
1153 
1154 }  // namespace functor
1155 
1156 template struct functor::DeepConv2D<CPUDevice, float>;
1157 
1158 }  // namespace tensorflow
1159