1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35
36 namespace xla {
37 namespace {
38
39 class BatchNormExpanderTest : public HloTestBase {
40 protected:
41 // BatchNorm should have a dynamic sized divider for mean operations.
CountGetDimensionSize(const HloModule & module)42 int64 CountGetDimensionSize(const HloModule& module) {
43 int64 count = 0;
44 for (HloComputation* comp : module.computations()) {
45 for (HloInstruction* inst : comp->instructions()) {
46 if (inst->opcode() == HloOpcode::kGetDimensionSize) {
47 count++;
48 }
49 }
50 }
51 return count;
52 }
53 };
54
55 // Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest,BatchNormTraining)56 TEST_F(BatchNormExpanderTest, BatchNormTraining) {
57 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
58 Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
59 Shape offset_shape = ShapeUtil::MakeShape(F32, {2});
60
61 HloComputation::Builder builder(TestName());
62 HloInstruction* param0 = builder.AddInstruction(
63 HloInstruction::CreateParameter(0, input_shape, "activation"));
64
65 HloInstruction* param1 = builder.AddInstruction(
66 HloInstruction::CreateParameter(1, scale_shape, "scale"));
67
68 HloInstruction* param2 = builder.AddInstruction(
69 HloInstruction::CreateParameter(2, offset_shape, "offset"));
70
71 builder.AddInstruction(HloInstruction::CreateBatchNormTraining(
72 ShapeUtil::MakeTupleShape({input_shape, scale_shape, offset_shape}),
73 param0, param1, param2,
74 /*epsilon=*/0.001, /*feature_index=*/3));
75
76 auto module = CreateNewVerifiedModule();
77 auto computation = module->AddEntryComputation(builder.Build());
78 HloInstruction* root = computation->root_instruction();
79 EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
80 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
81 /*rewrite_inference_op=*/true,
82 /*rewrite_grad_op=*/true);
83 ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
84 root = computation->root_instruction();
85 EXPECT_EQ(CountGetDimensionSize(*module), 3);
86 // Make sure this operation is expanded.
87 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
88 }
89
90 // Test that we expand BatchNormGrad.
TEST_F(BatchNormExpanderTest,BatchNormGrad)91 TEST_F(BatchNormExpanderTest, BatchNormGrad) {
92 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
93 Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
94 Shape mean_shape = ShapeUtil::MakeShape(F32, {2});
95 Shape var_shape = ShapeUtil::MakeShape(F32, {2});
96 Shape grad_output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
97
98 HloComputation::Builder builder(TestName());
99 HloInstruction* param0 = builder.AddInstruction(
100 HloInstruction::CreateParameter(0, input_shape, "activation"));
101
102 HloInstruction* param1 = builder.AddInstruction(
103 HloInstruction::CreateParameter(1, scale_shape, "scale"));
104
105 HloInstruction* param2 = builder.AddInstruction(
106 HloInstruction::CreateParameter(2, mean_shape, "mean"));
107
108 HloInstruction* param3 = builder.AddInstruction(
109 HloInstruction::CreateParameter(3, var_shape, "var"));
110
111 HloInstruction* param4 = builder.AddInstruction(
112 HloInstruction::CreateParameter(4, grad_output_shape, "grad_output"));
113
114 builder.AddInstruction(HloInstruction::CreateBatchNormGrad(
115 ShapeUtil::MakeTupleShape({input_shape, scale_shape, mean_shape}), param0,
116 param1, param2, param3, param4,
117 /*epsilon=*/0.001, /*feature_index=*/3));
118
119 auto module = CreateNewVerifiedModule();
120 auto computation = module->AddEntryComputation(builder.Build());
121 HloInstruction* root = computation->root_instruction();
122 EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
123 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
124 /*rewrite_inference_op=*/true,
125 /*rewrite_grad_op=*/true);
126 ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
127 root = computation->root_instruction();
128 EXPECT_EQ(CountGetDimensionSize(*module), 3);
129 // Make sure this operation is expanded.
130 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
131 }
132
TEST_F(BatchNormExpanderTest,BatchNormTrainingSharding)133 TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
134 const char* module_str = R"(
135 HloModule module
136 ENTRY entry {
137 %param.0 = f32[8,4] parameter(0)
138 %param.1 = f32[4] parameter(1)
139 %param.2 = f32[4] parameter(2)
140 ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
141 batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
142 epsilon=0.001, feature_index=1, sharding={maximal device=1}
143 })";
144
145 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
146 BatchNormExpander rewriter(/*rewrite_training_op=*/true,
147 /*rewrite_inference_op=*/true,
148 /*rewrite_grad_op=*/true);
149 ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie());
150
151 for (auto* instruction : m->entry_computation()->instructions()) {
152 if (instruction->opcode() == HloOpcode::kParameter) {
153 continue;
154 }
155 auto device = instruction->sharding_unique_device();
156 ASSERT_TRUE(device);
157 EXPECT_EQ(*device, 1);
158 }
159 }
160
161 } // namespace
162 } // namespace xla
163