1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17 #include "absl/base/casts.h"
18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23
24 namespace xla {
25
26 using absl::nullopt;
27 using absl::optional;
28 namespace m = match;
29
30 // Finds and returns the non-constant operand in instr.
31 //
32 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)33 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
34 const HloInstruction* result = nullptr;
35 for (const HloInstruction* operand : instr->operands()) {
36 if (!operand->IsConstant()) {
37 if (result != nullptr) {
38 CHECK_EQ(result, operand);
39 }
40 result = operand;
41 }
42 }
43 CHECK_NE(result, nullptr);
44 return result;
45 }
46
47 // If all of instr's operands are either constants or have the form
48 // get-tuple-element(gte_operand, N)
49 // for the same value N, returns N. Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)50 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
51 const HloInstruction* gte_operand) {
52 VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
53 << gte_operand->ToString() << ")";
54
55 // Among the operands of `instr`, find one that is a get-tuple-element op.
56 auto gte_it = c_find_if(instr->operands(), [](const HloInstruction* instr) {
57 return instr->opcode() == HloOpcode::kGetTupleElement;
58 });
59 if (gte_it == instr->operands().end()) {
60 VLOG(2) << "instr does not have a gte operand.";
61 return nullopt;
62 }
63
64 // All operands of `instr` must be either constants or of the form
65 // get-tuple-element(gte_operand, tuple_idx)
66 // for the same value tuple_idx.
67 int64 tuple_idx = (*gte_it)->tuple_index();
68 for (const HloInstruction* operand : instr->operands()) {
69 if (!Match(operand, m::Constant()) &&
70 !Match(operand,
71 m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
72 VLOG(2)
73 << "instr uses something other than a constant or gte(gte_operand, "
74 << tuple_idx << "): " << operand->ToString();
75 return nullopt;
76 }
77 }
78 return tuple_idx;
79 }
80
81 // Tries to get the tuple index of the induction variable of a while loop.
82 //
83 // Checks that the loop condition and body both plumb the induction variable
84 // through the same tuple index, and that they both apply exactly one op to the
85 // induction variable before deciding whether to do another loop iteration (in
86 // the loop condition's case) or packing the induction variable into the result
87 // tuple (in the loop body's case).
88 //
89 // Specifically, checks that the loop condition has structure
90 //
91 // root = op(constants, get-tuple-elem(param0, N), constants)
92 //
93 // and the loop body has the structure
94 //
95 // inc = op(constants, get-tuple-elem(param0, N), constants)
96 // root = tuple(..., inc, ...) // inc is N'th operand of tuple().
97 //
98 // If so, returns N. Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)99 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
100 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
101 VLOG(2) << "Finding induction variable for loop "
102 << while_op->ToShortString();
103
104 // The while_cond computation should have the form
105 //
106 // while_cond_root =
107 // op(constants, get-tuple-elem(while_cond_param, N), constants).
108 //
109 // If it does, set indvar_tuple_idx to N.
110 auto* while_cond = while_op->while_condition();
111 auto* while_cond_root = while_cond->root_instruction();
112 auto* while_cond_param = while_cond->parameter_instruction(0);
113 optional<int64> indvar_tuple_idx =
114 GetGTEOperandIndex(while_cond_root, while_cond_param);
115 if (!indvar_tuple_idx) {
116 VLOG(2) << "Induction variable not found in loop condition: "
117 << while_cond->root_instruction()->ToString();
118 return nullopt;
119 }
120
121 // The while_body computation should have the form
122 //
123 // while_body_inc =
124 // op(constants, get-tuple-elem(while_body_param, N), constants)
125 // while_body_root = tuple(..., while_body_inc, ...)
126 //
127 // where while_body_inc is operand N of while_body_root.
128 auto* while_body = while_op->while_body();
129 auto* while_body_root = while_body->root_instruction();
130 if (while_body_root->opcode() != HloOpcode::kTuple) {
131 VLOG(2) << "While body's root is not a tuple instruction: "
132 << while_body_root->ToString();
133 return nullopt;
134 }
135
136 auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
137 auto* while_body_param = while_body->parameter_instruction(0);
138 optional<int64> while_body_indvar_tuple_idx =
139 GetGTEOperandIndex(while_body_inc, while_body_param);
140 if (!while_body_indvar_tuple_idx) {
141 VLOG(2)
142 << "Induction variable not found in while body increment instruction: "
143 << while_body_inc->ToString();
144 return nullopt;
145 }
146 if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
147 VLOG(2) << "Tuple index of induction variable does not match between loop "
148 "condition ("
149 << *indvar_tuple_idx << ") and while body ("
150 << *while_body_indvar_tuple_idx << ")";
151 return nullopt;
152 }
153
154 // Finally, check that the while loop's initial value is a tuple with enough
155 // elements.
156 auto* while_init = while_op->operand(0);
157 if (while_init->opcode() != HloOpcode::kTuple) {
158 VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
159 return nullopt;
160 }
161
162 VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
163 return indvar_tuple_idx;
164 }
165
166 // Converts the given literal to a scalar int64, if possible.
167 //
168 // Fails if the literal is not an integral type or if the value it contains
169 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)170 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
171 if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
172 VLOG(2) << "literal is not an effective scalar: " << l.ToString();
173 return nullopt;
174 }
175 switch (l.shape().element_type()) {
176 case S8:
177 return l.GetFirstElement<int8>();
178 case S16:
179 return l.GetFirstElement<int16>();
180 case S32:
181 return l.GetFirstElement<int32>();
182 case S64:
183 return l.GetFirstElement<int64>();
184 case U8:
185 return l.GetFirstElement<uint8>();
186 case U16:
187 return l.GetFirstElement<uint16>();
188 case U32:
189 return l.GetFirstElement<uint32>();
190 case U64: {
191 uint64 v = l.GetFirstElement<uint64>();
192 if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
193 VLOG(2) << "uint64 literal is out of range for int64: " << v;
194 return nullopt;
195 }
196 return v;
197 }
198 default:
199 VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
200 return nullopt;
201 }
202 }
203
204 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64 a,int64 b)205 optional<int64> CheckedAdd(int64 a, int64 b) {
206 // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
207 // different sign, see Hacker's Delignt 2nd Ed. pp 28.
208 uint64 aa = absl::bit_cast<uint64>(a);
209 uint64 bb = absl::bit_cast<uint64>(b);
210 int64 result = absl::bit_cast<int64>(aa + bb);
211 if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
212 return nullopt;
213 }
214 return result;
215 }
216
217 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64 a,int64 b)218 optional<int64> CheckedSubtract(int64 a, int64 b) {
219 uint64 aa = absl::bit_cast<uint64>(a);
220 uint64 bb = absl::bit_cast<uint64>(b);
221 int64 result = absl::bit_cast<int64>(aa - bb);
222 // Overflow occurred iff `a` and `b` have different signs and the sign of
223 // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
224 if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
225 return nullopt;
226 }
227 return result;
228 }
229
230 // Check if
231 // - `i` is initialized to a scalar constant K (namely, `indvar_init`),
232 // - the while condition does `i < N` or `i <= N`, and
233 // - the while body does `i++`.
234 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64 indvar_tuple_idx,const Literal & indvar_init)235 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
236 int64 indvar_tuple_idx,
237 const Literal& indvar_init) {
238 // First, find the scalar constant K that `i` is initialized to.
239 optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
240 if (!indvar_init_val) {
241 VLOG(2) << "Pattern-match failed: induction variable init is not a "
242 "constant scalar representable as an int64: "
243 << indvar_init.ToString();
244 return nullopt;
245 }
246
247 // Check that `i` goes as `i++` in the while body.
248 //
249 // TODO(jlebar): We could also handle i-- and other idioms.
250 auto* while_body = while_op->while_body();
251 auto* while_body_indvar_update =
252 while_body->root_instruction()->operand(indvar_tuple_idx);
253 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
254 if (!Match(while_body_indvar_update,
255 m::AddAnyOrder(m::Op().Is(while_body_indvar),
256 m::ConstantEffectiveScalar(1)))) {
257 VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
258 << while_body_indvar_update->ToString();
259 return nullopt;
260 }
261
262 // Check that we do op(i, N) or op(N, i) as the while condition. Capture the
263 // value N.
264 auto* while_cond = while_op->while_condition();
265 auto* while_cond_root = while_cond->root_instruction();
266 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
267 HloInstruction* while_cond_bound = nullptr;
268 if (!Match(while_cond_root,
269 m::Op().WithBinaryOperandsAnyOrder(
270 m::Op().Is(while_cond_indvar),
271 m::ConstantEffectiveScalar(&while_cond_bound)))) {
272 VLOG(2) << "Pattern-match failed: while condition is not of the form "
273 "op(i, N) or op(N, i).";
274 return nullopt;
275 }
276 // Note: If this succeeds, the constant `N` is representable as an int64 --
277 // that is, if it's an XLA U64, it fits within an int64.
278 optional<int64> while_cond_bound_val =
279 LiteralAsScalarInt64(while_cond_bound->literal());
280 if (!while_cond_bound_val) {
281 VLOG(2) << "Pattern-match failed: while condition induction variable is "
282 "not a constant scalar representable as an int64.";
283 return nullopt;
284 }
285
286 // Handle `i = K; i < N; ++i`.
287 if (Match(while_cond_root,
288 m::Op()
289 .WithComparisonDirection(ComparisonDirection::kLt)
290 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
291 VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
292 << while_cond_root->ToString();
293 optional<int64> trips =
294 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
295 if (trips) {
296 return std::max(int64{0}, *trips);
297 } else {
298 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
299 return nullopt;
300 }
301 }
302
303 // Handle `i = K; i <= N; ++i`.
304 if (Match(while_cond_root,
305 m::Op()
306 .WithComparisonDirection(ComparisonDirection::kLe)
307 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
308 VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
309 << while_cond_root->ToString();
310 optional<int64> trips =
311 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
312 if (!trips) {
313 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
314 return nullopt;
315 }
316 trips = CheckedAdd(*trips, 1);
317 if (!trips) {
318 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
319 return nullopt;
320 }
321 return std::max<int64>(0, *trips);
322 }
323
324 VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
325 << while_cond_root->ToString();
326 return nullopt;
327 }
328
ComputeWhileLoopTripCount(HloInstruction * while_op,int64 max_brute_force_iters)329 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
330 int64 max_brute_force_iters) {
331 VLOG(2) << "Getting trip count for loop " << while_op->ToString();
332
333 // The loop's induction variable is found at
334 //
335 // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
336 //
337 // where comp is while_op->while_body() or while_op->while_condition().
338 optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
339 if (!indvar_tuple_idx) {
340 return nullopt;
341 }
342
343 // Now that we know the index of the induction variable, we can we can try to
344 // compute how many times the loop executes. Start by computing the induction
345 // variable's initial value.
346 HloEvaluator evaluator(/*max_loop_iterations=*/0);
347 auto* while_init = while_op->mutable_operand(0);
348 auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
349 StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
350 if (!indvar_init_result.ok()) {
351 VLOG(2) << "Couldn't evaluate induction variable init, "
352 << indvar_init_result.status() << ", " << indvar_init->ToString();
353 return nullopt;
354 }
355 Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
356
357 // First, try to pattern-match.
358 if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
359 indvar_iter_val)) {
360 return trip_count;
361 }
362
363 // If our pattern-match failed, try brute-forcing the loop trip count.
364 auto* while_body = while_op->while_body();
365 auto* while_body_indvar_update =
366 while_body->root_instruction()->operand(*indvar_tuple_idx);
367 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
368
369 auto* while_cond = while_op->while_condition();
370 auto* while_cond_root = while_cond->root_instruction();
371 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
372
373 for (int64 trip_count = 0; trip_count != max_brute_force_iters + 1;
374 ++trip_count) {
375 StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
376 while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
377 if (!result.ok()) {
378 VLOG(2) << "Couldn't evaluate while cond: " << result.status();
379 return nullopt;
380 }
381 if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
382 VLOG(2) << "Loop has static trip count of " << trip_count;
383 return trip_count;
384 }
385
386 // Calculate the value of the induction variable after one iteration of the
387 // loop, and check whether the while condition is true with this new value.
388 StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
389 while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
390 if (!indvar_next_result.ok()) {
391 VLOG(2) << "Couldn't evaluate induction variable update: "
392 << indvar_next_result.status();
393 return nullopt;
394 }
395 indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
396 }
397
398 VLOG(2) << "Loop has unknown trip count.";
399 return nullopt;
400 }
401
402 // If the only user of this instruction is a get-tuple-element, return that
403 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
404 // get a false negative if there are several copies of the same GTE, or there
405 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)406 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
407 if (inst->user_count() != 1) {
408 return nullptr;
409 }
410
411 HloInstruction* user = inst->users().back();
412 if (user->opcode() != HloOpcode::kGetTupleElement) {
413 return nullptr;
414 }
415 return user;
416 }
417
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)418 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
419 // If we know the exact trip count, it's also the upper bound.
420 auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
421 if (exact_trip_count) {
422 VLOG(2) << "Loop has exact trip count.";
423 return exact_trip_count;
424 }
425
426 // There is one more case we know how to handle. If the loop condition only
427 // looks at one element of the tuple, and the loop body sets this element to a
428 // constant, there are two options:
429 // 1) Evaluating the condition on this constant returns true. In this case,
430 // the loop either executes 0 times, or is an infinite loop, depending on the
431 // init value.
432 // 2) Evaluating the condition on this constant returns false. In this case,
433 // the loop executes 0 or 1 times, depending on the init value. This means
434 // that, regardless of the init value, the upper bound on the trip count is 1.
435
436 // Check whether the condition depends on a single parameter, and find out
437 // which.
438 auto* while_cond = while_op->while_condition();
439 auto* while_cond_param = while_cond->parameter_instruction(0);
440 auto* cond_gte = GetOnlyGTE(while_cond_param);
441 if (!cond_gte) {
442 VLOG(2) << "Induction variable not found in loop condition: "
443 << while_cond->root_instruction()->ToString();
444 return nullopt;
445 }
446
447 // Now check whether this gets set to a constant by the while body.
448 auto* while_body = while_op->while_body();
449 auto* while_body_root = while_body->root_instruction();
450 if (while_body_root->opcode() != HloOpcode::kTuple) {
451 VLOG(3) << "While body's root is not a tuple instruction: "
452 << while_body_root->ToString();
453 return nullopt;
454 }
455
456 int64 indvar_index = cond_gte->tuple_index();
457 auto* while_body_indvar = while_body_root->operand(indvar_index);
458 if (while_body_indvar->opcode() != HloOpcode::kConstant) {
459 VLOG(3) << "While body does not set the IV to a constant: "
460 << while_body_indvar->ToString();
461 return nullopt;
462 }
463
464 // We have a constant. Evaluate the condition on this constant.
465 HloEvaluator evaluator(/*max_loop_iterations=*/0);
466 Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
467 TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
468 /*dest_shape_index=*/{indvar_index},
469 /*src_shape_index=*/{}));
470 StatusOr<Literal> eval_result =
471 evaluator.Evaluate(*while_cond, {std::move(fake_input)});
472
473 if (!eval_result.ok()) {
474 VLOG(2) << "Couldn't evaluate while loop condition.";
475 return nullopt;
476 }
477
478 Literal cond_result_pred = std::move(eval_result.ValueOrDie());
479 CHECK(ShapeUtil::Equal(cond_result_pred.shape(),
480 ShapeUtil::MakeShape(PRED, {})));
481
482 // Per the explanation above, if the evaluated condition returns false, the
483 // loop executes at most once.
484 bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
485 if (!cond_returns_true) {
486 VLOG(2) << "Upper bound on the trip count is 1";
487 return 1;
488 }
489
490 VLOG(2) << "Loop has no known upper bound on the trip count.";
491 return nullopt;
492 }
493
494 } // namespace xla
495