1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <numeric> 21 #include <set> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/compiler/xla/layout_util.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 29 #include "tensorflow/compiler/xla/service/hlo_computation.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/service/hlo_query.h" 33 #include "tensorflow/compiler/xla/service/shape_inference.h" 34 #include "tensorflow/compiler/xla/shape_util.h" 35 #include "tensorflow/compiler/xla/status_macros.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/compiler/xla/util.h" 38 #include "tensorflow/compiler/xla/window_util.h" 39 #include "tensorflow/compiler/xla/xla_data.pb.h" 40 #include "tensorflow/core/lib/core/errors.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/lib/gtl/array_slice.h" 43 #include "tensorflow/core/platform/logging.h" 44 #include "tensorflow/core/platform/types.h" 45 46 namespace xla { 47 48 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm 49 // operations into smaller operations. 50 class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { 51 public: 52 // Default visitor action is to do nothing and return OK. 53 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 54 return Status::OK(); 55 } 56 57 Status HandleBatchNormTraining(HloInstruction* batch_norm) override; 58 59 Status HandleBatchNormInference(HloInstruction* batch_norm) override; 60 61 Status HandleBatchNormGrad(HloInstruction* batch_norm) override; 62 63 // Runs the visitor on a computation. 64 static bool Run(HloComputation* computation, bool rewrite_training_op, 65 bool rewrite_inference_op, bool rewrite_grad_op, 66 bool use_fusion); 67 68 // Returns whether any batch norm ops were rewritten. 69 const bool changed() const { return changed_; } 70 71 ~BatchNormExpanderVisitor() override = default; 72 73 private: 74 explicit BatchNormExpanderVisitor(HloComputation* computation, 75 bool rewrite_training_op, 76 bool rewrite_inference_op, 77 bool rewrite_grad_op, bool use_fusion) 78 : computation_(computation), 79 rewrite_training_op_(rewrite_training_op), 80 rewrite_inference_op_(rewrite_inference_op), 81 rewrite_grad_op_(rewrite_grad_op), 82 use_fusion_(use_fusion) {} 83 84 HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, 85 HloOpcode opcode) { 86 HloComputation::Builder b("scalar_computation"); 87 auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( 88 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); 89 auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( 90 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); 91 auto scalar_op = b.AddInstruction( 92 HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), 93 opcode, scalar_lhs, scalar_rhs)); 94 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); 95 } 96 97 // Current HloComputation instance the BatchNormExpander is 98 // traversing. 99 HloComputation* computation_; 100 101 bool rewrite_training_op_; 102 bool rewrite_inference_op_; 103 bool rewrite_grad_op_; 104 bool use_fusion_; 105 106 // Whether rewrite has occurred. 107 bool changed_ = false; 108 109 // Replaces the existing HLO instruction old_instruction, with 110 // new_instruction, and marks the optimizer status as changed. 111 // Returns the Status representing the result of the replace operation. 112 Status ReplaceWithNewInstruction( 113 HloInstruction* old_instruction, 114 std::unique_ptr<HloInstruction> new_instruction) { 115 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 116 old_instruction, std::move(new_instruction))); 117 changed_ = true; 118 return Status::OK(); 119 } 120 121 // Replaces the existing HLO instruction old_instruction, with 122 // new_instruction, and marks the optimizer status as changed. 123 // Returns the Status representing the result of the replace operation. 124 Status ReplaceInstruction(HloInstruction* old_instruction, 125 HloInstruction* new_instruction) { 126 TF_RETURN_IF_ERROR( 127 computation_->ReplaceInstruction(old_instruction, new_instruction)); 128 changed_ = true; 129 return Status::OK(); 130 } 131 }; 132 133 bool BatchNormExpanderVisitor::Run(HloComputation* computation, 134 bool rewrite_training_op, 135 bool rewrite_inference_op, 136 bool rewrite_grad_op, bool use_fusion) { 137 BatchNormExpanderVisitor visitor( 138 computation, 139 /*rewrite_training_op=*/rewrite_training_op, 140 /*rewrite_inference_op=*/rewrite_inference_op, 141 /*rewrite_grad_op=*/rewrite_grad_op, 142 /*use_fusion=*/use_fusion); 143 TF_CHECK_OK(computation->Accept(&visitor)); 144 return visitor.changed_; 145 } 146 147 Status BatchNormExpanderVisitor::HandleBatchNormTraining( 148 HloInstruction* batch_norm) { 149 if (!rewrite_training_op_) { 150 return Status::OK(); 151 } 152 153 std::vector<HloInstruction*> added_instructions; 154 auto add = [&](std::unique_ptr<HloInstruction> inst) { 155 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); 156 added_instructions.push_back(added_inst); 157 return added_inst; 158 }; 159 int64 instruction_count_before = computation_->instruction_count(); 160 161 // Expand batch norm training into smaller HLO ops. 162 HloInstruction* operand = batch_norm->mutable_operand(0); 163 const Shape operand_shape = operand->shape(); 164 PrimitiveType ptype = operand_shape.element_type(); 165 int64 feature_index = batch_norm->feature_index(); 166 const int64 feature_count = operand_shape.dimensions(feature_index); 167 const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); 168 auto elements_per_feature_literal = 169 Literal::CreateR0<float>(size_in_elements / feature_count); 170 TF_ASSIGN_OR_RETURN(elements_per_feature_literal, 171 elements_per_feature_literal->Convert(ptype)); 172 auto elements_per_feature = add( 173 HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); 174 175 HloInstruction* scale = batch_norm->mutable_operand(1); 176 HloInstruction* offset = batch_norm->mutable_operand(2); 177 const Shape feature_shape = scale->shape(); 178 179 auto zero_literal = Literal::CreateR0(0.0f); 180 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); 181 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); 182 183 auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); 184 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); 185 auto epsilon = 186 add(HloInstruction::CreateConstant(std::move(epsilon_literal))); 187 std::vector<int64> dimensions_without_feature; 188 189 for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { 190 if (i != feature_index) { 191 dimensions_without_feature.push_back(i); 192 } 193 } 194 195 auto scale_broadcasted = add( 196 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); 197 198 auto offset_broadcasted = add( 199 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); 200 201 HloComputation* add_reduce_computation = 202 GetScalarBinaryComputation(ptype, HloOpcode::kAdd); 203 204 // X^2. 205 auto operand_squared = add(HloInstruction::CreateBinary( 206 operand_shape, HloOpcode::kMultiply, operand, operand)); 207 // Sum[X]. 208 auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, 209 dimensions_without_feature, 210 add_reduce_computation)); 211 212 // Sum[X^2]. 213 auto squared_sum = add(HloInstruction::CreateReduce( 214 feature_shape, operand_squared, zero, dimensions_without_feature, 215 add_reduce_computation)); 216 217 // Fuse two parallel reduces together to improve performance. 218 if (use_fusion_ && !batch_norm->has_sharding()) { 219 auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); 220 221 auto fused = computation_->CreateFusionInstruction( 222 {tuple, sum, squared_sum, operand_squared}, 223 HloInstruction::FusionKind::kInput); 224 225 sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); 226 227 squared_sum = 228 add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); 229 } 230 231 // E[X]. 232 auto mean = add(HloInstruction::CreateBinary( 233 feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); 234 235 auto mean_broadcasted = add( 236 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); 237 238 // E[X^2]. 239 auto square_mean = add(HloInstruction::CreateBinary( 240 feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); 241 242 // E^2[X]. 243 auto mean_square = add(HloInstruction::CreateBinary( 244 feature_shape, HloOpcode::kMultiply, mean, mean)); 245 246 // Var[X]. 247 auto var = add(HloInstruction::CreateBinary( 248 feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); 249 250 auto var_broadcasted = 251 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); 252 253 // Var[X] + epsilon. 254 auto var_add_epsilon = add(HloInstruction::CreateBinary( 255 operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); 256 257 auto neg_half_literal = Literal::CreateR0(-0.5f); 258 TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); 259 auto neg_half = 260 add(HloInstruction::CreateConstant(std::move(neg_half_literal))); 261 262 // 1 / Sqrt[Var[X] + epsilon]. 263 auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( 264 operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); 265 266 // X - E[X]. 267 auto operand_minus_mean = add(HloInstruction::CreateBinary( 268 operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); 269 270 // (X - E[X]) / Sqrt[Var[X] + epsilon]. 271 auto normalized = add( 272 HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, 273 operand_minus_mean, rsqrt_var_add_epsilon)); 274 275 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. 276 auto scaled_normalized = add(HloInstruction::CreateBinary( 277 operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); 278 279 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. 280 auto shifted_normalized = add(HloInstruction::CreateBinary( 281 operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); 282 283 auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); 284 285 if (batch_norm->has_sharding()) { 286 int64 instruction_count_after = computation_->instruction_count(); 287 CHECK_EQ(instruction_count_after, 288 instruction_count_before + added_instructions.size()); 289 HloSharding operand_sharding = 290 batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); 291 for (HloInstruction* inst : added_instructions) { 292 if (ShapeUtil::Equal(inst->shape(), operand_shape)) { 293 inst->set_sharding(operand_sharding); 294 } else { 295 inst->set_sharding(HloSharding::Replicate()); 296 } 297 } 298 tuple->set_sharding(batch_norm->sharding()); 299 } 300 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); 301 return Status::OK(); 302 } 303 304 Status BatchNormExpanderVisitor::HandleBatchNormInference( 305 HloInstruction* batch_norm) { 306 if (!rewrite_inference_op_) { 307 return Status::OK(); 308 } 309 // Expand batch norm inference into smaller HLO ops. 310 HloInstruction* operand = batch_norm->mutable_operand(0); 311 const Shape operand_shape = operand->shape(); 312 int64 feature_index = batch_norm->feature_index(); 313 PrimitiveType ptype = operand_shape.element_type(); 314 315 HloInstruction* scale = batch_norm->mutable_operand(1); 316 HloInstruction* offset = batch_norm->mutable_operand(2); 317 HloInstruction* mean = batch_norm->mutable_operand(3); 318 HloInstruction* var = batch_norm->mutable_operand(4); 319 const Shape feature_shape = scale->shape(); 320 321 auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); 322 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); 323 auto epsilon = computation_->AddInstruction( 324 HloInstruction::CreateConstant(std::move(epsilon_literal))); 325 326 std::vector<int64> dimensions_without_feature; 327 328 for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { 329 if (i != feature_index) { 330 dimensions_without_feature.push_back(i); 331 } 332 } 333 334 std::vector<HloInstruction*> added_instructions; 335 auto add = [&](std::unique_ptr<HloInstruction> inst) { 336 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); 337 added_instructions.push_back(added_inst); 338 return added_inst; 339 }; 340 int64 instruction_count_before = computation_->instruction_count(); 341 342 auto scale_broadcasted = add( 343 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); 344 345 auto offset_broadcasted = add( 346 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); 347 348 auto mean_broadcasted = add( 349 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); 350 351 auto var_broadcasted = 352 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); 353 354 // Var[X] + epsilon. 355 auto var_add_epsilon = add(HloInstruction::CreateBinary( 356 operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); 357 358 auto neg_half_literal = Literal::CreateR0(-0.5f); 359 TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); 360 auto neg_half = 361 add(HloInstruction::CreateConstant(std::move(neg_half_literal))); 362 363 // 1 / Sqrt[Var[X] + epsilon]. 364 auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( 365 operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); 366 367 // X - E[X]. 368 auto operand_minus_mean = add(HloInstruction::CreateBinary( 369 operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); 370 371 // (X - E[X]) / Sqrt[Var[X] + epsilon]. 372 auto normalized = add( 373 HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, 374 operand_minus_mean, rsqrt_var_add_epsilon)); 375 376 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. 377 auto scaled_normalized = add(HloInstruction::CreateBinary( 378 operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); 379 380 // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. 381 auto shifted_normalized = HloInstruction::CreateBinary( 382 operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); 383 384 int64 instruction_count_after = computation_->instruction_count(); 385 CHECK_EQ(instruction_count_after, 386 instruction_count_before + added_instructions.size()); 387 if (batch_norm->has_sharding()) { 388 for (HloInstruction* inst : added_instructions) { 389 if (ShapeUtil::Equal(inst->shape(), operand_shape)) { 390 inst->set_sharding(batch_norm->sharding()); 391 } else { 392 inst->set_sharding(HloSharding::Replicate()); 393 } 394 } 395 shifted_normalized->set_sharding(batch_norm->sharding()); 396 } 397 TF_CHECK_OK( 398 ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); 399 return Status::OK(); 400 } 401 402 Status BatchNormExpanderVisitor::HandleBatchNormGrad( 403 HloInstruction* batch_norm) { 404 // Use the following formulas to calculate gradients: 405 // scale_grad = 406 // sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon) 407 // 408 // offset_grad = 409 // sum(output_grad) 410 // 411 // activation_grad = 412 // 1/N * scale * rsqrt(var + epsilon) * 413 // (N * output_grad - sum(output_grad) - (activation - mean(activation)) * 414 // sum(output_grad * (activation - mean(activation))) / (variance + 415 // epsilon)) 416 if (!rewrite_grad_op_) { 417 return Status::OK(); 418 } 419 std::vector<HloInstruction*> added_instructions; 420 auto add = [&](std::unique_ptr<HloInstruction> inst) { 421 HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); 422 added_instructions.push_back(added_inst); 423 return added_inst; 424 }; 425 int64 instruction_count_before = computation_->instruction_count(); 426 427 HloInstruction* activation = batch_norm->mutable_operand(0); 428 const Shape activation_shape = activation->shape(); 429 PrimitiveType ptype = activation_shape.element_type(); 430 HloInstruction* scale = batch_norm->mutable_operand(1); 431 const Shape feature_shape = scale->shape(); 432 HloInstruction* mean = batch_norm->mutable_operand(2); 433 HloInstruction* variance = batch_norm->mutable_operand(3); 434 HloInstruction* grad_output = batch_norm->mutable_operand(4); 435 436 int64 feature_index = batch_norm->feature_index(); 437 438 const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); 439 const int64 feature_count = activation_shape.dimensions(feature_index); 440 auto elements_per_feature_literal = 441 Literal::CreateR0<float>(size_in_elements / feature_count); 442 TF_ASSIGN_OR_RETURN(elements_per_feature_literal, 443 elements_per_feature_literal->Convert(ptype)); 444 auto elements_per_feature = add( 445 HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); 446 447 auto zero_literal = Literal::CreateR0(0.0f); 448 TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); 449 auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); 450 451 auto neg_half_literal = Literal::CreateR0(-0.5f); 452 TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); 453 auto neg_half = 454 add(HloInstruction::CreateConstant(std::move(neg_half_literal))); 455 456 auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); 457 TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); 458 auto epsilon = 459 add(HloInstruction::CreateConstant(std::move(epsilon_literal))); 460 461 std::vector<int64> dimensions_without_feature; 462 463 for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) { 464 if (i != feature_index) { 465 dimensions_without_feature.push_back(i); 466 } 467 } 468 469 auto scale_broadcasted = add(HloInstruction::CreateBroadcast( 470 activation_shape, scale, {feature_index})); 471 auto variance_broadcasted = add(HloInstruction::CreateBroadcast( 472 activation_shape, variance, {feature_index})); 473 474 // E[X]. 475 auto mean_broadcasted = add( 476 HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); 477 478 // rsqrt[Var[X] + epsilon]. 479 auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( 480 activation_shape, HloOpcode::kPower, 481 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, 482 variance_broadcasted, epsilon)), 483 neg_half)); 484 485 auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( 486 feature_shape, HloOpcode::kPower, 487 add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, 488 epsilon)), 489 neg_half)); 490 491 // X - E[X]. 492 auto activation_minus_mean = add(HloInstruction::CreateBinary( 493 activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); 494 495 // Grad[Y] * (X - E[X]). 496 auto grad_output_times_activiation_minus_mean = 497 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, 498 grad_output, activation_minus_mean)); 499 500 HloComputation* add_reduce_computation = 501 GetScalarBinaryComputation(ptype, HloOpcode::kAdd); 502 503 // sum(Grad[Y] * (X - E[X])). 504 auto sum_grad_output_times_activiation_minus_mean = 505 add(HloInstruction::CreateReduce( 506 feature_shape, grad_output_times_activiation_minus_mean, zero, 507 dimensions_without_feature, add_reduce_computation)); 508 509 // Grad[beta] = Sum(Grad[Y]). 510 auto grad_beta = add(HloInstruction::CreateReduce( 511 feature_shape, grad_output, zero, dimensions_without_feature, 512 add_reduce_computation)); 513 514 if (use_fusion_ && !batch_norm->has_sharding()) { 515 auto tuple = add(HloInstruction::CreateTuple( 516 {sum_grad_output_times_activiation_minus_mean, grad_beta})); 517 518 auto fused = computation_->CreateFusionInstruction( 519 {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, 520 HloInstruction::FusionKind::kInput); 521 522 sum_grad_output_times_activiation_minus_mean = 523 add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); 524 525 grad_beta = 526 add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); 527 } 528 529 // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). 530 auto grad_scale = add(HloInstruction::CreateBinary( 531 feature_shape, HloOpcode::kMultiply, 532 sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); 533 534 // I2 = Sum(Grad[Y]) 535 auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, 536 {feature_index})); 537 538 // I3 = Sum(Grad[Y] * (X - E[X])) 539 auto i3 = add(HloInstruction::CreateBroadcast( 540 activation_shape, sum_grad_output_times_activiation_minus_mean, 541 {feature_index})); 542 543 // I4 = (X - E[X]) * I3 544 auto i4 = add(HloInstruction::CreateBinary( 545 activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); 546 547 // I5 = I4 / (Var[X] + epsilon) 548 auto i5 = add(HloInstruction::CreateBinary( 549 activation_shape, HloOpcode::kDivide, i4, 550 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, 551 variance_broadcasted, epsilon)))); 552 553 // scale * rsqrt[Var[X] + epsilon] * 1/N 554 auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( 555 activation_shape, HloOpcode::kMultiply, scale_broadcasted, 556 rsqrt_var_add_epsilon_broadcasted)); 557 558 scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( 559 activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, 560 elements_per_feature)); 561 562 auto i1 = 563 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, 564 grad_output, elements_per_feature)); 565 566 // I6 = I1 - I2 - I5 567 auto i6 = add(HloInstruction::CreateBinary( 568 activation_shape, HloOpcode::kSubtract, 569 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, 570 i1, i2)), 571 i5)); 572 573 // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. 574 auto grad_activation = 575 add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, 576 scale_times_rsqrt_var_add_epsilon, i6)); 577 auto tuple = 578 HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); 579 if (batch_norm->has_sharding()) { 580 int64 instruction_count_after = computation_->instruction_count(); 581 CHECK_EQ(instruction_count_after, 582 instruction_count_before + added_instructions.size()); 583 HloSharding activation_sharding = 584 batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); 585 for (HloInstruction* inst : added_instructions) { 586 if (ShapeUtil::Equal(inst->shape(), activation_shape)) { 587 inst->set_sharding(activation_sharding); 588 } else { 589 inst->set_sharding(HloSharding::Replicate()); 590 } 591 } 592 tuple->set_sharding(batch_norm->sharding()); 593 } 594 595 TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); 596 597 return Status::OK(); 598 } 599 600 StatusOr<bool> BatchNormExpander::Run(HloModule* module) { 601 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); 602 bool changed = false; 603 for (auto* comp : module->MakeNonfusionComputations()) { 604 if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, 605 rewrite_inference_op_, rewrite_grad_op_, 606 use_fusion_)) { 607 changed = true; 608 } 609 } 610 XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString()); 611 return changed; 612 } 613 614 } // namespace xla 615