1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_query.h"
33 #include "tensorflow/compiler/xla/service/shape_inference.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/window_util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace xla {
47 
48 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
49 // operations into smaller operations.
50 class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
51  public:
52   // Default visitor action is to do nothing and return OK.
53   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
54     return Status::OK();
55   }
56 
57   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
58 
59   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
60 
61   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
62 
63   // Runs the visitor on a computation.
64   static bool Run(HloComputation* computation, bool rewrite_training_op,
65                   bool rewrite_inference_op, bool rewrite_grad_op,
66                   bool use_fusion);
67 
68   // Returns whether any batch norm ops were rewritten.
69   const bool changed() const { return changed_; }
70 
71   ~BatchNormExpanderVisitor() override = default;
72 
73  private:
74   explicit BatchNormExpanderVisitor(HloComputation* computation,
75                                     bool rewrite_training_op,
76                                     bool rewrite_inference_op,
77                                     bool rewrite_grad_op, bool use_fusion)
78       : computation_(computation),
79         rewrite_training_op_(rewrite_training_op),
80         rewrite_inference_op_(rewrite_inference_op),
81         rewrite_grad_op_(rewrite_grad_op),
82         use_fusion_(use_fusion) {}
83 
84   HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
85                                              HloOpcode opcode) {
86     HloComputation::Builder b("scalar_computation");
87     auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
88         0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
89     auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
90         1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
91     auto scalar_op = b.AddInstruction(
92         HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
93                                      opcode, scalar_lhs, scalar_rhs));
94     return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
95   }
96 
97   // Current HloComputation instance the BatchNormExpander is
98   // traversing.
99   HloComputation* computation_;
100 
101   bool rewrite_training_op_;
102   bool rewrite_inference_op_;
103   bool rewrite_grad_op_;
104   bool use_fusion_;
105 
106   // Whether rewrite has occurred.
107   bool changed_ = false;
108 
109   // Replaces the existing HLO instruction old_instruction, with
110   // new_instruction, and marks the optimizer status as changed.
111   // Returns the Status representing the result of the replace operation.
112   Status ReplaceWithNewInstruction(
113       HloInstruction* old_instruction,
114       std::unique_ptr<HloInstruction> new_instruction) {
115     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
116         old_instruction, std::move(new_instruction)));
117     changed_ = true;
118     return Status::OK();
119   }
120 
121   // Replaces the existing HLO instruction old_instruction, with
122   // new_instruction, and marks the optimizer status as changed.
123   // Returns the Status representing the result of the replace operation.
124   Status ReplaceInstruction(HloInstruction* old_instruction,
125                             HloInstruction* new_instruction) {
126     TF_RETURN_IF_ERROR(
127         computation_->ReplaceInstruction(old_instruction, new_instruction));
128     changed_ = true;
129     return Status::OK();
130   }
131 };
132 
133 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
134                                    bool rewrite_training_op,
135                                    bool rewrite_inference_op,
136                                    bool rewrite_grad_op, bool use_fusion) {
137   BatchNormExpanderVisitor visitor(
138       computation,
139       /*rewrite_training_op=*/rewrite_training_op,
140       /*rewrite_inference_op=*/rewrite_inference_op,
141       /*rewrite_grad_op=*/rewrite_grad_op,
142       /*use_fusion=*/use_fusion);
143   TF_CHECK_OK(computation->Accept(&visitor));
144   return visitor.changed_;
145 }
146 
147 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
148     HloInstruction* batch_norm) {
149   if (!rewrite_training_op_) {
150     return Status::OK();
151   }
152 
153   std::vector<HloInstruction*> added_instructions;
154   auto add = [&](std::unique_ptr<HloInstruction> inst) {
155     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
156     added_instructions.push_back(added_inst);
157     return added_inst;
158   };
159   int64 instruction_count_before = computation_->instruction_count();
160 
161   // Expand batch norm training into smaller HLO ops.
162   HloInstruction* operand = batch_norm->mutable_operand(0);
163   const Shape operand_shape = operand->shape();
164   PrimitiveType ptype = operand_shape.element_type();
165   int64 feature_index = batch_norm->feature_index();
166   const int64 feature_count = operand_shape.dimensions(feature_index);
167   const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
168   auto elements_per_feature_literal =
169       Literal::CreateR0<float>(size_in_elements / feature_count);
170   TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
171                       elements_per_feature_literal->Convert(ptype));
172   auto elements_per_feature = add(
173       HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
174 
175   HloInstruction* scale = batch_norm->mutable_operand(1);
176   HloInstruction* offset = batch_norm->mutable_operand(2);
177   const Shape feature_shape = scale->shape();
178 
179   auto zero_literal = Literal::CreateR0(0.0f);
180   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
181   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
182 
183   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
184   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
185   auto epsilon =
186       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
187   std::vector<int64> dimensions_without_feature;
188 
189   for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
190     if (i != feature_index) {
191       dimensions_without_feature.push_back(i);
192     }
193   }
194 
195   auto scale_broadcasted = add(
196       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
197 
198   auto offset_broadcasted = add(
199       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
200 
201   HloComputation* add_reduce_computation =
202       GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
203 
204   // X^2.
205   auto operand_squared = add(HloInstruction::CreateBinary(
206       operand_shape, HloOpcode::kMultiply, operand, operand));
207   // Sum[X].
208   auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
209                                               dimensions_without_feature,
210                                               add_reduce_computation));
211 
212   // Sum[X^2].
213   auto squared_sum = add(HloInstruction::CreateReduce(
214       feature_shape, operand_squared, zero, dimensions_without_feature,
215       add_reduce_computation));
216 
217   // Fuse two parallel reduces together to improve performance.
218   if (use_fusion_ && !batch_norm->has_sharding()) {
219     auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
220 
221     auto fused = computation_->CreateFusionInstruction(
222         {tuple, sum, squared_sum, operand_squared},
223         HloInstruction::FusionKind::kInput);
224 
225     sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
226 
227     squared_sum =
228         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
229   }
230 
231   // E[X].
232   auto mean = add(HloInstruction::CreateBinary(
233       feature_shape, HloOpcode::kDivide, sum, elements_per_feature));
234 
235   auto mean_broadcasted = add(
236       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
237 
238   // E[X^2].
239   auto square_mean = add(HloInstruction::CreateBinary(
240       feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature));
241 
242   // E^2[X].
243   auto mean_square = add(HloInstruction::CreateBinary(
244       feature_shape, HloOpcode::kMultiply, mean, mean));
245 
246   // Var[X].
247   auto var = add(HloInstruction::CreateBinary(
248       feature_shape, HloOpcode::kSubtract, square_mean, mean_square));
249 
250   auto var_broadcasted =
251       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
252 
253   // Var[X] + epsilon.
254   auto var_add_epsilon = add(HloInstruction::CreateBinary(
255       operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
256 
257   auto neg_half_literal = Literal::CreateR0(-0.5f);
258   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
259   auto neg_half =
260       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
261 
262   // 1 / Sqrt[Var[X] + epsilon].
263   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
264       operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
265 
266   // X - E[X].
267   auto operand_minus_mean = add(HloInstruction::CreateBinary(
268       operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
269 
270   // (X - E[X]) / Sqrt[Var[X] + epsilon].
271   auto normalized = add(
272       HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
273                                    operand_minus_mean, rsqrt_var_add_epsilon));
274 
275   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
276   auto scaled_normalized = add(HloInstruction::CreateBinary(
277       operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
278 
279   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
280   auto shifted_normalized = add(HloInstruction::CreateBinary(
281       operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted));
282 
283   auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
284 
285   if (batch_norm->has_sharding()) {
286     int64 instruction_count_after = computation_->instruction_count();
287     CHECK_EQ(instruction_count_after,
288              instruction_count_before + added_instructions.size());
289     HloSharding operand_sharding =
290         batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
291     for (HloInstruction* inst : added_instructions) {
292       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
293         inst->set_sharding(operand_sharding);
294       } else {
295         inst->set_sharding(HloSharding::Replicate());
296       }
297     }
298     tuple->set_sharding(batch_norm->sharding());
299   }
300   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
301   return Status::OK();
302 }
303 
304 Status BatchNormExpanderVisitor::HandleBatchNormInference(
305     HloInstruction* batch_norm) {
306   if (!rewrite_inference_op_) {
307     return Status::OK();
308   }
309   // Expand batch norm inference into smaller HLO ops.
310   HloInstruction* operand = batch_norm->mutable_operand(0);
311   const Shape operand_shape = operand->shape();
312   int64 feature_index = batch_norm->feature_index();
313   PrimitiveType ptype = operand_shape.element_type();
314 
315   HloInstruction* scale = batch_norm->mutable_operand(1);
316   HloInstruction* offset = batch_norm->mutable_operand(2);
317   HloInstruction* mean = batch_norm->mutable_operand(3);
318   HloInstruction* var = batch_norm->mutable_operand(4);
319   const Shape feature_shape = scale->shape();
320 
321   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
322   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
323   auto epsilon = computation_->AddInstruction(
324       HloInstruction::CreateConstant(std::move(epsilon_literal)));
325 
326   std::vector<int64> dimensions_without_feature;
327 
328   for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
329     if (i != feature_index) {
330       dimensions_without_feature.push_back(i);
331     }
332   }
333 
334   std::vector<HloInstruction*> added_instructions;
335   auto add = [&](std::unique_ptr<HloInstruction> inst) {
336     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
337     added_instructions.push_back(added_inst);
338     return added_inst;
339   };
340   int64 instruction_count_before = computation_->instruction_count();
341 
342   auto scale_broadcasted = add(
343       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
344 
345   auto offset_broadcasted = add(
346       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
347 
348   auto mean_broadcasted = add(
349       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
350 
351   auto var_broadcasted =
352       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
353 
354   // Var[X] + epsilon.
355   auto var_add_epsilon = add(HloInstruction::CreateBinary(
356       operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
357 
358   auto neg_half_literal = Literal::CreateR0(-0.5f);
359   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
360   auto neg_half =
361       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
362 
363   // 1 / Sqrt[Var[X] + epsilon].
364   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
365       operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
366 
367   // X - E[X].
368   auto operand_minus_mean = add(HloInstruction::CreateBinary(
369       operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
370 
371   // (X - E[X]) / Sqrt[Var[X] + epsilon].
372   auto normalized = add(
373       HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
374                                    operand_minus_mean, rsqrt_var_add_epsilon));
375 
376   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
377   auto scaled_normalized = add(HloInstruction::CreateBinary(
378       operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
379 
380   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
381   auto shifted_normalized = HloInstruction::CreateBinary(
382       operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
383 
384   int64 instruction_count_after = computation_->instruction_count();
385   CHECK_EQ(instruction_count_after,
386            instruction_count_before + added_instructions.size());
387   if (batch_norm->has_sharding()) {
388     for (HloInstruction* inst : added_instructions) {
389       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
390         inst->set_sharding(batch_norm->sharding());
391       } else {
392         inst->set_sharding(HloSharding::Replicate());
393       }
394     }
395     shifted_normalized->set_sharding(batch_norm->sharding());
396   }
397   TF_CHECK_OK(
398       ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
399   return Status::OK();
400 }
401 
402 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
403     HloInstruction* batch_norm) {
404   // Use the following formulas to calculate gradients:
405   // scale_grad =
406   //   sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
407   //
408   // offset_grad =
409   //   sum(output_grad)
410   //
411   // activation_grad =
412   //   1/N * scale * rsqrt(var + epsilon) *
413   //   (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
414   //   sum(output_grad * (activation - mean(activation))) / (variance +
415   //   epsilon))
416   if (!rewrite_grad_op_) {
417     return Status::OK();
418   }
419   std::vector<HloInstruction*> added_instructions;
420   auto add = [&](std::unique_ptr<HloInstruction> inst) {
421     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
422     added_instructions.push_back(added_inst);
423     return added_inst;
424   };
425   int64 instruction_count_before = computation_->instruction_count();
426 
427   HloInstruction* activation = batch_norm->mutable_operand(0);
428   const Shape activation_shape = activation->shape();
429   PrimitiveType ptype = activation_shape.element_type();
430   HloInstruction* scale = batch_norm->mutable_operand(1);
431   const Shape feature_shape = scale->shape();
432   HloInstruction* mean = batch_norm->mutable_operand(2);
433   HloInstruction* variance = batch_norm->mutable_operand(3);
434   HloInstruction* grad_output = batch_norm->mutable_operand(4);
435 
436   int64 feature_index = batch_norm->feature_index();
437 
438   const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
439   const int64 feature_count = activation_shape.dimensions(feature_index);
440   auto elements_per_feature_literal =
441       Literal::CreateR0<float>(size_in_elements / feature_count);
442   TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
443                       elements_per_feature_literal->Convert(ptype));
444   auto elements_per_feature = add(
445       HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
446 
447   auto zero_literal = Literal::CreateR0(0.0f);
448   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
449   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
450 
451   auto neg_half_literal = Literal::CreateR0(-0.5f);
452   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
453   auto neg_half =
454       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
455 
456   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
457   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
458   auto epsilon =
459       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
460 
461   std::vector<int64> dimensions_without_feature;
462 
463   for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) {
464     if (i != feature_index) {
465       dimensions_without_feature.push_back(i);
466     }
467   }
468 
469   auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
470       activation_shape, scale, {feature_index}));
471   auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
472       activation_shape, variance, {feature_index}));
473 
474   // E[X].
475   auto mean_broadcasted = add(
476       HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
477 
478   // rsqrt[Var[X] + epsilon].
479   auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary(
480       activation_shape, HloOpcode::kPower,
481       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
482                                        variance_broadcasted, epsilon)),
483       neg_half));
484 
485   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
486       feature_shape, HloOpcode::kPower,
487       add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance,
488                                        epsilon)),
489       neg_half));
490 
491   // X - E[X].
492   auto activation_minus_mean = add(HloInstruction::CreateBinary(
493       activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted));
494 
495   // Grad[Y] * (X - E[X]).
496   auto grad_output_times_activiation_minus_mean =
497       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
498                                        grad_output, activation_minus_mean));
499 
500   HloComputation* add_reduce_computation =
501       GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
502 
503   // sum(Grad[Y] * (X - E[X])).
504   auto sum_grad_output_times_activiation_minus_mean =
505       add(HloInstruction::CreateReduce(
506           feature_shape, grad_output_times_activiation_minus_mean, zero,
507           dimensions_without_feature, add_reduce_computation));
508 
509   // Grad[beta] = Sum(Grad[Y]).
510   auto grad_beta = add(HloInstruction::CreateReduce(
511       feature_shape, grad_output, zero, dimensions_without_feature,
512       add_reduce_computation));
513 
514   if (use_fusion_ && !batch_norm->has_sharding()) {
515     auto tuple = add(HloInstruction::CreateTuple(
516         {sum_grad_output_times_activiation_minus_mean, grad_beta}));
517 
518     auto fused = computation_->CreateFusionInstruction(
519         {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
520         HloInstruction::FusionKind::kInput);
521 
522     sum_grad_output_times_activiation_minus_mean =
523         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
524 
525     grad_beta =
526         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
527   }
528 
529   // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
530   auto grad_scale = add(HloInstruction::CreateBinary(
531       feature_shape, HloOpcode::kMultiply,
532       sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon));
533 
534   // I2 = Sum(Grad[Y])
535   auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
536                                                 {feature_index}));
537 
538   // I3 = Sum(Grad[Y] * (X - E[X]))
539   auto i3 = add(HloInstruction::CreateBroadcast(
540       activation_shape, sum_grad_output_times_activiation_minus_mean,
541       {feature_index}));
542 
543   // I4 = (X - E[X]) * I3
544   auto i4 = add(HloInstruction::CreateBinary(
545       activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean));
546 
547   // I5 = I4 / (Var[X] + epsilon)
548   auto i5 = add(HloInstruction::CreateBinary(
549       activation_shape, HloOpcode::kDivide, i4,
550       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
551                                        variance_broadcasted, epsilon))));
552 
553   // scale * rsqrt[Var[X] + epsilon] * 1/N
554   auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
555       activation_shape, HloOpcode::kMultiply, scale_broadcasted,
556       rsqrt_var_add_epsilon_broadcasted));
557 
558   scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
559       activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon,
560       elements_per_feature));
561 
562   auto i1 =
563       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
564                                        grad_output, elements_per_feature));
565 
566   // I6 = I1 - I2 - I5
567   auto i6 = add(HloInstruction::CreateBinary(
568       activation_shape, HloOpcode::kSubtract,
569       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
570                                        i1, i2)),
571       i5));
572 
573   // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
574   auto grad_activation =
575       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
576                                        scale_times_rsqrt_var_add_epsilon, i6));
577   auto tuple =
578       HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
579   if (batch_norm->has_sharding()) {
580     int64 instruction_count_after = computation_->instruction_count();
581     CHECK_EQ(instruction_count_after,
582              instruction_count_before + added_instructions.size());
583     HloSharding activation_sharding =
584         batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
585     for (HloInstruction* inst : added_instructions) {
586       if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
587         inst->set_sharding(activation_sharding);
588       } else {
589         inst->set_sharding(HloSharding::Replicate());
590       }
591     }
592     tuple->set_sharding(batch_norm->sharding());
593   }
594 
595   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
596 
597   return Status::OK();
598 }
599 
600 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
601   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
602   bool changed = false;
603   for (auto* comp : module->MakeNonfusionComputations()) {
604     if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
605                                       rewrite_inference_op_, rewrite_grad_op_,
606                                       use_fusion_)) {
607       changed = true;
608     }
609   }
610   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
611   return changed;
612 }
613 
614 }  // namespace xla
615