1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 // This file uses MKL CBLAS batched xGEMM for acceleration of TF Batch
19 // Matrix-Matrix Multiplication (MatMul) operations.
20 // We currently register this kernel only for MKL supported data
21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined
22 // by the build system only when MKL is chosen as an option at configure stage
23 // and when it is undefined at build time, this file becomes an empty
24 // compilation unit
25 
26 #define EIGEN_USE_THREADS
27 
28 #if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
29 #include <vector>
30 #include "mkl_cblas.h"
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/register_types.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/framework/type_traits.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/kernels/fill_functor.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 
47 template <typename Device, typename Scalar>
48 class BatchMatMulMkl : public OpKernel {
49  public:
BatchMatMulMkl(OpKernelConstruction * context)50   explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) {
51     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
52     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
53   }
54 
~BatchMatMulMkl()55   virtual ~BatchMatMulMkl() {}
56 
Compute(OpKernelContext * ctx)57   void Compute(OpKernelContext *ctx) override {
58     const Tensor &lhs = ctx->input(0);
59     const Tensor &rhs = ctx->input(1);
60     OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
61                 errors::InvalidArgument("lhs and rhs has different ndims: ",
62                                         lhs.shape().DebugString(), " vs. ",
63                                         rhs.shape().DebugString()));
64     const int ndims = lhs.dims();
65     OP_REQUIRES(
66         ctx, ndims >= 2,
67         errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
68     TensorShape out_shape;
69     for (int i = 0; i < ndims - 2; ++i) {
70       OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
71                   errors::InvalidArgument(
72                       "lhs.dim(", i, ") and rhs.dim(", i,
73                       ") must be the same: ", lhs.shape().DebugString(), " vs ",
74                       rhs.shape().DebugString()));
75       out_shape.AddDim(lhs.dim_size(i));
76     }
77     auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
78     auto lhs_rows = lhs.dim_size(ndims - 2);
79     auto lhs_cols = lhs.dim_size(ndims - 1);
80     auto rhs_rows = rhs.dim_size(ndims - 2);
81     auto rhs_cols = rhs.dim_size(ndims - 1);
82     if (adj_x_) std::swap(lhs_rows, lhs_cols);
83     if (adj_y_) std::swap(rhs_rows, rhs_cols);
84     OP_REQUIRES(ctx, lhs_cols == rhs_rows,
85                 errors::InvalidArgument(
86                     "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
87                     ": ", lhs.shape().DebugString(), " ",
88                     rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
89     out_shape.AddDim(lhs_rows);
90     out_shape.AddDim(rhs_cols);
91     Tensor *out = nullptr;
92     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
93     if (out->NumElements() == 0) {
94       return;
95     }
96     if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
97       functor::SetZeroFunctor<Device, Scalar> f;
98       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
99       return;
100     }
101 
102     auto rhs_reshaped = rhs.template flat_inner_dims<Scalar, 3>();
103     auto lhs_reshaped = lhs.template flat_inner_dims<Scalar, 3>();
104     auto out_reshaped = out->template flat_inner_dims<Scalar, 3>();
105     const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
106     const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
107     const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
108 
109     std::vector<MKL_INT> m_array(batch_size, M);
110     std::vector<MKL_INT> n_array(batch_size, N);
111     std::vector<MKL_INT> k_array(batch_size, K);
112     std::vector<MKL_INT> lda_array(batch_size, adj_x_ ? M : K);
113     std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N);
114     std::vector<MKL_INT> ldc_array(batch_size, N);
115     std::vector<MKL_INT> group_size(1, batch_size);
116     std::vector<const Scalar *> a_array;
117     std::vector<const Scalar *> b_array;
118     std::vector<Scalar *> c_array;
119     a_array.reserve(batch_size);
120     b_array.reserve(batch_size);
121     c_array.reserve(batch_size);
122     for (int64 i = 0; i < batch_size; i++) {
123       a_array.push_back(&lhs_reshaped(i, 0, 0));
124       b_array.push_back(&rhs_reshaped(i, 0, 0));
125       c_array.push_back(&out_reshaped(i, 0, 0));
126     }
127 
128     MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
129                       &k_array[0], &a_array[0], &lda_array[0], &b_array[0],
130                       &ldb_array[0], &c_array[0], &ldc_array[0], 1,
131                       &group_size[0]);
132   }
133 
134  private:
135   bool adj_x_;
136   bool adj_y_;
137 
MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const float ** A_Array,const MKL_INT * lda_Array,const float ** B_Array,const MKL_INT * ldb_Array,float ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)138   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
139                          const bool TransB, const MKL_INT *M_Array,
140                          const MKL_INT *N_Array, const MKL_INT *K_Array,
141                          const float **A_Array, const MKL_INT *lda_Array,
142                          const float **B_Array, const MKL_INT *ldb_Array,
143                          float **C_Array, const MKL_INT *ldc_Array,
144                          const MKL_INT group_count, const MKL_INT *group_size) {
145     std::vector<CBLAS_TRANSPOSE> TransA_Array(
146         group_size[0], TransA ? CblasTrans : CblasNoTrans);
147     std::vector<CBLAS_TRANSPOSE> TransB_Array(
148         group_size[0], TransB ? CblasTrans : CblasNoTrans);
149     std::vector<float> alpha_Array(group_size[0], 1.0);
150     std::vector<float> beta_Array(group_size[0], 0.0);
151     cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], M_Array,
152                       N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
153                       B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
154                       group_count, group_size);
155   }
156 
MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const double ** A_Array,const MKL_INT * lda_Array,const double ** B_Array,const MKL_INT * ldb_Array,double ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)157   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
158                          const bool TransB, const MKL_INT *M_Array,
159                          const MKL_INT *N_Array, const MKL_INT *K_Array,
160                          const double **A_Array, const MKL_INT *lda_Array,
161                          const double **B_Array, const MKL_INT *ldb_Array,
162                          double **C_Array, const MKL_INT *ldc_Array,
163                          const MKL_INT group_count, const MKL_INT *group_size) {
164     std::vector<CBLAS_TRANSPOSE> TransA_array(
165         group_size[0], TransA ? CblasTrans : CblasNoTrans);
166     std::vector<CBLAS_TRANSPOSE> TransB_array(
167         group_size[0], TransB ? CblasTrans : CblasNoTrans);
168     std::vector<double> alpha_Array(group_size[0], 1.0);
169     std::vector<double> beta_Array(group_size[0], 0.0);
170     cblas_dgemm_batch(Layout, &TransA_array[0], &TransB_array[0], M_Array,
171                       N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
172                       B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
173                       group_count, group_size);
174   }
175 
MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const complex64 ** A_Array,const MKL_INT * lda_Array,const complex64 ** B_Array,const MKL_INT * ldb_Array,complex64 ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)176   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
177                          const bool TransB, const MKL_INT *M_Array,
178                          const MKL_INT *N_Array, const MKL_INT *K_Array,
179                          const complex64 **A_Array, const MKL_INT *lda_Array,
180                          const complex64 **B_Array, const MKL_INT *ldb_Array,
181                          complex64 **C_Array, const MKL_INT *ldc_Array,
182                          const MKL_INT group_count, const MKL_INT *group_size) {
183     std::vector<CBLAS_TRANSPOSE> TransA_array(
184         group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
185     std::vector<CBLAS_TRANSPOSE> TransB_array(
186         group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
187     std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
188     std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
189     cblas_cgemm_batch(
190         Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
191         static_cast<const void *>(&alpha_Array[0]),
192         reinterpret_cast<const void **>(A_Array), lda_Array,
193         reinterpret_cast<const void **>(B_Array), ldb_Array,
194         static_cast<const void *>(&beta_Array[0]),
195         reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
196   }
197 
MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const complex128 ** A_Array,const MKL_INT * lda_Array,const complex128 ** B_Array,const MKL_INT * ldb_Array,complex128 ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)198   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
199                          const bool TransB, const MKL_INT *M_Array,
200                          const MKL_INT *N_Array, const MKL_INT *K_Array,
201                          const complex128 **A_Array, const MKL_INT *lda_Array,
202                          const complex128 **B_Array, const MKL_INT *ldb_Array,
203                          complex128 **C_Array, const MKL_INT *ldc_Array,
204                          const MKL_INT group_count, const MKL_INT *group_size) {
205     std::vector<CBLAS_TRANSPOSE> TransA_array(
206         group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
207     std::vector<CBLAS_TRANSPOSE> TransB_array(
208         group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
209     std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
210     std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
211     cblas_zgemm_batch(
212         Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
213         static_cast<const void *>(&alpha_Array[0]),
214         reinterpret_cast<const void **>(A_Array), lda_Array,
215         reinterpret_cast<const void **>(B_Array), ldb_Array,
216         static_cast<const void *>(&beta_Array[0]),
217         reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
218   }
219 };
220 
221 #define REGISTER_BATCH_MATMUL_MKL(TYPE)                                 \
222   REGISTER_KERNEL_BUILDER(                                              \
223       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
224       BatchMatMulMkl<CPUDevice, TYPE>)
225 
226 #ifdef ENABLE_MKL
227 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
228 TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
229 TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
230 TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
231 #endif  // ENABLE_MKL
232 
233 }  // end namespace tensorflow
234 #endif
235