1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifdef INTEL_MKL
19 #define EIGEN_USE_THREADS
20 
21 #include <numeric>
22 
23 #include "mkldnn.hpp"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/util/mkl_util.h"
29 
30 using mkldnn::stream;
31 using mkldnn::sum;
32 
33 namespace tensorflow {
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 
36 template <typename Device, typename T>
37 class MklAddNOp : public OpKernel {
38  public:
~MklAddNOp()39   ~MklAddNOp() {}
MklAddNOp(OpKernelConstruction * context)40   explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
41 
GetTensorShape(OpKernelContext * ctx,size_t src_index)42   TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) {
43     const Tensor& src_tensor = MklGetInput(ctx, src_index);
44     MklDnnShape src_mkl_shape;
45     GetMklShape(ctx, src_index, &src_mkl_shape);
46     return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
47                                        : src_tensor.shape();
48   }
49 
CheckInputShape(OpKernelContext * ctx)50   bool CheckInputShape(OpKernelContext* ctx) {
51     const int num_inputs = ctx->num_inputs() / 2;
52     const TensorShape src0_shape = GetTensorShape(ctx, 0);
53 
54     for (size_t i = 1; i < num_inputs; ++i) {
55       if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) {
56         ctx->SetStatus(errors::InvalidArgument(
57             "Inputs to operation ", this->name(), " of type ",
58             this->type_string(),
59             " must have the same size and shape.  Input 0: ",
60             src0_shape.DebugString(), " != input : ", i,
61             GetTensorShape(ctx, i).DebugString()));
62 
63         return false;
64       }
65     }
66 
67     return true;
68   }
69 
70   // Return first tensor index which is in MKL layout, or -1 with no MKL input.
FindMKLInputIndex(OpKernelContext * ctx)71   int FindMKLInputIndex(OpKernelContext* ctx) {
72     int mkl_index = -1;
73     const int num_inputs = ctx->num_inputs() / 2;
74 
75     MklDnnShape src_mkl_shape;
76     for (size_t i = 0; i < num_inputs; ++i) {
77       GetMklShape(ctx, i, &src_mkl_shape);
78       if (src_mkl_shape.IsMklTensor()) {
79         mkl_index = i;
80         break;
81       }
82     }
83 
84     return mkl_index;
85   }
86 
ComputeScalar(OpKernelContext * ctx)87   void ComputeScalar(OpKernelContext* ctx) {
88     const int num_inputs = ctx->num_inputs() / 2;
89     const size_t kOutputIdx = 0;
90     TensorShape output_tf_shape;
91     MklDnnShape output_mkl_shape;
92     Tensor* dst_tensor = nullptr;
93 
94     T sum = static_cast<T>(0);
95     for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
96       const Tensor& src_tensor = MklGetInput(ctx, src_idx);
97       T* src_i = const_cast<T*>(src_tensor.flat<T>().data());
98       sum += src_i[0];
99     }
100 
101     output_mkl_shape.SetMklTensor(false);
102     output_tf_shape = MklGetInput(ctx, kOutputIdx).shape();
103     AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
104                               output_mkl_shape);
105 
106     T* out_o = dst_tensor->flat<T>().data();
107     out_o[0] = sum;
108   }
109 
Compute(OpKernelContext * ctx)110   void Compute(OpKernelContext* ctx) override {
111     // Each input tensor in MKL layout has additional meta-tensor carrying
112     // layout information. So the number of actual tensors is half the total
113     // number of inputs.
114     const int num_inputs = ctx->num_inputs() / 2;
115 
116     MklDnnShape mkl_shape;
117     const size_t kSrc0Idx = 0;
118     const size_t kOutputIdx = 0;
119 
120     if (num_inputs == 1) {
121       GetMklShape(ctx, kSrc0Idx, &mkl_shape);
122       bool input_in_mkl_format = mkl_shape.IsMklTensor();
123 
124       if (input_in_mkl_format) {
125         ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
126       } else {
127         ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
128       }
129       return;
130     }
131 
132     // Check if the input shape is same
133     if (!CheckInputShape(ctx)) return;
134 
135     try {
136       TensorShape output_tf_shape;
137       MklDnnShape output_mkl_shape;
138       const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx);
139 
140       Tensor* dst_tensor = nullptr;
141 
142       // Nothing to compute, return.
143       if (src_tensor.shape().num_elements() == 0) {
144         output_mkl_shape.SetMklTensor(false);
145         output_tf_shape = src_tensor.shape();
146         AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
147                                   output_mkl_shape);
148         return;
149       }
150 
151       if (src_tensor.dims() == 0) {
152         ComputeScalar(ctx);
153         return;
154       }
155 
156       auto cpu_engine = engine(engine::kind::cpu, 0);
157       std::vector<float> coeff(num_inputs, 1.0);
158       std::vector<memory::desc> srcs_pd;
159       std::vector<memory> inputs;
160 
161       MklDnnData<T> dst(&cpu_engine);
162       MklDnnData<T> src(&cpu_engine);
163       bool has_mkl_input = false;
164       int mkl_input_index = FindMKLInputIndex(ctx);
165       MklTensorFormat mkl_data_format;
166       TensorFormat tf_data_format;
167       memory::format_tag dnn_fmt = memory::format_tag::any;
168       if (mkl_input_index >= 0) {
169         has_mkl_input = true;
170         GetMklShape(ctx, mkl_input_index, &mkl_shape);
171         // MKL input has the data format information.
172         mkl_data_format = mkl_shape.GetTfDataFormat();
173         tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format);
174         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
175       }
176 
177       std::shared_ptr<stream> fwd_cpu_stream;
178       fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
179 
180       // Create memory descriptor for MKL-DNN.
181       // If all input in Tensorflow format, create block memory descriptor,
182       // else convert TF format to MKL memory descriptor
183       for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
184         MklDnnShape src_mkl_shape;
185         GetMklShape(ctx, src_idx, &src_mkl_shape);
186         memory::desc md({}, memory::data_type::undef,
187                         memory::format_tag::undef);
188         const Tensor& src_tensor = MklGetInput(ctx, src_idx);
189 
190         if (src_mkl_shape.IsMklTensor()) {
191           md = src_mkl_shape.GetMklLayout();
192         } else {
193           if (has_mkl_input) {
194             memory::dims src_dims;
195             if (src_tensor.dims() == 4) {
196               src_dims =
197                   TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format);
198             } else {
199               DCHECK(src_tensor.dims() == 5);
200               src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
201                                                     tf_data_format);
202             }
203             md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
204           } else {
205             // Create block memory descriptor for TensorFlow format input.
206             auto dims = TFShapeToMklDnnDims(src_tensor.shape());
207             auto strides = CalculateTFStrides(dims);
208             md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
209           }
210         }
211         srcs_pd.push_back(memory::desc(md));
212         src.SetUsrMem(md, &src_tensor);
213         src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
214         inputs.push_back(src.GetOpMem());
215       }
216 
217       auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine);
218       output_mkl_shape.SetMklTensor(has_mkl_input);
219       auto output_pd = sum_pd.dst_desc();
220       dst.SetUsrMem(output_pd);
221 
222       if (has_mkl_input) {
223         output_mkl_shape.SetMklLayout(&output_pd);
224         output_mkl_shape.SetElemType(MklDnnType<T>());
225         output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(),
226                                      mkl_shape.GetSizesAsMklDnnDims(),
227                                      mkl_shape.GetTfDataFormat());
228         output_tf_shape.AddDim((output_pd.get_size() / sizeof(T)));
229       } else {
230         // All inputs have TF shapes, get the shape from first one.
231         output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape();
232       }
233       AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
234                                 output_mkl_shape);
235       dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
236 
237       // Create Sum op, and submit net for execution.
238       std::vector<primitive> net;
239       mkldnn::sum sum_op(sum_pd);
240       std::unordered_map<int, memory> net_args = {
241           {MKLDNN_ARG_DST, dst.GetOpMem()}};
242       for (int i = 0; i < num_inputs; ++i) {
243         net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
244       }
245       sum_op.execute(*fwd_cpu_stream, net_args);
246     } catch (mkldnn::error& e) {
247       string error_msg = "Status: " + std::to_string(e.status) +
248                          ", message: " + string(e.message) + ", in file " +
249                          string(__FILE__) + ":" + std::to_string(__LINE__);
250       OP_REQUIRES_OK(
251           ctx, errors::Aborted("Operation received an exception:", error_msg));
252     }
253   }
254 };
255 
256 #define REGISTER_MKL_CPU(T)                                    \
257   REGISTER_KERNEL_BUILDER(                                     \
258       Name("_MklAddN")                                         \
259           .Device(DEVICE_CPU)                                  \
260           .TypeConstraint<T>("T")                              \
261           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
262       MklAddNOp<CPUDevice, T>);
263 
264 TF_CALL_float(REGISTER_MKL_CPU);
265 TF_CALL_bfloat16(REGISTER_MKL_CPU);
266 #undef REGISTER_MKL_CPU
267 }  // namespace tensorflow
268 #endif  // INTEL_MKL
269