1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 #ifdef INTEL_MKL
18 
19 #include <string.h>
20 #include <algorithm>
21 #include <map>
22 #include <vector>
23 
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/tensor_slice.h"
32 #include "tensorflow/core/kernels/mkl_conv_ops.h"
33 #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
34 #include "tensorflow/core/kernels/no_op.h"
35 #include "tensorflow/core/kernels/ops_util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/array_slice.h"
38 #include "tensorflow/core/lib/strings/numbers.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/util/padding.h"
44 #include "tensorflow/core/util/tensor_format.h"
45 
46 #include "tensorflow/core/util/mkl_util.h"
47 
48 #ifndef INTEL_MKL_ML_ONLY
49 #include "mkldnn.hpp"
50 
51 using mkldnn::prop_kind;
52 using mkldnn::stream;
53 using mkldnn::convolution_forward;
54 using mkldnn::convolution_direct;
55 
56 #else
57 #include "mkl_dnn.h"
58 #include "mkl_dnn_types.h"
59 #endif
60 
61 namespace tensorflow {
62 
63 #ifndef INTEL_MKL_ML_ONLY
64 
65 // This structure aggregates multiple inputs to Conv2DFwd* methods.
66 struct MklConvFwdParams {
67   memory::dims src_dims;
68   memory::dims filter_dims;
69   memory::dims bias_dims;
70   memory::dims dst_dims;
71   memory::dims strides;
72   memory::dims dilations;
73   memory::dims padding_left;
74   memory::dims padding_right;
75   string dtypes = string("");
76   struct PostOpParam {
77     string name;
78     std::vector<float> param;
79   };
80   std::vector<PostOpParam> post_op_params;
81 
MklConvFwdParamstensorflow::MklConvFwdParams82   MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
83                    memory::dims bias_dims, memory::dims dst_dims,
84                    memory::dims strides, memory::dims dilations,
85                    memory::dims padding_left, memory::dims padding_right)
86       : src_dims(src_dims),
87         filter_dims(filter_dims),
88         bias_dims(bias_dims),
89         dst_dims(dst_dims),
90         strides(strides),
91         dilations(dilations),
92         padding_left(padding_left),
93         padding_right(padding_right) {}
94 };
95 
96 typedef mkldnn::convolution_forward::primitive_desc ConvFwdPd;
97 
98 // With quantization, input, filter, and output can have different types
99 // so we use different template parameter for each type
100 template <typename T, typename Tinput, typename Tfilter, typename Tbias,
101           typename Toutput>
102 class MklConvFwdPrimitive : public MklPrimitive {
103  public:
MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)104   explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
105       : cpu_engine_(engine::cpu, 0) {
106     context_.fwd_stream.reset(new stream(stream::kind::eager));
107     // Create conv primitive
108     if (context_.conv_fwd == nullptr) {
109       Setup(convFwdDims);
110     }
111   }
112 
~MklConvFwdPrimitive()113   ~MklConvFwdPrimitive() {}
114 
115   // Convolution forward execute with bias
116   //   src_data:    input data buffer of src
117   //   filter_data: input data buffer of filter (weights)
118   //   bias_data:   input data buffer of bias
119   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data)120   void Execute(const Tinput* src_data, const Tfilter* filter_data,
121                const Tbias* bias_data, const Toutput* dst_data) {
122     context_.src_mem->set_data_handle(
123         static_cast<void*>(const_cast<Tinput*>(src_data)));
124     context_.filter_mem->set_data_handle(
125         static_cast<void*>(const_cast<Tfilter*>(filter_data)));
126     context_.bias_mem->set_data_handle(
127         static_cast<void*>(const_cast<Tbias*>(bias_data)));
128     context_.dst_mem->set_data_handle(
129         static_cast<void*>(const_cast<Toutput*>(dst_data)));
130     context_.fwd_stream->submit(context_.fwd_primitives);
131 
132     // After exec, set data handle back
133     context_.src_mem->set_data_handle(DummyData);
134     context_.filter_mem->set_data_handle(DummyData);
135     context_.bias_mem->set_data_handle(DummyData);
136     context_.dst_mem->set_data_handle(DummyData);
137 
138     return;
139   }
140 
141   // Convolution forward execute without bias
142   //   src_data:    input data buffer of src
143   //   filter_data: input data buffer of filter (weights)
144   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data)145   void Execute(const Tinput* src_data, const Tfilter* filter_data,
146                const Toutput* dst_data) {
147     context_.src_mem->set_data_handle(
148         static_cast<void*>(const_cast<Tinput*>(src_data)));
149     context_.filter_mem->set_data_handle(
150         static_cast<void*>(const_cast<Tfilter*>(filter_data)));
151     context_.dst_mem->set_data_handle(
152         static_cast<void*>(const_cast<Toutput*>(dst_data)));
153     context_.fwd_stream->submit(context_.fwd_primitives);
154 
155     // After execution, set data handle back
156     context_.src_mem->set_data_handle(DummyData);
157     context_.filter_mem->set_data_handle(DummyData);
158     context_.dst_mem->set_data_handle(DummyData);
159   }
160 
GetSrcMemoryFormat() const161   memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
162 
GetFilterMemoryFormat() const163   memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
164 
GetPrimitiveDesc() const165   std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
166     return context_.fwd_pd;
167   }
168 
169  private:
170   // Primitive reuse context for Conv2D Fwd op
171   struct ConvFwdContext {
172     // Expected memory format for this primitive instance
173     memory::format src_fmt;
174     memory::format filter_fmt;
175 
176     // MKLDNN memory
177     std::shared_ptr<mkldnn::memory> src_mem;
178     std::shared_ptr<mkldnn::memory> filter_mem;
179     std::shared_ptr<mkldnn::memory> bias_mem;
180     std::shared_ptr<mkldnn::memory> dst_mem;
181 
182     // Desc & prmitive desc
183     std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
184 
185     // Memory desc
186     std::shared_ptr<mkldnn::memory::desc> src_md;
187     std::shared_ptr<mkldnn::memory::desc> filter_md;
188     std::shared_ptr<mkldnn::memory::desc> bias_md;
189     std::shared_ptr<mkldnn::memory::desc> dst_md;
190 
191     // Convolution primitive
192     std::shared_ptr<ConvFwdPd> fwd_pd;
193     std::shared_ptr<mkldnn::primitive> conv_fwd;
194 
195     std::shared_ptr<mkldnn::stream> fwd_stream;
196     std::vector<mkldnn::primitive> fwd_primitives;
197 
ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext198     ConvFwdContext()
199         : src_fmt(memory::format::any),
200           filter_fmt(memory::format::any),
201           src_mem(nullptr),
202           filter_mem(nullptr),
203           bias_mem(nullptr),
204           dst_mem(nullptr),
205           fwd_desc(nullptr),
206           src_md(nullptr),
207           filter_md(nullptr),
208           bias_md(nullptr),
209           fwd_pd(nullptr),
210           conv_fwd(nullptr),
211           fwd_stream(nullptr) {}
212   };
213 
Setup(const MklConvFwdParams & convFwdDims)214   void Setup(const MklConvFwdParams& convFwdDims) {
215     // Create memory descriptors for convolution data w/ no specified format
216     context_.src_md.reset(new memory::desc(
217         {convFwdDims.src_dims}, MklDnnType<Tinput>(), memory::format::any));
218 
219     context_.filter_md.reset(new memory::desc(
220         {convFwdDims.filter_dims}, MklDnnType<Tfilter>(), memory::format::any));
221 
222     context_.dst_md.reset(new memory::desc(
223         {convFwdDims.dst_dims}, MklDnnType<Toutput>(), memory::format::any));
224 
225     if (!convFwdDims.bias_dims.empty())
226       context_.bias_md.reset(new memory::desc(
227           {convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::any));
228 
229     // Create a convolution
230     if (!convFwdDims.bias_dims.empty()) {
231       context_.fwd_desc.reset(new convolution_forward::desc(
232           prop_kind::forward, convolution_direct, *context_.src_md,
233           *context_.filter_md, *context_.bias_md, *context_.dst_md,
234           convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
235           convFwdDims.padding_right, padding_kind::zero));
236     } else {
237       context_.fwd_desc.reset(new convolution_forward::desc(
238           prop_kind::forward, convolution_direct, *context_.src_md,
239           *context_.filter_md, *context_.dst_md, convFwdDims.strides,
240           convFwdDims.dilations, convFwdDims.padding_left,
241           convFwdDims.padding_right, padding_kind::zero));
242     }
243 
244     context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
245 
246     // Check if there is any fusions as post-ops
247     auto const& post_op_params = convFwdDims.post_op_params;
248     mkldnn::primitive_attr post_ops_attr;
249     mkldnn::post_ops post_ops;
250     if (!post_op_params.empty()) {
251       for (auto const& post_op_param : post_op_params) {
252         if (post_op_param.name == "relu") {
253           DCHECK_EQ(post_op_param.param.size(), 3);
254           float op_scale = post_op_param.param[0];
255           float op_alpha = post_op_param.param[1];
256           float op_beta = post_op_param.param[2];
257           post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha,
258                                   op_beta);
259         } else if (post_op_param.name == "sum") {
260           DCHECK_EQ(post_op_param.param.size(), 1);
261           float op_scale = post_op_param.param[0];
262           post_ops.append_sum(op_scale);
263         } else if (post_op_param.name == "output_scale") {
264           DCHECK_EQ(post_op_param.param.size(), 1);
265           std::vector<float> scales;
266           scales.push_back(post_op_param.param[0]);
267           post_ops_attr.set_output_scales(0, scales);
268         } else {
269           DCHECK((post_op_param.name == "relu") ||
270                  (post_op_param.name == "sum") ||
271                  (post_op_param.name == "output_scale"));
272         }
273       }
274       post_ops_attr.set_post_ops(post_ops);
275       context_.fwd_pd.reset(
276           new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_));
277     } else {
278       context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
279     }
280 
281     // Store the expected memory format
282     context_.src_fmt = static_cast<mkldnn::memory::format>(
283         context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
284 
285     context_.filter_fmt = static_cast<mkldnn::memory::format>(
286         context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
287 
288     // Create memory primitive based on dummy data
289     context_.src_mem.reset(
290         new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
291     context_.filter_mem.reset(
292         new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
293     context_.dst_mem.reset(
294         new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
295 
296     // Create convolution primitive and add it to net
297     if (!convFwdDims.bias_dims.empty()) {
298       context_.bias_mem.reset(new memory(
299           {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
300            cpu_engine_},
301           DummyData));
302       context_.conv_fwd.reset(new convolution_forward(
303           *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
304           *context_.bias_mem, *context_.dst_mem));
305     } else {
306       context_.conv_fwd.reset(
307           new convolution_forward(*context_.fwd_pd, *context_.src_mem,
308                                   *context_.filter_mem, *context_.dst_mem));
309     }
310 
311     context_.fwd_primitives.push_back(*context_.conv_fwd);
312     return;
313   }
314 
315   struct ConvFwdContext context_;
316   engine cpu_engine_;
317 };
318 
319 template <typename T, typename Tinput, typename Tfilter, typename Tbias,
320           typename Toutput>
321 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
322  public:
Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)323   static MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>* Get(
324       const MklConvFwdParams& convFwdDims, bool do_not_cache) {
325     MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr;
326 
327     if (do_not_cache) {
328       // Always create a new primitive
329       conv_fwd = new MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>(
330           convFwdDims);
331     } else {
332       // Try to find a suitable one in pool
333       conv_fwd = dynamic_cast<
334           MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>*>(
335           MklConvFwdPrimitiveFactory<T, Tinput, Tfilter, Tbias,
336                                      Toutput>::GetInstance()
337               .GetConvFwd(convFwdDims));
338       if (conv_fwd == nullptr) {
339         conv_fwd = new MklConvFwdPrimitive<T, Tinput, Tfilter, Tbias, Toutput>(
340             convFwdDims);
341         MklConvFwdPrimitiveFactory<T, Tinput, Tfilter, Tbias,
342                                    Toutput>::GetInstance()
343             .SetConvFwd(convFwdDims, conv_fwd);
344       }
345     }
346 
347     return conv_fwd;
348   }
349 
350  private:
MklConvFwdPrimitiveFactory()351   MklConvFwdPrimitiveFactory() {}
~MklConvFwdPrimitiveFactory()352   ~MklConvFwdPrimitiveFactory() {}
353 
354   static const int kDilationH = 0, kDilationW = 1;
355 
GetInstance()356   static MklConvFwdPrimitiveFactory& GetInstance() {
357     static MklConvFwdPrimitiveFactory instance_;
358     return instance_;
359   }
360 
CreateKey(const MklConvFwdParams & convFwdDims)361   static string CreateKey(const MklConvFwdParams& convFwdDims) {
362     string prefix = "conv_fwd_";
363     FactoryKeyCreator key_creator;
364     key_creator.AddAsKey(prefix);
365     key_creator.AddAsKey(convFwdDims.src_dims);
366     key_creator.AddAsKey(convFwdDims.filter_dims);
367     key_creator.AddAsKey(convFwdDims.bias_dims);
368     key_creator.AddAsKey(convFwdDims.dst_dims);
369     key_creator.AddAsKey(convFwdDims.strides);
370     key_creator.AddAsKey(convFwdDims.dilations);
371     key_creator.AddAsKey(convFwdDims.padding_left);
372     key_creator.AddAsKey(convFwdDims.padding_right);
373     key_creator.AddAsKey(convFwdDims.dtypes);
374 
375     // Generate keys for post-ops
376     for (auto const& post_op_param : convFwdDims.post_op_params) {
377       if (post_op_param.name == "relu") {
378         DCHECK_EQ(post_op_param.param.size(), 3);
379         key_creator.AddAsKey(post_op_param.name);
380         key_creator.AddAsKey(post_op_param.param[0]);
381         key_creator.AddAsKey(post_op_param.param[1]);
382         key_creator.AddAsKey(post_op_param.param[2]);
383       } else if (post_op_param.name == "sum") {
384         DCHECK_EQ(post_op_param.param.size(), 1);
385         key_creator.AddAsKey(post_op_param.name);
386         key_creator.AddAsKey(post_op_param.param[0]);
387       } else if (post_op_param.name == "output_scale") {
388         DCHECK_EQ(post_op_param.param.size(), 1);
389         key_creator.AddAsKey(post_op_param.name);
390         key_creator.AddAsKey(post_op_param.param[0]);
391       } else {
392         return string("not_a_key");
393       }
394     }
395 
396     return key_creator.GetKey();
397   }
398 
GetConvFwd(const MklConvFwdParams & convFwdDims)399   MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
400     string key = CreateKey(convFwdDims);
401     return this->GetOp(key);
402   }
403 
SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)404   void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
405     string key = CreateKey(convFwdDims);
406     this->SetOp(key, op);
407   }
408 };
409 
410 #endif
411 
412 typedef Eigen::ThreadPoolDevice CPUDevice;
413 
414 // For now, MKL-ML is default. So making MKL-DNN not a default choice.
415 #ifdef INTEL_MKL_ML_ONLY
416 template <typename Device, typename T, bool bias_enabled>
417 class MklConvOp : public OpKernel {
418  public:
~MklConvOp()419   ~MklConvOp() {}
420 
MklConvOp(OpKernelConstruction * context)421   explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
422     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
423     string data_format;
424     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
425     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
426                 errors::InvalidArgument("Invalid data format"));
427     OP_REQUIRES(context, strides_.size() == 4,
428                 errors::InvalidArgument("Sliding window strides field must "
429                                         "specify 4 dimensions"));
430 
431     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
432     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
433     OP_REQUIRES(
434         context, stride_n == 1 && stride_c == 1,
435         errors::InvalidArgument("Current implementation does not yet support "
436                                 "strides in the batch and depth dimensions."));
437     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
438   }
439 
Compute(OpKernelContext * context)440   void Compute(OpKernelContext* context) override {
441     MklConv2DOpContext mkl_context;
442     const Tensor& input = MklGetInput(context, 0);
443     GetMklShape(context, 0, &(mkl_context.input_shape));
444     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
445 
446     const Tensor& filter = MklGetInput(context, 1);
447     MklShape mkl_filter_shape;
448     GetMklShape(context, 1, &mkl_filter_shape);
449     CHECK(!mkl_filter_shape.IsMklTensor())
450         << "Conv filter should not be in MKL Layout";
451 
452     if (bias_enabled) {
453       const Tensor& bias = MklGetInput(context, 2);
454       OP_REQUIRES(context, bias.dims() == 1,
455                   errors::InvalidArgument("bias must be 1-dimensional: ",
456                                           bias.shape().DebugString()));
457     }
458 
459     if (!input_in_mkl_format) {
460       OP_REQUIRES(context, input.dims() == 4,
461                   errors::InvalidArgument("input must be 4-dimensional",
462                                           input.shape().DebugString()));
463     }
464 
465     OP_REQUIRES(context, filter.dims() == 4,
466                 errors::InvalidArgument("filter must be 4-dimensional: ",
467                                         filter.shape().DebugString()));
468 
469     for (int i = 0; i < 3; ++i) {
470       OP_REQUIRES(
471           context,
472           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
473           errors::InvalidArgument("filter too large"));
474     }
475 
476     const int64 input_depth =
477         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
478                             : GetTensorDim(input, data_format_, 'C');
479     OP_REQUIRES(context, input_depth == filter.dim_size(2),
480                 errors::InvalidArgument(
481                     "input and filter must have the same depth: ", input_depth,
482                     " vs ", filter.dim_size(2)));
483     // The last dimension for filter is out_depth.
484     const int out_depth = static_cast<int>(filter.dim_size(3));
485 
486     // The second dimension for input is rows/height.
487     // The first dimension for filter is rows/height.
488     const int64 input_rows_raw =
489         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
490                             : GetTensorDim(input, data_format_, 'H');
491     OP_REQUIRES(
492         context,
493         FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
494         errors::InvalidArgument("Input rows too large"));
495     const int input_rows = static_cast<int>(input_rows_raw);
496     const int filter_rows = static_cast<int>(filter.dim_size(0));
497 
498     // The third dimension for input is columns/width.
499     // The second dimension for filter is columns/width.
500     const int64 input_cols_raw =
501         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
502                             : GetTensorDim(input, data_format_, 'W');
503     OP_REQUIRES(
504         context,
505         FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
506         errors::InvalidArgument("Input cols too large"));
507     const int input_cols = static_cast<int>(input_cols_raw);
508     const int filter_cols = static_cast<int>(filter.dim_size(1));
509 
510     // The first dimension for input is batch.
511     const int64 input_batch_raw =
512         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
513                             : GetTensorDim(input, data_format_, 'N');
514     OP_REQUIRES(
515         context,
516         FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
517         errors::InvalidArgument("batch is too large"));
518     const int batch = static_cast<int>(input_batch_raw);
519 
520     // For now we take the stride from the second and third dimensions only (we
521     // do not support striding on the batch or depth dimension).
522     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
523     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
524 
525     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
526     OP_REQUIRES_OK(context,
527                    GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
528                                          padding_, &out_rows, &pad_rows));
529     OP_REQUIRES_OK(context,
530                    GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
531                                          padding_, &out_cols, &pad_cols));
532     TensorShape out_shape =
533         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
534 
535     // Output tensor is of the following dimensions:
536     // [ in_batch, out_rows, out_cols, out_depth ]
537     Tensor* output = nullptr;
538 
539     // If there is nothing to compute, return.
540     if (out_shape.num_elements() == 0) {
541       // Nothing to do, allocate output tensor and return
542       MklShape mkl_output_mkl_shape;
543       mkl_output_mkl_shape.SetMklTensor(false);
544       AllocateOutputSetMklShape(context, 0, &output, input.shape(),
545                                 mkl_output_mkl_shape);
546       return;
547     }
548 
549     if (batch == 0) {
550       // Nothing to do, allocate output tensor and return
551       MklShape mkl_output_mkl_shape;
552       mkl_output_mkl_shape.SetMklTensor(false);
553       AllocateOutputSetMklShape(context, 0, &output, input.shape(),
554                                 mkl_output_mkl_shape);
555       return;
556     }
557 
558     // Create MKL convolution primitives
559     mkl_context.in_dims = input_in_mkl_format
560                               ? mkl_context.input_shape.GetDimension()
561                               : input.dims();
562     mkl_context.filter_dims = filter.dims();
563 
564     mkl_context.in_sizes[MklDims::W] = static_cast<size_t>(input_cols);
565     mkl_context.in_sizes[MklDims::H] = static_cast<size_t>(input_rows);
566     mkl_context.in_sizes[MklDims::C] = static_cast<size_t>(input_depth);
567     mkl_context.in_sizes[MklDims::N] = static_cast<size_t>(batch);
568 
569     mkl_context.out_sizes[MklDims::W] = static_cast<size_t>(out_cols);
570     mkl_context.out_sizes[MklDims::H] = static_cast<size_t>(out_rows);
571     mkl_context.out_sizes[MklDims::C] = static_cast<size_t>(out_depth);
572     mkl_context.out_sizes[MklDims::N] = static_cast<size_t>(batch);
573 
574     mkl_context.input_offset[0] = static_cast<int>(-pad_cols);
575     mkl_context.input_offset[1] = static_cast<int>(-pad_rows);
576 
577     mkl_context.conv_stride[0] = static_cast<size_t>(stride_cols);
578     mkl_context.conv_stride[1] = static_cast<size_t>(stride_rows);
579 
580     GetStridesFromSizes(data_format_, mkl_context.out_strides,
581                         mkl_context.out_sizes);
582     GetStridesFromSizes(data_format_, mkl_context.in_strides,
583                         mkl_context.in_sizes);
584 
585     // TF filter dimension order (out_depth, in_depth, cols, rows) ->
586     // MKL filter dimension order (out_depth, in_depth, rows, cols)
587     mkl_context.filter_sizes[0] = filter.dim_size(1);  // cols
588     mkl_context.filter_sizes[1] = filter.dim_size(0);  // rows
589     mkl_context.filter_sizes[2] = filter.dim_size(2);  // in_depth
590     mkl_context.filter_sizes[3] = filter.dim_size(3);  // out_depth
591 
592     // TF filter layout - (rows, cols, in_depth, out_depth)
593     mkl_context.filter_strides[0] =
594         filter.dim_size(2) * filter.dim_size(3);  // cols
595     mkl_context.filter_strides[1] =
596         filter.dim_size(1) * filter.dim_size(2) * filter.dim_size(3);  // rows
597     mkl_context.filter_strides[2] = filter.dim_size(3);  // in_depth
598     mkl_context.filter_strides[3] = 1;                   // out_depth
599 
600     if (bias_enabled) {
601       const Tensor& bias = MklGetInput(context, 2);
602       mkl_context.bias_sizes[0] = {static_cast<size_t>(bias.dim_size(0))};
603       mkl_context.bias_strides[0] = {1};
604     }
605 
606     // Create Convolution Primitive
607     if (bias_enabled) {
608       CHECK_EQ(
609           dnnConvolutionCreateForwardBias_F32(
610               &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect,
611               mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
612               mkl_context.filter_sizes, mkl_context.conv_stride,
613               mkl_context.input_offset, dnnBorderZeros),
614           E_SUCCESS);
615     } else {
616       CHECK_EQ(
617           dnnConvolutionCreateForward_F32(
618               &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect,
619               mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
620               mkl_context.filter_sizes, mkl_context.conv_stride,
621               mkl_context.input_offset, dnnBorderZeros),
622           E_SUCCESS);
623     }
624 
625     TensorShape mkl_output_tf_shape;
626     MklShape mkl_output_mkl_shape;
627     mkl_output_mkl_shape.SetMklTensor(true);
628     mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, dnnResourceDst);
629     mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
630                                      mkl_context.out_strides);
631     // MKL might change the dimension ordering
632     // Create mapping to recover the original TF dimension order
633     mkl_output_mkl_shape.SetTfDimOrder(mkl_context.in_dims, data_format_);
634 
635     mkl_output_tf_shape.AddDim(
636         dnnLayoutGetMemorySize_F32(
637             static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
638         sizeof(T));
639     AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
640                               mkl_output_mkl_shape);
641     // Filter output to be used in the backprop_input
642     TensorShape mkl_filter_output_tf_shape;
643     MklShape mkl_filter_output_mkl_shape;
644     mkl_filter_output_mkl_shape.SetMklTensor(true);
645     mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd,
646                                              dnnResourceFilter);
647 
648     size_t filter_sizes[4] = {static_cast<size_t>(filter.dim_size(0)),
649                               static_cast<size_t>(filter.dim_size(1)),
650                               static_cast<size_t>(filter.dim_size(2)),
651                               static_cast<size_t>(filter.dim_size(3))};
652     mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes,
653                                             mkl_context.filter_strides);
654 
655     mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims,
656                                               data_format_);
657     mkl_filter_output_tf_shape.AddDim(
658         dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
659             mkl_filter_output_mkl_shape.GetMklLayout())) /
660         sizeof(T));
661     AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter,
662                               mkl_filter_output_tf_shape,
663                               mkl_filter_output_mkl_shape);
664 
665     mkl_context.conv_res[dnnResourceDst] =
666         static_cast<void*>(output->flat<T>().data());
667 
668     mkl_context.MklCreateInputLayouts(context);
669 
670     // Temp tensor used to allocate tmp buffers
671     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor,
672         mkl_tmp_bias_buf_tensor;
673     mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor,
674                                             &mkl_tmp_filter_buf_tensor,
675                                             &mkl_tmp_bias_buf_tensor);
676 
677     // Execute convolution
678     CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res),
679              E_SUCCESS);
680 
681     mkl_context.MklCleanup();
682   }
683 
684  private:
685   typedef struct {
686     int in_dims;
687     size_t in_sizes[4];
688     size_t in_strides[4];
689     size_t out_sizes[4];
690     size_t out_strides[4];
691     int filter_dims;
692     size_t filter_sizes[4];
693     size_t filter_strides[4];
694     size_t bias_sizes[1];
695     size_t bias_strides[1];
696     int input_offset[2];
697     size_t conv_stride[2];
698     MklShape input_shape;
699     dnnPrimitive_t prim_fwd;
700     void* conv_res[dnnResourceNumber];
701     dnnLayout_t lt_filter, lt_bias, lt_input;
702     Tensor* output_filter = nullptr;
703 
704     // Create MKL dnnLayout_t objects for tensors coming into the layer
MklCreateInputLayoutstensorflow::MklConvOp::__anonc9bfa00e0108705     void MklCreateInputLayouts(OpKernelContext* context) {
706       bool input_in_mkl_format = input_shape.IsMklTensor();
707       if (input_in_mkl_format) {
708         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
709       } else {
710         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
711                  E_SUCCESS);
712       }
713 
714       CHECK_EQ(dnnLayoutCreate_F32(&lt_filter, filter_dims, filter_sizes,
715                                    filter_strides),
716                E_SUCCESS);
717 
718       if (bias_enabled) {
719         CHECK_EQ(dnnLayoutCreate_F32(&lt_bias, 1, bias_sizes, bias_strides),
720                  E_SUCCESS);
721       }
722     }
723 
724     // Compare incoming tensor layouts with MKL preferred layouts and convert
725     // data to the preferred layout if necessary
MklPrepareConvolutionInputstensorflow::MklConvOp::__anonc9bfa00e0108726     void MklPrepareConvolutionInputs(OpKernelContext* context,
727                                      Tensor* mkl_tmp_input_buf_tensor,
728                                      Tensor* mkl_tmp_filter_buf_tensor,
729                                      Tensor* mkl_tmp_bias_buf_tensor) {
730       bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias;
731       dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias,
732           mkl_prim_convert_input;
733       dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias,
734           mkl_lt_internal_input;
735       void *mkl_buf_convert_input, *mkl_buf_convert_filter,
736           *mkl_buf_convert_bias;
737       mkl_prim_convert_filter = nullptr;
738       mkl_prim_convert_bias = nullptr;
739       mkl_prim_convert_input = nullptr;
740       mkl_lt_internal_filter = nullptr;
741       mkl_lt_internal_bias = nullptr;
742       mkl_lt_internal_input = nullptr;
743       mkl_buf_convert_input = nullptr;
744       mkl_buf_convert_filter = nullptr;
745       mkl_buf_convert_bias = nullptr;
746 
747       // Compare with internal layouts and convert if needed
748       const Tensor& input = MklGetInput(context, 0);
749       void* mkl_buf_input =
750           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
751       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
752                                                 prim_fwd, dnnResourceSrc),
753                E_SUCCESS);
754       mkl_convert_input =
755           !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
756       if (mkl_convert_input) {
757         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
758                                          mkl_lt_internal_input),
759                  E_SUCCESS);
760         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
761                        &mkl_buf_convert_input);
762         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
763                                           mkl_buf_convert_input),
764                  E_SUCCESS);
765         dnnDelete_F32(mkl_prim_convert_input);
766       }
767       dnnLayoutDelete_F32(mkl_lt_internal_input);
768 
769       conv_res[dnnResourceSrc] =
770           (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
771 
772       const Tensor& filter = MklGetInput(context, 1);
773       void* mkl_buf_filter =
774           const_cast<void*>(static_cast<const void*>(filter.flat<T>().data()));
775       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_filter,
776                                                 prim_fwd, dnnResourceFilter),
777                E_SUCCESS);
778       mkl_convert_filter =
779           !dnnLayoutCompare_F32(mkl_lt_internal_filter, lt_filter);
780       if (mkl_convert_filter) {
781         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter,
782                                          mkl_lt_internal_filter),
783                  E_SUCCESS);
784 
785         mkl_buf_convert_filter = const_cast<void*>(
786             static_cast<const void*>(output_filter->flat<T>().data()));
787 
788         CHECK_EQ(
789             dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter,
790                                      mkl_buf_convert_filter),
791             E_SUCCESS);
792         dnnDelete_F32(mkl_prim_convert_filter);
793       }
794       dnnLayoutDelete_F32(mkl_lt_internal_filter);
795 
796       conv_res[dnnResourceFilter] =
797           (mkl_convert_filter) ? mkl_buf_convert_filter : mkl_buf_filter;
798 
799       if (bias_enabled) {
800         const Tensor& bias = MklGetInput(context, 2);
801         void* mkl_buf_bias =
802             const_cast<void*>(static_cast<const void*>(bias.flat<T>().data()));
803         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_bias,
804                                                   prim_fwd, dnnResourceBias),
805                  E_SUCCESS);
806         mkl_convert_bias = !dnnLayoutCompare_F32(mkl_lt_internal_bias, lt_bias);
807         if (mkl_convert_bias) {
808           CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_bias, lt_bias,
809                                            mkl_lt_internal_bias),
810                    E_SUCCESS);
811           AllocTmpBuffer(context, mkl_tmp_bias_buf_tensor, mkl_lt_internal_bias,
812                          &mkl_buf_convert_bias);
813           CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_bias, mkl_buf_bias,
814                                             mkl_buf_convert_bias),
815                    E_SUCCESS);
816           dnnDelete_F32(mkl_prim_convert_bias);
817         }
818         dnnLayoutDelete_F32(mkl_lt_internal_bias);
819 
820         conv_res[dnnResourceBias] =
821             (mkl_convert_bias) ? mkl_buf_convert_bias : mkl_buf_bias;
822       }
823     }
824 
MklCleanuptensorflow::MklConvOp::__anonc9bfa00e0108825     void MklCleanup() {
826       bool input_in_mkl_format = input_shape.IsMklTensor();
827       dnnDelete_F32(prim_fwd);
828       if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input);
829       dnnLayoutDelete_F32(lt_filter);
830       if (bias_enabled) dnnLayoutDelete_F32(lt_bias);
831     }
832   } MklConv2DOpContext;
833 
834   std::vector<int32> strides_;
835   Padding padding_;
836   TensorFormat data_format_;
837 };
838 
839 // FP32 kernel registration for INTEL_MKL_ML
840 REGISTER_KERNEL_BUILDER(Name("_MklConv2D")
841                             .Device(DEVICE_CPU)
842                             .TypeConstraint<float>("T")
843                             .Label(mkl_op_registry::kMklOpLabel),
844                         MklConv2DOp<CPUDevice, float, false>);
845 REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias")
846                             .Device(DEVICE_CPU)
847                             .TypeConstraint<float>("T")
848                             .Label(mkl_op_registry::kMklOpLabel),
849                         MklConv2DOp<CPUDevice, float, true>);
850 
851 #else
852 
853 // Base class for convolution forward operations
854 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
855           typename Toutput, typename Ttemp_output, typename Tpadding,
856           bool bias_enabled, bool pad_enabled, bool is_depthwise>
857 class MklConvOp : public OpKernel {
858  public:
~MklConvOp()859   ~MklConvOp() {}
860 
MklConvOp(OpKernelConstruction * context)861   explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
862     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
863     if (context->HasAttr("padding_list")) {
864       OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_));
865     }
866     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
867     string data_format;
868     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
869     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
870                 errors::InvalidArgument("Invalid data format"));
871     OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
872                 errors::InvalidArgument("Sliding window strides field must "
873                                         "specify 4 or 5 dimensions"));
874 
875     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
876     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
877     OP_REQUIRES(
878         context, stride_n == 1 && stride_c == 1,
879         errors::InvalidArgument("Current implementation does not yet support "
880                                 "strides in the batch and depth dimensions."));
881     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
882     is_filter_const_ = false;
883     OP_REQUIRES_OK(context,
884                    context->GetAttr("is_filter_const", &is_filter_const_));
885 
886     if (strides_.size() == 4) {
887       OP_REQUIRES(context, dilations_.size() == 4,
888                   errors::InvalidArgument("Sliding window dilations field must "
889                                           "specify 4 dimensions"));
890       const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
891       const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
892       const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
893       const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
894       OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
895                   errors::InvalidArgument(
896                       "Current implementation does not yet support "
897                       "dilations in the batch and depth dimensions."));
898       OP_REQUIRES(
899           context, dilation_h > 0 && dilation_w > 0,
900           errors::InvalidArgument("Dilated rates should be larger than 0."));
901     } else if (strides_.size() == 5) {
902       OP_REQUIRES(context, dilations_.size() == 5,
903                   errors::InvalidArgument("Dilation rates field must "
904                                           "specify 5 dimensions"));
905       OP_REQUIRES(context,
906                   (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
907                    GetTensorDim(dilations_, data_format_, 'C') == 1),
908                   errors::InvalidArgument(
909                       "Current implementation does not yet support "
910                       "dilations rates in the batch and depth dimensions."));
911       OP_REQUIRES(
912           context,
913           (GetTensorDim(dilations_, data_format_, '0') > 0 &&
914            GetTensorDim(dilations_, data_format_, '1') > 0 &&
915            GetTensorDim(dilations_, data_format_, '2') > 0),
916           errors::InvalidArgument("Dilated rates should be larger than 0."));
917     }
918   }
919 
Compute(OpKernelContext * context)920   void Compute(OpKernelContext* context) override {
921     try {
922       // Input tensors
923       const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
924       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
925 
926       // Data from persistent (cached) filter tensor
927       const Tensor& cached_filter_data_tensor =
928           *cached_filter_data_ptensor_.AccessTensor(context);
929 
930       MklDnnShape src_mkl_shape, filter_mkl_shape;
931       GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
932       GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
933       OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
934                   errors::InvalidArgument("Filter should not be in "
935                                           "Mkl Layout"));
936 
937       MklDnnData<Tinput> src(&cpu_engine_);
938       MklDnnData<Tfilter> filter(&cpu_engine_);
939 
940       memory::dims src_dims, filter_dims, padding_left, padding_right,
941           dilations, strides;
942       memory::dims dst_dims_tf_order, dst_dims_mkl_order;
943 
944       // For Quantized-Conv2D and Pad fusion, we get padding from the
945       // `padding_list` attribute. Otherwise, we get it from one of the inputs.
946       bool quantized_pad_enabled = false;
947       for (auto const& padding_val : padding_list_) {
948         if (padding_val) {
949           quantized_pad_enabled = true;
950           break;
951         }
952       }
953 
954       if (fuse_pad_ || quantized_pad_enabled) {
955         PadWithConvFusion(context, padding_left, padding_right,
956                           quantized_pad_enabled);
957       }
958 
959       // Get shapes of input tensors in MKL-DNN order
960       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
961                               dilations_);
962       auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
963       auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
964       conv_utl.GetConvFwdSizesInMklOrder(
965           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
966           &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
967           &padding_right, (fuse_pad_ || quantized_pad_enabled), is_depthwise);
968 
969       if (!context->status().ok()) return;
970 
971       // Check for corner case - if there is nothing to compute, return.
972       TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order);
973 
974       // Corner cases: output with 0 elements and 0 batch size.
975       Tensor* dst_tensor = nullptr;
976       if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
977         MklDnnShape dst_mkl_shape;
978         dst_mkl_shape.SetMklTensor(false);
979         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
980                                   src_tf_shape, dst_mkl_shape);
981 
982         // MklConv2D/3D also outputs converted filter as 2nd output.
983         filter_mkl_shape.SetMklTensor(false);
984         Tensor* output_filter_tensor = nullptr;
985         if (typeid(Tinput) == typeid(float) &&
986             typeid(Tfilter) == typeid(float) &&
987             typeid(Toutput) == typeid(float)) {
988           filter_mkl_shape.SetMklTensor(false);
989           AllocateOutputSetMklShape(context, kOutputIndex_Filter,
990                                     &output_filter_tensor, filter_tf_shape,
991                                     filter_mkl_shape);
992         }
993         return;
994       }
995 
996       bool is_conv2d = (strides_.size() == 4);
997 
998       if (!is_conv2d) {
999         OP_REQUIRES(
1000             context, !pad_enabled,
1001             errors::InvalidArgument("Pad + Conv fusion only works for 2D"));
1002       }
1003 
1004       // TODO 3-D support for Depthwise is not there
1005       if (is_depthwise) {
1006         OP_REQUIRES(context, is_conv2d,
1007                     errors::InvalidArgument(
1008                         "Only 2D convolution is supported for depthwise."));
1009       }
1010 
1011       // TODO(Intel-tf) Add check to make sure pad_enabled is true only for 2D
1012       if (!is_conv2d) {
1013         OP_REQUIRES(
1014             context, !fuse_pad_,
1015             errors::InvalidArgument("Pad+Conv fusion only works for 2D"));
1016       }
1017       // Create memory for user data.
1018       // Describe how the inputs and outputs of Convolution look like. Also
1019       // specify buffers containing actual input and output data.
1020       auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
1021                               : TFDataFormatToMklDnn3DDataFormat(data_format_);
1022 
1023       // If input is in MKL layout, then simply grab the layout; otherwise,
1024       // construct TF layout for input.
1025       // For constructing TF layout for input, although input shape (src_dims)
1026       // is required to be in MKL-DNN order, the input layout is actually in
1027       // TF layout depending on the data format:
1028       //     Conv2D: NHWC or NCHW
1029       //     Conv3D: NDHWC or NCDHW
1030       auto src_md = src_mkl_shape.IsMklTensor()
1031                         ? src_mkl_shape.GetMklLayout()
1032                         : memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
1033       src.SetUsrMem(src_md, &src_tensor);
1034 
1035       // Although filter shape (filter_dims) required is in MKL-DNN order,
1036       // the layout is Tensorflow's layout (HWIO) and (HWIGO) for
1037       // depthwise/group convolutions.
1038 
1039       auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo
1040                                                      : memory::format::hwio)
1041                                      : memory::format::dhwio;
1042 
1043       DCHECK(!filter_mkl_shape.IsMklTensor());
1044       auto filter_md =
1045           filter_mkl_shape.IsMklTensor()
1046               ? filter_mkl_shape.GetMklLayout()
1047               : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
1048       filter.SetUsrMem(filter_md, &filter_tensor);
1049 
1050       // MKLDNN dilations start from 0.
1051       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
1052 
1053       // In some cases, primitive descriptor could potentially contain
1054       // large buffers. As a result, we don't cache these primitives if the
1055       // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True.
1056       // MKL-DNN allocates buffers in the following cases:
1057       //   1. Legacy CPU without AVX512/AVX2, or
1058       //   2. 1x1 convolution with strides != 1
1059       bool do_not_cache =
1060           MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() &&
1061           (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
1062           (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() ||
1063            IsConv1x1StrideNot1(filter_dims, strides));
1064 
1065       // Get a conv2d fwd from primitive pool
1066       MklConvFwdPrimitive<float, Tinput, Tfilter, Tbias, Ttemp_output>*
1067           conv_fwd = nullptr;
1068       memory::dims bias_dims = {};
1069       if (fuse_biasadd_) {
1070         conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
1071       }
1072       MklConvFwdParams convFwdDims(
1073           src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS,
1074           dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
1075 
1076       // TODO(mdfaijul): Extend the basic parameters for data types and fusions
1077       this->ExtendConvFwdParams(context, convFwdDims);
1078 
1079       conv_fwd = MklConvFwdPrimitiveFactory<float, Tinput, Tfilter, Tbias,
1080                                             Ttemp_output>::Get(convFwdDims,
1081                                                                do_not_cache);
1082 
1083       // Allocate output tensors `output_tensor` and `filter_out_tensor`
1084       std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
1085       AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
1086                            &dst_tensor);
1087       Tensor* filter_out_tensor = nullptr;
1088       if (typeid(Tinput) == typeid(float) && typeid(Tfilter) == typeid(float) &&
1089           typeid(Toutput) == typeid(float)) {
1090         AllocateFilterOutputTensor(context, *conv_fwd_pd,
1091                                    TFShapeToMklDnnDims(filter_tf_shape),
1092                                    &filter_out_tensor);
1093       }
1094 
1095       Ttemp_output* dst_data =
1096           reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data());
1097 
1098       // Check whether src and filter need to be reordered
1099       Tinput* src_data = nullptr;
1100       if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
1101         // Reorder src
1102         src.SetUsrMem(src_md, &src_tensor);
1103         src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
1104         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
1105       } else {
1106         src_data = static_cast<Tinput*>(
1107             const_cast<Tinput*>(src_tensor.flat<Tinput>().data()));
1108       }
1109 
1110       Tfilter* filter_data = nullptr;
1111       if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
1112         bool is_filter_cached = false;
1113         // If filter is a constant, we can avoid the conversion of filter from
1114         // Tensorflow format to MKL format by caching the filter when it is
1115         // converted for the first time. This cached filter can then be reused
1116         // in subsequent iterations.
1117         if (is_filter_const_) {
1118           if (IsFilterCacheEmpty(context)) {
1119             // Cache filter if it is not already cached.
1120             CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
1121                         filter, filter_md);
1122           }
1123           filter_data =
1124               GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat());
1125           is_filter_cached = (filter_data != nullptr);
1126         }
1127         if (!is_filter_cached) {
1128           filter.SetUsrMem(filter_md, &filter_tensor);
1129           if (filter_out_tensor == nullptr) {
1130             filter.CheckReorderToOpMem(
1131                 conv_fwd_pd.get()->weights_primitive_desc());
1132           } else {
1133             filter.CheckReorderToOpMem(
1134                 conv_fwd_pd.get()->weights_primitive_desc(),
1135                 filter.GetTensorBuffer(filter_out_tensor));
1136           }
1137           filter_data =
1138               static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
1139         }
1140       } else {
1141         filter_data = static_cast<Tfilter*>(
1142             const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data()));
1143       }
1144 
1145       // Execute convolution
1146       if (fuse_biasadd_) {
1147         const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
1148         Tbias* bias_data =
1149             this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
1150         conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
1151       } else {
1152         conv_fwd->Execute(src_data, filter_data, dst_data);
1153       }
1154 
1155       // Delete primitive since it is not cached.
1156       if (do_not_cache) delete conv_fwd;
1157     } catch (mkldnn::error& e) {
1158       string error_msg = tensorflow::strings::StrCat(
1159           "Status: ", e.status, ", message: ", string(e.message), ", in file ",
1160           __FILE__, ":", __LINE__);
1161       OP_REQUIRES_OK(
1162           context,
1163           errors::Aborted("Operation received an exception:", error_msg));
1164     }
1165   }
1166 
PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool quantized_pad_enabled)1167   void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left,
1168                          memory::dims& padding_right,
1169                          bool quantized_pad_enabled) {
1170     const Tensor& paddings_tf = MklGetInput(context, input_index_pad_);
1171     Tpadding* paddings = nullptr;
1172     if (quantized_pad_enabled) {
1173       paddings = padding_list_.data();
1174     } else {
1175       OP_REQUIRES(context, paddings_tf.dims() == 2,
1176                   errors::InvalidArgument("paddings must be 2-dimensional: ",
1177                                           paddings_tf.shape().DebugString()));
1178       // Flatten tensor to get individual paddings.
1179       paddings = static_cast<Tpadding*>(
1180           const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data()));
1181     }
1182     // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf)
1183     // will be zero.
1184     // Example:
1185     // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ],
1186     // flat method = row-major, then:
1187     // paddings = {0, 0, 1, 2, 3, 4, 0, 0}.
1188     // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4.
1189     //
1190     // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of
1191     // paddings(_tf) will be zero.
1192     // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}.
1193     int64 pad_top, pad_left;
1194     int64 pad_bottom, pad_right;
1195     string data_format = ToString(data_format_);
1196     if (data_format == "NHWC") {
1197       pad_top = paddings[2];
1198       pad_bottom = paddings[3];
1199       pad_left = paddings[4];
1200       pad_right = paddings[5];
1201     } else if (data_format == "NCHW") {
1202       pad_top = paddings[4];
1203       pad_bottom = paddings[5];
1204       pad_left = paddings[6];
1205       pad_right = paddings[7];
1206     }
1207     // Create padding arrays for MKL-DNN convolutions.
1208     // MKL-DNN uses asymetric padding.
1209     padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
1210     padding_right = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
1211   }
1212 
1213  protected:
set_fuse_biasadd(bool fuse_biasadd)1214   void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; }
set_fuse_relu(bool fuse_relu)1215   void set_fuse_relu(bool fuse_relu) { fuse_relu_ = fuse_relu; }
set_fuse_pad(bool fuse_pad)1216   void set_fuse_pad(bool fuse_pad) {
1217     fuse_pad_ = fuse_pad;
1218     // In PadwithFusedConv OP, pad is the fourth index.
1219     input_index_pad_ = 3;
1220   }
1221 
1222   // This method is for the base class MklConvOp, which handles the
1223   // floating point implementation of Conv. The quantized conv implementations
1224   // will use overidden versions of this method.
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1225   virtual void ExtendConvFwdParams(OpKernelContext* context,
1226                                    MklConvFwdParams& params) {
1227     // Create a string from data types of input, filter, bias, and output.
1228     params.dtypes.append(typeid(Tinput).name());
1229     params.dtypes.append(typeid(Tfilter).name());
1230     params.dtypes.append(typeid(Tbias).name());
1231     params.dtypes.append(typeid(Toutput).name());
1232 
1233     // Add fusions as post ops
1234     // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
1235     // checking `fuse_biasadd_` flag.
1236     if (fuse_relu_) params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
1237   }
1238 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)1239   virtual Tbias* GetBiasHandle(OpKernelContext* context,
1240                                std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd,
1241                                const Tensor& bias_tensor) {
1242     if (fuse_biasadd_) {
1243       return static_cast<Tbias*>(
1244           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1245     }
1246     return nullptr;
1247   }
1248 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,memory::format output_tf_format,Tensor ** output_tensor)1249   virtual void AllocateOutputTensor(OpKernelContext* context,
1250                                     const ConvFwdPd& conv_prim_desc,
1251                                     const memory::dims& output_dims_mkl_order,
1252                                     memory::format output_tf_format,
1253                                     Tensor** output_tensor) {
1254     CHECK_NOTNULL(output_tensor);
1255     auto dst_pd = conv_prim_desc.dst_primitive_desc();
1256 
1257     auto dst_md = dst_pd.desc();
1258     if (!std::is_same<Ttemp_output, Toutput>::value) {
1259       dst_md.data.data_type =
1260           static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
1261       dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
1262     }
1263     // Allocate shape of Mkl tensor.
1264     MklDnnShape output_mkl_shape;
1265     output_mkl_shape.SetMklTensor(true);
1266     output_mkl_shape.SetMklLayout(&dst_pd);
1267     output_mkl_shape.SetElemType(MklDnnType<Toutput>());
1268     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
1269                                  output_dims_mkl_order, output_tf_format);
1270 
1271     // Allocate shape of TF tensor.
1272     TensorShape output_tf_shape;
1273     output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
1274 
1275     AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
1276                               output_tf_shape, output_mkl_shape);
1277   }
1278 
1279   engine cpu_engine_ = engine(engine::cpu, 0);
1280 
1281  private:
1282   std::vector<int32> strides_;
1283   std::vector<int32> dilations_;
1284   std::vector<Tpadding> padding_list_;
1285   bool is_filter_const_;
1286   mutex mu_;
1287   Padding padding_;
1288   TensorFormat data_format_;
1289   PersistentTensor cached_filter_data_ptensor_ GUARDED_BY(mu_);
1290   PersistentTensor cached_filter_md_ptensor_ GUARDED_BY(mu_);
1291 
1292   // Initialize to values the template is instantiated with
1293   bool fuse_biasadd_ = bias_enabled;
1294   bool fuse_relu_ = false;
1295   bool fuse_pad_ = pad_enabled;
1296 
1297   int input_index_pad_ = 2;
1298 
1299   const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
1300   const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
1301   const int kDilationH = 0, kDilationW = 1;
1302 
1303   // Allocate persistent tensors for cached filter data and
1304   // cached filter memory descriptor (data format)
AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1305   void AllocatePersistentTensor(OpKernelContext* context,
1306                                 const ConvFwdPd& conv_prim_desc,
1307                                 Tensor** filter_tensor) {
1308     DCHECK(filter_tensor);
1309     TensorShape filter_tf_shape;
1310     filter_tf_shape.AddDim(
1311         (conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter)));
1312     OP_REQUIRES_OK(context, context->allocate_persistent(
1313                                 DataTypeToEnum<Tfilter>::value, filter_tf_shape,
1314                                 &cached_filter_data_ptensor_, filter_tensor));
1315 
1316     Tensor* second_tensor = nullptr;
1317     TensorShape filter_mkl_format;
1318     filter_mkl_format.AddDim(
1319         sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) /
1320         sizeof(DT_INT32));
1321     OP_REQUIRES_OK(context, context->allocate_persistent(
1322                                 DT_INT32, filter_mkl_format,
1323                                 &cached_filter_md_ptensor_, &second_tensor));
1324     second_tensor->scalar<int32>()() =
1325         conv_prim_desc.weights_primitive_desc().desc().data.format;
1326   }
1327 
AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1328   void AllocateFilterOutputTensor(OpKernelContext* context,
1329                                   const ConvFwdPd& conv_prim_desc,
1330                                   const memory::dims& filter_dims_tf_order,
1331                                   Tensor** filter_tensor) {
1332     CHECK_NOTNULL(filter_tensor);
1333     auto filter_pd = conv_prim_desc.weights_primitive_desc();
1334 
1335     // Allocate shape of Mkl tensor.
1336     MklDnnShape filter_mkl_shape;
1337     filter_mkl_shape.SetMklTensor(true);
1338     filter_mkl_shape.SetMklLayout(&filter_pd);
1339     filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
1340 
1341     // The format of the filter is actually OIhw8i8o, but TF doesn't support
1342     // this format. Just use format::blocked for now because the layout
1343     // is stored in the MKL data.
1344     filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
1345                                  filter_dims_tf_order, memory::format::blocked);
1346 
1347     // Allocate the data space for the filter to propagate as TF tensor.
1348     TensorShape filter_tf_shape;
1349     filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter)));
1350 
1351     AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
1352                               filter_tf_shape, filter_mkl_shape);
1353   }
1354 
1355   // Prepare and execute net - checks for input and output reorders.
PrepareAndExecuteNet(const ConvFwdPd & conv_prim_desc,MklDnnData<Tinput> * src,MklDnnData<Tfilter> * filter,MklDnnData<Tbias> * bias,MklDnnData<Toutput> * output,Tensor * filter_out_tensor)1356   void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
1357                             MklDnnData<Tinput>* src,
1358                             MklDnnData<Tfilter>* filter,
1359                             MklDnnData<Tbias>* bias,
1360                             MklDnnData<Toutput>* output,
1361                             Tensor* filter_out_tensor) {
1362     CHECK_NOTNULL(filter_out_tensor);
1363 
1364     // Create reorders between user layout and MKL layout if it is needed and
1365     // add it to the net before convolution. No need to check for output
1366     // reorder as we propagate output layout to the next layer.
1367     src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
1368 
1369     // rather than re-order to a temp buffer, reorder directly to the
1370     // filter output tensor
1371     filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
1372                                 filter->GetTensorBuffer(filter_out_tensor));
1373 
1374     // Create convolution primitive and add it to net.
1375     std::vector<primitive> net;
1376     if (bias) {
1377       DCHECK(fuse_biasadd_);
1378       net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
1379                                         filter->GetOpMem(), bias->GetOpMem(),
1380                                         output->GetOpMem()));
1381     } else {
1382       DCHECK(!fuse_biasadd_);
1383       net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
1384                                         filter->GetOpMem(),
1385                                         output->GetOpMem()));
1386     }
1387 
1388     stream(stream::kind::eager).submit(net).wait();
1389   }
1390 
1391   // LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
1392   // be acquired before entering the function, since it is acquired
1393   // inside the function.
IsFilterCacheEmpty(OpKernelContext * context)1394   inline bool IsFilterCacheEmpty(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
1395     tf_shared_lock lock(mu_);
1396     const Tensor& cached_filter_data_tensor =
1397         *cached_filter_data_ptensor_.AccessTensor(context);
1398     return (cached_filter_data_tensor.NumElements() == 0);
1399   }
1400 
1401   // Cache the converted filter in a persistent tensor.
1402   // Only one thread can execute this method at any given time.
CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md)1403   void CacheFilter(OpKernelContext* context,
1404                    const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1405                    Tfilter* filter_data, const Tensor& filter_tensor,
1406                    MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
1407       LOCKS_EXCLUDED(mu_) {
1408     mutex_lock lock(mu_);
1409     const Tensor& cached_filter_data_tensor =
1410         *cached_filter_data_ptensor_.AccessTensor(context);
1411 
1412     // If filter is already cached, there's nothing to do.
1413     if (cached_filter_data_tensor.NumElements() > 0) {
1414       return;
1415     }
1416 
1417     // Otherwise, cache filter
1418     filter.SetUsrMem(filter_md, &filter_tensor);
1419     filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
1420     filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
1421 
1422     Tensor* filter_tensor_ptr = nullptr;
1423     AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
1424     void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
1425     size_t cached_filter_data_size =
1426         filter.GetOpMem().get_primitive_desc().get_size();
1427     memcpy(cached_filter_data, filter_data, cached_filter_data_size);
1428   }
1429 
GetCachedFilter(OpKernelContext * context,const memory::format & filter_mf)1430   Tfilter* GetCachedFilter(OpKernelContext* context,
1431                            const memory::format& filter_mf)
1432       LOCKS_EXCLUDED(mu_) {
1433     tf_shared_lock lock(mu_);
1434     const Tensor& cached_filter_data =
1435         *cached_filter_data_ptensor_.AccessTensor(context);
1436     const Tensor& cached_filter_md =
1437         *cached_filter_md_ptensor_.AccessTensor(context);
1438 
1439     // Check if the memory descriptor of the cached weights is same as
1440     // filter_mf. If so, we can used the cached weights; otherwise
1441     // return NULL.
1442     // TODO (bhavanis): Do we need to cast filter_mf before the check?
1443     if (cached_filter_md.scalar<int32>().size() &&
1444         cached_filter_md.scalar<int32>()() == filter_mf) {
1445       return static_cast<Tfilter*>(
1446           const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
1447     }
1448     return nullptr;
1449   }
1450 };
1451 
1452 // Base class for fused convolution forward operations
1453 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
1454           typename Toutput, typename Ttemp_output, typename Tpadding,
1455           bool pad_enabled>
1456 class MklFusedConvOp
1457     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1458                        Tpadding, false, false, false> {
1459  public:
MklFusedConvOp(OpKernelConstruction * context)1460   explicit MklFusedConvOp(OpKernelConstruction* context)
1461       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1462                   Tpadding, false, false, false>(context) {
1463     // Since we came here through the registration of _MklFusedConv2D, get
1464     // all information from 'fused_ops' and 'num_args'
1465     std::vector<string> fused_ops;
1466     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
1467 
1468     int num_args;
1469     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
1470     OP_REQUIRES(context, !fused_ops.empty(),
1471                 errors::InvalidArgument(
1472                     "Fused Conv2D must have at least one fused op."));
1473 
1474     if (fused_ops == std::vector<string>{"BiasAdd"}) {
1475       this->set_fuse_biasadd(true);
1476       OP_REQUIRES(context, num_args == 1,
1477                   errors::InvalidArgument(
1478                       "Fused Conv2D must have one extra argument: bias."));
1479     } else if (fused_ops == std::vector<string>{"Relu"}) {
1480       this->set_fuse_relu(true);
1481     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
1482       this->set_fuse_biasadd(true);
1483       this->set_fuse_relu(true);
1484       OP_REQUIRES(context, num_args == 1,
1485                   errors::InvalidArgument(
1486                       "Fused Conv2D must have one extra argument: bias."));
1487     } else {
1488       OP_REQUIRES(context, false,
1489                   errors::Unimplemented("Fusion is not implemented: [",
1490                                         str_util::Join(fused_ops, ","), "]"));
1491     }
1492 
1493     if (pad_enabled) {
1494       this->set_fuse_pad(true);
1495     }
1496   }
1497 
~MklFusedConvOp()1498   virtual ~MklFusedConvOp() {}
1499 };
1500 
1501 // We create new class for each version of Quantized Convolution and inherit
1502 // from the FP32 version of the base class
1503 template <typename Device, typename Tbias, typename Toutput,
1504           typename Ttemp_output, bool bias_enabled>
1505 class MklQuantizedConv2DOp
1506     : public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output,
1507                        int32, bias_enabled, false, false> {
1508  public:
~MklQuantizedConv2DOp()1509   virtual ~MklQuantizedConv2DOp() {
1510     if (this->input_bias_ != nullptr) {
1511       delete this->input_bias_;
1512       input_bias_ = nullptr;
1513     }
1514 
1515     if (this->scaled_bias_ != nullptr) {
1516       delete this->scaled_bias_;
1517       scaled_bias_ = nullptr;
1518     }
1519   }
1520 
MklQuantizedConv2DOp(OpKernelConstruction * context)1521   explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
1522       : MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
1523                   bias_enabled, false, false>(context) {
1524     bool is_filter_const;
1525     OP_REQUIRES_OK(context,
1526                    context->GetAttr("is_filter_const", &is_filter_const));
1527     OP_REQUIRES(context, is_filter_const,
1528                 errors::InvalidArgument("Filter must be a constant"));
1529   }
1530 
Compute(OpKernelContext * context)1531   void Compute(OpKernelContext* context) override {
1532     // Compute int32 output tensor
1533     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
1534               bias_enabled, false, false>::Compute(context);
1535 
1536     // Compute additional outputs: min/max scalars.
1537     int bias_index_offset;
1538     bias_index_offset = bias_enabled ? 1 : 0;
1539 
1540     const float min_input =
1541         context->input(2 + bias_index_offset).flat<float>()(0);
1542     const float max_input =
1543         context->input(3 + bias_index_offset).flat<float>()(0);
1544     const float min_filter =
1545         context->input(4 + bias_index_offset).flat<float>()(0);
1546     const float max_filter =
1547         context->input(5 + bias_index_offset).flat<float>()(0);
1548 
1549     float min_output_value;
1550     float max_output_value;
1551     if (std::is_same<Toutput, quint8>::value ||
1552         std::is_same<Toutput, qint8>::value) {
1553       // This is the case when convolution and requantization are fused.
1554       // min_freezed_output and max_freezed_output are the actual range
1555       // of the output.
1556       min_output_value = context->input(6 + bias_index_offset).flat<float>()(0);
1557       max_output_value = context->input(7 + bias_index_offset).flat<float>()(0);
1558     } else {
1559       MklQuantizationRangeForMultiplication<quint8, qint8, qint32>(
1560           min_input, max_input, min_filter, max_filter, &min_output_value,
1561           &max_output_value);
1562     }
1563 
1564     Tensor* output_min = nullptr;
1565     Tensor* output_max = nullptr;
1566     MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
1567     output_min_mkl_shape.SetMklTensor(false);
1568     output_max_mkl_shape.SetMklTensor(false);
1569     AllocateOutputSetMklShape(context, 1, &output_min, {},
1570                               output_min_mkl_shape);
1571     AllocateOutputSetMklShape(context, 2, &output_max, {},
1572                               output_max_mkl_shape);
1573     output_min->flat<float>()(0) = min_output_value;
1574     output_max->flat<float>()(0) = max_output_value;
1575   }
1576 
1577  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1578   void ExtendConvFwdParams(OpKernelContext* context,
1579                            MklConvFwdParams& params) override {
1580     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
1581               bias_enabled, false, false>::ExtendConvFwdParams(context, params);
1582 
1583     // When the output type is quint8, the output data id requantized
1584     // into quint8. A post_op "output_scale" is added to do the conversion.
1585     if (std::is_same<Toutput, quint8>::value ||
1586         std::is_same<Toutput, qint8>::value) {
1587       int bias_index_offset;
1588       bias_index_offset = bias_enabled ? 1 : 0;
1589 
1590       const float min_input =
1591           context->input(2 + bias_index_offset).flat<float>()(0);
1592       const float max_input =
1593           context->input(3 + bias_index_offset).flat<float>()(0);
1594       const float min_filter =
1595           context->input(4 + bias_index_offset).flat<float>()(0);
1596       const float max_filter =
1597           context->input(5 + bias_index_offset).flat<float>()(0);
1598       const float min_freezed_output =
1599           context->input(6 + bias_index_offset).flat<float>()(0);
1600       const float max_freezed_output =
1601           context->input(7 + bias_index_offset).flat<float>()(0);
1602 
1603       float min_output_value;
1604       float max_output_value;
1605       MklQuantizationRangeForMultiplication<quint8, qint8, qint32>(
1606           min_input, max_input, min_filter, max_filter, &min_output_value,
1607           &max_output_value);
1608       float scale_int32 =
1609           std::max(std::abs(min_output_value), std::abs(max_output_value));
1610       float scale_eightbit =
1611           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1612       float scale = 1.0;
1613       if (std::is_same<Toutput, quint8>::value)
1614         scale = scale_int32 / scale_eightbit / static_cast<float>(1 << 23);
1615       else
1616         scale = scale_int32 / scale_eightbit / static_cast<float>(1 << 24);
1617 
1618       std::vector<float> output_scale;
1619       output_scale.push_back(scale);
1620       params.post_op_params.push_back({"output_scale", output_scale});
1621     }
1622   }
1623 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1624   Tbias* GetBiasHandle(OpKernelContext* context,
1625                        std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1626                        const Tensor& bias_tensor) override {
1627     int bias_index_offset;
1628     bias_index_offset = bias_enabled ? 1 : 0;
1629 
1630     const float min_input =
1631         context->input(2 + bias_index_offset).flat<float>()(0);
1632     const float max_input =
1633         context->input(3 + bias_index_offset).flat<float>()(0);
1634     const float min_filter =
1635         context->input(4 + bias_index_offset).flat<float>()(0);
1636     const float max_filter =
1637         context->input(5 + bias_index_offset).flat<float>()(0);
1638 
1639     std::vector<mkldnn::primitive> net;
1640     if (bias_enabled) {
1641       if (std::is_same<Tbias, qint32>::value) {
1642         return static_cast<Tbias*>(
1643             const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1644       }
1645       // If bias is enabled and requantization is not fused, scale the
1646       // bias to be consistent with quantized-input and quantized-filter.
1647       float bias_scale = 255.0 * 127.0 /
1648                          (std::max(std::abs(max_input), std::abs(min_input)) *
1649                           std::max(std::abs(max_filter), std::abs(min_filter)));
1650       std::vector<float> scales;
1651       scales.push_back(bias_scale);
1652       mkldnn::primitive_attr bias_attr;
1653       bias_attr.set_output_scales(0, scales);
1654 
1655       void* bias_buf = static_cast<void*>(
1656           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1657       input_bias_ = new memory(conv_fwd_pd->bias_primitive_desc(), bias_buf);
1658       scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc());
1659       auto reorder_desc = mkldnn::reorder::primitive_desc(
1660           input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(),
1661           bias_attr);
1662       net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
1663       stream(stream::kind::eager).submit(net).wait();
1664       return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
1665     } else {
1666       return nullptr;
1667     }
1668   }
1669 
1670   memory* input_bias_ = nullptr;
1671   memory* scaled_bias_ = nullptr;
1672 };
1673 
1674 template <typename Device, typename Tbias, typename Toutput,
1675           typename Ttemp_output, bool bias_enabled>
1676 class MklQuantizedConv2DReluOp
1677     : public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1678                                   bias_enabled> {
1679  public:
~MklQuantizedConv2DReluOp()1680   virtual ~MklQuantizedConv2DReluOp() {}
1681 
MklQuantizedConv2DReluOp(OpKernelConstruction * context)1682   explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context)
1683       : MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1684                              bias_enabled>(context) {}
1685 
1686  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1687   void ExtendConvFwdParams(OpKernelContext* context,
1688                            MklConvFwdParams& params) override {
1689     MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1690                          bias_enabled>::ExtendConvFwdParams(context, params);
1691     params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
1692   }
1693 };
1694 
1695 template <typename Device, typename Tbias, typename Toutput,
1696           typename Ttemp_output, bool bias_enabled>
1697 class MklQuantizedConv2DSumReluOp
1698     : public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1699                                   bias_enabled> {
1700  public:
~MklQuantizedConv2DSumReluOp()1701   virtual ~MklQuantizedConv2DSumReluOp() {
1702     if (this->summand_ != nullptr) {
1703       delete this->summand_;
1704       summand_ = nullptr;
1705     }
1706 
1707     if (this->dst_ != nullptr) {
1708       delete this->dst_;
1709       dst_ = nullptr;
1710     }
1711   }
1712 
MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1713   explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context)
1714       : MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1715                              bias_enabled>(context) {}
1716 
1717  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1718   void ExtendConvFwdParams(OpKernelContext* context,
1719                            MklConvFwdParams& params) override {
1720     MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output,
1721                          bias_enabled>::ExtendConvFwdParams(context, params);
1722     // Calculate the scale (beta in mkldnn api term) for sum
1723     if (std::is_same<Toutput, quint8>::value) {
1724       int summand_idx = context->num_inputs() / 2 - 1 - 2;
1725       DataType summand_type = this->input_type(summand_idx);
1726       bool summand_condition =
1727           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
1728       CHECK((summand_condition));
1729       int bias_index_offset = bias_enabled ? 1 : 0;
1730       const float min_freezed_output =
1731           context->input(6 + bias_index_offset).flat<float>()(0);
1732       const float max_freezed_output =
1733           context->input(7 + bias_index_offset).flat<float>()(0);
1734       const float min_freezed_summand =
1735           context->input(9 + bias_index_offset).flat<float>()(0);
1736       const float max_freezed_summand =
1737           context->input(10 + bias_index_offset).flat<float>()(0);
1738 
1739       float scale_output =
1740           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1741       float scale_summand = std::max(std::abs(min_freezed_summand),
1742                                      std::abs(max_freezed_summand));
1743       if (summand_type == DT_QUINT8)
1744         params.post_op_params.push_back(
1745             {"sum", {scale_summand / scale_output}});
1746       else
1747         params.post_op_params.push_back(
1748             {"sum", {2.0f * scale_summand / scale_output}});
1749     } else {
1750       params.post_op_params.push_back({"sum", {1.0}});
1751     }
1752     params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
1753   }
1754 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,memory::format output_tf_format,Tensor ** output_tensor)1755   void AllocateOutputTensor(OpKernelContext* context,
1756                             const ConvFwdPd& conv_prim_desc,
1757                             const memory::dims& output_dims_mkl_order,
1758                             memory::format output_tf_format,
1759                             Tensor** output_tensor) override {
1760     int summand_idx = context->num_inputs() / 2 - 1;
1761     float reorder_sum_scale = 1.0;
1762     if (std::is_same<Toutput, quint8>::value) {
1763       summand_idx -= 2;
1764       DataType summand_type = this->input_type(summand_idx);
1765       bool summand_condition =
1766           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
1767       CHECK((summand_condition));
1768       Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx));
1769       MklDnnShape summand_mkl_shape;
1770       GetMklShape(context, summand_idx, &summand_mkl_shape);
1771       auto dst_md = summand_mkl_shape.GetMklLayout();
1772       if (summand_mkl_shape.IsMklTensor()) {
1773         if (summand_type == DT_QINT8) {
1774           OP_REQUIRES_OK(context, summand.BitcastFrom(summand, DT_QUINT8,
1775                                                       summand.shape()));
1776           dst_md.data.data_type =
1777               static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
1778           summand_mkl_shape.SetMklLayout(&dst_md);
1779           summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
1780         }
1781         ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0,
1782                                             summand_mkl_shape);
1783         *output_tensor = const_cast<Tensor*>(&summand);
1784         return;
1785       } else {
1786         TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
1787                            "Current fusion is not successful."));
1788       }
1789     }
1790     // TODO(mdfaijul): Add cleaner code for non-mkl tensor
1791     MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
1792               bias_enabled, false,
1793               false>::AllocateOutputTensor(context, conv_prim_desc,
1794                                            output_dims_mkl_order,
1795                                            output_tf_format, output_tensor);
1796     const Tensor& summand = MklGetInput(context, summand_idx);
1797     if (summand.dtype() != DT_FLOAT)
1798       TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
1799                          "Current fusion requires summand to be float"));
1800     MklDnnShape summand_mkl_shape;
1801     GetMklShape(context, summand_idx, &summand_mkl_shape);
1802     // We need to compute scale for the summand
1803     int bias_index_offset = bias_enabled ? 1 : 0;
1804     const float min_input =
1805         context->input(2 + bias_index_offset).flat<float>()(0);
1806     const float max_input =
1807         context->input(3 + bias_index_offset).flat<float>()(0);
1808     const float min_filter =
1809         context->input(4 + bias_index_offset).flat<float>()(0);
1810     const float max_filter =
1811         context->input(5 + bias_index_offset).flat<float>()(0);
1812 
1813     reorder_sum_scale = 255.0 * 127.0 /
1814                         (std::max(std::abs(max_input), std::abs(min_input)) *
1815                          std::max(std::abs(max_filter), std::abs(min_filter)));
1816     std::vector<float> scales;
1817     scales.push_back(reorder_sum_scale);
1818     mkldnn::primitive_attr reorder_attr;
1819     reorder_attr.set_output_scales(0, scales);
1820 
1821     auto summand_md =
1822         summand_mkl_shape.IsMklTensor()
1823             ? summand_mkl_shape.GetMklLayout()
1824             : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
1825                            memory::format::nhwc);
1826     auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
1827     void* summand_buf =
1828         static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
1829     void* dst_buf =
1830         static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
1831     summand_ = new memory(summand_pd, summand_buf);
1832     dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf);
1833     auto reorder_desc = mkldnn::reorder::primitive_desc(
1834         summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr);
1835 
1836     std::vector<mkldnn::primitive> net;
1837     net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_));
1838     stream(stream::kind::eager).submit(net).wait();
1839   }
1840 
1841   memory* summand_ = nullptr;
1842   memory* dst_ = nullptr;
1843 };
1844 
1845 // INT8 kernel registration
1846 // Register NoOp kernel for QuantizedConv2D for qint8 filter
1847 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D")
1848                             .Device(DEVICE_CPU)
1849                             .TypeConstraint<quint8>("Tinput")
1850                             .TypeConstraint<qint8>("Tfilter")
1851                             .TypeConstraint<qint32>("out_type"),
1852                         NoOp);
1853 
1854 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize")
1855                             .Device(DEVICE_CPU)
1856                             .TypeConstraint<quint8>("Tinput")
1857                             .TypeConstraint<qint8>("Tfilter")
1858                             .TypeConstraint<qint8>("out_type"),
1859                         NoOp);
1860 
1861 // Register a templatized implementation of MklQuantizedConv2D.
1862 REGISTER_KERNEL_BUILDER(
1863     Name("_MklQuantizedConv2D")
1864         .Device(DEVICE_CPU)
1865         .TypeConstraint<quint8>("Tinput")
1866         .TypeConstraint<qint8>("Tfilter")
1867         .TypeConstraint<qint32>("out_type")
1868         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1869     MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, false>);
1870 
1871 REGISTER_KERNEL_BUILDER(
1872     Name("_MklQuantizedConv2DAndRequantize")
1873         .Device(DEVICE_CPU)
1874         .TypeConstraint<quint8>("Tinput")
1875         .TypeConstraint<qint8>("Tfilter")
1876         .TypeConstraint<qint8>("out_type")
1877         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1878     MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, false>);
1879 
1880 // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface.
1881 // This kernel will be replaced by an MKL kernel during graph
1882 // optimization pass.
1883 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias")
1884                             .Device(DEVICE_CPU)
1885                             .TypeConstraint<quint8>("Tinput")
1886                             .TypeConstraint<qint8>("Tfilter")
1887                             .TypeConstraint<qint32>("out_type"),
1888                         NoOp);
1889 
1890 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize")
1891                             .Device(DEVICE_CPU)
1892                             .TypeConstraint<quint8>("Tinput")
1893                             .TypeConstraint<qint8>("Tfilter")
1894                             .TypeConstraint<qint8>("out_type"),
1895                         NoOp);
1896 
1897 // Register a templatized implementation MklQuantizedConv2DWithBias.
1898 REGISTER_KERNEL_BUILDER(
1899     Name("_MklQuantizedConv2DWithBias")
1900         .Device(DEVICE_CPU)
1901         .TypeConstraint<quint8>("Tinput")
1902         .TypeConstraint<qint8>("Tfilter")
1903         .TypeConstraint<qint32>("out_type")
1904         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1905     MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, true>);
1906 
1907 REGISTER_KERNEL_BUILDER(
1908     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1909         .Device(DEVICE_CPU)
1910         .TypeConstraint<quint8>("Tinput")
1911         .TypeConstraint<qint8>("Tfilter")
1912         .TypeConstraint<qint32>("Tbias")
1913         .TypeConstraint<qint8>("out_type")
1914         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1915     MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, true>);
1916 REGISTER_KERNEL_BUILDER(
1917     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1918         .Device(DEVICE_CPU)
1919         .TypeConstraint<quint8>("Tinput")
1920         .TypeConstraint<qint8>("Tfilter")
1921         .TypeConstraint<float>("Tbias")
1922         .TypeConstraint<qint8>("out_type")
1923         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1924     MklQuantizedConv2DOp<CPUDevice, float, qint8, qint8, true>);
1925 
1926 // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface.
1927 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
1928 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRelu")
1929                             .Device(DEVICE_CPU)
1930                             .TypeConstraint<quint8>("Tinput")
1931                             .TypeConstraint<qint8>("Tfilter")
1932                             .TypeConstraint<qint32>("out_type"),
1933                         NoOp);
1934 
1935 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndReluAndRequantize")
1936                             .Device(DEVICE_CPU)
1937                             .TypeConstraint<quint8>("Tinput")
1938                             .TypeConstraint<qint8>("Tfilter")
1939                             .TypeConstraint<quint8>("out_type"),
1940                         NoOp);
1941 
1942 // Register a templatized implementation of MklQuantizedConv2DAndRelu.
1943 REGISTER_KERNEL_BUILDER(
1944     Name("_MklQuantizedConv2DAndRelu")
1945         .Device(DEVICE_CPU)
1946         .TypeConstraint<quint8>("Tinput")
1947         .TypeConstraint<qint8>("Tfilter")
1948         .TypeConstraint<qint32>("out_type")
1949         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1950     MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, false>);
1951 
1952 REGISTER_KERNEL_BUILDER(
1953     Name("_MklQuantizedConv2DAndReluAndRequantize")
1954         .Device(DEVICE_CPU)
1955         .TypeConstraint<quint8>("Tinput")
1956         .TypeConstraint<qint8>("Tfilter")
1957         .TypeConstraint<quint8>("out_type")
1958         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1959     MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, false>);
1960 
1961 // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python
1962 // interface.
1963 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
1964 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu")
1965                             .Device(DEVICE_CPU)
1966                             .TypeConstraint<quint8>("Tinput")
1967                             .TypeConstraint<qint8>("Tfilter")
1968                             .TypeConstraint<qint32>("out_type"),
1969                         NoOp);
1970 
1971 // Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize
1972 // to get a python interface.
1973 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
1974 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize")
1975                             .Device(DEVICE_CPU)
1976                             .TypeConstraint<quint8>("Tinput")
1977                             .TypeConstraint<qint8>("Tfilter")
1978                             .TypeConstraint<quint8>("out_type"),
1979                         NoOp);
1980 
1981 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu.
1982 REGISTER_KERNEL_BUILDER(
1983     Name("_MklQuantizedConv2DWithBiasAndRelu")
1984         .Device(DEVICE_CPU)
1985         .TypeConstraint<quint8>("Tinput")
1986         .TypeConstraint<qint8>("Tfilter")
1987         .TypeConstraint<qint32>("out_type")
1988         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1989     MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, true>);
1990 
1991 // Register a templatized implementation of
1992 // MklQuantizedConv2DWithBiasAndReluAndRequantize.
1993 REGISTER_KERNEL_BUILDER(
1994     Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
1995         .Device(DEVICE_CPU)
1996         .TypeConstraint<quint8>("Tinput")
1997         .TypeConstraint<qint8>("Tfilter")
1998         .TypeConstraint<float>("Tbias")
1999         .TypeConstraint<quint8>("out_type")
2000         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2001     MklQuantizedConv2DReluOp<CPUDevice, float, quint8, quint8, true>);
2002 REGISTER_KERNEL_BUILDER(
2003     Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
2004         .Device(DEVICE_CPU)
2005         .TypeConstraint<quint8>("Tinput")
2006         .TypeConstraint<qint8>("Tfilter")
2007         .TypeConstraint<qint32>("Tbias")
2008         .TypeConstraint<quint8>("out_type")
2009         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2010     MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true>);
2011 
2012 // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python
2013 // interface.
2014 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
2015 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndRelu")
2016                             .Device(DEVICE_CPU)
2017                             .TypeConstraint<quint8>("Tinput")
2018                             .TypeConstraint<qint8>("Tfilter")
2019                             .TypeConstraint<qint32>("out_type"),
2020                         NoOp);
2021 
2022 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndReluAndRequantize")
2023                             .Device(DEVICE_CPU)
2024                             .TypeConstraint<quint8>("Tinput")
2025                             .TypeConstraint<qint8>("Tfilter")
2026                             .TypeConstraint<quint8>("out_type"),
2027                         NoOp);
2028 REGISTER_KERNEL_BUILDER(
2029     Name("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2030         .Device(DEVICE_CPU)
2031         .TypeConstraint<quint8>("Tinput")
2032         .TypeConstraint<qint8>("Tfilter")
2033         .TypeConstraint<quint8>("out_type"),
2034     NoOp);
2035 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu.
2036 REGISTER_KERNEL_BUILDER(
2037     Name("_MklQuantizedConv2DWithBiasSumAndRelu")
2038         .Device(DEVICE_CPU)
2039         .TypeConstraint<quint8>("Tinput")
2040         .TypeConstraint<qint8>("Tfilter")
2041         .TypeConstraint<qint32>("out_type")
2042         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2043     MklQuantizedConv2DSumReluOp<CPUDevice, float, qint32, qint32, true>);
2044 
2045 REGISTER_KERNEL_BUILDER(
2046     Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
2047         .Device(DEVICE_CPU)
2048         .TypeConstraint<quint8>("Tinput")
2049         .TypeConstraint<qint8>("Tfilter")
2050         .TypeConstraint<qint32>("Tbias")
2051         .TypeConstraint<quint8>("out_type")
2052         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2053     MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, quint8, true>);
2054 
2055 REGISTER_KERNEL_BUILDER(
2056     Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2057         .Device(DEVICE_CPU)
2058         .TypeConstraint<quint8>("Tinput")
2059         .TypeConstraint<qint8>("Tfilter")
2060         .TypeConstraint<qint32>("Tbias")
2061         .TypeConstraint<quint8>("out_type")
2062         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2063     MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, qint8, true>);
2064 
2065 REGISTER_KERNEL_BUILDER(
2066     Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
2067         .Device(DEVICE_CPU)
2068         .TypeConstraint<quint8>("Tinput")
2069         .TypeConstraint<qint8>("Tfilter")
2070         .TypeConstraint<float>("Tbias")
2071         .TypeConstraint<quint8>("out_type")
2072         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2073     MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, quint8, true>);
2074 
2075 REGISTER_KERNEL_BUILDER(
2076     Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2077         .Device(DEVICE_CPU)
2078         .TypeConstraint<quint8>("Tinput")
2079         .TypeConstraint<qint8>("Tfilter")
2080         .TypeConstraint<float>("Tbias")
2081         .TypeConstraint<quint8>("out_type")
2082         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2083     MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, qint8, true>);
2084 #endif  // INTEL_MKL_ML
2085 
2086 // Register 2D operations
2087 #define REGISTER_MKL_CPU_2D(T)                                             \
2088   REGISTER_KERNEL_BUILDER(Name("_MklConv2D")                               \
2089                               .Device(DEVICE_CPU)                          \
2090                               .TypeConstraint<T>("T")                      \
2091                               .Label(mkl_op_registry::kMklOpLabel),        \
2092                           MklConvOp<CPUDevice, float, float, float, float, \
2093                                     float, int32, false, false, false>);   \
2094   REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias")                       \
2095                               .Device(DEVICE_CPU)                          \
2096                               .TypeConstraint<T>("T")                      \
2097                               .Label(mkl_op_registry::kMklOpLabel),        \
2098                           MklConvOp<CPUDevice, float, float, float, float, \
2099                                     float, int32, true, false, false>);    \
2100   REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias")                 \
2101                               .Device(DEVICE_CPU)                          \
2102                               .TypeConstraint<T>("T")                      \
2103                               .Label(mkl_op_registry::kMklOpLabel),        \
2104                           MklDummyOp<CPUDevice, T>);                       \
2105   REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D")                        \
2106                               .Device(DEVICE_CPU)                          \
2107                               .TypeConstraint<T>("T")                      \
2108                               .TypeConstraint<int32>("Tpaddings")          \
2109                               .Label(mkl_op_registry::kMklOpLabel),        \
2110                           MklConvOp<CPUDevice, float, float, float, float, \
2111                                     float, int32, false, true, false>);    \
2112   REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D")                        \
2113                               .Device(DEVICE_CPU)                          \
2114                               .TypeConstraint<T>("T")                      \
2115                               .TypeConstraint<int64>("Tpaddings")          \
2116                               .Label(mkl_op_registry::kMklOpLabel),        \
2117                           MklConvOp<CPUDevice, float, float, float, float, \
2118                                     float, int64, false, true, false>);    \
2119   REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D")                  \
2120                               .Device(DEVICE_CPU)                          \
2121                               .TypeConstraint<T>("T")                      \
2122                               .TypeConstraint<int32>("Tpaddings")          \
2123                               .Label(mkl_op_registry::kMklOpLabel),        \
2124                           MklDummyOp<CPUDevice, T>);
2125 
2126 TF_CALL_float(REGISTER_MKL_CPU_2D);
2127 
2128 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T)                                   \
2129   REGISTER_KERNEL_BUILDER(Name("_MklDepthwiseConv2dNative")                \
2130                               .Device(DEVICE_CPU)                          \
2131                               .TypeConstraint<float>("T")                  \
2132                               .Label(mkl_op_registry::kMklOpLabel),        \
2133                           MklConvOp<CPUDevice, float, float, float, float, \
2134                                     float, int32, false, false, true>);
2135 
2136 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
2137 
2138 // Note we are registering _MklFusedConv2D.
2139 // We check the fused_ops attributes to decide if bias is enabled or not.
2140 #define REGISTER_MKL_CPU_2D_FUSED(T)                                \
2141   REGISTER_KERNEL_BUILDER(                                          \
2142       Name("_MklFusedConv2D")                                       \
2143           .Device(DEVICE_CPU)                                       \
2144           .TypeConstraint<T>("T")                                   \
2145           .Label(mkl_op_registry::kMklOpLabel),                     \
2146       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false>);      \
2147   REGISTER_KERNEL_BUILDER(                                          \
2148       Name("_MklPadWithFusedConv2D")                                \
2149           .Device(DEVICE_CPU)                                       \
2150           .TypeConstraint<int32>("Tpaddings")                       \
2151           .TypeConstraint<T>("T")                                   \
2152           .Label(mkl_op_registry::kMklOpLabel),                     \
2153       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true>);       \
2154   REGISTER_KERNEL_BUILDER(                                          \
2155       Name("_MklPadWithFusedConv2D")                                \
2156           .Device(DEVICE_CPU)                                       \
2157           .TypeConstraint<T>("T")                                   \
2158           .TypeConstraint<int64>("Tpaddings")                       \
2159           .Label(mkl_op_registry::kMklOpLabel),                     \
2160       MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true>);       \
2161   REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithFusedConv2D")      \
2162                               .Device(DEVICE_CPU)                   \
2163                               .TypeConstraint<T>("T")               \
2164                               .TypeConstraint<int32>("Tpaddings")   \
2165                               .Label(mkl_op_registry::kMklOpLabel), \
2166                           MklDummyOp<CPUDevice, T>);
2167 
2168 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED);
2169 
2170 // Register 3D operations
2171 #define REGISTER_MKL_CPU_3D(T)                  \
2172   REGISTER_KERNEL_BUILDER(                      \
2173       Name("_MklConv3D")                        \
2174           .Device(DEVICE_CPU)                   \
2175           .TypeConstraint<T>("T")               \
2176           .Label(mkl_op_registry::kMklOpLabel), \
2177       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>);
2178 TF_CALL_float(REGISTER_MKL_CPU_3D);
2179 
2180 }  // namespace tensorflow
2181 #endif  // INTEL_MKL
2182