1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_inliner.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace op = xla::testing::opcode_matchers;
37 
38 namespace xla {
39 namespace {
40 
41 // Tests for call inlining that are most tractable at the HLO level (vs
42 // ComputationBuilder API in call_test.cc).
43 using CallInlinerTest = HloTestBase;
44 
TEST_F(CallInlinerTest,ControlDependenciesAreCarriedToCaller)45 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
46   // "inner" computation just has a control dependency from the "zero" value to
47   // the "one" value.
48   HloComputation::Builder inner(TestName() + ".inner");
49   HloInstruction* zero = inner.AddInstruction(
50       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f)));
51   HloInstruction* one = inner.AddInstruction(
52       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
53   TF_ASSERT_OK(zero->AddControlDependencyTo(one));
54   auto module = CreateNewVerifiedModule();
55   HloComputation* inner_computation =
56       module->AddEmbeddedComputation(inner.Build());
57 
58   // "outer" computation just calls the "inner" computation.
59   HloComputation::Builder outer(TestName() + ".outer");
60   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
61   outer.AddInstruction(
62       HloInstruction::CreateCall(r0f32, {}, inner_computation));
63 
64   auto computation = module->AddEntryComputation(outer.Build());
65 
66   CallInliner call_inliner;
67   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
68   ASSERT_TRUE(mutated);
69   EXPECT_THAT(computation->root_instruction(), op::Constant());
70   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
71             42);
72   ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size());
73   auto prior = computation->root_instruction()->control_predecessors()[0];
74   EXPECT_THAT(prior, op::Constant());
75   EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
76 }
77 
78 // Tests for referential transparency (a function that calls a function that
79 // returns false should be identical to just returning false).
TEST_F(CallInlinerTest,CallsWithinWhileBodiesAreInlined)80 TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
81   const Shape pred = ShapeUtil::MakeShape(PRED, {});
82   auto module = CreateNewVerifiedModule();
83 
84   // Create a lambda that calls a function that returns the false predicate.
85   // Note we also use this lambda twice by reference, just to make the test a
86   // little trickier.
87   HloComputation::Builder just_false(TestName() + ".false");
88   just_false.AddInstruction(
89       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
90   HloComputation* false_computation =
91       module->AddEmbeddedComputation(just_false.Build());
92 
93   HloComputation::Builder call_false_builder(TestName() + ".call_false");
94   call_false_builder.AddInstruction(
95       HloInstruction::CreateParameter(0, pred, "param"));
96   call_false_builder.AddInstruction(
97       HloInstruction::CreateCall(pred, {}, false_computation));
98   HloComputation* call_false =
99       module->AddEmbeddedComputation(call_false_builder.Build());
100 
101   HloComputation::Builder outer(TestName() + ".outer");
102   HloInstruction* init_value = outer.AddInstruction(
103       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
104   outer.AddInstruction(
105       HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
106 
107   auto computation = module->AddEntryComputation(outer.Build());
108 
109   CallInliner call_inliner;
110   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
111   ASSERT_TRUE(mutated);
112   EXPECT_THAT(
113       computation->root_instruction()->while_condition()->root_instruction(),
114       op::Constant());
115   EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
116               op::Constant());
117 }
118 
119 // Check CallInliner::Inline, which inlines a specific call without running the
120 // whole pass.
TEST_F(CallInlinerTest,InlineWithoutRunningPass)121 TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
122   const Shape pred = ShapeUtil::MakeShape(PRED, {});
123   auto module = CreateNewVerifiedModule();
124 
125   HloComputation::Builder just_false(TestName() + ".false");
126   auto* true_constant = just_false.AddInstruction(
127       HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true})));
128   auto* false_constant = just_false.AddInstruction(
129       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
130   TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
131   HloComputation* false_computation =
132       module->AddEmbeddedComputation(just_false.Build());
133 
134   HloComputation::Builder call_false_builder(TestName() + ".call_false");
135   HloInstruction* call = call_false_builder.AddInstruction(
136       HloInstruction::CreateCall(pred, {}, false_computation));
137   auto computation = module->AddEntryComputation(call_false_builder.Build());
138 
139   TF_ASSERT_OK(CallInliner::Inline(call).status());
140   EXPECT_THAT(computation->root_instruction(), op::Constant());
141   EXPECT_THAT(computation->root_instruction()->control_successors(),
142               ElementsAre(op::Constant()));
143 }
144 
TEST_F(CallInlinerTest,CallToOutfeedComputationIsInlined)145 TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
146   const Shape f32 = ShapeUtil::MakeShape(F32, {});
147   auto module = CreateNewVerifiedModule();
148 
149   HloComputation::Builder outfeeder(TestName() + ".outfeeder");
150   auto value = outfeeder.AddInstruction(
151       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
152   auto token = outfeeder.AddInstruction(HloInstruction::CreateToken());
153   outfeeder.AddInstruction(
154       HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
155 
156   auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build());
157 
158   HloComputation::Builder outer(TestName() + ".outer");
159   outer.AddInstruction(HloInstruction::CreateCall(
160       outfeed_computation->root_instruction()->shape(), /*operands=*/{},
161       outfeed_computation));
162 
163   module->AddEntryComputation(outer.Build());
164 
165   CallInliner call_inliner;
166   TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
167   ASSERT_TRUE(mutated);
168 }
169 
170 }  // namespace
171 }  // namespace xla
172