1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 // This file uses MKL CBLAS xGEMM for acceleration of TF Matrix-Matrix
19 // Multiplication (MatMul) operations.
20 // We currently register this kernel only for MKL supported data
21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined
22 // by the build system only when MKL is chosen as an option at configure stage
23 // and when it is undefined at build time, this file becomes an empty
24 // compilation unit
25 
26 #if defined(INTEL_MKL)
27 
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 
33 // This header file is part of MKL ML, need equivalent file in MKL DNN
34 #ifndef INTEL_MKL_DNN_ONLY
35 #include "mkl_cblas.h"
36 #else
37 #include "mkldnn.h"
38 #endif
39 
40 namespace tensorflow {
41 
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 
44 template <typename Device, typename T, bool USE_CUBLAS>
45 class MklMatMulOp : public OpKernel {
46  public:
MklMatMulOp(OpKernelConstruction * ctx)47   explicit MklMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
48     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
49     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
50   }
51 
Compute(OpKernelContext * ctx)52   void Compute(OpKernelContext* ctx) override {
53     const Tensor& a = ctx->input(0);
54     const Tensor& b = ctx->input(1);
55 
56     // Check that the dimensions of the two matrices are valid.
57     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
58                 errors::InvalidArgument("In[0] is not a matrix"));
59     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
60                 errors::InvalidArgument("In[1] is not a matrix"));
61     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
62     dim_pair[0].first = transpose_a_ ? 0 : 1;
63     dim_pair[0].second = transpose_b_ ? 1 : 0;
64 
65     OP_REQUIRES(
66         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
67         errors::InvalidArgument(
68             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
69             ", In[1]: ", b.shape().DebugString()));
70     int a_dim_remaining = 1 - dim_pair[0].first;
71     int b_dim_remaining = 1 - dim_pair[0].second;
72     TensorShape out_shape(
73         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
74     Tensor* out = nullptr;
75     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
76 
77     if (out->NumElements() == 0) {
78       // If a has shape [0, x] or b has shape [x, 0], the output shape
79       // is a 0-element matrix, so there is nothing to do.
80       return;
81     }
82 
83     if (a.NumElements() == 0 || b.NumElements() == 0) {
84       // If a has shape [x, 0] and b has shape [0, y], the
85       // output shape is [x, y] where x and y are non-zero, so we fill
86       // the output with zeros.
87       functor::SetZeroFunctor<Device, T> f;
88       f(ctx->eigen_device<Device>(), out->flat<T>());
89       return;
90     }
91 
92     const int m = a.dim_size(1 - dim_pair[0].first);
93     const int k = a.dim_size(dim_pair[0].first);
94     const int n = b.dim_size(1 - dim_pair[0].second);
95     bool transpose_a = dim_pair[0].first == 0;
96     bool transpose_b = dim_pair[0].second == 1;
97 
98     auto a_ptr = (a.template flat<T>().data());
99     auto b_ptr = (b.template flat<T>().data());
100     auto c_ptr = (out->template flat<T>().data());
101 
102     MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k,
103                 b_ptr, transpose_b ? k : n, c_ptr, n);
104   }
105 
106  private:
107   bool transpose_a_;
108   bool transpose_b_;
109   // --------------------------------------------------------------------------
110   //
111   // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS
112   // interface. c = op(a) * op(b)
113   //
114   // @param transa  Specifies the form of op(a) used in MatMul. If transa is
115   // true, then op(a) = a^T, otherwise op(a) = a
116   //
117   // @param transb  Specifies the form of op(b) used in MatMul. If transb is
118   // true, then op(b) = b^T, otherwise op(b) = b
119   //
120   // @param m       Specifies the number of rows of the matrix op(a) and of the
121   // matrix c. The value of m must be at least zero.
122   //
123   // @param n       Specifies the number of columns of the matrix op(b) and the
124   // number of columns of the matrix c. The value of n must be at least zero.
125   //
126   // @param k       Specifies the number of columns of the matrix op(a) and the
127   // number of rows of the matrix op(b)
128   //
129   // @param a       Address of matrix a
130   //
131   // @param lda     Leading dimension of 'a' matrix. This is set at calling site
132   // depending on transa parameter. Since TF uses row-major
133   // layout, leading dimension is the stride between consecutive rows
134   // lda = max(1,k) when transa is false, otherwise lda = max(1,m)
135   //
136   // @param b       Address of matrix b
137   //
138   // @param ldb     Leading dimension of 'b' matrix. This is set at calling site
139   // depending on transb parameter. Since TF uses row-major
140   // layout, leading dimension is the stride between consecutive rows
141   // ldb = max(1,n) when transb is false, otherwise ldb = max(1,k)
142   //
143   // @param c       Address of matrix c
144   //
145   // @param ldc     Leading dimension of 'c' matrix. Since TF uses row-major
146   // layout, leading dimension is the stride between consecutive rows, max(1,n)
147   //
148   // --------------------------------------------------------------------------
MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const float * a,const int lda,const float * b,const int ldb,float * c,const int ldc)149   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
150                    const int k, const float* a, const int lda, const float* b,
151                    const int ldb, float* c, const int ldc) {
152     // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b)
153     // + beta * c.
154     // Since TF MatMul does not have parameters for alpha, beta, we set them to
155     // 1.0 and 0.0 respectively.
156     const float alpha = 1.0f;
157     const float beta = 0.0f;
158 #if defined(INTEL_MKL_DNN_ONLY)
159     const char* const ftrans[] = {"N", "T", "C"};
160     int index_transa = transa ? 1 : 0;
161     int index_transb = transb ? 1 : 0;
162     VLOG(2) << "MKL DNN SGEMM called";
163     // MKL DNN only supports the Fortran api and requires column major while
164     // Tensorflow uses row major so we reverse the order A and B
165     mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
166                  b, &ldb, a, &lda, &beta, c, &ldc);
167 #else
168     // MKL ML binary uses CBLAS API
169     cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
170                 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
171                 ldb, beta, c, ldc);
172 #endif
173   }
174 
175   // MKLDNN only supports SGEMM
176 #ifndef INTEL_MKL_DNN_ONLY
177 
178   // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
179   // parameters, look at FP32 function description.
MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const double * a,const int lda,const double * b,const int ldb,double * c,const int ldc)180   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
181                    const int k, const double* a, const int lda, const double* b,
182                    const int ldb, double* c, const int ldc) {
183     const double alpha = 1.0;
184     const double beta = 0.0;
185     cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
186                 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
187                 ldb, beta, c, ldc);
188   }
189 
190   // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
191   // For detailed info about parameters, look at FP32 function description.
MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const complex64 * a,const int lda,const complex64 * b,const int ldb,complex64 * c,int const ldc)192   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
193                    const int k, const complex64* a, const int lda,
194                    const complex64* b, const int ldb, complex64* c,
195                    int const ldc) {
196     const MKL_Complex8 alpha = {1.0f, 0.0f};
197     const MKL_Complex8 beta = {0.0f, 0.0f};
198     cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
199                 transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
200                 reinterpret_cast<const MKL_Complex8*>(a), lda,
201                 reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
202                 reinterpret_cast<MKL_Complex8*>(c), ldc);
203   }
204 
205   // Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
206   // tensors. For detailed info about parameters, look at FP32 function
207   // description.
MklBlasGemm(bool transa,bool transb,const int m,const int n,const int k,const complex128 * a,const int lda,const complex128 * b,const int ldb,complex128 * c,const int ldc)208   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
209                    const int k, const complex128* a, const int lda,
210                    const complex128* b, const int ldb, complex128* c,
211                    const int ldc) {
212     const MKL_Complex16 alpha = {1.0, 0.0};
213     const MKL_Complex16 beta = {0.0, 0.0};
214     cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
215                 transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
216                 reinterpret_cast<const MKL_Complex16*>(a), lda,
217                 reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
218                 reinterpret_cast<MKL_Complex16*>(c), ldc);
219   }
220 #endif  // !INTEL_MKL_DNN_ONLY
221 };
222 
223 #define REGISTER_CPU(T)                                         \
224   REGISTER_KERNEL_BUILDER(                                      \
225       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
226       MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
227 
228 #ifdef ENABLE_MKL
229 // TODO(inteltf) Consider template specialization when adding/removing
230 // additional types
231 TF_CALL_float(REGISTER_CPU);
232 
233 #ifndef INTEL_MKL_DNN_ONLY
234 TF_CALL_double(REGISTER_CPU);
235 TF_CALL_complex64(REGISTER_CPU);
236 TF_CALL_complex128(REGISTER_CPU);
237 #endif  // !INTEL_MKL_DNN_ONLY
238 #endif  // ENABLE_MKL
239 
240 }  // namespace tensorflow
241 #endif  // INTEL_MKL
242