1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19 
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #if defined(INTEL_MKL_ML_ONLY) || defined(INTEL_MKL_DNN_ONLY)
28 #ifndef INTEL_MKL
29 #error "INTEL_MKL_{ML,DNN}_ONLY require INTEL_MKL"
30 #endif
31 #endif
32 
33 #if defined(INTEL_MKL_ML_ONLY) && defined(INTEL_MKL_DNN_ONLY)
34 #error "at most one of INTEL_MKL_ML_ONLY and INTEL_MKL_DNN_ONLY may be defined"
35 #endif
36 
37 #ifdef INTEL_MKL_ML_ONLY
38 #error "Please use INTEL MKL DNN (the default option for --config=mkl)."
39 #endif
40 
41 #ifdef INTEL_MKL_ML_ONLY
42 #include "mkl_dnn.h"
43 #include "mkl_dnn_types.h"
44 #include "mkl_service.h"
45 #include "mkl_trans.h"
46 #endif
47 
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/graph/mkl_graph_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/gtl/array_slice.h"
54 #include "tensorflow/core/platform/cpu_info.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/util/env_var.h"
58 #include "tensorflow/core/util/padding.h"
59 #include "tensorflow/core/util/tensor_format.h"
60 
61 #ifndef INTEL_MKL_ML_ONLY
62 #include "mkldnn.hpp"
63 #include "tensorflow/core/lib/core/stringpiece.h"
64 
65 using mkldnn::engine;
66 using mkldnn::memory;
67 using mkldnn::padding_kind;
68 using mkldnn::primitive;
69 using mkldnn::reorder;
70 #endif
71 
72 #ifdef _WIN32
73 typedef unsigned int uint;
74 #endif
75 
76 namespace tensorflow {
77 
78 // The file contains a number of utility classes and functions used by MKL
79 // enabled kernels
80 
81 // This class encapsulates all the meta data that is associated with an MKL
82 // tensor. A tensor is an MKL tensor if it was created as the result of an
83 // MKL operation, and did not go through a conversion to a standard
84 // Tensorflow tensor.
85 
86 // For use with MKL ML, has been deprecated
87 typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
88 
89 // The dimensions order that MKL-DNN internally uses for 2D activations
90 // [Batch, Channel, Height, Width] and
91 // for 2D filters [Out_Channel, In_Channel, Height, Width].
92 typedef enum {
93   Dim_N = 0,
94   Dim_C = 1,
95   Dim_H = 2,
96   Dim_W = 3,
97   Dim_O = 0,
98   Dim_I = 1
99 } MklDnnDims;
100 
101 // The dimensions order that MKL-DNN internally uses for 3D activations
102 // [Batch, Channel, Depth, Height, Width] and
103 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
104 typedef enum {
105   Dim3d_N = 0,
106   Dim3d_C = 1,
107   Dim3d_D = 2,
108   Dim3d_H = 3,
109   Dim3d_W = 4,
110   Dim3d_O = 0,
111   Dim3d_I = 1
112 } MklDnnDims3D;
113 
114 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
115 // filter_width, in_channels, out_channels]
116 typedef enum {
117   TF_2DFILTER_DIM_H = 0,
118   TF_2DFILTER_DIM_W = 1,
119   TF_2DFILTER_DIM_I = 2,
120   TF_2DFILTER_DIM_O = 3
121 } TFFilterDims2d;
122 
123 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
124 // filter_height, filter_width, in_channels, out_channels]
125 typedef enum {
126   TF_3DFILTER_DIM_P = 0,
127   TF_3DFILTER_DIM_H = 1,
128   TF_3DFILTER_DIM_W = 2,
129   TF_3DFILTER_DIM_I = 3,
130   TF_3DFILTER_DIM_O = 4
131 } TFFilterDims3d;
132 
133 // The dimensions order that MKL-DNN requires for the filter in a grouped
134 // convolution (2D only)
135 typedef enum {
136   MKL_GROUP_FILTER_DIM_G = 0,
137   MKL_GROUP_FILTER_DIM_O = 1,
138   MKL_GROUP_FILTER_DIM_I = 2,
139   MKL_GROUP_FILTER_DIM_H = 3,
140   MKL_GROUP_FILTER_DIM_W = 4
141 } MklDnnFilterGroupDims;
142 
143 // Enum used to templatize MklOp kernel implementations
144 // that support both fp32 and int8 versions.
145 enum class MklQuantization {
146   QUANTIZED_VERSION,
147   FP_VERSION,
148 };
149 
150 static const int kSmallBatchSize = 32;
151 
152 #ifdef INTEL_MKL_ML_ONLY
153 class MklShape {
154  public:
MklShape()155   MklShape() {}
156   TF_DISALLOW_COPY_AND_ASSIGN(MklShape);  // Cannot copy
157 
~MklShape()158   ~MklShape() {
159     if (sizes_) delete[] sizes_;
160     if (strides_) delete[] strides_;
161     if (mklLayout_) CHECK_EQ(dnnLayoutDelete_F32(mklLayout_), E_SUCCESS);
162     if (tfLayout_) CHECK_EQ(dnnLayoutDelete_F32(tfLayout_), E_SUCCESS);
163     if (tf_to_mkl_dim_map_) delete[] tf_to_mkl_dim_map_;
164   }
165 
IsMklTensor()166   const bool IsMklTensor() const { return isMklTensor_; }
167 
SetMklTensor(const bool isMklTensor)168   void SetMklTensor(const bool isMklTensor) { isMklTensor_ = isMklTensor; }
169 
SetDimensions(const size_t dimension)170   void SetDimensions(const size_t dimension) { dimension_ = dimension; }
171 
SetMklLayout(dnnLayout_t mklLayout)172   void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; }
173 
SetMklLayout(const void * primitive,size_t resourceType)174   void SetMklLayout(const void* primitive, size_t resourceType) {
175     CHECK_EQ(
176         dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive,
177                                          (dnnResourceType_t)resourceType),
178         E_SUCCESS);
179   }
180 
SetTfLayout(const size_t dimension,const size_t * sizes,const size_t * strides)181   void SetTfLayout(const size_t dimension, const size_t* sizes,
182                    const size_t* strides) {
183     dimension_ = dimension;
184     if (dimension > 0) {  // MKl doesn't support zero dimension tensors
185       sizes_ = new size_t[dimension];
186       strides_ = new size_t[dimension];
187 
188       for (int ii = 0; ii < dimension; ii++) {
189         sizes_[ii] = sizes[ii];
190         strides_[ii] = strides[ii];
191       }
192       CHECK_EQ(dnnLayoutCreate_F32(&tfLayout_, dimension, sizes, strides),
193                E_SUCCESS);
194     }
195   }
196 
197   // Default case - MKL dim ordering is opposite of TF dim ordering
198   // MKL -> (DIMS-1)...0 where (DIMS-1) is outermost dim and 0 is innermost dim
199   // TF  -> 0...(DIMS-1) where 0 is outermost dim and (DIMS-1) is innermost dim
200   // For layers that rely on data_format semantics (conv, pooling etc.)
201   // or operate only on certain dimensions (relu, concat, split etc.),
202   // Mkl APIs might require us to reorder these dimensions. In such cases,
203   // kernels should explicitly set this map
SetTfDimOrder(const size_t dimension)204   void SetTfDimOrder(const size_t dimension) {
205     CHECK(dimension == dimension_);
206     if (tf_to_mkl_dim_map_ == nullptr) {
207       tf_to_mkl_dim_map_ = new size_t[dimension];
208     }
209     for (size_t ii = 0; ii < dimension; ii++) {
210       tf_to_mkl_dim_map_[ii] = dimension - (ii + 1);
211     }
212   }
213 
SetTfDimOrder(const size_t dimension,const size_t * tf_to_mkl_dim_map)214   void SetTfDimOrder(const size_t dimension, const size_t* tf_to_mkl_dim_map) {
215     CHECK(dimension == dimension_);
216     if (tf_to_mkl_dim_map_ == nullptr) {
217       tf_to_mkl_dim_map_ = new size_t[dimension];
218     }
219     for (size_t ii = 0; ii < dimension; ii++) {
220       tf_to_mkl_dim_map_[ii] = tf_to_mkl_dim_map[ii];
221     }
222   }
223 
SetTfDimOrder(const size_t dimension,TensorFormat data_format)224   void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
225     CHECK_EQ(dimension, 4);
226     CHECK(dimension == dimension_);
227     if (tf_to_mkl_dim_map_ == nullptr) {
228       tf_to_mkl_dim_map_ = new size_t[dimension];
229     }
230     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDims::W;
231     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDims::H;
232     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDims::C;
233     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDims::N;
234   }
235 
GetMklLayout()236   const dnnLayout_t GetMklLayout() const { return mklLayout_; }
GetTfLayout()237   const dnnLayout_t GetTfLayout() const { return tfLayout_; }
GetCurLayout()238   const dnnLayout_t GetCurLayout() const {
239     return isMklTensor_ ? mklLayout_ : tfLayout_;
240   }
GetDimension()241   size_t GetDimension() const { return dimension_; }
GetSizes()242   const size_t* GetSizes() const { return sizes_; }
dim_size(int index)243   int64 dim_size(int index) const { return sizes_[index]; }
tf_dim_size(int index)244   int64 tf_dim_size(int index) const {
245     return sizes_[tf_to_mkl_dim_map_[index]];
246   }
GetStrides()247   const size_t* GetStrides() const { return strides_; }
GetTfToMklDimMap()248   const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
tf_dim_idx(int index)249   size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
250 
251   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
252   // corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)253   bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
254   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
255   // corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)256   bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
257   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
258   // corresponds to MKL's Width dimension.
IsMklWidthDim(int d)259   bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
260   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
261   // corresponds to MKL's Height dimension.
IsMklHeightDim(int d)262   bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
263 
264   // Check if the TF-Mkl dimension ordering map specifies if the input
265   // tensor is in NCHW format.
IsTensorInNCHWFormat()266   bool IsTensorInNCHWFormat() const {
267     TensorFormat data_format = FORMAT_NCHW;
268     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
269             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
270             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
271             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
272   }
273 
274   // Check if the TF-Mkl dimension ordering map specifies if the input
275   // tensor is in NHWC format.
IsTensorInNHWCFormat()276   bool IsTensorInNHWCFormat() const {
277     TensorFormat data_format = FORMAT_NHWC;
278     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
279             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
280             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
281             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
282   }
283 
GetConvertedFlatData(dnnLayout_t targetLayout,void * input,void * output)284   void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
285                             void* output) const {
286     dnnLayout_t curLayout;
287     if (isMklTensor_)
288       curLayout = mklLayout_;
289     else
290       curLayout = tfLayout_;
291     dnnPrimitive_t convert;
292     CHECK_EQ(dnnConversionCreate_F32(&convert, curLayout, targetLayout),
293              E_SUCCESS);
294     CHECK_EQ(dnnConversionExecute_F32(convert, input, output), E_SUCCESS);
295     CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS);
296   }
297 
298   // The following methods are used for serializing and de-serializing the
299   // contents of the mklshape object.
300   // The data is serialized in this order
301   // isMklTensor_
302   // dimension_
303   // sizes_
304   // strides_
305   // mklLayout_
306   // tfLayout_
307   // tf_to_mkl_dim_map_
308 
309 #define SIZE_OF_MKL_DNN_BUF \
310   (dnnLayoutSerializationBufferSize_F32())  // Size of buffer needed to
311                                             // serialize dnn_layout pointer
312 
313   // Size of buffer to hold the serialized object, the size is computed as
314   // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) +
315   // sizeof(strides_)
316   // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer)
317   // + sizeof(tf_to_mkl_dim_map_)
318 
319 #define SIZE_OF_MKL_SERIAL_DATA(dims) \
320   (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF)
321 
322   // First we need to define some macro for offsets into the serial buffer where
323   // different elements of Mklshape is written/read from
324 
325 #define IS_MKL_TENSOR_OFFSET 0
326 // Location from start of buffer where isMklTensor_ is serialized
327 #define DIMS_OFFSET \
328   (IS_MKL_TENSOR_OFFSET + sizeof(size_t))  // Location of dimension_
329 // Location of sizes. Note dim is not used here, left here
330 // to make macros consistent.
331 #define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
332 #define STRIDES_OFFSET(dims) \
333   (SIZES_OFFSET(dims) + dims * sizeof(size_t))  // Location of strides
334 #define MKL_LAYOUT_OFFSET(dims) \
335   (STRIDES_OFFSET(dims) + dims * sizeof(size_t))  // Location of mklLayout_
336 #define TF_LAYOUT_OFFSET(dims) \
337   (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)  // Location of tfLayout_
338 // Location of tf_to_mkl_dim_map_
339 #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
340   (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
341 
342   // TODO(agramesh1) make sure to create a const to share with rewrite pass
343   // for min size of MKL metadata tensor.
344 
DeSerializeMklShape(const unsigned char * buf,size_t buf_size)345   void DeSerializeMklShape(const unsigned char* buf, size_t buf_size) {
346     CHECK(buf_size >= sizeof(size_t)) << "Bufsize too small in DeSerialize";
347     // Make sure buffer holds at least  isMklTensor_
348     isMklTensor_ =
349         *reinterpret_cast<const size_t*>(buf + IS_MKL_TENSOR_OFFSET) != 0;
350 
351     if (isMklTensor_) {  // If it is an MKL Tensor then read the rest
352       dimension_ = *(reinterpret_cast<const size_t*>(buf + DIMS_OFFSET));
353       CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
354           << "Bufsize too small in DeSerialize";
355       sizes_ = new size_t[dimension_];
356       strides_ = new size_t[dimension_];
357       tf_to_mkl_dim_map_ = new size_t[dimension_];
358       for (int i = 0; i < dimension_; i++) {
359         sizes_[i] =
360             reinterpret_cast<const size_t*>(buf + SIZES_OFFSET(dimension_))[i];
361         strides_[i] = reinterpret_cast<const size_t*>(
362             buf + STRIDES_OFFSET(dimension_))[i];
363         tf_to_mkl_dim_map_[i] = reinterpret_cast<const size_t*>(
364             buf + TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i];
365       }
366       CHECK_EQ(dnnLayoutDeserialize_F32(&mklLayout_,
367                                         buf + MKL_LAYOUT_OFFSET(dimension_)),
368                E_SUCCESS);
369       CHECK_EQ(dnnLayoutDeserialize_F32(&tfLayout_,
370                                         buf + TF_LAYOUT_OFFSET(dimension_)),
371                E_SUCCESS);
372     }
373   }
374 
SerializeMklShape(unsigned char * buf,size_t buf_size)375   void SerializeMklShape(unsigned char* buf, size_t buf_size) const {
376     CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
377         << "Bufsize too small to Serialize";
378     *reinterpret_cast<size_t*>(buf + IS_MKL_TENSOR_OFFSET) =
379         isMklTensor_ ? 1 : 0;
380     if (isMklTensor_) {
381       *(reinterpret_cast<size_t*>(buf + DIMS_OFFSET)) = dimension_;
382       for (int i = 0; i < dimension_; i++) {
383         reinterpret_cast<size_t*>(buf + SIZES_OFFSET(dimension_))[i] =
384             sizes_[i];
385         reinterpret_cast<size_t*>(buf + STRIDES_OFFSET(dimension_))[i] =
386             strides_[i];
387         reinterpret_cast<size_t*>(buf +
388                                   TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i] =
389             tf_to_mkl_dim_map_[i];
390       }
391       CHECK_EQ(dnnLayoutSerialize_F32(mklLayout_,
392                                       buf + MKL_LAYOUT_OFFSET(dimension_)),
393                E_SUCCESS);
394       CHECK_EQ(
395           dnnLayoutSerialize_F32(tfLayout_, buf + TF_LAYOUT_OFFSET(dimension_)),
396           E_SUCCESS);
397     }
398   }
399 
400  private:
401   bool isMklTensor_ =
402       false;  // Flag to indicate if the tensor is an  MKL tensor or not
403   dnnLayout_t mklLayout_ = nullptr;  // Pointer to the MKL layout
404   dnnLayout_t tfLayout_ = nullptr;   // Pointer to layout of corresponding
405   // Tensorflow tensor, used when conversion from MKL to standard tensor
406   size_t dimension_ = 0;
407   size_t* sizes_ = nullptr;    // Required by MKL for conversions
408   size_t* strides_ = nullptr;  // Required by MKL for conversions
409   size_t* tf_to_mkl_dim_map_ =
410       nullptr;  // TF dimension corresponding to this MKL dimension
411 };
412 
413 #else
414 
415 // Forward decl
416 TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
417 TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
418 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
419 memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
420                                         const memory::dims& strides,
421                                         memory::data_type dtype);
422 
423 class MklDnnShape {
424  private:
425   typedef struct {
426     /// Flag to indicate if the tensor is an  MKL tensor or not
427     bool is_mkl_tensor_ = false;
428     /// Number of dimensions in Tensorflow format
429     size_t dimension_ = 0;
430     /// Required by MKLDNN for conversions
431     mkldnn_dims_t sizes_;  // Required by MKL for conversions
432     memory::format tf_data_format_ = memory::format::format_undef;
433     memory::data_type T_ = memory::data_type::data_undef;
434     // MKL layout
435     mkldnn_memory_desc_t mkl_md_;
436     /// TF dimension corresponding to this MKL dimension
437     mkldnn_dims_t map_;
438   } MklShapeData;
439   MklShapeData data_;
440 
441   typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
442 #define INVALID_DIM_SIZE -1
443 
444  public:
MklDnnShape()445   MklDnnShape() {
446     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
447          ++i) {
448       data_.sizes_[i] = -1;
449     }
450     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
451       data_.map_[i] = -1;
452     }
453   }
454 
~MklDnnShape()455   ~MklDnnShape() {}
456   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
457 
458   /// Helper function to compare memory::desc objects for MklDnn.
459   /// May be this should go into MklDnn directly.
CompareMklDnnLayouts(const memory::desc & md1,const memory::desc & md2)460   inline bool CompareMklDnnLayouts(const memory::desc& md1,
461                                    const memory::desc& md2) const {
462     mkldnn_memory_desc_t mdd1 = md1.data;
463     mkldnn_memory_desc_t mdd2 = md2.data;
464     const char* d1 = reinterpret_cast<const char*>(&mdd1);
465     const char* d2 = reinterpret_cast<const char*>(&mdd2);
466 
467     size_t md_size = sizeof(mdd1);
468     for (size_t i = 0; i < md_size; i++) {
469       if (*d1++ != *d2++) {
470         return false;
471       }
472     }
473     return true;
474   }
475 
476   /// Equality function for MklDnnShape objects
477   /// @return true if both are equal; false otherwise.
478   inline bool operator==(const MklDnnShape& input_shape) const {
479     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
480       return false;
481     }
482 
483     // If input tensors are in Mkl layout, then we check for dimensions and
484     // sizes.
485     if (this->IsMklTensor()) {
486       return this->GetTfShape() == input_shape.GetTfShape() &&
487              CompareMklDnnLayouts(this->GetMklLayout(),
488                                   input_shape.GetMklLayout());
489     }
490 
491     return true;
492   }
493 
494   /// Equality operator for MklDnnShape and TFShape.
495   /// Returns: true if TF shapes for both are the same, false otherwise
496   inline bool operator==(const TensorShape& input_shape) const {
497     if (!this->IsMklTensor()) {
498       return false;
499     }
500 
501     return this->GetTfShape() == input_shape;
502   }
503 
IsMklTensor()504   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)505   inline void SetMklTensor(bool is_mkl_tensor) {
506     data_.is_mkl_tensor_ = is_mkl_tensor;
507   }
508 
SetDimensions(const size_t dimension)509   inline void SetDimensions(const size_t dimension) {
510     data_.dimension_ = dimension;
511   }
GetDimension(char dimension)512   inline size_t GetDimension(char dimension) const {
513     int index = GetMklDnnTensorDimIndex(dimension);
514     CHECK(index >= 0 && index < this->GetDimension())
515         << "Invalid index from the dimension: " << index << ", " << dimension;
516     return this->DimSize(index);
517   }
518 
GetDimension3D(char dimension)519   inline size_t GetDimension3D(char dimension) const {
520     int index = GetMklDnnTensor3DDimIndex(dimension);
521     CHECK(index >= 0 && index < this->GetDimension())
522         << "Invalid index from the dimension: " << index << ", " << dimension;
523     return this->DimSize(index);
524   }
525 
GetMklDnnTensorDimIndex(char dimension)526   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
527     switch (dimension) {
528       case 'N':
529         return MklDnnDims::Dim_N;
530       case 'C':
531         return MklDnnDims::Dim_C;
532       case 'H':
533         return MklDnnDims::Dim_H;
534       case 'W':
535         return MklDnnDims::Dim_W;
536       default:
537         LOG(FATAL) << "Invalid dimension: " << dimension;
538         return -1;  // Avoid compiler warning about missing return value
539     }
540   }
541 
GetMklDnnTensor3DDimIndex(char dimension)542   inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
543     switch (dimension) {
544       case 'N':
545         return MklDnnDims3D::Dim3d_N;
546       case 'C':
547         return MklDnnDims3D::Dim3d_C;
548       case 'D':
549         return MklDnnDims3D::Dim3d_D;
550       case 'H':
551         return MklDnnDims3D::Dim3d_H;
552       case 'W':
553         return MklDnnDims3D::Dim3d_W;
554       default:
555         LOG(FATAL) << "Invalid dimension: " << dimension;
556         return -1;  // Avoid compiler warning about missing return value
557     }
558   }
559 
GetDimension()560   inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()561   inline const int* GetSizes() const {
562     return reinterpret_cast<const int*>(&data_.sizes_[0]);
563   }
564 
565   // Returns an mkldnn::memory::dims object that contains the sizes of this
566   // MklDnnShape object.
GetSizesAsMklDnnDims()567   inline memory::dims GetSizesAsMklDnnDims() const {
568     memory::dims retVal;
569     if (data_.is_mkl_tensor_) {
570       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
571       for (size_t i = 0; i < dimensions; i++) {
572         if (data_.sizes_[i] != INVALID_DIM_SIZE)
573           retVal.push_back(data_.sizes_[i]);
574       }
575     } else {
576       CHECK_EQ(data_.is_mkl_tensor_, true);
577     }
578     return retVal;
579   }
580 
DimSize(int index)581   inline int64 DimSize(int index) const {
582     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
583     return data_.sizes_[index];
584   }
585 
586   /// Return TensorShape that describes the Tensorflow shape of the tensor
587   /// represented by this MklShape.
GetTfShape()588   inline TensorShape GetTfShape() const {
589     CHECK_EQ(data_.is_mkl_tensor_, true);
590 
591     std::vector<int32> shape(data_.dimension_, -1);
592     if (data_.tf_data_format_ != memory::format::blocked) {
593       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
594         shape[idx] = data_.sizes_[TfDimIdx(idx)];
595       }
596     } else {
597       // If Tensorflow shape is in Blocked format, then we don't have dimension
598       // map for it. So we just create Tensorflow shape from sizes in the
599       // specified order.
600       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
601         shape[idx] = data_.sizes_[idx];
602       }
603     }
604 
605     TensorShape ts;
606     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
607     CHECK_EQ(ret, true);
608     return ts;
609   }
610 
SetElemType(memory::data_type dt)611   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()612   inline const memory::data_type GetElemType() { return data_.T_; }
613 
SetMklLayout(memory::primitive_desc * pd)614   inline void SetMklLayout(memory::primitive_desc* pd) {
615     CHECK_NOTNULL(pd);
616     data_.mkl_md_ = pd->desc().data;
617   }
618 
SetMklLayout(memory::desc * md)619   inline void SetMklLayout(memory::desc* md) {
620     CHECK_NOTNULL(md);
621     data_.mkl_md_ = md->data;
622   }
623 
GetMklLayout()624   inline const memory::desc GetMklLayout() const {
625     return memory::desc(data_.mkl_md_);
626   }
627 
GetTfDataFormat()628   inline memory::format GetTfDataFormat() const {
629     return data_.tf_data_format_;
630   }
631   /// We don't create primitive_descriptor for TensorFlow layout now.
632   /// We use lazy evaluation and create it only when needed. Input format can
633   /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,memory::format format)634   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
635                           memory::format format) {
636     CHECK_EQ(dims, sizes.size());
637     data_.dimension_ = dims;
638     for (size_t ii = 0; ii < dims; ii++) {
639       data_.sizes_[ii] = sizes[ii];
640     }
641     data_.tf_data_format_ = format;
642     if (format != memory::format::blocked) {
643       SetTfDimOrder(dims, format);
644     }
645   }
646 
GetTfLayout()647   inline const memory::desc GetTfLayout() const {
648     memory::dims dims;
649     for (size_t ii = 0; ii < data_.dimension_; ii++) {
650       dims.push_back(data_.sizes_[ii]);
651     }
652 
653     // Create Blocked memory desc if input TF format was set like that.
654     if (data_.tf_data_format_ == memory::format::blocked) {
655       auto strides = CalculateTFStrides(dims);
656       return CreateBlockedMemDescHelper(dims, strides, data_.T_);
657     } else {
658       return memory::desc(dims, data_.T_, data_.tf_data_format_);
659     }
660   }
661 
GetCurLayout()662   inline const memory::desc GetCurLayout() const {
663     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
664   }
665 
666   // nhasabni - I've removed SetTfDimOrder that was setting default order in
667   // case of MKL-ML. We don't need a case of default dimension order because
668   // when an operator that does not get data_format attribute gets all inputs
669   // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)670   inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
671     CHECK(dimension == data_.dimension_);
672     for (size_t ii = 0; ii < dimension; ii++) {
673       data_.map_[ii] = map[ii];
674     }
675   }
676 
SetTfDimOrder(const size_t dimension,TensorFormat data_format)677   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
678     if (dimension == 5) {
679       CHECK(dimension == data_.dimension_);
680       data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
681           MklDnnDims3D::Dim3d_D;
682       data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
683           MklDnnDims3D::Dim3d_H;
684       data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
685           MklDnnDims3D::Dim3d_W;
686       data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
687           MklDnnDims3D::Dim3d_C;
688       data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
689           MklDnnDims3D::Dim3d_N;
690     } else {
691       CHECK_EQ(dimension, 4);
692       CHECK(dimension == data_.dimension_);
693       data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
694       data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
695       data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
696       data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
697     }
698   }
699 
SetTfDimOrder(const size_t dimension,memory::format format)700   inline void SetTfDimOrder(const size_t dimension, memory::format format) {
701     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
702     SetTfDimOrder(dimension, data_format);
703   }
704 
GetTfToMklDimMap()705   inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)706   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)707   inline int64 TfDimSize(int index) const {
708     return data_.sizes_[TfDimIdx(index)];
709   }
710 
711   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
712   /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)713   inline bool IsMklChannelDim(int d) const {
714     return TfDimIdx(d) == MklDnnDims::Dim_C;
715   }
716   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
717   /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)718   inline bool IsMklBatchDim(int d) const {
719     return TfDimIdx(d) == MklDnnDims::Dim_N;
720   }
721   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
722   /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)723   inline bool IsMklWidthDim(int d) const {
724     return TfDimIdx(d) == MklDnnDims::Dim_W;
725   }
726   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
727   /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)728   inline bool IsMklHeightDim(int d) const {
729     return TfDimIdx(d) == MklDnnDims::Dim_H;
730   }
731 
732   /// Check if the TF-Mkl dimension ordering map specifies if the input
733   /// tensor is in NCHW format.
IsTensorInNCHWFormat()734   inline bool IsTensorInNCHWFormat() const {
735     TensorFormat data_format = FORMAT_NCHW;
736     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
737             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
738             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
739             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
740   }
741 
742   /// Check if the TF-Mkl dimension ordering map specifies if the input
743   /// tensor is in NHWC format.
IsTensorInNHWCFormat()744   inline bool IsTensorInNHWCFormat() const {
745     TensorFormat data_format = FORMAT_NHWC;
746     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
747             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
748             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
749             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
750   }
751 
752   /// The following methods are used for serializing and de-serializing the
753   /// contents of the mklshape object.
754   /// The data is serialized in this order
755   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
756 
757   /// Size of buffer to hold the serialized object, the size is computed by
758   /// following above mentioned order
GetSerializeBufferSize()759   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
760 
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)761   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
762     CHECK(buf_size >= GetSerializeBufferSize())
763         << "Buffer size is too small to SerializeMklDnnShape";
764     *reinterpret_cast<MklShapeData*>(buf) = data_;
765   }
766 
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)767   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
768     // Make sure buffer holds at least is_mkl_tensor_.
769     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
770         << "Buffer size is too small in DeSerializeMklDnnShape";
771 
772     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
773     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
774       CHECK(buf_size >= GetSerializeBufferSize())
775           << "Buffer size is too small in DeSerializeMklDnnShape";
776       data_ = *reinterpret_cast<const MklShapeData*>(buf);
777     }
778   }
779 };
780 
781 #endif
782 
783 // List of MklShape objects. Used in Concat/Split layers.
784 
785 #ifndef INTEL_MKL_ML_ONLY
786 typedef std::vector<MklDnnShape> MklDnnShapeList;
787 #else
788 typedef std::vector<MklShape> MklShapeList;
789 #endif
790 
791 #ifdef INTEL_MKL_ML_ONLY
792 // Check if all tensors specified by MklShapes are MKL tensors.
AreAllMklTensors(const MklShapeList & shapes)793 inline bool AreAllMklTensors(const MklShapeList& shapes) {
794   for (auto& s : shapes) {
795     if (!s.IsMklTensor()) {
796       return false;
797     }
798   }
799   return true;
800 }
801 
802 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & mkl_tensor,const MklShape & mkl_shape)803 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
804                              const MklShape& mkl_shape) {
805   Tensor output_tensor;
806   TensorShape output_shape;
807 
808   for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
809     // Outermost to innermost dimension
810     output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
811   }
812 
813   // Allocate output tensor.
814   context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
815 
816   dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
817   void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
818   void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
819 
820   if (mkl_tensor.NumElements() != 0) {
821     mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
822   }
823 
824   return output_tensor;
825 }
826 #else
827 using mkldnn::stream;
828 template <typename T>
829 class MklDnnData;
830 
831 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & mkl_tensor,const MklDnnShape & mkl_shape)832 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
833                              const MklDnnShape& mkl_shape) {
834   Tensor output_tensor;
835   try {
836     if (!mkl_shape.IsMklTensor())
837       return mkl_tensor;  // return input since it is already TF tensor
838 
839     TensorShape output_shape = mkl_shape.GetTfShape();
840 
841     // Allocate output tensor.
842     context->allocate_temp(DataTypeToEnum<T>::v(), output_shape,
843                            &output_tensor);
844 
845     auto cpu_engine = engine(engine::cpu, 0);
846     MklDnnData<T> input(&cpu_engine);
847 
848     // Get Mkl layout of input tensor.
849     auto input_mkl_md = mkl_shape.GetMklLayout();
850     auto output_tf_md = mkl_shape.GetTfLayout();
851     auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
852     input.SetUsrMem(input_mkl_md, &mkl_tensor);
853 
854     // reorder
855     if (input.IsReorderNeeded(output_tf_pd)) {
856       std::vector<primitive> net;
857       CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net),
858                true);
859       stream(stream::kind::eager).submit(net).wait();
860     } else {
861       // If not, just forward input tensor to output tensor.
862       CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape));
863     }
864   } catch (mkldnn::error& e) {
865     string error_msg = "Status: " + std::to_string(e.status) +
866                        ", message: " + string(e.message) + ", in file " +
867                        string(__FILE__) + ":" + std::to_string(__LINE__);
868     LOG(FATAL) << "Operation received an exception: " << error_msg;
869   }
870   return output_tensor;
871 }
872 #endif
873 
874 // Get the MKL shape from the second string tensor
875 #ifdef INTEL_MKL_ML_ONLY
GetMklShape(OpKernelContext * ctext,int n,MklShape * mklshape)876 inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
877   mklshape->DeSerializeMklShape(
878       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
879           .flat<uint8>()
880           .data(),
881       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
882               .flat<uint8>()
883               .size() *
884           sizeof(uint8));
885 }
886 #else
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)887 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
888   mklshape->DeSerializeMklDnnShape(
889       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
890           .flat<uint8>()
891           .data(),
892       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
893               .flat<uint8>()
894               .size() *
895           sizeof(uint8));
896 }
897 #endif
898 
899 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)900 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
901   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
902 }
903 
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)904 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
905                             OpInputList* input_tensors) {
906   CHECK_NOTNULL(input_tensors);
907   ctext->input_list(name, input_tensors);
908 }
909 
910 #ifdef INTEL_MKL_ML_ONLY
911 
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklShapeList * mkl_shapes)912 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
913                             MklShapeList* mkl_shapes) {
914   OpInputList input_mkl_tensors;
915   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
916 
917   for (int i = 0; i < input_mkl_tensors.size(); i++) {
918     (*mkl_shapes)[i].DeSerializeMklShape(
919         input_mkl_tensors[i].flat<uint8>().data(),
920         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
921   }
922 }
923 
924 #else
925 
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklDnnShapeList * mkl_shapes)926 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
927                             MklDnnShapeList* mkl_shapes) {
928   OpInputList input_mkl_tensors;
929   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
930 
931   for (int i = 0; i < input_mkl_tensors.size(); i++) {
932     (*mkl_shapes)[i].DeSerializeMklDnnShape(
933         input_mkl_tensors[i].flat<uint8>().data(),
934         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
935   }
936 }
937 
938 #endif
939 
940 #ifndef INTEL_MKL_ML_ONLY
941 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
942 /// If the input tensor is in MKL layout, then obtains TensorShape from
943 /// MklShape.
GetTfShape(OpKernelContext * context,size_t input_idx)944 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
945   // Sanity check.
946   CHECK_NOTNULL(context);
947   CHECK_LT(input_idx, context->num_inputs());
948 
949   MklDnnShape input_mkl_shape;
950   GetMklShape(context, input_idx, &input_mkl_shape);
951   if (input_mkl_shape.IsMklTensor()) {
952     return input_mkl_shape.GetTfShape();
953   } else {
954     const Tensor& t = MklGetInput(context, input_idx);
955     return t.shape();
956   }
957 }
958 #endif
959 
960 #ifdef INTEL_MKL_ML_ONLY
961 // Allocate the second output tensor that will contain
962 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklShape & mkl_shape)963 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
964                                       const MklShape& mkl_shape) {
965   Tensor* second_tensor = nullptr;
966   TensorShape second_shape;
967   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
968   OP_REQUIRES_OK(ctext, ctext->allocate_output(
969                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
970                             second_shape, &second_tensor));
971   mkl_shape.SerializeMklShape(
972       second_tensor->flat<uint8>().data(),
973       second_tensor->flat<uint8>().size() * sizeof(uint8));
974 }
975 
976 #else
977 // Allocate the second output tensor that will contain
978 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)979 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
980                                       const MklDnnShape& mkl_shape) {
981   Tensor* second_tensor = nullptr;
982   TensorShape second_shape;
983   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
984   OP_REQUIRES_OK(ctext, ctext->allocate_output(
985                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
986                             second_shape, &second_tensor));
987   mkl_shape.SerializeMklDnnShape(
988       second_tensor->flat<uint8>().data(),
989       second_tensor->flat<uint8>().size() * sizeof(uint8));
990 }
991 #endif
992 
993 #ifdef INTEL_MKL_ML_ONLY
994 // Allocate the output tensor, create a second output tensor that will contain
995 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,Tensor ** output,const TensorShape & tf_shape,const MklShape & mkl_shape)996 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
997                                       Tensor** output,
998                                       const TensorShape& tf_shape,
999                                       const MklShape& mkl_shape) {
1000   Tensor* second_tensor = nullptr;
1001   TensorShape second_shape;
1002   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
1003   OP_REQUIRES_OK(
1004       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
1005                                     tf_shape, output));
1006   OP_REQUIRES_OK(ctext, ctext->allocate_output(
1007                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
1008                             second_shape, &second_tensor));
1009   mkl_shape.SerializeMklShape(
1010       second_tensor->flat<uint8>().data(),
1011       second_tensor->flat<uint8>().size() * sizeof(uint8));
1012 }
1013 
1014 #else
1015 // Allocate the output tensor, create a second output tensor that will contain
1016 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,Tensor ** output,const TensorShape & tf_shape,const MklDnnShape & mkl_shape)1017 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
1018                                       Tensor** output,
1019                                       const TensorShape& tf_shape,
1020                                       const MklDnnShape& mkl_shape) {
1021   Tensor* second_tensor = nullptr;
1022   TensorShape second_shape;
1023   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
1024   OP_REQUIRES_OK(
1025       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
1026                                     tf_shape, output));
1027   OP_REQUIRES_OK(ctext, ctext->allocate_output(
1028                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
1029                             second_shape, &second_tensor));
1030   mkl_shape.SerializeMklDnnShape(
1031       second_tensor->flat<uint8>().data(),
1032       second_tensor->flat<uint8>().size() * sizeof(uint8));
1033 }
1034 #endif
1035 
1036 // Allocates a temp tensor and returns the data buffer for temporary storage.
1037 // Currently
1038 #ifndef INTEL_MKL_ML_ONLY
1039 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::primitive_desc & pd,void ** buf_out)1040 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1041                            const memory::primitive_desc& pd, void** buf_out) {
1042   TensorShape tf_shape;
1043 
1044   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
1045   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
1046                                                  tf_shape, tensor_out));
1047   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
1048 }
1049 #else
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,dnnLayout_t lt_buff,void ** buf_out)1050 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1051                            dnnLayout_t lt_buff, void** buf_out) {
1052   TensorShape tf_shape;
1053 
1054   tf_shape.AddDim(
1055       dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(lt_buff)) /
1056           sizeof(float) +
1057       1);
1058   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::v(),
1059                                                  tf_shape, tensor_out));
1060   *buf_out = static_cast<void*>(tensor_out->flat<float>().data());
1061 }
1062 
1063 #endif
1064 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)1065 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1066                            TensorShape tf_shape) {
1067   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
1068                                                  tf_shape, tensor_out));
1069 }
1070 
GetStridesFromSizes(TensorFormat data_format,size_t * strides,const size_t * sizes)1071 inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
1072                                 const size_t* sizes) {
1073   // MKL requires strides in NCHW
1074   if (data_format == FORMAT_NHWC) {
1075     strides[0] = sizes[2];
1076     strides[1] = sizes[0] * sizes[2];
1077     strides[2] = 1;
1078     strides[3] = sizes[0] * sizes[1] * sizes[2];
1079   } else {
1080     strides[0] = 1;
1081     strides[1] = sizes[0];
1082     strides[2] = sizes[0] * sizes[1];
1083     strides[3] = sizes[0] * sizes[1] * sizes[2];
1084   }
1085 }
1086 
1087 #ifdef INTEL_MKL_ML_ONLY
MklSizesToTFSizes(OpKernelContext * context,TensorFormat data_format_,const MklShape & mkl_shape,TensorShape * tf_shape)1088 inline void MklSizesToTFSizes(OpKernelContext* context,
1089                               TensorFormat data_format_,
1090                               const MklShape& mkl_shape,
1091                               TensorShape* tf_shape) {
1092   size_t tf_dim = mkl_shape.GetDimension();
1093   const size_t* tf_sizes = mkl_shape.GetSizes();
1094 
1095   OP_REQUIRES(context, tf_dim == 4,
1096               errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
1097   std::vector<int32> sizes;
1098 
1099   sizes.push_back(tf_sizes[3]);
1100 
1101   if (data_format_ == FORMAT_NHWC) {
1102     sizes.push_back(tf_sizes[1]);
1103     sizes.push_back(tf_sizes[0]);
1104     sizes.push_back(tf_sizes[2]);
1105   } else {
1106     sizes.push_back(tf_sizes[2]);
1107     sizes.push_back(tf_sizes[1]);
1108     sizes.push_back(tf_sizes[0]);
1109   }
1110 
1111   OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
1112 }
1113 #endif
1114 
GetMklTensorDimIndex(char dimension)1115 inline int32 GetMklTensorDimIndex(char dimension) {
1116   switch (dimension) {
1117     case 'N':
1118       return MklDims::N;
1119     case 'C':
1120       return MklDims::C;
1121     case 'H':
1122       return MklDims::H;
1123     case 'W':
1124       return MklDims::W;
1125     default:
1126       LOG(FATAL) << "Invalid dimension: " << dimension;
1127       return -1;  // Avoid compiler warning about missing return value
1128   }
1129 }
1130 
1131 #ifdef INTEL_MKL_ML_ONLY
GetMklTensorDim(const MklShape & mkl_shape,char dimension)1132 inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
1133   int index = GetMklTensorDimIndex(dimension);
1134   CHECK(index >= 0 && index < mkl_shape.GetDimension())
1135       << "Invalid index from the dimension: " << index << ", " << dimension;
1136   return mkl_shape.dim_size(index);
1137 }
1138 #endif
1139 
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1140 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
1141                                  int idx_out) {
1142   int num_inputs = context->num_inputs();
1143   int num_outputs = context->num_outputs();
1144   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1145   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
1146   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1147   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
1148 
1149   const Tensor& data = context->input(idx_data_in);
1150   const Tensor& meta = context->input(idx_meta_in);
1151   Tensor output(data.dtype());
1152   Tensor meta_output(meta.dtype());
1153 
1154   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1155   CHECK(output.CopyFrom(data, data.shape()));
1156   CHECK(meta_output.CopyFrom(meta, meta.shape()));
1157   context->set_output(idx_data_out, output);
1158   context->set_output(idx_meta_out, meta_output);
1159 }
1160 
1161 #ifdef INTEL_MKL_ML_ONLY
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)1162 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
1163                                          int idx_out,
1164                                          const TensorShape& shape) {
1165   int num_inputs = context->num_inputs();
1166   int num_outputs = context->num_outputs();
1167   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1168   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1169 
1170   const Tensor& data = context->input(idx_data_in);
1171   MklShape mkl_shape_output;
1172   mkl_shape_output.SetMklTensor(false);
1173   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1174   Tensor output(data.dtype());
1175   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1176   CHECK(output.CopyFrom(data, shape));
1177   context->set_output(idx_data_out, output);
1178 }
1179 #else
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)1180 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
1181                                          int idx_out,
1182                                          const TensorShape& shape) {
1183   int num_inputs = context->num_inputs();
1184   int num_outputs = context->num_outputs();
1185   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1186   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1187 
1188   const Tensor& data = context->input(idx_data_in);
1189   MklDnnShape mkl_shape_output;
1190   mkl_shape_output.SetMklTensor(false);
1191   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1192   Tensor output(data.dtype());
1193   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1194   CHECK(output.CopyFrom(data, shape));
1195   context->set_output(idx_data_out, output);
1196 }
1197 #endif
1198 
1199 #ifdef INTEL_MKL_ML_ONLY
1200 
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1201 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
1202                                    int idx_out) {
1203   int num_inputs = context->num_inputs();
1204   int num_outputs = context->num_outputs();
1205   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1206   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1207 
1208   MklShape mkl_shape_output;
1209   mkl_shape_output.SetMklTensor(false);
1210   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1211   if (IsRefType(context->input_dtype(idx_data_in))) {
1212     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1213   } else {
1214     context->set_output(idx_data_out, context->input(idx_data_in));
1215   }
1216 }
1217 
1218 #else
1219 
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1220 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
1221                                    int idx_out) {
1222   int num_inputs = context->num_inputs();
1223   int num_outputs = context->num_outputs();
1224   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1225   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1226 
1227   MklDnnShape dnn_shape_output;
1228   dnn_shape_output.SetMklTensor(false);
1229   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
1230   if (IsRefType(context->input_dtype(idx_data_in))) {
1231     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1232   } else {
1233     context->set_output(idx_data_out, context->input(idx_data_in));
1234   }
1235 }
1236 
1237 #endif
1238 
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1239 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
1240                                     int idx_out) {
1241   int num_inputs = context->num_inputs();
1242   int num_outputs = context->num_outputs();
1243   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1244   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
1245   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1246   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
1247 
1248   if (IsRefType(context->input_dtype(idx_data_in))) {
1249     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1250     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
1251   } else {
1252     context->set_output(idx_data_out, context->input(idx_data_in));
1253     context->set_output(idx_meta_out, context->input(idx_meta_in));
1254   }
1255 }
1256 
1257 #ifndef INTEL_MKL_ML_ONLY
1258 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)1259 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
1260                                       uint32 idx_data_out) {
1261   MklDnnShape mkl_shape_output;
1262   mkl_shape_output.SetMklTensor(false);
1263   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
1264 }
1265 
ForwardMklTensorInToOutWithMklShape(OpKernelContext * context,int idx_in,int idx_out,const MklDnnShape & mkl_shape)1266 inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
1267                                                 int idx_in, int idx_out,
1268                                                 const MklDnnShape& mkl_shape) {
1269   int num_inputs = context->num_inputs();
1270   int num_outputs = context->num_outputs();
1271   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1272   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1273 
1274   AllocateOutputSetMklShape(context, idx_out, mkl_shape);
1275 
1276   if (IsRefType(context->input_dtype(idx_data_in))) {
1277     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1278   } else {
1279     context->set_output(idx_data_out, context->input(idx_data_in));
1280   }
1281 }
1282 #endif
1283 
1284 // Forward the MKL shape ONLY (used in elementwise and other ops where
1285 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)1286 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
1287                                       uint32 idx_data_in,
1288                                       uint32_t idx_data_out) {
1289   uint32 idx_meta_in =
1290       GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
1291   uint32 idx_meta_out =
1292       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
1293 
1294   if (IsRefType(context->input_dtype(idx_data_in))) {
1295     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
1296   } else {
1297     context->set_output(idx_meta_out, context->input(idx_meta_in));
1298   }
1299 }
1300 
1301 #ifdef INTEL_MKL_ML_ONLY
1302 // Set a dummy MKL shape (called when the output is in TF format)
SetDummyMklShapeOutput(OpKernelContext * context,uint32 idx_data_out)1303 inline void SetDummyMklShapeOutput(OpKernelContext* context,
1304                                    uint32 idx_data_out) {
1305   MklShape mkl_shape_output;
1306   mkl_shape_output.SetMklTensor(false);
1307   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
1308 }
1309 // We don't need these functions in MKLDNN. We have defined equality operator
1310 // on MklDnnShape class directly.
1311 
1312 // Checks if the TF shape for both MKL tensors is the same or not
1313 // Returns: true if both TF shapes are the same, false otherwise
MklCompareShapes(const MklShape * input_shape_0,const MklShape * input_shape_1)1314 inline bool MklCompareShapes(const MklShape* input_shape_0,
1315                              const MklShape* input_shape_1) {
1316   // Check for number of dimensions
1317   if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) {
1318     return false;
1319   }
1320 
1321   // Check size of each dimension
1322   size_t ndims = input_shape_0->GetDimension();
1323   for (size_t i = 0; i < ndims; i++) {
1324     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
1325       return false;
1326     }
1327   }
1328 
1329   return true;
1330 }
1331 
1332 // Checks if the TF shape for both tensors is the same or not
1333 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const MklShape * input_shape_0,const TensorShape * input_shape_1)1334 inline bool MklCompareShapes(const MklShape* input_shape_0,
1335                              const TensorShape* input_shape_1) {
1336   // Check for number of dimensions
1337   if (input_shape_0->GetDimension() != input_shape_1->dims()) {
1338     return false;
1339   }
1340 
1341   // Check size of each dimension
1342   size_t ndims = input_shape_0->GetDimension();
1343   for (size_t i = 0; i < ndims; i++) {
1344     if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) {
1345       return false;
1346     }
1347   }
1348 
1349   return true;
1350 }
1351 
1352 // Checks if the TF shape for both tensors is the same or not
1353 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const TensorShape * input_shape_0,const MklShape * input_shape_1)1354 inline bool MklCompareShapes(const TensorShape* input_shape_0,
1355                              const MklShape* input_shape_1) {
1356   return MklCompareShapes(input_shape_1, input_shape_0);
1357 }
1358 
1359 // Checks if the TF shape for both tensors is the same or not
1360 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const TensorShape * input_shape_0,const TensorShape * input_shape_1)1361 inline bool MklCompareShapes(const TensorShape* input_shape_0,
1362                              const TensorShape* input_shape_1) {
1363   // Check for number of dimensions
1364   if (input_shape_0->dims() != input_shape_1->dims()) {
1365     return false;
1366   }
1367 
1368   // Check size of each dimension
1369   size_t ndims = input_shape_0->dims();
1370   for (size_t i = 0; i < ndims; i++) {
1371     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
1372       return false;
1373     }
1374   }
1375 
1376   return true;
1377 }
1378 
1379 // These functions do not compile with MKL-DNN since mkl.h is missing.
1380 // We may need to remove them later.
1381 // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
1382 // out.
MklNHWCToNCHW(const Tensor & input,Tensor ** output)1383 inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
1384   const float* buf_in = input.flat<float>().data();
1385   float* buf_out = (*output)->flat<float>().data();
1386 
1387   int64 N = input.dim_size(0);
1388   int64 H = input.dim_size(1);
1389   int64 W = input.dim_size(2);
1390   int64 C = input.dim_size(3);
1391   int64 stride_n = H * W * C;
1392 #pragma omp parallel for num_threads(16)
1393   for (int64 n = 0; n < N; ++n) {
1394     mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C,
1395                   buf_out + n * stride_n, H * W);
1396   }
1397 }
1398 
MklNCHWToNHWC(const Tensor & input,Tensor ** output)1399 inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
1400   const float* buf_in = input.flat<float>().data();
1401   float* buf_out = (*output)->flat<float>().data();
1402 
1403   int64 N = (*output)->dim_size(0);
1404   int64 H = (*output)->dim_size(1);
1405   int64 W = (*output)->dim_size(2);
1406   int64 C = (*output)->dim_size(3);
1407   int64 stride_n = H * W * C;
1408 #pragma omp parallel for num_threads(16)
1409   for (int64 n = 0; n < N; ++n) {
1410     mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W,
1411                   buf_out + n * stride_n, C);
1412   }
1413 }
1414 
1415 #endif
1416 // -------------------------------------------------------------------
1417 
1418 #ifndef INTEL_MKL_ML_ONLY
1419 
1420 /// Return MKL-DNN data type (memory::data_type) for input type T
1421 ///
1422 /// @input None
1423 /// @return memory::data_type corresponding to type T
1424 template <typename T>
1425 static memory::data_type MklDnnType();
1426 
1427 /// Instantiation for float type. Add similar instantiations for other
1428 /// type if needed.
1429 template <>
1430 memory::data_type MklDnnType<float>() {
1431   return memory::data_type::f32;
1432 }
1433 template <>
1434 memory::data_type MklDnnType<quint8>() {
1435   return memory::data_type::u8;
1436 }
1437 template <>
1438 memory::data_type MklDnnType<qint8>() {
1439   return memory::data_type::s8;
1440 }
1441 template <>
1442 memory::data_type MklDnnType<qint32>() {
1443   return memory::data_type::s32;
1444 }
1445 
1446 /// Map TensorFlow's data format into MKL-DNN 3D data format
1447 /// @input: TensorFlow data format
1448 /// @return: memory::format corresponding to TensorFlow data format;
1449 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1450 inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1451   if (format == FORMAT_NHWC)
1452     return memory::format::ndhwc;
1453   else if (format == FORMAT_NCHW)
1454     return memory::format::ncdhw;
1455   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1456   return memory::format::format_undef;
1457 }
1458 
1459 /// Map TensorFlow's data format into MKL-DNN data format
1460 ///
1461 /// @input: TensorFlow data format
1462 /// @return: memory::format corresponding to TensorFlow data format;
1463 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1464 inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1465   if (format == FORMAT_NHWC)
1466     return memory::format::nhwc;
1467   else if (format == FORMAT_NCHW)
1468     return memory::format::nchw;
1469   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1470   return memory::format::format_undef;
1471 }
1472 
1473 /// Map MKL-DNN data format to TensorFlow's data format
1474 ///
1475 /// @input: memory::format
1476 /// @return: Tensorflow data format corresponding to memory::format
1477 ///          Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(memory::format format)1478 inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
1479   if (format == memory::format::nhwc || format == memory::format::ndhwc)
1480     return FORMAT_NHWC;
1481   else if (format == memory::format::nchw || format == memory::format::ncdhw)
1482     return FORMAT_NCHW;
1483   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1484 
1485   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1486   // that we don't come here.
1487   return FORMAT_NHWC;
1488 }
1489 
1490 /// Map TensorShape object into memory::dims required by MKL-DNN
1491 ///
1492 /// This function will simply map input TensorShape into MKL-DNN dims
1493 /// naively. So it will preserve the order of dimensions. E.g., if
1494 /// input tensor is in NHWC format, then dims will be in NHWC format
1495 /// also.
1496 ///
1497 /// @input TensorShape object in shape
1498 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1499 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1500   memory::dims dims(shape.dims());
1501   for (int d = 0; d < shape.dims(); ++d) {
1502     dims[d] = shape.dim_size(d);
1503   }
1504   return dims;
1505 }
1506 
1507 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1508 ///
1509 /// This function is a specific one than above function. It will map input
1510 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1511 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1512 /// will be in NCHW format, and not in NHWC format.
1513 ///
1514 /// @input TensorShape object in shape
1515 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1516 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1517                                               TensorFormat format) {
1518   // Check validity of format.
1519   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1520            memory::format::format_undef);
1521 
1522   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1523   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1524   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1525   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1526 
1527   // MKL-DNN requires dimensions in NCHW format.
1528   return memory::dims({n, c, h, w});
1529 }
1530 
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1531 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1532                                                TensorFormat format) {
1533   // Check validity of format.
1534   CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1535            memory::format::format_undef);
1536 
1537   int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1538   int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1539   int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1540   int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1541   int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1542 
1543   // MKL-DNN requires dimensions in NCDHW format.
1544   return memory::dims({n, c, d, h, w});
1545 }
1546 
1547 /// Overloaded version of function above. Input parameters are
1548 /// self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1549 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1550                                      TensorFormat format) {
1551   // Check validity of format.
1552   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1553            memory::format::format_undef);
1554 
1555   int n = in_dims[GetTensorDimIndex(format, 'N')];
1556   int c = in_dims[GetTensorDimIndex(format, 'C')];
1557   int h = in_dims[GetTensorDimIndex(format, 'H')];
1558   int w = in_dims[GetTensorDimIndex(format, 'W')];
1559 
1560   // MKL-DNN requires dimensions in NCHW format.
1561   return memory::dims({n, c, h, w});
1562 }
1563 
1564 /// Map MklDnn memory::dims object into TensorShape object.
1565 ///
1566 /// This function will simply map input shape in MKL-DNN memory::dims format
1567 /// in Tensorflow's TensorShape object by preserving dimension order.
1568 ///
1569 /// @input MKL-DNN memory::dims object
1570 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1571 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1572   std::vector<int32> shape(dims.size(), -1);
1573   for (int d = 0; d < dims.size(); d++) {
1574     shape[d] = dims[d];
1575   }
1576 
1577   TensorShape ret;
1578   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1579   return ret;
1580 }
1581 
1582 /// Function to calculate strides given tensor shape in Tensorflow order
1583 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1584 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1585 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1586 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1587 ///
1588 /// @input Tensorflow shape in memory::dims type
1589 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1590 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1591   CHECK_GT(dims_tf_order.size(), 0);
1592   memory::dims strides(dims_tf_order.size());
1593   int last_dim_idx = dims_tf_order.size() - 1;
1594   strides[last_dim_idx] = 1;
1595   for (int d = last_dim_idx - 1; d >= 0; d--) {
1596     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1597   }
1598   return strides;
1599 }
1600 
TFPaddingToMklDnnPadding(Padding pad)1601 inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
1602   // MKL-DNN only supports zero padding.
1603   return padding_kind::zero;
1604 }
1605 
1606 /// Helper function to create memory descriptor in Blocked format
1607 ///
1608 /// @input: Tensor dimensions
1609 /// @input: strides corresponding to dimensions. One can use utility
1610 ///         function such as CalculateTFStrides to compute strides
1611 ///         for given dimensions.
1612 /// @return: memory::desc object corresponding to blocked memory format
1613 ///          for given dimensions and strides.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype)1614 inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
1615                                                const memory::dims& strides,
1616                                                memory::data_type dtype) {
1617   CHECK_EQ(dim.size(), strides.size());
1618 
1619   // We have to construct memory descriptor in a C style. This is not at all
1620   // ideal but MKLDNN does not offer any API to construct descriptor in
1621   // blocked format except a copy constructor that accepts
1622   // mkldnn_memory_desc_t.
1623   mkldnn_memory_desc_t md;
1624   md.primitive_kind = mkldnn_memory;
1625   md.ndims = dim.size();
1626   md.format = mkldnn_blocked;
1627   md.data_type = memory::convert_to_c(dtype);
1628 
1629   for (size_t i = 0; i < dim.size(); i++) {
1630     md.layout_desc.blocking.block_dims[i] = 1;
1631     md.layout_desc.blocking.strides[1][i] = 1;
1632     md.layout_desc.blocking.strides[0][i] = strides[i];
1633     md.layout_desc.blocking.padding_dims[i] = dim[i];
1634     md.layout_desc.blocking.offset_padding_to_data[i] = 0;
1635     md.dims[i] = dim[i];
1636   }
1637   md.layout_desc.blocking.offset_padding = 0;
1638 
1639   return memory::desc(md);
1640 }
1641 
1642 template <typename T>
1643 inline primitive FindOrCreateReorder(const memory* from, const memory* to);
1644 /*
1645  * Class to represent all the resources corresponding to a tensor in TensorFlow
1646  * that are required to execute an operation (such as Convolution).
1647  */
1648 template <typename T>
1649 class MklDnnData {
1650  private:
1651   /// MKL-DNN memory primitive for input user memory
1652   memory* user_memory_;
1653 
1654   /// MKL-DNN memory primitive in case input or output reorder is needed.
1655   memory* reorder_memory_;
1656 
1657   /// Operations memory descriptor
1658   memory::desc* op_md_;
1659   // flat to indicate if data is 3D or not.
1660   bool bIs3D;
1661   /// Operations temp buffer
1662   void* allocated_buffer_;
1663   /// CPU engine on which operation will be executed
1664   const engine* cpu_engine_;
1665 
1666  public:
MklDnnData(const engine * e)1667   explicit MklDnnData(const engine* e)
1668       : user_memory_(nullptr),
1669         reorder_memory_(nullptr),
1670         op_md_(nullptr),
1671         allocated_buffer_(nullptr),
1672         cpu_engine_(e) {}
1673 
~MklDnnData()1674   ~MklDnnData() {
1675     if (allocated_buffer_ != nullptr) {
1676       cpu_allocator()->DeallocateRaw(allocated_buffer_);
1677     }
1678     cpu_engine_ = nullptr;  // We don't own this.
1679     delete (user_memory_);
1680     delete (reorder_memory_);
1681     delete (op_md_);
1682   }
1683 
GetTensorBuffer(const Tensor * tensor)1684   inline void* GetTensorBuffer(const Tensor* tensor) const {
1685     CHECK_NOTNULL(tensor);
1686     return const_cast<void*>(
1687         static_cast<const void*>(tensor->flat<T>().data()));
1688   }
1689 
SetIs3DData(bool bIs3D_)1690   void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
1691 
GetIs3D()1692   bool GetIs3D() { return bIs3D; }
1693 
1694   /// Set user memory primitive using specified dimensions, memory format and
1695   /// data_buffer. Function automatically uses element data type by using
1696   /// input type T used for creating call object.
1697   ///
1698   /// In a nutshell, function allows user to describe the input tensor to
1699   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1700   /// memory format HWIO, and the buffer that contains actual values is
1701   /// pointed by data_buffer.
1702   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
1703                         void* data_buffer = nullptr) {
1704     auto md = memory::desc(dim, MklDnnType<T>(), fm);
1705     SetUsrMem(md, data_buffer);
1706   }
1707 
SetUsrMem(const memory::dims & dim,memory::format fm,const Tensor * tensor)1708   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
1709                         const Tensor* tensor) {
1710     CHECK_NOTNULL(tensor);
1711     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1712   }
1713 
1714   /// Helper function to create memory descriptor in Blocked format
1715   ///
1716   /// @input: Tensor dimensions
1717   /// @input: strides corresponding to dimensions. One can use utility
1718   ///         function such as CalculateTFStrides to compute strides
1719   ///         for given dimensions.
1720   /// @return: memory::desc object corresponding to blocked memory format
1721   ///          for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1722   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1723                                                   const memory::dims& strides) {
1724     return CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>());
1725   }
1726 
1727   /// A version of SetUsrMem call that allows user to create memory in blocked
1728   /// format. So in addition to accepting dimensions, it also accepts strides.
1729   /// This allows user to create memory for tensor in a format that is not
1730   /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1731   /// dimensional tensor as a native format. But by using blocked format, a user
1732   /// can create memory for 6D tensor.
1733   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1734                         void* data_buffer = nullptr) {
1735     CHECK_EQ(dim.size(), strides.size());
1736     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1737     SetUsrMem(blocked_md, data_buffer);
1738   }
1739 
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1740   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1741                         const Tensor* tensor) {
1742     CHECK_NOTNULL(tensor);
1743     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1744   }
1745 
1746   /// A version of function to set user memory primitive that accepts memory
1747   /// descriptor directly, instead of accepting dimensions and format. This
1748   /// function is more generic that the one above, but the function above is
1749   /// sufficient in most cases.
1750   inline void SetUsrMem(const memory::desc& md, void* data_buffer = nullptr) {
1751     auto pd = memory::primitive_desc(md, *cpu_engine_);
1752     SetUsrMem(pd, data_buffer);
1753   }
1754 
1755   /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1756   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1757     CHECK_NOTNULL(tensor);
1758     SetUsrMem(md, GetTensorBuffer(tensor));
1759   }
1760 
1761   /// A version of function to set user memory primitive that accepts primitive
1762   /// descriptor directly, instead of accepting dimensions and format. This
1763   /// function is more generic that the one above, but the function above is
1764   /// sufficient in most cases.
1765   inline void SetUsrMem(const memory::primitive_desc& pd,
1766                         void* data_buffer = nullptr) {
1767     CHECK_NOTNULL(cpu_engine_);
1768     if (user_memory_) delete user_memory_;
1769     // TODO(nhasabni): can we remove dynamic memory allocation?
1770     if (data_buffer) {
1771       user_memory_ = new memory(pd, data_buffer);
1772     } else {
1773       user_memory_ = new memory(pd);
1774     }
1775   }
1776 
1777   /// A version of SetUsrMem with primitive descriptor and tensor
SetUsrMem(const memory::primitive_desc & pd,const Tensor * tensor)1778   inline void SetUsrMem(const memory::primitive_desc& pd,
1779                         const Tensor* tensor) {
1780     CHECK_NOTNULL(tensor);
1781     SetUsrMem(pd, GetTensorBuffer(tensor));
1782   }
1783 
1784   /// Get function for user memory primitive.
GetUsrMem()1785   inline const memory* GetUsrMem() const { return user_memory_; }
1786 
1787   /// Get function for primitive descriptor of user memory primitive.
GetUsrMemPrimDesc()1788   inline const memory::primitive_desc GetUsrMemPrimDesc() const {
1789     CHECK_NOTNULL(user_memory_);
1790     return user_memory_->get_primitive_desc();
1791   }
1792 
1793   /// Get function for descriptor of user memory.
GetUsrMemDesc()1794   inline memory::desc GetUsrMemDesc() {
1795     // This is ugly. Why MKL-DNN does not provide desc() method of const type??
1796     const memory::primitive_desc pd = GetUsrMemPrimDesc();
1797     return const_cast<memory::primitive_desc*>(&pd)->desc();
1798   }
1799 
1800   /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1801   inline void* GetUsrMemDataHandle() const {
1802     CHECK_NOTNULL(user_memory_);
1803     return user_memory_->get_data_handle();
1804   }
1805 
1806   /// Set function for data buffer of user memory primitive.
SetUsrMemDataHandle(void * data_buffer)1807   inline void SetUsrMemDataHandle(void* data_buffer) {
1808     CHECK_NOTNULL(user_memory_);
1809     CHECK_NOTNULL(data_buffer);
1810     user_memory_->set_data_handle(data_buffer);
1811   }
1812 
1813   /// Set function for data buffer of user memory primitive.
SetUsrMemDataHandle(const Tensor * tensor)1814   inline void SetUsrMemDataHandle(const Tensor* tensor) {
1815     CHECK_NOTNULL(user_memory_);
1816     CHECK_NOTNULL(tensor);
1817     user_memory_->set_data_handle(GetTensorBuffer(tensor));
1818   }
1819 
1820   /// allocate function for data buffer
AllocateBuffer(size_t size)1821   inline void AllocateBuffer(size_t size) {
1822     const int64 kMemoryAlginment = 64;  // For AVX512 memory alignment.
1823     allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlginment, size);
1824   }
1825 
GetAllocatedBuffer()1826   inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1827 
1828   /// Get the memory primitive for input and output of an op. If inputs
1829   /// to an op require reorders, then this function returns memory primitive
1830   /// for reorder. Otherwise, it will return memory primitive for user memory.
1831   ///
1832   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1833   /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
1834   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1835   /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1836   inline const memory& GetOpMem() const {
1837     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1838   }
1839 
1840   /// Set memory descriptor of an operation in terms of dimensions and memory
1841   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1842   /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
1843   /// best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format fm)1844   inline void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
1845     // TODO(nhasabni): can we remove dynamic memory allocation?
1846     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1847   }
1848 
1849   /// Get function for memory descriptor for an operation
GetOpMemDesc()1850   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1851 
1852   /// Predicate that checks if we need to reorder user's memory into memory
1853   /// pointed by op_pd.
1854   ///
1855   /// @input: op_pd - memory primitive descriptor of the given input of an
1856   ///               operation
1857   /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::primitive_desc & op_pd)1858   inline bool IsReorderNeeded(const memory::primitive_desc& op_pd) const {
1859     CHECK_NOTNULL(user_memory_);
1860     return op_pd != user_memory_->get_primitive_desc();
1861   }
1862 
1863   /// Predicate that checks if we need to reorder user's memory into memory
1864   /// based on the provided format.
1865   ///
1866   /// @input: target_format - memory format of the given input of an
1867   ///               operation
1868   /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::format & target_format)1869   inline bool IsReorderNeeded(const memory::format& target_format) const {
1870     CHECK_NOTNULL(user_memory_);
1871     return target_format !=
1872            user_memory_->get_primitive_desc().desc().data.format;
1873   }
1874 
1875   /// Function to create a reorder from memory pointed by from to memory pointed
1876   /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1877   inline primitive CreateReorder(const memory* from, const memory* to) const {
1878     CHECK_NOTNULL(from);
1879     CHECK_NOTNULL(to);
1880     return reorder(*from, *to);
1881   }
1882 
1883   /// Function to handle input reordering
1884   ///
1885   /// Check if we need to reorder this input of an operation.
1886   /// Return true and allocate reorder memory primitive if reorder is needed.
1887   /// Otherwise, return false and do not allocate reorder memory primitive.
1888   ///
1889   /// To check if reorder is needed, this function compares memory primitive
1890   /// descriptor of an operation (op_pd) for the given input with the
1891   /// user-specified memory primitive descriptor.
1892   ///
1893   /// @input: op_pd - memory primitive descriptor of the given input of an
1894   ///               operation
1895   /// @input: net - net to which to add reorder primitive in case it is needed.
1896   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,std::vector<primitive> * net)1897   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1898                                   std::vector<primitive>* net) {
1899     CHECK_NOTNULL(net);
1900     CHECK_NOTNULL(user_memory_);
1901     if (IsReorderNeeded(op_pd)) {
1902       // TODO(nhasabni): can we remove dynamic memory allocation?
1903       reorder_memory_ = new memory(op_pd);
1904       net->push_back(CreateReorder(user_memory_, reorder_memory_));
1905       return true;
1906     }
1907     return false;
1908   }
1909 
1910   /// TODO: this is a faster path with reorder primitive cache compared with
1911   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1912   /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd)1913   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
1914     CHECK_NOTNULL(user_memory_);
1915     if (IsReorderNeeded(op_pd)) {
1916       // TODO(nhasabni): can we remove dynamic memory allocation?
1917       // primitive reuse don't allow two same reorder prim in
1918       // one stream, so submit it immediately
1919       reorder_memory_ = new memory(op_pd);
1920       std::vector<primitive> net;
1921       net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
1922       stream(stream::kind::eager).submit(net).wait();
1923       return true;
1924     }
1925     return false;
1926   }
1927 
1928   /// Overloaded version of above function that accepts memory buffer
1929   /// where output of reorder needs to be stored.
1930   ///
1931   /// @input: op_pd - memory primitive descriptor of the given input of an
1932   ///               operation
1933   /// @reorder_data_handle - memory buffer where output of reorder needs to be
1934   ///                        stored. Primitive does not check if buffer is
1935   ///                        enough size to write.
1936   /// @input: net - net to which to add reorder primitive in case it is needed.
1937   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,void * reorder_data_handle,std::vector<primitive> * net)1938   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1939                                   void* reorder_data_handle,
1940                                   std::vector<primitive>* net) {
1941     CHECK_NOTNULL(net);
1942     CHECK_NOTNULL(reorder_data_handle);
1943     CHECK_NOTNULL(user_memory_);
1944     if (IsReorderNeeded(op_pd)) {
1945       // TODO(nhasabni): can we remove dynamic memory allocation?
1946       reorder_memory_ = new memory(op_pd, reorder_data_handle);
1947       net->push_back(CreateReorder(user_memory_, reorder_memory_));
1948       return true;
1949     }
1950     return false;
1951   }
1952 
1953   /// TODO: this is a faster path with reorder primitive cache compared with
1954   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1955   /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd,void * reorder_data_handle)1956   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1957                                   void* reorder_data_handle) {
1958     CHECK_NOTNULL(reorder_data_handle);
1959     CHECK_NOTNULL(user_memory_);
1960     if (IsReorderNeeded(op_pd)) {
1961       // TODO(nhasabni): can we remove dynamic memory allocation?
1962       // primitive reuse don't allow two same reorder prim in
1963       // one stream, so submit it immediately
1964       std::vector<primitive> net;
1965       reorder_memory_ = new memory(op_pd, reorder_data_handle);
1966       net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
1967       stream(stream::kind::eager).submit(net).wait();
1968       return true;
1969     }
1970     return false;
1971   }
1972 
1973   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1974   /// where output of reorder needs to be stored.
1975   ///
1976   /// @input: op_pd - memory primitive descriptor of the given input of an
1977   ///               operation
1978   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1979   ///                   reorder. Primitive does not check if buffer is
1980   ///                   enough size to write.
1981   /// @input: net - net to which to add reorder primitive in case it is needed.
1982   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,Tensor * reorder_tensor,std::vector<primitive> * net)1983   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1984                                   Tensor* reorder_tensor,
1985                                   std::vector<primitive>* net) {
1986     CHECK_NOTNULL(net);
1987     CHECK_NOTNULL(reorder_tensor);
1988     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
1989   }
1990 
1991   /// TODO: this is a faster path with reorder primitive cache compared with
1992   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1993   /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd,Tensor * reorder_tensor)1994   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1995                                   Tensor* reorder_tensor) {
1996     CHECK_NOTNULL(reorder_tensor);
1997     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
1998   }
1999 
2000   /// Function to handle output reorder
2001   ///
2002   /// This function performs very similar functionality as input reordering
2003   /// function above. The only difference is that this function does not add
2004   /// reorder primitive to the net. The reason for this is: the reorder
2005   /// primitive for output needs to be added to the list only after operation
2006   /// has executed. But we need to prepare a temporary buffer in case output
2007   /// reorder is needed. And this temporary buffer will hold the output of
2008   /// an operation before it is fed to reorder primitive.
2009   ///
2010   /// @input memory primitive descriptor for the given output of an operation
2011   /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::primitive_desc & op_pd)2012   inline bool PrepareReorderToUserMemIfReq(
2013       const memory::primitive_desc& op_pd) {
2014     CHECK_NOTNULL(user_memory_);
2015     if (IsReorderNeeded(op_pd)) {
2016       // TODO(nhasabni): can we remove dynamic memory allocation?
2017       reorder_memory_ = new memory(op_pd);
2018       return true;
2019     }
2020     return false;
2021   }
2022 
2023   /// Function to actually insert reorder primitive in the net
2024   ///
2025   /// This function completes remaining part of output reordering. It inserts
2026   /// a reordering primitive from the temporary buffer that holds the output
2027   /// to the user-specified output buffer.
2028   ///
2029   /// @input: net - net to which to add reorder primitive
InsertReorderToUserMem(std::vector<primitive> * net)2030   inline void InsertReorderToUserMem(std::vector<primitive>* net) {
2031     CHECK_NOTNULL(net);
2032     CHECK_NOTNULL(user_memory_);
2033     CHECK_NOTNULL(reorder_memory_);
2034     net->push_back(CreateReorder(reorder_memory_, user_memory_));
2035   }
2036 
2037   /// TODO: this is a faster path with reorder primitive cache compared with
2038   ///       InsertReorderToUserMem(std::vector<primitive>* net), will remove
2039   ///       slow path in the future
InsertReorderToUserMem()2040   inline void InsertReorderToUserMem() {
2041     CHECK_NOTNULL(user_memory_);
2042     CHECK_NOTNULL(reorder_memory_);
2043     // primitive reuse don't allow two same reorder prim in
2044     // one stream, so submit it immediately
2045     std::vector<primitive> net;
2046     net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
2047     stream(stream::kind::eager).submit(net).wait();
2048   }
2049 };
2050 
2051 /// Base class for operations with reuse of primitives
2052 ///
2053 class MklPrimitive {
2054  public:
~MklPrimitive()2055   virtual ~MklPrimitive() {}
2056 
2057   // Dummy data which MKL DNN never operates on
2058   unsigned char* DummyData = nullptr;
2059 };
2060 
2061 const mkldnn::memory::dims NONE_DIMS = {};
2062 
2063 //
2064 // LRUCache is a class which implements LRU (Least Recently Used) cache.
2065 // The implementation is similar to that of
2066 //    tensorflow/core/platform/cloud/expiring_lru_cache.h
2067 // without its thread-safe part because the cache is supposed to be
2068 // used as thread local (for instance, MklPrimitive caching).
2069 //
2070 // The LRU list maintains objects in chronological order based on
2071 // creation time, with the least recently accessed object at the
2072 // tail of LRU list, while the most recently accessed object
2073 // at the head of LRU list.
2074 //
2075 // This class is used to maintain an upper bound on the total number of
2076 // cached items. When the cache reaches its capacity, the LRU item will
2077 // be removed and replaced by a new one from SetOp call.
2078 //
2079 template <typename T>
2080 class LRUCache {
2081  public:
LRUCache(size_t capacity)2082   explicit LRUCache(size_t capacity) {
2083     capacity_ = capacity;
2084     Clear();
2085   }
2086 
GetOp(const string & key)2087   T* GetOp(const string& key) {
2088     auto it = cache_.find(key);
2089     if (it == cache_.end()) {
2090       return nullptr;
2091     }
2092 
2093     // Move to the front of LRU list as the most recently accessed.
2094     lru_list_.erase(it->second.lru_iterator);
2095     lru_list_.push_front(it->first);
2096     it->second.lru_iterator = lru_list_.begin();
2097     return it->second.op;
2098   }
2099 
SetOp(const string & key,T * op)2100   void SetOp(const string& key, T* op) {
2101     if (lru_list_.size() >= capacity_) {
2102       Delete();
2103     }
2104 
2105     // Insert an entry to the front of the LRU list
2106     lru_list_.push_front(key);
2107     Entry entry(op, lru_list_.begin());
2108     cache_.emplace(std::make_pair(key, std::move(entry)));
2109   }
2110 
Clear()2111   void Clear() {
2112     if (lru_list_.empty()) return;
2113 
2114     // Clean up the cache
2115     cache_.clear();
2116     lru_list_.clear();
2117   }
2118 
2119  private:
2120   struct Entry {
2121     // The entry's value.
2122     T* op;
2123 
2124     // A list iterator pointing to the entry's position in the LRU list.
2125     std::list<string>::iterator lru_iterator;
2126 
2127     // Constructor
EntryEntry2128     Entry(T* op, std::list<string>::iterator it) {
2129       this->op = op;
2130       this->lru_iterator = it;
2131     }
2132 
2133     // Move construcctor
EntryEntry2134     Entry(Entry&& source) noexcept
2135         : lru_iterator(std::move(source.lru_iterator)) {
2136       op = std::move(source.op);
2137       source.op = std::forward<T*>(nullptr);
2138     }
2139 
2140     // Destructor
~EntryEntry2141     ~Entry() {
2142       if (op != nullptr) delete op;
2143     }
2144   };
2145 
2146   // Remove the least recently accessed entry from LRU list, which
2147   // is the tail of lru_list_. Update cache_ correspondingly.
Delete()2148   bool Delete() {
2149     if (lru_list_.empty()) return false;
2150     string key = lru_list_.back();
2151     lru_list_.pop_back();
2152     cache_.erase(key);
2153     return true;
2154   }
2155 
2156   // Cache capacity
2157   size_t capacity_;
2158 
2159   // The cache, a map from string key to a LRU entry.
2160   std::unordered_map<string, Entry> cache_;
2161 
2162   // The LRU list of entries.
2163   // The front of the list contains the key of the most recently accessed
2164   // entry, while the back of the list is the least recently accessed entry.
2165   std::list<string> lru_list_;
2166 };
2167 
2168 template <typename T>
2169 class MklPrimitiveFactory {
2170  public:
MklPrimitiveFactory()2171   MklPrimitiveFactory() {}
2172 
~MklPrimitiveFactory()2173   ~MklPrimitiveFactory() {}
2174 
GetOp(const string & key)2175   MklPrimitive* GetOp(const string& key) {
2176     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
2177     return lru_cache.GetOp(key);
2178   }
2179 
SetOp(const string & key,MklPrimitive * op)2180   void SetOp(const string& key, MklPrimitive* op) {
2181     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
2182     lru_cache.SetOp(key, op);
2183   }
2184 
2185   /// Function to decide whether HW has AVX512 or AVX2
2186   /// For those legacy device(w/o AVX512 and AVX2),
2187   /// MKL-DNN GEMM will be used.
IsLegacyPlatform()2188   static inline bool IsLegacyPlatform() {
2189     return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
2190             !port::TestCPUFeature(port::CPUFeature::AVX2));
2191   }
2192 
2193   /// Fuction to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()2194   static inline bool IsPrimitiveMemOptEnabled() {
2195     bool is_primitive_mem_opt_enabled = true;
2196     TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
2197                                    &is_primitive_mem_opt_enabled));
2198     return is_primitive_mem_opt_enabled;
2199   }
2200 
2201  private:
GetLRUCache()2202   static inline LRUCache<MklPrimitive>& GetLRUCache() {
2203     static const int kCapacity = 1024;  // cache capacity
2204     static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
2205     return lru_cache_;
2206   }
2207 };
2208 
2209 // utility class for creating keys of MKL primitive pool.
2210 class FactoryKeyCreator {
2211  public:
FactoryKeyCreator()2212   FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
2213 
~FactoryKeyCreator()2214   ~FactoryKeyCreator() {}
2215 
AddAsKey(const string & str)2216   void AddAsKey(const string& str) { Append(str); }
2217 
AddAsKey(const mkldnn::memory::dims & dims)2218   void AddAsKey(const mkldnn::memory::dims& dims) {
2219     for (unsigned int i = 0; i < dims.size(); i++) {
2220       AddAsKey<int>(dims[i]);
2221     }
2222   }
2223 
2224   template <typename T>
AddAsKey(const T data)2225   void AddAsKey(const T data) {
2226     auto buffer = reinterpret_cast<const char*>(&data);
2227     Append(StringPiece(buffer, sizeof(T)));
2228   }
2229 
GetKey()2230   string GetKey() { return key_; }
2231 
2232  private:
2233   string key_;
2234   const char delimiter = 'x';
2235   const int kMaxKeyLength = 256;
Append(StringPiece s)2236   void Append(StringPiece s) {
2237     key_.append(string(s));
2238     key_.append(1, delimiter);
2239   }
2240 };
2241 
2242 static inline memory::format get_desired_format(int channel,
2243                                                 bool is_2d = true) {
2244   memory::format fmt_desired = memory::format::any;
2245 
2246   if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
2247     fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
2248   } else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
2249              (channel % 8) == 0) {
2250     fmt_desired = is_2d ? memory::format::nChw8c
2251                         : memory::format::ncdhw;  // no avx2 support for 3d yet.
2252   } else {
2253     fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
2254   }
2255   return fmt_desired;
2256 }
2257 
2258 class MklReorderPrimitive : public MklPrimitive {
2259  public:
MklReorderPrimitive(const memory * from,const memory * to)2260   explicit MklReorderPrimitive(const memory* from, const memory* to) {
2261     Setup(from, to);
2262   }
~MklReorderPrimitive()2263   ~MklReorderPrimitive() {}
2264 
GetPrimitive()2265   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
2266 
SetMemory(const memory * from,const memory * to)2267   void SetMemory(const memory* from, const memory* to) {
2268     context_.src_mem->set_data_handle(from->get_data_handle());
2269     context_.dst_mem->set_data_handle(to->get_data_handle());
2270   }
2271 
2272  private:
2273   struct ReorderContext {
2274     std::shared_ptr<mkldnn::memory> src_mem;
2275     std::shared_ptr<mkldnn::memory> dst_mem;
2276     std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext2277     ReorderContext()
2278         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
2279   } context_;
2280 
2281   engine cpu_engine_ = engine(engine::cpu, 0);
2282 
Setup(const memory * from,const memory * to)2283   void Setup(const memory* from, const memory* to) {
2284     context_.src_mem.reset(new memory(
2285         {from->get_primitive_desc().desc(), cpu_engine_}, DummyData));
2286     context_.dst_mem.reset(
2287         new memory({to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
2288     context_.reorder_prim = std::make_shared<mkldnn::reorder>(
2289         reorder(*context_.src_mem, *context_.dst_mem));
2290   }
2291 };
2292 
2293 template <typename T>
2294 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
2295  public:
Get(const memory * from,const memory * to)2296   static MklReorderPrimitive* Get(const memory* from, const memory* to) {
2297     auto reorderPrim = static_cast<MklReorderPrimitive*>(
2298         MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
2299     if (reorderPrim == nullptr) {
2300       reorderPrim = new MklReorderPrimitive(from, to);
2301       MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
2302                                                               reorderPrim);
2303     }
2304     reorderPrim->SetMemory(from, to);
2305     return reorderPrim;
2306   }
2307 
GetInstance()2308   static MklReorderPrimitiveFactory& GetInstance() {
2309     static MklReorderPrimitiveFactory instance_;
2310     return instance_;
2311   }
2312 
2313  private:
MklReorderPrimitiveFactory()2314   MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()2315   ~MklReorderPrimitiveFactory() {}
2316 
CreateKey(const memory * from,const memory * to)2317   static string CreateKey(const memory* from, const memory* to) {
2318     string prefix = "reorder";
2319     FactoryKeyCreator key_creator;
2320     auto const& from_desc = from->get_primitive_desc().desc().data;
2321     auto const& to_desc = to->get_primitive_desc().desc().data;
2322     const int KIdxFirstStride = 0;
2323     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
2324     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
2325     memory::dims from_strides(
2326         from_desc.layout_desc.blocking.strides[KIdxFirstStride],
2327         &from_desc.layout_desc.blocking
2328              .strides[KIdxFirstStride][from_desc.ndims]);
2329     memory::dims to_strides(
2330         to_desc.layout_desc.blocking.strides[KIdxFirstStride],
2331         &to_desc.layout_desc.blocking.strides[KIdxFirstStride][to_desc.ndims]);
2332     key_creator.AddAsKey(prefix);
2333     key_creator.AddAsKey(static_cast<int>(from_desc.format));
2334     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
2335     key_creator.AddAsKey(from_dims);
2336     key_creator.AddAsKey(from_strides);
2337     key_creator.AddAsKey(static_cast<int>(to_desc.format));
2338     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
2339     key_creator.AddAsKey(to_dims);
2340     key_creator.AddAsKey(to_strides);
2341     return key_creator.GetKey();
2342   }
2343 
GetReorder(const memory * from,const memory * to)2344   MklPrimitive* GetReorder(const memory* from, const memory* to) {
2345     string key = CreateKey(from, to);
2346     return this->GetOp(key);
2347   }
2348 
SetReorder(const memory * from,const memory * to,MklPrimitive * op)2349   void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2350     string key = CreateKey(from, to);
2351     this->SetOp(key, op);
2352   }
2353 };
2354 
2355 /// Fuction to find(or create) a reorder from memory pointed by
2356 /// from to memory pointed by to, it will created primitive or
2357 /// get primitive from pool if it is cached.
2358 /// Returns the primitive.
2359 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)2360 inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
2361   CHECK_NOTNULL(from);
2362   CHECK_NOTNULL(to);
2363   MklReorderPrimitive* reorder_prim =
2364       MklReorderPrimitiveFactory<T>::Get(from, to);
2365   return *reorder_prim->GetPrimitive();
2366 }
2367 
2368 // utility function to determine if it is conv 1x1 and stride != 1
2369 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2370 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2371                                 memory::dims strides) {
2372   if (filter_dims.size() != 4 || strides.size() != 2) return false;
2373 
2374   return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2375           ((strides[0] != 1) || (strides[1] != 1)));
2376 }
2377 
2378 #endif  // INTEL_MKL_DNN
2379 
2380 }  // namespace tensorflow
2381 #endif  // INTEL_MKL
2382 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2383