1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17 
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25 
26 namespace xla {
27 
28 using absl::nullopt;
29 using absl::optional;
30 namespace m = match;
31 
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36   const HloInstruction* result = nullptr;
37   for (const HloInstruction* operand : instr->operands()) {
38     if (!operand->IsConstant()) {
39       if (result != nullptr) {
40         CHECK_EQ(result, operand);
41       }
42       result = operand;
43     }
44   }
45   CHECK_NE(result, nullptr);
46   return result;
47 }
48 
49 // If all of instr's operands are either constants or have the form
50 //   get-tuple-element(gte_operand, N)
51 // for the same value N, returns N.  Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
53                                           const HloInstruction* gte_operand) {
54   VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55           << gte_operand->ToString() << ")";
56 
57   // Among the operands of `instr`, find one that is a get-tuple-element op.
58   auto gte_it = c_find_if(instr->operands(), [](const HloInstruction* instr) {
59     return instr->opcode() == HloOpcode::kGetTupleElement;
60   });
61   if (gte_it == instr->operands().end()) {
62     VLOG(2) << "instr does not have a gte operand.";
63     return nullopt;
64   }
65 
66   // All operands of `instr` must be either constants or of the form
67   //   get-tuple-element(gte_operand, tuple_idx)
68   // for the same value tuple_idx.
69   int64 tuple_idx = (*gte_it)->tuple_index();
70   for (const HloInstruction* operand : instr->operands()) {
71     if (!Match(operand, m::Constant()) &&
72         !Match(operand,
73                m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
74       VLOG(2)
75           << "instr uses something other than a constant or gte(gte_operand, "
76           << tuple_idx << "): " << operand->ToString();
77       return nullopt;
78     }
79   }
80   return tuple_idx;
81 }
82 
83 // The below function identifies a subset of all possible auxiliary
84 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
85 // gte(param0, N)
86 // The function checks if the loop body plumbs the AIV
87 // through the same tuple index at root, and that ops involving AIV
88 // involve constants.
89 //   op2 = op(constants, gte(param0, N), constants)
90 //   op3 = op(constants, f(op2, gte(param0, N), constants)
91 //   op4 = op(constants, f(op3, constants)
92 //   root = tuple(..., op4, ...)
93 // Further, the ops are restricted to basic math ops (+,-,*,/).
94 // Finally, loop invariant GTEs are excluded from AIVs.
95 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)96 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
97     const HloInstruction* while_op) {
98   std::vector<const HloInstruction*> aux_ind_gte;
99   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
100   auto* while_body = while_op->while_body();
101   auto* while_body_param = while_body->parameter_instruction(0);
102   VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
103   VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
104   VLOG(2) << "the parameter user count:" << while_body_param->users().size();
105   if (while_body_param == nullptr) return aux_ind_gte;
106 
107   // candidates_pairs = pair<inst, inst>(
108   //   operands of the root while body,
109   //   GTE only operands that index into the same position in the parameter)
110   // for each candidate_pair (x, y)
111   //  find all paths between x and y,
112   //  each paths should satisfy the above listed criterion
113   //  index that x and y used is added as a aux variable index
114   std::map<int64, const HloInstruction*> extractions;
115   for (const HloInstruction* indx_instr : while_body_param->users()) {
116     if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
117       continue;
118     }
119     auto it = extractions.find(indx_instr->tuple_index());
120     // if we find two extractions at the same index, we ignore such
121     // a candidate
122     if (it != extractions.end()) {
123       it->second = nullptr;
124       VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
125     } else {
126       extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
127       VLOG(2) << "inserting extraction :" << indx_instr->ToString();
128     }
129   }
130   VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
131   if (extractions.empty()) {
132     return aux_ind_gte;
133   }
134 
135   auto* while_body_root = while_body->root_instruction();
136   if (while_body_root->opcode() != HloOpcode::kTuple) {
137     VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
138     return aux_ind_gte;
139   }
140   int64 index = -1;
141   std::map<int64, const HloInstruction*> insertions;
142   for (const HloInstruction* operand : while_body_root->operands()) {
143     index++;
144     if (!operand->IsConstant()) {
145       auto it = insertions.find(index);
146       if (it != insertions.end()) {
147         it->second = nullptr;
148         VLOG(2) << "two insertions at same index:" << operand->ToString();
149       } else {
150         insertions.insert(std::make_pair(index, operand));
151         VLOG(2) << "inserting insertions:" << operand->ToString();
152       }
153     }
154   }
155   if (insertions.empty()) {
156     return aux_ind_gte;
157   }
158 
159   std::map<int64, std::pair<const HloInstruction*, const HloInstruction*>>
160       candidate_pairs;
161   for (; index >= 0; --index) {
162     const HloInstruction *ext, *inst;
163     ext = (extractions.find(index) != extractions.end())
164               ? extractions.find(index)->second
165               : nullptr;
166     inst = (insertions.find(index) != insertions.end())
167                ? insertions.find(index)->second
168                : nullptr;
169     if (ext != nullptr && inst != nullptr) {
170       // Filter out trivial aux, i.e., extract directly to an insert.
171       if (ext != inst) {
172         candidate_pairs.insert(
173             std::make_pair(index, std::make_pair(ext, inst)));
174       }
175     }
176   }
177   VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
178 
179   // Passed to ReachabilityMap to decide the type of produce-consumer edges
180   // along the reachability path.
181   const auto add_dependencies = [](const HloInstruction* hlo,
182                                    std::vector<HloInstruction*>* inputs) {
183     HloInstruction* non_const_operand = nullptr;
184     int num_non_constants = 0;
185     for (HloInstruction* operand : hlo->operands()) {
186       if (!operand->IsConstant()) {
187         num_non_constants++;
188         non_const_operand = operand;
189       }
190     }
191     if (num_non_constants == 1 &&
192         (hlo->opcode() == HloOpcode::kGetTupleElement ||
193          hlo->opcode() == HloOpcode::kAdd ||
194          hlo->opcode() == HloOpcode::kMultiply ||
195          hlo->opcode() == HloOpcode::kDivide ||
196          hlo->opcode() == HloOpcode::kSubtract)) {
197       inputs->push_back(non_const_operand);
198     }
199   };
200 
201   std::unique_ptr<HloReachabilityMap> hrm =
202       HloReachabilityMap::BuildWithRestrictions(
203           while_body,
204           absl::FunctionRef<void(const HloInstruction* hlo,
205                                  std::vector<HloInstruction*>* inputs)>(
206               add_dependencies));
207 
208   for (auto candidates : candidate_pairs) {
209     VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
210             << "*************" << (candidates.second.second)->ToString()
211             << std::endl;
212     if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
213       aux_ind_gte.push_back(candidates.second.first);
214       VLOG(2) << "YES";
215     } else {
216       VLOG(2) << "NO";
217     }
218   }
219   VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
220   return aux_ind_gte;
221 }
222 
223 // Tries to get the tuple index of the induction variable of a while loop.
224 //
225 // Checks that the loop condition and body both plumb the induction variable
226 // through the same tuple index, and that they both apply exactly one op to the
227 // induction variable before  deciding whether to do another loop iteration (in
228 // the loop condition's case) or packing the induction variable into the result
229 // tuple (in the loop body's case).
230 //
231 // Specifically, checks that the loop condition has structure
232 //
233 //   root = op(constants, get-tuple-elem(param0, N), constants)
234 //
235 // and the loop body has the structure
236 //
237 //   inc = op(constants, get-tuple-elem(param0, N), constants)
238 //   root = tuple(..., inc, ...)  // inc is N'th operand of tuple().
239 //
240 // If so, returns N.  Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)241 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
242   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
243   VLOG(2) << "Finding induction variable for loop "
244           << while_op->ToShortString();
245 
246   // The while_cond computation should have the form
247   //
248   //   while_cond_root =
249   //       op(constants, get-tuple-elem(while_cond_param, N), constants).
250   //
251   // If it does, set indvar_tuple_idx to N.
252   auto* while_cond = while_op->while_condition();
253   auto* while_cond_root = while_cond->root_instruction();
254   auto* while_cond_param = while_cond->parameter_instruction(0);
255   optional<int64> indvar_tuple_idx =
256       GetGTEOperandIndex(while_cond_root, while_cond_param);
257   if (!indvar_tuple_idx) {
258     VLOG(2) << "Induction variable not found in loop condition: "
259             << while_cond->root_instruction()->ToString();
260     return nullopt;
261   }
262 
263   // The while_body computation should have the form
264   //
265   //   while_body_inc =
266   //       op(constants, get-tuple-elem(while_body_param, N), constants)
267   //   while_body_root = tuple(..., while_body_inc, ...)
268   //
269   // where while_body_inc is operand N of while_body_root.
270   auto* while_body = while_op->while_body();
271   auto* while_body_root = while_body->root_instruction();
272   if (while_body_root->opcode() != HloOpcode::kTuple) {
273     VLOG(2) << "While body's root is not a tuple instruction: "
274             << while_body_root->ToString();
275     return nullopt;
276   }
277 
278   auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
279   auto* while_body_param = while_body->parameter_instruction(0);
280   optional<int64> while_body_indvar_tuple_idx =
281       GetGTEOperandIndex(while_body_inc, while_body_param);
282   if (!while_body_indvar_tuple_idx) {
283     VLOG(2)
284         << "Induction variable not found in while body increment instruction: "
285         << while_body_inc->ToString();
286     return nullopt;
287   }
288   if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
289     VLOG(2) << "Tuple index of induction variable does not match between loop "
290                "condition ("
291             << *indvar_tuple_idx << ") and while body ("
292             << *while_body_indvar_tuple_idx << ")";
293     return nullopt;
294   }
295 
296   // Finally, check that the while loop's initial value is a tuple with enough
297   // elements.
298   auto* while_init = while_op->operand(0);
299   if (while_init->opcode() != HloOpcode::kTuple) {
300     VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
301     return nullopt;
302   }
303 
304   VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
305   return indvar_tuple_idx;
306 }
307 
308 // Converts the given literal to a scalar int64, if possible.
309 //
310 // Fails if the literal is not an integral type or if the value it contains
311 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)312 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
313   if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
314     VLOG(2) << "literal is not an effective scalar: " << l.ToString();
315     return nullopt;
316   }
317   switch (l.shape().element_type()) {
318     case S8:
319       return l.GetFirstElement<int8>();
320     case S16:
321       return l.GetFirstElement<int16>();
322     case S32:
323       return l.GetFirstElement<int32>();
324     case S64:
325       return l.GetFirstElement<int64>();
326     case U8:
327       return l.GetFirstElement<uint8>();
328     case U16:
329       return l.GetFirstElement<uint16>();
330     case U32:
331       return l.GetFirstElement<uint32>();
332     case U64: {
333       uint64 v = l.GetFirstElement<uint64>();
334       if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
335         VLOG(2) << "uint64 literal is out of range for int64: " << v;
336         return nullopt;
337       }
338       return v;
339     }
340     default:
341       VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
342       return nullopt;
343   }
344 }
345 
346 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64 a,int64 b)347 optional<int64> CheckedAdd(int64 a, int64 b) {
348   // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
349   // different sign, see Hacker's Delignt 2nd Ed. pp 28.
350   uint64 aa = absl::bit_cast<uint64>(a);
351   uint64 bb = absl::bit_cast<uint64>(b);
352   int64 result = absl::bit_cast<int64>(aa + bb);
353   if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
354     return nullopt;
355   }
356   return result;
357 }
358 
359 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64 a,int64 b)360 optional<int64> CheckedSubtract(int64 a, int64 b) {
361   uint64 aa = absl::bit_cast<uint64>(a);
362   uint64 bb = absl::bit_cast<uint64>(b);
363   int64 result = absl::bit_cast<int64>(aa - bb);
364   // Overflow occurred iff `a` and `b` have different signs and the sign of
365   // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
366   if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
367     return nullopt;
368   }
369   return result;
370 }
371 
372 // Check if
373 //  - `i` is initialized to a scalar constant K (namely, `indvar_init`),
374 //  - the while condition does `i < N` or `i <= N`, and
375 //  - the while body does `i++`.
376 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64 indvar_tuple_idx,const Literal & indvar_init)377 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
378                                                  int64 indvar_tuple_idx,
379                                                  const Literal& indvar_init) {
380   // First, find the scalar constant K that `i` is initialized to.
381   optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
382   if (!indvar_init_val) {
383     VLOG(2) << "Pattern-match failed: induction variable init is not a "
384                "constant scalar representable as an int64: "
385             << indvar_init.ToString();
386     return nullopt;
387   }
388 
389   // Check that `i` goes as `i++` in the while body.
390   //
391   // TODO(jlebar): We could also handle i-- and other idioms.
392   auto* while_body = while_op->while_body();
393   auto* while_body_indvar_update =
394       while_body->root_instruction()->operand(indvar_tuple_idx);
395   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
396   if (!Match(while_body_indvar_update,
397              m::AddAnyOrder(m::Op().Is(while_body_indvar),
398                             m::ConstantEffectiveScalar(1)))) {
399     VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
400             << while_body_indvar_update->ToString();
401     return nullopt;
402   }
403 
404   // Check that we do op(i, N) or op(N, i) as the while condition.  Capture the
405   // value N.
406   auto* while_cond = while_op->while_condition();
407   auto* while_cond_root = while_cond->root_instruction();
408   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
409   HloInstruction* while_cond_bound = nullptr;
410   if (!Match(while_cond_root,
411              m::Op().WithBinaryOperandsAnyOrder(
412                  m::Op().Is(while_cond_indvar),
413                  m::ConstantEffectiveScalar(&while_cond_bound)))) {
414     VLOG(2) << "Pattern-match failed: while condition is not of the form "
415                "op(i, N) or op(N, i).";
416     return nullopt;
417   }
418   // Note: If this succeeds, the constant `N` is representable as an int64 --
419   // that is, if it's an XLA U64, it fits within an int64.
420   optional<int64> while_cond_bound_val =
421       LiteralAsScalarInt64(while_cond_bound->literal());
422   if (!while_cond_bound_val) {
423     VLOG(2) << "Pattern-match failed: while condition induction variable is "
424                "not a constant scalar representable as an int64.";
425     return nullopt;
426   }
427 
428   // Handle `i = K; i < N; ++i`.
429   if (Match(while_cond_root,
430             m::Op()
431                 .WithComparisonDirection(ComparisonDirection::kLt)
432                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
433     VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
434             << while_cond_root->ToString();
435     optional<int64> trips =
436         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
437     if (trips) {
438       return std::max(int64{0}, *trips);
439     } else {
440       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
441       return nullopt;
442     }
443   }
444 
445   // Handle `i = K; i <= N; ++i`.
446   if (Match(while_cond_root,
447             m::Op()
448                 .WithComparisonDirection(ComparisonDirection::kLe)
449                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
450     VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
451             << while_cond_root->ToString();
452     optional<int64> trips =
453         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
454     if (!trips) {
455       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
456       return nullopt;
457     }
458     trips = CheckedAdd(*trips, 1);
459     if (!trips) {
460       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
461       return nullopt;
462     }
463     return std::max<int64>(0, *trips);
464   }
465 
466   VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
467           << while_cond_root->ToString();
468   return nullopt;
469 }
470 
ComputeWhileLoopTripCount(HloInstruction * while_op,int64 max_brute_force_iters)471 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
472                                           int64 max_brute_force_iters) {
473   VLOG(2) << "Getting trip count for loop " << while_op->ToString();
474 
475   // The loop's induction variable is found at
476   //
477   //   get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
478   //
479   // where comp is while_op->while_body() or while_op->while_condition().
480   optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
481   if (!indvar_tuple_idx) {
482     return nullopt;
483   }
484 
485   // Now that we know the index of the induction variable, we can we can try to
486   // compute how many times the loop executes.  Start by computing the induction
487   // variable's initial value.
488   HloEvaluator evaluator(/*max_loop_iterations=*/0);
489   auto* while_init = while_op->mutable_operand(0);
490   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
491   StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
492   if (!indvar_init_result.ok()) {
493     VLOG(2) << "Couldn't evaluate induction variable init, "
494             << indvar_init_result.status() << ", " << indvar_init->ToString();
495     return nullopt;
496   }
497   Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
498 
499   // First, try to pattern-match.
500   if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
501                                                   indvar_iter_val)) {
502     return trip_count;
503   }
504 
505   // If our pattern-match failed, try brute-forcing the loop trip count.
506   auto* while_body = while_op->while_body();
507   auto* while_body_indvar_update =
508       while_body->root_instruction()->operand(*indvar_tuple_idx);
509   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
510 
511   auto* while_cond = while_op->while_condition();
512   auto* while_cond_root = while_cond->root_instruction();
513   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
514 
515   for (int64 trip_count = 0; trip_count != max_brute_force_iters + 1;
516        ++trip_count) {
517     StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
518         while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
519     if (!result.ok()) {
520       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
521       return nullopt;
522     }
523     if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
524       VLOG(2) << "Loop has static trip count of " << trip_count;
525       return trip_count;
526     }
527 
528     // Calculate the value of the induction variable after one iteration of the
529     // loop, and check whether the while condition is true with this new value.
530     StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
531         while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
532     if (!indvar_next_result.ok()) {
533       VLOG(2) << "Couldn't evaluate induction variable update: "
534               << indvar_next_result.status();
535       return nullopt;
536     }
537     indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
538   }
539 
540   VLOG(2) << "Loop has unknown trip count.";
541   return nullopt;
542 }
543 
544 // If the only user of this instruction is a get-tuple-element, return that
545 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
546 // get a false negative if there are several copies of the same GTE, or there
547 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)548 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
549   if (inst->user_count() != 1) {
550     return nullptr;
551   }
552 
553   HloInstruction* user = inst->users().back();
554   if (user->opcode() != HloOpcode::kGetTupleElement) {
555     return nullptr;
556   }
557   return user;
558 }
559 
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)560 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
561   // If we know the exact trip count, it's also the upper bound.
562   auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
563   if (exact_trip_count) {
564     VLOG(2) << "Loop has exact trip count.";
565     return exact_trip_count;
566   }
567 
568   // There is one more case we know how to handle. If the loop condition only
569   // looks at one element of the tuple, and the loop body sets this element to a
570   // constant, there are two options:
571   // 1) Evaluating the condition on this constant returns true. In this case,
572   // the loop either executes 0 times, or is an infinite loop, depending on the
573   // init value.
574   // 2) Evaluating the condition on this constant returns false. In this case,
575   // the loop executes 0 or 1 times, depending on the init value. This means
576   // that, regardless of the init value, the upper bound on the trip count is 1.
577 
578   // Check whether the condition depends on a single parameter, and find out
579   // which.
580   auto* while_cond = while_op->while_condition();
581   auto* while_cond_param = while_cond->parameter_instruction(0);
582   auto* cond_gte = GetOnlyGTE(while_cond_param);
583   if (!cond_gte) {
584     VLOG(2) << "Induction variable not found in loop condition: "
585             << while_cond->root_instruction()->ToString();
586     return nullopt;
587   }
588 
589   // Now check whether this gets set to a constant by the while body.
590   auto* while_body = while_op->while_body();
591   auto* while_body_root = while_body->root_instruction();
592   if (while_body_root->opcode() != HloOpcode::kTuple) {
593     VLOG(3) << "While body's root is not a tuple instruction: "
594             << while_body_root->ToString();
595     return nullopt;
596   }
597 
598   int64 indvar_index = cond_gte->tuple_index();
599   auto* while_body_indvar = while_body_root->operand(indvar_index);
600   if (while_body_indvar->opcode() != HloOpcode::kConstant) {
601     VLOG(3) << "While body does not set the IV to a constant: "
602             << while_body_indvar->ToString();
603     return nullopt;
604   }
605 
606   // We have a constant. Evaluate the condition on this constant.
607   HloEvaluator evaluator(/*max_loop_iterations=*/0);
608   Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
609   TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
610                                   /*dest_shape_index=*/{indvar_index},
611                                   /*src_shape_index=*/{}));
612   StatusOr<Literal> eval_result =
613       evaluator.Evaluate(*while_cond, {std::move(fake_input)});
614 
615   if (!eval_result.ok()) {
616     VLOG(2) << "Couldn't evaluate while loop condition.";
617     return nullopt;
618   }
619 
620   Literal cond_result_pred = std::move(eval_result.ValueOrDie());
621   CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
622                                       ShapeUtil::MakeShape(PRED, {})));
623 
624   // Per the explanation above, if the evaluated condition returns false, the
625   // loop executes at most once.
626   bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
627   if (!cond_returns_true) {
628     VLOG(2) << "Upper bound on the trip count is 1";
629     return 1;
630   }
631 
632   VLOG(2) << "Loop has no known upper bound on the trip count.";
633   return nullopt;
634 }
635 
636 }  // namespace xla
637