1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 // This file uses MKL CBLAS xGEMM for acceleration of TF Matrix-Matrix 19 // Multiplication (MatMul) operations. 20 // We currently register this kernel only for MKL supported data 21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined 22 // by the build system only when MKL is chosen as an option at configure stage 23 // and when it is undefined at build time, this file becomes an empty 24 // compilation unit 25 26 #if defined(INTEL_MKL) 27 28 #include "tensorflow/core/framework/op.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/register_types.h" 31 #include "tensorflow/core/kernels/fill_functor.h" 32 33 // This header file is part of MKL ML, need equivalent file in MKL DNN 34 #ifndef INTEL_MKL_DNN_ONLY 35 #include "mkl_cblas.h" 36 #else 37 #include "mkldnn.h" 38 #endif 39 40 namespace tensorflow { 41 42 typedef Eigen::ThreadPoolDevice CPUDevice; 43 44 template <typename Device, typename T, bool USE_CUBLAS> 45 class MklMatMulOp : public OpKernel { 46 public: MklMatMulOp(OpKernelConstruction * ctx)47 explicit MklMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 48 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); 49 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); 50 } 51 Compute(OpKernelContext * ctx)52 void Compute(OpKernelContext* ctx) override { 53 const Tensor& a = ctx->input(0); 54 const Tensor& b = ctx->input(1); 55 56 // Check that the dimensions of the two matrices are valid. 57 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), 58 errors::InvalidArgument("In[0] is not a matrix")); 59 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), 60 errors::InvalidArgument("In[1] is not a matrix")); 61 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 62 dim_pair[0].first = transpose_a_ ? 0 : 1; 63 dim_pair[0].second = transpose_b_ ? 1 : 0; 64 65 OP_REQUIRES( 66 ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), 67 errors::InvalidArgument( 68 "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), 69 ", In[1]: ", b.shape().DebugString())); 70 int a_dim_remaining = 1 - dim_pair[0].first; 71 int b_dim_remaining = 1 - dim_pair[0].second; 72 TensorShape out_shape( 73 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); 74 Tensor* out = nullptr; 75 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 76 77 if (out->NumElements() == 0) { 78 // If a has shape [0, x] or b has shape [x, 0], the output shape 79 // is a 0-element matrix, so there is nothing to do. 80 return; 81 } 82 83 if (a.NumElements() == 0 || b.NumElements() == 0) { 84 // If a has shape [x, 0] and b has shape [0, y], the 85 // output shape is [x, y] where x and y are non-zero, so we fill 86 // the output with zeros. 87 functor::SetZeroFunctor<Device, T> f; 88 f(ctx->eigen_device<Device>(), out->flat<T>()); 89 return; 90 } 91 92 const int m = a.dim_size(1 - dim_pair[0].first); 93 const int k = a.dim_size(dim_pair[0].first); 94 const int n = b.dim_size(1 - dim_pair[0].second); 95 bool transpose_a = dim_pair[0].first == 0; 96 bool transpose_b = dim_pair[0].second == 1; 97 98 auto a_ptr = (a.template flat<T>().data()); 99 auto b_ptr = (b.template flat<T>().data()); 100 auto c_ptr = (out->template flat<T>().data()); 101 102 MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k, 103 b_ptr, transpose_b ? k : n, c_ptr, n); 104 } 105 106 private: 107 bool transpose_a_; 108 bool transpose_b_; 109 // -------------------------------------------------------------------------- 110 // 111 // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS 112 // interface. c = op(a) * op(b) 113 // 114 // @param transa Specifies the form of op(a) used in MatMul. If transa is 115 // true, then op(a) = a^T, otherwise op(a) = a 116 // 117 // @param transb Specifies the form of op(b) used in MatMul. If transb is 118 // true, then op(b) = b^T, otherwise op(b) = b 119 // 120 // @param m Specifies the number of rows of the matrix op(a) and of the 121 // matrix c. The value of m must be at least zero. 122 // 123 // @param n Specifies the number of columns of the matrix op(b) and the 124 // number of columns of the matrix c. The value of n must be at least zero. 125 // 126 // @param k Specifies the number of columns of the matrix op(a) and the 127 // number of rows of the matrix op(b) 128 // 129 // @param a Address of matrix a 130 // 131 // @param lda Leading dimension of 'a' matrix. This is set at calling site 132 // depending on transa parameter. Since TF uses row-major 133 // layout, leading dimension is the stride between consecutive rows 134 // lda = max(1,k) when transa is false, otherwise lda = max(1,m) 135 // 136 // @param b Address of matrix b 137 // 138 // @param ldb Leading dimension of 'b' matrix. This is set at calling site 139 // depending on transb parameter. Since TF uses row-major 140 // layout, leading dimension is the stride between consecutive rows 141 // ldb = max(1,n) when transb is false, otherwise ldb = max(1,k) 142 // 143 // @param c Address of matrix c 144 // 145 // @param ldc Leading dimension of 'c' matrix. Since TF uses row-major 146 // layout, leading dimension is the stride between consecutive rows, max(1,n) 147 // 148 // -------------------------------------------------------------------------- MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const float * a,const int lda,const float * b,const int ldb,float * c,const int ldc)149 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 150 const int k, const float* a, const int lda, const float* b, 151 const int ldb, float* c, const int ldc) { 152 // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b) 153 // + beta * c. 154 // Since TF MatMul does not have parameters for alpha, beta, we set them to 155 // 1.0 and 0.0 respectively. 156 const float alpha = 1.0f; 157 const float beta = 0.0f; 158 #if defined(INTEL_MKL_DNN_ONLY) 159 const char* const ftrans[] = {"N", "T", "C"}; 160 int index_transa = transa ? 1 : 0; 161 int index_transb = transb ? 1 : 0; 162 VLOG(2) << "MKL DNN SGEMM called"; 163 // MKL DNN only supports the Fortran api and requires column major while 164 // Tensorflow uses row major so we reverse the order A and B 165 mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha, 166 b, &ldb, a, &lda, &beta, c, &ldc); 167 #else 168 // MKL ML binary uses CBLAS API 169 cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 170 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, 171 ldb, beta, c, ldc); 172 #endif 173 } 174 175 // MKLDNN only supports SGEMM 176 #ifndef INTEL_MKL_DNN_ONLY 177 178 // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about 179 // parameters, look at FP32 function description. MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const double * a,const int lda,const double * b,const int ldb,double * c,const int ldc)180 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 181 const int k, const double* a, const int lda, const double* b, 182 const int ldb, double* c, const int ldc) { 183 const double alpha = 1.0; 184 const double beta = 0.0; 185 cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 186 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, 187 ldb, beta, c, ldc); 188 } 189 190 // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. 191 // For detailed info about parameters, look at FP32 function description. MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const complex64 * a,const int lda,const complex64 * b,const int ldb,complex64 * c,int const ldc)192 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 193 const int k, const complex64* a, const int lda, 194 const complex64* b, const int ldb, complex64* c, 195 int const ldc) { 196 const MKL_Complex8 alpha = {1.0f, 0.0f}; 197 const MKL_Complex8 beta = {0.0f, 0.0f}; 198 cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 199 transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, 200 reinterpret_cast<const MKL_Complex8*>(a), lda, 201 reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta, 202 reinterpret_cast<MKL_Complex8*>(c), ldc); 203 } 204 205 // Matrix-Matrix Multiplication with Complex128 (std::complex<double>) 206 // tensors. For detailed info about parameters, look at FP32 function 207 // description. MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const complex128 * a,const int lda,const complex128 * b,const int ldb,complex128 * c,const int ldc)208 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 209 const int k, const complex128* a, const int lda, 210 const complex128* b, const int ldb, complex128* c, 211 const int ldc) { 212 const MKL_Complex16 alpha = {1.0, 0.0}; 213 const MKL_Complex16 beta = {0.0, 0.0}; 214 cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 215 transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, 216 reinterpret_cast<const MKL_Complex16*>(a), lda, 217 reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, 218 reinterpret_cast<MKL_Complex16*>(c), ldc); 219 } 220 #endif // !INTEL_MKL_DNN_ONLY 221 }; 222 223 #define REGISTER_CPU(T) \ 224 REGISTER_KERNEL_BUILDER( \ 225 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 226 MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); 227 228 #ifdef ENABLE_MKL 229 // TODO(inteltf) Consider template specialization when adding/removing 230 // additional types 231 TF_CALL_float(REGISTER_CPU); 232 233 #ifndef INTEL_MKL_DNN_ONLY 234 TF_CALL_double(REGISTER_CPU); 235 TF_CALL_complex64(REGISTER_CPU); 236 TF_CALL_complex128(REGISTER_CPU); 237 #endif // !INTEL_MKL_DNN_ONLY 238 #endif // ENABLE_MKL 239 240 } // namespace tensorflow 241 #endif // INTEL_MKL 242