1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17 #include "absl/base/casts.h"
18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23 
24 namespace xla {
25 
26 using absl::nullopt;
27 using absl::optional;
28 namespace m = match;
29 
30 // Finds and returns the non-constant operand in instr.
31 //
32 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)33 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
34   const HloInstruction* result = nullptr;
35   for (const HloInstruction* operand : instr->operands()) {
36     if (!operand->IsConstant()) {
37       if (result != nullptr) {
38         CHECK_EQ(result, operand);
39       }
40       result = operand;
41     }
42   }
43   CHECK_NE(result, nullptr);
44   return result;
45 }
46 
47 // If all of instr's operands are either constants or have the form
48 //   get-tuple-element(gte_operand, N)
49 // for the same value N, returns N.  Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)50 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
51                                           const HloInstruction* gte_operand) {
52   VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
53           << gte_operand->ToString() << ")";
54 
55   // Among the operands of `instr`, find one that is a get-tuple-element op.
56   auto gte_it = c_find_if(instr->operands(), [](const HloInstruction* instr) {
57     return instr->opcode() == HloOpcode::kGetTupleElement;
58   });
59   if (gte_it == instr->operands().end()) {
60     VLOG(2) << "instr does not have a gte operand.";
61     return nullopt;
62   }
63 
64   // All operands of `instr` must be either constants or of the form
65   //   get-tuple-element(gte_operand, tuple_idx)
66   // for the same value tuple_idx.
67   int64 tuple_idx = (*gte_it)->tuple_index();
68   for (const HloInstruction* operand : instr->operands()) {
69     if (!Match(operand, m::Constant()) &&
70         !Match(operand,
71                m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
72       VLOG(2)
73           << "instr uses something other than a constant or gte(gte_operand, "
74           << tuple_idx << "): " << operand->ToString();
75       return nullopt;
76     }
77   }
78   return tuple_idx;
79 }
80 
81 // Tries to get the tuple index of the induction variable of a while loop.
82 //
83 // Checks that the loop condition and body both plumb the induction variable
84 // through the same tuple index, and that they both apply exactly one op to the
85 // induction variable before  deciding whether to do another loop iteration (in
86 // the loop condition's case) or packing the induction variable into the result
87 // tuple (in the loop body's case).
88 //
89 // Specifically, checks that the loop condition has structure
90 //
91 //   root = op(constants, get-tuple-elem(param0, N), constants)
92 //
93 // and the loop body has the structure
94 //
95 //   inc = op(constants, get-tuple-elem(param0, N), constants)
96 //   root = tuple(..., inc, ...)  // inc is N'th operand of tuple().
97 //
98 // If so, returns N.  Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)99 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
100   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
101   VLOG(2) << "Finding induction variable for loop "
102           << while_op->ToShortString();
103 
104   // The while_cond computation should have the form
105   //
106   //   while_cond_root =
107   //       op(constants, get-tuple-elem(while_cond_param, N), constants).
108   //
109   // If it does, set indvar_tuple_idx to N.
110   auto* while_cond = while_op->while_condition();
111   auto* while_cond_root = while_cond->root_instruction();
112   auto* while_cond_param = while_cond->parameter_instruction(0);
113   optional<int64> indvar_tuple_idx =
114       GetGTEOperandIndex(while_cond_root, while_cond_param);
115   if (!indvar_tuple_idx) {
116     VLOG(2) << "Induction variable not found in loop condition: "
117             << while_cond->root_instruction()->ToString();
118     return nullopt;
119   }
120 
121   // The while_body computation should have the form
122   //
123   //   while_body_inc =
124   //       op(constants, get-tuple-elem(while_body_param, N), constants)
125   //   while_body_root = tuple(..., while_body_inc, ...)
126   //
127   // where while_body_inc is operand N of while_body_root.
128   auto* while_body = while_op->while_body();
129   auto* while_body_root = while_body->root_instruction();
130   if (while_body_root->opcode() != HloOpcode::kTuple) {
131     VLOG(2) << "While body's root is not a tuple instruction: "
132             << while_body_root->ToString();
133     return nullopt;
134   }
135 
136   auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
137   auto* while_body_param = while_body->parameter_instruction(0);
138   optional<int64> while_body_indvar_tuple_idx =
139       GetGTEOperandIndex(while_body_inc, while_body_param);
140   if (!while_body_indvar_tuple_idx) {
141     VLOG(2)
142         << "Induction variable not found in while body increment instruction: "
143         << while_body_inc->ToString();
144     return nullopt;
145   }
146   if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
147     VLOG(2) << "Tuple index of induction variable does not match between loop "
148                "condition ("
149             << *indvar_tuple_idx << ") and while body ("
150             << *while_body_indvar_tuple_idx << ")";
151     return nullopt;
152   }
153 
154   // Finally, check that the while loop's initial value is a tuple with enough
155   // elements.
156   auto* while_init = while_op->operand(0);
157   if (while_init->opcode() != HloOpcode::kTuple) {
158     VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
159     return nullopt;
160   }
161 
162   VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
163   return indvar_tuple_idx;
164 }
165 
166 // Converts the given literal to a scalar int64, if possible.
167 //
168 // Fails if the literal is not an integral type or if the value it contains
169 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)170 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
171   if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
172     VLOG(2) << "literal is not an effective scalar: " << l.ToString();
173     return nullopt;
174   }
175   switch (l.shape().element_type()) {
176     case S8:
177       return l.GetFirstElement<int8>();
178     case S16:
179       return l.GetFirstElement<int16>();
180     case S32:
181       return l.GetFirstElement<int32>();
182     case S64:
183       return l.GetFirstElement<int64>();
184     case U8:
185       return l.GetFirstElement<uint8>();
186     case U16:
187       return l.GetFirstElement<uint16>();
188     case U32:
189       return l.GetFirstElement<uint32>();
190     case U64: {
191       uint64 v = l.GetFirstElement<uint64>();
192       if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
193         VLOG(2) << "uint64 literal is out of range for int64: " << v;
194         return nullopt;
195       }
196       return v;
197     }
198     default:
199       VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
200       return nullopt;
201   }
202 }
203 
204 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64 a,int64 b)205 optional<int64> CheckedAdd(int64 a, int64 b) {
206   // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
207   // different sign, see Hacker's Delignt 2nd Ed. pp 28.
208   uint64 aa = absl::bit_cast<uint64>(a);
209   uint64 bb = absl::bit_cast<uint64>(b);
210   int64 result = absl::bit_cast<int64>(aa + bb);
211   if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
212     return nullopt;
213   }
214   return result;
215 }
216 
217 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64 a,int64 b)218 optional<int64> CheckedSubtract(int64 a, int64 b) {
219   uint64 aa = absl::bit_cast<uint64>(a);
220   uint64 bb = absl::bit_cast<uint64>(b);
221   int64 result = absl::bit_cast<int64>(aa - bb);
222   // Overflow occurred iff `a` and `b` have different signs and the sign of
223   // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
224   if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
225     return nullopt;
226   }
227   return result;
228 }
229 
230 // Check if
231 //  - `i` is initialized to a scalar constant K (namely, `indvar_init`),
232 //  - the while condition does `i < N` or `i <= N`, and
233 //  - the while body does `i++`.
234 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64 indvar_tuple_idx,const Literal & indvar_init)235 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
236                                                  int64 indvar_tuple_idx,
237                                                  const Literal& indvar_init) {
238   // First, find the scalar constant K that `i` is initialized to.
239   optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
240   if (!indvar_init_val) {
241     VLOG(2) << "Pattern-match failed: induction variable init is not a "
242                "constant scalar representable as an int64: "
243             << indvar_init.ToString();
244     return nullopt;
245   }
246 
247   // Check that `i` goes as `i++` in the while body.
248   //
249   // TODO(jlebar): We could also handle i-- and other idioms.
250   auto* while_body = while_op->while_body();
251   auto* while_body_indvar_update =
252       while_body->root_instruction()->operand(indvar_tuple_idx);
253   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
254   if (!Match(while_body_indvar_update,
255              m::AddAnyOrder(m::Op().Is(while_body_indvar),
256                             m::ConstantEffectiveScalar(1)))) {
257     VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
258             << while_body_indvar_update->ToString();
259     return nullopt;
260   }
261 
262   // Check that we do op(i, N) or op(N, i) as the while condition.  Capture the
263   // value N.
264   auto* while_cond = while_op->while_condition();
265   auto* while_cond_root = while_cond->root_instruction();
266   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
267   HloInstruction* while_cond_bound = nullptr;
268   if (!Match(while_cond_root,
269              m::Op().WithBinaryOperandsAnyOrder(
270                  m::Op().Is(while_cond_indvar),
271                  m::ConstantEffectiveScalar(&while_cond_bound)))) {
272     VLOG(2) << "Pattern-match failed: while condition is not of the form "
273                "op(i, N) or op(N, i).";
274     return nullopt;
275   }
276   // Note: If this succeeds, the constant `N` is representable as an int64 --
277   // that is, if it's an XLA U64, it fits within an int64.
278   optional<int64> while_cond_bound_val =
279       LiteralAsScalarInt64(while_cond_bound->literal());
280   if (!while_cond_bound_val) {
281     VLOG(2) << "Pattern-match failed: while condition induction variable is "
282                "not a constant scalar representable as an int64.";
283     return nullopt;
284   }
285 
286   // Handle `i = K; i < N; ++i`.
287   if (Match(while_cond_root,
288             m::Op()
289                 .WithComparisonDirection(ComparisonDirection::kLt)
290                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
291     VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
292             << while_cond_root->ToString();
293     optional<int64> trips =
294         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
295     if (trips) {
296       return std::max(int64{0}, *trips);
297     } else {
298       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
299       return nullopt;
300     }
301   }
302 
303   // Handle `i = K; i <= N; ++i`.
304   if (Match(while_cond_root,
305             m::Op()
306                 .WithComparisonDirection(ComparisonDirection::kLe)
307                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
308     VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
309             << while_cond_root->ToString();
310     optional<int64> trips =
311         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
312     if (!trips) {
313       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
314       return nullopt;
315     }
316     trips = CheckedAdd(*trips, 1);
317     if (!trips) {
318       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
319       return nullopt;
320     }
321     return std::max<int64>(0, *trips);
322   }
323 
324   VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
325           << while_cond_root->ToString();
326   return nullopt;
327 }
328 
ComputeWhileLoopTripCount(HloInstruction * while_op,int64 max_brute_force_iters)329 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
330                                           int64 max_brute_force_iters) {
331   VLOG(2) << "Getting trip count for loop " << while_op->ToString();
332 
333   // The loop's induction variable is found at
334   //
335   //   get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
336   //
337   // where comp is while_op->while_body() or while_op->while_condition().
338   optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
339   if (!indvar_tuple_idx) {
340     return nullopt;
341   }
342 
343   // Now that we know the index of the induction variable, we can we can try to
344   // compute how many times the loop executes.  Start by computing the induction
345   // variable's initial value.
346   HloEvaluator evaluator(/*max_loop_iterations=*/0);
347   auto* while_init = while_op->mutable_operand(0);
348   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
349   StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
350   if (!indvar_init_result.ok()) {
351     VLOG(2) << "Couldn't evaluate induction variable init, "
352             << indvar_init_result.status() << ", " << indvar_init->ToString();
353     return nullopt;
354   }
355   Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
356 
357   // First, try to pattern-match.
358   if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
359                                                   indvar_iter_val)) {
360     return trip_count;
361   }
362 
363   // If our pattern-match failed, try brute-forcing the loop trip count.
364   auto* while_body = while_op->while_body();
365   auto* while_body_indvar_update =
366       while_body->root_instruction()->operand(*indvar_tuple_idx);
367   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
368 
369   auto* while_cond = while_op->while_condition();
370   auto* while_cond_root = while_cond->root_instruction();
371   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
372 
373   for (int64 trip_count = 0; trip_count != max_brute_force_iters + 1;
374        ++trip_count) {
375     StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
376         while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
377     if (!result.ok()) {
378       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
379       return nullopt;
380     }
381     if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
382       VLOG(2) << "Loop has static trip count of " << trip_count;
383       return trip_count;
384     }
385 
386     // Calculate the value of the induction variable after one iteration of the
387     // loop, and check whether the while condition is true with this new value.
388     StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
389         while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
390     if (!indvar_next_result.ok()) {
391       VLOG(2) << "Couldn't evaluate induction variable update: "
392               << indvar_next_result.status();
393       return nullopt;
394     }
395     indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
396   }
397 
398   VLOG(2) << "Loop has unknown trip count.";
399   return nullopt;
400 }
401 
402 // If the only user of this instruction is a get-tuple-element, return that
403 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
404 // get a false negative if there are several copies of the same GTE, or there
405 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)406 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
407   if (inst->user_count() != 1) {
408     return nullptr;
409   }
410 
411   HloInstruction* user = inst->users().back();
412   if (user->opcode() != HloOpcode::kGetTupleElement) {
413     return nullptr;
414   }
415   return user;
416 }
417 
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)418 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
419   // If we know the exact trip count, it's also the upper bound.
420   auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
421   if (exact_trip_count) {
422     VLOG(2) << "Loop has exact trip count.";
423     return exact_trip_count;
424   }
425 
426   // There is one more case we know how to handle. If the loop condition only
427   // looks at one element of the tuple, and the loop body sets this element to a
428   // constant, there are two options:
429   // 1) Evaluating the condition on this constant returns true. In this case,
430   // the loop either executes 0 times, or is an infinite loop, depending on the
431   // init value.
432   // 2) Evaluating the condition on this constant returns false. In this case,
433   // the loop executes 0 or 1 times, depending on the init value. This means
434   // that, regardless of the init value, the upper bound on the trip count is 1.
435 
436   // Check whether the condition depends on a single parameter, and find out
437   // which.
438   auto* while_cond = while_op->while_condition();
439   auto* while_cond_param = while_cond->parameter_instruction(0);
440   auto* cond_gte = GetOnlyGTE(while_cond_param);
441   if (!cond_gte) {
442     VLOG(2) << "Induction variable not found in loop condition: "
443             << while_cond->root_instruction()->ToString();
444     return nullopt;
445   }
446 
447   // Now check whether this gets set to a constant by the while body.
448   auto* while_body = while_op->while_body();
449   auto* while_body_root = while_body->root_instruction();
450   if (while_body_root->opcode() != HloOpcode::kTuple) {
451     VLOG(3) << "While body's root is not a tuple instruction: "
452             << while_body_root->ToString();
453     return nullopt;
454   }
455 
456   int64 indvar_index = cond_gte->tuple_index();
457   auto* while_body_indvar = while_body_root->operand(indvar_index);
458   if (while_body_indvar->opcode() != HloOpcode::kConstant) {
459     VLOG(3) << "While body does not set the IV to a constant: "
460             << while_body_indvar->ToString();
461     return nullopt;
462   }
463 
464   // We have a constant. Evaluate the condition on this constant.
465   HloEvaluator evaluator(/*max_loop_iterations=*/0);
466   Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
467   TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
468                                   /*dest_shape_index=*/{indvar_index},
469                                   /*src_shape_index=*/{}));
470   StatusOr<Literal> eval_result =
471       evaluator.Evaluate(*while_cond, {std::move(fake_input)});
472 
473   if (!eval_result.ok()) {
474     VLOG(2) << "Couldn't evaluate while loop condition.";
475     return nullopt;
476   }
477 
478   Literal cond_result_pred = std::move(eval_result.ValueOrDie());
479   CHECK(ShapeUtil::Equal(cond_result_pred.shape(),
480                          ShapeUtil::MakeShape(PRED, {})));
481 
482   // Per the explanation above, if the evaluated condition returns false, the
483   // loop executes at most once.
484   bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
485   if (!cond_returns_true) {
486     VLOG(2) << "Upper bound on the trip count is 1";
487     return 1;
488   }
489 
490   VLOG(2) << "Loop has no known upper bound on the trip count.";
491   return nullopt;
492 }
493 
494 }  // namespace xla
495