1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 namespace {
37 
38 namespace op = xla::testing::opcode_matchers;
39 
40 class ConditionalSimplifierTest : public HloTestBase {
41  public:
42   // Makes a computation that contains a conditional with constant predicate.
43   HloComputation* MakeConditional(HloModule* module);
44 };
45 
MakeConditional(HloModule * module)46 HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
47   HloComputation::Builder builder(TestName());
48 
49   // true_computation returns param+1.
50   HloComputation* true_computation;
51   {
52     HloComputation::Builder true_computation_builder(TestName() +
53                                                      ".true_computation");
54     auto param =
55         true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
56             0, ShapeUtil::MakeShape(S32, {}), "param"));
57     auto one = true_computation_builder.AddInstruction(
58         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
59 
60     true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
61         ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
62 
63     true_computation =
64         module->AddEmbeddedComputation(true_computation_builder.Build());
65   }
66 
67   // false_computation returns param+42.
68   HloComputation* false_computation;
69   {
70     HloComputation::Builder false_computation_builder(TestName() +
71                                                       ".false_computation");
72     auto param = false_computation_builder.AddInstruction(
73         HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
74                                         "param"));
75     auto forty_two = false_computation_builder.AddInstruction(
76         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
77 
78     false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
79         ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
80     false_computation =
81         module->AddEmbeddedComputation(false_computation_builder.Build());
82   }
83 
84   auto false_instrn = builder.AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
86   auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
87       0, ShapeUtil::MakeShape(S32, {}), "false_param"));
88   auto one = builder.AddInstruction(
89       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
90 
91   builder.AddInstruction(HloInstruction::CreateConditional(
92       ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
93       false_param, false_computation));
94 
95   return module->AddEntryComputation(builder.Build());
96 }
97 
TEST_F(ConditionalSimplifierTest,ConditionalGetsInlined)98 TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
99   auto m = CreateNewVerifiedModule();
100   HloComputation* computation = MakeConditional(m.get());
101   ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
102   EXPECT_THAT(computation->root_instruction(),
103               op::Add(op::Parameter(), op::Constant()));
104 }
105 
TEST_F(ConditionalSimplifierTest,ConditionalWithControlDependency)106 TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
107   auto m = CreateNewVerifiedModule();
108   HloComputation* computation = MakeConditional(m.get());
109 
110   auto* true_op = computation->AddInstruction(
111       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
112   TF_ASSERT_OK(
113       true_op->AddControlDependencyTo(computation->root_instruction()));
114 
115   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
116 }
117 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsSend)118 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
119   auto m = CreateNewVerifiedModule();
120   HloComputation* computation = MakeConditional(m.get());
121   auto* conditional = computation->root_instruction();
122   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
123 
124   auto* true_computation = conditional->true_computation();
125   auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
126   auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
127       true_computation->AddInstruction(
128           HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
129       token, /*channel_id=*/0));
130   true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
131   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
132 }
133 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsRecv)134 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
135   auto m = CreateNewVerifiedModule();
136   HloComputation* computation = MakeConditional(m.get());
137   auto* conditional = computation->root_instruction();
138   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
139 
140   auto* true_computation = conditional->true_computation();
141   auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
142   auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
143       ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
144   true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
145   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
146 }
147 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsNonRemovableInstruction)148 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
149   auto m = CreateNewVerifiedModule();
150   HloComputation* computation = MakeConditional(m.get());
151   auto* conditional = computation->root_instruction();
152   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
153   auto* false_computation = conditional->false_computation();
154   auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
155   false_computation->AddInstruction(HloInstruction::CreateInfeed(
156       ShapeUtil::MakeShape(F32, {1}), token, "config"));
157   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
158 }
159 
160 }  // namespace
161 }  // namespace xla
162