1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <vector>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40 
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44 
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 
49 namespace tensorflow {
50 
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53 
54 namespace {
55 
56 // Returns the pair of dimensions along which to perform Tensor contraction to
57 // emulate matrix multiplication.
58 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
59 // second dimension and Y is contracted along the first dimension (if neither X
60 // nor Y is adjointed). The dimension to contract along is switched when any
61 // operand is adjointed.
62 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)63 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
64   return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
65 }
66 
67 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
68 // in Eigen.
69 template <typename Scalar, bool IsComplex = true>
70 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel71   static void Conjugate(const OpKernelContext* context, Tensor* out) {
72     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
73     auto z = out->tensor<Scalar, 3>();
74     z.device(d) = z.conjugate();
75   }
76 
RunParallelMatMulKernel77   static void Run(const OpKernelContext* context, const Tensor& in_x,
78                   const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
79                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
80                   int batch_size) {
81     static_assert(IsComplex, "Complex type expected.");
82     auto Tx = in_x.tensor<Scalar, 3>();
83     auto Ty = in_y.tensor<Scalar, 3>();
84     auto Tz = out->tensor<Scalar, 3>();
85     // We use the identities
86     //   conj(a) * conj(b) = conj(a * b)
87     //   conj(a) * b = conj(a * conj(b))
88     // to halve the number of cases. The final conjugation of the result is
89     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
90     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
91     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
92     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
93 
94     const bool should_bcast = bcast.IsBroadcastingRequired();
95     const auto& x_batch_indices = bcast.x_batch_indices();
96     const auto& y_batch_indices = bcast.y_batch_indices();
97     // TODO(rmlarsen): Consider launching these contractions asynchronously.
98     for (int64 i = 0; i < batch_size; ++i) {
99       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
100       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
101 
102       auto x = Tx.template chip<0>(x_batch_index);
103       auto z = Tz.template chip<0>(i);
104       if (adj_x != adj_y) {
105         auto y = Ty.template chip<0>(y_batch_index).conjugate();
106         z.device(d) = x.contract(y, contract_pairs);
107       } else {
108         auto y = Ty.template chip<0>(y_batch_index);
109         z.device(d) = x.contract(y, contract_pairs);
110       }
111     }
112   }
113 };
114 
115 // The Eigen contraction kernel used here is very large and slow to compile,
116 // so we partially specialize ParallelMatMulKernel for real types to avoid all
117 // but one of the instantiations.
118 template <typename Scalar>
119 struct ParallelMatMulKernel<Scalar, false> {
120   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
121 
122   static void Run(const OpKernelContext* context, const Tensor& in_x,
123                   const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
124                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
125                   int batch_size) {
126     const bool should_bcast = bcast.IsBroadcastingRequired();
127     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
128     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
129     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
130     if (batch_size == 1 && !should_bcast) {
131       auto Tx = in_x.flat_inner_dims<Scalar, 2>();
132       auto Ty = in_y.flat_inner_dims<Scalar, 2>();
133       auto Tz = out->flat_inner_dims<Scalar, 2>();
134       Tz.device(d) = Tx.contract(Ty, contract_pairs);
135     } else {
136       auto Tx = in_x.tensor<Scalar, 3>();
137       auto Ty = in_y.tensor<Scalar, 3>();
138       auto Tz = out->tensor<Scalar, 3>();
139       const auto& x_batch_indices = bcast.x_batch_indices();
140       const auto& y_batch_indices = bcast.y_batch_indices();
141       // TODO(rmlarsen): Consider launching these contractions asynchronously.
142       for (int64 i = 0; i < batch_size; ++i) {
143         const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
144         const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
145         auto x = Tx.template chip<0>(x_batch_index);
146         auto y = Ty.template chip<0>(y_batch_index);
147         auto z = Tz.template chip<0>(i);
148 
149         z.device(d) = x.contract(y, contract_pairs);
150       }
151     }
152   }
153 };
154 
155 // Sequential batch matmul kernel that calls the regular Eigen matmul.
156 // We prefer this over the tensor contraction because it performs
157 // better on vector-matrix and matrix-vector products.
158 template <typename Scalar>
159 struct SequentialMatMulKernel {
160   using Matrix =
161       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
162   using ConstMatrixMap = Eigen::Map<const Matrix>;
163   using MatrixMap = Eigen::Map<Matrix>;
164 
165   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
166                                                       int slice) {
167     return ConstMatrixMap(
168         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
169         t.dim_size(1), t.dim_size(2));
170   }
171 
172   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
173     return MatrixMap(
174         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
175         t->dim_size(1), t->dim_size(2));
176   }
177 
178   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
179                   bool adj_y, bool trans_x, bool trans_y,
180                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
181     const bool should_bcast = bcast.IsBroadcastingRequired();
182     const auto& x_batch_indices = bcast.x_batch_indices();
183     const auto& y_batch_indices = bcast.y_batch_indices();
184     for (int64 i = start; i < limit; ++i) {
185       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
186       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
187       auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
188       auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
189       auto z = TensorSliceToEigenMatrix(out, i);
190       // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
191       // and trans_y.
192       if (!adj_x && !trans_x) {
193         if (!adj_y && !trans_y) {
194           z.noalias() = x * y;
195         } else if (adj_y) {
196           z.noalias() = x * y.adjoint();
197         } else {  // trans_y == true
198           z.noalias() = x * y.transpose();
199         }
200       } else if (adj_x) {
201         if (!adj_y && !trans_y) {
202           z.noalias() = x.adjoint() * y;
203         } else if (adj_y) {
204           z.noalias() = x.adjoint() * y.adjoint();
205         } else {  // trans_y == true
206           z.noalias() = x.adjoint() * y.transpose();
207         }
208       } else {  // trans_x == true
209         if (!adj_y && !trans_y) {
210           z.noalias() = x.transpose() * y;
211         } else if (adj_y) {
212           z.noalias() = x.transpose() * y.adjoint();
213         } else {  // trans_y == true
214           z.noalias() = x.transpose() * y.transpose();
215         }
216       }
217     }
218   }
219 };
220 
221 }  // namespace
222 
223 template <typename Device, typename Scalar>
224 struct LaunchBatchMatMul;
225 
226 template <typename Scalar>
227 struct LaunchBatchMatMul<CPUDevice, Scalar> {
228   static void Launch(OpKernelContext* context, const Tensor& in_x,
229                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
230                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
231     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
232         ParallelMatMulKernel;
233     bool conjugate_result = false;
234 
235     // Number of matrix multiplies i.e. size of the batch.
236     const int64 batch_size = bcast.output_batch_size();
237     const int64 cost_per_unit =
238         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
239     const int64 small_dim = std::min(
240         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
241     // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
242     // Jan 21, 2020.
243     const int64 kMaxCostOuterParallelism = 128 * 128;  // heuristic.
244     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
245     // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
246     // evaluation in Eigen Tensor.
247     if (small_dim > 1 &&
248         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
249       // Parallelize over inner dims.
250       // For large matrix products it is counter-productive to parallelize
251       // over the batch dimension.
252       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
253                                 trans_y, bcast, out, batch_size);
254       conjugate_result = adj_x;
255     } else {
256       // Parallelize over outer dims. For small matrices and large batches, it
257       // is counter-productive to parallelize the inner matrix multiplies.
258       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
259             cost_per_unit,
260             [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
261                 int start, int limit) {
262               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
263                                                   trans_x, trans_y, bcast, out,
264                                                   start, limit);
265             });
266     }
267     if (conjugate_result) {
268       // We used one of the identities
269       //   conj(a) * conj(b) = conj(a * b)
270       //   conj(a) * b = conj(a * conj(b))
271       // above, we need to conjugate the final output. This is a
272       // no-op for non-complex types.
273       ParallelMatMulKernel::Conjugate(context, out);
274     }
275   }
276 };
277 
278 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279 
280 namespace {
281 template <typename T>
282 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
283   se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
284   se::DeviceMemory<T> typed(wrapped);
285   return typed;
286 }
287 
288 class BlasScratchAllocator : public se::ScratchAllocator {
289  public:
290   using Stream = se::Stream;
291   using DeviceMemoryBytes = se::DeviceMemory<uint8>;
292 
293   BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
294 
295   int64 GetMemoryLimitInBytes() override { return -1; }
296 
297   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
298       int64 byte_size) override {
299     Tensor temporary_memory;
300 
301     Status allocation_status(context_->allocate_temp(
302         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
303     if (!allocation_status.ok()) {
304       return se::port::StatusOr<DeviceMemoryBytes>(
305           DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
306     }
307     // Hold the reference of the allocated tensors until the end of the
308     // allocator.
309     allocated_tensors_.push_back(temporary_memory);
310     return se::port::StatusOr<DeviceMemoryBytes>(
311         DeviceMemoryBytes::MakeFromByteSize(
312             temporary_memory.flat<uint8>().data(),
313             temporary_memory.flat<uint8>().size()));
314   }
315 
316  private:
317   OpKernelContext* context_;
318   std::vector<Tensor> allocated_tensors_;
319 };
320 }  // namespace
321 
322 template <typename Scalar>
323 struct LaunchBatchMatMul<GPUDevice, Scalar> {
324   static void Launch(OpKernelContext* context, const Tensor& in_x,
325                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
326                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
327     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
328                                    se::blas::Transpose::kTranspose,
329                                    se::blas::Transpose::kConjugateTranspose};
330     const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
331     const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
332     const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
333     const int64 batch_size = bcast.output_batch_size();
334     auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
335     auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
336 
337     auto* stream = context->op_device_context()->stream();
338     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
339 
340     typedef se::DeviceMemory<Scalar> DeviceMemoryType;
341     std::vector<DeviceMemoryType> a_device_memory;
342     std::vector<DeviceMemoryType> b_device_memory;
343     std::vector<DeviceMemoryType> c_device_memory;
344     std::vector<DeviceMemoryType*> a_ptrs;
345     std::vector<DeviceMemoryType*> b_ptrs;
346     std::vector<DeviceMemoryType*> c_ptrs;
347     a_device_memory.reserve(bcast.x_batch_size());
348     b_device_memory.reserve(bcast.y_batch_size());
349     c_device_memory.reserve(batch_size);
350     a_ptrs.reserve(batch_size);
351     b_ptrs.reserve(batch_size);
352     c_ptrs.reserve(batch_size);
353     auto* a_base_ptr = in_x.template flat<Scalar>().data();
354     auto* b_base_ptr = in_y.template flat<Scalar>().data();
355     auto* c_base_ptr = out->template flat<Scalar>().data();
356     uint64 a_stride;
357     uint64 b_stride;
358     uint64 c_stride;
359 
360     bool is_full_broadcast =
361         std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
362     bool use_strided_batched =
363         (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
364         batch_size > 1;
365     if (use_strided_batched) {
366       a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
367       b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
368       c_stride = m * n;
369       a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
370       b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
371       c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
372       a_ptrs.push_back(&a_device_memory.back());
373       b_ptrs.push_back(&b_device_memory.back());
374       c_ptrs.push_back(&c_device_memory.back());
375     } else if (!bcast.IsBroadcastingRequired()) {
376       for (int64 i = 0; i < batch_size; ++i) {
377         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
378         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
379         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
380         a_ptrs.push_back(&a_device_memory.back());
381         b_ptrs.push_back(&b_device_memory.back());
382         c_ptrs.push_back(&c_device_memory.back());
383       }
384     } else {
385       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
386       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
387       for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
388         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
389       }
390       for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
391         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
392       }
393       for (int64 i = 0; i < batch_size; ++i) {
394         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
395         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
396         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
397         c_ptrs.push_back(&c_device_memory.back());
398       }
399     }
400 
401     typedef Scalar Coefficient;
402 
403     // Blas does
404     // C = A x B
405     // where A, B and C are assumed to be in column major.
406     // We want the output to be in row-major, so we can compute
407     // C' = B' x A', where ' stands for transpose (not adjoint).
408     // TODO(yangzihao): Choose the best of the three strategies using autotune.
409     if (batch_size == 1) {
410       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
411       // overhead of the scratch allocator and the batch interface.
412       if (n == 1 &&
413           blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
414           blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
415         // This is a matrix*vector multiply so use GEMV to compute A * b.
416         // Here we are multiplying in the natural order, so we have to flip
417         // the transposition flag to compensate for the tensor being stored
418         // row-major. Since GEMV doesn't provide a way to just conjugate an
419         // argument, we have to defer those cases to GEMM below.
420         auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
421                                 ? se::blas::Transpose::kNoTranspose
422                                 : se::blas::Transpose::kTranspose;
423         bool blas_launch_status =
424             stream
425                 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
426                                adj_x || trans_x ? k : m,
427                                static_cast<Coefficient>(1.0), *(a_ptrs[0]),
428                                adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
429                                static_cast<Coefficient>(0.0), c_ptrs[0], 1)
430                 .ok();
431         if (!blas_launch_status) {
432           context->SetStatus(errors::Internal(
433               "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
434               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
435               ", k=", k));
436         }
437       } else {
438         bool blas_launch_status =
439             stream
440                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
441                                static_cast<Coefficient>(1.0), *(b_ptrs[0]),
442                                adj_y || trans_y ? k : n, *(a_ptrs[0]),
443                                adj_x || trans_x ? m : k,
444                                static_cast<Coefficient>(0.0), c_ptrs[0], n)
445                 .ok();
446         if (!blas_launch_status) {
447           context->SetStatus(errors::Internal(
448               "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
449               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
450               ", k=", k));
451         }
452       }
453     } else if (use_strided_batched) {
454       bool blas_launch_status =
455           stream
456               ->ThenBlasGemmStridedBatched(
457                   blas_transpose_b, blas_transpose_a, n, m, k,
458                   static_cast<Coefficient>(1.0), *b_ptrs[0],
459                   adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
460                   adj_x || trans_x ? m : k, a_stride,
461                   static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
462                   batch_size)
463               .ok();
464       if (!blas_launch_status) {
465         context->SetStatus(errors::Internal(
466             "Blas xGEMMStridedBatched launch failed : a.shape=",
467             in_x.shape().DebugString(),
468             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
469             ", k=", k, ", batch_size=", batch_size));
470       }
471     } else {
472       BlasScratchAllocator scratch_allocator(context);
473       bool blas_launch_status =
474           stream
475               ->ThenBlasGemmBatchedWithScratch(
476                   blas_transpose_b, blas_transpose_a, n, m, k,
477                   static_cast<Coefficient>(1.0), b_ptrs,
478                   adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
479                   static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
480                   &scratch_allocator)
481               .ok();
482       if (!blas_launch_status) {
483         context->SetStatus(errors::Internal(
484             "Blas xGEMMBatched launch failed : a.shape=",
485             in_x.shape().DebugString(),
486             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
487             ", k=", k, ", batch_size=", batch_size));
488       }
489     }
490   }
491 };
492 
493 template <>
494 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
495   static void Launch(OpKernelContext* context, const Tensor& in_x,
496                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
497                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
498     typedef Eigen::half Scalar;
499     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
500                                    se::blas::Transpose::kTranspose,
501                                    se::blas::Transpose::kConjugateTranspose};
502     const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
503     const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
504     const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
505     const uint64 batch_size = bcast.output_batch_size();
506     auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
507     auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
508 
509     auto* stream = context->op_device_context()->stream();
510     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
511 
512     typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
513     std::vector<DeviceMemoryType> a_device_memory;
514     std::vector<DeviceMemoryType> b_device_memory;
515     std::vector<DeviceMemoryType> c_device_memory;
516     std::vector<DeviceMemoryType*> a_ptrs;
517     std::vector<DeviceMemoryType*> b_ptrs;
518     std::vector<DeviceMemoryType*> c_ptrs;
519     a_device_memory.reserve(bcast.x_batch_size());
520     b_device_memory.reserve(bcast.y_batch_size());
521     c_device_memory.reserve(batch_size);
522     a_ptrs.reserve(batch_size);
523     b_ptrs.reserve(batch_size);
524     c_ptrs.reserve(batch_size);
525     auto* a_base_ptr = in_x.template flat<Scalar>().data();
526     auto* b_base_ptr = in_y.template flat<Scalar>().data();
527     auto* c_base_ptr = out->template flat<Scalar>().data();
528 
529     uint64 a_stride;
530     uint64 b_stride;
531     uint64 c_stride;
532 
533     bool is_full_broadcast =
534         std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
535     bool use_strided_batched =
536         (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
537         batch_size > 1;
538     if (use_strided_batched) {
539       a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
540       b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
541       c_stride = m * n;
542       a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
543       b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
544       c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
545       a_ptrs.push_back(&a_device_memory.back());
546       b_ptrs.push_back(&b_device_memory.back());
547       c_ptrs.push_back(&c_device_memory.back());
548     } else if (!bcast.IsBroadcastingRequired()) {
549       for (int64 i = 0; i < batch_size; ++i) {
550         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
551         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
552         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
553         a_ptrs.push_back(&a_device_memory.back());
554         b_ptrs.push_back(&b_device_memory.back());
555         c_ptrs.push_back(&c_device_memory.back());
556       }
557     } else {
558       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
559       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
560       for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
561         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
562       }
563       for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
564         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
565       }
566       for (int64 i = 0; i < batch_size; ++i) {
567         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
568         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
569         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
570         c_ptrs.push_back(&c_device_memory.back());
571       }
572     }
573 
574     typedef float Coefficient;
575 
576     // Blas does
577     // C = A x B
578     // where A, B and C are assumed to be in column major.
579     // We want the output to be in row-major, so we can compute
580     // C' = B' x A', where ' stands for transpose (not adjoint).
581     // TODO(yangzihao): Choose the best of the three strategies using autotune.
582     if (batch_size == 1) {
583       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
584       // overhead of the scratch allocator and the batch interface.
585       // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
586       bool blas_launch_status =
587           stream
588               ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
589                              static_cast<Coefficient>(1.0), *(b_ptrs[0]),
590                              adj_y || trans_y ? k : n, *(a_ptrs[0]),
591                              adj_x || trans_x ? m : k,
592                              static_cast<Coefficient>(0.0), c_ptrs[0], n)
593               .ok();
594       if (!blas_launch_status) {
595         context->SetStatus(errors::Internal(
596             "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
597             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
598             ", k=", k));
599       }
600     } else if (use_strided_batched) {
601       bool blas_launch_status =
602           stream
603               ->ThenBlasGemmStridedBatched(
604                   blas_transpose_b, blas_transpose_a, n, m, k,
605                   static_cast<Coefficient>(1.0), *b_ptrs[0],
606                   adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
607                   adj_x || trans_x ? m : k, a_stride,
608                   static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
609                   batch_size)
610               .ok();
611       if (!blas_launch_status) {
612         context->SetStatus(errors::Internal(
613             "Blas xGEMMStridedBatched launch failed : a.shape=",
614             in_x.shape().DebugString(),
615             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
616             ", k=", k, ", batch_size=", batch_size));
617       }
618     } else {
619       BlasScratchAllocator scratch_allocator(context);
620       bool blas_launch_status =
621           stream
622               ->ThenBlasGemmBatchedWithScratch(
623                   blas_transpose_b, blas_transpose_a, n, m, k,
624                   static_cast<Coefficient>(1.0), b_ptrs,
625                   adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
626                   static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
627                   &scratch_allocator)
628               .ok();
629       if (!blas_launch_status) {
630         context->SetStatus(errors::Internal(
631             "Blas xGEMMBatched launch failed : a.shape=",
632             in_x.shape().DebugString(),
633             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
634             ", k=", k, ", batch_size=", batch_size));
635       }
636     }
637   }
638 };
639 
640 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
641 
642 
643 template <typename Device, typename Scalar>
644 class BaseBatchMatMulOp : public OpKernel {
645  public:
646   explicit BaseBatchMatMulOp(OpKernelConstruction* context,
647                              bool is_legacy_matmul)
648       : OpKernel(context) {
649     if (is_legacy_matmul) {
650       // The old MatMul kernel has "transpose_a/transpose_b" attributes.
651       OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
652       OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
653       adj_x_ = false;
654       adj_y_ = false;
655     } else {
656       OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
657       OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
658       trans_x_ = false;
659       trans_y_ = false;
660     }
661   }
662 
663   ~BaseBatchMatMulOp() override {}
664 
665   void Compute(OpKernelContext* ctx) override {
666     const Tensor& in0 = ctx->input(0);
667     const Tensor& in1 = ctx->input(1);
668 
669     const Status s = ValidateInputTensors(ctx, in0, in1);
670     if (!s.ok()) {
671       ctx->SetStatus(s);
672       return;
673     }
674 
675     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
676     OP_REQUIRES(
677         ctx, bcast.IsValid(),
678         errors::InvalidArgument(
679             "In[0] and In[1] must have compatible batch dimensions: ",
680             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
681 
682     TensorShape out_shape = bcast.output_batch_shape();
683     auto batch_size = bcast.output_batch_size();
684     auto d0 = in0.dim_size(in0.dims() - 2);
685     auto d1 = in0.dim_size(in0.dims() - 1);
686     Tensor in0_reshaped;
687     OP_REQUIRES(
688         ctx,
689         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
690         errors::Internal("Failed to reshape In[0] from ",
691                          in0.shape().DebugString()));
692     auto d2 = in1.dim_size(in1.dims() - 2);
693     auto d3 = in1.dim_size(in1.dims() - 1);
694     Tensor in1_reshaped;
695     OP_REQUIRES(
696         ctx,
697         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
698         errors::Internal("Failed to reshape In[1] from ",
699                          in1.shape().DebugString()));
700     if (adj_x_ || trans_x_) std::swap(d0, d1);
701     if (adj_y_ || trans_y_) std::swap(d2, d3);
702     OP_REQUIRES(ctx, d1 == d2,
703                 errors::InvalidArgument(
704                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
705                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
706                     " ", adj_x_, " ", adj_y_));
707     out_shape.AddDim(d0);
708     out_shape.AddDim(d3);
709     Tensor* out = nullptr;
710     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
711     if (out->NumElements() == 0) {
712       return;
713     }
714     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
715       functor::SetZeroFunctor<Device, Scalar> f;
716       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
717       return;
718     }
719     Tensor out_reshaped;
720     OP_REQUIRES(ctx,
721                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
722                 errors::Internal("Failed to reshape output from ",
723                                  out->shape().DebugString()));
724     if (std::is_same<Scalar, bfloat16>::value) {
725       bool is_cpu = std::is_same<Device, CPUDevice>::value;
726       OP_REQUIRES(ctx, is_cpu,
727                   errors::Internal("bfloat16 matmul is not supported by GPU"));
728       Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
729       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
730                                              &in0_reshaped_float));
731       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
732                                              &in1_reshaped_float));
733       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
734                                              &out_reshaped_float));
735 
736       // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
737       BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
738                       in0_reshaped_float.flat<float>().data(),
739                       in0_reshaped.NumElements());
740       BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
741                       in1_reshaped_float.flat<float>().data(),
742                       in1_reshaped.NumElements());
743 
744       LaunchBatchMatMul<Device, float>::Launch(
745           ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
746           trans_y_, bcast, &out_reshaped_float);
747       FloatToBFloat16(out_reshaped_float.flat<float>().data(),
748                       out_reshaped.flat<bfloat16>().data(), out->NumElements());
749     } else {
750       LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
751                                                 adj_x_, adj_y_, trans_x_,
752                                                 trans_y_, bcast, &out_reshaped);
753     }
754   }
755 
756  protected:
757   virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
758                                       const Tensor& in1) = 0;
759 
760  private:
761   // TODO(171979567) Make the ops take both adj and transpose attributes.
762   bool adj_x_;
763   bool adj_y_;
764   bool trans_x_;
765   bool trans_y_;
766 };
767 
768 // BatchMatMul Op implementation which disallows broadcasting.
769 template <typename Device, typename Scalar, bool is_legacy_matmul = false>
770 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
771  public:
772   explicit BatchMatMulOp(OpKernelConstruction* context)
773       : BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
774 
775   ~BatchMatMulOp() override {}
776 
777  private:
778   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
779                               const Tensor& in1) override {
780     // Disallow broadcasting support. Ensure that all batch dimensions of the
781     // input tensors match.
782     if (in0.dims() != in1.dims()) {
783       return errors::InvalidArgument(
784           "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
785           " vs. ", in1.shape().DebugString());
786     }
787     const int ndims = in0.dims();
788     if (is_legacy_matmul) {
789       if (ndims != 2) {
790         return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
791                                        ndims);
792       }
793     } else {
794       if (ndims < 2) {
795         return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
796                                        ndims);
797       }
798       for (int i = 0; i < ndims - 2; ++i) {
799         if (in0.dim_size(i) != in1.dim_size(i)) {
800           return errors::InvalidArgument(
801               "In[0].dim(", i, ") and In[1].dim(", i,
802               ") must be the same: ", in0.shape().DebugString(), " vs ",
803               in1.shape().DebugString());
804         }
805       }
806     }
807     return Status::OK();
808   }
809 };
810 
811 // BatchMatMul Op implementation with broadcasting support.
812 template <typename Device, typename Scalar>
813 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
814  public:
815   explicit BatchMatMulV2Op(OpKernelConstruction* context)
816       : BaseBatchMatMulOp<Device, Scalar>(context,
817                                           /* is_legacy_matmul= */ false) {}
818 
819   ~BatchMatMulV2Op() override {}
820 
821  private:
822   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
823                               const Tensor& in1) override {
824     // Enable broadcasting support. Validity of broadcasting is checked in
825     // BaseBatchMatMulOp.
826     if (in0.dims() < 2) {
827       return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
828     }
829     if (in1.dims() < 2) {
830       return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
831     }
832     return Status::OK();
833   }
834 };
835 
836 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                   \
837   REGISTER_KERNEL_BUILDER(                                                \
838       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),   \
839       BatchMatMulOp<CPUDevice, TYPE>);                                    \
840   REGISTER_KERNEL_BUILDER(                                                \
841       Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
842       BatchMatMulV2Op<CPUDevice, TYPE>);                                  \
843   REGISTER_KERNEL_BUILDER(                                                \
844       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),        \
845       BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
846 
847 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                   \
848   REGISTER_KERNEL_BUILDER(                                                \
849       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),   \
850       BatchMatMulOp<GPUDevice, TYPE>);                                    \
851   REGISTER_KERNEL_BUILDER(                                                \
852       Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
853       BatchMatMulV2Op<GPUDevice, TYPE>);                                  \
854   REGISTER_KERNEL_BUILDER(                                                \
855       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),        \
856       BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
857 
858 }  // namespace tensorflow
859 
860 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
861