1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifdef INTEL_MKL
16 #include "mkldnn.hpp"
17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
23 #include "tensorflow/core/kernels/no_op.h"
24 #include "tensorflow/core/util/mkl_util.h"
25 #include "tensorflow/core/util/tensor_format.h"
26 
27 #define GET_FLAG(bn_flag) static_cast<int>(mkldnn::normalization_flags::bn_flag)
28 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
29 
30 using mkldnn::batch_normalization_backward;
31 using mkldnn::batch_normalization_forward;
32 using mkldnn::prop_kind;
33 using mkldnn::stream;
34 
35 using BatchNormFwdPd = mkldnn::batch_normalization_forward::primitive_desc;
36 using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc;
37 
38 namespace tensorflow {
39 using CPUDevice = Eigen::ThreadPoolDevice;
40 
41 using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
42 
43 struct MklBatchNormFwdParams {
44   memory::dims src_dims;
45   int depth;
46   float eps;
47   bool training;
48   FusedBNActivationMode activation_mode;
49   memory::desc src_md;
50 
MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams51   MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
52                         bool training, memory::desc src_md,
53                         FusedBNActivationMode activation_mode)
54       : src_dims(src_dims),
55         depth(depth),
56         eps(eps),
57         training(training),
58         activation_mode(activation_mode),
59         src_md(src_md) {}
60 };
61 
62 template <typename T, typename U>
63 class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
64  public:
MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)65   explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
66       : MklPrimitive(engine(engine::kind::cpu, 0)) {
67     if (context_.bn_fwd == nullptr) Setup(fwdParams);
68   }
69 
~MklFusedBatchNormFwdPrimitive()70   ~MklFusedBatchNormFwdPrimitive() {}
71 
72   // BatchNormalization forward execute
73   //   src_data:     input data buffer of src
74   //   weights_data: input data buffer of weights
75   //   dst_data:     output data buffer of dst
76   //   mean_data:     output data buffer of means
77   //   variance_data: output data buffer of variances
Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)78   void Execute(const T* src_data, const U* weights_data, T* dst_data,
79                U* mean_data, U* variance_data,
80                std::shared_ptr<stream> fwd_stream, U* workspace_data) {
81 #ifdef ENABLE_MKLDNN_THREADPOOL
82     // TODO: Create a common function and avoid the duplicate code
83     context_.src_mem->set_data_handle(
84         static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
85     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
86                                       *fwd_stream);
87 
88     if (IS_SET(use_scale_shift))
89       context_.weights_mem->set_data_handle(
90           static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream);
91 
92     if ((context_.pkind == prop_kind::forward_training) ||
93         (IS_SET(use_global_stats))) {
94       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data),
95                                          *fwd_stream);
96       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data),
97                                              *fwd_stream);
98     }
99     if (workspace_data != nullptr) {
100       context_.ws_mem->set_data_handle(workspace_data, *fwd_stream);
101     }
102 #else
103     context_.src_mem->set_data_handle(
104         static_cast<void*>(const_cast<T*>(src_data)));
105     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
106 
107     if (IS_SET(use_scale_shift))
108       context_.weights_mem->set_data_handle(
109           static_cast<void*>(const_cast<U*>(weights_data)));
110 
111     if ((context_.pkind == prop_kind::forward_training) ||
112         (IS_SET(use_global_stats))) {
113       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
114       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
115     }
116     if (workspace_data != nullptr) {
117       context_.ws_mem->set_data_handle(workspace_data);
118     }
119 #endif  // ENABLE_MKLDNN_THREADPOOL
120 
121     // Execute batch-normalization forward primitives.
122     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
123 
124     context_.src_mem->set_data_handle(DummyData);
125     context_.dst_mem->set_data_handle(DummyData);
126 
127     if (IS_SET(use_scale_shift))
128       context_.weights_mem->set_data_handle(DummyData);
129 
130     if ((context_.pkind == prop_kind::forward_training) ||
131         (IS_SET(use_global_stats))) {
132       context_.mean_mem->set_data_handle(DummyData);
133       context_.variance_mem->set_data_handle(DummyData);
134     }
135 
136     if (workspace_data != nullptr) {
137       context_.ws_mem->set_data_handle(DummyData);
138     }
139   }
140 
GetDstPd() const141   memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); }
142 
GetBatchNormFwdPd() const143   std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const {
144     return context_.fwd_pd;
145   }
146 
147  private:
148   // Primitive reuse context for BatchNorm forward op.
149   struct BatchNormFwdContext {
150     // Flags indicating if it is training or inference mode.
151     int64 flags;
152 
153     // Algorithm kind.
154     mkldnn::prop_kind pkind;
155 
156     // Inputs/outputs memory.
157     std::shared_ptr<mkldnn::memory> src_mem;
158     std::shared_ptr<mkldnn::memory> weights_mem;
159     std::shared_ptr<mkldnn::memory> dst_mem;
160     std::shared_ptr<mkldnn::memory> mean_mem;
161     std::shared_ptr<mkldnn::memory> variance_mem;
162     std::shared_ptr<mkldnn::memory> ws_mem;
163 
164     // Forward BatchNorm primitive descriptor.
165     std::shared_ptr<BatchNormFwdPd> fwd_pd;
166 
167     // BatchNorm forward primitive.
168     std::shared_ptr<mkldnn::primitive> bn_fwd;
169     std::vector<mkldnn::primitive> fwd_primitives;
170 
171     std::vector<std::unordered_map<int, memory>> net_args;
172 
BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext173     BatchNormFwdContext()
174         : flags(0),
175           pkind(prop_kind::forward_training),
176           src_mem(nullptr),
177           weights_mem(nullptr),
178           dst_mem(nullptr),
179           mean_mem(nullptr),
180           variance_mem(nullptr),
181           ws_mem(nullptr),
182           bn_fwd(nullptr) {}
183   };
184 
Setup(const MklBatchNormFwdParams & fwdParams)185   void Setup(const MklBatchNormFwdParams& fwdParams) {
186     context_.flags =
187         fwdParams.training
188             ? GET_FLAG(use_scale_shift)
189             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
190     context_.pkind = fwdParams.training ? prop_kind::forward_training
191                                         : prop_kind::forward_scoring;
192 
193     if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
194       context_.flags |= GET_FLAG(fuse_norm_relu);
195     }
196     // Memory descriptor
197     auto src_md = fwdParams.src_md;
198     // Create forward BatchNorm descriptor and primitive descriptor.
199     auto fwd_desc = batch_normalization_forward::desc(
200         context_.pkind, src_md, fwdParams.eps,
201         static_cast<mkldnn::normalization_flags>(context_.flags));
202 
203     context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_));
204 
205     // Create memory primitive based on dummy data
206     context_.src_mem.reset(
207         new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData));
208     context_.dst_mem.reset(
209         new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData));
210 
211     memory::dims s_dims = {2, fwdParams.depth};
212     memory::dims m_dims = {1, fwdParams.depth};
213     if (IS_SET(use_scale_shift)) {
214       context_.weights_mem.reset(
215           new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc},
216                      cpu_engine_, DummyData));
217     }
218 
219     if (fwdParams.training || (IS_SET(use_global_stats))) {
220       context_.mean_mem.reset(
221           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
222                      cpu_engine_, DummyData));
223 
224       context_.variance_mem.reset(
225           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
226                      cpu_engine_, DummyData));
227     }
228 
229     if (IS_SET(fuse_norm_relu)) {
230       context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(),
231                                        cpu_engine_, DummyData));
232     }
233 
234     // BatchNorm forward primitive.
235     // TODO(intel-tf): Merge all the #ifdefs and simplify code
236     if (!fwdParams.training && !(IS_SET(use_global_stats))) {
237       if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) {
238         context_.net_args.push_back(
239             {{MKLDNN_ARG_SRC, *context_.src_mem},
240              {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
241              {MKLDNN_ARG_DST, *context_.dst_mem}});
242       } else {
243         context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
244                                      {MKLDNN_ARG_DST, *context_.dst_mem}});
245       }
246       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
247     } else if (IS_SET(use_global_stats)) {
248       if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
249         if (IS_SET(fuse_norm_relu)) {
250           context_.net_args.push_back(
251               {{MKLDNN_ARG_SRC, *context_.src_mem},
252                {MKLDNN_ARG_MEAN, *context_.mean_mem},
253                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
254                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
255                {MKLDNN_ARG_DST, *context_.dst_mem},
256                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
257         } else {
258           context_.net_args.push_back(
259               {{MKLDNN_ARG_SRC, *context_.src_mem},
260                {MKLDNN_ARG_MEAN, *context_.mean_mem},
261                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
262                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
263                {MKLDNN_ARG_DST, *context_.dst_mem}});
264         }
265       } else {
266         if (IS_SET(fuse_norm_relu)) {
267           context_.net_args.push_back(
268               {{MKLDNN_ARG_SRC, *context_.src_mem},
269                {MKLDNN_ARG_MEAN, *context_.mean_mem},
270                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
271                {MKLDNN_ARG_DST, *context_.dst_mem},
272                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
273         } else {
274           context_.net_args.push_back(
275               {{MKLDNN_ARG_SRC, *context_.src_mem},
276                {MKLDNN_ARG_MEAN, *context_.mean_mem},
277                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
278                {MKLDNN_ARG_DST, *context_.dst_mem}});
279         }
280       }
281       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
282     } else {
283       if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
284         if (IS_SET(fuse_norm_relu)) {
285           context_.net_args.push_back(
286               {{MKLDNN_ARG_SRC, *context_.src_mem},
287                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
288                {MKLDNN_ARG_DST, *context_.dst_mem},
289                {MKLDNN_ARG_MEAN, *context_.mean_mem},
290                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
291                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
292         } else {
293           context_.net_args.push_back(
294               {{MKLDNN_ARG_SRC, *context_.src_mem},
295                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
296                {MKLDNN_ARG_DST, *context_.dst_mem},
297                {MKLDNN_ARG_MEAN, *context_.mean_mem},
298                {MKLDNN_ARG_VARIANCE, *context_.variance_mem}});
299         }
300       } else {
301         if (IS_SET(fuse_norm_relu)) {
302           context_.net_args.push_back(
303               {{MKLDNN_ARG_SRC, *context_.src_mem},
304                {MKLDNN_ARG_DST, *context_.dst_mem},
305                {MKLDNN_ARG_MEAN, *context_.mean_mem},
306                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
307                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
308         } else {
309           context_.net_args.push_back(
310               {{MKLDNN_ARG_SRC, *context_.src_mem},
311                {MKLDNN_ARG_DST, *context_.dst_mem},
312                {MKLDNN_ARG_MEAN, *context_.mean_mem},
313                {MKLDNN_ARG_VARIANCE, *context_.variance_mem}});
314         }
315       }
316       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
317     }
318 
319     context_.fwd_primitives.push_back(*context_.bn_fwd);
320   }
321 
322   struct BatchNormFwdContext context_;
323 };
324 
325 template <typename T, typename U>
326 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
327  public:
Get(const MklBatchNormFwdParams & fwdParams)328   static MklFusedBatchNormFwdPrimitive<T, U>* Get(
329       const MklBatchNormFwdParams& fwdParams) {
330     auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>(
331         MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance()
332             .GetBatchNormFwd(fwdParams));
333 
334     if (bn_fwd == nullptr) {
335       bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams);
336       MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd(
337           fwdParams, bn_fwd);
338     }
339     return bn_fwd;
340   }
341 
GetInstance()342   static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() {
343     static MklFusedBatchNormFwdPrimitiveFactory instance_;
344     return instance_;
345   }
346 
347  private:
MklFusedBatchNormFwdPrimitiveFactory()348   MklFusedBatchNormFwdPrimitiveFactory() {}
~MklFusedBatchNormFwdPrimitiveFactory()349   ~MklFusedBatchNormFwdPrimitiveFactory() {}
350 
CreateKey(const MklBatchNormFwdParams & fwdParams)351   static string CreateKey(const MklBatchNormFwdParams& fwdParams) {
352     string prefix = "bn_fwd";
353     FactoryKeyCreator key_creator;
354     key_creator.AddAsKey(prefix);
355     key_creator.AddAsKey(fwdParams.src_dims);
356     key_creator.AddAsKey<int>(fwdParams.depth);
357     key_creator.AddAsKey<float>(fwdParams.eps);
358     key_creator.AddAsKey<bool>(fwdParams.training);
359     key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
360     key_creator.AddAsKey(typeid(T).name());
361     key_creator.AddAsKey(typeid(U).name());
362     return key_creator.GetKey();
363   }
364 
GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)365   MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
366     string key = CreateKey(fwdParams);
367     return this->GetOp(key);
368   }
369 
SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)370   void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
371                        MklPrimitive* op) {
372     string key = CreateKey(fwdParams);
373     this->SetOp(key, op);
374   }
375 };
376 
377 struct MklBatchNormBwdParams {
378   memory::dims src_dims;
379   memory::dims diff_dst_dims;
380   int depth;
381   float eps;
382   bool training;
383 
384   memory::desc src_md;
385   memory::desc diff_dst_md;
386 
MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams387   MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
388                         int depth, float eps, bool training,
389                         memory::desc src_md, memory::desc diff_dst_md)
390       : src_dims(src_dims),
391         diff_dst_dims(diff_dst_dims),
392         depth(depth),
393         eps(eps),
394         training(training),
395         src_md(src_md),
396         diff_dst_md(diff_dst_md) {}
397 };
398 
399 template <typename T, typename U>
400 class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
401  public:
MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)402   explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
403       : MklPrimitive(engine(engine::kind::cpu, 0)) {
404     if (context_.bn_bwd == nullptr) Setup(bwdParams);
405   }
406 
~MklFusedBatchNormBwdPrimitive()407   ~MklFusedBatchNormBwdPrimitive() {}
408 
409   // BatchNormalization backward execute
410   //   src_data:       input data buffer of src
411   //   mean_data:      input data buffer of mean
412   //   variance_data:  input data buffer of variance
413   //   diff_dst_data:  input data buffer of diff_dst
414   //   weights_data:   input data buffer of weights
415   //   diff_src_data:      output data buffer of diff_src
416   //   diff_weights_data:  output data buffer of diff_weights
417   //   res_space_data:     output data buffer or reserved_space_3.
418   //                       TODO: reserved_space_3: temp mem to hold
419   //                          intermediate results is not implemented
420   //                          on CPU as of now.
Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)421   void Execute(const T* src_data, const U* mean_data, const U* variance_data,
422                const T* diff_dst_data, const U* weights_data, T* diff_src_data,
423                U* diff_weights_data, U* res_space_data,
424                std::shared_ptr<stream> bwd_stream) {
425 #ifdef ENABLE_MKLDNN_THREADPOOL
426     // TODO: Create a common function and avoid the duplicate code
427     context_.src_mem->set_data_handle(
428         static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
429     context_.mean_mem->set_data_handle(
430         static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream);
431     context_.variance_mem->set_data_handle(
432         static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream);
433     context_.diff_dst_mem->set_data_handle(
434         static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
435 
436     if (IS_SET(use_scale_shift)) {
437       context_.weights_mem->set_data_handle(
438           static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream);
439       context_.diff_weights_mem->set_data_handle(
440           static_cast<void*>(diff_weights_data), *bwd_stream);
441     }
442 
443     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
444                                            *bwd_stream);
445 #else
446     context_.src_mem->set_data_handle(
447         static_cast<void*>(const_cast<T*>(src_data)));
448     context_.mean_mem->set_data_handle(
449         static_cast<void*>(const_cast<U*>(mean_data)));
450     context_.variance_mem->set_data_handle(
451         static_cast<void*>(const_cast<U*>(variance_data)));
452     context_.diff_dst_mem->set_data_handle(
453         static_cast<void*>(const_cast<T*>(diff_dst_data)));
454 
455     if (IS_SET(use_scale_shift)) {
456       context_.weights_mem->set_data_handle(
457           static_cast<void*>(const_cast<U*>(weights_data)));
458       context_.diff_weights_mem->set_data_handle(
459           static_cast<void*>(diff_weights_data));
460     }
461 
462     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
463 #endif  // ENABLE_MKLDNN_THREADPOOL
464     // Execute backward batch-normalization primitives.
465     DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
466     execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
467 
468     // After execution, set data handle back to DummyData.
469     context_.src_mem->set_data_handle(DummyData);
470     context_.mean_mem->set_data_handle(DummyData);
471     context_.variance_mem->set_data_handle(DummyData);
472     context_.diff_dst_mem->set_data_handle(DummyData);
473     if (IS_SET(use_scale_shift)) {
474       context_.weights_mem->set_data_handle(DummyData);
475       context_.diff_weights_mem->set_data_handle(DummyData);
476     }
477     context_.diff_src_mem->set_data_handle(DummyData);
478   }
479 
GetBatchNormBwdPd() const480   std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const {
481     return context_.bwd_pd;
482   }
483 
GetDiffSrcPd()484   memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); }
485 
486  private:
487   struct BatchNormBwdContext {
488     // Flags to indicate whether it is training or inference.
489     int64 flags;
490 
491     // Inputs/output memory.
492     std::shared_ptr<mkldnn::memory> src_mem;
493     std::shared_ptr<mkldnn::memory> mean_mem;
494     std::shared_ptr<mkldnn::memory> variance_mem;
495     std::shared_ptr<mkldnn::memory> diff_dst_mem;
496     std::shared_ptr<mkldnn::memory> weights_mem;
497     std::shared_ptr<mkldnn::memory> diff_weights_mem;
498     std::shared_ptr<mkldnn::memory> diff_src_mem;
499 
500     // Backward batch-normalization primitive descriptor.
501     std::shared_ptr<BatchNormBwdPd> bwd_pd;
502 
503     // Backward batch-normalization primitive.
504     std::shared_ptr<mkldnn::primitive> bn_bwd;
505     std::vector<mkldnn::primitive> bwd_primitives;
506 
507     std::vector<std::unordered_map<int, memory>> net_args;
508 
BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext509     BatchNormBwdContext()
510         : src_mem(nullptr),
511           mean_mem(nullptr),
512           variance_mem(nullptr),
513           diff_dst_mem(nullptr),
514           weights_mem(nullptr),
515           diff_weights_mem(nullptr),
516           diff_src_mem(nullptr) {}
517   };
518 
Setup(const MklBatchNormBwdParams & bwdParams)519   void Setup(const MklBatchNormBwdParams& bwdParams) {
520     context_.flags =
521         bwdParams.training
522             ? GET_FLAG(use_scale_shift)
523             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
524 
525     // Memory descriptors.
526     auto src_md = bwdParams.src_md;
527     auto diff_dst_md = bwdParams.diff_dst_md;
528     auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
529                                       memory::format_tag::nc);
530     auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
531                                   memory::format_tag::nc);
532     auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(),
533                                      memory::format_tag::nc);
534     auto diff_weights_desc = weights_desc;
535 
536     // Forward batch-normalization descriptor and primitive descriptor.
537     // Adding this back due to type difference with context.flags
538     auto bn_flags = bwdParams.training
539                         ? mkldnn::normalization_flags::use_scale_shift
540                         : (mkldnn::normalization_flags::use_scale_shift |
541                            mkldnn::normalization_flags::use_global_stats);
542     auto fwd_desc = batch_normalization_forward::desc(
543         prop_kind::forward_training, src_md, bwdParams.eps, bn_flags);
544     auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_);
545 
546     // Backward batch-normalization primitive.
547     // For inference, specify use_global_stats
548     //   1. on fwd propagation, use mean and variance provided as inputs.
549     //   2. on bwd propagation, mean and variance are considered as constants.
550     //      Thus, reduce the amount of MKL computation.
551     auto bwd_desc = batch_normalization_backward::desc(
552         prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags);
553     context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd));
554 
555     // Create memory primitives.
556     context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
557     context_.diff_dst_mem.reset(
558         new memory(diff_dst_md, cpu_engine_, DummyData));
559     context_.variance_mem.reset(
560         new memory(variance_desc, cpu_engine_, DummyData));
561     context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData));
562     context_.weights_mem.reset(
563         new memory(weights_desc, cpu_engine_, DummyData));
564     context_.diff_weights_mem.reset(
565         new memory(diff_weights_desc, cpu_engine_, DummyData));
566     context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
567 
568     context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd));
569     context_.net_args.push_back(
570         {{MKLDNN_ARG_SRC, *context_.src_mem},
571          {MKLDNN_ARG_MEAN, *context_.mean_mem},
572          {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
573          {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
574          {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
575          {MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem},
576          {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}});
577     context_.bwd_primitives.push_back(*context_.bn_bwd);
578   }
579 
580   struct BatchNormBwdContext context_;
581 };
582 
583 template <typename T, typename U>
584 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
585  public:
Get(const MklBatchNormBwdParams & bwdParams)586   static MklFusedBatchNormBwdPrimitive<T, U>* Get(
587       const MklBatchNormBwdParams& bwdParams) {
588     auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>(
589         MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance()
590             .GetBatchNormBwd(bwdParams));
591     if (bn_bwd == nullptr) {
592       bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams);
593       MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd(
594           bwdParams, bn_bwd);
595     }
596     return bn_bwd;
597   }
598 
GetInstance()599   static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() {
600     static MklFusedBatchNormBwdPrimitiveFactory instance_;
601     return instance_;
602   }
603 
604  private:
MklFusedBatchNormBwdPrimitiveFactory()605   MklFusedBatchNormBwdPrimitiveFactory() {}
~MklFusedBatchNormBwdPrimitiveFactory()606   ~MklFusedBatchNormBwdPrimitiveFactory() {}
607 
CreateKey(const MklBatchNormBwdParams & bwdParams)608   static string CreateKey(const MklBatchNormBwdParams& bwdParams) {
609     string prefix = "bn_bwd";
610     FactoryKeyCreator key_creator;
611     key_creator.AddAsKey(prefix);
612     key_creator.AddAsKey(bwdParams.src_dims);
613     key_creator.AddAsKey(bwdParams.diff_dst_dims);
614     key_creator.AddAsKey<int>(bwdParams.depth);
615     key_creator.AddAsKey<float>(bwdParams.eps);
616     key_creator.AddAsKey<bool>(bwdParams.training);
617     key_creator.AddAsKey(typeid(T).name());
618     key_creator.AddAsKey(typeid(U).name());
619     return key_creator.GetKey();
620   }
621 
GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)622   MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
623     string key = CreateKey(bwdParams);
624     return this->GetOp(key);
625   }
626 
SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)627   void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
628                        MklPrimitive* op) {
629     string key = CreateKey(bwdParams);
630     this->SetOp(key, op);
631   }
632 };
633 
634 //  Adding a third parameter to the template to support FusedBatchNormV3
635 //  with MKL. This is different from default where the classes are
636 //  derived. Moves enabling to compile-time rather than runtime.
637 template <typename Device, typename T, typename U, bool reserved_space,
638           bool is_batch_norm_ex = false, bool native_format = false>
639 class MklFusedBatchNormOp : public OpKernel {
640  public:
MklFusedBatchNormOp(OpKernelConstruction * context)641   explicit MklFusedBatchNormOp(OpKernelConstruction* context)
642       : OpKernel(context) {
643     float epsilon;
644     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
645     epsilon_ = epsilon;
646     float exponential_avg_factor;
647     OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
648                                              &exponential_avg_factor));
649     exponential_avg_factor_ = static_cast<U>(exponential_avg_factor);
650     string tensor_format;
651     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
652     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
653                 errors::InvalidArgument("Invalid data format"));
654     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
655     depth_ = 0;
656     mean_values_ = nullptr;
657     variance_values_ = nullptr;
658 
659     if (!is_batch_norm_ex) {
660       activation_mode_ = FusedBNActivationMode::kIdentity;
661     } else {
662       int num_side_inputs;
663       OP_REQUIRES_OK(context,
664                      context->GetAttr("num_side_inputs", &num_side_inputs));
665       // Currently _MKLFusedBatchNormEx do not support "SideInput"
666       OP_REQUIRES(context, num_side_inputs == 0,
667                   errors::InvalidArgument(
668                       "_MKLFusedBatchNorm do not support side input now."));
669 
670       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
671       OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
672                   errors::InvalidArgument(
673                       "_MKLFusedBatchNorm only support Relu activation"));
674     }
675   }
676 
Compute(OpKernelContext * context)677   void Compute(OpKernelContext* context) override {
678     try {
679       const size_t kSrcIndex = 0;       // index of src input tensor
680       const size_t kScaleIndex = 1;     // index of scale tensor
681       const size_t kShiftIndex = 2;     // index of shift tensor
682       const size_t kMeanIndex = 3;      // index of est_mean tensor
683       const size_t kVarianceIndex = 4;  // index of est_variance tensor
684 
685       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
686       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
687       const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
688       const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
689       const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex);
690 
691       TensorShape tf_shape_src;
692       MklDnnShape dnn_shape_src;
693       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
694 
695       if (dnn_shape_src.IsMklTensor()) {
696         tf_shape_src = dnn_shape_src.GetTfShape();
697         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
698                     errors::InvalidArgument("input must be 4-dimensional",
699                                             src_tensor.shape().DebugString()));
700       } else {
701         tf_shape_src = src_tensor.shape();
702         OP_REQUIRES(context, src_tensor.dims() == 4,
703                     errors::InvalidArgument("input must be 4-dimensional",
704                                             src_tensor.shape().DebugString()));
705       }
706       OP_REQUIRES(context, scale_tensor.dims() == 1,
707                   errors::InvalidArgument("scale must be 1-dimensional",
708                                           scale_tensor.shape().DebugString()));
709       OP_REQUIRES(context, shift_tensor.dims() == 1,
710                   errors::InvalidArgument("offset must be 1-dimensional",
711                                           shift_tensor.shape().DebugString()));
712       OP_REQUIRES(
713           context, est_mean_tensor.dims() == 1,
714           errors::InvalidArgument("estimated_mean must be 1-dimensional",
715                                   est_mean_tensor.shape().DebugString()));
716       OP_REQUIRES(
717           context, est_variance_tensor.dims() == 1,
718           errors::InvalidArgument("estimated_variance must be 1-dimensional",
719                                   est_variance_tensor.shape().DebugString()));
720 
721       // Handle the special case: input with 0 element and 0 batch size.
722       Tensor* dst_tensor = nullptr;
723       TensorShape workspace_tf_shape;
724       if (tf_shape_src.num_elements() == 0) {
725         size_t workspace_bytes = 0;
726         workspace_tf_shape.AddDim(workspace_bytes);
727         HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
728                          scale_tensor.shape(), &dst_tensor);
729         return;
730       }
731 
732       if (dnn_shape_src.IsMklTensor())
733         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
734       else
735         ExtractParams(context);
736 
737       // Index of output tensor(diff_src).
738       const size_t kDstIndex = 0;
739 
740       // Allocate 5 output TF tensors.
741       Tensor* batch_mean_tensor = nullptr;
742       Tensor* batch_variance_tensor = nullptr;
743       Tensor* saved_mean_tensor = nullptr;
744       Tensor* saved_variance_tensor = nullptr;
745       Tensor* reserved_space_tensor = nullptr;
746 
747       MklDnnData<T> src(&cpu_engine_);
748       MklDnnData<U> weights(&cpu_engine_);
749       MklDnnData<U> wksp(&cpu_engine_);
750 
751       memory::format_tag dnn_fmt;
752       MklTensorFormat mkl_tensor_fmt;
753       if (dnn_shape_src.IsMklTensor()) {
754         if (dnn_shape_src.IsTensorInNCHWFormat()) {
755           dnn_fmt = memory::format_tag::nchw;
756           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
757         } else {
758           dnn_fmt = memory::format_tag::nhwc;
759           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
760         }
761       } else {
762         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
763         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
764       }
765 
766       // Set src memory descriptor.
767       memory::dims src_dims =
768           dnn_shape_src.IsMklTensor()
769               ? dnn_shape_src.GetSizesAsMklDnnDims()
770               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
771 
772       auto src_md = dnn_shape_src.IsMklTensor()
773                         ? dnn_shape_src.GetMklLayout()
774                         : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
775 
776       MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
777                                       src_md, activation_mode_);
778 
779       // Get forward batch-normalization op from the primitive caching pool.
780       MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
781           MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
782 
783       // Allocate workspace tensor
784       U* ws_data = nullptr;
785       if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
786         memory::desc workspace_md =
787             bn_fwd->GetBatchNormFwdPd()->workspace_desc();
788         size_t workspace_bytes = workspace_md.get_size();
789         workspace_tf_shape.AddDim(workspace_bytes);
790 
791         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
792                           &batch_mean_tensor, &batch_variance_tensor,
793                           &saved_mean_tensor, &saved_variance_tensor,
794                           &reserved_space_tensor);
795         if (reserved_space) {
796           wksp.SetUsrMem(workspace_md, reserved_space_tensor);
797           ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
798         }
799       } else {
800         // There is actually no workspace tensor out, so we make a dummy one.
801         size_t workspace_bytes = 0;
802         workspace_tf_shape.AddDim(workspace_bytes);
803         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
804                           &batch_mean_tensor, &batch_variance_tensor,
805                           &saved_mean_tensor, &saved_variance_tensor,
806                           &reserved_space_tensor);
807       }
808 
809       if (is_training_)
810         SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
811       else
812         SetMeanVariance(est_mean_tensor, est_variance_tensor);
813 
814       // MKL-DNN packs scale & shift as "weights":
815       // <scale>...<scale><shift>...<shift>
816       weights.AllocateBuffer(2 * depth_ * sizeof(U));
817       U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
818       const U* scale_tf = scale_tensor.flat<U>().data();
819       const U* shift_tf = shift_tensor.flat<U>().data();
820 
821       std::memcpy(weights_data, scale_tf, depth_ * sizeof(U));
822       std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U));
823       char* saved_mean_data_tf =
824           reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data());
825       std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
826                   depth_ * sizeof(U));
827 
828       char* saved_variance_data_tf =
829           reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data());
830       std::memcpy(saved_variance_data_tf,
831                   reinterpret_cast<char*>(variance_values_),
832                   depth_ * sizeof(U));
833 
834       // Check if reorder is needed for src.
835       const T* src_data = nullptr;
836       std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
837       if (!native_format && src_md != bn_fwd_pd->src_desc()) {
838         src.SetUsrMem(src_md, &src_tensor);
839         src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context);
840         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
841       } else {
842         src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
843       }
844 
845       // Allocate output (dst) tensor
846       MklDnnShape dnn_shape_dst;
847       TensorShape tf_shape_dst;
848       dnn_shape_dst.SetMklTensor(true);
849       auto dst_pd = bn_fwd->GetDstPd();
850       dnn_shape_dst.SetMklLayout(&dst_pd);
851       dnn_shape_dst.SetElemType(MklDnnType<T>());
852       auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension()
853                                                : src_tensor.shape().dims();
854       dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt);
855       tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
856       if (native_format) {
857         tf_shape_dst = dnn_shape_dst.GetTfShape();
858       }
859       AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
860                                 dnn_shape_dst, native_format);
861 
862       U* weights_op_data = weights_data;
863       U* mean_op_data = saved_mean_tensor->flat<U>().data();
864       U* variance_op_data = saved_variance_tensor->flat<U>().data();
865       T* dst_data = dst_tensor->flat<T>().data();
866 
867       // Execute
868       std::shared_ptr<stream> fwd_cpu_stream;
869       fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine()));
870       bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
871                       variance_op_data, fwd_cpu_stream, ws_data);
872       float adjust_factor = 1.0;
873       if (is_training_) {
874         size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
875         size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1;
876         adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
877       }
878 
879       auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf);
880       auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf);
881       auto batch_mean_data = batch_mean_tensor->flat<U>().data();
882       auto batch_variance_data = batch_variance_tensor->flat<U>().data();
883       auto est_mean_data = est_mean_tensor.flat<U>().data();
884       auto est_variance_data = est_variance_tensor.flat<U>().data();
885       if (is_training_) {
886         if (exponential_avg_factor_ == U(1.0)) {
887           for (int k = 0; k < depth_; k++) {
888             batch_mean_data[k] = mean_data[k];
889             batch_variance_data[k] =
890                 static_cast<U>(adjust_factor) * variance_data[k];
891           }
892         } else {
893           U one_minus_factor = U(1.0) - exponential_avg_factor_;
894           for (int k = 0; k < depth_; k++) {
895             batch_mean_data[k] = one_minus_factor * est_mean_data[k] +
896                                  exponential_avg_factor_ * mean_data[k];
897             batch_variance_data[k] = one_minus_factor * est_variance_data[k] +
898                                      exponential_avg_factor_ *
899                                          static_cast<U>(adjust_factor) *
900                                          variance_data[k];
901           }
902         }
903       } else {
904         std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U));
905         std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U));
906       }
907     } catch (mkldnn::error& e) {
908       string error_msg = "Status: " + std::to_string(e.status) +
909                          ", message: " + string(e.message) + ", in file " +
910                          string(__FILE__) + ":" + std::to_string(__LINE__);
911       OP_REQUIRES_OK(
912           context,
913           errors::Aborted("Operation received an exception:", error_msg));
914     }
915   }
916 
917  private:
918   float epsilon_;
919   U exponential_avg_factor_;
920   TensorFormat tensor_format_;
921   bool is_training_;
922   U* mean_values_;
923   U* variance_values_;
924   size_t depth_;  // Batch normalization is performed for per channel.
925   FusedBNActivationMode activation_mode_;
926   engine cpu_engine_ = engine(engine::kind::cpu, 0);
927 
ExtractParams(OpKernelContext * context)928   void ExtractParams(OpKernelContext* context) {
929     const Tensor& input = MklGetInput(context, 0);
930     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
931   }
932 
SetMeanVariance(const Tensor & mean,const Tensor & variance)933   void SetMeanVariance(const Tensor& mean, const Tensor& variance) {
934     mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data()));
935     variance_values_ =
936         reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data()));
937   }
938 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)939   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
940                         TensorShape workspace_tf_shape,
941                         TensorShape tf_shape_scale, Tensor** dst_tensor) {
942     DCHECK(dst_tensor);
943 
944     const size_t kDstIndex = 0;
945     MklDnnShape dnn_shape_dst;
946     dnn_shape_dst.SetMklTensor(false);
947     AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src,
948                               dnn_shape_dst, native_format);
949     DCHECK(*dst_tensor);
950     memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
951            (*dst_tensor)->tensor_data().size());
952 
953     Tensor* batch_mean_tensor = nullptr;
954     Tensor* batch_variance_tensor = nullptr;
955     Tensor* saved_mean_tensor = nullptr;
956     Tensor* saved_variance_tensor = nullptr;
957     Tensor* reserved_space_tensor = nullptr;
958     AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
959                       &batch_mean_tensor, &batch_variance_tensor,
960                       &saved_mean_tensor, &saved_variance_tensor,
961                       &reserved_space_tensor);
962   }
963 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)964   void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
965                          TensorShape workspace_tf_shape,
966                          Tensor** batch_mean_tensor,
967                          Tensor** batch_variance_tensor,
968                          Tensor** saved_mean_tensor,
969                          Tensor** saved_variance_tensor,
970                          Tensor** reserved_space_tensor) {
971     DCHECK(batch_mean_tensor);
972     DCHECK(batch_variance_tensor);
973     DCHECK(saved_mean_tensor);
974     DCHECK(saved_variance_tensor);
975 
976     const size_t kBatchMeanIndex = 1;
977     const size_t kBatchVarianceIndex = 2;
978     const size_t kSavedMeanIndex = 3;
979     const size_t kSavedVarianceIndex = 4;
980     const size_t kReservedSpaceIndex = 5;
981 
982     // Allocate batch mean output tensor.
983     MklDnnShape mkl_shape_batch_mean;
984     mkl_shape_batch_mean.SetMklTensor(false);
985     AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor,
986                               tf_shape_scale, mkl_shape_batch_mean,
987                               native_format);
988     DCHECK(*batch_mean_tensor);
989 
990     // Set NAN mean value in case of empty input tensor
991     int num_elements = tf_shape_scale.num_elements();
992     auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data();
993     std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN));
994 
995     // Allocate batch variance output tensor.
996     MklDnnShape mkl_shape_batch_variance;
997     mkl_shape_batch_variance.SetMklTensor(false);
998     AllocateOutputSetMklShape(context, kBatchVarianceIndex,
999                               batch_variance_tensor, tf_shape_scale,
1000                               mkl_shape_batch_variance, native_format);
1001     DCHECK(*batch_variance_tensor);
1002 
1003     // Set NAN variance value in case of empty input tensor
1004     auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data();
1005     std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN));
1006     // Mean and variance (without Bessel's correction) saved for backward
1007     // computation to serve as pre-computed mean and variance.
1008     MklDnnShape mkl_shape_saved_mean;
1009     mkl_shape_saved_mean.SetMklTensor(false);
1010     AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor,
1011                               tf_shape_scale, mkl_shape_saved_mean,
1012                               native_format);
1013     DCHECK(*saved_mean_tensor);
1014 
1015     // Set 0 mean value in case of empty input tensor
1016     auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
1017     std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
1018 
1019     MklDnnShape mkl_shape_saved_variance;
1020     mkl_shape_saved_variance.SetMklTensor(false);
1021     AllocateOutputSetMklShape(context, kSavedVarianceIndex,
1022                               saved_variance_tensor, tf_shape_scale,
1023                               mkl_shape_saved_variance, native_format);
1024     DCHECK(*saved_variance_tensor);
1025 
1026     // Set 0 variance value in case of empty input tensor
1027     auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
1028     std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
1029 
1030     // Changes to support reserved_space_3 parameter in FusedBatchNormV3.
1031     if (reserved_space) {
1032       DCHECK(reserved_space_tensor != nullptr);
1033 
1034       MklDnnShape mkl_shape_reserved_space;
1035       mkl_shape_reserved_space.SetMklTensor(false);
1036       AllocateOutputSetMklShape(context, kReservedSpaceIndex,
1037                                 reserved_space_tensor, workspace_tf_shape,
1038                                 mkl_shape_reserved_space, native_format);
1039       DCHECK((*reserved_space_tensor) != nullptr);
1040     }
1041   }
1042 };
1043 
1044 template <typename Device, typename T, typename U, bool reserved_space,
1045           bool native_format = false>
1046 class MklFusedBatchNormGradOp : public OpKernel {
1047  public:
MklFusedBatchNormGradOp(OpKernelConstruction * context)1048   explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
1049       : OpKernel(context) {
1050     float epsilon;
1051     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1052     epsilon_ = epsilon;
1053     string tensor_format;
1054     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1055     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1056                 errors::InvalidArgument("Invalid data format"));
1057     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1058     depth_ = 0;
1059   }
1060 
Compute(OpKernelContext * context)1061   void Compute(OpKernelContext* context) override {
1062     try {
1063       const size_t kDiffDstIndex = 0;        // index of diff_dst tensor
1064       const size_t kSrcIndex = 1;            // index of src input tensor
1065       const size_t kScaleIndex = 2;          // index of scale tensor
1066       const size_t kMeanIndex = 3;           // index of saved_mean tensor
1067       const size_t kVarianceIndex = 4;       // index of saved_variance tensor
1068       const size_t kReservedSpaceIndex = 5;  // index of reserved space 3 tensor
1069 
1070       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
1071       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
1072       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
1073       const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
1074       const Tensor& saved_variance_tensor =
1075           MklGetInput(context, kVarianceIndex);
1076       const Tensor& reserved_space_tensor =
1077           (reserved_space) ? MklGetInput(context, kReservedSpaceIndex)
1078                            : Tensor();
1079 
1080       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
1081       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
1082       GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format);
1083 
1084       TensorShape tf_shape_src, tf_shape_diff_dst;
1085       if (dnn_shape_diff_dst.IsMklTensor()) {
1086         tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
1087         OP_REQUIRES(
1088             context, dnn_shape_diff_dst.GetDimension() == 4,
1089             errors::InvalidArgument("input must be 4-dimensional",
1090                                     diff_dst_tensor.shape().DebugString()));
1091       } else {
1092         tf_shape_diff_dst = diff_dst_tensor.shape();
1093         OP_REQUIRES(
1094             context, diff_dst_tensor.dims() == 4,
1095             errors::InvalidArgument("input must be 4-dimensional",
1096                                     diff_dst_tensor.shape().DebugString()));
1097       }
1098 
1099       if (dnn_shape_src.IsMklTensor()) {
1100         tf_shape_src = dnn_shape_src.GetTfShape();
1101         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
1102                     errors::InvalidArgument("input must be 4-dimensional",
1103                                             src_tensor.shape().DebugString()));
1104       } else {
1105         tf_shape_src = src_tensor.shape();
1106         OP_REQUIRES(context, src_tensor.dims() == 4,
1107                     errors::InvalidArgument("input must be 4-dimensional",
1108                                             src_tensor.shape().DebugString()));
1109       }
1110 
1111       OP_REQUIRES(context, scale_tensor.dims() == 1,
1112                   errors::InvalidArgument("scale must be 1-dimensional",
1113                                           scale_tensor.shape().DebugString()));
1114       OP_REQUIRES(
1115           context, saved_mean_tensor.dims() == 1,
1116           errors::InvalidArgument("saved mean must be 1-dimensional",
1117                                   saved_mean_tensor.shape().DebugString()));
1118 
1119       OP_REQUIRES(
1120           context, saved_variance_tensor.dims() == 1,
1121           errors::InvalidArgument("saved variance must be 1-dimensional",
1122                                   saved_variance_tensor.shape().DebugString()));
1123 
1124       // Handle the special case: input with 0 element and 0 batch size.
1125       Tensor* diff_src_tensor = nullptr;
1126       if (tf_shape_src.num_elements() == 0 ||
1127           tf_shape_diff_dst.num_elements() == 0) {
1128         HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
1129                          &diff_src_tensor);
1130         return;
1131       }
1132 
1133       if (dnn_shape_src.IsMklTensor()) {
1134         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
1135       } else if (dnn_shape_diff_dst.IsMklTensor()) {
1136         depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
1137       } else {
1138         ExtractParams(context);
1139       }
1140 
1141       memory::format_tag dnn_fmt;
1142       MklTensorFormat mkl_tensor_fmt;
1143       if (dnn_shape_src.IsMklTensor()) {
1144         if (dnn_shape_src.IsTensorInNCHWFormat()) {
1145           dnn_fmt = memory::format_tag::nchw;
1146           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
1147         } else {
1148           dnn_fmt = memory::format_tag::nhwc;
1149           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
1150         }
1151       } else {
1152         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
1153         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
1154       }
1155 
1156       MklDnnData<T> src(&cpu_engine_);
1157       MklDnnData<T> diff_dst(&cpu_engine_);
1158       MklDnnData<U> weights(&cpu_engine_);
1159       MklDnnData<U> diff_weights(&cpu_engine_);
1160 
1161       memory::dims src_dims =
1162           dnn_shape_src.IsMklTensor()
1163               ? dnn_shape_src.GetSizesAsMklDnnDims()
1164               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
1165       memory::dims diff_dst_dims =
1166           dnn_shape_diff_dst.IsMklTensor()
1167               ? dnn_shape_diff_dst.GetSizesAsMklDnnDims()
1168               : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
1169                                           tensor_format_);
1170 
1171       // Set src and diff_dst primitive descriptors.
1172       memory::desc src_md =
1173           dnn_shape_src.IsMklTensor()
1174               ? dnn_shape_src.GetMklLayout()
1175               : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
1176       memory::desc diff_dst_md =
1177           dnn_shape_diff_dst.IsMklTensor()
1178               ? dnn_shape_diff_dst.GetMklLayout()
1179               : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt);
1180 
1181       MklDnnData<T> reorder_src(&cpu_engine_);
1182       MklDnnData<T> reorder_diff_dst(&cpu_engine_);
1183       T* diff_dst_data =
1184           static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
1185       T* src_data =
1186           static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
1187 
1188       if (!native_format) {
1189         // MKL-DNN requires src and diff_dst to be in same memory layout, either
1190         // blocked or native format. If these inputs are in different formats,
1191         // convert the one in native format to blocked format as MKL-DNN gives
1192         // better performance for blocked format.
1193         if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
1194           reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
1195           reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context);
1196           diff_dst_md = src_md;
1197           diff_dst_data =
1198               static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle());
1199         } else if (!dnn_shape_src.IsMklTensor() &&
1200                    dnn_shape_diff_dst.IsMklTensor()) {
1201           reorder_src.SetUsrMem(src_md, &src_tensor);
1202           reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context);
1203           src_md = diff_dst_md;
1204           src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle());
1205         }
1206       }
1207 
1208       // weights -- MKL DNN packs scales/ shifts as weights in order
1209       // of scale, ..., scale, shift, ...., shift
1210       weights.AllocateBuffer(2 * depth_ * sizeof(U));
1211       U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
1212       const U* scale_tf = scale_tensor.flat<U>().data();
1213       for (int k = 0; k < depth_; k++) {
1214         weights_data_tf[k] = scale_tf[k];
1215         weights_data_tf[k + depth_] = static_cast<U>(0);
1216       }
1217 
1218       diff_weights.AllocateBuffer(2 * depth_ * sizeof(U));
1219 
1220       MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
1221                                       is_training_, src_md, diff_dst_md);
1222       MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
1223           MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
1224 
1225       // Check if diff_dst input needs to be reordered
1226       std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
1227       if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) {
1228         diff_dst.SetUsrMem(diff_dst_md, diff_dst_data);
1229         diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_,
1230                                      context);
1231         diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
1232       }
1233 
1234       if (!native_format && (src_md != bn_bwd_pd->src_desc())) {
1235         src.SetUsrMem(src_md, src_data);
1236         src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context);
1237         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
1238       }
1239 
1240       // Indices of output tensors
1241       const size_t kDiffSrcIndex = 0;
1242 
1243       // Allocate output tensor diff_src, always set as MKL-DNN layout.
1244       MklDnnShape dnn_shape_diff_src;
1245       TensorShape tf_shape_diff_src;
1246       dnn_shape_diff_src.SetMklTensor(true);
1247       auto diff_src_pd = bn_bwd->GetDiffSrcPd();
1248       dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
1249       dnn_shape_diff_src.SetElemType(MklDnnType<T>());
1250       dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt);
1251       dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_);
1252       tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
1253       if (native_format) {
1254         tf_shape_diff_src = dnn_shape_diff_src.GetTfShape();
1255       }
1256       AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
1257                                 tf_shape_diff_src, dnn_shape_diff_src,
1258                                 native_format);
1259 
1260       U* mean_data =
1261           static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data()));
1262       U* variance_data = static_cast<U*>(
1263           const_cast<U*>(saved_variance_tensor.flat<U>().data()));
1264       U* weights_data = weights_data_tf;
1265       T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data());
1266       U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer());
1267 
1268       U* res_space_data =
1269           ((reserved_space) ? static_cast<U*>(const_cast<U*>(
1270                                   reserved_space_tensor.flat<U>().data()))
1271                             : nullptr);
1272 
1273       // Execute
1274       std::shared_ptr<stream> bwd_cpu_stream;
1275       bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine()));
1276       bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
1277                       weights_data, diff_src_data, diff_weights_data,
1278                       res_space_data, bwd_cpu_stream);
1279       // Allocate output TF tensors diff_scale and diff_shift.
1280       Tensor* diff_scale_tensor = nullptr;
1281       Tensor* diff_shift_tensor = nullptr;
1282       AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
1283                         &diff_shift_tensor);
1284 
1285       // Copy data for tensors diff_scale and diff_shift.
1286       auto diff_scale_data = diff_scale_tensor->flat<U>().data();
1287       auto diff_shift_data = diff_shift_tensor->flat<U>().data();
1288       std::memcpy(reinterpret_cast<char*>(diff_scale_data),
1289                   reinterpret_cast<char*>(diff_weights_data),
1290                   depth_ * sizeof(U));
1291       std::memcpy(reinterpret_cast<char*>(diff_shift_data),
1292                   reinterpret_cast<char*>(diff_weights_data + depth_),
1293                   depth_ * sizeof(U));
1294     } catch (mkldnn::error& e) {
1295       string error_msg = "Status: " + std::to_string(e.status) +
1296                          ", message: " + string(e.message) + ", in file " +
1297                          string(__FILE__) + ":" + std::to_string(__LINE__);
1298       OP_REQUIRES_OK(
1299           context,
1300           errors::Aborted("Operation received an exception:", error_msg));
1301     }
1302   }
1303 
1304  private:
1305   float epsilon_;
1306   TensorFormat tensor_format_;
1307   size_t depth_;  // Batch normalization is performed for per channel.
1308   bool is_training_;
1309   engine cpu_engine_ = engine(engine::kind::cpu, 0);
1310 
ExtractParams(OpKernelContext * context)1311   void ExtractParams(OpKernelContext* context) {
1312     const Tensor& input = MklGetInput(context, 0);
1313     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
1314   }
1315 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1316   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
1317                         TensorShape tf_shape_scale_shift,
1318                         Tensor** diff_src_tensor) {
1319     const size_t kDiffSrcIndex = 0;
1320 
1321     MklDnnShape dnn_shape_diff_src;
1322     dnn_shape_diff_src.SetMklTensor(false);
1323     AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
1324                               tf_shape_src, dnn_shape_diff_src, native_format);
1325     auto diff_src_data = (*diff_src_tensor)->flat<T>().data();
1326     std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(),
1327                 static_cast<T>(0));
1328 
1329     Tensor* diff_scale_tensor = nullptr;
1330     Tensor* diff_shift_tensor = nullptr;
1331     AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor,
1332                       &diff_shift_tensor);
1333   }
1334 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1335   void AllocateTFOutputs(OpKernelContext* context,
1336                          TensorShape tf_shape_scale_shift,
1337                          Tensor** diff_scale_tensor,
1338                          Tensor** diff_shift_tensor) {
1339     DCHECK(diff_scale_tensor);
1340     DCHECK(diff_shift_tensor);
1341 
1342     const size_t kDiffScaleIndex = 1;
1343     const size_t kDiffShiftIndex = 2;
1344     const size_t kP1Index = 3;
1345     const size_t kP2Index = 4;
1346 
1347     // Separate out scale and shift grad and copy to individual tensors
1348     MklDnnShape mkl_shape_diff_scale;
1349     mkl_shape_diff_scale.SetMklTensor(false);
1350     AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
1351                               tf_shape_scale_shift, mkl_shape_diff_scale,
1352                               native_format);
1353     DCHECK(*diff_scale_tensor);
1354 
1355     auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data();
1356     std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(),
1357                 static_cast<U>(0));
1358 
1359     MklDnnShape mkl_shape_diff_shift;
1360     mkl_shape_diff_shift.SetMklTensor(false);
1361     AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
1362                               tf_shape_scale_shift, mkl_shape_diff_shift,
1363                               native_format);
1364     DCHECK(*diff_shift_tensor);
1365 
1366     auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data();
1367     std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(),
1368                 static_cast<U>(0));
1369 
1370     // Placeholders for estimated_mean and estimated_variance, which are
1371     // used for inference and thus not needed here for gradient computation.
1372     Tensor *p1_tensor = nullptr, *p2_tensor = nullptr;
1373     MklDnnShape mkl_shape_p;
1374     mkl_shape_p.SetMklTensor(false);
1375     AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
1376                               mkl_shape_p, native_format);
1377     std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
1378                 static_cast<U>(0));
1379     AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
1380                               mkl_shape_p, native_format);
1381     std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
1382                 static_cast<U>(0));
1383   }
1384 
GetMeanVarianceDims()1385   memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
1386 };
1387 
1388 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T)                    \
1389   REGISTER_KERNEL_BUILDER(                                     \
1390       Name("_MklFusedBatchNorm")                               \
1391           .Device(DEVICE_CPU)                                  \
1392           .TypeConstraint<T>("T")                              \
1393           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1394       MklFusedBatchNormOp<CPUDevice, T, T, false, false>);     \
1395   REGISTER_KERNEL_BUILDER(                                     \
1396       Name("_MklNativeFusedBatchNorm")                         \
1397           .Device(DEVICE_CPU)                                  \
1398           .TypeConstraint<T>("T")                              \
1399           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1400       MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>);
1401 
1402 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1403 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1404 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU
1405 
1406 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U)              \
1407   REGISTER_KERNEL_BUILDER(                                     \
1408       Name("_MklFusedBatchNormV2")                             \
1409           .Device(DEVICE_CPU)                                  \
1410           .TypeConstraint<T>("T")                              \
1411           .TypeConstraint<U>("U")                              \
1412           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1413       MklFusedBatchNormOp<CPUDevice, T, U, false, false>);     \
1414   REGISTER_KERNEL_BUILDER(                                     \
1415       Name("_MklNativeFusedBatchNormV2")                       \
1416           .Device(DEVICE_CPU)                                  \
1417           .TypeConstraint<T>("T")                              \
1418           .TypeConstraint<U>("U")                              \
1419           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1420       MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>);
1421 
1422 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
1423 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
1424 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU
1425 
1426 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T)               \
1427   REGISTER_KERNEL_BUILDER(                                     \
1428       Name("_MklFusedBatchNormGrad")                           \
1429           .Device(DEVICE_CPU)                                  \
1430           .TypeConstraint<T>("T")                              \
1431           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1432       MklFusedBatchNormGradOp<CPUDevice, T, T, false>);        \
1433   REGISTER_KERNEL_BUILDER(                                     \
1434       Name("_MklNativeFusedBatchNormGrad")                     \
1435           .Device(DEVICE_CPU)                                  \
1436           .TypeConstraint<T>("T")                              \
1437           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1438       MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>);
1439 
1440 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1441 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1442 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU
1443 
1444 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U)         \
1445   REGISTER_KERNEL_BUILDER(                                     \
1446       Name("_MklFusedBatchNormGradV2")                         \
1447           .Device(DEVICE_CPU)                                  \
1448           .TypeConstraint<T>("T")                              \
1449           .TypeConstraint<U>("U")                              \
1450           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1451       MklFusedBatchNormGradOp<CPUDevice, T, U, false>);        \
1452   REGISTER_KERNEL_BUILDER(                                     \
1453       Name("_MklNativeFusedBatchNormGradV2")                   \
1454           .Device(DEVICE_CPU)                                  \
1455           .TypeConstraint<T>("T")                              \
1456           .TypeConstraint<U>("U")                              \
1457           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1458       MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>);
1459 
1460 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float);
1461 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
1462 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU
1463 
1464 // TODO: FusedBatchNormV3 has an additional output that is used to
1465 //       hold intermediate results. This parameter functionality is
1466 //       not implemented on CPU.
1467 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U)               \
1468   REGISTER_KERNEL_BUILDER(                                      \
1469       Name("_MklFusedBatchNormV3")                              \
1470           .Device(DEVICE_CPU)                                   \
1471           .TypeConstraint<T>("T")                               \
1472           .TypeConstraint<U>("U")                               \
1473           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1474       MklFusedBatchNormOp<CPUDevice, T, U, true, false>);       \
1475   REGISTER_KERNEL_BUILDER(                                      \
1476       Name("_MklFusedBatchNormEx")                              \
1477           .Device(DEVICE_CPU)                                   \
1478           .TypeConstraint<T>("T")                               \
1479           .TypeConstraint<U>("U")                               \
1480           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1481       MklFusedBatchNormOp<CPUDevice, T, U, true, true>);        \
1482   REGISTER_KERNEL_BUILDER(                                      \
1483       Name("_MklNativeFusedBatchNormV3")                        \
1484           .Device(DEVICE_CPU)                                   \
1485           .TypeConstraint<T>("T")                               \
1486           .TypeConstraint<U>("U")                               \
1487           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1488       MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \
1489   REGISTER_KERNEL_BUILDER(                                      \
1490       Name("_MklNativeFusedBatchNormEx")                        \
1491           .Device(DEVICE_CPU)                                   \
1492           .TypeConstraint<T>("T")                               \
1493           .TypeConstraint<U>("U")                               \
1494           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1495       MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>);
1496 
1497 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
1498 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
1499 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
1500 
1501 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1502                             .Device(DEVICE_CPU)
1503                             .TypeConstraint<float>("T")
1504                             .TypeConstraint<float>("U"),
1505                         NoOp);
1506 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1507                             .Device(DEVICE_CPU)
1508                             .TypeConstraint<bfloat16>("T")
1509                             .TypeConstraint<float>("U"),
1510                         NoOp);
1511 
1512 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U)         \
1513   REGISTER_KERNEL_BUILDER(                                     \
1514       Name("_MklFusedBatchNormGradV3")                         \
1515           .Device(DEVICE_CPU)                                  \
1516           .TypeConstraint<T>("T")                              \
1517           .TypeConstraint<U>("U")                              \
1518           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1519       MklFusedBatchNormGradOp<CPUDevice, T, U, true>);         \
1520   REGISTER_KERNEL_BUILDER(                                     \
1521       Name("_MklNativeFusedBatchNormGradV3")                   \
1522           .Device(DEVICE_CPU)                                  \
1523           .TypeConstraint<T>("T")                              \
1524           .TypeConstraint<U>("U")                              \
1525           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1526       MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>);
1527 
1528 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float);
1529 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float);
1530 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU
1531 
1532 }  // namespace tensorflow
1533 
1534 #undef GET_FLAG
1535 #undef IS_SET
1536 
1537 #endif  // INTEL_MKL
1538