1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace xla {
36 namespace {
37
38 namespace op = xla::testing::opcode_matchers;
39
40 class ConditionalSimplifierTest : public HloTestBase {
41 public:
42 // Makes a computation that contains a conditional with constant predicate.
43 HloComputation* MakeConditional(HloModule* module);
44 };
45
MakeConditional(HloModule * module)46 HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
47 HloComputation::Builder builder(TestName());
48
49 // true_computation returns param+1.
50 HloComputation* true_computation;
51 {
52 HloComputation::Builder true_computation_builder(TestName() +
53 ".true_computation");
54 auto param =
55 true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
56 0, ShapeUtil::MakeShape(S32, {}), "param"));
57 auto one = true_computation_builder.AddInstruction(
58 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
59
60 true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
61 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
62
63 true_computation =
64 module->AddEmbeddedComputation(true_computation_builder.Build());
65 }
66
67 // false_computation returns param+42.
68 HloComputation* false_computation;
69 {
70 HloComputation::Builder false_computation_builder(TestName() +
71 ".false_computation");
72 auto param = false_computation_builder.AddInstruction(
73 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
74 "param"));
75 auto forty_two = false_computation_builder.AddInstruction(
76 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
77
78 false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
79 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
80 false_computation =
81 module->AddEmbeddedComputation(false_computation_builder.Build());
82 }
83
84 auto false_instrn = builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
86 auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
87 0, ShapeUtil::MakeShape(S32, {}), "false_param"));
88 auto one = builder.AddInstruction(
89 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
90
91 builder.AddInstruction(HloInstruction::CreateConditional(
92 ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
93 false_param, false_computation));
94
95 return module->AddEntryComputation(builder.Build());
96 }
97
TEST_F(ConditionalSimplifierTest,ConditionalGetsInlined)98 TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
99 auto m = CreateNewVerifiedModule();
100 HloComputation* computation = MakeConditional(m.get());
101 ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
102 EXPECT_THAT(computation->root_instruction(),
103 op::Add(op::Parameter(), op::Constant()));
104 }
105
TEST_F(ConditionalSimplifierTest,ConditionalWithControlDependency)106 TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
107 auto m = CreateNewVerifiedModule();
108 HloComputation* computation = MakeConditional(m.get());
109
110 auto* true_op = computation->AddInstruction(
111 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
112 TF_ASSERT_OK(
113 true_op->AddControlDependencyTo(computation->root_instruction()));
114
115 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
116 }
117
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsSend)118 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
119 auto m = CreateNewVerifiedModule();
120 HloComputation* computation = MakeConditional(m.get());
121 auto* conditional = computation->root_instruction();
122 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
123
124 auto* true_computation = conditional->true_computation();
125 auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
126 auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
127 true_computation->AddInstruction(
128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
129 token, /*channel_id=*/0));
130 true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
131 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
132 }
133
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsRecv)134 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
135 auto m = CreateNewVerifiedModule();
136 HloComputation* computation = MakeConditional(m.get());
137 auto* conditional = computation->root_instruction();
138 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
139
140 auto* true_computation = conditional->true_computation();
141 auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
142 auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
143 ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
144 true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
145 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
146 }
147
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsNonRemovableInstruction)148 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
149 auto m = CreateNewVerifiedModule();
150 HloComputation* computation = MakeConditional(m.get());
151 auto* conditional = computation->root_instruction();
152 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
153 auto* false_computation = conditional->false_computation();
154 auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
155 false_computation->AddInstruction(HloInstruction::CreateInfeed(
156 ShapeUtil::MakeShape(F32, {1}), token, "config"));
157 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
158 }
159
160 } // namespace
161 } // namespace xla
162