1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #ifdef INTEL_MKL 19 #define EIGEN_USE_THREADS 20 21 #include <numeric> 22 23 #include "mkldnn.hpp" 24 #include "tensorflow/core/framework/numeric_op.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/lib/gtl/inlined_vector.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/util/mkl_util.h" 29 30 using mkldnn::stream; 31 using mkldnn::sum; 32 33 namespace tensorflow { 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 36 template <typename Device, typename T> 37 class MklAddNOp : public OpKernel { 38 public: ~MklAddNOp()39 ~MklAddNOp() {} MklAddNOp(OpKernelConstruction * context)40 explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {} 41 GetTensorShape(OpKernelContext * ctx,size_t src_index)42 TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) { 43 const Tensor& src_tensor = MklGetInput(ctx, src_index); 44 MklDnnShape src_mkl_shape; 45 GetMklShape(ctx, src_index, &src_mkl_shape); 46 return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() 47 : src_tensor.shape(); 48 } 49 CheckInputShape(OpKernelContext * ctx)50 bool CheckInputShape(OpKernelContext* ctx) { 51 const int num_inputs = ctx->num_inputs() / 2; 52 const TensorShape src0_shape = GetTensorShape(ctx, 0); 53 54 for (size_t i = 1; i < num_inputs; ++i) { 55 if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) { 56 ctx->SetStatus(errors::InvalidArgument( 57 "Inputs to operation ", this->name(), " of type ", 58 this->type_string(), 59 " must have the same size and shape. Input 0: ", 60 src0_shape.DebugString(), " != input : ", i, 61 GetTensorShape(ctx, i).DebugString())); 62 63 return false; 64 } 65 } 66 67 return true; 68 } 69 70 // Return first tensor index which is in MKL layout, or -1 with no MKL input. FindMKLInputIndex(OpKernelContext * ctx)71 int FindMKLInputIndex(OpKernelContext* ctx) { 72 int mkl_index = -1; 73 const int num_inputs = ctx->num_inputs() / 2; 74 75 MklDnnShape src_mkl_shape; 76 for (size_t i = 0; i < num_inputs; ++i) { 77 GetMklShape(ctx, i, &src_mkl_shape); 78 if (src_mkl_shape.IsMklTensor()) { 79 mkl_index = i; 80 break; 81 } 82 } 83 84 return mkl_index; 85 } 86 ComputeScalar(OpKernelContext * ctx)87 void ComputeScalar(OpKernelContext* ctx) { 88 const int num_inputs = ctx->num_inputs() / 2; 89 const size_t kOutputIdx = 0; 90 TensorShape output_tf_shape; 91 MklDnnShape output_mkl_shape; 92 Tensor* dst_tensor = nullptr; 93 94 T sum = static_cast<T>(0); 95 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 96 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 97 T* src_i = const_cast<T*>(src_tensor.flat<T>().data()); 98 sum += src_i[0]; 99 } 100 101 output_mkl_shape.SetMklTensor(false); 102 output_tf_shape = MklGetInput(ctx, kOutputIdx).shape(); 103 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 104 output_mkl_shape); 105 106 T* out_o = dst_tensor->flat<T>().data(); 107 out_o[0] = sum; 108 } 109 Compute(OpKernelContext * ctx)110 void Compute(OpKernelContext* ctx) override { 111 // Each input tensor in MKL layout has additional meta-tensor carrying 112 // layout information. So the number of actual tensors is half the total 113 // number of inputs. 114 const int num_inputs = ctx->num_inputs() / 2; 115 116 MklDnnShape mkl_shape; 117 const size_t kSrc0Idx = 0; 118 const size_t kOutputIdx = 0; 119 120 if (num_inputs == 1) { 121 GetMklShape(ctx, kSrc0Idx, &mkl_shape); 122 bool input_in_mkl_format = mkl_shape.IsMklTensor(); 123 124 if (input_in_mkl_format) { 125 ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 126 } else { 127 ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 128 } 129 return; 130 } 131 132 // Check if the input shape is same 133 if (!CheckInputShape(ctx)) return; 134 135 try { 136 TensorShape output_tf_shape; 137 MklDnnShape output_mkl_shape; 138 const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx); 139 140 Tensor* dst_tensor = nullptr; 141 142 // Nothing to compute, return. 143 if (src_tensor.shape().num_elements() == 0) { 144 output_mkl_shape.SetMklTensor(false); 145 output_tf_shape = src_tensor.shape(); 146 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 147 output_mkl_shape); 148 return; 149 } 150 151 if (src_tensor.dims() == 0) { 152 ComputeScalar(ctx); 153 return; 154 } 155 156 auto cpu_engine = engine(engine::kind::cpu, 0); 157 std::vector<float> coeff(num_inputs, 1.0); 158 std::vector<memory::desc> srcs_pd; 159 std::vector<memory> inputs; 160 161 MklDnnData<T> dst(&cpu_engine); 162 MklDnnData<T> src(&cpu_engine); 163 bool has_mkl_input = false; 164 int mkl_input_index = FindMKLInputIndex(ctx); 165 MklTensorFormat mkl_data_format; 166 TensorFormat tf_data_format; 167 memory::format_tag dnn_fmt = memory::format_tag::any; 168 if (mkl_input_index >= 0) { 169 has_mkl_input = true; 170 GetMklShape(ctx, mkl_input_index, &mkl_shape); 171 // MKL input has the data format information. 172 mkl_data_format = mkl_shape.GetTfDataFormat(); 173 tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format); 174 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format); 175 } 176 177 std::shared_ptr<stream> fwd_cpu_stream; 178 fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine)); 179 180 // Create memory descriptor for MKL-DNN. 181 // If all input in Tensorflow format, create block memory descriptor, 182 // else convert TF format to MKL memory descriptor 183 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 184 MklDnnShape src_mkl_shape; 185 GetMklShape(ctx, src_idx, &src_mkl_shape); 186 memory::desc md({}, memory::data_type::undef, 187 memory::format_tag::undef); 188 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 189 190 if (src_mkl_shape.IsMklTensor()) { 191 md = src_mkl_shape.GetMklLayout(); 192 } else { 193 if (has_mkl_input) { 194 memory::dims src_dims; 195 if (src_tensor.dims() == 4) { 196 src_dims = 197 TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format); 198 } else { 199 DCHECK(src_tensor.dims() == 5); 200 src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), 201 tf_data_format); 202 } 203 md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 204 } else { 205 // Create block memory descriptor for TensorFlow format input. 206 auto dims = TFShapeToMklDnnDims(src_tensor.shape()); 207 auto strides = CalculateTFStrides(dims); 208 md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides); 209 } 210 } 211 srcs_pd.push_back(memory::desc(md)); 212 src.SetUsrMem(md, &src_tensor); 213 src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream); 214 inputs.push_back(src.GetOpMem()); 215 } 216 217 auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine); 218 output_mkl_shape.SetMklTensor(has_mkl_input); 219 auto output_pd = sum_pd.dst_desc(); 220 dst.SetUsrMem(output_pd); 221 222 if (has_mkl_input) { 223 output_mkl_shape.SetMklLayout(&output_pd); 224 output_mkl_shape.SetElemType(MklDnnType<T>()); 225 output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(), 226 mkl_shape.GetSizesAsMklDnnDims(), 227 mkl_shape.GetTfDataFormat()); 228 output_tf_shape.AddDim((output_pd.get_size() / sizeof(T))); 229 } else { 230 // All inputs have TF shapes, get the shape from first one. 231 output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape(); 232 } 233 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 234 output_mkl_shape); 235 dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); 236 237 // Create Sum op, and submit net for execution. 238 std::vector<primitive> net; 239 mkldnn::sum sum_op(sum_pd); 240 std::unordered_map<int, memory> net_args = { 241 {MKLDNN_ARG_DST, dst.GetOpMem()}}; 242 for (int i = 0; i < num_inputs; ++i) { 243 net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]}); 244 } 245 sum_op.execute(*fwd_cpu_stream, net_args); 246 } catch (mkldnn::error& e) { 247 string error_msg = "Status: " + std::to_string(e.status) + 248 ", message: " + string(e.message) + ", in file " + 249 string(__FILE__) + ":" + std::to_string(__LINE__); 250 OP_REQUIRES_OK( 251 ctx, errors::Aborted("Operation received an exception:", error_msg)); 252 } 253 } 254 }; 255 256 #define REGISTER_MKL_CPU(T) \ 257 REGISTER_KERNEL_BUILDER( \ 258 Name("_MklAddN") \ 259 .Device(DEVICE_CPU) \ 260 .TypeConstraint<T>("T") \ 261 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 262 MklAddNOp<CPUDevice, T>); 263 264 TF_CALL_float(REGISTER_MKL_CPU); 265 TF_CALL_bfloat16(REGISTER_MKL_CPU); 266 #undef REGISTER_MKL_CPU 267 } // namespace tensorflow 268 #endif // INTEL_MKL 269