1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
17 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
18 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23
24 namespace xla {
25 namespace {
26
27 class TripCountAnnotatorTest : public HloTestBase {};
28
TEST_F(TripCountAnnotatorTest,KnownSmallTripCount)29 TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
30 const char* kModuleStr = R"(
31 HloModule test
32 Body {
33 param = (s32[]) parameter(0)
34 i = s32[] get-tuple-element(param), index=0
35 one = s32[] constant(1)
36 i_plus_one = s32[] add(i, one)
37 ROOT tuple = (s32[]) tuple(i_plus_one)
38 }
39
40 Cond {
41 param = (s32[]) parameter(0)
42 i = s32[] get-tuple-element(param), index=0
43 trip_count = s32[] constant(10)
44 ROOT done = pred[] compare(i, trip_count), direction=LT
45 }
46
47 ENTRY test {
48 i_start = s32[] constant(0)
49 initial_tuple = (s32[]) tuple(i_start)
50 ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
51 })";
52
53 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
54 WhileLoopTripCountAnnotator pass;
55 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
56 ASSERT_TRUE(changed);
57
58 TF_ASSERT_OK_AND_ASSIGN(auto config,
59 m->entry_computation()
60 ->root_instruction()
61 ->backend_config<WhileLoopBackendConfig>());
62 EXPECT_EQ(10, config.known_trip_count().n());
63 }
64
TEST_F(TripCountAnnotatorTest,KnownLargeTripCount)65 TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
66 const char* kModuleStr = R"(
67 HloModule test
68 Body {
69 param = (s32[]) parameter(0)
70 i = s32[] get-tuple-element(param), index=0
71 one = s32[] constant(1)
72 i_plus_one = s32[] add(i, one)
73 ROOT tuple = (s32[]) tuple(i_plus_one)
74 }
75
76 Cond {
77 param = (s32[]) parameter(0)
78 i = s32[] get-tuple-element(param), index=0
79 trip_count = s32[] constant(1000000)
80 ROOT done = pred[] compare(i, trip_count), direction=LT
81 }
82
83 ENTRY test {
84 i_start = s32[] constant(0)
85 initial_tuple = (s32[]) tuple(i_start)
86 ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
87 })";
88
89 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
90 WhileLoopTripCountAnnotator pass;
91 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
92 ASSERT_TRUE(changed);
93
94 TF_ASSERT_OK_AND_ASSIGN(auto config,
95 m->entry_computation()
96 ->root_instruction()
97 ->backend_config<WhileLoopBackendConfig>());
98 EXPECT_EQ(1000000, config.known_trip_count().n());
99 }
100
TEST_F(TripCountAnnotatorTest,NonzeroStart)101 TEST_F(TripCountAnnotatorTest, NonzeroStart) {
102 const char* kModuleStr = R"(
103 HloModule test
104 Body {
105 param = (s32[]) parameter(0)
106 i = s32[] get-tuple-element(param), index=0
107 one = s32[] constant(1)
108 i_plus_one = s32[] add(i, one)
109 ROOT tuple = (s32[]) tuple(i_plus_one)
110 }
111
112 Cond {
113 param = (s32[]) parameter(0)
114 i = s32[] get-tuple-element(param), index=0
115 trip_count = s32[] constant(1000000)
116 ROOT done = pred[] compare(i, trip_count), direction=LT
117 }
118
119 ENTRY test {
120 i_start = s32[] constant(10)
121 initial_tuple = (s32[]) tuple(i_start)
122 ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
123 })";
124
125 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
126 WhileLoopTripCountAnnotator pass;
127 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
128 ASSERT_TRUE(changed);
129
130 TF_ASSERT_OK_AND_ASSIGN(auto config,
131 m->entry_computation()
132 ->root_instruction()
133 ->backend_config<WhileLoopBackendConfig>());
134 EXPECT_EQ(999990, config.known_trip_count().n());
135 }
136
TEST_F(TripCountAnnotatorTest,LessThanOrEqualTo)137 TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
138 const char* kModuleStr = R"(
139 HloModule test
140 Body {
141 param = (s32[]) parameter(0)
142 i = s32[] get-tuple-element(param), index=0
143 one = s32[] constant(1)
144 i_plus_one = s32[] add(i, one)
145 ROOT tuple = (s32[]) tuple(i_plus_one)
146 }
147
148 Cond {
149 param = (s32[]) parameter(0)
150 i = s32[] get-tuple-element(param), index=0
151 trip_count = s32[] constant(1000000)
152 ROOT done = pred[] compare(i, trip_count), direction=LE
153 }
154
155 ENTRY test {
156 i_start = s32[] constant(10)
157 initial_tuple = (s32[]) tuple(i_start)
158 ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
159 })";
160
161 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
162 WhileLoopTripCountAnnotator pass;
163 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
164 ASSERT_TRUE(changed);
165
166 TF_ASSERT_OK_AND_ASSIGN(auto config,
167 m->entry_computation()
168 ->root_instruction()
169 ->backend_config<WhileLoopBackendConfig>());
170 EXPECT_EQ(999991, config.known_trip_count().n());
171 }
172
TEST_F(TripCountAnnotatorTest,Int64Overflow)173 TEST_F(TripCountAnnotatorTest, Int64Overflow) {
174 // for(i = INT64_MIN; i < INT64_MAX; ++i)
175 //
176 // We store the trip count as an int64, so this loop is unanalyzable.
177 const char* kModuleStr = R"(
178 HloModule test
179 Body {
180 param = (s64[]) parameter(0)
181 i = s64[] get-tuple-element(param), index=0
182 one = s64[] constant(1)
183 i_plus_one = s64[] add(i, one)
184 ROOT tuple = (s64[]) tuple(i_plus_one)
185 }
186
187 Cond {
188 param = (s64[]) parameter(0)
189 i = s64[] get-tuple-element(param), index=0
190 trip_count = s64[] constant(9223372036854775807) // 2^63-1
191 ROOT done = pred[] compare(i, trip_count), direction=LE
192 }
193
194 ENTRY test {
195 i_start = s64[] constant(-9223372036854775808) // -2^63
196 initial_tuple = (s64[]) tuple(i_start)
197 ROOT while = (s64[]) while(initial_tuple), condition=Cond, body=Body
198 })";
199
200 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
201 WhileLoopTripCountAnnotator pass;
202 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
203 EXPECT_FALSE(changed);
204 }
205
206 } // namespace
207 } // namespace xla
208