1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 18 19 #include <limits> 20 #include <memory> 21 #include <vector> 22 23 #include "mkldnn.hpp" 24 #include "tensorflow/core/framework/bounds_check.h" 25 #include "tensorflow/core/framework/numeric_op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/tensor_slice.h" 31 #include "tensorflow/core/kernels/conv_grad_ops.h" 32 #include "tensorflow/core/kernels/ops_util.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/lib/strings/numbers.h" 36 #include "tensorflow/core/lib/strings/str_util.h" 37 #include "tensorflow/core/platform/logging.h" 38 #include "tensorflow/core/platform/macros.h" 39 #include "tensorflow/core/util/mkl_util.h" 40 #include "tensorflow/core/util/padding.h" 41 #include "tensorflow/core/util/tensor_format.h" 42 43 using mkldnn::convolution_direct; 44 using mkldnn::convolution_forward; 45 using mkldnn::prop_kind; 46 using mkldnn::stream; 47 48 namespace tensorflow { 49 50 class MklDnnConvUtil { 51 protected: 52 OpKernelContext* context_; // We don't own this. 53 std::vector<int32> strides_; 54 std::vector<int32> dilations_; 55 Padding padding_; 56 TensorFormat data_format_; 57 58 public: 59 MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides, 60 Padding pad, TensorFormat fm, 61 const std::vector<int32>& dilations, bool is_depthwise = false) context_(context)62 : context_(context), 63 strides_(strides), 64 dilations_(dilations), 65 padding_(pad), 66 data_format_(fm) {} 67 ~MklDnnConvUtil()68 virtual ~MklDnnConvUtil() { context_ = nullptr; } 69 70 // Calculate Convolution strides GetStridesInMklOrder(memory::dims * strides)71 virtual inline void GetStridesInMklOrder(memory::dims* strides) { 72 // For now we take the stride from the second and third dimensions only 73 // (we do not support striding on the batch or depth dimension). 74 CHECK_NOTNULL(strides); 75 if (strides_.size() == 4) { 76 int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 77 int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 78 *strides = {stride_rows, stride_cols}; 79 } else if (strides_.size() == 5) { 80 int stride_planes = GetTensorDim(strides_, data_format_, '0'); 81 int stride_rows = GetTensorDim(strides_, data_format_, '1'); 82 int stride_cols = GetTensorDim(strides_, data_format_, '2'); 83 *strides = {stride_planes, stride_rows, stride_cols}; 84 } 85 } 86 87 // Calculate Convolution dilations GetDilationsInMklOrder(memory::dims * dilations)88 virtual inline void GetDilationsInMklOrder(memory::dims* dilations) { 89 // For now we take the dilation from the second and third dimensions only 90 // (we do not support dilation on the batch or depth dimension). 91 CHECK_NOTNULL(dilations); 92 if (dilations_.size() == 4) { 93 int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); 94 int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); 95 *dilations = {dilations_rows, dilations_cols}; 96 } else if (dilations_.size() == 5) { 97 int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); 98 int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); 99 int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); 100 *dilations = {dilations_planes, dilations_rows, dilations_cols}; 101 } 102 } 103 104 // Calculate Convolution input size in MKL-DNN order. MKL-DNN 105 // requires input in NCHW/NCDHW format. Function does not return anything. 106 // But errors arising from sanity checks are returned in context's 107 // status. GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)108 virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, 109 memory::dims* input_dims) { 110 #define CHECK_BOUNDS(val, err_msg) \ 111 do { \ 112 OP_REQUIRES(context_, \ 113 FastBoundsCheck(val, std::numeric_limits<int>::max()), \ 114 errors::InvalidArgument(err_msg)); \ 115 } while (0) 116 117 CHECK_NOTNULL(input_dims); 118 119 // Input channel 120 int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); 121 int input_depth = static_cast<int>(input_depth_raw); 122 123 // Input batch 124 int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); 125 CHECK_BOUNDS(input_batch_raw, "Input batch too large"); 126 int input_batch = static_cast<int>(input_batch_raw); 127 128 if (strides_.size() == 4) { // NCHW format for Conv2D 129 // Input rows/height 130 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); 131 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 132 int input_rows = static_cast<int>(input_rows_raw); 133 134 // Input columns/width 135 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); 136 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 137 int input_cols = static_cast<int>(input_cols_raw); 138 139 // MKL-DNN always requires input in NCHW format Conv2D. 140 std::vector<int> mkldnn_sizes(4, -1); 141 mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; 142 mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; 143 mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; 144 mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; 145 146 *input_dims = mkldnn_sizes; 147 } else if (strides_.size() == 5) { // NCDHW format for Conv3D 148 // Input planes/third-dimension 149 int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); 150 CHECK_BOUNDS(input_planes_raw, "Input depth too large"); 151 int input_planes = static_cast<int>(input_planes_raw); 152 153 // Input rows/height 154 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); 155 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 156 int input_rows = static_cast<int>(input_rows_raw); 157 158 // Input columns/width 159 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); 160 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 161 int input_cols = static_cast<int>(input_cols_raw); 162 163 // MKL-DNN always requires input in NCDHW format for Conv3D. 164 std::vector<int> mkldnn_sizes(5, -1); 165 mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; 166 mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; 167 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; 168 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows; 169 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols; 170 171 *input_dims = mkldnn_sizes; 172 } 173 #undef CHECK_BOUNDS 174 } 175 176 // Calculate Convolution filter size in MKL-DNN order. 177 // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. 178 // Function does not return anything. 179 // But errors arising from sanity checks are returned in context's 180 // status. This function differs from GetConvFilterSizeInMklOrder in 181 // parameter for input - it accepts src_shape since Convolution Backward 182 // Input gets shape of input tensor rather than actual tensor (Convolution 183 // forward gets actual tensor as input). 184 // 185 // TODO(nhasabni): Add similar function for input and filter in MklShape. GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool is_depthwise)186 virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, 187 const TensorShape& filter_shape, 188 memory::dims* filter_dims, 189 bool is_depthwise) { 190 CHECK_NOTNULL(filter_dims); 191 192 OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), 193 errors::InvalidArgument((strides_.size() == 4) 194 ? "filter must be 4-dimensional: " 195 : "filter must be 5-dimensional: ", 196 filter_shape.DebugString())); 197 198 for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { 199 OP_REQUIRES(context_, 200 FastBoundsCheck(filter_shape.dim_size(i), 201 std::numeric_limits<int>::max()), 202 errors::InvalidArgument("filter too large")); 203 } 204 205 int input_depth = GetTensorDim(input_shape, data_format_, 'C'); 206 207 if (strides_.size() == 4) { // Conv2D 208 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), 209 errors::InvalidArgument( 210 "input and filter must have the same depth: ", 211 input_depth, " vs ", filter_shape.dim_size(2))); 212 213 // TF filter is always in (rows, cols, in_depth, out_depth) order. 214 int filter_rows = 215 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H)); 216 int filter_cols = 217 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W)); 218 int filter_in_depth = 219 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I)); 220 int filter_out_depth = 221 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O)); 222 // MKL-DNN always needs filter in OIHW format for regular convolutions 223 // and GOIHW for grouped/depthwise convolutions, 224 // OIHW = (out_depth, in_depth, rows, cols) 225 // GOIHW = (group, out_depth, in_depth, rows, cols) 226 // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 227 if (is_depthwise) { 228 std::vector<int> mkldnn_sizes(5, -1); 229 mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; 230 mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; 231 mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1; 232 mkldnn_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; 233 mkldnn_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; 234 235 *filter_dims = mkldnn_sizes; 236 } else { 237 std::vector<int> mkldnn_sizes(4, -1); 238 mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth; 239 mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth; 240 mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; 241 mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; 242 243 *filter_dims = mkldnn_sizes; 244 } 245 } else { // Conv3D 246 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), 247 errors::InvalidArgument( 248 "input and filter must have the same depth: ", 249 input_depth, " vs ", filter_shape.dim_size(3))); 250 251 // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. 252 int filter_planes = 253 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P)); 254 int filter_rows = 255 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H)); 256 int filter_cols = 257 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W)); 258 int filter_in_depth = 259 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I)); 260 int filter_out_depth = 261 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O)); 262 263 // MKL-DNN always needs filter in OIDHW format. 264 // OIDHW = (out_depth, in_depth, planes, rows, cols) 265 std::vector<int> mkldnn_sizes(5, -1); 266 mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; 267 mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; 268 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; 269 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; 270 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; 271 272 *filter_dims = mkldnn_sizes; 273 } 274 } 275 276 // Calculate Convolution filter size in MKL-DNN order. 277 // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. 278 // Function does not return anything. But errors arising from sanity 279 // checks are returned in context's status. GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool is_depthwise)280 virtual inline void GetFilterSizeInMklOrder(size_t src_index, 281 size_t filter_index, 282 memory::dims* filter_dims, 283 bool is_depthwise) { 284 CHECK_NOTNULL(filter_dims); 285 GetFilterSizeInMklOrder(GetTfShape(context_, src_index), 286 GetTfShape(context_, filter_index), filter_dims, 287 is_depthwise); 288 } 289 290 // Calculate Bias size for 2D or 3D Convolution. Function does not 291 // return anything, but may set an error in context status. GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)292 virtual inline void GetBiasSizeInMklOrder(size_t bias_index, 293 memory::dims* bias_dims) { 294 const Tensor& bias = MklGetInput(context_, bias_index); 295 OP_REQUIRES(context_, bias.dims() == 1, 296 errors::InvalidArgument("bias must be 1-dimensional: ", 297 bias.shape().DebugString())); 298 299 *bias_dims = {static_cast<int>(bias.dim_size(0))}; 300 } 301 302 // Function to calculate output and padding size for 2D/3D convolution. 303 // 304 // Calculate output shape of Convolution in MKL-DNN and TensorFlow order. 305 // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. 306 // But TensorFlow output will be in NHWC||NCHW(Conv2D) or 307 // NDHWC||NCDHW(Conv3D) format depending on data format. 308 // Function also calculates left, right, top and bottom pads. 309 // Function does not return any status which is set with context status. 310 // 311 // TODO(nhasabni): Add similar function for input and filter in MklShape. 312 virtual inline void GetOutputAndPadSizeInMklOrder( 313 const TensorShape& input_shape, const TensorShape& filter_shape, 314 const memory::dims& strides, const memory::dims& dilations, 315 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 316 memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false, 317 bool is_depthwise = false) { 318 CHECK_NOTNULL(output_dims_tf_order); 319 CHECK_NOTNULL(output_dims_mkl_order); 320 CHECK_NOTNULL(pad_l); 321 CHECK_NOTNULL(pad_r); 322 323 bool is_conv2d = (strides_.size() == 4); 324 int input_planes, input_rows, input_cols; 325 if (is_conv2d) { 326 input_rows = GetTensorDim(input_shape, data_format_, 'H'); 327 input_cols = GetTensorDim(input_shape, data_format_, 'W'); 328 } else { 329 input_planes = GetTensorDim(input_shape, data_format_, '0'); 330 input_rows = GetTensorDim(input_shape, data_format_, '1'); 331 input_cols = GetTensorDim(input_shape, data_format_, '2'); 332 } 333 334 // Filter dimension 335 // Conv2D: 336 // First dimension: rows/height. 337 // Second dimension: cols/width. 338 // Conv3D: 339 // First dimension: planes/depth. 340 // Second dimension: rows/height. 341 // Third dimension: cols/width. 342 343 int filter_planes, filter_rows, filter_cols; 344 if (is_conv2d) { 345 filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H); 346 filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W); 347 } else { 348 filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P); 349 filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H); 350 filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W); 351 } 352 353 int stride_planes, stride_rows, stride_cols; 354 int dilation_planes, dilation_rows, dilation_cols; 355 if (is_conv2d) { 356 // Conv2D stride is a vector of 2 elements: {s_r, s_c} 357 stride_rows = strides[0]; 358 stride_cols = strides[1]; 359 dilation_rows = dilations[0]; 360 dilation_cols = dilations[1]; 361 } else { 362 // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} 363 stride_planes = strides[0]; 364 stride_rows = strides[1]; 365 stride_cols = strides[2]; 366 dilation_planes = dilations[0]; 367 dilation_rows = dilations[1]; 368 dilation_cols = dilations[2]; 369 } 370 371 // Output batch is same as input batch. 372 int out_batch = GetTensorDim(input_shape, data_format_, 'N'); 373 int out_depth; 374 375 // TODO add support for 3-D Depthwise 376 377 // Output depth is same as last dimension for filters for regular 378 // convolutions. For depthwise it is in_depth * channel_multiplier. 379 // The channel_multiplier is the last dimension of TF filter for 380 // depthwise convolutions. 381 if (is_depthwise) { 382 out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) * 383 filter_shape.dim_size(TF_2DFILTER_DIM_O)); 384 } else { 385 out_depth = filter_shape.dim_size( 386 is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O) 387 : static_cast<int>(TF_3DFILTER_DIM_O)); 388 } 389 390 int64 out_rows = 0, out_cols = 0, out_planes = 0; 391 int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; 392 int64 pad_D1, pad_D2; 393 394 if (is_conv2d) { 395 Padding padding_type; 396 if (pad_enabled) { 397 padding_type = Padding::EXPLICIT; 398 pad_top = static_cast<int64>((*pad_l)[0]); 399 pad_left = static_cast<int64>((*pad_l)[1]); 400 pad_bottom = static_cast<int64>((*pad_r)[0]); 401 pad_right = static_cast<int64>((*pad_r)[1]); 402 } else { 403 padding_type = padding_; 404 } 405 OP_REQUIRES_OK(context_, 406 GetWindowedOutputSizeVerboseV2( 407 input_rows, filter_rows, dilation_rows, stride_rows, 408 padding_type, &out_rows, &pad_top, &pad_bottom)); 409 OP_REQUIRES_OK(context_, 410 GetWindowedOutputSizeVerboseV2( 411 input_cols, filter_cols, dilation_cols, stride_cols, 412 padding_type, &out_cols, &pad_left, &pad_right)); 413 } else { 414 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( 415 input_planes, filter_planes, stride_planes, 416 padding_, &out_planes, &pad_D1, &pad_D2)); 417 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( 418 input_rows, filter_rows, stride_rows, 419 padding_, &out_rows, &pad_top, &pad_bottom)); 420 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( 421 input_cols, filter_cols, stride_cols, 422 padding_, &out_cols, &pad_left, &pad_right)); 423 } 424 425 if (is_conv2d) { 426 // Conv + pad fusion is enabled only for 2D. 427 // If pad_enabled, i.e., pad and conv op are fused, then 428 // all pads are already passed from pad op through 429 // *pad_l and *pad_r and they don't need to be set here. 430 if (!pad_enabled) { 431 *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 432 *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 433 } 434 } else { 435 // Set padding for Conv3D here 436 *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top), 437 static_cast<int>(pad_left)}; 438 *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom), 439 static_cast<int>(pad_right)}; 440 } 441 // Tensorflow output is in data_format order. 442 // Conv2D: NHWC or NCHW 443 // Conv3D: NDHWC or NCDHW 444 // MKL-DNN uses asymetric padding. 445 TensorShape out_shape = 446 is_conv2d 447 ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, 448 out_depth) 449 : ShapeFromFormat(data_format_, out_batch, 450 {{out_planes, out_rows, out_cols}}, out_depth); 451 *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); 452 453 if (is_conv2d) { 454 // For Conv2D, MKL-DNN always needs output in NCHW format. 455 std::vector<int> mkldnn_sizes(4, -1); 456 mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; 457 mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; 458 mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); 459 mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); 460 *output_dims_mkl_order = mkldnn_sizes; 461 } else { 462 std::vector<int> mkldnn_sizes(5, -1); 463 mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; 464 mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; 465 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes); 466 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows); 467 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols); 468 *output_dims_mkl_order = mkldnn_sizes; 469 } 470 } 471 472 // Calculate output and pad size of forward Convolution operator. 473 // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. 474 // 475 // Function does not return anything, but sets error in context status. GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_depthwise)476 inline void GetOutputAndPadSizeInMklOrder( 477 size_t src_index, size_t filter_index, const memory::dims& strides, 478 const memory::dims& dilations, memory::dims* output_dims_tf_order, 479 memory::dims* output_dims_mkl_order, memory::dims* pad_l, 480 memory::dims* pad_r, bool is_depthwise) { 481 CHECK_NOTNULL(output_dims_tf_order); 482 CHECK_NOTNULL(output_dims_mkl_order); 483 CHECK_NOTNULL(pad_l); 484 CHECK_NOTNULL(pad_r); 485 486 auto input_tf_shape = GetTfShape(context_, src_index); 487 auto filter_tf_shape = GetTfShape(context_, filter_index); 488 489 if (strides_.size() == 4) { 490 // Conv2D 491 OP_REQUIRES(context_, input_tf_shape.dims() == 4, 492 errors::InvalidArgument("input must be 4-dimensional", 493 input_tf_shape.DebugString())); 494 } else { 495 // Conv3D 496 OP_REQUIRES(context_, input_tf_shape.dims() == 5, 497 errors::InvalidArgument("input must be 5-dimensional", 498 input_tf_shape.DebugString())); 499 } 500 501 GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, 502 dilations, output_dims_tf_order, 503 output_dims_mkl_order, pad_l, pad_r, 504 is_depthwise); 505 } 506 507 // Wrapper function to calculate input, filter, and output sizes of 508 // Conv2D/Conv3D in MKL order: 509 // Conv2D: NCHW for input and output; OIHW for filter. 510 // Conv3D: NCDHW for input and output; OIDHW for filter. 511 // Function also calculates output shape in Tensorflow order. 512 // Additionally, it also calculates strides and paddings. 513 // 514 // Function does not return anything, but sets error in context status. 515 inline void GetConvFwdSizesInMklOrder( 516 const TensorShape& input_shape, const TensorShape& filter_shape, 517 memory::dims* input_dims, memory::dims* filter_dims, 518 memory::dims* strides, memory::dims* dilations, 519 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 520 memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false, 521 bool is_depthwise = false) { 522 CHECK_NOTNULL(input_dims); 523 CHECK_NOTNULL(filter_dims); 524 CHECK_NOTNULL(strides); 525 CHECK_NOTNULL(dilations); 526 CHECK_NOTNULL(output_dims_tf_order); 527 CHECK_NOTNULL(output_dims_mkl_order); 528 CHECK_NOTNULL(pad_l); 529 CHECK_NOTNULL(pad_r); 530 531 GetInputSizeInMklOrder(input_shape, input_dims); 532 if (!context_->status().ok()) return; 533 GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims, 534 is_depthwise); 535 if (!context_->status().ok()) return; 536 GetStridesInMklOrder(strides); 537 GetDilationsInMklOrder(dilations); 538 GetOutputAndPadSizeInMklOrder( 539 input_shape, filter_shape, *strides, *dilations, output_dims_tf_order, 540 output_dims_mkl_order, pad_l, pad_r, pad_enabled, is_depthwise); 541 if (!context_->status().ok()) return; 542 } 543 }; 544 545 ///////////////////////////////////////////////////////////////////// 546 /// Common class that implements ConvBackpropFilter and Input 547 ///////////////////////////////////////////////////////////////////// 548 549 template <typename Device, class T, bool is_depthwise> 550 class MklConvBackpropCommonOp : public OpKernel { 551 public: ~MklConvBackpropCommonOp()552 ~MklConvBackpropCommonOp() {} MklConvBackpropCommonOp(OpKernelConstruction * context)553 explicit MklConvBackpropCommonOp(OpKernelConstruction* context) 554 : OpKernel(context) { 555 string data_format_str; 556 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 557 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), 558 errors::InvalidArgument("Invalid data format")); 559 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 560 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 561 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 562 const int64 stride_h = GetTensorDim(strides_, data_format_, 'H'); 563 const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); 564 OP_REQUIRES( 565 context, (stride_n == 1 && stride_c == 1), 566 errors::InvalidArgument("Current implementation does not yet support " 567 "strides in the batch and depth dimensions.")); 568 569 // Depthwise Convolution doesn't have dilation parameter 570 if (!is_depthwise) { 571 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 572 if (strides_.size() == 4) { 573 // Check Conv2D dilations 574 OP_REQUIRES( 575 context, dilations_.size() == 4, 576 errors::InvalidArgument("Sliding window dilations field must " 577 "specify 4 dimensions")); 578 int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 579 int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 580 int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 581 int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 582 OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), 583 errors::InvalidArgument( 584 "Current implementation does not yet support " 585 "dilations in the batch and depth dimensions.")); 586 OP_REQUIRES( 587 context, dilation_h > 0 && dilation_w > 0, 588 errors::InvalidArgument("Dilated rates should be larger than 0.")); 589 } 590 } else { 591 // Set dilations as 1 for depthwise conv 592 // for future support to align with Tensorflow 593 dilations_ = {1, 1, 1, 1}; 594 } 595 596 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 597 } 598 599 protected: 600 // data members accessible to derived classes. 601 std::vector<int32> dilations_; 602 std::vector<int32> strides_; 603 Padding padding_; 604 TensorFormat data_format_; // NCHW or NHWC 605 }; 606 607 ///////////////////////////////////////////////////////////////////// 608 /// Dummy Mkl op that is just used for operators that are intermediate 609 /// output of node fusion in the graph 610 ///////////////////////////////////////////////////////////////////// 611 612 template <typename Device, typename T> 613 class MklDummyOp : public OpKernel { 614 public: ~MklDummyOp()615 ~MklDummyOp() {} 616 MklDummyOp(OpKernelConstruction * context)617 explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} 618 Compute(OpKernelContext * context)619 void Compute(OpKernelContext* context) override { 620 TF_CHECK_OK( 621 errors::Unimplemented("This is a dummy op." 622 "It should not have been invoked.")); 623 } 624 }; 625 626 } // namespace tensorflow 627 628 #endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ 629