1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 // This file uses oneDNN library for acceleration of Batch Matrix-Matrix
19 // Multiplication (MatMul) operations. We currently register this kernel only
20 // for oneDNN supported data types (float, bfloat16). The maximum number of
21 // dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
22 // exceeds 12, we fall back to Eigen library based kernel.
23 
24 #define EIGEN_USE_THREADS
25 
26 #if defined(INTEL_MKL)
27 
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/type_traits.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/kernels/fill_functor.h"
37 #include "tensorflow/core/kernels/matmul_op_impl.h"
38 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/util/matmul_bcast.h"
42 #include "tensorflow/core/util/mkl_util.h"
43 
44 namespace tensorflow {
45 
46 typedef Eigen::ThreadPoolDevice CPUDevice;
47 
48 //  The third parameter v2_bcast is set to true if we are using V2 otherwise
49 //  we set it to false.
50 template <typename Device, typename Scalar, bool v2_bcast>
51 class BatchMatMulMkl : public OpKernel {
52  public:
BatchMatMulMkl(OpKernelConstruction * context)53   explicit BatchMatMulMkl(OpKernelConstruction* context)
54       : OpKernel(context), eigen_batch_mm_v2_(context) {
55     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
56     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
57   }
58 
~BatchMatMulMkl()59   virtual ~BatchMatMulMkl() {}
60 
Compute(OpKernelContext * ctx)61   void Compute(OpKernelContext* ctx) override {
62     const Tensor& lhs = ctx->input(0);
63     const Tensor& rhs = ctx->input(1);
64 
65     if (!v2_bcast) {
66       // Using V1, so check to make sure lhs and rhs dimensions are correct and
67       // no broadcasting is needed.
68       OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
69                   errors::InvalidArgument("lhs and rhs has different ndims: ",
70                                           lhs.shape().DebugString(), " vs. ",
71                                           rhs.shape().DebugString()));
72       const int ndims = lhs.dims();
73       OP_REQUIRES(
74           ctx, ndims >= 2,
75           errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
76       for (int i = 0; i < ndims - 2; ++i) {
77         OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
78                     errors::InvalidArgument(
79                         "lhs.dim(", i, ") and rhs.dim(", i,
80                         ") must be the same: ", lhs.shape().DebugString(),
81                         " vs ", rhs.shape().DebugString()));
82       }
83     } else {
84       OP_REQUIRES(
85           ctx, lhs.dims() >= 2,
86           errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims()));
87       OP_REQUIRES(
88           ctx, rhs.dims() >= 2,
89           errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims()));
90     }
91 
92     // lhs and rhs can have different dimensions
93     const auto ndims_lhs = lhs.dims();
94     const auto ndims_rhs = rhs.dims();
95 
96     // Get broadcast info
97     MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes());
98     OP_REQUIRES(
99         ctx, bcast.IsValid(),
100         errors::InvalidArgument(
101             "In[0] and In[1] must have compatible batch dimensions: ",
102             lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
103 
104     TensorShape out_shape = bcast.output_batch_shape();
105 
106     auto lhs_rows = lhs.dim_size(ndims_lhs - 2);
107     auto lhs_cols = lhs.dim_size(ndims_lhs - 1);
108     auto rhs_rows = rhs.dim_size(ndims_rhs - 2);
109     auto rhs_cols = rhs.dim_size(ndims_rhs - 1);
110 
111     if (adj_x_) std::swap(lhs_rows, lhs_cols);
112     if (adj_y_) std::swap(rhs_rows, rhs_cols);
113     OP_REQUIRES(ctx, lhs_cols == rhs_rows,
114                 errors::InvalidArgument(
115                     "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
116                     ": ", lhs.shape().DebugString(), " ",
117                     rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
118 
119     out_shape.AddDim(lhs_rows);
120     out_shape.AddDim(rhs_cols);
121     // The maximum number of dimensions for a tensor in DNNL is 12.
122     OP_REQUIRES(
123         ctx, out_shape.dims() <= 12,
124         errors::InvalidArgument(
125             "Rank of output tensor must be <= 12, but is ", out_shape.dims(),
126             ". Current implementation supports upto rank 12 tensors."));
127 
128     Tensor* out = nullptr;
129     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
130     if (out->NumElements() == 0) {
131       return;
132     }
133     if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
134       functor::SetZeroFunctor<Device, Scalar> f;
135       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
136       return;
137     }
138 
139     // Compute parameters for DNNL matmul primitive.
140     auto params = CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape);
141     // Create or retrieve matmul primitive from cache.
142     MklMatMulPrimitive<Scalar>* matmul_prim =
143         MklMatMulPrimitiveFactory<Scalar>::Get(
144             *params, false /* value for do_not_cache */);
145     // Execute matmul primitive.
146     std::shared_ptr<stream> cpu_stream;
147     cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
148     matmul_prim->Execute(lhs.flat<Scalar>().data(), rhs.flat<Scalar>().data(),
149                          out->flat<Scalar>().data(), cpu_stream);
150   }
151 
152  private:
153   bool adj_x_;
154   bool adj_y_;
155   BatchMatMulV2Op<CPUDevice, Scalar> eigen_batch_mm_v2_;
156 
157   using dims = dnnl::memory::dims;
158 
159   // This method makes the rank (ndims) of input same as the output by adding
160   // new axes to the input. For example, if input shape is [a, b, c, d] and
161   // output shape is [e, f, g, h, i, j], then the reshaped input would have a
162   // shape of [1, 1, a, b, c, d].
ExpandInputDimsToOutputShape(const TensorShape & input_shape,const TensorShape & output_shape,dims * reshaped_dims)163   void ExpandInputDimsToOutputShape(const TensorShape& input_shape,
164                                     const TensorShape& output_shape,
165                                     dims* reshaped_dims) {
166     auto ndims_input = input_shape.dims();
167     auto ndims_output = output_shape.dims();
168     auto dim_offset = ndims_output - ndims_input;
169     DCHECK(dim_offset > 0);
170     reshaped_dims->clear();
171     reshaped_dims->resize(ndims_output, 1);
172     auto input_dims = input_shape.dim_sizes();
173     for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx)
174       reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx];
175   }
176 
CreateMatMulParams(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & out_shape)177   std::unique_ptr<MklMatMulParams> CreateMatMulParams(
178       const TensorShape& lhs_shape, const TensorShape& rhs_shape,
179       const TensorShape& out_shape) {
180     const auto ndims_lhs = lhs_shape.dims();
181     const auto ndims_rhs = rhs_shape.dims();
182     const auto ndims_out = out_shape.dims();
183     auto lhs_dims = TFShapeToMklDnnDims(lhs_shape);
184     auto rhs_dims = TFShapeToMklDnnDims(rhs_shape);
185     auto out_dims = TFShapeToMklDnnDims(out_shape);
186 
187     // DNNL matmul_primitive requires ranks of inputs and output to be same.
188     // Create dnnl::memory::dims for inputs and output of same rank.
189     // It is assumed here that MatMulBCast object creates output_batch_shape as
190     // a conforming superset of input batch shapes, i.e., ndims_out >=
191     // ndims_lhs and ndims_out >= ndims_rhs.
192     if (ndims_lhs < ndims_out) {
193       ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
194     }
195     if (ndims_rhs < ndims_out) {
196       ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims);
197     }
198 
199     using dim = dnnl::memory::dim;
200     dim m;  // number of rows in x
201     dim k;  // number of columns in x
202     dim n;  // number of columns in y
203     auto lhs_strides = CalculateTFStrides(lhs_dims);
204     auto rhs_strides = CalculateTFStrides(rhs_dims);
205     auto out_strides = CalculateTFStrides(out_dims);
206 
207     if (adj_x_) {
208       int m_idx = ndims_out - 1;
209       int k_idx = ndims_out - 2;
210       m = lhs_dims[m_idx];
211       k = lhs_dims[k_idx];
212       std::swap(lhs_dims[m_idx], lhs_dims[k_idx]);
213       lhs_strides[m_idx] = m;
214       lhs_strides[k_idx] = 1;
215     }
216 
217     if (adj_y_) {
218       int k_idx = ndims_out - 1;
219       int n_idx = ndims_out - 2;
220       k = rhs_dims[k_idx];
221       n = rhs_dims[n_idx];
222       std::swap(rhs_dims[k_idx], rhs_dims[n_idx]);
223       rhs_strides[k_idx] = k;
224       rhs_strides[n_idx] = 1;
225     }
226     return std::make_unique<MklMatMulParams>(
227         lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides, out_strides);
228   }
229 };
230 
231 #define REGISTER_BATCH_MATMUL_MKL(TYPE)                                       \
232   REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMul")                             \
233                               .Device(DEVICE_CPU)                             \
234                               .TypeConstraint<TYPE>("T")                      \
235                               .Label(mkl_op_registry::kMklNameChangeOpLabel), \
236                           BatchMatMulMkl<CPUDevice, TYPE, false>)
237 
238 #define REGISTER_BATCH_MATMUL_MKL_V2(TYPE)                                    \
239   REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2")                           \
240                               .Device(DEVICE_CPU)                             \
241                               .TypeConstraint<TYPE>("T")                      \
242                               .Label(mkl_op_registry::kMklNameChangeOpLabel), \
243                           BatchMatMulMkl<CPUDevice, TYPE, true>)
244 #ifdef ENABLE_MKL
245 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
246 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
247 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
248 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
249 #endif  // ENABLE_MKL
250 
251 }  // end namespace tensorflow
252 #endif
253