1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
18 
19 #ifdef INTEL_MKL
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "mkldnn.hpp"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/util/mkl_util.h"
28 
29 using mkldnn::inner_product_forward;
30 using mkldnn::primitive_attr;
31 using mkldnn::prop_kind;
32 using mkldnn::stream;
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 
38 #ifdef INTEL_MKL_DNN_ONLY
39 // Temporarily copying some definitions from mkl_cblas.h so the same code can
40 // be used when calling oneDNN or CBLAS batchmatmul in mkl_batch_matmul_op.cc.
41 typedef enum { CblasRowMajor, CblasColumnMajor } CBLAS_LAYOUT;
42 #define MKL_INT int
43 #endif
44 
45 // This structure aggregates multiple inputs to MklDnnMatMul* methods.
46 struct MklDnnMatMulFwdParams {
47   memory::dims src_dims;
48   memory::dims weight_dims;
49   memory::dims bias_dims;
50   memory::dims dst_dims;
51   MEMORY_FORMAT src_format;
52   MEMORY_FORMAT weight_format;
53   MEMORY_FORMAT dst_format;
54   string dtypes = string("");
55   struct PostOpParam {
56     string name;
57     std::vector<float> param;
58   };
59   std::vector<PostOpParam> post_op_params;
60 
61   MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
62                         memory::dims bias_dims, memory::dims dst_dims,
63                         MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
64                         MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
65                         MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
src_dimsMklDnnMatMulFwdParams66       : src_dims(src_dims),
67         weight_dims(weight_dims),
68         bias_dims(bias_dims),
69         dst_dims(dst_dims),
70         src_format(src_format),
71         weight_format(weight_format),
72         dst_format(dst_format) {}
73 };
74 
75 // With quantization, input, weight, bias, and output can have different types.
76 // So we use different template parameters for each type.
77 // TODO(intel-tf): The template type "T" is currently used to match the
78 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
79 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
80 // needs to drop "T".
81 template <typename T, typename Tinput, typename Tweight, typename Tbias,
82           typename Toutput>
83 class MklDnnMatMulFwdPrimitive : public MklPrimitive {
84  public:
MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)85   explicit MklDnnMatMulFwdPrimitive(
86       const MklDnnMatMulFwdParams& matmulFwdParams)
87       : MklPrimitive(engine(engine::kind::cpu, 0)) {
88     // Create matmul primitive
89     if (context_.matmul_fwd == nullptr) {
90       Setup(matmulFwdParams);
91     }
92   }
93 
~MklDnnMatMulFwdPrimitive()94   ~MklDnnMatMulFwdPrimitive() {}
95 
96   // Inner-product forward execute with bias:
97   //  - src_data: input data buffer of src
98   //  - weight_data: input data buffer of weight
99   //  - bias_data: input data buffer of bias
100   //  - dst_data: output data buffer of dst
Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,std::shared_ptr<stream> fwd_stream)101   void Execute(const Tinput* src_data, const Tweight* weight_data,
102                const Tbias* bias_data, Toutput* dst_data,
103                std::shared_ptr<stream> fwd_stream) {
104 #ifdef ENABLE_MKLDNN_THREADPOOL
105     context_.src_mem->set_data_handle(
106         static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
107     context_.weight_mem->set_data_handle(
108         static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
109     context_.bias_mem->set_data_handle(
110         static_cast<void*>(const_cast<Tbias*>(bias_data)));
111     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
112                                       *fwd_stream);
113 #else
114     context_.src_mem->set_data_handle(
115         static_cast<void*>(const_cast<Tinput*>(src_data)));
116     context_.weight_mem->set_data_handle(
117         static_cast<void*>(const_cast<Tweight*>(weight_data)));
118     context_.bias_mem->set_data_handle(
119         static_cast<void*>(const_cast<Tbias*>(bias_data)));
120     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
121 #endif  // ENABLE_MKLDNN_THREADPOOL
122 
123     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
124 
125     // After execution, set data handle back
126     context_.src_mem->set_data_handle(DummyData);
127     context_.weight_mem->set_data_handle(DummyData);
128     context_.bias_mem->set_data_handle(DummyData);
129     context_.dst_mem->set_data_handle(DummyData);
130   }
131 
132   std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
GetPrimitiveDesc()133   GetPrimitiveDesc() const {
134     return context_.fwd_pd;
135   }
136 
137  private:
138   // Primitive reuse context for inner-product Fwd op
139   struct MklDnnMatMulFwdContext {
140     // MKL-DNN memory.
141     std::shared_ptr<mkldnn::memory> src_mem;
142     std::shared_ptr<mkldnn::memory> weight_mem;
143     std::shared_ptr<mkldnn::memory> bias_mem;
144     std::shared_ptr<mkldnn::memory> dst_mem;
145 
146     // Descriptor and primitive-descriptor for forward inner-product.
147     std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
148     std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
149 
150     // Memory descriptors.
151     std::shared_ptr<mkldnn::memory::desc> src_md;
152     std::shared_ptr<mkldnn::memory::desc> weight_md;
153     std::shared_ptr<mkldnn::memory::desc> bias_md;
154     std::shared_ptr<mkldnn::memory::desc> dst_md;
155 
156     // Inner-product primitive.
157     std::shared_ptr<mkldnn::primitive> matmul_fwd;
158     std::vector<mkldnn::primitive> fwd_primitives;
159 
160     std::vector<std::unordered_map<int, memory>> net_args;
161 
MklDnnMatMulFwdContextMklDnnMatMulFwdContext162     MklDnnMatMulFwdContext()
163         : src_mem(nullptr),
164           weight_mem(nullptr),
165           bias_mem(nullptr),
166           dst_mem(nullptr),
167           fwd_desc(nullptr),
168           fwd_pd(nullptr),
169           src_md(nullptr),
170           weight_md(nullptr),
171           bias_md(nullptr),
172           dst_md(nullptr),
173           matmul_fwd(nullptr) {}
174   };
175 
Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)176   void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
177     // Create memory descriptors for inner-product data without specified
178     // format.
179     context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
180                                            MklDnnType<Tinput>(),
181                                            matmul_fwd_params.src_format));
182 
183     context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
184                                               MklDnnType<Tweight>(),
185                                               matmul_fwd_params.weight_format));
186 
187     context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
188                                            MklDnnType<Toutput>(),
189                                            matmul_fwd_params.dst_format));
190 
191     context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
192                                             MklDnnType<Tbias>(),
193                                             memory::format_tag::any));
194     // Create an inner-product.
195     context_.fwd_desc.reset(new inner_product_forward::desc(
196         prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
197         *context_.bias_md, *context_.dst_md));
198     context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
199         *context_.fwd_desc, cpu_engine_));
200 
201     // Check if there is any fusion as post-ops
202     auto const& post_op_params = matmul_fwd_params.post_op_params;
203     mkldnn::primitive_attr post_ops_attr;
204     mkldnn::post_ops post_ops;
205     if (!post_op_params.empty()) {
206       for (auto const& post_op_param : post_op_params) {
207         if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
208           DCHECK_EQ(post_op_param.param.size(), 3);
209           float op_scale = post_op_param.param[0];
210           float op_alpha = post_op_param.param[1];
211           float op_beta = post_op_param.param[2];
212           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
213                                   op_alpha, op_beta);
214         } else if (post_op_param.name == "relu6") {
215           DCHECK_EQ(post_op_param.param.size(), 3);
216           float op_scale = post_op_param.param[0];
217           float op_alpha = post_op_param.param[1];
218           float op_beta = post_op_param.param[2];
219           post_ops.append_eltwise(op_scale,
220                                   mkldnn::algorithm::eltwise_bounded_relu,
221                                   op_alpha, op_beta);
222         } else if (post_op_param.name == "elu") {
223           DCHECK_EQ(post_op_param.param.size(), 3);
224           float op_scale = post_op_param.param[0];
225           float op_alpha = post_op_param.param[1];
226           float op_beta = post_op_param.param[2];
227           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_elu,
228                                   op_alpha, op_beta);
229         } else if (post_op_param.name == "tanh") {
230           DCHECK_EQ(post_op_param.param.size(), 3);
231           float op_scale = post_op_param.param[0];
232           float op_alpha = post_op_param.param[1];
233           float op_beta = post_op_param.param[2];
234           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_tanh,
235                                   op_alpha, op_beta);
236         } else if (post_op_param.name == "output_scale") {
237           DCHECK_EQ(post_op_param.param.size(), 1);
238           std::vector<float> scales;
239           scales.push_back(post_op_param.param[0]);
240           post_ops_attr.set_output_scales(0, scales);
241         } else if (post_op_param.name == "sum") {
242           DCHECK_EQ(post_op_param.param.size(), 1);
243           float op_scale = post_op_param.param[0];
244           post_ops.append_sum(op_scale);
245 
246         } else {
247           DCHECK((post_op_param.name == "relu") ||
248                  (post_op_param.name == "relu6") ||
249                  (post_op_param.name == "elu") ||
250                  (post_op_param.name == "tanh") ||
251                  (post_op_param.name == "sum") ||
252                  (post_op_param.name == "leakyrelu") ||
253                  (post_op_param.name == "output_scale"));
254         }
255       }
256       post_ops_attr.set_post_ops(post_ops);
257       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
258           *context_.fwd_desc, post_ops_attr, cpu_engine_));
259     } else {
260       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
261           *context_.fwd_desc, cpu_engine_));
262     }
263 
264     // Create memory primitive based on dummy data
265     context_.src_mem.reset(
266         new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
267     context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
268                                          cpu_engine_, DummyData));
269     context_.dst_mem.reset(
270         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
271     context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims},
272                                         MklDnnType<Tbias>(),
273                                         memory::format_tag::x},
274                                        cpu_engine_, DummyData));
275 
276     // Create inner-product primitive.
277     context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
278     context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
279                                  {MKLDNN_ARG_WEIGHTS, *context_.weight_mem},
280                                  {MKLDNN_ARG_BIAS, *context_.bias_mem},
281                                  {MKLDNN_ARG_DST, *context_.dst_mem}});
282 
283     context_.fwd_primitives.push_back(*context_.matmul_fwd);
284     return;
285   }
286 
287   struct MklDnnMatMulFwdContext context_;
288 };
289 
290 template <typename T, typename Tinput, typename Tweight, typename Tbias,
291           typename Toutput>
292 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
293  public:
Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)294   static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
295       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
296     MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
297         nullptr;
298 
299     if (do_not_cache) {
300       // Always create new primitive
301       matmul_fwd =
302           new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
303               mkldnn_matmul_fwd_dims);
304     } else {
305       // Try to find a suitable one in pool
306       matmul_fwd = dynamic_cast<
307           MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
308           MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
309                                           Toutput>::GetInstance()
310               .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
311       if (matmul_fwd == nullptr) {
312         matmul_fwd =
313             new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
314                 mkldnn_matmul_fwd_dims);
315         MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
316                                         Toutput>::GetInstance()
317             .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
318       }
319     }
320     return matmul_fwd;
321   }
322 
323  private:
MklDnnMatMulFwdPrimitiveFactory()324   MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory()325   ~MklDnnMatMulFwdPrimitiveFactory() {}
326 
GetInstance()327   static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
328     static MklDnnMatMulFwdPrimitiveFactory instance_;
329     return instance_;
330   }
331 
CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)332   static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
333     string prefix = "matmul_fwd_";
334     FactoryKeyCreator key_creator;
335     key_creator.AddAsKey(prefix);
336     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
337     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
338     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
339     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
340     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
341     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
342 
343     // Generate keys for post-ops
344     for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
345       if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
346           post_op_param.name == "elu" || post_op_param.name == "tanh" ||
347           post_op_param.name == "leakyrelu") {
348         DCHECK_EQ(post_op_param.param.size(), 3);
349         key_creator.AddAsKey(post_op_param.name);
350         key_creator.AddAsKey(post_op_param.param[0]);
351         key_creator.AddAsKey(post_op_param.param[1]);
352         key_creator.AddAsKey(post_op_param.param[2]);
353       } else if (post_op_param.name == "sum") {
354         DCHECK_EQ(post_op_param.param.size(), 1);
355         key_creator.AddAsKey(post_op_param.name);
356         key_creator.AddAsKey(post_op_param.param[0]);
357       } else if (post_op_param.name == "output_scale") {
358         DCHECK_EQ(post_op_param.param.size(), 1);
359         key_creator.AddAsKey(post_op_param.name);
360         key_creator.AddAsKey(post_op_param.param[0]);
361       } else {
362         return string("not_a_key");
363       }
364     }
365     return key_creator.GetKey();
366   }
367 
GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)368   MklPrimitive* GetMklDnnMatMulFwd(
369       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
370     string key = CreateKey(mkldnn_matmul_fwd_dims);
371     return this->GetOp(key);
372   }
373 
SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)374   void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
375                           MklPrimitive* op) {
376     string key = CreateKey(mkldnn_matmul_fwd_dims);
377     this->SetOp(key, op);
378   }
379 };
380 
381 template <class Tweight, class Toutput>
382 class MklDnnMatMulOpBase : public OpKernel {
383  public:
MklDnnMatMulOpBase(OpKernelConstruction * context)384   explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
385       : OpKernel(context) {}
386   void Compute(OpKernelContext* context) override = 0;
387 
388   // Allocate output tensor.
AllocateOutputTensor(OpKernelContext * context,const inner_product_forward::primitive_desc & mkldnn_matmul_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,Tensor ** output_tensor)389   virtual void AllocateOutputTensor(
390       OpKernelContext* context,
391       const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
392       const memory::dims& output_dims_mkl_order,
393       MklTensorFormat output_tf_format, Tensor** output_tensor) {
394     DCHECK(output_tensor);
395     auto dst_pd = mkldnn_matmul_prim_desc.dst_desc();
396 
397     MklDnnShape output_mkl_shape;
398     output_mkl_shape.SetMklTensor(true);
399     output_mkl_shape.SetMklLayout(&dst_pd);
400     output_mkl_shape.SetElemType(MklDnnType<Toutput>());
401     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
402                                  output_dims_mkl_order, output_tf_format);
403 
404     TensorShape output_tf_shape;
405     output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
406 
407     // Allocate Output Tensor
408     AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
409                               output_tf_shape, output_mkl_shape);
410   }
411 
412   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
413   // be acquired before entering the function, since it is acquired
414   // inside the function.
IsWeightCacheEmpty(OpKernelContext * context)415   inline bool IsWeightCacheEmpty(OpKernelContext* context)
416       TF_LOCKS_EXCLUDED(mu_) {
417     tf_shared_lock lock(mu_);
418     return (weight_oi_.NumElements() == 0);
419   }
420 
421   // Cache the converted weight in a persistent tensor.
422   // Only one thread can execute this method at any given time.
CacheWeight(OpKernelContext * context,const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)423   void CacheWeight(
424       OpKernelContext* context,
425       const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
426           matmul_fwd_pd,
427       Tweight* weight_data, const Tensor& weight_tensor,
428       MklDnnData<Tweight>& weight, const memory::desc& weight_md)
429       TF_LOCKS_EXCLUDED(mu_) {
430     mutex_lock lock(mu_);
431     const Tensor& weight_t = *weight_oi_.AccessTensor(context);
432 
433     // If the weights are already cached, there's nothing to do
434     if (weight_t.NumElements() > 0) {
435       return;
436     }
437 
438     // reorder and cache the weight
439     weight.SetUsrMem(weight_md, &weight_tensor);
440     weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_,
441                                context);
442     weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
443 
444     Tensor* weight_tensor_ptr = nullptr;
445 
446     size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size();
447     TensorShape weight_tf_shape;
448     weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
449 
450     OP_REQUIRES_OK(context, context->allocate_persistent(
451                                 DataTypeToEnum<Tweight>::value, weight_tf_shape,
452                                 &weight_oi_, &weight_tensor_ptr));
453 
454     void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr);
455     memcpy(weight_oi_t_data, weight_data, weight_size);
456 
457     // cache the memory descriptor
458     auto expected_md = matmul_fwd_pd->weights_desc();
459     Tensor* weight_md_tensor_ptr = nullptr;
460     TensorShape weight_mkl_format;
461     weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
462 
463     OP_REQUIRES_OK(
464         context, context->allocate_persistent(DataTypeToEnum<Tweight>::value,
465                                               weight_mkl_format, &weight_oi_md_,
466                                               &weight_md_tensor_ptr));
467     *reinterpret_cast<memory::desc*>(
468         weight_md_tensor_ptr->flat<Tweight>().data()) = expected_md;
469   }
470 
GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)471   Tweight* GetCachedWeight(OpKernelContext* context,
472                            const memory::desc& expected_md)
473       TF_LOCKS_EXCLUDED(mu_) {
474     tf_shared_lock lock(mu_);
475     const Tensor& weight_t = *weight_oi_.AccessTensor(context);
476     const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context);
477 
478     // Check if the memory descriptor of the cached weight is same as
479     // expected_md. if so use the cached memory, else return NULL
480     if (weight_md_t.flat<Tweight>().size()) {
481       const memory::desc& stored_md =
482           *(static_cast<memory::desc*>(weight_md_t.data()));
483       if (stored_md == expected_md) {
484         return static_cast<Tweight*>(
485             const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
486       }
487     }
488     return nullptr;
489   }
490 
491   engine cpu_engine_ = engine(engine::kind::cpu, 0);
492 
493  protected:
494   // Tensor to save reordered weight
495   mutex mu_;
496   PersistentTensor weight_oi_ TF_GUARDED_BY(mu_);
497   PersistentTensor weight_oi_md_ TF_GUARDED_BY(mu_);
498 
499   bool is_weight_const_;
500 
501   const int kInputIndexSrc = 0;
502   const int kInputIndexWeight = 1;
503   const int kInputIndexBias = 2;
504   const int kOutputIndexDst = 0;
505 };
506 
507 using mkldnn::matmul;
508 
509 namespace {
510 
511 struct MklMatMulParams {
512   memory::dims a_dims;
513   memory::dims b_dims;
514   memory::dims c_dims;
515   memory::dims a_strides;
516   memory::dims b_strides;
517   memory::dims c_strides;
518 
MklMatMulParamsMklMatMulParams519   MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
520                   memory::dims a_strides, memory::dims b_strides,
521                   memory::dims c_strides)
522       : a_dims(a_dims),
523         b_dims(b_dims),
524         c_dims(c_dims),
525         a_strides(a_strides),
526         b_strides(b_strides),
527         c_strides(c_strides) {}
528 };
529 
530 template <typename T>
531 class MklMatMulPrimitive : public MklPrimitive {
532  public:
MklMatMulPrimitive(const MklMatMulParams & params)533   explicit MklMatMulPrimitive(const MklMatMulParams& params)
534       : MklPrimitive(engine(engine::kind::cpu, 0)) {
535     // Create matmul primitive
536     Setup(params);
537   }
538 
~MklMatMulPrimitive()539   ~MklMatMulPrimitive() {}
540 
Execute(const T * a_data,const T * b_data,T * c_data,std::shared_ptr<stream> stream)541   void Execute(const T* a_data, const T* b_data, T* c_data,
542                std::shared_ptr<stream> stream) {
543 #ifdef ENABLE_MKLDNN_THREADPOOL
544     context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
545                                     *stream);
546     context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
547                                     *stream);
548     context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)),
549                                     *stream);
550 #else
551     context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
552     context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
553     context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
554 #endif  // ENABLE_MKLDNN_THREADPOOL
555     execute_primitives(context_.matmul_primitives, stream, context_.net_args);
556 
557     // After execution, set data handle back
558     context_.a_mem->set_data_handle(DummyData);
559     context_.b_mem->set_data_handle(DummyData);
560     context_.c_mem->set_data_handle(DummyData);
561   }
562 
563  private:
564   // Primitive reuse context for MatMul op
565   struct MklMatMulContext {
566     // MKL-DNN memory.
567     std::shared_ptr<mkldnn::memory> a_mem;
568     std::shared_ptr<mkldnn::memory> b_mem;
569     std::shared_ptr<mkldnn::memory> c_mem;
570 
571     // Descriptor and primitive-descriptor for MatMul.
572     std::shared_ptr<matmul::desc> desc;
573     std::shared_ptr<matmul::primitive_desc> prim_desc;
574 
575     // Memory descriptors.
576     std::shared_ptr<mkldnn::memory::desc> a_md;
577     std::shared_ptr<mkldnn::memory::desc> b_md;
578     std::shared_ptr<mkldnn::memory::desc> c_md;
579 
580     // MatMul primitive.
581     std::vector<mkldnn::primitive> matmul_primitives;
582     std::vector<std::unordered_map<int, memory>> net_args;
583 
MklMatMulContextMklMatMulContext584     MklMatMulContext()
585         : a_mem(nullptr),
586           b_mem(nullptr),
587           c_mem(nullptr),
588           desc(nullptr),
589           prim_desc(nullptr),
590           a_md(nullptr),
591           b_md(nullptr),
592           c_md(nullptr) {}
593   };
594 
Setup(const MklMatMulParams & params)595   void Setup(const MklMatMulParams& params) {
596     std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr;
597 
598     // Create MatMul descriptor and primitive descriptor.
599     context_.a_md.reset(
600         new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));
601 
602     context_.b_md.reset(
603         new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));
604 
605     context_.c_md.reset(
606         new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));
607 
608     // Create matmul.
609     context_.desc.reset(
610         new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
611     context_.prim_desc.reset(
612         new matmul::primitive_desc(*context_.desc, cpu_engine_));
613 
614     // Create memory primitive based on dummy data.
615     context_.a_mem.reset(
616         new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
617     context_.b_mem.reset(
618         new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
619     context_.c_mem.reset(
620         new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
621 
622     // Create matmul primitive.
623     matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc));
624     context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
625                                  {MKLDNN_ARG_WEIGHTS, *context_.b_mem},
626                                  {MKLDNN_ARG_DST, *context_.c_mem}});
627 
628     context_.matmul_primitives.push_back(*matmul_primitive);
629     return;
630   }
631 
632   struct MklMatMulContext context_;
633 };
634 
635 template <typename T>
636 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
637  public:
Get(const MklMatMulParams & params,bool do_not_cache)638   static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
639                                     bool do_not_cache) {
640     MklMatMulPrimitive<T>* matmul_prim = nullptr;
641 
642     if (do_not_cache) {
643       // Always create new primitive
644       matmul_prim = new MklMatMulPrimitive<T>(params);
645     } else {
646       // Try to find a suitable one in pool
647       matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
648           MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
649       if (matmul_prim == nullptr) {
650         matmul_prim = new MklMatMulPrimitive<T>(params);
651         MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
652                                                                  matmul_prim);
653       }
654     }
655 
656     return matmul_prim;
657   }
658 
659  private:
MklMatMulPrimitiveFactory()660   MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory()661   ~MklMatMulPrimitiveFactory() {}
662 
GetInstance()663   static MklMatMulPrimitiveFactory& GetInstance() {
664     static MklMatMulPrimitiveFactory instance_;
665     return instance_;
666   }
667 
CreateKey(const MklMatMulParams & params)668   static string CreateKey(const MklMatMulParams& params) {
669     string prefix = "matmul_";
670     FactoryKeyCreator key_creator;
671     key_creator.AddAsKey(prefix);
672     key_creator.AddAsKey(params.a_dims);
673     key_creator.AddAsKey(params.b_dims);
674     key_creator.AddAsKey(params.c_dims);
675     key_creator.AddAsKey(params.a_strides);
676     key_creator.AddAsKey(params.b_strides);
677     key_creator.AddAsKey(params.c_strides);
678     key_creator.AddAsKey(typeid(T).name());
679 
680     return key_creator.GetKey();
681   }
682 
GetMklMatMul(const MklMatMulParams & params)683   MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
684     string key = CreateKey(params);
685     return this->GetOp(key);
686   }
687 
SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)688   void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
689     string key = CreateKey(params);
690     this->SetOp(key, op);
691   }
692 };
693 
694 template <typename T>
695 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
696                float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
697                float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
698   using dims = mkldnn::memory::dims;
699 
700   // Prepare strides based on the transa and transb flags: transposed
701   // matrices have strides swapped
702   dims a_dims = dims{m, k};
703   dims b_dims = dims{k, n};
704   dims c_dims = dims{m, n};
705   dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
706   dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
707   dims c_strides = dims{ldc, 1};
708 
709   // MklMatMul uses const alpha and beta, make guarantee here to ensure
710   // they are never changed.
711   DCHECK_EQ(alpha, 1.0f);
712   DCHECK_EQ(beta, 0.f);
713 
714   MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
715                          c_strides);
716   MklMatMulPrimitive<T>* matmul_prim =
717       MklMatMulPrimitiveFactory<T>::Get(params, 0);
718 
719   // Execute matmul primitive.
720   std::shared_ptr<stream> cpu_stream;
721   cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
722   matmul_prim->Execute(a, b, c, cpu_stream);
723 }
724 
725 }  // anonymous namespace
726 
727 }  // namespace tensorflow
728 
729 #endif  // INTEL_MKL
730 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
731