1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h"
17 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
18 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 
24 namespace xla {
25 namespace {
26 
27 class TripCountAnnotatorTest : public HloTestBase {};
28 
TEST_F(TripCountAnnotatorTest,KnownSmallTripCount)29 TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
30   const char* kModuleStr = R"(
31     HloModule test
32     Body {
33       param = (s32[]) parameter(0)
34       i = s32[] get-tuple-element(param), index=0
35       one = s32[] constant(1)
36       i_plus_one = s32[] add(i, one)
37       ROOT tuple = (s32[]) tuple(i_plus_one)
38     }
39 
40     Cond {
41       param = (s32[]) parameter(0)
42       i = s32[] get-tuple-element(param), index=0
43       trip_count = s32[] constant(10)
44       ROOT done = pred[] compare(i, trip_count), direction=LT
45     }
46 
47     ENTRY test {
48       i_start = s32[] constant(0)
49       initial_tuple = (s32[]) tuple(i_start)
50       ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
51     })";
52 
53   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
54   WhileLoopTripCountAnnotator pass;
55   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
56   ASSERT_TRUE(changed);
57 
58   TF_ASSERT_OK_AND_ASSIGN(auto config,
59                           m->entry_computation()
60                               ->root_instruction()
61                               ->backend_config<WhileLoopBackendConfig>());
62   EXPECT_EQ(10, config.known_trip_count().n());
63 }
64 
TEST_F(TripCountAnnotatorTest,KnownLargeTripCount)65 TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
66   const char* kModuleStr = R"(
67     HloModule test
68     Body {
69       param = (s32[]) parameter(0)
70       i = s32[] get-tuple-element(param), index=0
71       one = s32[] constant(1)
72       i_plus_one = s32[] add(i, one)
73       ROOT tuple = (s32[]) tuple(i_plus_one)
74     }
75 
76     Cond {
77       param = (s32[]) parameter(0)
78       i = s32[] get-tuple-element(param), index=0
79       trip_count = s32[] constant(1000000)
80       ROOT done = pred[] compare(i, trip_count), direction=LT
81     }
82 
83     ENTRY test {
84       i_start = s32[] constant(0)
85       initial_tuple = (s32[]) tuple(i_start)
86       ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
87     })";
88 
89   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
90   WhileLoopTripCountAnnotator pass;
91   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
92   ASSERT_TRUE(changed);
93 
94   TF_ASSERT_OK_AND_ASSIGN(auto config,
95                           m->entry_computation()
96                               ->root_instruction()
97                               ->backend_config<WhileLoopBackendConfig>());
98   EXPECT_EQ(1000000, config.known_trip_count().n());
99 }
100 
TEST_F(TripCountAnnotatorTest,NonzeroStart)101 TEST_F(TripCountAnnotatorTest, NonzeroStart) {
102   const char* kModuleStr = R"(
103     HloModule test
104     Body {
105       param = (s32[]) parameter(0)
106       i = s32[] get-tuple-element(param), index=0
107       one = s32[] constant(1)
108       i_plus_one = s32[] add(i, one)
109       ROOT tuple = (s32[]) tuple(i_plus_one)
110     }
111 
112     Cond {
113       param = (s32[]) parameter(0)
114       i = s32[] get-tuple-element(param), index=0
115       trip_count = s32[] constant(1000000)
116       ROOT done = pred[] compare(i, trip_count), direction=LT
117     }
118 
119     ENTRY test {
120       i_start = s32[] constant(10)
121       initial_tuple = (s32[]) tuple(i_start)
122       ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
123     })";
124 
125   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
126   WhileLoopTripCountAnnotator pass;
127   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
128   ASSERT_TRUE(changed);
129 
130   TF_ASSERT_OK_AND_ASSIGN(auto config,
131                           m->entry_computation()
132                               ->root_instruction()
133                               ->backend_config<WhileLoopBackendConfig>());
134   EXPECT_EQ(999990, config.known_trip_count().n());
135 }
136 
TEST_F(TripCountAnnotatorTest,LessThanOrEqualTo)137 TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
138   const char* kModuleStr = R"(
139     HloModule test
140     Body {
141       param = (s32[]) parameter(0)
142       i = s32[] get-tuple-element(param), index=0
143       one = s32[] constant(1)
144       i_plus_one = s32[] add(i, one)
145       ROOT tuple = (s32[]) tuple(i_plus_one)
146     }
147 
148     Cond {
149       param = (s32[]) parameter(0)
150       i = s32[] get-tuple-element(param), index=0
151       trip_count = s32[] constant(1000000)
152       ROOT done = pred[] compare(i, trip_count), direction=LE
153     }
154 
155     ENTRY test {
156       i_start = s32[] constant(10)
157       initial_tuple = (s32[]) tuple(i_start)
158       ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body
159     })";
160 
161   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
162   WhileLoopTripCountAnnotator pass;
163   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
164   ASSERT_TRUE(changed);
165 
166   TF_ASSERT_OK_AND_ASSIGN(auto config,
167                           m->entry_computation()
168                               ->root_instruction()
169                               ->backend_config<WhileLoopBackendConfig>());
170   EXPECT_EQ(999991, config.known_trip_count().n());
171 }
172 
TEST_F(TripCountAnnotatorTest,Int64Overflow)173 TEST_F(TripCountAnnotatorTest, Int64Overflow) {
174   // for(i = INT64_MIN; i < INT64_MAX; ++i)
175   //
176   // We store the trip count as an int64, so this loop is unanalyzable.
177   const char* kModuleStr = R"(
178     HloModule test
179     Body {
180       param = (s64[]) parameter(0)
181       i = s64[] get-tuple-element(param), index=0
182       one = s64[] constant(1)
183       i_plus_one = s64[] add(i, one)
184       ROOT tuple = (s64[]) tuple(i_plus_one)
185     }
186 
187     Cond {
188       param = (s64[]) parameter(0)
189       i = s64[] get-tuple-element(param), index=0
190       trip_count = s64[] constant(9223372036854775807) // 2^63-1
191       ROOT done = pred[] compare(i, trip_count), direction=LE
192     }
193 
194     ENTRY test {
195       i_start = s64[] constant(-9223372036854775808)  // -2^63
196       initial_tuple = (s64[]) tuple(i_start)
197       ROOT while = (s64[]) while(initial_tuple), condition=Cond, body=Body
198     })";
199 
200   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
201   WhileLoopTripCountAnnotator pass;
202   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
203   EXPECT_FALSE(changed);
204 }
205 
206 }  // namespace
207 }  // namespace xla
208