1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_inliner.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace op = xla::testing::opcode_matchers;
37
38 namespace xla {
39 namespace {
40
41 // Tests for call inlining that are most tractable at the HLO level (vs
42 // ComputationBuilder API in call_test.cc).
43 using CallInlinerTest = HloTestBase;
44
TEST_F(CallInlinerTest,ControlDependenciesAreCarriedToCaller)45 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
46 // "inner" computation just has a control dependency from the "zero" value to
47 // the "one" value.
48 HloComputation::Builder inner(TestName() + ".inner");
49 HloInstruction* zero = inner.AddInstruction(
50 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f)));
51 HloInstruction* one = inner.AddInstruction(
52 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
53 TF_ASSERT_OK(zero->AddControlDependencyTo(one));
54 auto module = CreateNewVerifiedModule();
55 HloComputation* inner_computation =
56 module->AddEmbeddedComputation(inner.Build());
57
58 // "outer" computation just calls the "inner" computation.
59 HloComputation::Builder outer(TestName() + ".outer");
60 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
61 outer.AddInstruction(
62 HloInstruction::CreateCall(r0f32, {}, inner_computation));
63
64 auto computation = module->AddEntryComputation(outer.Build());
65
66 CallInliner call_inliner;
67 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
68 ASSERT_TRUE(mutated);
69 EXPECT_THAT(computation->root_instruction(), op::Constant());
70 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
71 42);
72 ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size());
73 auto prior = computation->root_instruction()->control_predecessors()[0];
74 EXPECT_THAT(prior, op::Constant());
75 EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
76 }
77
78 // Tests for referential transparency (a function that calls a function that
79 // returns false should be identical to just returning false).
TEST_F(CallInlinerTest,CallsWithinWhileBodiesAreInlined)80 TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
81 const Shape pred = ShapeUtil::MakeShape(PRED, {});
82 auto module = CreateNewVerifiedModule();
83
84 // Create a lambda that calls a function that returns the false predicate.
85 // Note we also use this lambda twice by reference, just to make the test a
86 // little trickier.
87 HloComputation::Builder just_false(TestName() + ".false");
88 just_false.AddInstruction(
89 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
90 HloComputation* false_computation =
91 module->AddEmbeddedComputation(just_false.Build());
92
93 HloComputation::Builder call_false_builder(TestName() + ".call_false");
94 call_false_builder.AddInstruction(
95 HloInstruction::CreateParameter(0, pred, "param"));
96 call_false_builder.AddInstruction(
97 HloInstruction::CreateCall(pred, {}, false_computation));
98 HloComputation* call_false =
99 module->AddEmbeddedComputation(call_false_builder.Build());
100
101 HloComputation::Builder outer(TestName() + ".outer");
102 HloInstruction* init_value = outer.AddInstruction(
103 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
104 outer.AddInstruction(
105 HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
106
107 auto computation = module->AddEntryComputation(outer.Build());
108
109 CallInliner call_inliner;
110 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
111 ASSERT_TRUE(mutated);
112 EXPECT_THAT(
113 computation->root_instruction()->while_condition()->root_instruction(),
114 op::Constant());
115 EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
116 op::Constant());
117 }
118
119 // Check CallInliner::Inline, which inlines a specific call without running the
120 // whole pass.
TEST_F(CallInlinerTest,InlineWithoutRunningPass)121 TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
122 const Shape pred = ShapeUtil::MakeShape(PRED, {});
123 auto module = CreateNewVerifiedModule();
124
125 HloComputation::Builder just_false(TestName() + ".false");
126 auto* true_constant = just_false.AddInstruction(
127 HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true})));
128 auto* false_constant = just_false.AddInstruction(
129 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
130 TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
131 HloComputation* false_computation =
132 module->AddEmbeddedComputation(just_false.Build());
133
134 HloComputation::Builder call_false_builder(TestName() + ".call_false");
135 HloInstruction* call = call_false_builder.AddInstruction(
136 HloInstruction::CreateCall(pred, {}, false_computation));
137 auto computation = module->AddEntryComputation(call_false_builder.Build());
138
139 TF_ASSERT_OK(CallInliner::Inline(call).status());
140 EXPECT_THAT(computation->root_instruction(), op::Constant());
141 EXPECT_THAT(computation->root_instruction()->control_successors(),
142 ElementsAre(op::Constant()));
143 }
144
TEST_F(CallInlinerTest,CallToOutfeedComputationIsInlined)145 TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
146 const Shape f32 = ShapeUtil::MakeShape(F32, {});
147 auto module = CreateNewVerifiedModule();
148
149 HloComputation::Builder outfeeder(TestName() + ".outfeeder");
150 auto value = outfeeder.AddInstruction(
151 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
152 auto token = outfeeder.AddInstruction(HloInstruction::CreateToken());
153 outfeeder.AddInstruction(
154 HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
155
156 auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build());
157
158 HloComputation::Builder outer(TestName() + ".outer");
159 outer.AddInstruction(HloInstruction::CreateCall(
160 outfeed_computation->root_instruction()->shape(), /*operands=*/{},
161 outfeed_computation));
162
163 module->AddEntryComputation(outer.Build());
164
165 CallInliner call_inliner;
166 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
167 ASSERT_TRUE(mutated);
168 }
169
170 } // namespace
171 } // namespace xla
172