1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::optional;
46 
47 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
48 // operations into smaller operations.
49 class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
50  public:
51   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)52   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
53     return Status::OK();
54   }
55 
56   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
57 
58   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
59 
60   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
61 
62   // Runs the visitor on a computation.
63   static bool Run(HloComputation* computation, bool rewrite_training_op,
64                   bool rewrite_inference_op, bool rewrite_grad_op);
65 
66   // Returns whether any batch norm ops were rewritten.
changed() const67   const bool changed() const { return changed_; }
68 
69   ~BatchNormExpanderVisitor() override = default;
70 
71  private:
BatchNormExpanderVisitor(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)72   explicit BatchNormExpanderVisitor(HloComputation* computation,
73                                     bool rewrite_training_op,
74                                     bool rewrite_inference_op,
75                                     bool rewrite_grad_op)
76       : computation_(computation),
77         rewrite_training_op_(rewrite_training_op),
78         rewrite_inference_op_(rewrite_inference_op),
79         rewrite_grad_op_(rewrite_grad_op) {}
80 
GetOrCreateScalarAddComputation(PrimitiveType primitive_type)81   HloComputation* GetOrCreateScalarAddComputation(
82       PrimitiveType primitive_type) {
83     HloComputation::Builder b("scalar_add_computation");
84     Shape shape = ShapeUtil::MakeShape(primitive_type, {});
85     auto scalar_lhs = b.AddInstruction(
86         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
87     auto scalar_rhs = b.AddInstruction(
88         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
89     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
90         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
91     return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
92   }
93 
Rsqrt(HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)94   std::unique_ptr<HloInstruction> Rsqrt(
95       HloInstruction* operand,
96       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
97           add_instruction) {
98     return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt,
99                                        operand);
100   }
101 
Mean(HloInstruction * element_count,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)102   std::unique_ptr<HloInstruction> Mean(
103       HloInstruction* element_count, HloInstruction* operand,
104       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
105           add_instruction) {
106     auto broadcast = add_instruction(
107         HloInstruction::CreateBroadcast(operand->shape(), element_count, {}));
108     return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide,
109                                         operand, broadcast);
110   }
111 
DynamicElementCountPerFeature(HloInstruction * operand,int64 feature_index,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)112   std::unique_ptr<HloInstruction> DynamicElementCountPerFeature(
113       HloInstruction* operand, int64 feature_index,
114       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
115           add_instruction) {
116     auto elements_per_feature_u32 = add_instruction(
117         HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
118 
119     for (int64 i = 0; i < operand->shape().rank(); ++i) {
120       if (i == feature_index) {
121         continue;
122       }
123       auto dynamic_dimension_size =
124           add_instruction(HloInstruction::CreateGetDimensionSize(
125               ShapeUtil::MakeShape(U32, {}), operand, i));
126       elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary(
127           ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply,
128           dynamic_dimension_size, elements_per_feature_u32));
129     }
130 
131     return HloInstruction::CreateConvert(
132         ShapeUtil::MakeShape(operand->shape().element_type(), {}),
133         elements_per_feature_u32);
134   }
135 
136   // Replaces the existing HLO instruction old_instruction, with
137   // new_instruction, and marks the optimizer status as changed.
138   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)139   Status ReplaceWithNewInstruction(
140       HloInstruction* old_instruction,
141       std::unique_ptr<HloInstruction> new_instruction) {
142     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
143         old_instruction, std::move(new_instruction)));
144     changed_ = true;
145     return Status::OK();
146   }
147 
148   // Replaces the existing HLO instruction old_instruction, with
149   // new_instruction, and marks the optimizer status as changed.
150   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)151   Status ReplaceInstruction(HloInstruction* old_instruction,
152                             HloInstruction* new_instruction) {
153     TF_RETURN_IF_ERROR(
154         computation_->ReplaceInstruction(old_instruction, new_instruction));
155     changed_ = true;
156     return Status::OK();
157   }
158   // Current HloComputation instance the BatchNormExpander is
159   // traversing.
160   HloComputation* computation_;
161 
162   bool rewrite_training_op_;
163   bool rewrite_inference_op_;
164   bool rewrite_grad_op_;
165 
166   // Whether rewrite has occurred.
167   bool changed_ = false;
168 };
169 
170 }  // namespace
171 
Run(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)172 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
173                                    bool rewrite_training_op,
174                                    bool rewrite_inference_op,
175                                    bool rewrite_grad_op) {
176   BatchNormExpanderVisitor visitor(
177       computation,
178       /*rewrite_training_op=*/rewrite_training_op,
179       /*rewrite_inference_op=*/rewrite_inference_op,
180       /*rewrite_grad_op=*/rewrite_grad_op);
181   TF_CHECK_OK(computation->Accept(&visitor));
182   return visitor.changed_;
183 }
184 
HandleBatchNormTraining(HloInstruction * batch_norm)185 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
186     HloInstruction* batch_norm) {
187   if (!rewrite_training_op_) {
188     return Status::OK();
189   }
190 
191   std::vector<HloInstruction*> added_instructions;
192   auto add = [&](std::unique_ptr<HloInstruction> inst) {
193     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
194     added_inst->set_metadata(batch_norm->metadata());
195     added_instructions.push_back(added_inst);
196     return added_inst;
197   };
198   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
199                         HloInstruction* a, HloInstruction* b) {
200     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
201   };
202   int64 instruction_count_before = computation_->instruction_count();
203 
204   // Expand batch norm training into smaller HLO ops.
205   HloInstruction* operand = batch_norm->mutable_operand(0);
206   const Shape operand_shape = operand->shape();
207   PrimitiveType ptype = operand_shape.element_type();
208   int64 feature_index = batch_norm->feature_index();
209 
210   HloInstruction* scale = batch_norm->mutable_operand(1);
211   HloInstruction* offset = batch_norm->mutable_operand(2);
212   const Shape feature_shape = scale->shape();
213 
214   auto zero_literal = LiteralUtil::CreateR0(0.0f);
215   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
216   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
217 
218   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
219   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
220   auto epsilon = add(HloInstruction::CreateBroadcast(
221       operand_shape,
222       add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
223   std::vector<int64> dimensions_without_feature;
224 
225   for (int64 i = 0; i < operand_shape.rank(); ++i) {
226     if (i != feature_index) {
227       dimensions_without_feature.push_back(i);
228     }
229   }
230 
231   auto elements_per_feature =
232       add(DynamicElementCountPerFeature(operand, feature_index, add));
233 
234   auto scale_broadcasted = add(
235       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
236 
237   auto offset_broadcasted = add(
238       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
239 
240   HloComputation* add_reduce_computation =
241       GetOrCreateScalarAddComputation(ptype);
242 
243   // X^2.
244   auto operand_squared =
245       add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
246   // Sum[X].
247   auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
248                                               dimensions_without_feature,
249                                               add_reduce_computation));
250 
251   // Sum[X^2].
252   auto squared_sum = add(HloInstruction::CreateReduce(
253       feature_shape, operand_squared, zero, dimensions_without_feature,
254       add_reduce_computation));
255 
256   // E[X].
257   auto mean = add(Mean(elements_per_feature, sum, add));
258 
259   auto mean_broadcasted = add(
260       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
261 
262   // E[X^2].
263   auto square_mean = add(Mean(elements_per_feature, squared_sum, add));
264 
265   // E^2[X].
266   auto mean_square =
267       add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
268 
269   // Var[X].
270   auto var =
271       add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
272 
273   auto var_broadcasted =
274       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
275 
276   // Var[X] + epsilon.
277   auto var_add_epsilon =
278       add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
279 
280   // 1 / Sqrt[Var[X] + epsilon].
281   auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
282 
283   // X - E[X].
284   auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
285                                        operand, mean_broadcasted);
286 
287   // (X - E[X]) / Sqrt[Var[X] + epsilon].
288   auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
289                                operand_minus_mean, rsqrt_var_add_epsilon);
290 
291   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
292   auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
293                                       normalized, scale_broadcasted);
294 
295   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
296   auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
297                                        scaled_normalized, offset_broadcasted);
298 
299   auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
300 
301   if (batch_norm->has_sharding()) {
302     int64 instruction_count_after = computation_->instruction_count();
303     CHECK_EQ(instruction_count_after,
304              instruction_count_before + added_instructions.size());
305     const HloSharding& sharding = batch_norm->sharding();
306     HloSharding operand_sharding =
307         sharding.GetAsShapeTree(batch_norm->shape()).element({0});
308     optional<int64> unique_device = batch_norm->sharding_unique_device();
309     HloSharding default_sharding =
310         unique_device.has_value()
311             ? HloSharding::AssignDevice(unique_device.value())
312             : HloSharding::Replicate();
313     for (HloInstruction* inst : added_instructions) {
314       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
315         inst->set_sharding(operand_sharding);
316       } else {
317         inst->set_sharding(default_sharding);
318       }
319     }
320     tuple->set_sharding(sharding);
321   }
322   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
323   return Status::OK();
324 }
325 
HandleBatchNormInference(HloInstruction * batch_norm)326 Status BatchNormExpanderVisitor::HandleBatchNormInference(
327     HloInstruction* batch_norm) {
328   if (!rewrite_inference_op_) {
329     return Status::OK();
330   }
331   // Expand batch norm inference into smaller HLO ops.
332   HloInstruction* operand = batch_norm->mutable_operand(0);
333   const Shape operand_shape = operand->shape();
334   int64 feature_index = batch_norm->feature_index();
335   PrimitiveType ptype = operand_shape.element_type();
336 
337   HloInstruction* scale = batch_norm->mutable_operand(1);
338   HloInstruction* offset = batch_norm->mutable_operand(2);
339   HloInstruction* mean = batch_norm->mutable_operand(3);
340   HloInstruction* var = batch_norm->mutable_operand(4);
341   const Shape feature_shape = scale->shape();
342 
343   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
344   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
345   auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
346       operand_shape,
347       computation_->AddInstruction(
348           HloInstruction::CreateConstant(std::move(epsilon_literal))),
349       {}));
350 
351   std::vector<int64> dimensions_without_feature;
352 
353   for (int64 i = 0; i < operand_shape.rank(); ++i) {
354     if (i != feature_index) {
355       dimensions_without_feature.push_back(i);
356     }
357   }
358 
359   std::vector<HloInstruction*> added_instructions;
360   auto add = [&](std::unique_ptr<HloInstruction> inst) {
361     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
362     added_inst->set_metadata(batch_norm->metadata());
363     added_instructions.push_back(added_inst);
364     return added_inst;
365   };
366   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
367                         HloInstruction* a, HloInstruction* b) {
368     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
369   };
370   int64 instruction_count_before = computation_->instruction_count();
371 
372   auto scale_broadcasted = add(
373       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
374 
375   auto offset_broadcasted = add(
376       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
377 
378   auto mean_broadcasted = add(
379       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
380 
381   auto var_broadcasted =
382       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
383 
384   // Var[X] + epsilon.
385   auto var_add_epsilon =
386       add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
387 
388   // 1 / Sqrt[Var[X] + epsilon].
389   auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
390 
391   // X - E[X].
392   auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
393                                        operand, mean_broadcasted);
394 
395   // (X - E[X]) / Sqrt[Var[X] + epsilon].
396   auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
397                                operand_minus_mean, rsqrt_var_add_epsilon);
398 
399   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
400   auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
401                                       normalized, scale_broadcasted);
402 
403   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
404   auto shifted_normalized = HloInstruction::CreateBinary(
405       operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
406 
407   int64 instruction_count_after = computation_->instruction_count();
408   CHECK_EQ(instruction_count_after,
409            instruction_count_before + added_instructions.size());
410   if (batch_norm->has_sharding()) {
411     const HloSharding& sharding = batch_norm->sharding();
412     optional<int64> unique_device = batch_norm->sharding_unique_device();
413     HloSharding default_sharding =
414         unique_device.has_value()
415             ? HloSharding::AssignDevice(unique_device.value())
416             : HloSharding::Replicate();
417     for (HloInstruction* inst : added_instructions) {
418       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
419         inst->set_sharding(sharding);
420       } else {
421         inst->set_sharding(default_sharding);
422       }
423     }
424     shifted_normalized->set_sharding(sharding);
425   }
426   TF_CHECK_OK(
427       ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
428   return Status::OK();
429 }
430 
HandleBatchNormGrad(HloInstruction * batch_norm)431 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
432     HloInstruction* batch_norm) {
433   // Use the following formulas to calculate gradients:
434   // scale_grad =
435   //   sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
436   //
437   // offset_grad =
438   //   sum(output_grad)
439   //
440   // activation_grad =
441   //   1/N * scale * rsqrt(var + epsilon) *
442   //   (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
443   //   sum(output_grad * (activation - mean(activation))) / (variance +
444   //   epsilon))
445   if (!rewrite_grad_op_) {
446     return Status::OK();
447   }
448   std::vector<HloInstruction*> added_instructions;
449   auto add = [&](std::unique_ptr<HloInstruction> inst) {
450     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
451     added_inst->set_metadata(batch_norm->metadata());
452     added_instructions.push_back(added_inst);
453     return added_inst;
454   };
455   auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
456                         HloInstruction* a, HloInstruction* b) {
457     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
458   };
459   int64 instruction_count_before = computation_->instruction_count();
460 
461   HloInstruction* activation = batch_norm->mutable_operand(0);
462   const Shape activation_shape = activation->shape();
463   PrimitiveType ptype = activation_shape.element_type();
464   HloInstruction* scale = batch_norm->mutable_operand(1);
465   const Shape feature_shape = scale->shape();
466   HloInstruction* mean = batch_norm->mutable_operand(2);
467   HloInstruction* variance = batch_norm->mutable_operand(3);
468   HloInstruction* grad_output = batch_norm->mutable_operand(4);
469 
470   int64 feature_index = batch_norm->feature_index();
471 
472   auto elements_per_feature =
473       add(DynamicElementCountPerFeature(activation, feature_index, add));
474 
475   auto zero_literal = LiteralUtil::CreateR0(0.0f);
476   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
477   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
478 
479   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
480   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
481   auto epsilon_scalar =
482       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
483   auto epsilon_activation = add(
484       HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
485   auto epsilon_feature =
486       add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
487 
488   std::vector<int64> dimensions_without_feature;
489 
490   for (int64 i = 0; i < activation_shape.rank(); ++i) {
491     if (i != feature_index) {
492       dimensions_without_feature.push_back(i);
493     }
494   }
495 
496   auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
497       activation_shape, scale, {feature_index}));
498   auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
499       activation_shape, variance, {feature_index}));
500 
501   // E[X].
502   auto mean_broadcasted = add(
503       HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
504 
505   // rsqrt[Var[X] + epsilon].
506   auto rsqrt_var_add_epsilon_broadcasted =
507       add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
508                            variance_broadcasted, epsilon_activation),
509                 add));
510 
511   auto rsqrt_var_add_epsilon = add(Rsqrt(
512       add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
513       add));
514 
515   // X - E[X].
516   auto activation_minus_mean = add_binary(
517       activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
518 
519   // Grad[Y] * (X - E[X]).
520   auto grad_output_times_activation_minus_mean =
521       add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
522                  activation_minus_mean);
523 
524   HloComputation* add_reduce_computation =
525       GetOrCreateScalarAddComputation(ptype);
526 
527   // sum(Grad[Y] * (X - E[X])).
528   auto sum_grad_output_times_activation_minus_mean =
529       add(HloInstruction::CreateReduce(
530           feature_shape, grad_output_times_activation_minus_mean, zero,
531           dimensions_without_feature, add_reduce_computation));
532 
533   // Grad[beta] = Sum(Grad[Y]).
534   auto grad_beta = add(HloInstruction::CreateReduce(
535       feature_shape, grad_output, zero, dimensions_without_feature,
536       add_reduce_computation));
537 
538   // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
539   auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
540                                sum_grad_output_times_activation_minus_mean,
541                                rsqrt_var_add_epsilon);
542 
543   // I2 = Sum(Grad[Y])
544   auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
545                                                 {feature_index}));
546 
547   // I3 = Sum(Grad[Y] * (X - E[X]))
548   auto i3 = add(HloInstruction::CreateBroadcast(
549       activation_shape, sum_grad_output_times_activation_minus_mean,
550       {feature_index}));
551 
552   // I4 = (X - E[X]) * I3
553   auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
554                        activation_minus_mean);
555 
556   // I5 = I4 / (Var[X] + epsilon)
557   auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
558                        add_binary(activation_shape, HloOpcode::kAdd,
559                                   variance_broadcasted, epsilon_activation));
560 
561   // scale * rsqrt[Var[X] + epsilon] * 1/N
562   auto scale_times_rsqrt_var_add_epsilon =
563       add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
564                  rsqrt_var_add_epsilon_broadcasted);
565 
566   scale_times_rsqrt_var_add_epsilon =
567       add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add));
568 
569   auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
570                        add(HloInstruction::CreateBroadcast(
571                            activation_shape, elements_per_feature, {})));
572 
573   // I6 = I1 - I2 - I5
574   auto i6 = add_binary(
575       activation_shape, HloOpcode::kSubtract,
576       add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
577 
578   // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
579   auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
580                                     scale_times_rsqrt_var_add_epsilon, i6);
581   auto tuple =
582       HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
583   if (batch_norm->has_sharding()) {
584     const HloSharding& sharding = batch_norm->sharding();
585     int64 instruction_count_after = computation_->instruction_count();
586     CHECK_EQ(instruction_count_after,
587              instruction_count_before + added_instructions.size());
588     HloSharding activation_sharding =
589         sharding.GetAsShapeTree(batch_norm->shape()).element({0});
590     auto unique_device = batch_norm->sharding_unique_device();
591     HloSharding default_sharding =
592         unique_device.has_value()
593             ? HloSharding::AssignDevice(unique_device.value())
594             : HloSharding::Replicate();
595     for (HloInstruction* inst : added_instructions) {
596       if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
597         inst->set_sharding(activation_sharding);
598       } else {
599         inst->set_sharding(default_sharding);
600       }
601     }
602     tuple->set_sharding(sharding);
603   }
604 
605   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
606 
607   return Status::OK();
608 }
609 
Run(HloModule * module)610 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
611   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
612   bool changed = false;
613   for (auto* comp : module->MakeNonfusionComputations()) {
614     if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
615                                       rewrite_inference_op_,
616                                       rewrite_grad_op_)) {
617       changed = true;
618     }
619   }
620   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
621   return changed;
622 }
623 
624 }  // namespace xla
625