1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/compiler/xla/tests/test_utils.h"
25 
26 namespace xla {
27 
28 class HloSubcomputationUnificationTest : public HloTestBase {
29  protected:
HloSubcomputationUnificationTest()30   HloSubcomputationUnificationTest() {}
31 
CreateR0S32IdentityComputation()32   std::unique_ptr<HloComputation> CreateR0S32IdentityComputation() {
33     auto builder = HloComputation::Builder("Identity");
34     builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
35     return builder.Build();
36   }
37 
CreateR0S32AdditionComputation()38   std::unique_ptr<HloComputation> CreateR0S32AdditionComputation() {
39     auto builder = HloComputation::Builder("Addition");
40     auto x =
41         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
42     auto y =
43         builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y"));
44     builder.AddInstruction(
45         HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
46     return builder.Build();
47   }
48 
CreateR1S32AdditionComputation(const Shape & shape)49   std::unique_ptr<HloComputation> CreateR1S32AdditionComputation(
50       const Shape& shape) {
51     auto builder = HloComputation::Builder("Addition");
52     auto x =
53         builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
54     auto y =
55         builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y"));
56     builder.AddInstruction(
57         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y));
58     return builder.Build();
59   }
60 
61   Shape r0s32_ = ShapeUtil::MakeShape(S32, {});
62   Shape r0f32_ = ShapeUtil::MakeShape(S32, {});
63   Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5});
64   Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3});
65 };
66 
TEST_F(HloSubcomputationUnificationTest,UnifyIdentities)67 TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
68   auto module = CreateNewVerifiedModule();
69   auto builder = HloComputation::Builder(TestName());
70 
71   auto callee1 =
72       module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
73   auto callee2 =
74       module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
75 
76   auto constant = builder.AddInstruction(
77       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
78   auto x = builder.AddInstruction(
79       HloInstruction::CreateCall(r0s32_, {constant}, callee1));
80   auto y = builder.AddInstruction(
81       HloInstruction::CreateCall(r0s32_, {constant}, callee2));
82   builder.AddInstruction(
83       HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
84 
85   module->AddEntryComputation(builder.Build());
86 
87   EXPECT_EQ(3, module->computation_count());
88   EXPECT_NE(x->to_apply(), y->to_apply());
89   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
90   EXPECT_EQ(2, module->computation_count());
91   EXPECT_EQ(x->to_apply(), y->to_apply());
92 }
93 
TEST_F(HloSubcomputationUnificationTest,UnifyAdditions)94 TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
95   auto module = CreateNewVerifiedModule();
96   auto builder = HloComputation::Builder(TestName());
97 
98   auto callee1 =
99       module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
100   auto callee2 =
101       module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
102 
103   auto constant1 = builder.AddInstruction(
104       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
105   auto constant2 = builder.AddInstruction(
106       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
107   auto x = builder.AddInstruction(
108       HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
109   auto y = builder.AddInstruction(
110       HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2));
111   builder.AddInstruction(
112       HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
113 
114   module->AddEntryComputation(builder.Build());
115 
116   EXPECT_EQ(3, module->computation_count());
117   EXPECT_NE(x->to_apply(), y->to_apply());
118   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
119   EXPECT_EQ(2, module->computation_count());
120   EXPECT_EQ(x->to_apply(), y->to_apply());
121 }
122 
123 // Do not unify subcomputations with different parameter shapes.
TEST_F(HloSubcomputationUnificationTest,DifferentParameterShapes)124 TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
125   auto module = CreateNewUnverifiedModule();
126   auto builder = HloComputation::Builder(TestName());
127 
128   auto callee1 =
129       module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_));
130   auto callee2 =
131       module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_));
132 
133   auto param1 = builder.AddInstruction(
134       HloInstruction::CreateParameter(0, r1s32_5_, "param1"));
135   auto param2 = builder.AddInstruction(
136       HloInstruction::CreateParameter(1, r1s32_5_, "param2"));
137   auto x = builder.AddInstruction(
138       HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1));
139   auto y = builder.AddInstruction(
140       HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2));
141   builder.AddInstruction(HloInstruction::CreateConcatenate(
142       ShapeUtil::MakeShape(S32, {8}), {x, y}, 0));
143 
144   module->AddEntryComputation(builder.Build());
145 
146   EXPECT_EQ(3, module->computation_count());
147   EXPECT_NE(x->to_apply(), y->to_apply());
148   EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
149   EXPECT_EQ(3, module->computation_count());
150   EXPECT_NE(x->to_apply(), y->to_apply());
151 }
152 
153 // Regression test for b/31466798. Checks that entry_computation is still valid
154 // after unification.
TEST_F(HloSubcomputationUnificationTest,TwoIdenticalComputations)155 TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
156   auto module = CreateNewVerifiedModule();
157   for (int i = 0; i < 2; ++i) {
158     HloComputation::Builder builder("pow");
159     auto x =
160         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
161     auto y =
162         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
163     builder.AddInstruction(
164         HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y));
165     if (i == 0) {
166       module->AddEmbeddedComputation(builder.Build());
167     } else {
168       module->AddEntryComputation(builder.Build());
169     }
170   }
171 
172   EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
173   EXPECT_EQ(1, module->computation_count());
174   EXPECT_EQ(*module->computations().begin(), module->entry_computation());
175 }
176 
177 }  // namespace xla
178