1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/matmul_op.h"
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/kernels/fill_functor.h"
26 #include "tensorflow/core/util/matmul_autotune.h"
27 #if GOOGLE_CUDA
28 #include "cuda/include/cuda.h"
29 #include "tensorflow/core/kernels/gpu_utils.h"
30 #include "tensorflow/core/platform/stream_executor.h"
31 #endif  // GOOGLE_CUDA
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 #ifdef TENSORFLOW_USE_SYCL
38 typedef Eigen::SyclDevice SYCLDevice;
39 #endif  // TENSORFLOW_USE_SYCL
40 
41 template <typename Device, typename T, bool USE_CUBLAS>
42 struct LaunchMatMul;
43 
44 namespace {
45 // Converts a TensorFlow Tensor to an Eigen Matrix.
46 template <typename T>
47 Eigen::Map<
48     const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
ToEigenMatrix(const Tensor & tensor)49 ToEigenMatrix(const Tensor& tensor) {
50   auto matrix = tensor.matrix<T>();
51   return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
52       matrix.data(), matrix.dimension(0), matrix.dimension(1));
53 }
54 
55 // Converts a TensorFlow Tensor to an Eigen Vector.
56 template <typename T>
ToEigenVector(Tensor * tensor)57 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
58   auto v = tensor->flat<T>();
59   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
60 }
61 template <typename T>
ToEigenVector(const Tensor & tensor)62 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
63     const Tensor& tensor) {
64   auto v = tensor.flat<T>();
65   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
66 }
67 }  // namespace
68 
69 // If either side can be represented as a vector, do an explicit vector
70 // matrix multiply and return true; else return false.
71 //
72 // Note: this uses plain Eigen and not Eigen Tensor because it is more
73 // efficient.
74 template <typename T>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)75 bool ExplicitVectorMatrixOptimization(
76     const Tensor& a, const Tensor& b,
77     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
78     Tensor* out) {
79   if (out->dim_size(0) == 1) {
80     if (dim_pair[0].second == 0) {
81       // Note: this case is optimized in Eigen Tensors.
82       return false;
83     } else {
84       auto out_v = ToEigenVector<T>(out);
85       auto a_v = ToEigenVector<T>(a);
86       auto b_m = ToEigenMatrix<T>(b);
87       out_v.noalias() = b_m * a_v;
88     }
89     return true;
90   } else if (out->dim_size(1) == 1) {
91     auto out_v = ToEigenVector<T>(out);
92     auto a_m = ToEigenMatrix<T>(a);
93     auto b_v = ToEigenVector<T>(b);
94     if (dim_pair[0].first == 0) {
95       out_v.noalias() = a_m.transpose() * b_v;
96     } else {
97       out_v.noalias() = a_m * b_v;
98     }
99     return true;
100   }
101   return false;
102 }
103 // Half is not supported.
104 template <>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)105 bool ExplicitVectorMatrixOptimization<Eigen::half>(
106     const Tensor& a, const Tensor& b,
107     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
108     Tensor* out) {
109   return false;
110 }
111 
112 template <typename Device, typename T>
113 struct LaunchMatMulBase {
114 #if GOOGLE_CUDA
115   typedef se::blas::AlgorithmType AlgorithmType;
116 #else
117   typedef int64 AlgorithmType;
118 #endif  // GOOGLE_CUDA
119 
launchtensorflow::LaunchMatMulBase120   static void launch(
121       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
122       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
123       std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
124 #ifndef TENSORFLOW_USE_SYCL
125     // An explicit vector-matrix multiply is much better optimized than an
126     // implicit one and this is a bottleneck during non-batched inference.
127     bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
128     if (!was_vector) {
129 #endif  // TENSORFLOW_USE_SYCL
130       functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
131                                           out->matrix<T>(), a.matrix<T>(),
132                                           b.matrix<T>(), dim_pair);
133 #ifndef TENSORFLOW_USE_SYCL
134     }
135 #endif  // TENSORFLOW_USE_SYCL
136   }
137 
GetBlasGemmAlgorithmtensorflow::LaunchMatMulBase138   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
139                                    std::vector<int64>* algorithms,
140                                    bool* algorithm_set_flag) {}
141 };
142 // On CPUs, we ignore USE_CUBLAS
143 template <typename T>
144 struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
145 
146 template <typename T, bool USE_CUBLAS>
147 struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
148 
149 #ifdef TENSORFLOW_USE_SYCL
150 template <typename T>
151 struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
152 
153 template <typename T, bool USE_CUBLAS>
154 struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
155 #endif  // TENSORFLOW_USE_SYCL
156 
157 #if GOOGLE_CUDA
158 
159 namespace {
160 
161 template <typename T>
162 struct LaunchBlasGemv {
Computetensorflow::__anon979e9cd20211::LaunchBlasGemv163   static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
164                       uint64 m, uint64 n, const se::DeviceMemory<T>& a,
165                       const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
166                       se::blas::ProfileResult* output_profile) {
167     const auto blas_trans = trans ? se::blas::Transpose::kTranspose
168                                   : se::blas::Transpose::kNoTranspose;
169     if (output_profile == nullptr) {
170       bool blas_launch_status =
171           stream
172               ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
173                              static_cast<T>(0.0), c, 1)
174               .ok();
175       if (!blas_launch_status) {
176         ctx->SetStatus(
177             errors::Internal("Blas GEMV launch failed:  m=", m, ", n=", n));
178       }
179     } else {
180       bool blas_launch_status =
181           stream
182               ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
183                                           a, m, b, 1, static_cast<T>(0.0), c, 1,
184                                           output_profile)
185               .ok();
186       if (!blas_launch_status) {
187         ctx->SetStatus(errors::Internal(
188             "Blas GEMV with profiling launch failed:  m=", m, ", n=", n));
189       }
190     }
191   }
192 
IsSupportedtensorflow::__anon979e9cd20211::LaunchBlasGemv193   static bool IsSupported() { return true; }
194 };
195 
196 template <>
Compute(OpKernelContext * ctx,se::Stream * stream,bool trans,uint64 m,uint64 n,const se::DeviceMemory<Eigen::half> & a,const se::DeviceMemory<Eigen::half> & b,se::DeviceMemory<Eigen::half> * c,se::blas::ProfileResult * output_profile)197 void LaunchBlasGemv<Eigen::half>::Compute(
198     OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
199     const se::DeviceMemory<Eigen::half>& a,
200     const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
201     se::blas::ProfileResult* output_profile) {
202   ctx->SetStatus(errors::Internal(
203       "Blas GEMV launch failed: GEMV is not implemented for float16."));
204 }
205 
206 template <>
IsSupported()207 bool LaunchBlasGemv<Eigen::half>::IsSupported() {
208   return false;
209 }
210 
211 template <typename T>
ShouldUseGemv(uint64 n)212 bool ShouldUseGemv(uint64 n) {
213   return (LaunchBlasGemv<T>::IsSupported() && n == 1);
214 }
215 
216 }  // namespace
217 
GetCublasAutotuneComputationType(const DataType & dtype,se::blas::ComputationType * compute_type)218 bool GetCublasAutotuneComputationType(const DataType& dtype,
219                                       se::blas::ComputationType* compute_type) {
220   using se::blas::ComputationType;
221   bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
222   switch (dtype) {
223     case DT_HALF:
224     case DT_BFLOAT16:
225       if (use_f32_for_f16_computation) {
226         *compute_type = ComputationType::kF32;
227       } else {
228         *compute_type = ComputationType::kF16;
229       }
230       return false;
231     case DT_FLOAT:
232       *compute_type = ComputationType::kF32;
233       return true;
234     case DT_DOUBLE:
235       *compute_type = ComputationType::kF64;
236       return true;
237     default:
238       // Unsupported compute_type, return false.
239       return false;
240   }
241 }
242 
243 // A dummy type to group matmul autotune results together.
244 struct MatmulAutoTuneGroup {
nametensorflow::MatmulAutoTuneGroup245   static string name() { return "Matmul"; }
246 };
247 typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
248                           se::blas::AlgorithmConfig>
249     AutoTuneMatmul;
250 
251 template <typename T>
252 struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
launchtensorflow::LaunchMatMul253   static void launch(
254       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
255       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
256       std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
257     using se::blas::AlgorithmConfig;
258     using se::blas::ComputationType;
259     using se::blas::kDefaultAlgorithm;
260     using se::blas::kDefaultBlasGemm;
261     using se::blas::kDefaultBlasGemv;
262     using se::blas::kNoAlgorithm;
263     using se::blas::ProfileResult;
264     using se::blas::Transpose;
265     Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
266     const uint64 m = a.dim_size(1 - dim_pair[0].first);
267     const uint64 k = a.dim_size(dim_pair[0].first);
268     const uint64 n = b.dim_size(1 - dim_pair[0].second);
269     bool transpose_a = dim_pair[0].first == 0;
270     bool transpose_b = dim_pair[0].second == 1;
271     auto blas_transpose_a = trans[transpose_a];
272     auto blas_transpose_b = trans[transpose_b];
273 
274     auto* stream = ctx->op_device_context()->stream();
275     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
276 
277     auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
278                                 a.template flat<T>().size());
279     auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
280                                 b.template flat<T>().size());
281     auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
282                                 out->template flat<T>().size());
283     auto alpha = static_cast<T>(1.0);
284     auto beta = static_cast<T>(0.0);
285 
286     int device_id = stream->parent()->device_ordinal();
287     DataType dtype = a.dtype();
288     MatmulParameters matmul_parameters = {
289         transpose_a, transpose_b, m, n, k, dtype, device_id,
290     };
291     AlgorithmConfig algorithm_config(kNoAlgorithm);
292 
293     ComputationType computation_type;
294     bool compute_type_supported =
295         GetCublasAutotuneComputationType(dtype, &computation_type);
296     if (use_autotune && compute_type_supported && !algorithms->empty()) {
297       ProfileResult best_result;
298       // TODO(yangzihao): Unify this code with conv autotuning.
299       if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
300                                                &algorithm_config)) {
301         ProfileResult profile_result;
302         for (auto profile_algorithm : (*algorithms)) {
303           // Cublas does
304           // C = A x B
305           // where A, B and C are assumed to be in column major.
306           // We want the output to be in row-major, so we can compute
307           // C' = B' x A' (' stands for transpose)
308           bool cublas_launch_status =
309               stream
310                   ->ThenBlasGemmWithAlgorithm(
311                       blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
312                       transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
313                       &c_ptr, n, computation_type, profile_algorithm,
314                       &profile_result)
315                   .ok();
316           if (cublas_launch_status) {
317             if (profile_result.is_valid()) {
318               if (profile_result.elapsed_time_in_ms() <
319                   best_result.elapsed_time_in_ms()) {
320                 best_result = profile_result;
321               }
322             }
323           }
324         }
325         // Try BlasGemmWithProfiling
326         bool cublas_launch_status =
327             stream
328                 ->ThenBlasGemmWithProfiling(
329                     blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
330                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
331                     &c_ptr, n, &profile_result)
332                 .ok();
333         if (cublas_launch_status) {
334           if (profile_result.is_valid()) {
335             if (profile_result.elapsed_time_in_ms() <
336                 best_result.elapsed_time_in_ms()) {
337               best_result = profile_result;
338             }
339           }
340         }
341         // Try BlasGemvWithProfiling
342         if (ShouldUseGemv<T>(n)) {
343           LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
344                                      transpose_a ? m : k, transpose_a ? k : m,
345                                      a_ptr, b_ptr, &c_ptr, &profile_result);
346           if (profile_result.is_valid()) {
347             if (profile_result.elapsed_time_in_ms() <
348                 best_result.elapsed_time_in_ms()) {
349               best_result = profile_result;
350             }
351           }
352         }
353       }
354       // We make sure that each matmul parameter set only gets one pass of
355       // autotune. If the best result is found, assign it to algorithm_type
356       // and insert it to autotune map. If all internal kernels of
357       // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
358       // autotune map.
359       if (best_result.is_valid()) {
360         algorithm_config.set_algorithm(best_result.algorithm());
361       }
362       AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
363                                             algorithm_config);
364       if (algorithm_config.algorithm() != kNoAlgorithm &&
365           algorithm_config.algorithm() != kDefaultBlasGemm &&
366           algorithm_config.algorithm() != kDefaultBlasGemv) {
367         bool cublas_launch_status =
368             stream
369                 ->ThenBlasGemmWithAlgorithm(
370                     blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
371                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
372                     &c_ptr, n, computation_type, algorithm_config.algorithm(),
373                     nullptr)
374                 .ok();
375         if (!cublas_launch_status) {
376           ctx->SetStatus(errors::Internal(
377               "Blas GEMM with algorithm launch failed : a.shape=(",
378               a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
379               ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
380         }
381       }
382     }
383     // For the following case, we use normal BlasGemm():
384     //  1) We didn't set the use_autotune flag;
385     //  2) compute type does not support autotune;
386     //  3) no algorithm is found;
387     //  4) all internal kernels in autotune return invalid results.
388     //  For the following case, we use normal BlasGemv():
389     //  1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
390     //     and n == 1.
391     //  2) We set the use_autotune flag and it picked up BlasGemv() and set the
392     //     algorithm_config.algorithm() to be kDefaultBlasGemv.
393     if (!use_autotune || !compute_type_supported || algorithms->empty() ||
394         algorithm_config.algorithm() == kNoAlgorithm ||
395         algorithm_config.algorithm() == kDefaultBlasGemm ||
396         algorithm_config.algorithm() == kDefaultBlasGemv) {
397       if (algorithm_config.algorithm() == kDefaultBlasGemv ||
398           ShouldUseGemv<T>(n)) {
399         // This is a matrix*vector multiply so use GEMV to compute A * b.
400         // Here we are multiplying in the natural order, so we have to flip
401         // the transposition flag to compensate for the tensor being stored
402         // row-major.
403         // TODO(yangzihao): Add Gemv as an autotuning option too.
404         LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
405                                    transpose_a ? m : k, transpose_a ? k : m,
406                                    a_ptr, b_ptr, &c_ptr, nullptr);
407       } else {
408         // Use C' = B' x A' (' stands for transpose)
409         bool blas_launch_status =
410             stream
411                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
412                                1.0f, b_ptr, transpose_b ? k : n, a_ptr,
413                                transpose_a ? m : k, 0.0f, &c_ptr, n)
414                 .ok();
415         if (!blas_launch_status) {
416           ctx->SetStatus(errors::Internal(
417               "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
418               a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
419               "), m=", m, ", n=", n, ", k=", k));
420         }
421       }
422     }
423   }
424 
GetBlasGemmAlgorithmtensorflow::LaunchMatMul425   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
426                                    std::vector<int64>* algorithms,
427                                    bool* algorithm_set_flag) {
428     if (*algorithm_set_flag == false) {
429       auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
430       stream->parent()->GetBlasGemmAlgorithms(algorithms);
431       *algorithm_set_flag = true;
432     }
433   }
434 };
435 
436 #endif  // GOOGLE_CUDA
437 
438 template <typename Device, typename T, bool USE_CUBLAS>
439 class MatMulOp : public OpKernel {
440  public:
MatMulOp(OpKernelConstruction * ctx)441   explicit MatMulOp(OpKernelConstruction* ctx)
442       : OpKernel(ctx), algorithms_set_already_(false) {
443     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
444     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
445 
446     LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
447         ctx, &algorithms_, &algorithms_set_already_);
448     use_autotune_ = MatmulAutotuneEnable();
449   }
450 
Compute(OpKernelContext * ctx)451   void Compute(OpKernelContext* ctx) override {
452     const Tensor& a = ctx->input(0);
453     const Tensor& b = ctx->input(1);
454 
455     // Check that the dimensions of the two matrices are valid.
456     OP_REQUIRES(
457         ctx, TensorShapeUtils::IsMatrix(a.shape()),
458         errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
459                                 a.shape().DebugString()));
460     OP_REQUIRES(
461         ctx, TensorShapeUtils::IsMatrix(b.shape()),
462         errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
463                                 b.shape().DebugString()));
464     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
465     dim_pair[0].first = transpose_a_ ? 0 : 1;
466     dim_pair[0].second = transpose_b_ ? 1 : 0;
467 
468     OP_REQUIRES(
469         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
470         errors::InvalidArgument(
471             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
472             ", In[1]: ", b.shape().DebugString()));
473     int a_dim_remaining = 1 - dim_pair[0].first;
474     int b_dim_remaining = 1 - dim_pair[0].second;
475     TensorShape out_shape(
476         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
477     Tensor* out = nullptr;
478     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
479 
480     if (out->NumElements() == 0) {
481       // If a has shape [0, x] or b has shape [x, 0], the output shape
482       // is a 0-element matrix, so there is nothing to do.
483       return;
484     }
485 
486     if (a.NumElements() == 0 || b.NumElements() == 0) {
487       // If a has shape [x, 0] and b has shape [0, y], the
488       // output shape is [x, y] where x and y are non-zero, so we fill
489       // the output with zeros.
490       functor::SetZeroFunctor<Device, T> f;
491       f(ctx->eigen_device<Device>(), out->flat<T>());
492       return;
493     }
494 
495     if (std::is_same<T, bfloat16>::value) {
496       bool is_cpu = std::is_same<Device, CPUDevice>::value;
497       OP_REQUIRES(ctx, is_cpu,
498                   errors::Internal("bfloat16 matmul is not supported by GPU"));
499       Tensor a_float, b_float, out_float;
500       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
501       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
502       OP_REQUIRES_OK(ctx,
503                      ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
504 
505       // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
506       BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
507                       a.NumElements());
508       BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
509                       b.NumElements());
510 
511       LaunchMatMul<Device, float, USE_CUBLAS>::launch(
512           ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
513           &out_float);
514       FloatToBFloat16(out_float.flat<float>().data(),
515                       out->flat<bfloat16>().data(), out->NumElements());
516     } else {
517       LaunchMatMul<Device, T, USE_CUBLAS>::launch(
518           ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
519     }
520   }
521 
522  private:
523   std::vector<int64> algorithms_;
524   bool algorithms_set_already_;
525   bool use_autotune_;
526   bool transpose_a_;
527   bool transpose_b_;
528 };
529 
530 namespace functor {
531 
532 // Partial specialization MatMulFunctor<Device=CPUDevice, T>.
533 template <typename T>
534 struct MatMulFunctor<CPUDevice, T> {
operator ()tensorflow::functor::MatMulFunctor535   void operator()(
536       const CPUDevice& d, typename MatMulTypes<T>::out_type out,
537       typename MatMulTypes<T>::in_type in0,
538       typename MatMulTypes<T>::in_type in1,
539       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
540     MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
541   }
542 };
543 
544 #ifdef TENSORFLOW_USE_SYCL
545 // Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
546 template <typename T>
547 struct MatMulFunctor<SYCLDevice, T> {
operator ()tensorflow::functor::MatMulFunctor548   void operator()(
549       const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
550       typename MatMulTypes<T>::in_type in0,
551       typename MatMulTypes<T>::in_type in1,
552       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
553     MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
554   }
555 };
556 #endif  // TENSORFLOW_USE_SYCL
557 
558 }  // end namespace functor
559 
560 #define REGISTER_CPU_EIGEN(T)                                                  \
561   REGISTER_KERNEL_BUILDER(                                                     \
562       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
563       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
564 
565 #define REGISTER_CPU(T)                                             \
566   REGISTER_KERNEL_BUILDER(                                          \
567       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"),     \
568       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
569   REGISTER_CPU_EIGEN(T);
570 
571 #define REGISTER_GPU(T)                                            \
572   REGISTER_KERNEL_BUILDER(                                         \
573       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"),    \
574       MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
575   REGISTER_KERNEL_BUILDER(Name("MatMul")                           \
576                               .Device(DEVICE_GPU)                  \
577                               .TypeConstraint<T>("T")              \
578                               .Label("cublas"),                    \
579                           MatMulOp<GPUDevice, T, true /* cublas */>)
580 
581 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
582 
583 // MKL supports float, double, complex64 and complex128 types for
584 // matrix-multiplication, and these kernels are registered in mkl_matmul_op.cc.
585 // MKL does not support half, bfloat16, int32 and int64 types for
586 // matrix-multiplication, so register the kernel to use default Eigen based
587 // implementations for these types. REGISTER_CPU defines two versions - Eigen
588 // label and NO-LABEL
589 TF_CALL_half(REGISTER_CPU);
590 TF_CALL_bfloat16(REGISTER_CPU);
591 TF_CALL_int32(REGISTER_CPU);
592 TF_CALL_int64(REGISTER_CPU);
593 
594 // Float is supported in both MKL DNN as well as in MKL ML
595 // Registration for NO-LABEL version is in mkl_matmul_op.cc for types supported
596 // by MKL. However we define Eigen label version here just to pass a few unit
597 // tests
598 TF_CALL_float(REGISTER_CPU_EIGEN);
599 
600 // MKL DNN does not support complex64/complex128/double, if user specifies
601 // to use only opensource MKL DNN then use default implementation for these
602 // types otherwise use GEMM from MKL ML binary
603 
604 #if defined(INTEL_MKL_DNN_ONLY)
605 TF_CALL_complex64(REGISTER_CPU);
606 TF_CALL_complex128(REGISTER_CPU);
607 TF_CALL_double(REGISTER_CPU);
608 #else  // INTEL_MKL_DNN_ONLY
609 TF_CALL_complex64(REGISTER_CPU_EIGEN);
610 TF_CALL_complex128(REGISTER_CPU_EIGEN);
611 TF_CALL_double(REGISTER_CPU_EIGEN);
612 #endif  // INTEL_MKL_DNN_ONLY
613 
614 #else   // INTEL_MKL && ENABLE_MKL
615 TF_CALL_float(REGISTER_CPU);
616 TF_CALL_double(REGISTER_CPU);
617 TF_CALL_half(REGISTER_CPU);
618 TF_CALL_bfloat16(REGISTER_CPU);
619 TF_CALL_int32(REGISTER_CPU);
620 TF_CALL_int64(REGISTER_CPU);
621 TF_CALL_complex64(REGISTER_CPU);
622 TF_CALL_complex128(REGISTER_CPU);
623 #endif  // INTEL_MKL && ENABLE_MKL
624 
625 #if GOOGLE_CUDA
626 TF_CALL_float(REGISTER_GPU);
627 TF_CALL_double(REGISTER_GPU);
628 TF_CALL_complex64(REGISTER_GPU);
629 TF_CALL_complex128(REGISTER_GPU);
630 TF_CALL_half(REGISTER_GPU);
631 #endif  // GOOGLE_CUDA
632 
633 #ifdef TENSORFLOW_USE_SYCL
634 #define REGISTER_SYCL(T)                                         \
635   REGISTER_KERNEL_BUILDER(                                       \
636       Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
637       MatMulOp<SYCLDevice, T, false /* xxblas */>);              \
638   REGISTER_KERNEL_BUILDER(Name("MatMul")                         \
639                               .Device(DEVICE_SYCL)               \
640                               .TypeConstraint<T>("T")            \
641                               .Label("eigen"),                   \
642                           MatMulOp<SYCLDevice, T, false /* xxblas */>)
643 TF_CALL_float(REGISTER_SYCL);
644 TF_CALL_double(REGISTER_SYCL);
645 
646 #endif  // TENSORFLOW_USE_SYCL
647 }  // namespace tensorflow
648