1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 // This file uses MKL CBLAS batched xGEMM for acceleration of TF Batch 19 // Matrix-Matrix Multiplication (MatMul) operations. 20 // We currently register this kernel only for MKL supported data 21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined 22 // by the build system only when MKL is chosen as an option at configure stage 23 // and when it is undefined at build time, this file becomes an empty 24 // compilation unit 25 26 #define EIGEN_USE_THREADS 27 28 #if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) 29 #include <vector> 30 #include "mkl_cblas.h" 31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 32 #include "tensorflow/core/framework/op.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/framework/register_types.h" 35 #include "tensorflow/core/framework/tensor.h" 36 #include "tensorflow/core/framework/tensor_shape.h" 37 #include "tensorflow/core/framework/type_traits.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/kernels/fill_functor.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/types.h" 42 43 namespace tensorflow { 44 45 typedef Eigen::ThreadPoolDevice CPUDevice; 46 47 template <typename Device, typename Scalar> 48 class BatchMatMulMkl : public OpKernel { 49 public: BatchMatMulMkl(OpKernelConstruction * context)50 explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) { 51 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); 52 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); 53 } 54 ~BatchMatMulMkl()55 virtual ~BatchMatMulMkl() {} 56 Compute(OpKernelContext * ctx)57 void Compute(OpKernelContext *ctx) override { 58 const Tensor &lhs = ctx->input(0); 59 const Tensor &rhs = ctx->input(1); 60 OP_REQUIRES(ctx, lhs.dims() == rhs.dims(), 61 errors::InvalidArgument("lhs and rhs has different ndims: ", 62 lhs.shape().DebugString(), " vs. ", 63 rhs.shape().DebugString())); 64 const int ndims = lhs.dims(); 65 OP_REQUIRES( 66 ctx, ndims >= 2, 67 errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims)); 68 TensorShape out_shape; 69 for (int i = 0; i < ndims - 2; ++i) { 70 OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i), 71 errors::InvalidArgument( 72 "lhs.dim(", i, ") and rhs.dim(", i, 73 ") must be the same: ", lhs.shape().DebugString(), " vs ", 74 rhs.shape().DebugString())); 75 out_shape.AddDim(lhs.dim_size(i)); 76 } 77 auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements(); 78 auto lhs_rows = lhs.dim_size(ndims - 2); 79 auto lhs_cols = lhs.dim_size(ndims - 1); 80 auto rhs_rows = rhs.dim_size(ndims - 2); 81 auto rhs_cols = rhs.dim_size(ndims - 1); 82 if (adj_x_) std::swap(lhs_rows, lhs_cols); 83 if (adj_y_) std::swap(rhs_rows, rhs_cols); 84 OP_REQUIRES(ctx, lhs_cols == rhs_rows, 85 errors::InvalidArgument( 86 "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows, 87 ": ", lhs.shape().DebugString(), " ", 88 rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_)); 89 out_shape.AddDim(lhs_rows); 90 out_shape.AddDim(rhs_cols); 91 Tensor *out = nullptr; 92 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 93 if (out->NumElements() == 0) { 94 return; 95 } 96 if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { 97 functor::SetZeroFunctor<Device, Scalar> f; 98 f(ctx->eigen_device<Device>(), out->flat<Scalar>()); 99 return; 100 } 101 102 auto rhs_reshaped = rhs.template flat_inner_dims<Scalar, 3>(); 103 auto lhs_reshaped = lhs.template flat_inner_dims<Scalar, 3>(); 104 auto out_reshaped = out->template flat_inner_dims<Scalar, 3>(); 105 const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1); 106 const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2); 107 const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2); 108 109 std::vector<MKL_INT> m_array(batch_size, M); 110 std::vector<MKL_INT> n_array(batch_size, N); 111 std::vector<MKL_INT> k_array(batch_size, K); 112 std::vector<MKL_INT> lda_array(batch_size, adj_x_ ? M : K); 113 std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N); 114 std::vector<MKL_INT> ldc_array(batch_size, N); 115 std::vector<MKL_INT> group_size(1, batch_size); 116 std::vector<const Scalar *> a_array; 117 std::vector<const Scalar *> b_array; 118 std::vector<Scalar *> c_array; 119 a_array.reserve(batch_size); 120 b_array.reserve(batch_size); 121 c_array.reserve(batch_size); 122 for (int64 i = 0; i < batch_size; i++) { 123 a_array.push_back(&lhs_reshaped(i, 0, 0)); 124 b_array.push_back(&rhs_reshaped(i, 0, 0)); 125 c_array.push_back(&out_reshaped(i, 0, 0)); 126 } 127 128 MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0], 129 &k_array[0], &a_array[0], &lda_array[0], &b_array[0], 130 &ldb_array[0], &c_array[0], &ldc_array[0], 1, 131 &group_size[0]); 132 } 133 134 private: 135 bool adj_x_; 136 bool adj_y_; 137 MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const float ** A_Array,const MKL_INT * lda_Array,const float ** B_Array,const MKL_INT * ldb_Array,float ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)138 void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, 139 const bool TransB, const MKL_INT *M_Array, 140 const MKL_INT *N_Array, const MKL_INT *K_Array, 141 const float **A_Array, const MKL_INT *lda_Array, 142 const float **B_Array, const MKL_INT *ldb_Array, 143 float **C_Array, const MKL_INT *ldc_Array, 144 const MKL_INT group_count, const MKL_INT *group_size) { 145 std::vector<CBLAS_TRANSPOSE> TransA_Array( 146 group_size[0], TransA ? CblasTrans : CblasNoTrans); 147 std::vector<CBLAS_TRANSPOSE> TransB_Array( 148 group_size[0], TransB ? CblasTrans : CblasNoTrans); 149 std::vector<float> alpha_Array(group_size[0], 1.0); 150 std::vector<float> beta_Array(group_size[0], 0.0); 151 cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], M_Array, 152 N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array, 153 B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array, 154 group_count, group_size); 155 } 156 MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const double ** A_Array,const MKL_INT * lda_Array,const double ** B_Array,const MKL_INT * ldb_Array,double ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)157 void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, 158 const bool TransB, const MKL_INT *M_Array, 159 const MKL_INT *N_Array, const MKL_INT *K_Array, 160 const double **A_Array, const MKL_INT *lda_Array, 161 const double **B_Array, const MKL_INT *ldb_Array, 162 double **C_Array, const MKL_INT *ldc_Array, 163 const MKL_INT group_count, const MKL_INT *group_size) { 164 std::vector<CBLAS_TRANSPOSE> TransA_array( 165 group_size[0], TransA ? CblasTrans : CblasNoTrans); 166 std::vector<CBLAS_TRANSPOSE> TransB_array( 167 group_size[0], TransB ? CblasTrans : CblasNoTrans); 168 std::vector<double> alpha_Array(group_size[0], 1.0); 169 std::vector<double> beta_Array(group_size[0], 0.0); 170 cblas_dgemm_batch(Layout, &TransA_array[0], &TransB_array[0], M_Array, 171 N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array, 172 B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array, 173 group_count, group_size); 174 } 175 MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const complex64 ** A_Array,const MKL_INT * lda_Array,const complex64 ** B_Array,const MKL_INT * ldb_Array,complex64 ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)176 void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, 177 const bool TransB, const MKL_INT *M_Array, 178 const MKL_INT *N_Array, const MKL_INT *K_Array, 179 const complex64 **A_Array, const MKL_INT *lda_Array, 180 const complex64 **B_Array, const MKL_INT *ldb_Array, 181 complex64 **C_Array, const MKL_INT *ldc_Array, 182 const MKL_INT group_count, const MKL_INT *group_size) { 183 std::vector<CBLAS_TRANSPOSE> TransA_array( 184 group_size[0], TransA ? CblasConjTrans : CblasNoTrans); 185 std::vector<CBLAS_TRANSPOSE> TransB_array( 186 group_size[0], TransB ? CblasConjTrans : CblasNoTrans); 187 std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f}); 188 std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f}); 189 cblas_cgemm_batch( 190 Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, 191 static_cast<const void *>(&alpha_Array[0]), 192 reinterpret_cast<const void **>(A_Array), lda_Array, 193 reinterpret_cast<const void **>(B_Array), ldb_Array, 194 static_cast<const void *>(&beta_Array[0]), 195 reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size); 196 } 197 MklCblasGemmBatch(const CBLAS_LAYOUT Layout,const bool TransA,const bool TransB,const MKL_INT * M_Array,const MKL_INT * N_Array,const MKL_INT * K_Array,const complex128 ** A_Array,const MKL_INT * lda_Array,const complex128 ** B_Array,const MKL_INT * ldb_Array,complex128 ** C_Array,const MKL_INT * ldc_Array,const MKL_INT group_count,const MKL_INT * group_size)198 void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, 199 const bool TransB, const MKL_INT *M_Array, 200 const MKL_INT *N_Array, const MKL_INT *K_Array, 201 const complex128 **A_Array, const MKL_INT *lda_Array, 202 const complex128 **B_Array, const MKL_INT *ldb_Array, 203 complex128 **C_Array, const MKL_INT *ldc_Array, 204 const MKL_INT group_count, const MKL_INT *group_size) { 205 std::vector<CBLAS_TRANSPOSE> TransA_array( 206 group_size[0], TransA ? CblasConjTrans : CblasNoTrans); 207 std::vector<CBLAS_TRANSPOSE> TransB_array( 208 group_size[0], TransB ? CblasConjTrans : CblasNoTrans); 209 std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f}); 210 std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f}); 211 cblas_zgemm_batch( 212 Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, 213 static_cast<const void *>(&alpha_Array[0]), 214 reinterpret_cast<const void **>(A_Array), lda_Array, 215 reinterpret_cast<const void **>(B_Array), ldb_Array, 216 static_cast<const void *>(&beta_Array[0]), 217 reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size); 218 } 219 }; 220 221 #define REGISTER_BATCH_MATMUL_MKL(TYPE) \ 222 REGISTER_KERNEL_BUILDER( \ 223 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ 224 BatchMatMulMkl<CPUDevice, TYPE>) 225 226 #ifdef ENABLE_MKL 227 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); 228 TF_CALL_double(REGISTER_BATCH_MATMUL_MKL); 229 TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL); 230 TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL); 231 #endif // ENABLE_MKL 232 233 } // end namespace tensorflow 234 #endif 235