1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifdef INTEL_MKL 16 #include "mkldnn.hpp" 17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/fused_batch_norm_op.h" 23 #include "tensorflow/core/kernels/no_op.h" 24 #include "tensorflow/core/util/mkl_util.h" 25 #include "tensorflow/core/util/tensor_format.h" 26 27 #define GET_FLAG(bn_flag) static_cast<int>(mkldnn::normalization_flags::bn_flag) 28 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) 29 30 using mkldnn::batch_normalization_backward; 31 using mkldnn::batch_normalization_forward; 32 using mkldnn::prop_kind; 33 using mkldnn::stream; 34 35 using BatchNormFwdPd = mkldnn::batch_normalization_forward::primitive_desc; 36 using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc; 37 38 namespace tensorflow { 39 using CPUDevice = Eigen::ThreadPoolDevice; 40 41 using FusedBNActivationMode = functor::FusedBatchNormActivationMode; 42 43 struct MklBatchNormFwdParams { 44 memory::dims src_dims; 45 int depth; 46 float eps; 47 bool training; 48 FusedBNActivationMode activation_mode; 49 memory::desc src_md; 50 MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams51 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, 52 bool training, memory::desc src_md, 53 FusedBNActivationMode activation_mode) 54 : src_dims(src_dims), 55 depth(depth), 56 eps(eps), 57 training(training), 58 activation_mode(activation_mode), 59 src_md(src_md) {} 60 }; 61 62 template <typename T, typename U> 63 class MklFusedBatchNormFwdPrimitive : public MklPrimitive { 64 public: MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)65 explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) 66 : MklPrimitive(engine(engine::kind::cpu, 0)) { 67 if (context_.bn_fwd == nullptr) Setup(fwdParams); 68 } 69 ~MklFusedBatchNormFwdPrimitive()70 ~MklFusedBatchNormFwdPrimitive() {} 71 72 // BatchNormalization forward execute 73 // src_data: input data buffer of src 74 // weights_data: input data buffer of weights 75 // dst_data: output data buffer of dst 76 // mean_data: output data buffer of means 77 // variance_data: output data buffer of variances Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)78 void Execute(const T* src_data, const U* weights_data, T* dst_data, 79 U* mean_data, U* variance_data, 80 std::shared_ptr<stream> fwd_stream, U* workspace_data) { 81 #ifdef ENABLE_MKLDNN_THREADPOOL 82 // TODO: Create a common function and avoid the duplicate code 83 context_.src_mem->set_data_handle( 84 static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream); 85 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 86 *fwd_stream); 87 88 if (IS_SET(use_scale_shift)) 89 context_.weights_mem->set_data_handle( 90 static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream); 91 92 if ((context_.pkind == prop_kind::forward_training) || 93 (IS_SET(use_global_stats))) { 94 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data), 95 *fwd_stream); 96 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data), 97 *fwd_stream); 98 } 99 if (workspace_data != nullptr) { 100 context_.ws_mem->set_data_handle(workspace_data, *fwd_stream); 101 } 102 #else 103 context_.src_mem->set_data_handle( 104 static_cast<void*>(const_cast<T*>(src_data))); 105 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 106 107 if (IS_SET(use_scale_shift)) 108 context_.weights_mem->set_data_handle( 109 static_cast<void*>(const_cast<U*>(weights_data))); 110 111 if ((context_.pkind == prop_kind::forward_training) || 112 (IS_SET(use_global_stats))) { 113 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); 114 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); 115 } 116 if (workspace_data != nullptr) { 117 context_.ws_mem->set_data_handle(workspace_data); 118 } 119 #endif // ENABLE_MKLDNN_THREADPOOL 120 121 // Execute batch-normalization forward primitives. 122 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); 123 124 context_.src_mem->set_data_handle(DummyData); 125 context_.dst_mem->set_data_handle(DummyData); 126 127 if (IS_SET(use_scale_shift)) 128 context_.weights_mem->set_data_handle(DummyData); 129 130 if ((context_.pkind == prop_kind::forward_training) || 131 (IS_SET(use_global_stats))) { 132 context_.mean_mem->set_data_handle(DummyData); 133 context_.variance_mem->set_data_handle(DummyData); 134 } 135 136 if (workspace_data != nullptr) { 137 context_.ws_mem->set_data_handle(DummyData); 138 } 139 } 140 GetDstPd() const141 memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); } 142 GetBatchNormFwdPd() const143 std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const { 144 return context_.fwd_pd; 145 } 146 147 private: 148 // Primitive reuse context for BatchNorm forward op. 149 struct BatchNormFwdContext { 150 // Flags indicating if it is training or inference mode. 151 int64 flags; 152 153 // Algorithm kind. 154 mkldnn::prop_kind pkind; 155 156 // Inputs/outputs memory. 157 std::shared_ptr<mkldnn::memory> src_mem; 158 std::shared_ptr<mkldnn::memory> weights_mem; 159 std::shared_ptr<mkldnn::memory> dst_mem; 160 std::shared_ptr<mkldnn::memory> mean_mem; 161 std::shared_ptr<mkldnn::memory> variance_mem; 162 std::shared_ptr<mkldnn::memory> ws_mem; 163 164 // Forward BatchNorm primitive descriptor. 165 std::shared_ptr<BatchNormFwdPd> fwd_pd; 166 167 // BatchNorm forward primitive. 168 std::shared_ptr<mkldnn::primitive> bn_fwd; 169 std::vector<mkldnn::primitive> fwd_primitives; 170 171 std::vector<std::unordered_map<int, memory>> net_args; 172 BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext173 BatchNormFwdContext() 174 : flags(0), 175 pkind(prop_kind::forward_training), 176 src_mem(nullptr), 177 weights_mem(nullptr), 178 dst_mem(nullptr), 179 mean_mem(nullptr), 180 variance_mem(nullptr), 181 ws_mem(nullptr), 182 bn_fwd(nullptr) {} 183 }; 184 Setup(const MklBatchNormFwdParams & fwdParams)185 void Setup(const MklBatchNormFwdParams& fwdParams) { 186 context_.flags = 187 fwdParams.training 188 ? GET_FLAG(use_scale_shift) 189 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 190 context_.pkind = fwdParams.training ? prop_kind::forward_training 191 : prop_kind::forward_scoring; 192 193 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 194 context_.flags |= GET_FLAG(fuse_norm_relu); 195 } 196 // Memory descriptor 197 auto src_md = fwdParams.src_md; 198 // Create forward BatchNorm descriptor and primitive descriptor. 199 auto fwd_desc = batch_normalization_forward::desc( 200 context_.pkind, src_md, fwdParams.eps, 201 static_cast<mkldnn::normalization_flags>(context_.flags)); 202 203 context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_)); 204 205 // Create memory primitive based on dummy data 206 context_.src_mem.reset( 207 new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData)); 208 context_.dst_mem.reset( 209 new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData)); 210 211 memory::dims s_dims = {2, fwdParams.depth}; 212 memory::dims m_dims = {1, fwdParams.depth}; 213 if (IS_SET(use_scale_shift)) { 214 context_.weights_mem.reset( 215 new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc}, 216 cpu_engine_, DummyData)); 217 } 218 219 if (fwdParams.training || (IS_SET(use_global_stats))) { 220 context_.mean_mem.reset( 221 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 222 cpu_engine_, DummyData)); 223 224 context_.variance_mem.reset( 225 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 226 cpu_engine_, DummyData)); 227 } 228 229 if (IS_SET(fuse_norm_relu)) { 230 context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(), 231 cpu_engine_, DummyData)); 232 } 233 234 // BatchNorm forward primitive. 235 // TODO(intel-tf): Merge all the #ifdefs and simplify code 236 if (!fwdParams.training && !(IS_SET(use_global_stats))) { 237 if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) { 238 context_.net_args.push_back( 239 {{MKLDNN_ARG_SRC, *context_.src_mem}, 240 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 241 {MKLDNN_ARG_DST, *context_.dst_mem}}); 242 } else { 243 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, 244 {MKLDNN_ARG_DST, *context_.dst_mem}}); 245 } 246 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 247 } else if (IS_SET(use_global_stats)) { 248 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { 249 if (IS_SET(fuse_norm_relu)) { 250 context_.net_args.push_back( 251 {{MKLDNN_ARG_SRC, *context_.src_mem}, 252 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 253 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 254 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 255 {MKLDNN_ARG_DST, *context_.dst_mem}, 256 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 257 } else { 258 context_.net_args.push_back( 259 {{MKLDNN_ARG_SRC, *context_.src_mem}, 260 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 261 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 262 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 263 {MKLDNN_ARG_DST, *context_.dst_mem}}); 264 } 265 } else { 266 if (IS_SET(fuse_norm_relu)) { 267 context_.net_args.push_back( 268 {{MKLDNN_ARG_SRC, *context_.src_mem}, 269 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 270 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 271 {MKLDNN_ARG_DST, *context_.dst_mem}, 272 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 273 } else { 274 context_.net_args.push_back( 275 {{MKLDNN_ARG_SRC, *context_.src_mem}, 276 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 277 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 278 {MKLDNN_ARG_DST, *context_.dst_mem}}); 279 } 280 } 281 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 282 } else { 283 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { 284 if (IS_SET(fuse_norm_relu)) { 285 context_.net_args.push_back( 286 {{MKLDNN_ARG_SRC, *context_.src_mem}, 287 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 288 {MKLDNN_ARG_DST, *context_.dst_mem}, 289 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 290 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 291 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 292 } else { 293 context_.net_args.push_back( 294 {{MKLDNN_ARG_SRC, *context_.src_mem}, 295 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 296 {MKLDNN_ARG_DST, *context_.dst_mem}, 297 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 298 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}}); 299 } 300 } else { 301 if (IS_SET(fuse_norm_relu)) { 302 context_.net_args.push_back( 303 {{MKLDNN_ARG_SRC, *context_.src_mem}, 304 {MKLDNN_ARG_DST, *context_.dst_mem}, 305 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 306 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 307 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 308 } else { 309 context_.net_args.push_back( 310 {{MKLDNN_ARG_SRC, *context_.src_mem}, 311 {MKLDNN_ARG_DST, *context_.dst_mem}, 312 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 313 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}}); 314 } 315 } 316 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 317 } 318 319 context_.fwd_primitives.push_back(*context_.bn_fwd); 320 } 321 322 struct BatchNormFwdContext context_; 323 }; 324 325 template <typename T, typename U> 326 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 327 public: Get(const MklBatchNormFwdParams & fwdParams)328 static MklFusedBatchNormFwdPrimitive<T, U>* Get( 329 const MklBatchNormFwdParams& fwdParams) { 330 auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>( 331 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance() 332 .GetBatchNormFwd(fwdParams)); 333 334 if (bn_fwd == nullptr) { 335 bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams); 336 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd( 337 fwdParams, bn_fwd); 338 } 339 return bn_fwd; 340 } 341 GetInstance()342 static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { 343 static MklFusedBatchNormFwdPrimitiveFactory instance_; 344 return instance_; 345 } 346 347 private: MklFusedBatchNormFwdPrimitiveFactory()348 MklFusedBatchNormFwdPrimitiveFactory() {} ~MklFusedBatchNormFwdPrimitiveFactory()349 ~MklFusedBatchNormFwdPrimitiveFactory() {} 350 CreateKey(const MklBatchNormFwdParams & fwdParams)351 static string CreateKey(const MklBatchNormFwdParams& fwdParams) { 352 string prefix = "bn_fwd"; 353 FactoryKeyCreator key_creator; 354 key_creator.AddAsKey(prefix); 355 key_creator.AddAsKey(fwdParams.src_dims); 356 key_creator.AddAsKey<int>(fwdParams.depth); 357 key_creator.AddAsKey<float>(fwdParams.eps); 358 key_creator.AddAsKey<bool>(fwdParams.training); 359 key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode); 360 key_creator.AddAsKey(typeid(T).name()); 361 key_creator.AddAsKey(typeid(U).name()); 362 return key_creator.GetKey(); 363 } 364 GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)365 MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { 366 string key = CreateKey(fwdParams); 367 return this->GetOp(key); 368 } 369 SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)370 void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, 371 MklPrimitive* op) { 372 string key = CreateKey(fwdParams); 373 this->SetOp(key, op); 374 } 375 }; 376 377 struct MklBatchNormBwdParams { 378 memory::dims src_dims; 379 memory::dims diff_dst_dims; 380 int depth; 381 float eps; 382 bool training; 383 384 memory::desc src_md; 385 memory::desc diff_dst_md; 386 MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams387 MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, 388 int depth, float eps, bool training, 389 memory::desc src_md, memory::desc diff_dst_md) 390 : src_dims(src_dims), 391 diff_dst_dims(diff_dst_dims), 392 depth(depth), 393 eps(eps), 394 training(training), 395 src_md(src_md), 396 diff_dst_md(diff_dst_md) {} 397 }; 398 399 template <typename T, typename U> 400 class MklFusedBatchNormBwdPrimitive : public MklPrimitive { 401 public: MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)402 explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) 403 : MklPrimitive(engine(engine::kind::cpu, 0)) { 404 if (context_.bn_bwd == nullptr) Setup(bwdParams); 405 } 406 ~MklFusedBatchNormBwdPrimitive()407 ~MklFusedBatchNormBwdPrimitive() {} 408 409 // BatchNormalization backward execute 410 // src_data: input data buffer of src 411 // mean_data: input data buffer of mean 412 // variance_data: input data buffer of variance 413 // diff_dst_data: input data buffer of diff_dst 414 // weights_data: input data buffer of weights 415 // diff_src_data: output data buffer of diff_src 416 // diff_weights_data: output data buffer of diff_weights 417 // res_space_data: output data buffer or reserved_space_3. 418 // TODO: reserved_space_3: temp mem to hold 419 // intermediate results is not implemented 420 // on CPU as of now. Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)421 void Execute(const T* src_data, const U* mean_data, const U* variance_data, 422 const T* diff_dst_data, const U* weights_data, T* diff_src_data, 423 U* diff_weights_data, U* res_space_data, 424 std::shared_ptr<stream> bwd_stream) { 425 #ifdef ENABLE_MKLDNN_THREADPOOL 426 // TODO: Create a common function and avoid the duplicate code 427 context_.src_mem->set_data_handle( 428 static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream); 429 context_.mean_mem->set_data_handle( 430 static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream); 431 context_.variance_mem->set_data_handle( 432 static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream); 433 context_.diff_dst_mem->set_data_handle( 434 static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream); 435 436 if (IS_SET(use_scale_shift)) { 437 context_.weights_mem->set_data_handle( 438 static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream); 439 context_.diff_weights_mem->set_data_handle( 440 static_cast<void*>(diff_weights_data), *bwd_stream); 441 } 442 443 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data), 444 *bwd_stream); 445 #else 446 context_.src_mem->set_data_handle( 447 static_cast<void*>(const_cast<T*>(src_data))); 448 context_.mean_mem->set_data_handle( 449 static_cast<void*>(const_cast<U*>(mean_data))); 450 context_.variance_mem->set_data_handle( 451 static_cast<void*>(const_cast<U*>(variance_data))); 452 context_.diff_dst_mem->set_data_handle( 453 static_cast<void*>(const_cast<T*>(diff_dst_data))); 454 455 if (IS_SET(use_scale_shift)) { 456 context_.weights_mem->set_data_handle( 457 static_cast<void*>(const_cast<U*>(weights_data))); 458 context_.diff_weights_mem->set_data_handle( 459 static_cast<void*>(diff_weights_data)); 460 } 461 462 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); 463 #endif // ENABLE_MKLDNN_THREADPOOL 464 // Execute backward batch-normalization primitives. 465 DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); 466 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); 467 468 // After execution, set data handle back to DummyData. 469 context_.src_mem->set_data_handle(DummyData); 470 context_.mean_mem->set_data_handle(DummyData); 471 context_.variance_mem->set_data_handle(DummyData); 472 context_.diff_dst_mem->set_data_handle(DummyData); 473 if (IS_SET(use_scale_shift)) { 474 context_.weights_mem->set_data_handle(DummyData); 475 context_.diff_weights_mem->set_data_handle(DummyData); 476 } 477 context_.diff_src_mem->set_data_handle(DummyData); 478 } 479 GetBatchNormBwdPd() const480 std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const { 481 return context_.bwd_pd; 482 } 483 GetDiffSrcPd()484 memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); } 485 486 private: 487 struct BatchNormBwdContext { 488 // Flags to indicate whether it is training or inference. 489 int64 flags; 490 491 // Inputs/output memory. 492 std::shared_ptr<mkldnn::memory> src_mem; 493 std::shared_ptr<mkldnn::memory> mean_mem; 494 std::shared_ptr<mkldnn::memory> variance_mem; 495 std::shared_ptr<mkldnn::memory> diff_dst_mem; 496 std::shared_ptr<mkldnn::memory> weights_mem; 497 std::shared_ptr<mkldnn::memory> diff_weights_mem; 498 std::shared_ptr<mkldnn::memory> diff_src_mem; 499 500 // Backward batch-normalization primitive descriptor. 501 std::shared_ptr<BatchNormBwdPd> bwd_pd; 502 503 // Backward batch-normalization primitive. 504 std::shared_ptr<mkldnn::primitive> bn_bwd; 505 std::vector<mkldnn::primitive> bwd_primitives; 506 507 std::vector<std::unordered_map<int, memory>> net_args; 508 BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext509 BatchNormBwdContext() 510 : src_mem(nullptr), 511 mean_mem(nullptr), 512 variance_mem(nullptr), 513 diff_dst_mem(nullptr), 514 weights_mem(nullptr), 515 diff_weights_mem(nullptr), 516 diff_src_mem(nullptr) {} 517 }; 518 Setup(const MklBatchNormBwdParams & bwdParams)519 void Setup(const MklBatchNormBwdParams& bwdParams) { 520 context_.flags = 521 bwdParams.training 522 ? GET_FLAG(use_scale_shift) 523 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 524 525 // Memory descriptors. 526 auto src_md = bwdParams.src_md; 527 auto diff_dst_md = bwdParams.diff_dst_md; 528 auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 529 memory::format_tag::nc); 530 auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 531 memory::format_tag::nc); 532 auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(), 533 memory::format_tag::nc); 534 auto diff_weights_desc = weights_desc; 535 536 // Forward batch-normalization descriptor and primitive descriptor. 537 // Adding this back due to type difference with context.flags 538 auto bn_flags = bwdParams.training 539 ? mkldnn::normalization_flags::use_scale_shift 540 : (mkldnn::normalization_flags::use_scale_shift | 541 mkldnn::normalization_flags::use_global_stats); 542 auto fwd_desc = batch_normalization_forward::desc( 543 prop_kind::forward_training, src_md, bwdParams.eps, bn_flags); 544 auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_); 545 546 // Backward batch-normalization primitive. 547 // For inference, specify use_global_stats 548 // 1. on fwd propagation, use mean and variance provided as inputs. 549 // 2. on bwd propagation, mean and variance are considered as constants. 550 // Thus, reduce the amount of MKL computation. 551 auto bwd_desc = batch_normalization_backward::desc( 552 prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags); 553 context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd)); 554 555 // Create memory primitives. 556 context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 557 context_.diff_dst_mem.reset( 558 new memory(diff_dst_md, cpu_engine_, DummyData)); 559 context_.variance_mem.reset( 560 new memory(variance_desc, cpu_engine_, DummyData)); 561 context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData)); 562 context_.weights_mem.reset( 563 new memory(weights_desc, cpu_engine_, DummyData)); 564 context_.diff_weights_mem.reset( 565 new memory(diff_weights_desc, cpu_engine_, DummyData)); 566 context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 567 568 context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); 569 context_.net_args.push_back( 570 {{MKLDNN_ARG_SRC, *context_.src_mem}, 571 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 572 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 573 {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem}, 574 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 575 {MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem}, 576 {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}}); 577 context_.bwd_primitives.push_back(*context_.bn_bwd); 578 } 579 580 struct BatchNormBwdContext context_; 581 }; 582 583 template <typename T, typename U> 584 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 585 public: Get(const MklBatchNormBwdParams & bwdParams)586 static MklFusedBatchNormBwdPrimitive<T, U>* Get( 587 const MklBatchNormBwdParams& bwdParams) { 588 auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>( 589 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance() 590 .GetBatchNormBwd(bwdParams)); 591 if (bn_bwd == nullptr) { 592 bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams); 593 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd( 594 bwdParams, bn_bwd); 595 } 596 return bn_bwd; 597 } 598 GetInstance()599 static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { 600 static MklFusedBatchNormBwdPrimitiveFactory instance_; 601 return instance_; 602 } 603 604 private: MklFusedBatchNormBwdPrimitiveFactory()605 MklFusedBatchNormBwdPrimitiveFactory() {} ~MklFusedBatchNormBwdPrimitiveFactory()606 ~MklFusedBatchNormBwdPrimitiveFactory() {} 607 CreateKey(const MklBatchNormBwdParams & bwdParams)608 static string CreateKey(const MklBatchNormBwdParams& bwdParams) { 609 string prefix = "bn_bwd"; 610 FactoryKeyCreator key_creator; 611 key_creator.AddAsKey(prefix); 612 key_creator.AddAsKey(bwdParams.src_dims); 613 key_creator.AddAsKey(bwdParams.diff_dst_dims); 614 key_creator.AddAsKey<int>(bwdParams.depth); 615 key_creator.AddAsKey<float>(bwdParams.eps); 616 key_creator.AddAsKey<bool>(bwdParams.training); 617 key_creator.AddAsKey(typeid(T).name()); 618 key_creator.AddAsKey(typeid(U).name()); 619 return key_creator.GetKey(); 620 } 621 GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)622 MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { 623 string key = CreateKey(bwdParams); 624 return this->GetOp(key); 625 } 626 SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)627 void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, 628 MklPrimitive* op) { 629 string key = CreateKey(bwdParams); 630 this->SetOp(key, op); 631 } 632 }; 633 634 // Adding a third parameter to the template to support FusedBatchNormV3 635 // with MKL. This is different from default where the classes are 636 // derived. Moves enabling to compile-time rather than runtime. 637 template <typename Device, typename T, typename U, bool reserved_space, 638 bool is_batch_norm_ex = false, bool native_format = false> 639 class MklFusedBatchNormOp : public OpKernel { 640 public: MklFusedBatchNormOp(OpKernelConstruction * context)641 explicit MklFusedBatchNormOp(OpKernelConstruction* context) 642 : OpKernel(context) { 643 float epsilon; 644 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 645 epsilon_ = epsilon; 646 float exponential_avg_factor; 647 OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor", 648 &exponential_avg_factor)); 649 exponential_avg_factor_ = static_cast<U>(exponential_avg_factor); 650 string tensor_format; 651 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 652 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 653 errors::InvalidArgument("Invalid data format")); 654 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 655 depth_ = 0; 656 mean_values_ = nullptr; 657 variance_values_ = nullptr; 658 659 if (!is_batch_norm_ex) { 660 activation_mode_ = FusedBNActivationMode::kIdentity; 661 } else { 662 int num_side_inputs; 663 OP_REQUIRES_OK(context, 664 context->GetAttr("num_side_inputs", &num_side_inputs)); 665 // Currently _MKLFusedBatchNormEx do not support "SideInput" 666 OP_REQUIRES(context, num_side_inputs == 0, 667 errors::InvalidArgument( 668 "_MKLFusedBatchNorm do not support side input now.")); 669 670 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); 671 OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu, 672 errors::InvalidArgument( 673 "_MKLFusedBatchNorm only support Relu activation")); 674 } 675 } 676 Compute(OpKernelContext * context)677 void Compute(OpKernelContext* context) override { 678 try { 679 const size_t kSrcIndex = 0; // index of src input tensor 680 const size_t kScaleIndex = 1; // index of scale tensor 681 const size_t kShiftIndex = 2; // index of shift tensor 682 const size_t kMeanIndex = 3; // index of est_mean tensor 683 const size_t kVarianceIndex = 4; // index of est_variance tensor 684 685 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 686 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 687 const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); 688 const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); 689 const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); 690 691 TensorShape tf_shape_src; 692 MklDnnShape dnn_shape_src; 693 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 694 695 if (dnn_shape_src.IsMklTensor()) { 696 tf_shape_src = dnn_shape_src.GetTfShape(); 697 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 698 errors::InvalidArgument("input must be 4-dimensional", 699 src_tensor.shape().DebugString())); 700 } else { 701 tf_shape_src = src_tensor.shape(); 702 OP_REQUIRES(context, src_tensor.dims() == 4, 703 errors::InvalidArgument("input must be 4-dimensional", 704 src_tensor.shape().DebugString())); 705 } 706 OP_REQUIRES(context, scale_tensor.dims() == 1, 707 errors::InvalidArgument("scale must be 1-dimensional", 708 scale_tensor.shape().DebugString())); 709 OP_REQUIRES(context, shift_tensor.dims() == 1, 710 errors::InvalidArgument("offset must be 1-dimensional", 711 shift_tensor.shape().DebugString())); 712 OP_REQUIRES( 713 context, est_mean_tensor.dims() == 1, 714 errors::InvalidArgument("estimated_mean must be 1-dimensional", 715 est_mean_tensor.shape().DebugString())); 716 OP_REQUIRES( 717 context, est_variance_tensor.dims() == 1, 718 errors::InvalidArgument("estimated_variance must be 1-dimensional", 719 est_variance_tensor.shape().DebugString())); 720 721 // Handle the special case: input with 0 element and 0 batch size. 722 Tensor* dst_tensor = nullptr; 723 TensorShape workspace_tf_shape; 724 if (tf_shape_src.num_elements() == 0) { 725 size_t workspace_bytes = 0; 726 workspace_tf_shape.AddDim(workspace_bytes); 727 HandleEmptyInput(context, tf_shape_src, workspace_tf_shape, 728 scale_tensor.shape(), &dst_tensor); 729 return; 730 } 731 732 if (dnn_shape_src.IsMklTensor()) 733 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 734 else 735 ExtractParams(context); 736 737 // Index of output tensor(diff_src). 738 const size_t kDstIndex = 0; 739 740 // Allocate 5 output TF tensors. 741 Tensor* batch_mean_tensor = nullptr; 742 Tensor* batch_variance_tensor = nullptr; 743 Tensor* saved_mean_tensor = nullptr; 744 Tensor* saved_variance_tensor = nullptr; 745 Tensor* reserved_space_tensor = nullptr; 746 747 MklDnnData<T> src(&cpu_engine_); 748 MklDnnData<U> weights(&cpu_engine_); 749 MklDnnData<U> wksp(&cpu_engine_); 750 751 memory::format_tag dnn_fmt; 752 MklTensorFormat mkl_tensor_fmt; 753 if (dnn_shape_src.IsMklTensor()) { 754 if (dnn_shape_src.IsTensorInNCHWFormat()) { 755 dnn_fmt = memory::format_tag::nchw; 756 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 757 } else { 758 dnn_fmt = memory::format_tag::nhwc; 759 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 760 } 761 } else { 762 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 763 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 764 } 765 766 // Set src memory descriptor. 767 memory::dims src_dims = 768 dnn_shape_src.IsMklTensor() 769 ? dnn_shape_src.GetSizesAsMklDnnDims() 770 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 771 772 auto src_md = dnn_shape_src.IsMklTensor() 773 ? dnn_shape_src.GetMklLayout() 774 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 775 776 MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, 777 src_md, activation_mode_); 778 779 // Get forward batch-normalization op from the primitive caching pool. 780 MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = 781 MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); 782 783 // Allocate workspace tensor 784 U* ws_data = nullptr; 785 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 786 memory::desc workspace_md = 787 bn_fwd->GetBatchNormFwdPd()->workspace_desc(); 788 size_t workspace_bytes = workspace_md.get_size(); 789 workspace_tf_shape.AddDim(workspace_bytes); 790 791 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 792 &batch_mean_tensor, &batch_variance_tensor, 793 &saved_mean_tensor, &saved_variance_tensor, 794 &reserved_space_tensor); 795 if (reserved_space) { 796 wksp.SetUsrMem(workspace_md, reserved_space_tensor); 797 ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle()); 798 } 799 } else { 800 // There is actually no workspace tensor out, so we make a dummy one. 801 size_t workspace_bytes = 0; 802 workspace_tf_shape.AddDim(workspace_bytes); 803 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 804 &batch_mean_tensor, &batch_variance_tensor, 805 &saved_mean_tensor, &saved_variance_tensor, 806 &reserved_space_tensor); 807 } 808 809 if (is_training_) 810 SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); 811 else 812 SetMeanVariance(est_mean_tensor, est_variance_tensor); 813 814 // MKL-DNN packs scale & shift as "weights": 815 // <scale>...<scale><shift>...<shift> 816 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 817 U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 818 const U* scale_tf = scale_tensor.flat<U>().data(); 819 const U* shift_tf = shift_tensor.flat<U>().data(); 820 821 std::memcpy(weights_data, scale_tf, depth_ * sizeof(U)); 822 std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U)); 823 char* saved_mean_data_tf = 824 reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data()); 825 std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), 826 depth_ * sizeof(U)); 827 828 char* saved_variance_data_tf = 829 reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data()); 830 std::memcpy(saved_variance_data_tf, 831 reinterpret_cast<char*>(variance_values_), 832 depth_ * sizeof(U)); 833 834 // Check if reorder is needed for src. 835 const T* src_data = nullptr; 836 std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); 837 if (!native_format && src_md != bn_fwd_pd->src_desc()) { 838 src.SetUsrMem(src_md, &src_tensor); 839 src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context); 840 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 841 } else { 842 src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 843 } 844 845 // Allocate output (dst) tensor 846 MklDnnShape dnn_shape_dst; 847 TensorShape tf_shape_dst; 848 dnn_shape_dst.SetMklTensor(true); 849 auto dst_pd = bn_fwd->GetDstPd(); 850 dnn_shape_dst.SetMklLayout(&dst_pd); 851 dnn_shape_dst.SetElemType(MklDnnType<T>()); 852 auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() 853 : src_tensor.shape().dims(); 854 dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt); 855 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 856 if (native_format) { 857 tf_shape_dst = dnn_shape_dst.GetTfShape(); 858 } 859 AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, 860 dnn_shape_dst, native_format); 861 862 U* weights_op_data = weights_data; 863 U* mean_op_data = saved_mean_tensor->flat<U>().data(); 864 U* variance_op_data = saved_variance_tensor->flat<U>().data(); 865 T* dst_data = dst_tensor->flat<T>().data(); 866 867 // Execute 868 std::shared_ptr<stream> fwd_cpu_stream; 869 fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine())); 870 bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, 871 variance_op_data, fwd_cpu_stream, ws_data); 872 float adjust_factor = 1.0; 873 if (is_training_) { 874 size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; 875 size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1; 876 adjust_factor = (static_cast<float>(orig_size)) / adjust_size; 877 } 878 879 auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf); 880 auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf); 881 auto batch_mean_data = batch_mean_tensor->flat<U>().data(); 882 auto batch_variance_data = batch_variance_tensor->flat<U>().data(); 883 auto est_mean_data = est_mean_tensor.flat<U>().data(); 884 auto est_variance_data = est_variance_tensor.flat<U>().data(); 885 if (is_training_) { 886 if (exponential_avg_factor_ == U(1.0)) { 887 for (int k = 0; k < depth_; k++) { 888 batch_mean_data[k] = mean_data[k]; 889 batch_variance_data[k] = 890 static_cast<U>(adjust_factor) * variance_data[k]; 891 } 892 } else { 893 U one_minus_factor = U(1.0) - exponential_avg_factor_; 894 for (int k = 0; k < depth_; k++) { 895 batch_mean_data[k] = one_minus_factor * est_mean_data[k] + 896 exponential_avg_factor_ * mean_data[k]; 897 batch_variance_data[k] = one_minus_factor * est_variance_data[k] + 898 exponential_avg_factor_ * 899 static_cast<U>(adjust_factor) * 900 variance_data[k]; 901 } 902 } 903 } else { 904 std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U)); 905 std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U)); 906 } 907 } catch (mkldnn::error& e) { 908 string error_msg = "Status: " + std::to_string(e.status) + 909 ", message: " + string(e.message) + ", in file " + 910 string(__FILE__) + ":" + std::to_string(__LINE__); 911 OP_REQUIRES_OK( 912 context, 913 errors::Aborted("Operation received an exception:", error_msg)); 914 } 915 } 916 917 private: 918 float epsilon_; 919 U exponential_avg_factor_; 920 TensorFormat tensor_format_; 921 bool is_training_; 922 U* mean_values_; 923 U* variance_values_; 924 size_t depth_; // Batch normalization is performed for per channel. 925 FusedBNActivationMode activation_mode_; 926 engine cpu_engine_ = engine(engine::kind::cpu, 0); 927 ExtractParams(OpKernelContext * context)928 void ExtractParams(OpKernelContext* context) { 929 const Tensor& input = MklGetInput(context, 0); 930 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 931 } 932 SetMeanVariance(const Tensor & mean,const Tensor & variance)933 void SetMeanVariance(const Tensor& mean, const Tensor& variance) { 934 mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data())); 935 variance_values_ = 936 reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data())); 937 } 938 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)939 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 940 TensorShape workspace_tf_shape, 941 TensorShape tf_shape_scale, Tensor** dst_tensor) { 942 DCHECK(dst_tensor); 943 944 const size_t kDstIndex = 0; 945 MklDnnShape dnn_shape_dst; 946 dnn_shape_dst.SetMklTensor(false); 947 AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, 948 dnn_shape_dst, native_format); 949 DCHECK(*dst_tensor); 950 memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, 951 (*dst_tensor)->tensor_data().size()); 952 953 Tensor* batch_mean_tensor = nullptr; 954 Tensor* batch_variance_tensor = nullptr; 955 Tensor* saved_mean_tensor = nullptr; 956 Tensor* saved_variance_tensor = nullptr; 957 Tensor* reserved_space_tensor = nullptr; 958 AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape, 959 &batch_mean_tensor, &batch_variance_tensor, 960 &saved_mean_tensor, &saved_variance_tensor, 961 &reserved_space_tensor); 962 } 963 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)964 void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, 965 TensorShape workspace_tf_shape, 966 Tensor** batch_mean_tensor, 967 Tensor** batch_variance_tensor, 968 Tensor** saved_mean_tensor, 969 Tensor** saved_variance_tensor, 970 Tensor** reserved_space_tensor) { 971 DCHECK(batch_mean_tensor); 972 DCHECK(batch_variance_tensor); 973 DCHECK(saved_mean_tensor); 974 DCHECK(saved_variance_tensor); 975 976 const size_t kBatchMeanIndex = 1; 977 const size_t kBatchVarianceIndex = 2; 978 const size_t kSavedMeanIndex = 3; 979 const size_t kSavedVarianceIndex = 4; 980 const size_t kReservedSpaceIndex = 5; 981 982 // Allocate batch mean output tensor. 983 MklDnnShape mkl_shape_batch_mean; 984 mkl_shape_batch_mean.SetMklTensor(false); 985 AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, 986 tf_shape_scale, mkl_shape_batch_mean, 987 native_format); 988 DCHECK(*batch_mean_tensor); 989 990 // Set NAN mean value in case of empty input tensor 991 int num_elements = tf_shape_scale.num_elements(); 992 auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data(); 993 std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN)); 994 995 // Allocate batch variance output tensor. 996 MklDnnShape mkl_shape_batch_variance; 997 mkl_shape_batch_variance.SetMklTensor(false); 998 AllocateOutputSetMklShape(context, kBatchVarianceIndex, 999 batch_variance_tensor, tf_shape_scale, 1000 mkl_shape_batch_variance, native_format); 1001 DCHECK(*batch_variance_tensor); 1002 1003 // Set NAN variance value in case of empty input tensor 1004 auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data(); 1005 std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN)); 1006 // Mean and variance (without Bessel's correction) saved for backward 1007 // computation to serve as pre-computed mean and variance. 1008 MklDnnShape mkl_shape_saved_mean; 1009 mkl_shape_saved_mean.SetMklTensor(false); 1010 AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, 1011 tf_shape_scale, mkl_shape_saved_mean, 1012 native_format); 1013 DCHECK(*saved_mean_tensor); 1014 1015 // Set 0 mean value in case of empty input tensor 1016 auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data(); 1017 std::fill_n(saved_mean_data, num_elements, static_cast<U>(0)); 1018 1019 MklDnnShape mkl_shape_saved_variance; 1020 mkl_shape_saved_variance.SetMklTensor(false); 1021 AllocateOutputSetMklShape(context, kSavedVarianceIndex, 1022 saved_variance_tensor, tf_shape_scale, 1023 mkl_shape_saved_variance, native_format); 1024 DCHECK(*saved_variance_tensor); 1025 1026 // Set 0 variance value in case of empty input tensor 1027 auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data(); 1028 std::fill_n(saved_variance_data, num_elements, static_cast<U>(0)); 1029 1030 // Changes to support reserved_space_3 parameter in FusedBatchNormV3. 1031 if (reserved_space) { 1032 DCHECK(reserved_space_tensor != nullptr); 1033 1034 MklDnnShape mkl_shape_reserved_space; 1035 mkl_shape_reserved_space.SetMklTensor(false); 1036 AllocateOutputSetMklShape(context, kReservedSpaceIndex, 1037 reserved_space_tensor, workspace_tf_shape, 1038 mkl_shape_reserved_space, native_format); 1039 DCHECK((*reserved_space_tensor) != nullptr); 1040 } 1041 } 1042 }; 1043 1044 template <typename Device, typename T, typename U, bool reserved_space, 1045 bool native_format = false> 1046 class MklFusedBatchNormGradOp : public OpKernel { 1047 public: MklFusedBatchNormGradOp(OpKernelConstruction * context)1048 explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) 1049 : OpKernel(context) { 1050 float epsilon; 1051 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1052 epsilon_ = epsilon; 1053 string tensor_format; 1054 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 1055 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 1056 errors::InvalidArgument("Invalid data format")); 1057 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 1058 depth_ = 0; 1059 } 1060 Compute(OpKernelContext * context)1061 void Compute(OpKernelContext* context) override { 1062 try { 1063 const size_t kDiffDstIndex = 0; // index of diff_dst tensor 1064 const size_t kSrcIndex = 1; // index of src input tensor 1065 const size_t kScaleIndex = 2; // index of scale tensor 1066 const size_t kMeanIndex = 3; // index of saved_mean tensor 1067 const size_t kVarianceIndex = 4; // index of saved_variance tensor 1068 const size_t kReservedSpaceIndex = 5; // index of reserved space 3 tensor 1069 1070 const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); 1071 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 1072 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 1073 const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); 1074 const Tensor& saved_variance_tensor = 1075 MklGetInput(context, kVarianceIndex); 1076 const Tensor& reserved_space_tensor = 1077 (reserved_space) ? MklGetInput(context, kReservedSpaceIndex) 1078 : Tensor(); 1079 1080 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 1081 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 1082 GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format); 1083 1084 TensorShape tf_shape_src, tf_shape_diff_dst; 1085 if (dnn_shape_diff_dst.IsMklTensor()) { 1086 tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); 1087 OP_REQUIRES( 1088 context, dnn_shape_diff_dst.GetDimension() == 4, 1089 errors::InvalidArgument("input must be 4-dimensional", 1090 diff_dst_tensor.shape().DebugString())); 1091 } else { 1092 tf_shape_diff_dst = diff_dst_tensor.shape(); 1093 OP_REQUIRES( 1094 context, diff_dst_tensor.dims() == 4, 1095 errors::InvalidArgument("input must be 4-dimensional", 1096 diff_dst_tensor.shape().DebugString())); 1097 } 1098 1099 if (dnn_shape_src.IsMklTensor()) { 1100 tf_shape_src = dnn_shape_src.GetTfShape(); 1101 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 1102 errors::InvalidArgument("input must be 4-dimensional", 1103 src_tensor.shape().DebugString())); 1104 } else { 1105 tf_shape_src = src_tensor.shape(); 1106 OP_REQUIRES(context, src_tensor.dims() == 4, 1107 errors::InvalidArgument("input must be 4-dimensional", 1108 src_tensor.shape().DebugString())); 1109 } 1110 1111 OP_REQUIRES(context, scale_tensor.dims() == 1, 1112 errors::InvalidArgument("scale must be 1-dimensional", 1113 scale_tensor.shape().DebugString())); 1114 OP_REQUIRES( 1115 context, saved_mean_tensor.dims() == 1, 1116 errors::InvalidArgument("saved mean must be 1-dimensional", 1117 saved_mean_tensor.shape().DebugString())); 1118 1119 OP_REQUIRES( 1120 context, saved_variance_tensor.dims() == 1, 1121 errors::InvalidArgument("saved variance must be 1-dimensional", 1122 saved_variance_tensor.shape().DebugString())); 1123 1124 // Handle the special case: input with 0 element and 0 batch size. 1125 Tensor* diff_src_tensor = nullptr; 1126 if (tf_shape_src.num_elements() == 0 || 1127 tf_shape_diff_dst.num_elements() == 0) { 1128 HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), 1129 &diff_src_tensor); 1130 return; 1131 } 1132 1133 if (dnn_shape_src.IsMklTensor()) { 1134 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 1135 } else if (dnn_shape_diff_dst.IsMklTensor()) { 1136 depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C); 1137 } else { 1138 ExtractParams(context); 1139 } 1140 1141 memory::format_tag dnn_fmt; 1142 MklTensorFormat mkl_tensor_fmt; 1143 if (dnn_shape_src.IsMklTensor()) { 1144 if (dnn_shape_src.IsTensorInNCHWFormat()) { 1145 dnn_fmt = memory::format_tag::nchw; 1146 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 1147 } else { 1148 dnn_fmt = memory::format_tag::nhwc; 1149 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 1150 } 1151 } else { 1152 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 1153 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 1154 } 1155 1156 MklDnnData<T> src(&cpu_engine_); 1157 MklDnnData<T> diff_dst(&cpu_engine_); 1158 MklDnnData<U> weights(&cpu_engine_); 1159 MklDnnData<U> diff_weights(&cpu_engine_); 1160 1161 memory::dims src_dims = 1162 dnn_shape_src.IsMklTensor() 1163 ? dnn_shape_src.GetSizesAsMklDnnDims() 1164 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 1165 memory::dims diff_dst_dims = 1166 dnn_shape_diff_dst.IsMklTensor() 1167 ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() 1168 : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), 1169 tensor_format_); 1170 1171 // Set src and diff_dst primitive descriptors. 1172 memory::desc src_md = 1173 dnn_shape_src.IsMklTensor() 1174 ? dnn_shape_src.GetMklLayout() 1175 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 1176 memory::desc diff_dst_md = 1177 dnn_shape_diff_dst.IsMklTensor() 1178 ? dnn_shape_diff_dst.GetMklLayout() 1179 : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt); 1180 1181 MklDnnData<T> reorder_src(&cpu_engine_); 1182 MklDnnData<T> reorder_diff_dst(&cpu_engine_); 1183 T* diff_dst_data = 1184 static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 1185 T* src_data = 1186 static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 1187 1188 if (!native_format) { 1189 // MKL-DNN requires src and diff_dst to be in same memory layout, either 1190 // blocked or native format. If these inputs are in different formats, 1191 // convert the one in native format to blocked format as MKL-DNN gives 1192 // better performance for blocked format. 1193 if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { 1194 reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 1195 reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context); 1196 diff_dst_md = src_md; 1197 diff_dst_data = 1198 static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle()); 1199 } else if (!dnn_shape_src.IsMklTensor() && 1200 dnn_shape_diff_dst.IsMklTensor()) { 1201 reorder_src.SetUsrMem(src_md, &src_tensor); 1202 reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context); 1203 src_md = diff_dst_md; 1204 src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle()); 1205 } 1206 } 1207 1208 // weights -- MKL DNN packs scales/ shifts as weights in order 1209 // of scale, ..., scale, shift, ...., shift 1210 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1211 U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 1212 const U* scale_tf = scale_tensor.flat<U>().data(); 1213 for (int k = 0; k < depth_; k++) { 1214 weights_data_tf[k] = scale_tf[k]; 1215 weights_data_tf[k + depth_] = static_cast<U>(0); 1216 } 1217 1218 diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1219 1220 MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, 1221 is_training_, src_md, diff_dst_md); 1222 MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd = 1223 MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams); 1224 1225 // Check if diff_dst input needs to be reordered 1226 std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); 1227 if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) { 1228 diff_dst.SetUsrMem(diff_dst_md, diff_dst_data); 1229 diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_, 1230 context); 1231 diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle()); 1232 } 1233 1234 if (!native_format && (src_md != bn_bwd_pd->src_desc())) { 1235 src.SetUsrMem(src_md, src_data); 1236 src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context); 1237 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 1238 } 1239 1240 // Indices of output tensors 1241 const size_t kDiffSrcIndex = 0; 1242 1243 // Allocate output tensor diff_src, always set as MKL-DNN layout. 1244 MklDnnShape dnn_shape_diff_src; 1245 TensorShape tf_shape_diff_src; 1246 dnn_shape_diff_src.SetMklTensor(true); 1247 auto diff_src_pd = bn_bwd->GetDiffSrcPd(); 1248 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 1249 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 1250 dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt); 1251 dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); 1252 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 1253 if (native_format) { 1254 tf_shape_diff_src = dnn_shape_diff_src.GetTfShape(); 1255 } 1256 AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, 1257 tf_shape_diff_src, dnn_shape_diff_src, 1258 native_format); 1259 1260 U* mean_data = 1261 static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data())); 1262 U* variance_data = static_cast<U*>( 1263 const_cast<U*>(saved_variance_tensor.flat<U>().data())); 1264 U* weights_data = weights_data_tf; 1265 T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); 1266 U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer()); 1267 1268 U* res_space_data = 1269 ((reserved_space) ? static_cast<U*>(const_cast<U*>( 1270 reserved_space_tensor.flat<U>().data())) 1271 : nullptr); 1272 1273 // Execute 1274 std::shared_ptr<stream> bwd_cpu_stream; 1275 bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine())); 1276 bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, 1277 weights_data, diff_src_data, diff_weights_data, 1278 res_space_data, bwd_cpu_stream); 1279 // Allocate output TF tensors diff_scale and diff_shift. 1280 Tensor* diff_scale_tensor = nullptr; 1281 Tensor* diff_shift_tensor = nullptr; 1282 AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, 1283 &diff_shift_tensor); 1284 1285 // Copy data for tensors diff_scale and diff_shift. 1286 auto diff_scale_data = diff_scale_tensor->flat<U>().data(); 1287 auto diff_shift_data = diff_shift_tensor->flat<U>().data(); 1288 std::memcpy(reinterpret_cast<char*>(diff_scale_data), 1289 reinterpret_cast<char*>(diff_weights_data), 1290 depth_ * sizeof(U)); 1291 std::memcpy(reinterpret_cast<char*>(diff_shift_data), 1292 reinterpret_cast<char*>(diff_weights_data + depth_), 1293 depth_ * sizeof(U)); 1294 } catch (mkldnn::error& e) { 1295 string error_msg = "Status: " + std::to_string(e.status) + 1296 ", message: " + string(e.message) + ", in file " + 1297 string(__FILE__) + ":" + std::to_string(__LINE__); 1298 OP_REQUIRES_OK( 1299 context, 1300 errors::Aborted("Operation received an exception:", error_msg)); 1301 } 1302 } 1303 1304 private: 1305 float epsilon_; 1306 TensorFormat tensor_format_; 1307 size_t depth_; // Batch normalization is performed for per channel. 1308 bool is_training_; 1309 engine cpu_engine_ = engine(engine::kind::cpu, 0); 1310 ExtractParams(OpKernelContext * context)1311 void ExtractParams(OpKernelContext* context) { 1312 const Tensor& input = MklGetInput(context, 0); 1313 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 1314 } 1315 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1316 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 1317 TensorShape tf_shape_scale_shift, 1318 Tensor** diff_src_tensor) { 1319 const size_t kDiffSrcIndex = 0; 1320 1321 MklDnnShape dnn_shape_diff_src; 1322 dnn_shape_diff_src.SetMklTensor(false); 1323 AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, 1324 tf_shape_src, dnn_shape_diff_src, native_format); 1325 auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); 1326 std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 1327 static_cast<T>(0)); 1328 1329 Tensor* diff_scale_tensor = nullptr; 1330 Tensor* diff_shift_tensor = nullptr; 1331 AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, 1332 &diff_shift_tensor); 1333 } 1334 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1335 void AllocateTFOutputs(OpKernelContext* context, 1336 TensorShape tf_shape_scale_shift, 1337 Tensor** diff_scale_tensor, 1338 Tensor** diff_shift_tensor) { 1339 DCHECK(diff_scale_tensor); 1340 DCHECK(diff_shift_tensor); 1341 1342 const size_t kDiffScaleIndex = 1; 1343 const size_t kDiffShiftIndex = 2; 1344 const size_t kP1Index = 3; 1345 const size_t kP2Index = 4; 1346 1347 // Separate out scale and shift grad and copy to individual tensors 1348 MklDnnShape mkl_shape_diff_scale; 1349 mkl_shape_diff_scale.SetMklTensor(false); 1350 AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, 1351 tf_shape_scale_shift, mkl_shape_diff_scale, 1352 native_format); 1353 DCHECK(*diff_scale_tensor); 1354 1355 auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data(); 1356 std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), 1357 static_cast<U>(0)); 1358 1359 MklDnnShape mkl_shape_diff_shift; 1360 mkl_shape_diff_shift.SetMklTensor(false); 1361 AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, 1362 tf_shape_scale_shift, mkl_shape_diff_shift, 1363 native_format); 1364 DCHECK(*diff_shift_tensor); 1365 1366 auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data(); 1367 std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), 1368 static_cast<U>(0)); 1369 1370 // Placeholders for estimated_mean and estimated_variance, which are 1371 // used for inference and thus not needed here for gradient computation. 1372 Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; 1373 MklDnnShape mkl_shape_p; 1374 mkl_shape_p.SetMklTensor(false); 1375 AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), 1376 mkl_shape_p, native_format); 1377 std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(), 1378 static_cast<U>(0)); 1379 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), 1380 mkl_shape_p, native_format); 1381 std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(), 1382 static_cast<U>(0)); 1383 } 1384 GetMeanVarianceDims()1385 memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } 1386 }; 1387 1388 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T) \ 1389 REGISTER_KERNEL_BUILDER( \ 1390 Name("_MklFusedBatchNorm") \ 1391 .Device(DEVICE_CPU) \ 1392 .TypeConstraint<T>("T") \ 1393 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1394 MklFusedBatchNormOp<CPUDevice, T, T, false, false>); \ 1395 REGISTER_KERNEL_BUILDER( \ 1396 Name("_MklNativeFusedBatchNorm") \ 1397 .Device(DEVICE_CPU) \ 1398 .TypeConstraint<T>("T") \ 1399 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1400 MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>); 1401 1402 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1403 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1404 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU 1405 1406 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U) \ 1407 REGISTER_KERNEL_BUILDER( \ 1408 Name("_MklFusedBatchNormV2") \ 1409 .Device(DEVICE_CPU) \ 1410 .TypeConstraint<T>("T") \ 1411 .TypeConstraint<U>("U") \ 1412 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1413 MklFusedBatchNormOp<CPUDevice, T, U, false, false>); \ 1414 REGISTER_KERNEL_BUILDER( \ 1415 Name("_MklNativeFusedBatchNormV2") \ 1416 .Device(DEVICE_CPU) \ 1417 .TypeConstraint<T>("T") \ 1418 .TypeConstraint<U>("U") \ 1419 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1420 MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>); 1421 1422 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); 1423 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); 1424 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU 1425 1426 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T) \ 1427 REGISTER_KERNEL_BUILDER( \ 1428 Name("_MklFusedBatchNormGrad") \ 1429 .Device(DEVICE_CPU) \ 1430 .TypeConstraint<T>("T") \ 1431 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1432 MklFusedBatchNormGradOp<CPUDevice, T, T, false>); \ 1433 REGISTER_KERNEL_BUILDER( \ 1434 Name("_MklNativeFusedBatchNormGrad") \ 1435 .Device(DEVICE_CPU) \ 1436 .TypeConstraint<T>("T") \ 1437 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1438 MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>); 1439 1440 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1441 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1442 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU 1443 1444 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U) \ 1445 REGISTER_KERNEL_BUILDER( \ 1446 Name("_MklFusedBatchNormGradV2") \ 1447 .Device(DEVICE_CPU) \ 1448 .TypeConstraint<T>("T") \ 1449 .TypeConstraint<U>("U") \ 1450 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1451 MklFusedBatchNormGradOp<CPUDevice, T, U, false>); \ 1452 REGISTER_KERNEL_BUILDER( \ 1453 Name("_MklNativeFusedBatchNormGradV2") \ 1454 .Device(DEVICE_CPU) \ 1455 .TypeConstraint<T>("T") \ 1456 .TypeConstraint<U>("U") \ 1457 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1458 MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>); 1459 1460 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); 1461 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); 1462 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU 1463 1464 // TODO: FusedBatchNormV3 has an additional output that is used to 1465 // hold intermediate results. This parameter functionality is 1466 // not implemented on CPU. 1467 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ 1468 REGISTER_KERNEL_BUILDER( \ 1469 Name("_MklFusedBatchNormV3") \ 1470 .Device(DEVICE_CPU) \ 1471 .TypeConstraint<T>("T") \ 1472 .TypeConstraint<U>("U") \ 1473 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1474 MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \ 1475 REGISTER_KERNEL_BUILDER( \ 1476 Name("_MklFusedBatchNormEx") \ 1477 .Device(DEVICE_CPU) \ 1478 .TypeConstraint<T>("T") \ 1479 .TypeConstraint<U>("U") \ 1480 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1481 MklFusedBatchNormOp<CPUDevice, T, U, true, true>); \ 1482 REGISTER_KERNEL_BUILDER( \ 1483 Name("_MklNativeFusedBatchNormV3") \ 1484 .Device(DEVICE_CPU) \ 1485 .TypeConstraint<T>("T") \ 1486 .TypeConstraint<U>("U") \ 1487 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1488 MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \ 1489 REGISTER_KERNEL_BUILDER( \ 1490 Name("_MklNativeFusedBatchNormEx") \ 1491 .Device(DEVICE_CPU) \ 1492 .TypeConstraint<T>("T") \ 1493 .TypeConstraint<U>("U") \ 1494 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1495 MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>); 1496 1497 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); 1498 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); 1499 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU 1500 1501 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1502 .Device(DEVICE_CPU) 1503 .TypeConstraint<float>("T") 1504 .TypeConstraint<float>("U"), 1505 NoOp); 1506 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1507 .Device(DEVICE_CPU) 1508 .TypeConstraint<bfloat16>("T") 1509 .TypeConstraint<float>("U"), 1510 NoOp); 1511 1512 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \ 1513 REGISTER_KERNEL_BUILDER( \ 1514 Name("_MklFusedBatchNormGradV3") \ 1515 .Device(DEVICE_CPU) \ 1516 .TypeConstraint<T>("T") \ 1517 .TypeConstraint<U>("U") \ 1518 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1519 MklFusedBatchNormGradOp<CPUDevice, T, U, true>); \ 1520 REGISTER_KERNEL_BUILDER( \ 1521 Name("_MklNativeFusedBatchNormGradV3") \ 1522 .Device(DEVICE_CPU) \ 1523 .TypeConstraint<T>("T") \ 1524 .TypeConstraint<U>("U") \ 1525 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1526 MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>); 1527 1528 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float); 1529 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); 1530 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU 1531 1532 } // namespace tensorflow 1533 1534 #undef GET_FLAG 1535 #undef IS_SET 1536 1537 #endif // INTEL_MKL 1538