1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace xla {
42
43 namespace {
44
45 using absl::optional;
46
47 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
48 // operations into smaller operations.
49 class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
50 public:
51 // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)52 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
53 return Status::OK();
54 }
55
56 Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
57
58 Status HandleBatchNormInference(HloInstruction* batch_norm) override;
59
60 Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
61
62 // Runs the visitor on a computation.
63 static bool Run(HloComputation* computation, bool rewrite_training_op,
64 bool rewrite_inference_op, bool rewrite_grad_op);
65
66 // Returns whether any batch norm ops were rewritten.
changed() const67 const bool changed() const { return changed_; }
68
69 ~BatchNormExpanderVisitor() override = default;
70
71 private:
BatchNormExpanderVisitor(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)72 explicit BatchNormExpanderVisitor(HloComputation* computation,
73 bool rewrite_training_op,
74 bool rewrite_inference_op,
75 bool rewrite_grad_op)
76 : computation_(computation),
77 rewrite_training_op_(rewrite_training_op),
78 rewrite_inference_op_(rewrite_inference_op),
79 rewrite_grad_op_(rewrite_grad_op) {}
80
GetOrCreateScalarAddComputation(PrimitiveType primitive_type)81 HloComputation* GetOrCreateScalarAddComputation(
82 PrimitiveType primitive_type) {
83 HloComputation::Builder b("scalar_add_computation");
84 Shape shape = ShapeUtil::MakeShape(primitive_type, {});
85 auto scalar_lhs = b.AddInstruction(
86 HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
87 auto scalar_rhs = b.AddInstruction(
88 HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
89 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
90 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
91 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
92 }
93
Rsqrt(HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)94 std::unique_ptr<HloInstruction> Rsqrt(
95 HloInstruction* operand,
96 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
97 add_instruction) {
98 return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt,
99 operand);
100 }
101
Mean(HloInstruction * element_count,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)102 std::unique_ptr<HloInstruction> Mean(
103 HloInstruction* element_count, HloInstruction* operand,
104 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
105 add_instruction) {
106 auto broadcast = add_instruction(
107 HloInstruction::CreateBroadcast(operand->shape(), element_count, {}));
108 return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide,
109 operand, broadcast);
110 }
111
DynamicElementCountPerFeature(HloInstruction * operand,int64 feature_index,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)112 std::unique_ptr<HloInstruction> DynamicElementCountPerFeature(
113 HloInstruction* operand, int64 feature_index,
114 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
115 add_instruction) {
116 auto elements_per_feature_u32 = add_instruction(
117 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
118
119 for (int64 i = 0; i < operand->shape().rank(); ++i) {
120 if (i == feature_index) {
121 continue;
122 }
123 auto dynamic_dimension_size =
124 add_instruction(HloInstruction::CreateGetDimensionSize(
125 ShapeUtil::MakeShape(U32, {}), operand, i));
126 elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary(
127 ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply,
128 dynamic_dimension_size, elements_per_feature_u32));
129 }
130
131 return HloInstruction::CreateConvert(
132 ShapeUtil::MakeShape(operand->shape().element_type(), {}),
133 elements_per_feature_u32);
134 }
135
136 // Replaces the existing HLO instruction old_instruction, with
137 // new_instruction, and marks the optimizer status as changed.
138 // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)139 Status ReplaceWithNewInstruction(
140 HloInstruction* old_instruction,
141 std::unique_ptr<HloInstruction> new_instruction) {
142 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
143 old_instruction, std::move(new_instruction)));
144 changed_ = true;
145 return Status::OK();
146 }
147
148 // Replaces the existing HLO instruction old_instruction, with
149 // new_instruction, and marks the optimizer status as changed.
150 // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)151 Status ReplaceInstruction(HloInstruction* old_instruction,
152 HloInstruction* new_instruction) {
153 TF_RETURN_IF_ERROR(
154 computation_->ReplaceInstruction(old_instruction, new_instruction));
155 changed_ = true;
156 return Status::OK();
157 }
158 // Current HloComputation instance the BatchNormExpander is
159 // traversing.
160 HloComputation* computation_;
161
162 bool rewrite_training_op_;
163 bool rewrite_inference_op_;
164 bool rewrite_grad_op_;
165
166 // Whether rewrite has occurred.
167 bool changed_ = false;
168 };
169
170 } // namespace
171
Run(HloComputation * computation,bool rewrite_training_op,bool rewrite_inference_op,bool rewrite_grad_op)172 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
173 bool rewrite_training_op,
174 bool rewrite_inference_op,
175 bool rewrite_grad_op) {
176 BatchNormExpanderVisitor visitor(
177 computation,
178 /*rewrite_training_op=*/rewrite_training_op,
179 /*rewrite_inference_op=*/rewrite_inference_op,
180 /*rewrite_grad_op=*/rewrite_grad_op);
181 TF_CHECK_OK(computation->Accept(&visitor));
182 return visitor.changed_;
183 }
184
HandleBatchNormTraining(HloInstruction * batch_norm)185 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
186 HloInstruction* batch_norm) {
187 if (!rewrite_training_op_) {
188 return Status::OK();
189 }
190
191 std::vector<HloInstruction*> added_instructions;
192 auto add = [&](std::unique_ptr<HloInstruction> inst) {
193 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
194 added_inst->set_metadata(batch_norm->metadata());
195 added_instructions.push_back(added_inst);
196 return added_inst;
197 };
198 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
199 HloInstruction* a, HloInstruction* b) {
200 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
201 };
202 int64 instruction_count_before = computation_->instruction_count();
203
204 // Expand batch norm training into smaller HLO ops.
205 HloInstruction* operand = batch_norm->mutable_operand(0);
206 const Shape operand_shape = operand->shape();
207 PrimitiveType ptype = operand_shape.element_type();
208 int64 feature_index = batch_norm->feature_index();
209
210 HloInstruction* scale = batch_norm->mutable_operand(1);
211 HloInstruction* offset = batch_norm->mutable_operand(2);
212 const Shape feature_shape = scale->shape();
213
214 auto zero_literal = LiteralUtil::CreateR0(0.0f);
215 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
216 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
217
218 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
219 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
220 auto epsilon = add(HloInstruction::CreateBroadcast(
221 operand_shape,
222 add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
223 std::vector<int64> dimensions_without_feature;
224
225 for (int64 i = 0; i < operand_shape.rank(); ++i) {
226 if (i != feature_index) {
227 dimensions_without_feature.push_back(i);
228 }
229 }
230
231 auto elements_per_feature =
232 add(DynamicElementCountPerFeature(operand, feature_index, add));
233
234 auto scale_broadcasted = add(
235 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
236
237 auto offset_broadcasted = add(
238 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
239
240 HloComputation* add_reduce_computation =
241 GetOrCreateScalarAddComputation(ptype);
242
243 // X^2.
244 auto operand_squared =
245 add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
246 // Sum[X].
247 auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
248 dimensions_without_feature,
249 add_reduce_computation));
250
251 // Sum[X^2].
252 auto squared_sum = add(HloInstruction::CreateReduce(
253 feature_shape, operand_squared, zero, dimensions_without_feature,
254 add_reduce_computation));
255
256 // E[X].
257 auto mean = add(Mean(elements_per_feature, sum, add));
258
259 auto mean_broadcasted = add(
260 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
261
262 // E[X^2].
263 auto square_mean = add(Mean(elements_per_feature, squared_sum, add));
264
265 // E^2[X].
266 auto mean_square =
267 add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
268
269 // Var[X].
270 auto var =
271 add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
272
273 auto var_broadcasted =
274 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
275
276 // Var[X] + epsilon.
277 auto var_add_epsilon =
278 add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
279
280 // 1 / Sqrt[Var[X] + epsilon].
281 auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
282
283 // X - E[X].
284 auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
285 operand, mean_broadcasted);
286
287 // (X - E[X]) / Sqrt[Var[X] + epsilon].
288 auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
289 operand_minus_mean, rsqrt_var_add_epsilon);
290
291 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
292 auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
293 normalized, scale_broadcasted);
294
295 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
296 auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
297 scaled_normalized, offset_broadcasted);
298
299 auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
300
301 if (batch_norm->has_sharding()) {
302 int64 instruction_count_after = computation_->instruction_count();
303 CHECK_EQ(instruction_count_after,
304 instruction_count_before + added_instructions.size());
305 const HloSharding& sharding = batch_norm->sharding();
306 HloSharding operand_sharding =
307 sharding.GetAsShapeTree(batch_norm->shape()).element({0});
308 optional<int64> unique_device = batch_norm->sharding_unique_device();
309 HloSharding default_sharding =
310 unique_device.has_value()
311 ? HloSharding::AssignDevice(unique_device.value())
312 : HloSharding::Replicate();
313 for (HloInstruction* inst : added_instructions) {
314 if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
315 inst->set_sharding(operand_sharding);
316 } else {
317 inst->set_sharding(default_sharding);
318 }
319 }
320 tuple->set_sharding(sharding);
321 }
322 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
323 return Status::OK();
324 }
325
HandleBatchNormInference(HloInstruction * batch_norm)326 Status BatchNormExpanderVisitor::HandleBatchNormInference(
327 HloInstruction* batch_norm) {
328 if (!rewrite_inference_op_) {
329 return Status::OK();
330 }
331 // Expand batch norm inference into smaller HLO ops.
332 HloInstruction* operand = batch_norm->mutable_operand(0);
333 const Shape operand_shape = operand->shape();
334 int64 feature_index = batch_norm->feature_index();
335 PrimitiveType ptype = operand_shape.element_type();
336
337 HloInstruction* scale = batch_norm->mutable_operand(1);
338 HloInstruction* offset = batch_norm->mutable_operand(2);
339 HloInstruction* mean = batch_norm->mutable_operand(3);
340 HloInstruction* var = batch_norm->mutable_operand(4);
341 const Shape feature_shape = scale->shape();
342
343 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
344 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
345 auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
346 operand_shape,
347 computation_->AddInstruction(
348 HloInstruction::CreateConstant(std::move(epsilon_literal))),
349 {}));
350
351 std::vector<int64> dimensions_without_feature;
352
353 for (int64 i = 0; i < operand_shape.rank(); ++i) {
354 if (i != feature_index) {
355 dimensions_without_feature.push_back(i);
356 }
357 }
358
359 std::vector<HloInstruction*> added_instructions;
360 auto add = [&](std::unique_ptr<HloInstruction> inst) {
361 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
362 added_inst->set_metadata(batch_norm->metadata());
363 added_instructions.push_back(added_inst);
364 return added_inst;
365 };
366 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
367 HloInstruction* a, HloInstruction* b) {
368 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
369 };
370 int64 instruction_count_before = computation_->instruction_count();
371
372 auto scale_broadcasted = add(
373 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
374
375 auto offset_broadcasted = add(
376 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
377
378 auto mean_broadcasted = add(
379 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
380
381 auto var_broadcasted =
382 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
383
384 // Var[X] + epsilon.
385 auto var_add_epsilon =
386 add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
387
388 // 1 / Sqrt[Var[X] + epsilon].
389 auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
390
391 // X - E[X].
392 auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
393 operand, mean_broadcasted);
394
395 // (X - E[X]) / Sqrt[Var[X] + epsilon].
396 auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
397 operand_minus_mean, rsqrt_var_add_epsilon);
398
399 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
400 auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
401 normalized, scale_broadcasted);
402
403 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
404 auto shifted_normalized = HloInstruction::CreateBinary(
405 operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
406
407 int64 instruction_count_after = computation_->instruction_count();
408 CHECK_EQ(instruction_count_after,
409 instruction_count_before + added_instructions.size());
410 if (batch_norm->has_sharding()) {
411 const HloSharding& sharding = batch_norm->sharding();
412 optional<int64> unique_device = batch_norm->sharding_unique_device();
413 HloSharding default_sharding =
414 unique_device.has_value()
415 ? HloSharding::AssignDevice(unique_device.value())
416 : HloSharding::Replicate();
417 for (HloInstruction* inst : added_instructions) {
418 if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
419 inst->set_sharding(sharding);
420 } else {
421 inst->set_sharding(default_sharding);
422 }
423 }
424 shifted_normalized->set_sharding(sharding);
425 }
426 TF_CHECK_OK(
427 ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
428 return Status::OK();
429 }
430
HandleBatchNormGrad(HloInstruction * batch_norm)431 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
432 HloInstruction* batch_norm) {
433 // Use the following formulas to calculate gradients:
434 // scale_grad =
435 // sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
436 //
437 // offset_grad =
438 // sum(output_grad)
439 //
440 // activation_grad =
441 // 1/N * scale * rsqrt(var + epsilon) *
442 // (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
443 // sum(output_grad * (activation - mean(activation))) / (variance +
444 // epsilon))
445 if (!rewrite_grad_op_) {
446 return Status::OK();
447 }
448 std::vector<HloInstruction*> added_instructions;
449 auto add = [&](std::unique_ptr<HloInstruction> inst) {
450 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
451 added_inst->set_metadata(batch_norm->metadata());
452 added_instructions.push_back(added_inst);
453 return added_inst;
454 };
455 auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
456 HloInstruction* a, HloInstruction* b) {
457 return add(HloInstruction::CreateBinary(shape, opcode, a, b));
458 };
459 int64 instruction_count_before = computation_->instruction_count();
460
461 HloInstruction* activation = batch_norm->mutable_operand(0);
462 const Shape activation_shape = activation->shape();
463 PrimitiveType ptype = activation_shape.element_type();
464 HloInstruction* scale = batch_norm->mutable_operand(1);
465 const Shape feature_shape = scale->shape();
466 HloInstruction* mean = batch_norm->mutable_operand(2);
467 HloInstruction* variance = batch_norm->mutable_operand(3);
468 HloInstruction* grad_output = batch_norm->mutable_operand(4);
469
470 int64 feature_index = batch_norm->feature_index();
471
472 auto elements_per_feature =
473 add(DynamicElementCountPerFeature(activation, feature_index, add));
474
475 auto zero_literal = LiteralUtil::CreateR0(0.0f);
476 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
477 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
478
479 auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
480 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
481 auto epsilon_scalar =
482 add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
483 auto epsilon_activation = add(
484 HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
485 auto epsilon_feature =
486 add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
487
488 std::vector<int64> dimensions_without_feature;
489
490 for (int64 i = 0; i < activation_shape.rank(); ++i) {
491 if (i != feature_index) {
492 dimensions_without_feature.push_back(i);
493 }
494 }
495
496 auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
497 activation_shape, scale, {feature_index}));
498 auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
499 activation_shape, variance, {feature_index}));
500
501 // E[X].
502 auto mean_broadcasted = add(
503 HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
504
505 // rsqrt[Var[X] + epsilon].
506 auto rsqrt_var_add_epsilon_broadcasted =
507 add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
508 variance_broadcasted, epsilon_activation),
509 add));
510
511 auto rsqrt_var_add_epsilon = add(Rsqrt(
512 add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
513 add));
514
515 // X - E[X].
516 auto activation_minus_mean = add_binary(
517 activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
518
519 // Grad[Y] * (X - E[X]).
520 auto grad_output_times_activation_minus_mean =
521 add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
522 activation_minus_mean);
523
524 HloComputation* add_reduce_computation =
525 GetOrCreateScalarAddComputation(ptype);
526
527 // sum(Grad[Y] * (X - E[X])).
528 auto sum_grad_output_times_activation_minus_mean =
529 add(HloInstruction::CreateReduce(
530 feature_shape, grad_output_times_activation_minus_mean, zero,
531 dimensions_without_feature, add_reduce_computation));
532
533 // Grad[beta] = Sum(Grad[Y]).
534 auto grad_beta = add(HloInstruction::CreateReduce(
535 feature_shape, grad_output, zero, dimensions_without_feature,
536 add_reduce_computation));
537
538 // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
539 auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
540 sum_grad_output_times_activation_minus_mean,
541 rsqrt_var_add_epsilon);
542
543 // I2 = Sum(Grad[Y])
544 auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
545 {feature_index}));
546
547 // I3 = Sum(Grad[Y] * (X - E[X]))
548 auto i3 = add(HloInstruction::CreateBroadcast(
549 activation_shape, sum_grad_output_times_activation_minus_mean,
550 {feature_index}));
551
552 // I4 = (X - E[X]) * I3
553 auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
554 activation_minus_mean);
555
556 // I5 = I4 / (Var[X] + epsilon)
557 auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
558 add_binary(activation_shape, HloOpcode::kAdd,
559 variance_broadcasted, epsilon_activation));
560
561 // scale * rsqrt[Var[X] + epsilon] * 1/N
562 auto scale_times_rsqrt_var_add_epsilon =
563 add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
564 rsqrt_var_add_epsilon_broadcasted);
565
566 scale_times_rsqrt_var_add_epsilon =
567 add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add));
568
569 auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
570 add(HloInstruction::CreateBroadcast(
571 activation_shape, elements_per_feature, {})));
572
573 // I6 = I1 - I2 - I5
574 auto i6 = add_binary(
575 activation_shape, HloOpcode::kSubtract,
576 add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
577
578 // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
579 auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
580 scale_times_rsqrt_var_add_epsilon, i6);
581 auto tuple =
582 HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
583 if (batch_norm->has_sharding()) {
584 const HloSharding& sharding = batch_norm->sharding();
585 int64 instruction_count_after = computation_->instruction_count();
586 CHECK_EQ(instruction_count_after,
587 instruction_count_before + added_instructions.size());
588 HloSharding activation_sharding =
589 sharding.GetAsShapeTree(batch_norm->shape()).element({0});
590 auto unique_device = batch_norm->sharding_unique_device();
591 HloSharding default_sharding =
592 unique_device.has_value()
593 ? HloSharding::AssignDevice(unique_device.value())
594 : HloSharding::Replicate();
595 for (HloInstruction* inst : added_instructions) {
596 if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
597 inst->set_sharding(activation_sharding);
598 } else {
599 inst->set_sharding(default_sharding);
600 }
601 }
602 tuple->set_sharding(sharding);
603 }
604
605 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
606
607 return Status::OK();
608 }
609
Run(HloModule * module)610 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
611 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
612 bool changed = false;
613 for (auto* comp : module->MakeNonfusionComputations()) {
614 if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
615 rewrite_inference_op_,
616 rewrite_grad_op_)) {
617 changed = true;
618 }
619 }
620 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
621 return changed;
622 }
623
624 } // namespace xla
625