1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 19 #include <string.h> 20 #include <algorithm> 21 #include <map> 22 #include <vector> 23 24 #include "absl/strings/str_join.h" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/numeric_op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_slice.h" 32 #include "tensorflow/core/kernels/mkl_conv_ops.h" 33 #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h" 34 #include "tensorflow/core/kernels/no_op.h" 35 #include "tensorflow/core/kernels/ops_util.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/gtl/array_slice.h" 38 #include "tensorflow/core/lib/strings/numbers.h" 39 #include "tensorflow/core/lib/strings/str_util.h" 40 #include "tensorflow/core/lib/strings/strcat.h" 41 #include "tensorflow/core/platform/logging.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/util/padding.h" 44 #include "tensorflow/core/util/tensor_format.h" 45 46 #include "tensorflow/core/util/mkl_util.h" 47 48 #ifndef INTEL_MKL_ML_ONLY 49 #include "mkldnn.hpp" 50 51 using mkldnn::prop_kind; 52 using mkldnn::stream; 53 using mkldnn::convolution_forward; 54 using mkldnn::convolution_direct; 55 56 #else 57 #include "mkl_dnn.h" 58 #include "mkl_dnn_types.h" 59 #endif 60 61 namespace tensorflow { 62 63 #ifndef INTEL_MKL_ML_ONLY 64 65 // This structure aggregates multiple inputs to Conv2DFwd* methods. 66 struct MklConvFwdParams { 67 memory::dims src_dims; 68 memory::dims filter_dims; 69 memory::dims bias_dims; 70 memory::dims dst_dims; 71 memory::dims strides; 72 memory::dims dilations; 73 memory::dims padding_left; 74 memory::dims padding_right; 75 string dtypes = string(""); 76 struct PostOpParam { 77 string name; 78 std::vector<float> param; 79 }; 80 std::vector<PostOpParam> post_op_params; 81 MklConvFwdParamstensorflow::MklConvFwdParams82 MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, 83 memory::dims bias_dims, memory::dims dst_dims, 84 memory::dims strides, memory::dims dilations, 85 memory::dims padding_left, memory::dims padding_right) 86 : src_dims(src_dims), 87 filter_dims(filter_dims), 88 bias_dims(bias_dims), 89 dst_dims(dst_dims), 90 strides(strides), 91 dilations(dilations), 92 padding_left(padding_left), 93 padding_right(padding_right) {} 94 }; 95 96 typedef mkldnn::convolution_forward::primitive_desc ConvFwdPd; 97 98 // With quantization, input, filter, and output can have different types 99 // so we use different template parameter for each type 100 template <typename T, typename Tinput, typename Tfilter, typename Tbias, 101 typename Toutput> 102 class MklConvFwdPrimitive : public MklPrimitive { 103 public: MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)104 explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) 105 : cpu_engine_(engine::cpu, 0) { 106 context_.fwd_stream.reset(new stream(stream::kind::eager)); 107 // Create conv primitive 108 if (context_.conv_fwd == nullptr) { 109 Setup(convFwdDims); 110 } 111 } 112 ~MklConvFwdPrimitive()113 ~MklConvFwdPrimitive() {} 114 115 // Convolution forward execute with bias 116 // src_data: input data buffer of src 117 // filter_data: input data buffer of filter (weights) 118 // bias_data: input data buffer of bias 119 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data)120 void Execute(const Tinput* src_data, const Tfilter* filter_data, 121 const Tbias* bias_data, const Toutput* dst_data) { 122 context_.src_mem->set_data_handle( 123 static_cast<void*>(const_cast<Tinput*>(src_data))); 124 context_.filter_mem->set_data_handle( 125 static_cast<void*>(const_cast<Tfilter*>(filter_data))); 126 context_.bias_mem->set_data_handle( 127 static_cast<void*>(const_cast<Tbias*>(bias_data))); 128 context_.dst_mem->set_data_handle( 129 static_cast<void*>(const_cast<Toutput*>(dst_data))); 130 context_.fwd_stream->submit(context_.fwd_primitives); 131 132 // After exec, set data handle back 133 context_.src_mem->set_data_handle(DummyData); 134 context_.filter_mem->set_data_handle(DummyData); 135 context_.bias_mem->set_data_handle(DummyData); 136 context_.dst_mem->set_data_handle(DummyData); 137 138 return; 139 } 140 141 // Convolution forward execute without bias 142 // src_data: input data buffer of src 143 // filter_data: input data buffer of filter (weights) 144 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data)145 void Execute(const Tinput* src_data, const Tfilter* filter_data, 146 const Toutput* dst_data) { 147 context_.src_mem->set_data_handle( 148 static_cast<void*>(const_cast<Tinput*>(src_data))); 149 context_.filter_mem->set_data_handle( 150 static_cast<void*>(const_cast<Tfilter*>(filter_data))); 151 context_.dst_mem->set_data_handle( 152 static_cast<void*>(const_cast<Toutput*>(dst_data))); 153 context_.fwd_stream->submit(context_.fwd_primitives); 154 155 // After execution, set data handle back 156 context_.src_mem->set_data_handle(DummyData); 157 context_.filter_mem->set_data_handle(DummyData); 158 context_.dst_mem->set_data_handle(DummyData); 159 } 160 GetSrcMemoryFormat() const161 memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } 162 GetFilterMemoryFormat() const163 memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } 164 GetPrimitiveDesc() const165 std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const { 166 return context_.fwd_pd; 167 } 168 169 private: 170 // Primitive reuse context for Conv2D Fwd op 171 struct ConvFwdContext { 172 // Expected memory format for this primitive instance 173 memory::format src_fmt; 174 memory::format filter_fmt; 175 176 // MKLDNN memory 177 std::shared_ptr<mkldnn::memory> src_mem; 178 std::shared_ptr<mkldnn::memory> filter_mem; 179 std::shared_ptr<mkldnn::memory> bias_mem; 180 std::shared_ptr<mkldnn::memory> dst_mem; 181 182 // Desc & prmitive desc 183 std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc; 184 185 // Memory desc 186 std::shared_ptr<mkldnn::memory::desc> src_md; 187 std::shared_ptr<mkldnn::memory::desc> filter_md; 188 std::shared_ptr<mkldnn::memory::desc> bias_md; 189 std::shared_ptr<mkldnn::memory::desc> dst_md; 190 191 // Convolution primitive 192 std::shared_ptr<ConvFwdPd> fwd_pd; 193 std::shared_ptr<mkldnn::primitive> conv_fwd; 194 195 std::shared_ptr<mkldnn::stream> fwd_stream; 196 std::vector<mkldnn::primitive> fwd_primitives; 197 ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext198 ConvFwdContext() 199 : src_fmt(memory::format::any), 200 filter_fmt(memory::format::any), 201 src_mem(nullptr), 202 filter_mem(nullptr), 203 bias_mem(nullptr), 204 dst_mem(nullptr), 205 fwd_desc(nullptr), 206 src_md(nullptr), 207 filter_md(nullptr), 208 bias_md(nullptr), 209 fwd_pd(nullptr), 210 conv_fwd(nullptr), 211 fwd_stream(nullptr) {} 212 }; 213 Setup(const MklConvFwdParams & convFwdDims)214 void Setup(const MklConvFwdParams& convFwdDims) { 215 // Create memory descriptors for convolution data w/ no specified format 216 context_.src_md.reset(new memory::desc( 217 {convFwdDims.src_dims}, MklDnnType<Tinput>(), memory::format::any)); 218 219 context_.filter_md.reset(new memory::desc( 220 {convFwdDims.filter_dims}, MklDnnType<Tfilter>(), memory::format::any)); 221 222 context_.dst_md.reset(new memory::desc( 223 {convFwdDims.dst_dims}, MklDnnType<Toutput>(), memory::format::any)); 224 225 if (!convFwdDims.bias_dims.empty()) 226 context_.bias_md.reset(new memory::desc( 227 {convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::any)); 228 229 // Create a convolution 230 if (!convFwdDims.bias_dims.empty()) { 231 context_.fwd_desc.reset(new convolution_forward::desc( 232 prop_kind::forward, convolution_direct, *context_.src_md, 233 *context_.filter_md, *context_.bias_md, *context_.dst_md, 234 convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, 235 convFwdDims.padding_right, padding_kind::zero)); 236 } else { 237 context_.fwd_desc.reset(new convolution_forward::desc( 238 prop_kind::forward, convolution_direct, *context_.src_md, 239 *context_.filter_md, *context_.dst_md, convFwdDims.strides, 240 convFwdDims.dilations, convFwdDims.padding_left, 241 convFwdDims.padding_right, padding_kind::zero)); 242 } 243 244 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); 245 246 // Check if there is any fusions as post-ops 247 auto const& post_op_params = convFwdDims.post_op_params; 248 mkldnn::primitive_attr post_ops_attr; 249 mkldnn::post_ops post_ops; 250 if (!post_op_params.empty()) { 251 for (auto const& post_op_param : post_op_params) { 252 if (post_op_param.name == "relu") { 253 DCHECK_EQ(post_op_param.param.size(), 3); 254 float op_scale = post_op_param.param[0]; 255 float op_alpha = post_op_param.param[1]; 256 float op_beta = post_op_param.param[2]; 257 post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha, 258 op_beta); 259 } else if (post_op_param.name == "sum") { 260 DCHECK_EQ(post_op_param.param.size(), 1); 261 float op_scale = post_op_param.param[0]; 262 post_ops.append_sum(op_scale); 263 } else if (post_op_param.name == "output_scale") { 264 DCHECK_EQ(post_op_param.param.size(), 1); 265 std::vector<float> scales; 266 scales.push_back(post_op_param.param[0]); 267 post_ops_attr.set_output_scales(0, scales); 268 } else { 269 DCHECK((post_op_param.name == "relu") || 270 (post_op_param.name == "sum") || 271 (post_op_param.name == "output_scale")); 272 } 273 } 274 post_ops_attr.set_post_ops(post_ops); 275 context_.fwd_pd.reset( 276 new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); 277 } else { 278 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); 279 } 280 281 // Store the expected memory format 282 context_.src_fmt = static_cast<mkldnn::memory::format>( 283 context_.fwd_pd.get()->src_primitive_desc().desc().data.format); 284 285 context_.filter_fmt = static_cast<mkldnn::memory::format>( 286 context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); 287 288 // Create memory primitive based on dummy data 289 context_.src_mem.reset( 290 new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); 291 context_.filter_mem.reset( 292 new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); 293 context_.dst_mem.reset( 294 new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); 295 296 // Create convolution primitive and add it to net 297 if (!convFwdDims.bias_dims.empty()) { 298 context_.bias_mem.reset(new memory( 299 {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x}, 300 cpu_engine_}, 301 DummyData)); 302 context_.conv_fwd.reset(new convolution_forward( 303 *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, 304 *context_.bias_mem, *context_.dst_mem)); 305 } else { 306 context_.conv_fwd.reset( 307 new convolution_forward(*context_.fwd_pd, *context_.src_mem, 308 *context_.filter_mem, *context_.dst_mem)); 309 } 310 311 context_.fwd_primitives.push_back(*context_.conv_fwd); 312 return; 313 } 314 315 struct ConvFwdContext context_; 316 engine cpu_engine_; 317 }; 318 319 template <typename T, typename Tinput, typename Tfilter, typename Tbias, 320 typename Toutput> 321 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 322 public: Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)323 static MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>* Get( 324 const MklConvFwdParams& convFwdDims, bool do_not_cache) { 325 MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr; 326 327 if (do_not_cache) { 328 // Always create a new primitive 329 conv_fwd = new MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>( 330 convFwdDims); 331 } else { 332 // Try to find a suitable one in pool 333 conv_fwd = dynamic_cast< 334 MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>*>( 335 MklConvFwdPrimitiveFactory<T, Tinput, Tfilter, Tbias, 336 Toutput>::GetInstance() 337 .GetConvFwd(convFwdDims)); 338 if (conv_fwd == nullptr) { 339 conv_fwd = new MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>( 340 convFwdDims); 341 MklConvFwdPrimitiveFactory<T, Tinput, Tfilter, Tbias, 342 Toutput>::GetInstance() 343 .SetConvFwd(convFwdDims, conv_fwd); 344 } 345 } 346 347 return conv_fwd; 348 } 349 350 private: MklConvFwdPrimitiveFactory()351 MklConvFwdPrimitiveFactory() {} ~MklConvFwdPrimitiveFactory()352 ~MklConvFwdPrimitiveFactory() {} 353 354 static const int kDilationH = 0, kDilationW = 1; 355 GetInstance()356 static MklConvFwdPrimitiveFactory& GetInstance() { 357 static MklConvFwdPrimitiveFactory instance_; 358 return instance_; 359 } 360 CreateKey(const MklConvFwdParams & convFwdDims)361 static string CreateKey(const MklConvFwdParams& convFwdDims) { 362 string prefix = "conv_fwd_"; 363 FactoryKeyCreator key_creator; 364 key_creator.AddAsKey(prefix); 365 key_creator.AddAsKey(convFwdDims.src_dims); 366 key_creator.AddAsKey(convFwdDims.filter_dims); 367 key_creator.AddAsKey(convFwdDims.bias_dims); 368 key_creator.AddAsKey(convFwdDims.dst_dims); 369 key_creator.AddAsKey(convFwdDims.strides); 370 key_creator.AddAsKey(convFwdDims.dilations); 371 key_creator.AddAsKey(convFwdDims.padding_left); 372 key_creator.AddAsKey(convFwdDims.padding_right); 373 key_creator.AddAsKey(convFwdDims.dtypes); 374 375 // Generate keys for post-ops 376 for (auto const& post_op_param : convFwdDims.post_op_params) { 377 if (post_op_param.name == "relu") { 378 DCHECK_EQ(post_op_param.param.size(), 3); 379 key_creator.AddAsKey(post_op_param.name); 380 key_creator.AddAsKey(post_op_param.param[0]); 381 key_creator.AddAsKey(post_op_param.param[1]); 382 key_creator.AddAsKey(post_op_param.param[2]); 383 } else if (post_op_param.name == "sum") { 384 DCHECK_EQ(post_op_param.param.size(), 1); 385 key_creator.AddAsKey(post_op_param.name); 386 key_creator.AddAsKey(post_op_param.param[0]); 387 } else if (post_op_param.name == "output_scale") { 388 DCHECK_EQ(post_op_param.param.size(), 1); 389 key_creator.AddAsKey(post_op_param.name); 390 key_creator.AddAsKey(post_op_param.param[0]); 391 } else { 392 return string("not_a_key"); 393 } 394 } 395 396 return key_creator.GetKey(); 397 } 398 GetConvFwd(const MklConvFwdParams & convFwdDims)399 MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) { 400 string key = CreateKey(convFwdDims); 401 return this->GetOp(key); 402 } 403 SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)404 void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { 405 string key = CreateKey(convFwdDims); 406 this->SetOp(key, op); 407 } 408 }; 409 410 #endif 411 412 typedef Eigen::ThreadPoolDevice CPUDevice; 413 414 // For now, MKL-ML is default. So making MKL-DNN not a default choice. 415 #ifdef INTEL_MKL_ML_ONLY 416 template <typename Device, typename T, bool bias_enabled> 417 class MklConvOp : public OpKernel { 418 public: ~MklConvOp()419 ~MklConvOp() {} 420 MklConvOp(OpKernelConstruction * context)421 explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { 422 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 423 string data_format; 424 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 425 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 426 errors::InvalidArgument("Invalid data format")); 427 OP_REQUIRES(context, strides_.size() == 4, 428 errors::InvalidArgument("Sliding window strides field must " 429 "specify 4 dimensions")); 430 431 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 432 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 433 OP_REQUIRES( 434 context, stride_n == 1 && stride_c == 1, 435 errors::InvalidArgument("Current implementation does not yet support " 436 "strides in the batch and depth dimensions.")); 437 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 438 } 439 Compute(OpKernelContext * context)440 void Compute(OpKernelContext* context) override { 441 MklConv2DOpContext mkl_context; 442 const Tensor& input = MklGetInput(context, 0); 443 GetMklShape(context, 0, &(mkl_context.input_shape)); 444 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 445 446 const Tensor& filter = MklGetInput(context, 1); 447 MklShape mkl_filter_shape; 448 GetMklShape(context, 1, &mkl_filter_shape); 449 CHECK(!mkl_filter_shape.IsMklTensor()) 450 << "Conv filter should not be in MKL Layout"; 451 452 if (bias_enabled) { 453 const Tensor& bias = MklGetInput(context, 2); 454 OP_REQUIRES(context, bias.dims() == 1, 455 errors::InvalidArgument("bias must be 1-dimensional: ", 456 bias.shape().DebugString())); 457 } 458 459 if (!input_in_mkl_format) { 460 OP_REQUIRES(context, input.dims() == 4, 461 errors::InvalidArgument("input must be 4-dimensional", 462 input.shape().DebugString())); 463 } 464 465 OP_REQUIRES(context, filter.dims() == 4, 466 errors::InvalidArgument("filter must be 4-dimensional: ", 467 filter.shape().DebugString())); 468 469 for (int i = 0; i < 3; ++i) { 470 OP_REQUIRES( 471 context, 472 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), 473 errors::InvalidArgument("filter too large")); 474 } 475 476 const int64 input_depth = 477 input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') 478 : GetTensorDim(input, data_format_, 'C'); 479 OP_REQUIRES(context, input_depth == filter.dim_size(2), 480 errors::InvalidArgument( 481 "input and filter must have the same depth: ", input_depth, 482 " vs ", filter.dim_size(2))); 483 // The last dimension for filter is out_depth. 484 const int out_depth = static_cast<int>(filter.dim_size(3)); 485 486 // The second dimension for input is rows/height. 487 // The first dimension for filter is rows/height. 488 const int64 input_rows_raw = 489 input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') 490 : GetTensorDim(input, data_format_, 'H'); 491 OP_REQUIRES( 492 context, 493 FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), 494 errors::InvalidArgument("Input rows too large")); 495 const int input_rows = static_cast<int>(input_rows_raw); 496 const int filter_rows = static_cast<int>(filter.dim_size(0)); 497 498 // The third dimension for input is columns/width. 499 // The second dimension for filter is columns/width. 500 const int64 input_cols_raw = 501 input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') 502 : GetTensorDim(input, data_format_, 'W'); 503 OP_REQUIRES( 504 context, 505 FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), 506 errors::InvalidArgument("Input cols too large")); 507 const int input_cols = static_cast<int>(input_cols_raw); 508 const int filter_cols = static_cast<int>(filter.dim_size(1)); 509 510 // The first dimension for input is batch. 511 const int64 input_batch_raw = 512 input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') 513 : GetTensorDim(input, data_format_, 'N'); 514 OP_REQUIRES( 515 context, 516 FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()), 517 errors::InvalidArgument("batch is too large")); 518 const int batch = static_cast<int>(input_batch_raw); 519 520 // For now we take the stride from the second and third dimensions only (we 521 // do not support striding on the batch or depth dimension). 522 const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 523 const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 524 525 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; 526 OP_REQUIRES_OK(context, 527 GetWindowedOutputSize(input_rows, filter_rows, stride_rows, 528 padding_, &out_rows, &pad_rows)); 529 OP_REQUIRES_OK(context, 530 GetWindowedOutputSize(input_cols, filter_cols, stride_cols, 531 padding_, &out_cols, &pad_cols)); 532 TensorShape out_shape = 533 ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); 534 535 // Output tensor is of the following dimensions: 536 // [ in_batch, out_rows, out_cols, out_depth ] 537 Tensor* output = nullptr; 538 539 // If there is nothing to compute, return. 540 if (out_shape.num_elements() == 0) { 541 // Nothing to do, allocate output tensor and return 542 MklShape mkl_output_mkl_shape; 543 mkl_output_mkl_shape.SetMklTensor(false); 544 AllocateOutputSetMklShape(context, 0, &output, input.shape(), 545 mkl_output_mkl_shape); 546 return; 547 } 548 549 if (batch == 0) { 550 // Nothing to do, allocate output tensor and return 551 MklShape mkl_output_mkl_shape; 552 mkl_output_mkl_shape.SetMklTensor(false); 553 AllocateOutputSetMklShape(context, 0, &output, input.shape(), 554 mkl_output_mkl_shape); 555 return; 556 } 557 558 // Create MKL convolution primitives 559 mkl_context.in_dims = input_in_mkl_format 560 ? mkl_context.input_shape.GetDimension() 561 : input.dims(); 562 mkl_context.filter_dims = filter.dims(); 563 564 mkl_context.in_sizes[MklDims::W] = static_cast<size_t>(input_cols); 565 mkl_context.in_sizes[MklDims::H] = static_cast<size_t>(input_rows); 566 mkl_context.in_sizes[MklDims::C] = static_cast<size_t>(input_depth); 567 mkl_context.in_sizes[MklDims::N] = static_cast<size_t>(batch); 568 569 mkl_context.out_sizes[MklDims::W] = static_cast<size_t>(out_cols); 570 mkl_context.out_sizes[MklDims::H] = static_cast<size_t>(out_rows); 571 mkl_context.out_sizes[MklDims::C] = static_cast<size_t>(out_depth); 572 mkl_context.out_sizes[MklDims::N] = static_cast<size_t>(batch); 573 574 mkl_context.input_offset[0] = static_cast<int>(-pad_cols); 575 mkl_context.input_offset[1] = static_cast<int>(-pad_rows); 576 577 mkl_context.conv_stride[0] = static_cast<size_t>(stride_cols); 578 mkl_context.conv_stride[1] = static_cast<size_t>(stride_rows); 579 580 GetStridesFromSizes(data_format_, mkl_context.out_strides, 581 mkl_context.out_sizes); 582 GetStridesFromSizes(data_format_, mkl_context.in_strides, 583 mkl_context.in_sizes); 584 585 // TF filter dimension order (out_depth, in_depth, cols, rows) -> 586 // MKL filter dimension order (out_depth, in_depth, rows, cols) 587 mkl_context.filter_sizes[0] = filter.dim_size(1); // cols 588 mkl_context.filter_sizes[1] = filter.dim_size(0); // rows 589 mkl_context.filter_sizes[2] = filter.dim_size(2); // in_depth 590 mkl_context.filter_sizes[3] = filter.dim_size(3); // out_depth 591 592 // TF filter layout - (rows, cols, in_depth, out_depth) 593 mkl_context.filter_strides[0] = 594 filter.dim_size(2) * filter.dim_size(3); // cols 595 mkl_context.filter_strides[1] = 596 filter.dim_size(1) * filter.dim_size(2) * filter.dim_size(3); // rows 597 mkl_context.filter_strides[2] = filter.dim_size(3); // in_depth 598 mkl_context.filter_strides[3] = 1; // out_depth 599 600 if (bias_enabled) { 601 const Tensor& bias = MklGetInput(context, 2); 602 mkl_context.bias_sizes[0] = {static_cast<size_t>(bias.dim_size(0))}; 603 mkl_context.bias_strides[0] = {1}; 604 } 605 606 // Create Convolution Primitive 607 if (bias_enabled) { 608 CHECK_EQ( 609 dnnConvolutionCreateForwardBias_F32( 610 &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect, 611 mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes, 612 mkl_context.filter_sizes, mkl_context.conv_stride, 613 mkl_context.input_offset, dnnBorderZeros), 614 E_SUCCESS); 615 } else { 616 CHECK_EQ( 617 dnnConvolutionCreateForward_F32( 618 &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect, 619 mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes, 620 mkl_context.filter_sizes, mkl_context.conv_stride, 621 mkl_context.input_offset, dnnBorderZeros), 622 E_SUCCESS); 623 } 624 625 TensorShape mkl_output_tf_shape; 626 MklShape mkl_output_mkl_shape; 627 mkl_output_mkl_shape.SetMklTensor(true); 628 mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, dnnResourceDst); 629 mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes, 630 mkl_context.out_strides); 631 // MKL might change the dimension ordering 632 // Create mapping to recover the original TF dimension order 633 mkl_output_mkl_shape.SetTfDimOrder(mkl_context.in_dims, data_format_); 634 635 mkl_output_tf_shape.AddDim( 636 dnnLayoutGetMemorySize_F32( 637 static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) / 638 sizeof(T)); 639 AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, 640 mkl_output_mkl_shape); 641 // Filter output to be used in the backprop_input 642 TensorShape mkl_filter_output_tf_shape; 643 MklShape mkl_filter_output_mkl_shape; 644 mkl_filter_output_mkl_shape.SetMklTensor(true); 645 mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, 646 dnnResourceFilter); 647 648 size_t filter_sizes[4] = {static_cast<size_t>(filter.dim_size(0)), 649 static_cast<size_t>(filter.dim_size(1)), 650 static_cast<size_t>(filter.dim_size(2)), 651 static_cast<size_t>(filter.dim_size(3))}; 652 mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes, 653 mkl_context.filter_strides); 654 655 mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims, 656 data_format_); 657 mkl_filter_output_tf_shape.AddDim( 658 dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 659 mkl_filter_output_mkl_shape.GetMklLayout())) / 660 sizeof(T)); 661 AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter, 662 mkl_filter_output_tf_shape, 663 mkl_filter_output_mkl_shape); 664 665 mkl_context.conv_res[dnnResourceDst] = 666 static_cast<void*>(output->flat<T>().data()); 667 668 mkl_context.MklCreateInputLayouts(context); 669 670 // Temp tensor used to allocate tmp buffers 671 Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor, 672 mkl_tmp_bias_buf_tensor; 673 mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor, 674 &mkl_tmp_filter_buf_tensor, 675 &mkl_tmp_bias_buf_tensor); 676 677 // Execute convolution 678 CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res), 679 E_SUCCESS); 680 681 mkl_context.MklCleanup(); 682 } 683 684 private: 685 typedef struct { 686 int in_dims; 687 size_t in_sizes[4]; 688 size_t in_strides[4]; 689 size_t out_sizes[4]; 690 size_t out_strides[4]; 691 int filter_dims; 692 size_t filter_sizes[4]; 693 size_t filter_strides[4]; 694 size_t bias_sizes[1]; 695 size_t bias_strides[1]; 696 int input_offset[2]; 697 size_t conv_stride[2]; 698 MklShape input_shape; 699 dnnPrimitive_t prim_fwd; 700 void* conv_res[dnnResourceNumber]; 701 dnnLayout_t lt_filter, lt_bias, lt_input; 702 Tensor* output_filter = nullptr; 703 704 // Create MKL dnnLayout_t objects for tensors coming into the layer MklCreateInputLayoutstensorflow::MklConvOp::__anonc9bfa00e0108705 void MklCreateInputLayouts(OpKernelContext* context) { 706 bool input_in_mkl_format = input_shape.IsMklTensor(); 707 if (input_in_mkl_format) { 708 lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout()); 709 } else { 710 CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), 711 E_SUCCESS); 712 } 713 714 CHECK_EQ(dnnLayoutCreate_F32(<_filter, filter_dims, filter_sizes, 715 filter_strides), 716 E_SUCCESS); 717 718 if (bias_enabled) { 719 CHECK_EQ(dnnLayoutCreate_F32(<_bias, 1, bias_sizes, bias_strides), 720 E_SUCCESS); 721 } 722 } 723 724 // Compare incoming tensor layouts with MKL preferred layouts and convert 725 // data to the preferred layout if necessary MklPrepareConvolutionInputstensorflow::MklConvOp::__anonc9bfa00e0108726 void MklPrepareConvolutionInputs(OpKernelContext* context, 727 Tensor* mkl_tmp_input_buf_tensor, 728 Tensor* mkl_tmp_filter_buf_tensor, 729 Tensor* mkl_tmp_bias_buf_tensor) { 730 bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias; 731 dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias, 732 mkl_prim_convert_input; 733 dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias, 734 mkl_lt_internal_input; 735 void *mkl_buf_convert_input, *mkl_buf_convert_filter, 736 *mkl_buf_convert_bias; 737 mkl_prim_convert_filter = nullptr; 738 mkl_prim_convert_bias = nullptr; 739 mkl_prim_convert_input = nullptr; 740 mkl_lt_internal_filter = nullptr; 741 mkl_lt_internal_bias = nullptr; 742 mkl_lt_internal_input = nullptr; 743 mkl_buf_convert_input = nullptr; 744 mkl_buf_convert_filter = nullptr; 745 mkl_buf_convert_bias = nullptr; 746 747 // Compare with internal layouts and convert if needed 748 const Tensor& input = MklGetInput(context, 0); 749 void* mkl_buf_input = 750 const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); 751 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, 752 prim_fwd, dnnResourceSrc), 753 E_SUCCESS); 754 mkl_convert_input = 755 !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); 756 if (mkl_convert_input) { 757 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, 758 mkl_lt_internal_input), 759 E_SUCCESS); 760 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, 761 &mkl_buf_convert_input); 762 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, 763 mkl_buf_convert_input), 764 E_SUCCESS); 765 dnnDelete_F32(mkl_prim_convert_input); 766 } 767 dnnLayoutDelete_F32(mkl_lt_internal_input); 768 769 conv_res[dnnResourceSrc] = 770 (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; 771 772 const Tensor& filter = MklGetInput(context, 1); 773 void* mkl_buf_filter = 774 const_cast<void*>(static_cast<const void*>(filter.flat<T>().data())); 775 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_filter, 776 prim_fwd, dnnResourceFilter), 777 E_SUCCESS); 778 mkl_convert_filter = 779 !dnnLayoutCompare_F32(mkl_lt_internal_filter, lt_filter); 780 if (mkl_convert_filter) { 781 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter, 782 mkl_lt_internal_filter), 783 E_SUCCESS); 784 785 mkl_buf_convert_filter = const_cast<void*>( 786 static_cast<const void*>(output_filter->flat<T>().data())); 787 788 CHECK_EQ( 789 dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter, 790 mkl_buf_convert_filter), 791 E_SUCCESS); 792 dnnDelete_F32(mkl_prim_convert_filter); 793 } 794 dnnLayoutDelete_F32(mkl_lt_internal_filter); 795 796 conv_res[dnnResourceFilter] = 797 (mkl_convert_filter) ? mkl_buf_convert_filter : mkl_buf_filter; 798 799 if (bias_enabled) { 800 const Tensor& bias = MklGetInput(context, 2); 801 void* mkl_buf_bias = 802 const_cast<void*>(static_cast<const void*>(bias.flat<T>().data())); 803 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_bias, 804 prim_fwd, dnnResourceBias), 805 E_SUCCESS); 806 mkl_convert_bias = !dnnLayoutCompare_F32(mkl_lt_internal_bias, lt_bias); 807 if (mkl_convert_bias) { 808 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_bias, lt_bias, 809 mkl_lt_internal_bias), 810 E_SUCCESS); 811 AllocTmpBuffer(context, mkl_tmp_bias_buf_tensor, mkl_lt_internal_bias, 812 &mkl_buf_convert_bias); 813 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_bias, mkl_buf_bias, 814 mkl_buf_convert_bias), 815 E_SUCCESS); 816 dnnDelete_F32(mkl_prim_convert_bias); 817 } 818 dnnLayoutDelete_F32(mkl_lt_internal_bias); 819 820 conv_res[dnnResourceBias] = 821 (mkl_convert_bias) ? mkl_buf_convert_bias : mkl_buf_bias; 822 } 823 } 824 MklCleanuptensorflow::MklConvOp::__anonc9bfa00e0108825 void MklCleanup() { 826 bool input_in_mkl_format = input_shape.IsMklTensor(); 827 dnnDelete_F32(prim_fwd); 828 if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input); 829 dnnLayoutDelete_F32(lt_filter); 830 if (bias_enabled) dnnLayoutDelete_F32(lt_bias); 831 } 832 } MklConv2DOpContext; 833 834 std::vector<int32> strides_; 835 Padding padding_; 836 TensorFormat data_format_; 837 }; 838 839 // FP32 kernel registration for INTEL_MKL_ML 840 REGISTER_KERNEL_BUILDER(Name("_MklConv2D") 841 .Device(DEVICE_CPU) 842 .TypeConstraint<float>("T") 843 .Label(mkl_op_registry::kMklOpLabel), 844 MklConv2DOp<CPUDevice, float, false>); 845 REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") 846 .Device(DEVICE_CPU) 847 .TypeConstraint<float>("T") 848 .Label(mkl_op_registry::kMklOpLabel), 849 MklConv2DOp<CPUDevice, float, true>); 850 851 #else 852 853 // Base class for convolution forward operations 854 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 855 typename Toutput, typename Ttemp_output, typename Tpadding, 856 bool bias_enabled, bool pad_enabled, bool is_depthwise> 857 class MklConvOp : public OpKernel { 858 public: ~MklConvOp()859 ~MklConvOp() {} 860 MklConvOp(OpKernelConstruction * context)861 explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { 862 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 863 if (context->HasAttr("padding_list")) { 864 OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); 865 } 866 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 867 string data_format; 868 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 869 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 870 errors::InvalidArgument("Invalid data format")); 871 OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), 872 errors::InvalidArgument("Sliding window strides field must " 873 "specify 4 or 5 dimensions")); 874 875 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 876 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 877 OP_REQUIRES( 878 context, stride_n == 1 && stride_c == 1, 879 errors::InvalidArgument("Current implementation does not yet support " 880 "strides in the batch and depth dimensions.")); 881 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 882 is_filter_const_ = false; 883 OP_REQUIRES_OK(context, 884 context->GetAttr("is_filter_const", &is_filter_const_)); 885 886 if (strides_.size() == 4) { 887 OP_REQUIRES(context, dilations_.size() == 4, 888 errors::InvalidArgument("Sliding window dilations field must " 889 "specify 4 dimensions")); 890 const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 891 const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 892 const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 893 const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 894 OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, 895 errors::InvalidArgument( 896 "Current implementation does not yet support " 897 "dilations in the batch and depth dimensions.")); 898 OP_REQUIRES( 899 context, dilation_h > 0 && dilation_w > 0, 900 errors::InvalidArgument("Dilated rates should be larger than 0.")); 901 } else if (strides_.size() == 5) { 902 OP_REQUIRES(context, dilations_.size() == 5, 903 errors::InvalidArgument("Dilation rates field must " 904 "specify 5 dimensions")); 905 OP_REQUIRES(context, 906 (GetTensorDim(dilations_, data_format_, 'N') == 1 && 907 GetTensorDim(dilations_, data_format_, 'C') == 1), 908 errors::InvalidArgument( 909 "Current implementation does not yet support " 910 "dilations rates in the batch and depth dimensions.")); 911 OP_REQUIRES( 912 context, 913 (GetTensorDim(dilations_, data_format_, '0') > 0 && 914 GetTensorDim(dilations_, data_format_, '1') > 0 && 915 GetTensorDim(dilations_, data_format_, '2') > 0), 916 errors::InvalidArgument("Dilated rates should be larger than 0.")); 917 } 918 } 919 Compute(OpKernelContext * context)920 void Compute(OpKernelContext* context) override { 921 try { 922 // Input tensors 923 const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src); 924 const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); 925 926 // Data from persistent (cached) filter tensor 927 const Tensor& cached_filter_data_tensor = 928 *cached_filter_data_ptensor_.AccessTensor(context); 929 930 MklDnnShape src_mkl_shape, filter_mkl_shape; 931 GetMklShape(context, kInputIndex_Src, &src_mkl_shape); 932 GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape); 933 OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false, 934 errors::InvalidArgument("Filter should not be in " 935 "Mkl Layout")); 936 937 MklDnnData<Tinput> src(&cpu_engine_); 938 MklDnnData<Tfilter> filter(&cpu_engine_); 939 940 memory::dims src_dims, filter_dims, padding_left, padding_right, 941 dilations, strides; 942 memory::dims dst_dims_tf_order, dst_dims_mkl_order; 943 944 // For Quantized-Conv2D and Pad fusion, we get padding from the 945 // `padding_list` attribute. Otherwise, we get it from one of the inputs. 946 bool quantized_pad_enabled = false; 947 for (auto const& padding_val : padding_list_) { 948 if (padding_val) { 949 quantized_pad_enabled = true; 950 break; 951 } 952 } 953 954 if (fuse_pad_ || quantized_pad_enabled) { 955 PadWithConvFusion(context, padding_left, padding_right, 956 quantized_pad_enabled); 957 } 958 959 // Get shapes of input tensors in MKL-DNN order 960 MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_, 961 dilations_); 962 auto src_tf_shape = GetTfShape(context, kInputIndex_Src); 963 auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter); 964 conv_utl.GetConvFwdSizesInMklOrder( 965 src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, 966 &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, 967 &padding_right, (fuse_pad_ || quantized_pad_enabled), is_depthwise); 968 969 if (!context->status().ok()) return; 970 971 // Check for corner case - if there is nothing to compute, return. 972 TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order); 973 974 // Corner cases: output with 0 elements and 0 batch size. 975 Tensor* dst_tensor = nullptr; 976 if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) { 977 MklDnnShape dst_mkl_shape; 978 dst_mkl_shape.SetMklTensor(false); 979 AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor, 980 src_tf_shape, dst_mkl_shape); 981 982 // MklConv2D/3D also outputs converted filter as 2nd output. 983 filter_mkl_shape.SetMklTensor(false); 984 Tensor* output_filter_tensor = nullptr; 985 if (typeid(Tinput) == typeid(float) && 986 typeid(Tfilter) == typeid(float) && 987 typeid(Toutput) == typeid(float)) { 988 filter_mkl_shape.SetMklTensor(false); 989 AllocateOutputSetMklShape(context, kOutputIndex_Filter, 990 &output_filter_tensor, filter_tf_shape, 991 filter_mkl_shape); 992 } 993 return; 994 } 995 996 bool is_conv2d = (strides_.size() == 4); 997 998 if (!is_conv2d) { 999 OP_REQUIRES( 1000 context, !pad_enabled, 1001 errors::InvalidArgument("Pad + Conv fusion only works for 2D")); 1002 } 1003 1004 // TODO 3-D support for Depthwise is not there 1005 if (is_depthwise) { 1006 OP_REQUIRES(context, is_conv2d, 1007 errors::InvalidArgument( 1008 "Only 2D convolution is supported for depthwise.")); 1009 } 1010 1011 // TODO(Intel-tf) Add check to make sure pad_enabled is true only for 2D 1012 if (!is_conv2d) { 1013 OP_REQUIRES( 1014 context, !fuse_pad_, 1015 errors::InvalidArgument("Pad+Conv fusion only works for 2D")); 1016 } 1017 // Create memory for user data. 1018 // Describe how the inputs and outputs of Convolution look like. Also 1019 // specify buffers containing actual input and output data. 1020 auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) 1021 : TFDataFormatToMklDnn3DDataFormat(data_format_); 1022 1023 // If input is in MKL layout, then simply grab the layout; otherwise, 1024 // construct TF layout for input. 1025 // For constructing TF layout for input, although input shape (src_dims) 1026 // is required to be in MKL-DNN order, the input layout is actually in 1027 // TF layout depending on the data format: 1028 // Conv2D: NHWC or NCHW 1029 // Conv3D: NDHWC or NCDHW 1030 auto src_md = src_mkl_shape.IsMklTensor() 1031 ? src_mkl_shape.GetMklLayout() 1032 : memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt); 1033 src.SetUsrMem(src_md, &src_tensor); 1034 1035 // Although filter shape (filter_dims) required is in MKL-DNN order, 1036 // the layout is Tensorflow's layout (HWIO) and (HWIGO) for 1037 // depthwise/group convolutions. 1038 1039 auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo 1040 : memory::format::hwio) 1041 : memory::format::dhwio; 1042 1043 DCHECK(!filter_mkl_shape.IsMklTensor()); 1044 auto filter_md = 1045 filter_mkl_shape.IsMklTensor() 1046 ? filter_mkl_shape.GetMklLayout() 1047 : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format); 1048 filter.SetUsrMem(filter_md, &filter_tensor); 1049 1050 // MKLDNN dilations start from 0. 1051 for (int i = 0; i < dilations.size(); ++i) --dilations[i]; 1052 1053 // In some cases, primitive descriptor could potentially contain 1054 // large buffers. As a result, we don't cache these primitives if the 1055 // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True. 1056 // MKL-DNN allocates buffers in the following cases: 1057 // 1. Legacy CPU without AVX512/AVX2, or 1058 // 2. 1x1 convolution with strides != 1 1059 bool do_not_cache = 1060 MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() && 1061 (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) && 1062 (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() || 1063 IsConv1x1StrideNot1(filter_dims, strides)); 1064 1065 // Get a conv2d fwd from primitive pool 1066 MklConvFwdPrimitive<float, Tinput, Tfilter, Tbias, Ttemp_output>* 1067 conv_fwd = nullptr; 1068 memory::dims bias_dims = {}; 1069 if (fuse_biasadd_) { 1070 conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); 1071 } 1072 MklConvFwdParams convFwdDims( 1073 src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, 1074 dst_dims_mkl_order, strides, dilations, padding_left, padding_right); 1075 1076 // TODO(mdfaijul): Extend the basic parameters for data types and fusions 1077 this->ExtendConvFwdParams(context, convFwdDims); 1078 1079 conv_fwd = MklConvFwdPrimitiveFactory<float, Tinput, Tfilter, Tbias, 1080 Ttemp_output>::Get(convFwdDims, 1081 do_not_cache); 1082 1083 // Allocate output tensors `output_tensor` and `filter_out_tensor` 1084 std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); 1085 AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, 1086 &dst_tensor); 1087 Tensor* filter_out_tensor = nullptr; 1088 if (typeid(Tinput) == typeid(float) && typeid(Tfilter) == typeid(float) && 1089 typeid(Toutput) == typeid(float)) { 1090 AllocateFilterOutputTensor(context, *conv_fwd_pd, 1091 TFShapeToMklDnnDims(filter_tf_shape), 1092 &filter_out_tensor); 1093 } 1094 1095 Ttemp_output* dst_data = 1096 reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data()); 1097 1098 // Check whether src and filter need to be reordered 1099 Tinput* src_data = nullptr; 1100 if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) { 1101 // Reorder src 1102 src.SetUsrMem(src_md, &src_tensor); 1103 src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc()); 1104 src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); 1105 } else { 1106 src_data = static_cast<Tinput*>( 1107 const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); 1108 } 1109 1110 Tfilter* filter_data = nullptr; 1111 if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) { 1112 bool is_filter_cached = false; 1113 // If filter is a constant, we can avoid the conversion of filter from 1114 // Tensorflow format to MKL format by caching the filter when it is 1115 // converted for the first time. This cached filter can then be reused 1116 // in subsequent iterations. 1117 if (is_filter_const_) { 1118 if (IsFilterCacheEmpty(context)) { 1119 // Cache filter if it is not already cached. 1120 CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, 1121 filter, filter_md); 1122 } 1123 filter_data = 1124 GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat()); 1125 is_filter_cached = (filter_data != nullptr); 1126 } 1127 if (!is_filter_cached) { 1128 filter.SetUsrMem(filter_md, &filter_tensor); 1129 if (filter_out_tensor == nullptr) { 1130 filter.CheckReorderToOpMem( 1131 conv_fwd_pd.get()->weights_primitive_desc()); 1132 } else { 1133 filter.CheckReorderToOpMem( 1134 conv_fwd_pd.get()->weights_primitive_desc(), 1135 filter.GetTensorBuffer(filter_out_tensor)); 1136 } 1137 filter_data = 1138 static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 1139 } 1140 } else { 1141 filter_data = static_cast<Tfilter*>( 1142 const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data())); 1143 } 1144 1145 // Execute convolution 1146 if (fuse_biasadd_) { 1147 const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); 1148 Tbias* bias_data = 1149 this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); 1150 conv_fwd->Execute(src_data, filter_data, bias_data, dst_data); 1151 } else { 1152 conv_fwd->Execute(src_data, filter_data, dst_data); 1153 } 1154 1155 // Delete primitive since it is not cached. 1156 if (do_not_cache) delete conv_fwd; 1157 } catch (mkldnn::error& e) { 1158 string error_msg = tensorflow::strings::StrCat( 1159 "Status: ", e.status, ", message: ", string(e.message), ", in file ", 1160 __FILE__, ":", __LINE__); 1161 OP_REQUIRES_OK( 1162 context, 1163 errors::Aborted("Operation received an exception:", error_msg)); 1164 } 1165 } 1166 PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool quantized_pad_enabled)1167 void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, 1168 memory::dims& padding_right, 1169 bool quantized_pad_enabled) { 1170 const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); 1171 Tpadding* paddings = nullptr; 1172 if (quantized_pad_enabled) { 1173 paddings = padding_list_.data(); 1174 } else { 1175 OP_REQUIRES(context, paddings_tf.dims() == 2, 1176 errors::InvalidArgument("paddings must be 2-dimensional: ", 1177 paddings_tf.shape().DebugString())); 1178 // Flatten tensor to get individual paddings. 1179 paddings = static_cast<Tpadding*>( 1180 const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data())); 1181 } 1182 // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) 1183 // will be zero. 1184 // Example: 1185 // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ], 1186 // flat method = row-major, then: 1187 // paddings = {0, 0, 1, 2, 3, 4, 0, 0}. 1188 // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4. 1189 // 1190 // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of 1191 // paddings(_tf) will be zero. 1192 // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}. 1193 int64 pad_top, pad_left; 1194 int64 pad_bottom, pad_right; 1195 string data_format = ToString(data_format_); 1196 if (data_format == "NHWC") { 1197 pad_top = paddings[2]; 1198 pad_bottom = paddings[3]; 1199 pad_left = paddings[4]; 1200 pad_right = paddings[5]; 1201 } else if (data_format == "NCHW") { 1202 pad_top = paddings[4]; 1203 pad_bottom = paddings[5]; 1204 pad_left = paddings[6]; 1205 pad_right = paddings[7]; 1206 } 1207 // Create padding arrays for MKL-DNN convolutions. 1208 // MKL-DNN uses asymetric padding. 1209 padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 1210 padding_right = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 1211 } 1212 1213 protected: set_fuse_biasadd(bool fuse_biasadd)1214 void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; } set_fuse_relu(bool fuse_relu)1215 void set_fuse_relu(bool fuse_relu) { fuse_relu_ = fuse_relu; } set_fuse_pad(bool fuse_pad)1216 void set_fuse_pad(bool fuse_pad) { 1217 fuse_pad_ = fuse_pad; 1218 // In PadwithFusedConv OP, pad is the fourth index. 1219 input_index_pad_ = 3; 1220 } 1221 1222 // This method is for the base class MklConvOp, which handles the 1223 // floating point implementation of Conv. The quantized conv implementations 1224 // will use overidden versions of this method. ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1225 virtual void ExtendConvFwdParams(OpKernelContext* context, 1226 MklConvFwdParams& params) { 1227 // Create a string from data types of input, filter, bias, and output. 1228 params.dtypes.append(typeid(Tinput).name()); 1229 params.dtypes.append(typeid(Tfilter).name()); 1230 params.dtypes.append(typeid(Tbias).name()); 1231 params.dtypes.append(typeid(Toutput).name()); 1232 1233 // Add fusions as post ops 1234 // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by 1235 // checking `fuse_biasadd_` flag. 1236 if (fuse_relu_) params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); 1237 } 1238 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)1239 virtual Tbias* GetBiasHandle(OpKernelContext* context, 1240 std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd, 1241 const Tensor& bias_tensor) { 1242 if (fuse_biasadd_) { 1243 return static_cast<Tbias*>( 1244 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1245 } 1246 return nullptr; 1247 } 1248 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,memory::format output_tf_format,Tensor ** output_tensor)1249 virtual void AllocateOutputTensor(OpKernelContext* context, 1250 const ConvFwdPd& conv_prim_desc, 1251 const memory::dims& output_dims_mkl_order, 1252 memory::format output_tf_format, 1253 Tensor** output_tensor) { 1254 CHECK_NOTNULL(output_tensor); 1255 auto dst_pd = conv_prim_desc.dst_primitive_desc(); 1256 1257 auto dst_md = dst_pd.desc(); 1258 if (!std::is_same<Ttemp_output, Toutput>::value) { 1259 dst_md.data.data_type = 1260 static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>()); 1261 dst_pd = memory::primitive_desc(dst_md, cpu_engine_); 1262 } 1263 // Allocate shape of Mkl tensor. 1264 MklDnnShape output_mkl_shape; 1265 output_mkl_shape.SetMklTensor(true); 1266 output_mkl_shape.SetMklLayout(&dst_pd); 1267 output_mkl_shape.SetElemType(MklDnnType<Toutput>()); 1268 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 1269 output_dims_mkl_order, output_tf_format); 1270 1271 // Allocate shape of TF tensor. 1272 TensorShape output_tf_shape; 1273 output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); 1274 1275 AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, 1276 output_tf_shape, output_mkl_shape); 1277 } 1278 1279 engine cpu_engine_ = engine(engine::cpu, 0); 1280 1281 private: 1282 std::vector<int32> strides_; 1283 std::vector<int32> dilations_; 1284 std::vector<Tpadding> padding_list_; 1285 bool is_filter_const_; 1286 mutex mu_; 1287 Padding padding_; 1288 TensorFormat data_format_; 1289 PersistentTensor cached_filter_data_ptensor_ GUARDED_BY(mu_); 1290 PersistentTensor cached_filter_md_ptensor_ GUARDED_BY(mu_); 1291 1292 // Initialize to values the template is instantiated with 1293 bool fuse_biasadd_ = bias_enabled; 1294 bool fuse_relu_ = false; 1295 bool fuse_pad_ = pad_enabled; 1296 1297 int input_index_pad_ = 2; 1298 1299 const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2; 1300 const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; 1301 const int kDilationH = 0, kDilationW = 1; 1302 1303 // Allocate persistent tensors for cached filter data and 1304 // cached filter memory descriptor (data format) AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1305 void AllocatePersistentTensor(OpKernelContext* context, 1306 const ConvFwdPd& conv_prim_desc, 1307 Tensor** filter_tensor) { 1308 DCHECK(filter_tensor); 1309 TensorShape filter_tf_shape; 1310 filter_tf_shape.AddDim( 1311 (conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter))); 1312 OP_REQUIRES_OK(context, context->allocate_persistent( 1313 DataTypeToEnum<Tfilter>::value, filter_tf_shape, 1314 &cached_filter_data_ptensor_, filter_tensor)); 1315 1316 Tensor* second_tensor = nullptr; 1317 TensorShape filter_mkl_format; 1318 filter_mkl_format.AddDim( 1319 sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) / 1320 sizeof(DT_INT32)); 1321 OP_REQUIRES_OK(context, context->allocate_persistent( 1322 DT_INT32, filter_mkl_format, 1323 &cached_filter_md_ptensor_, &second_tensor)); 1324 second_tensor->scalar<int32>()() = 1325 conv_prim_desc.weights_primitive_desc().desc().data.format; 1326 } 1327 AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1328 void AllocateFilterOutputTensor(OpKernelContext* context, 1329 const ConvFwdPd& conv_prim_desc, 1330 const memory::dims& filter_dims_tf_order, 1331 Tensor** filter_tensor) { 1332 CHECK_NOTNULL(filter_tensor); 1333 auto filter_pd = conv_prim_desc.weights_primitive_desc(); 1334 1335 // Allocate shape of Mkl tensor. 1336 MklDnnShape filter_mkl_shape; 1337 filter_mkl_shape.SetMklTensor(true); 1338 filter_mkl_shape.SetMklLayout(&filter_pd); 1339 filter_mkl_shape.SetElemType(MklDnnType<Tfilter>()); 1340 1341 // The format of the filter is actually OIhw8i8o, but TF doesn't support 1342 // this format. Just use format::blocked for now because the layout 1343 // is stored in the MKL data. 1344 filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), 1345 filter_dims_tf_order, memory::format::blocked); 1346 1347 // Allocate the data space for the filter to propagate as TF tensor. 1348 TensorShape filter_tf_shape; 1349 filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter))); 1350 1351 AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, 1352 filter_tf_shape, filter_mkl_shape); 1353 } 1354 1355 // Prepare and execute net - checks for input and output reorders. PrepareAndExecuteNet(const ConvFwdPd & conv_prim_desc,MklDnnData<Tinput> * src,MklDnnData<Tfilter> * filter,MklDnnData<Tbias> * bias,MklDnnData<Toutput> * output,Tensor * filter_out_tensor)1356 void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, 1357 MklDnnData<Tinput>* src, 1358 MklDnnData<Tfilter>* filter, 1359 MklDnnData<Tbias>* bias, 1360 MklDnnData<Toutput>* output, 1361 Tensor* filter_out_tensor) { 1362 CHECK_NOTNULL(filter_out_tensor); 1363 1364 // Create reorders between user layout and MKL layout if it is needed and 1365 // add it to the net before convolution. No need to check for output 1366 // reorder as we propagate output layout to the next layer. 1367 src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc()); 1368 1369 // rather than re-order to a temp buffer, reorder directly to the 1370 // filter output tensor 1371 filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), 1372 filter->GetTensorBuffer(filter_out_tensor)); 1373 1374 // Create convolution primitive and add it to net. 1375 std::vector<primitive> net; 1376 if (bias) { 1377 DCHECK(fuse_biasadd_); 1378 net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), 1379 filter->GetOpMem(), bias->GetOpMem(), 1380 output->GetOpMem())); 1381 } else { 1382 DCHECK(!fuse_biasadd_); 1383 net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), 1384 filter->GetOpMem(), 1385 output->GetOpMem())); 1386 } 1387 1388 stream(stream::kind::eager).submit(net).wait(); 1389 } 1390 1391 // LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 1392 // be acquired before entering the function, since it is acquired 1393 // inside the function. IsFilterCacheEmpty(OpKernelContext * context)1394 inline bool IsFilterCacheEmpty(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { 1395 tf_shared_lock lock(mu_); 1396 const Tensor& cached_filter_data_tensor = 1397 *cached_filter_data_ptensor_.AccessTensor(context); 1398 return (cached_filter_data_tensor.NumElements() == 0); 1399 } 1400 1401 // Cache the converted filter in a persistent tensor. 1402 // Only one thread can execute this method at any given time. CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md)1403 void CacheFilter(OpKernelContext* context, 1404 const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1405 Tfilter* filter_data, const Tensor& filter_tensor, 1406 MklDnnData<Tfilter>& filter, const memory::desc& filter_md) 1407 LOCKS_EXCLUDED(mu_) { 1408 mutex_lock lock(mu_); 1409 const Tensor& cached_filter_data_tensor = 1410 *cached_filter_data_ptensor_.AccessTensor(context); 1411 1412 // If filter is already cached, there's nothing to do. 1413 if (cached_filter_data_tensor.NumElements() > 0) { 1414 return; 1415 } 1416 1417 // Otherwise, cache filter 1418 filter.SetUsrMem(filter_md, &filter_tensor); 1419 filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc()); 1420 filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 1421 1422 Tensor* filter_tensor_ptr = nullptr; 1423 AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr); 1424 void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); 1425 size_t cached_filter_data_size = 1426 filter.GetOpMem().get_primitive_desc().get_size(); 1427 memcpy(cached_filter_data, filter_data, cached_filter_data_size); 1428 } 1429 GetCachedFilter(OpKernelContext * context,const memory::format & filter_mf)1430 Tfilter* GetCachedFilter(OpKernelContext* context, 1431 const memory::format& filter_mf) 1432 LOCKS_EXCLUDED(mu_) { 1433 tf_shared_lock lock(mu_); 1434 const Tensor& cached_filter_data = 1435 *cached_filter_data_ptensor_.AccessTensor(context); 1436 const Tensor& cached_filter_md = 1437 *cached_filter_md_ptensor_.AccessTensor(context); 1438 1439 // Check if the memory descriptor of the cached weights is same as 1440 // filter_mf. If so, we can used the cached weights; otherwise 1441 // return NULL. 1442 // TODO (bhavanis): Do we need to cast filter_mf before the check? 1443 if (cached_filter_md.scalar<int32>().size() && 1444 cached_filter_md.scalar<int32>()() == filter_mf) { 1445 return static_cast<Tfilter*>( 1446 const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data())); 1447 } 1448 return nullptr; 1449 } 1450 }; 1451 1452 // Base class for fused convolution forward operations 1453 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 1454 typename Toutput, typename Ttemp_output, typename Tpadding, 1455 bool pad_enabled> 1456 class MklFusedConvOp 1457 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1458 Tpadding, false, false, false> { 1459 public: MklFusedConvOp(OpKernelConstruction * context)1460 explicit MklFusedConvOp(OpKernelConstruction* context) 1461 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1462 Tpadding, false, false, false>(context) { 1463 // Since we came here through the registration of _MklFusedConv2D, get 1464 // all information from 'fused_ops' and 'num_args' 1465 std::vector<string> fused_ops; 1466 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 1467 1468 int num_args; 1469 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 1470 OP_REQUIRES(context, !fused_ops.empty(), 1471 errors::InvalidArgument( 1472 "Fused Conv2D must have at least one fused op.")); 1473 1474 if (fused_ops == std::vector<string>{"BiasAdd"}) { 1475 this->set_fuse_biasadd(true); 1476 OP_REQUIRES(context, num_args == 1, 1477 errors::InvalidArgument( 1478 "Fused Conv2D must have one extra argument: bias.")); 1479 } else if (fused_ops == std::vector<string>{"Relu"}) { 1480 this->set_fuse_relu(true); 1481 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 1482 this->set_fuse_biasadd(true); 1483 this->set_fuse_relu(true); 1484 OP_REQUIRES(context, num_args == 1, 1485 errors::InvalidArgument( 1486 "Fused Conv2D must have one extra argument: bias.")); 1487 } else { 1488 OP_REQUIRES(context, false, 1489 errors::Unimplemented("Fusion is not implemented: [", 1490 str_util::Join(fused_ops, ","), "]")); 1491 } 1492 1493 if (pad_enabled) { 1494 this->set_fuse_pad(true); 1495 } 1496 } 1497 ~MklFusedConvOp()1498 virtual ~MklFusedConvOp() {} 1499 }; 1500 1501 // We create new class for each version of Quantized Convolution and inherit 1502 // from the FP32 version of the base class 1503 template <typename Device, typename Tbias, typename Toutput, 1504 typename Ttemp_output, bool bias_enabled> 1505 class MklQuantizedConv2DOp 1506 : public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, 1507 int32, bias_enabled, false, false> { 1508 public: ~MklQuantizedConv2DOp()1509 virtual ~MklQuantizedConv2DOp() { 1510 if (this->input_bias_ != nullptr) { 1511 delete this->input_bias_; 1512 input_bias_ = nullptr; 1513 } 1514 1515 if (this->scaled_bias_ != nullptr) { 1516 delete this->scaled_bias_; 1517 scaled_bias_ = nullptr; 1518 } 1519 } 1520 MklQuantizedConv2DOp(OpKernelConstruction * context)1521 explicit MklQuantizedConv2DOp(OpKernelConstruction* context) 1522 : MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, 1523 bias_enabled, false, false>(context) { 1524 bool is_filter_const; 1525 OP_REQUIRES_OK(context, 1526 context->GetAttr("is_filter_const", &is_filter_const)); 1527 OP_REQUIRES(context, is_filter_const, 1528 errors::InvalidArgument("Filter must be a constant")); 1529 } 1530 Compute(OpKernelContext * context)1531 void Compute(OpKernelContext* context) override { 1532 // Compute int32 output tensor 1533 MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, 1534 bias_enabled, false, false>::Compute(context); 1535 1536 // Compute additional outputs: min/max scalars. 1537 int bias_index_offset; 1538 bias_index_offset = bias_enabled ? 1 : 0; 1539 1540 const float min_input = 1541 context->input(2 + bias_index_offset).flat<float>()(0); 1542 const float max_input = 1543 context->input(3 + bias_index_offset).flat<float>()(0); 1544 const float min_filter = 1545 context->input(4 + bias_index_offset).flat<float>()(0); 1546 const float max_filter = 1547 context->input(5 + bias_index_offset).flat<float>()(0); 1548 1549 float min_output_value; 1550 float max_output_value; 1551 if (std::is_same<Toutput, quint8>::value || 1552 std::is_same<Toutput, qint8>::value) { 1553 // This is the case when convolution and requantization are fused. 1554 // min_freezed_output and max_freezed_output are the actual range 1555 // of the output. 1556 min_output_value = context->input(6 + bias_index_offset).flat<float>()(0); 1557 max_output_value = context->input(7 + bias_index_offset).flat<float>()(0); 1558 } else { 1559 MklQuantizationRangeForMultiplication<quint8, qint8, qint32>( 1560 min_input, max_input, min_filter, max_filter, &min_output_value, 1561 &max_output_value); 1562 } 1563 1564 Tensor* output_min = nullptr; 1565 Tensor* output_max = nullptr; 1566 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 1567 output_min_mkl_shape.SetMklTensor(false); 1568 output_max_mkl_shape.SetMklTensor(false); 1569 AllocateOutputSetMklShape(context, 1, &output_min, {}, 1570 output_min_mkl_shape); 1571 AllocateOutputSetMklShape(context, 2, &output_max, {}, 1572 output_max_mkl_shape); 1573 output_min->flat<float>()(0) = min_output_value; 1574 output_max->flat<float>()(0) = max_output_value; 1575 } 1576 1577 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1578 void ExtendConvFwdParams(OpKernelContext* context, 1579 MklConvFwdParams& params) override { 1580 MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, 1581 bias_enabled, false, false>::ExtendConvFwdParams(context, params); 1582 1583 // When the output type is quint8, the output data id requantized 1584 // into quint8. A post_op "output_scale" is added to do the conversion. 1585 if (std::is_same<Toutput, quint8>::value || 1586 std::is_same<Toutput, qint8>::value) { 1587 int bias_index_offset; 1588 bias_index_offset = bias_enabled ? 1 : 0; 1589 1590 const float min_input = 1591 context->input(2 + bias_index_offset).flat<float>()(0); 1592 const float max_input = 1593 context->input(3 + bias_index_offset).flat<float>()(0); 1594 const float min_filter = 1595 context->input(4 + bias_index_offset).flat<float>()(0); 1596 const float max_filter = 1597 context->input(5 + bias_index_offset).flat<float>()(0); 1598 const float min_freezed_output = 1599 context->input(6 + bias_index_offset).flat<float>()(0); 1600 const float max_freezed_output = 1601 context->input(7 + bias_index_offset).flat<float>()(0); 1602 1603 float min_output_value; 1604 float max_output_value; 1605 MklQuantizationRangeForMultiplication<quint8, qint8, qint32>( 1606 min_input, max_input, min_filter, max_filter, &min_output_value, 1607 &max_output_value); 1608 float scale_int32 = 1609 std::max(std::abs(min_output_value), std::abs(max_output_value)); 1610 float scale_eightbit = 1611 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1612 float scale = 1.0; 1613 if (std::is_same<Toutput, quint8>::value) 1614 scale = scale_int32 / scale_eightbit / static_cast<float>(1 << 23); 1615 else 1616 scale = scale_int32 / scale_eightbit / static_cast<float>(1 << 24); 1617 1618 std::vector<float> output_scale; 1619 output_scale.push_back(scale); 1620 params.post_op_params.push_back({"output_scale", output_scale}); 1621 } 1622 } 1623 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1624 Tbias* GetBiasHandle(OpKernelContext* context, 1625 std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1626 const Tensor& bias_tensor) override { 1627 int bias_index_offset; 1628 bias_index_offset = bias_enabled ? 1 : 0; 1629 1630 const float min_input = 1631 context->input(2 + bias_index_offset).flat<float>()(0); 1632 const float max_input = 1633 context->input(3 + bias_index_offset).flat<float>()(0); 1634 const float min_filter = 1635 context->input(4 + bias_index_offset).flat<float>()(0); 1636 const float max_filter = 1637 context->input(5 + bias_index_offset).flat<float>()(0); 1638 1639 std::vector<mkldnn::primitive> net; 1640 if (bias_enabled) { 1641 if (std::is_same<Tbias, qint32>::value) { 1642 return static_cast<Tbias*>( 1643 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1644 } 1645 // If bias is enabled and requantization is not fused, scale the 1646 // bias to be consistent with quantized-input and quantized-filter. 1647 float bias_scale = 255.0 * 127.0 / 1648 (std::max(std::abs(max_input), std::abs(min_input)) * 1649 std::max(std::abs(max_filter), std::abs(min_filter))); 1650 std::vector<float> scales; 1651 scales.push_back(bias_scale); 1652 mkldnn::primitive_attr bias_attr; 1653 bias_attr.set_output_scales(0, scales); 1654 1655 void* bias_buf = static_cast<void*>( 1656 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1657 input_bias_ = new memory(conv_fwd_pd->bias_primitive_desc(), bias_buf); 1658 scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc()); 1659 auto reorder_desc = mkldnn::reorder::primitive_desc( 1660 input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(), 1661 bias_attr); 1662 net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_)); 1663 stream(stream::kind::eager).submit(net).wait(); 1664 return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); 1665 } else { 1666 return nullptr; 1667 } 1668 } 1669 1670 memory* input_bias_ = nullptr; 1671 memory* scaled_bias_ = nullptr; 1672 }; 1673 1674 template <typename Device, typename Tbias, typename Toutput, 1675 typename Ttemp_output, bool bias_enabled> 1676 class MklQuantizedConv2DReluOp 1677 : public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1678 bias_enabled> { 1679 public: ~MklQuantizedConv2DReluOp()1680 virtual ~MklQuantizedConv2DReluOp() {} 1681 MklQuantizedConv2DReluOp(OpKernelConstruction * context)1682 explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) 1683 : MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1684 bias_enabled>(context) {} 1685 1686 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1687 void ExtendConvFwdParams(OpKernelContext* context, 1688 MklConvFwdParams& params) override { 1689 MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1690 bias_enabled>::ExtendConvFwdParams(context, params); 1691 params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); 1692 } 1693 }; 1694 1695 template <typename Device, typename Tbias, typename Toutput, 1696 typename Ttemp_output, bool bias_enabled> 1697 class MklQuantizedConv2DSumReluOp 1698 : public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1699 bias_enabled> { 1700 public: ~MklQuantizedConv2DSumReluOp()1701 virtual ~MklQuantizedConv2DSumReluOp() { 1702 if (this->summand_ != nullptr) { 1703 delete this->summand_; 1704 summand_ = nullptr; 1705 } 1706 1707 if (this->dst_ != nullptr) { 1708 delete this->dst_; 1709 dst_ = nullptr; 1710 } 1711 } 1712 MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1713 explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) 1714 : MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1715 bias_enabled>(context) {} 1716 1717 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1718 void ExtendConvFwdParams(OpKernelContext* context, 1719 MklConvFwdParams& params) override { 1720 MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, 1721 bias_enabled>::ExtendConvFwdParams(context, params); 1722 // Calculate the scale (beta in mkldnn api term) for sum 1723 if (std::is_same<Toutput, quint8>::value) { 1724 int summand_idx = context->num_inputs() / 2 - 1 - 2; 1725 DataType summand_type = this->input_type(summand_idx); 1726 bool summand_condition = 1727 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 1728 CHECK((summand_condition)); 1729 int bias_index_offset = bias_enabled ? 1 : 0; 1730 const float min_freezed_output = 1731 context->input(6 + bias_index_offset).flat<float>()(0); 1732 const float max_freezed_output = 1733 context->input(7 + bias_index_offset).flat<float>()(0); 1734 const float min_freezed_summand = 1735 context->input(9 + bias_index_offset).flat<float>()(0); 1736 const float max_freezed_summand = 1737 context->input(10 + bias_index_offset).flat<float>()(0); 1738 1739 float scale_output = 1740 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1741 float scale_summand = std::max(std::abs(min_freezed_summand), 1742 std::abs(max_freezed_summand)); 1743 if (summand_type == DT_QUINT8) 1744 params.post_op_params.push_back( 1745 {"sum", {scale_summand / scale_output}}); 1746 else 1747 params.post_op_params.push_back( 1748 {"sum", {2.0f * scale_summand / scale_output}}); 1749 } else { 1750 params.post_op_params.push_back({"sum", {1.0}}); 1751 } 1752 params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); 1753 } 1754 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,memory::format output_tf_format,Tensor ** output_tensor)1755 void AllocateOutputTensor(OpKernelContext* context, 1756 const ConvFwdPd& conv_prim_desc, 1757 const memory::dims& output_dims_mkl_order, 1758 memory::format output_tf_format, 1759 Tensor** output_tensor) override { 1760 int summand_idx = context->num_inputs() / 2 - 1; 1761 float reorder_sum_scale = 1.0; 1762 if (std::is_same<Toutput, quint8>::value) { 1763 summand_idx -= 2; 1764 DataType summand_type = this->input_type(summand_idx); 1765 bool summand_condition = 1766 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 1767 CHECK((summand_condition)); 1768 Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx)); 1769 MklDnnShape summand_mkl_shape; 1770 GetMklShape(context, summand_idx, &summand_mkl_shape); 1771 auto dst_md = summand_mkl_shape.GetMklLayout(); 1772 if (summand_mkl_shape.IsMklTensor()) { 1773 if (summand_type == DT_QINT8) { 1774 OP_REQUIRES_OK(context, summand.BitcastFrom(summand, DT_QUINT8, 1775 summand.shape())); 1776 dst_md.data.data_type = 1777 static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>()); 1778 summand_mkl_shape.SetMklLayout(&dst_md); 1779 summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); 1780 } 1781 ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0, 1782 summand_mkl_shape); 1783 *output_tensor = const_cast<Tensor*>(&summand); 1784 return; 1785 } else { 1786 TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, 1787 "Current fusion is not successful.")); 1788 } 1789 } 1790 // TODO(mdfaijul): Add cleaner code for non-mkl tensor 1791 MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, 1792 bias_enabled, false, 1793 false>::AllocateOutputTensor(context, conv_prim_desc, 1794 output_dims_mkl_order, 1795 output_tf_format, output_tensor); 1796 const Tensor& summand = MklGetInput(context, summand_idx); 1797 if (summand.dtype() != DT_FLOAT) 1798 TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, 1799 "Current fusion requires summand to be float")); 1800 MklDnnShape summand_mkl_shape; 1801 GetMklShape(context, summand_idx, &summand_mkl_shape); 1802 // We need to compute scale for the summand 1803 int bias_index_offset = bias_enabled ? 1 : 0; 1804 const float min_input = 1805 context->input(2 + bias_index_offset).flat<float>()(0); 1806 const float max_input = 1807 context->input(3 + bias_index_offset).flat<float>()(0); 1808 const float min_filter = 1809 context->input(4 + bias_index_offset).flat<float>()(0); 1810 const float max_filter = 1811 context->input(5 + bias_index_offset).flat<float>()(0); 1812 1813 reorder_sum_scale = 255.0 * 127.0 / 1814 (std::max(std::abs(max_input), std::abs(min_input)) * 1815 std::max(std::abs(max_filter), std::abs(min_filter))); 1816 std::vector<float> scales; 1817 scales.push_back(reorder_sum_scale); 1818 mkldnn::primitive_attr reorder_attr; 1819 reorder_attr.set_output_scales(0, scales); 1820 1821 auto summand_md = 1822 summand_mkl_shape.IsMklTensor() 1823 ? summand_mkl_shape.GetMklLayout() 1824 : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(), 1825 memory::format::nhwc); 1826 auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_); 1827 void* summand_buf = 1828 static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data())); 1829 void* dst_buf = 1830 static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); 1831 summand_ = new memory(summand_pd, summand_buf); 1832 dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf); 1833 auto reorder_desc = mkldnn::reorder::primitive_desc( 1834 summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr); 1835 1836 std::vector<mkldnn::primitive> net; 1837 net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_)); 1838 stream(stream::kind::eager).submit(net).wait(); 1839 } 1840 1841 memory* summand_ = nullptr; 1842 memory* dst_ = nullptr; 1843 }; 1844 1845 // INT8 kernel registration 1846 // Register NoOp kernel for QuantizedConv2D for qint8 filter 1847 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D") 1848 .Device(DEVICE_CPU) 1849 .TypeConstraint<quint8>("Tinput") 1850 .TypeConstraint<qint8>("Tfilter") 1851 .TypeConstraint<qint32>("out_type"), 1852 NoOp); 1853 1854 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize") 1855 .Device(DEVICE_CPU) 1856 .TypeConstraint<quint8>("Tinput") 1857 .TypeConstraint<qint8>("Tfilter") 1858 .TypeConstraint<qint8>("out_type"), 1859 NoOp); 1860 1861 // Register a templatized implementation of MklQuantizedConv2D. 1862 REGISTER_KERNEL_BUILDER( 1863 Name("_MklQuantizedConv2D") 1864 .Device(DEVICE_CPU) 1865 .TypeConstraint<quint8>("Tinput") 1866 .TypeConstraint<qint8>("Tfilter") 1867 .TypeConstraint<qint32>("out_type") 1868 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1869 MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, false>); 1870 1871 REGISTER_KERNEL_BUILDER( 1872 Name("_MklQuantizedConv2DAndRequantize") 1873 .Device(DEVICE_CPU) 1874 .TypeConstraint<quint8>("Tinput") 1875 .TypeConstraint<qint8>("Tfilter") 1876 .TypeConstraint<qint8>("out_type") 1877 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1878 MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, false>); 1879 1880 // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface. 1881 // This kernel will be replaced by an MKL kernel during graph 1882 // optimization pass. 1883 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias") 1884 .Device(DEVICE_CPU) 1885 .TypeConstraint<quint8>("Tinput") 1886 .TypeConstraint<qint8>("Tfilter") 1887 .TypeConstraint<qint32>("out_type"), 1888 NoOp); 1889 1890 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize") 1891 .Device(DEVICE_CPU) 1892 .TypeConstraint<quint8>("Tinput") 1893 .TypeConstraint<qint8>("Tfilter") 1894 .TypeConstraint<qint8>("out_type"), 1895 NoOp); 1896 1897 // Register a templatized implementation MklQuantizedConv2DWithBias. 1898 REGISTER_KERNEL_BUILDER( 1899 Name("_MklQuantizedConv2DWithBias") 1900 .Device(DEVICE_CPU) 1901 .TypeConstraint<quint8>("Tinput") 1902 .TypeConstraint<qint8>("Tfilter") 1903 .TypeConstraint<qint32>("out_type") 1904 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1905 MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, true>); 1906 1907 REGISTER_KERNEL_BUILDER( 1908 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1909 .Device(DEVICE_CPU) 1910 .TypeConstraint<quint8>("Tinput") 1911 .TypeConstraint<qint8>("Tfilter") 1912 .TypeConstraint<qint32>("Tbias") 1913 .TypeConstraint<qint8>("out_type") 1914 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1915 MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, true>); 1916 REGISTER_KERNEL_BUILDER( 1917 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1918 .Device(DEVICE_CPU) 1919 .TypeConstraint<quint8>("Tinput") 1920 .TypeConstraint<qint8>("Tfilter") 1921 .TypeConstraint<float>("Tbias") 1922 .TypeConstraint<qint8>("out_type") 1923 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1924 MklQuantizedConv2DOp<CPUDevice, float, qint8, qint8, true>); 1925 1926 // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface. 1927 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 1928 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRelu") 1929 .Device(DEVICE_CPU) 1930 .TypeConstraint<quint8>("Tinput") 1931 .TypeConstraint<qint8>("Tfilter") 1932 .TypeConstraint<qint32>("out_type"), 1933 NoOp); 1934 1935 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndReluAndRequantize") 1936 .Device(DEVICE_CPU) 1937 .TypeConstraint<quint8>("Tinput") 1938 .TypeConstraint<qint8>("Tfilter") 1939 .TypeConstraint<quint8>("out_type"), 1940 NoOp); 1941 1942 // Register a templatized implementation of MklQuantizedConv2DAndRelu. 1943 REGISTER_KERNEL_BUILDER( 1944 Name("_MklQuantizedConv2DAndRelu") 1945 .Device(DEVICE_CPU) 1946 .TypeConstraint<quint8>("Tinput") 1947 .TypeConstraint<qint8>("Tfilter") 1948 .TypeConstraint<qint32>("out_type") 1949 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1950 MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, false>); 1951 1952 REGISTER_KERNEL_BUILDER( 1953 Name("_MklQuantizedConv2DAndReluAndRequantize") 1954 .Device(DEVICE_CPU) 1955 .TypeConstraint<quint8>("Tinput") 1956 .TypeConstraint<qint8>("Tfilter") 1957 .TypeConstraint<quint8>("out_type") 1958 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1959 MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, false>); 1960 1961 // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python 1962 // interface. 1963 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 1964 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu") 1965 .Device(DEVICE_CPU) 1966 .TypeConstraint<quint8>("Tinput") 1967 .TypeConstraint<qint8>("Tfilter") 1968 .TypeConstraint<qint32>("out_type"), 1969 NoOp); 1970 1971 // Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize 1972 // to get a python interface. 1973 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 1974 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize") 1975 .Device(DEVICE_CPU) 1976 .TypeConstraint<quint8>("Tinput") 1977 .TypeConstraint<qint8>("Tfilter") 1978 .TypeConstraint<quint8>("out_type"), 1979 NoOp); 1980 1981 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu. 1982 REGISTER_KERNEL_BUILDER( 1983 Name("_MklQuantizedConv2DWithBiasAndRelu") 1984 .Device(DEVICE_CPU) 1985 .TypeConstraint<quint8>("Tinput") 1986 .TypeConstraint<qint8>("Tfilter") 1987 .TypeConstraint<qint32>("out_type") 1988 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1989 MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, true>); 1990 1991 // Register a templatized implementation of 1992 // MklQuantizedConv2DWithBiasAndReluAndRequantize. 1993 REGISTER_KERNEL_BUILDER( 1994 Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 1995 .Device(DEVICE_CPU) 1996 .TypeConstraint<quint8>("Tinput") 1997 .TypeConstraint<qint8>("Tfilter") 1998 .TypeConstraint<float>("Tbias") 1999 .TypeConstraint<quint8>("out_type") 2000 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2001 MklQuantizedConv2DReluOp<CPUDevice, float, quint8, quint8, true>); 2002 REGISTER_KERNEL_BUILDER( 2003 Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 2004 .Device(DEVICE_CPU) 2005 .TypeConstraint<quint8>("Tinput") 2006 .TypeConstraint<qint8>("Tfilter") 2007 .TypeConstraint<qint32>("Tbias") 2008 .TypeConstraint<quint8>("out_type") 2009 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2010 MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true>); 2011 2012 // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python 2013 // interface. 2014 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 2015 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndRelu") 2016 .Device(DEVICE_CPU) 2017 .TypeConstraint<quint8>("Tinput") 2018 .TypeConstraint<qint8>("Tfilter") 2019 .TypeConstraint<qint32>("out_type"), 2020 NoOp); 2021 2022 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndReluAndRequantize") 2023 .Device(DEVICE_CPU) 2024 .TypeConstraint<quint8>("Tinput") 2025 .TypeConstraint<qint8>("Tfilter") 2026 .TypeConstraint<quint8>("out_type"), 2027 NoOp); 2028 REGISTER_KERNEL_BUILDER( 2029 Name("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2030 .Device(DEVICE_CPU) 2031 .TypeConstraint<quint8>("Tinput") 2032 .TypeConstraint<qint8>("Tfilter") 2033 .TypeConstraint<quint8>("out_type"), 2034 NoOp); 2035 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu. 2036 REGISTER_KERNEL_BUILDER( 2037 Name("_MklQuantizedConv2DWithBiasSumAndRelu") 2038 .Device(DEVICE_CPU) 2039 .TypeConstraint<quint8>("Tinput") 2040 .TypeConstraint<qint8>("Tfilter") 2041 .TypeConstraint<qint32>("out_type") 2042 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2043 MklQuantizedConv2DSumReluOp<CPUDevice, float, qint32, qint32, true>); 2044 2045 REGISTER_KERNEL_BUILDER( 2046 Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") 2047 .Device(DEVICE_CPU) 2048 .TypeConstraint<quint8>("Tinput") 2049 .TypeConstraint<qint8>("Tfilter") 2050 .TypeConstraint<qint32>("Tbias") 2051 .TypeConstraint<quint8>("out_type") 2052 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2053 MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, quint8, true>); 2054 2055 REGISTER_KERNEL_BUILDER( 2056 Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2057 .Device(DEVICE_CPU) 2058 .TypeConstraint<quint8>("Tinput") 2059 .TypeConstraint<qint8>("Tfilter") 2060 .TypeConstraint<qint32>("Tbias") 2061 .TypeConstraint<quint8>("out_type") 2062 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2063 MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, qint8, true>); 2064 2065 REGISTER_KERNEL_BUILDER( 2066 Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") 2067 .Device(DEVICE_CPU) 2068 .TypeConstraint<quint8>("Tinput") 2069 .TypeConstraint<qint8>("Tfilter") 2070 .TypeConstraint<float>("Tbias") 2071 .TypeConstraint<quint8>("out_type") 2072 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2073 MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, quint8, true>); 2074 2075 REGISTER_KERNEL_BUILDER( 2076 Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2077 .Device(DEVICE_CPU) 2078 .TypeConstraint<quint8>("Tinput") 2079 .TypeConstraint<qint8>("Tfilter") 2080 .TypeConstraint<float>("Tbias") 2081 .TypeConstraint<quint8>("out_type") 2082 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2083 MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, qint8, true>); 2084 #endif // INTEL_MKL_ML 2085 2086 // Register 2D operations 2087 #define REGISTER_MKL_CPU_2D(T) \ 2088 REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ 2089 .Device(DEVICE_CPU) \ 2090 .TypeConstraint<T>("T") \ 2091 .Label(mkl_op_registry::kMklOpLabel), \ 2092 MklConvOp<CPUDevice, float, float, float, float, \ 2093 float, int32, false, false, false>); \ 2094 REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ 2095 .Device(DEVICE_CPU) \ 2096 .TypeConstraint<T>("T") \ 2097 .Label(mkl_op_registry::kMklOpLabel), \ 2098 MklConvOp<CPUDevice, float, float, float, float, \ 2099 float, int32, true, false, false>); \ 2100 REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \ 2101 .Device(DEVICE_CPU) \ 2102 .TypeConstraint<T>("T") \ 2103 .Label(mkl_op_registry::kMklOpLabel), \ 2104 MklDummyOp<CPUDevice, T>); \ 2105 REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D") \ 2106 .Device(DEVICE_CPU) \ 2107 .TypeConstraint<T>("T") \ 2108 .TypeConstraint<int32>("Tpaddings") \ 2109 .Label(mkl_op_registry::kMklOpLabel), \ 2110 MklConvOp<CPUDevice, float, float, float, float, \ 2111 float, int32, false, true, false>); \ 2112 REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D") \ 2113 .Device(DEVICE_CPU) \ 2114 .TypeConstraint<T>("T") \ 2115 .TypeConstraint<int64>("Tpaddings") \ 2116 .Label(mkl_op_registry::kMklOpLabel), \ 2117 MklConvOp<CPUDevice, float, float, float, float, \ 2118 float, int64, false, true, false>); \ 2119 REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D") \ 2120 .Device(DEVICE_CPU) \ 2121 .TypeConstraint<T>("T") \ 2122 .TypeConstraint<int32>("Tpaddings") \ 2123 .Label(mkl_op_registry::kMklOpLabel), \ 2124 MklDummyOp<CPUDevice, T>); 2125 2126 TF_CALL_float(REGISTER_MKL_CPU_2D); 2127 2128 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ 2129 REGISTER_KERNEL_BUILDER(Name("_MklDepthwiseConv2dNative") \ 2130 .Device(DEVICE_CPU) \ 2131 .TypeConstraint<float>("T") \ 2132 .Label(mkl_op_registry::kMklOpLabel), \ 2133 MklConvOp<CPUDevice, float, float, float, float, \ 2134 float, int32, false, false, true>); 2135 2136 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); 2137 2138 // Note we are registering _MklFusedConv2D. 2139 // We check the fused_ops attributes to decide if bias is enabled or not. 2140 #define REGISTER_MKL_CPU_2D_FUSED(T) \ 2141 REGISTER_KERNEL_BUILDER( \ 2142 Name("_MklFusedConv2D") \ 2143 .Device(DEVICE_CPU) \ 2144 .TypeConstraint<T>("T") \ 2145 .Label(mkl_op_registry::kMklOpLabel), \ 2146 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false>); \ 2147 REGISTER_KERNEL_BUILDER( \ 2148 Name("_MklPadWithFusedConv2D") \ 2149 .Device(DEVICE_CPU) \ 2150 .TypeConstraint<int32>("Tpaddings") \ 2151 .TypeConstraint<T>("T") \ 2152 .Label(mkl_op_registry::kMklOpLabel), \ 2153 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true>); \ 2154 REGISTER_KERNEL_BUILDER( \ 2155 Name("_MklPadWithFusedConv2D") \ 2156 .Device(DEVICE_CPU) \ 2157 .TypeConstraint<T>("T") \ 2158 .TypeConstraint<int64>("Tpaddings") \ 2159 .Label(mkl_op_registry::kMklOpLabel), \ 2160 MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true>); \ 2161 REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithFusedConv2D") \ 2162 .Device(DEVICE_CPU) \ 2163 .TypeConstraint<T>("T") \ 2164 .TypeConstraint<int32>("Tpaddings") \ 2165 .Label(mkl_op_registry::kMklOpLabel), \ 2166 MklDummyOp<CPUDevice, T>); 2167 2168 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); 2169 2170 // Register 3D operations 2171 #define REGISTER_MKL_CPU_3D(T) \ 2172 REGISTER_KERNEL_BUILDER( \ 2173 Name("_MklConv3D") \ 2174 .Device(DEVICE_CPU) \ 2175 .TypeConstraint<T>("T") \ 2176 .Label(mkl_op_registry::kMklOpLabel), \ 2177 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>); 2178 TF_CALL_float(REGISTER_MKL_CPU_3D); 2179 2180 } // namespace tensorflow 2181 #endif // INTEL_MKL 2182