1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 
36 namespace xla {
37 namespace {
38 
39 class BatchNormExpanderTest : public HloTestBase {
40  protected:
41   // BatchNorm should have a dynamic sized divider for mean operations.
CountGetDimensionSize(const HloModule & module)42   int64 CountGetDimensionSize(const HloModule& module) {
43     int64 count = 0;
44     for (HloComputation* comp : module.computations()) {
45       for (HloInstruction* inst : comp->instructions()) {
46         if (inst->opcode() == HloOpcode::kGetDimensionSize) {
47           count++;
48         }
49       }
50     }
51     return count;
52   }
53 };
54 
55 // Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest,BatchNormTraining)56 TEST_F(BatchNormExpanderTest, BatchNormTraining) {
57   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
58   Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
59   Shape offset_shape = ShapeUtil::MakeShape(F32, {2});
60 
61   HloComputation::Builder builder(TestName());
62   HloInstruction* param0 = builder.AddInstruction(
63       HloInstruction::CreateParameter(0, input_shape, "activation"));
64 
65   HloInstruction* param1 = builder.AddInstruction(
66       HloInstruction::CreateParameter(1, scale_shape, "scale"));
67 
68   HloInstruction* param2 = builder.AddInstruction(
69       HloInstruction::CreateParameter(2, offset_shape, "offset"));
70 
71   builder.AddInstruction(HloInstruction::CreateBatchNormTraining(
72       ShapeUtil::MakeTupleShape({input_shape, scale_shape, offset_shape}),
73       param0, param1, param2,
74       /*epsilon=*/0.001, /*feature_index=*/3));
75 
76   auto module = CreateNewVerifiedModule();
77   auto computation = module->AddEntryComputation(builder.Build());
78   HloInstruction* root = computation->root_instruction();
79   EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
80   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
81                              /*rewrite_inference_op=*/true,
82                              /*rewrite_grad_op=*/true);
83   ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
84   root = computation->root_instruction();
85   EXPECT_EQ(CountGetDimensionSize(*module), 3);
86   // Make sure this operation is expanded.
87   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
88 }
89 
90 // Test that we expand BatchNormGrad.
TEST_F(BatchNormExpanderTest,BatchNormGrad)91 TEST_F(BatchNormExpanderTest, BatchNormGrad) {
92   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
93   Shape scale_shape = ShapeUtil::MakeShape(F32, {2});
94   Shape mean_shape = ShapeUtil::MakeShape(F32, {2});
95   Shape var_shape = ShapeUtil::MakeShape(F32, {2});
96   Shape grad_output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
97 
98   HloComputation::Builder builder(TestName());
99   HloInstruction* param0 = builder.AddInstruction(
100       HloInstruction::CreateParameter(0, input_shape, "activation"));
101 
102   HloInstruction* param1 = builder.AddInstruction(
103       HloInstruction::CreateParameter(1, scale_shape, "scale"));
104 
105   HloInstruction* param2 = builder.AddInstruction(
106       HloInstruction::CreateParameter(2, mean_shape, "mean"));
107 
108   HloInstruction* param3 = builder.AddInstruction(
109       HloInstruction::CreateParameter(3, var_shape, "var"));
110 
111   HloInstruction* param4 = builder.AddInstruction(
112       HloInstruction::CreateParameter(4, grad_output_shape, "grad_output"));
113 
114   builder.AddInstruction(HloInstruction::CreateBatchNormGrad(
115       ShapeUtil::MakeTupleShape({input_shape, scale_shape, mean_shape}), param0,
116       param1, param2, param3, param4,
117       /*epsilon=*/0.001, /*feature_index=*/3));
118 
119   auto module = CreateNewVerifiedModule();
120   auto computation = module->AddEntryComputation(builder.Build());
121   HloInstruction* root = computation->root_instruction();
122   EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
123   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
124                              /*rewrite_inference_op=*/true,
125                              /*rewrite_grad_op=*/true);
126   ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
127   root = computation->root_instruction();
128   EXPECT_EQ(CountGetDimensionSize(*module), 3);
129   // Make sure this operation is expanded.
130   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
131 }
132 
TEST_F(BatchNormExpanderTest,BatchNormTrainingSharding)133 TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
134   const char* module_str = R"(
135 HloModule module
136 ENTRY entry {
137   %param.0 = f32[8,4] parameter(0)
138   %param.1 = f32[4] parameter(1)
139   %param.2 = f32[4] parameter(2)
140   ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
141     batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
142     epsilon=0.001, feature_index=1, sharding={maximal device=1}
143 })";
144 
145   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
146   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
147                              /*rewrite_inference_op=*/true,
148                              /*rewrite_grad_op=*/true);
149   ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie());
150 
151   for (auto* instruction : m->entry_computation()->instructions()) {
152     if (instruction->opcode() == HloOpcode::kParameter) {
153       continue;
154     }
155     auto device = instruction->sharding_unique_device();
156     ASSERT_TRUE(device);
157     EXPECT_EQ(*device, 1);
158   }
159 }
160 
161 }  // namespace
162 }  // namespace xla
163