1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
18 
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "mkldnn.hpp"
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/numeric_op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/tensor_slice.h"
31 #include "tensorflow/core/kernels/conv_grad_ops.h"
32 #include "tensorflow/core/kernels/ops_util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/util/mkl_util.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 
43 using mkldnn::convolution_direct;
44 using mkldnn::convolution_forward;
45 using mkldnn::prop_kind;
46 using mkldnn::stream;
47 
48 namespace tensorflow {
49 
50 class MklDnnConvUtil {
51  protected:
52   OpKernelContext* context_;  // We don't own this.
53   std::vector<int32> strides_;
54   std::vector<int32> dilations_;
55   Padding padding_;
56   TensorFormat data_format_;
57 
58  public:
59   MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
60                  Padding pad, TensorFormat fm,
61                  const std::vector<int32>& dilations, bool is_depthwise = false)
context_(context)62       : context_(context),
63         strides_(strides),
64         dilations_(dilations),
65         padding_(pad),
66         data_format_(fm) {}
67 
~MklDnnConvUtil()68   virtual ~MklDnnConvUtil() { context_ = nullptr; }
69 
70   // Calculate Convolution strides
GetStridesInMklOrder(memory::dims * strides)71   virtual inline void GetStridesInMklOrder(memory::dims* strides) {
72     // For now we take the stride from the second and third dimensions only
73     // (we do not support striding on the batch or depth dimension).
74     CHECK_NOTNULL(strides);
75     if (strides_.size() == 4) {
76       int stride_rows = GetTensorDim(strides_, data_format_, 'H');
77       int stride_cols = GetTensorDim(strides_, data_format_, 'W');
78       *strides = {stride_rows, stride_cols};
79     } else if (strides_.size() == 5) {
80       int stride_planes = GetTensorDim(strides_, data_format_, '0');
81       int stride_rows = GetTensorDim(strides_, data_format_, '1');
82       int stride_cols = GetTensorDim(strides_, data_format_, '2');
83       *strides = {stride_planes, stride_rows, stride_cols};
84     }
85   }
86 
87   // Calculate Convolution dilations
GetDilationsInMklOrder(memory::dims * dilations)88   virtual inline void GetDilationsInMklOrder(memory::dims* dilations) {
89     // For now we take the dilation from the second and third dimensions only
90     // (we do not support dilation on the batch or depth dimension).
91     CHECK_NOTNULL(dilations);
92     if (dilations_.size() == 4) {
93       int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
94       int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
95       *dilations = {dilations_rows, dilations_cols};
96     } else if (dilations_.size() == 5) {
97       int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
98       int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
99       int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
100       *dilations = {dilations_planes, dilations_rows, dilations_cols};
101     }
102   }
103 
104   // Calculate Convolution input size in MKL-DNN order. MKL-DNN
105   // requires input in NCHW/NCDHW format. Function does not return anything.
106   // But errors arising from sanity checks are returned in context's
107   // status.
GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)108   virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
109                                              memory::dims* input_dims) {
110 #define CHECK_BOUNDS(val, err_msg)                                     \
111   do {                                                                 \
112     OP_REQUIRES(context_,                                              \
113                 FastBoundsCheck(val, std::numeric_limits<int>::max()), \
114                 errors::InvalidArgument(err_msg));                     \
115   } while (0)
116 
117     CHECK_NOTNULL(input_dims);
118 
119     // Input channel
120     int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
121     int input_depth = static_cast<int>(input_depth_raw);
122 
123     // Input batch
124     int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
125     CHECK_BOUNDS(input_batch_raw, "Input batch too large");
126     int input_batch = static_cast<int>(input_batch_raw);
127 
128     if (strides_.size() == 4) {  // NCHW format for Conv2D
129       // Input rows/height
130       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
131       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
132       int input_rows = static_cast<int>(input_rows_raw);
133 
134       // Input columns/width
135       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
136       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
137       int input_cols = static_cast<int>(input_cols_raw);
138 
139       // MKL-DNN always requires input in NCHW format Conv2D.
140       std::vector<int> mkldnn_sizes(4, -1);
141       mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
142       mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
143       mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
144       mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
145 
146       *input_dims = mkldnn_sizes;
147     } else if (strides_.size() == 5) {  // NCDHW format for Conv3D
148       // Input planes/third-dimension
149       int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
150       CHECK_BOUNDS(input_planes_raw, "Input depth too large");
151       int input_planes = static_cast<int>(input_planes_raw);
152 
153       // Input rows/height
154       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
155       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
156       int input_rows = static_cast<int>(input_rows_raw);
157 
158       // Input columns/width
159       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
160       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
161       int input_cols = static_cast<int>(input_cols_raw);
162 
163       // MKL-DNN always requires input in NCDHW format for Conv3D.
164       std::vector<int> mkldnn_sizes(5, -1);
165       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
166       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
167       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
168       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
169       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
170 
171       *input_dims = mkldnn_sizes;
172     }
173 #undef CHECK_BOUNDS
174   }
175 
176   // Calculate Convolution filter size in MKL-DNN order.
177   // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
178   // Function does not return anything.
179   // But errors arising from sanity checks are returned in context's
180   // status. This function differs from GetConvFilterSizeInMklOrder in
181   // parameter for input - it accepts src_shape since Convolution Backward
182   // Input gets shape of input tensor rather than actual tensor (Convolution
183   // forward gets actual tensor as input).
184   //
185   // TODO(nhasabni): Add similar function for input and filter in MklShape.
GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool is_depthwise)186   virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape,
187                                               const TensorShape& filter_shape,
188                                               memory::dims* filter_dims,
189                                               bool is_depthwise) {
190     CHECK_NOTNULL(filter_dims);
191 
192     OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
193                 errors::InvalidArgument((strides_.size() == 4)
194                                             ? "filter must be 4-dimensional: "
195                                             : "filter must be 5-dimensional: ",
196                                         filter_shape.DebugString()));
197 
198     for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
199       OP_REQUIRES(context_,
200                   FastBoundsCheck(filter_shape.dim_size(i),
201                                   std::numeric_limits<int>::max()),
202                   errors::InvalidArgument("filter too large"));
203     }
204 
205     int input_depth = GetTensorDim(input_shape, data_format_, 'C');
206 
207     if (strides_.size() == 4) {  // Conv2D
208       OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
209                   errors::InvalidArgument(
210                       "input and filter must have the same depth: ",
211                       input_depth, " vs ", filter_shape.dim_size(2)));
212 
213       // TF filter is always in (rows, cols, in_depth, out_depth) order.
214       int filter_rows =
215           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H));
216       int filter_cols =
217           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W));
218       int filter_in_depth =
219           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I));
220       int filter_out_depth =
221           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O));
222       // MKL-DNN always needs filter in OIHW format for regular convolutions
223       // and GOIHW for grouped/depthwise convolutions,
224       // OIHW = (out_depth, in_depth, rows, cols)
225       // GOIHW = (group, out_depth, in_depth, rows, cols)
226       // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
227       if (is_depthwise) {
228         std::vector<int> mkldnn_sizes(5, -1);
229         mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
230         mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
231         mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
232         mkldnn_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows;
233         mkldnn_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols;
234 
235         *filter_dims = mkldnn_sizes;
236       } else {
237         std::vector<int> mkldnn_sizes(4, -1);
238         mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth;
239         mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth;
240         mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
241         mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
242 
243         *filter_dims = mkldnn_sizes;
244       }
245     } else {  // Conv3D
246       OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
247                   errors::InvalidArgument(
248                       "input and filter must have the same depth: ",
249                       input_depth, " vs ", filter_shape.dim_size(3)));
250 
251       // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
252       int filter_planes =
253           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P));
254       int filter_rows =
255           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H));
256       int filter_cols =
257           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W));
258       int filter_in_depth =
259           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I));
260       int filter_out_depth =
261           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O));
262 
263       // MKL-DNN always needs filter in OIDHW format.
264       // OIDHW = (out_depth, in_depth, planes, rows, cols)
265       std::vector<int> mkldnn_sizes(5, -1);
266       mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
267       mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
268       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
269       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
270       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
271 
272       *filter_dims = mkldnn_sizes;
273     }
274   }
275 
276   // Calculate Convolution filter size in MKL-DNN order.
277   // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
278   // Function does not return anything. But errors arising from sanity
279   // checks are returned in context's status.
GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool is_depthwise)280   virtual inline void GetFilterSizeInMklOrder(size_t src_index,
281                                               size_t filter_index,
282                                               memory::dims* filter_dims,
283                                               bool is_depthwise) {
284     CHECK_NOTNULL(filter_dims);
285     GetFilterSizeInMklOrder(GetTfShape(context_, src_index),
286                             GetTfShape(context_, filter_index), filter_dims,
287                             is_depthwise);
288   }
289 
290   // Calculate Bias size for 2D or 3D Convolution. Function does not
291   // return anything, but may set an error in context status.
GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)292   virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
293                                             memory::dims* bias_dims) {
294     const Tensor& bias = MklGetInput(context_, bias_index);
295     OP_REQUIRES(context_, bias.dims() == 1,
296                 errors::InvalidArgument("bias must be 1-dimensional: ",
297                                         bias.shape().DebugString()));
298 
299     *bias_dims = {static_cast<int>(bias.dim_size(0))};
300   }
301 
302   // Function to calculate output and padding size for 2D/3D convolution.
303   //
304   // Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
305   // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
306   // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
307   // NDHWC||NCDHW(Conv3D) format depending on data format.
308   // Function also calculates left, right, top and bottom pads.
309   // Function does not return any status which is set with context status.
310   //
311   // TODO(nhasabni): Add similar function for input and filter in MklShape.
312   virtual inline void GetOutputAndPadSizeInMklOrder(
313       const TensorShape& input_shape, const TensorShape& filter_shape,
314       const memory::dims& strides, const memory::dims& dilations,
315       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
316       memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false,
317       bool is_depthwise = false) {
318     CHECK_NOTNULL(output_dims_tf_order);
319     CHECK_NOTNULL(output_dims_mkl_order);
320     CHECK_NOTNULL(pad_l);
321     CHECK_NOTNULL(pad_r);
322 
323     bool is_conv2d = (strides_.size() == 4);
324     int input_planes, input_rows, input_cols;
325     if (is_conv2d) {
326       input_rows = GetTensorDim(input_shape, data_format_, 'H');
327       input_cols = GetTensorDim(input_shape, data_format_, 'W');
328     } else {
329       input_planes = GetTensorDim(input_shape, data_format_, '0');
330       input_rows = GetTensorDim(input_shape, data_format_, '1');
331       input_cols = GetTensorDim(input_shape, data_format_, '2');
332     }
333 
334     // Filter dimension
335     // Conv2D:
336     //    First dimension: rows/height.
337     //    Second dimension: cols/width.
338     // Conv3D:
339     //    First dimension: planes/depth.
340     //    Second dimension: rows/height.
341     //    Third dimension: cols/width.
342 
343     int filter_planes, filter_rows, filter_cols;
344     if (is_conv2d) {
345       filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H);
346       filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W);
347     } else {
348       filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P);
349       filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H);
350       filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W);
351     }
352 
353     int stride_planes, stride_rows, stride_cols;
354     int dilation_planes, dilation_rows, dilation_cols;
355     if (is_conv2d) {
356       // Conv2D stride is a vector of 2 elements: {s_r, s_c}
357       stride_rows = strides[0];
358       stride_cols = strides[1];
359       dilation_rows = dilations[0];
360       dilation_cols = dilations[1];
361     } else {
362       // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
363       stride_planes = strides[0];
364       stride_rows = strides[1];
365       stride_cols = strides[2];
366       dilation_planes = dilations[0];
367       dilation_rows = dilations[1];
368       dilation_cols = dilations[2];
369     }
370 
371     // Output batch is same as input batch.
372     int out_batch = GetTensorDim(input_shape, data_format_, 'N');
373     int out_depth;
374 
375     // TODO add support for 3-D Depthwise
376 
377     // Output depth is same as last dimension for filters for regular
378     // convolutions. For depthwise it is in_depth * channel_multiplier.
379     // The channel_multiplier is the last dimension of TF filter for
380     // depthwise convolutions.
381     if (is_depthwise) {
382       out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) *
383                    filter_shape.dim_size(TF_2DFILTER_DIM_O));
384     } else {
385       out_depth = filter_shape.dim_size(
386           is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O)
387                     : static_cast<int>(TF_3DFILTER_DIM_O));
388     }
389 
390     int64 out_rows = 0, out_cols = 0, out_planes = 0;
391     int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
392     int64 pad_D1, pad_D2;
393 
394     if (is_conv2d) {
395       Padding padding_type;
396       if (pad_enabled) {
397         padding_type = Padding::EXPLICIT;
398         pad_top = static_cast<int64>((*pad_l)[0]);
399         pad_left = static_cast<int64>((*pad_l)[1]);
400         pad_bottom = static_cast<int64>((*pad_r)[0]);
401         pad_right = static_cast<int64>((*pad_r)[1]);
402       } else {
403         padding_type = padding_;
404       }
405       OP_REQUIRES_OK(context_,
406                      GetWindowedOutputSizeVerboseV2(
407                          input_rows, filter_rows, dilation_rows, stride_rows,
408                          padding_type, &out_rows, &pad_top, &pad_bottom));
409       OP_REQUIRES_OK(context_,
410                      GetWindowedOutputSizeVerboseV2(
411                          input_cols, filter_cols, dilation_cols, stride_cols,
412                          padding_type, &out_cols, &pad_left, &pad_right));
413     } else {
414       OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
415                                    input_planes, filter_planes, stride_planes,
416                                    padding_, &out_planes, &pad_D1, &pad_D2));
417       OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
418                                    input_rows, filter_rows, stride_rows,
419                                    padding_, &out_rows, &pad_top, &pad_bottom));
420       OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
421                                    input_cols, filter_cols, stride_cols,
422                                    padding_, &out_cols, &pad_left, &pad_right));
423     }
424 
425     if (is_conv2d) {
426       // Conv + pad fusion is enabled only for 2D.
427       // If pad_enabled, i.e., pad and conv op are fused, then
428       // all pads are already passed from pad op through
429       // *pad_l and *pad_r and they don't need to be set here.
430       if (!pad_enabled) {
431         *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
432         *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
433       }
434     } else {
435       // Set padding for Conv3D here
436       *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
437                 static_cast<int>(pad_left)};
438       *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
439                 static_cast<int>(pad_right)};
440     }
441     // Tensorflow output is in data_format order.
442     //     Conv2D: NHWC or NCHW
443     //     Conv3D: NDHWC or NCDHW
444     // MKL-DNN uses asymetric padding.
445     TensorShape out_shape =
446         is_conv2d
447             ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
448                               out_depth)
449             : ShapeFromFormat(data_format_, out_batch,
450                               {{out_planes, out_rows, out_cols}}, out_depth);
451     *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
452 
453     if (is_conv2d) {
454       // For Conv2D, MKL-DNN always needs output in NCHW format.
455       std::vector<int> mkldnn_sizes(4, -1);
456       mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
457       mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
458       mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
459       mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
460       *output_dims_mkl_order = mkldnn_sizes;
461     } else {
462       std::vector<int> mkldnn_sizes(5, -1);
463       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
464       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
465       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
466       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
467       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
468       *output_dims_mkl_order = mkldnn_sizes;
469     }
470   }
471 
472   // Calculate output and pad size of forward Convolution operator.
473   // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
474   //
475   // Function does not return anything, but sets error in context status.
GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_depthwise)476   inline void GetOutputAndPadSizeInMklOrder(
477       size_t src_index, size_t filter_index, const memory::dims& strides,
478       const memory::dims& dilations, memory::dims* output_dims_tf_order,
479       memory::dims* output_dims_mkl_order, memory::dims* pad_l,
480       memory::dims* pad_r, bool is_depthwise) {
481     CHECK_NOTNULL(output_dims_tf_order);
482     CHECK_NOTNULL(output_dims_mkl_order);
483     CHECK_NOTNULL(pad_l);
484     CHECK_NOTNULL(pad_r);
485 
486     auto input_tf_shape = GetTfShape(context_, src_index);
487     auto filter_tf_shape = GetTfShape(context_, filter_index);
488 
489     if (strides_.size() == 4) {
490       // Conv2D
491       OP_REQUIRES(context_, input_tf_shape.dims() == 4,
492                   errors::InvalidArgument("input must be 4-dimensional",
493                                           input_tf_shape.DebugString()));
494     } else {
495       // Conv3D
496       OP_REQUIRES(context_, input_tf_shape.dims() == 5,
497                   errors::InvalidArgument("input must be 5-dimensional",
498                                           input_tf_shape.DebugString()));
499     }
500 
501     GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
502                                   dilations, output_dims_tf_order,
503                                   output_dims_mkl_order, pad_l, pad_r,
504                                   is_depthwise);
505   }
506 
507   // Wrapper function to calculate input, filter, and output sizes of
508   // Conv2D/Conv3D in MKL order:
509   //     Conv2D: NCHW for input and output; OIHW for filter.
510   //     Conv3D: NCDHW for input and output; OIDHW for filter.
511   // Function also calculates output shape in Tensorflow order.
512   // Additionally, it also calculates strides and paddings.
513   //
514   // Function does not return anything, but sets error in context status.
515   inline void GetConvFwdSizesInMklOrder(
516       const TensorShape& input_shape, const TensorShape& filter_shape,
517       memory::dims* input_dims, memory::dims* filter_dims,
518       memory::dims* strides, memory::dims* dilations,
519       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
520       memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false,
521       bool is_depthwise = false) {
522     CHECK_NOTNULL(input_dims);
523     CHECK_NOTNULL(filter_dims);
524     CHECK_NOTNULL(strides);
525     CHECK_NOTNULL(dilations);
526     CHECK_NOTNULL(output_dims_tf_order);
527     CHECK_NOTNULL(output_dims_mkl_order);
528     CHECK_NOTNULL(pad_l);
529     CHECK_NOTNULL(pad_r);
530 
531     GetInputSizeInMklOrder(input_shape, input_dims);
532     if (!context_->status().ok()) return;
533     GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims,
534                             is_depthwise);
535     if (!context_->status().ok()) return;
536     GetStridesInMklOrder(strides);
537     GetDilationsInMklOrder(dilations);
538     GetOutputAndPadSizeInMklOrder(
539         input_shape, filter_shape, *strides, *dilations, output_dims_tf_order,
540         output_dims_mkl_order, pad_l, pad_r, pad_enabled, is_depthwise);
541     if (!context_->status().ok()) return;
542   }
543 };
544 
545 /////////////////////////////////////////////////////////////////////
546 ///  Common class that implements ConvBackpropFilter and Input
547 /////////////////////////////////////////////////////////////////////
548 
549 template <typename Device, class T, bool is_depthwise>
550 class MklConvBackpropCommonOp : public OpKernel {
551  public:
~MklConvBackpropCommonOp()552   ~MklConvBackpropCommonOp() {}
MklConvBackpropCommonOp(OpKernelConstruction * context)553   explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
554       : OpKernel(context) {
555     string data_format_str;
556     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
557     OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
558                 errors::InvalidArgument("Invalid data format"));
559     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
560     int stride_n = GetTensorDim(strides_, data_format_, 'N');
561     int stride_c = GetTensorDim(strides_, data_format_, 'C');
562     const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
563     const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
564     OP_REQUIRES(
565         context, (stride_n == 1 && stride_c == 1),
566         errors::InvalidArgument("Current implementation does not yet support "
567                                 "strides in the batch and depth dimensions."));
568 
569     // Depthwise Convolution doesn't have dilation parameter
570     if (!is_depthwise) {
571       OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
572       if (strides_.size() == 4) {
573         // Check Conv2D dilations
574         OP_REQUIRES(
575             context, dilations_.size() == 4,
576             errors::InvalidArgument("Sliding window dilations field must "
577                                     "specify 4 dimensions"));
578         int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
579         int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
580         int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
581         int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
582         OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
583                     errors::InvalidArgument(
584                         "Current implementation does not yet support "
585                         "dilations in the batch and depth dimensions."));
586         OP_REQUIRES(
587             context, dilation_h > 0 && dilation_w > 0,
588             errors::InvalidArgument("Dilated rates should be larger than 0."));
589       }
590     } else {
591       // Set dilations as 1 for depthwise conv
592       // for future support to align with Tensorflow
593       dilations_ = {1, 1, 1, 1};
594     }
595 
596     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
597   }
598 
599  protected:
600   // data members accessible to derived classes.
601   std::vector<int32> dilations_;
602   std::vector<int32> strides_;
603   Padding padding_;
604   TensorFormat data_format_;  // NCHW or NHWC
605 };
606 
607 /////////////////////////////////////////////////////////////////////
608 ///  Dummy Mkl op that is just used for operators that are intermediate
609 ///  output of node fusion in the graph
610 /////////////////////////////////////////////////////////////////////
611 
612 template <typename Device, typename T>
613 class MklDummyOp : public OpKernel {
614  public:
~MklDummyOp()615   ~MklDummyOp() {}
616 
MklDummyOp(OpKernelConstruction * context)617   explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {}
618 
Compute(OpKernelContext * context)619   void Compute(OpKernelContext* context) override {
620     TF_CHECK_OK(
621         errors::Unimplemented("This is a dummy op."
622                               "It should not have been invoked."));
623   }
624 };
625 
626 }  // namespace tensorflow
627 
628 #endif  // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
629