1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/call_inliner.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
33 
34 namespace xla {
35 
36 namespace m = match;
37 using absl::optional;
38 using hlo_query::ContainsInstrWithOpcode;
39 
40 // This is a utility function that removes the given tuple indices from the
41 // while loop init, body, and condition. The final shape returned is still the
42 // same as before.
RemoveDeadTupleIndices(HloInstruction * while_op,absl::flat_hash_set<int64> & used_tuple_indices)43 static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
44     HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
45   // Build up maps from the old/new to the new/old tuple indices.
46   std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
47                                           used_tuple_indices.end());
48   absl::c_sort(new_to_old_tuple_idx);
49 
50   HloModule* module = while_op->GetModule();
51   HloComputation* computation = while_op->parent();
52   HloInstruction* while_init = while_op->mutable_operand(0);
53   HloComputation* while_cond = while_op->while_condition();
54   HloComputation* while_body = while_op->while_body();
55   HloInstruction* while_body_root = while_body->root_instruction();
56 
57   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
58 
59   absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
60   for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
61     int64 old_idx = new_to_old_tuple_idx[new_idx];
62     old_to_new_tuple_idx[old_idx] = new_idx;
63     VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
64   }
65 
66   // Compute the shape of the while op after we remove the dead indices.
67   std::vector<Shape> new_while_tuple_elem_shapes;
68   new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
69   for (int64 old_idx : new_to_old_tuple_idx) {
70     new_while_tuple_elem_shapes.push_back(
71         while_init->shape().tuple_shapes(old_idx));
72   }
73   Shape new_while_shape =
74       ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
75 
76   // Returns a map from elements in the computation to new instructions which
77   // replace the old instructions after we remove unused elements from the while
78   // tuple.
79   auto make_while_computation_replacements = [&](const HloComputation* comp) {
80     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
81         replacements;
82 
83     auto* param = comp->parameter_instruction(0);
84     replacements.emplace(param, HloInstruction::CreateParameter(
85                                     0, new_while_shape, param->name()));
86 
87     // Materialize param's users, since we're about to add new ones below.
88     std::vector<HloInstruction*> materialized_users(param->users().begin(),
89                                                     param->users().end());
90     for (const auto* user : materialized_users) {
91       // The while body root is handled separately.
92       if (user == while_body_root) {
93         continue;
94       }
95       CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
96           << user->ToString(print_no_metadata);
97 
98       int64 old_idx = user->tuple_index();
99       auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
100       if (new_idx_iter != old_to_new_tuple_idx.end()) {
101         // This is a GTE of an index that survives.  Replace it.
102         replacements.emplace(
103             user, HloInstruction::CreateGetTupleElement(user->shape(), param,
104                                                         new_idx_iter->second));
105       } else {
106         // This is a GTE of an index that we've removed.  Remove it from the
107         // cloned computation.
108         CHECK(user->user_count() == 0 ||
109               user->user_count() == 1 &&
110                   user->users().front() == while_body_root)
111             << "Instruction " << user->ToString(print_no_metadata)
112             << " should be unused (except by root of while body), but has "
113                "users: {"
114             << absl::StrJoin(user->users(), ", ",
115                              [&](string* out, const HloInstruction* instr) {
116                                absl::StrAppend(
117                                    out, instr->ToString(print_no_metadata));
118                              })
119             << "}";
120 
121         replacements.emplace(user, nullptr);
122       }
123     }
124     return replacements;
125   };
126 
127   // Create the new while condition, body, and init value.
128   std::unique_ptr<HloComputation> new_while_cond =
129       while_cond->CloneWithReplacements(
130           make_while_computation_replacements(while_cond));
131 
132   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
133       while_body_replacements = make_while_computation_replacements(while_body);
134   std::vector<HloInstruction*> new_while_body_root_elems;
135   new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
136   for (int64 old_idx : new_to_old_tuple_idx) {
137     new_while_body_root_elems.push_back(
138         while_body_root->mutable_operand(old_idx));
139   }
140   while_body_replacements.emplace(
141       while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
142   std::unique_ptr<HloComputation> new_while_body =
143       while_body->CloneWithReplacements(std::move(while_body_replacements));
144 
145   // Add a new while_init instruction that repackages the old while_init
146   // instruction's elements.  We rely on the AlgebraicSimplifier and DCE to
147   // clean this up in the common case where while_init is a tuple op.  (It's
148   // definitely tuple-shaped, but it's not necessarily a tuple op.)
149   std::vector<HloInstruction*> new_while_init_elems;
150   new_while_init_elems.reserve(new_to_old_tuple_idx.size());
151   for (int64 old_idx : new_to_old_tuple_idx) {
152     new_while_init_elems.push_back(
153         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
154             while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
155   }
156   auto* new_while_init = computation->AddInstruction(
157       HloInstruction::CreateTuple(new_while_init_elems));
158 
159   // Create the new while op.
160   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
161       new_while_shape,
162       module->AddEmbeddedComputation(std::move(new_while_cond)),
163       module->AddEmbeddedComputation(std::move(new_while_body)),
164       new_while_init));
165 
166   // Create a tuple op that recreates the output of the old while op.  That is,
167   // we transform to
168   //
169   //  new_while_init   while_init
170   //       |              |
171   //       V              |
172   //   new_while          |
173   //       |              |
174   //       -------|   |----
175   //              V   V
176   //            new_tuple
177   //                |
178   //                V
179   //    (orig. users of while op)
180   //
181   // The tuple simplifier will then simplify this if possible, removing
182   // new_tuple and while_init.
183   std::vector<HloInstruction*> new_tuple_elems;
184   const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
185   for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
186     auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
187     if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
188       int64 gte_idx = new_tuple_idx_it->second;
189       new_tuple_elems.push_back(
190           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
191               new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
192               gte_idx)));
193     } else {
194       new_tuple_elems.push_back(
195           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
196               while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
197     }
198   }
199   HloInstruction* new_tuple =
200       computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
201   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
202 
203   return new_while_op;
204 }
205 
206 // Tries to remove elements in a while loop's tuple that aren't used within the
207 // loop.
208 //
209 // Specifically, if a loop is tuple-shaped, and there exists some element of
210 // that tuple that is not used by the loop condition and is not used by the loop
211 // body except to pass it to the next iteration of the loop, then we can remove
212 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)213 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
214   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
215 
216   // Don't try this transformation if the while loop isn't removable, since if
217   // it succeeds ultimately we're going to have to replace the old while loop
218   // with a new one.
219   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
220     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
221     return false;
222   }
223 
224   HloInstruction* while_init = while_op->mutable_operand(0);
225   HloComputation* while_cond = while_op->while_condition();
226   HloComputation* while_body = while_op->while_body();
227   HloInstruction* while_body_root = while_body->root_instruction();
228 
229   if (!while_init->shape().IsTuple()) {
230     VLOG(2) << "While op's carried value isn't tuple shaped.";
231     return false;
232   }
233 
234   if (while_body_root->opcode() != HloOpcode::kTuple) {
235     VLOG(2) << "While body's root is not a tuple(...) instruction.";
236     return false;
237   }
238 
239   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
240 
241   // Bail if param0 of while_cond or while_body has users which aren't of type
242   // get-tuple-element.
243   for (const HloInstruction* instr : {while_body->parameter_instruction(0),
244                                       while_cond->parameter_instruction(0)}) {
245     for (const HloInstruction* user : instr->users()) {
246       if (user->opcode() != HloOpcode::kGetTupleElement) {
247         VLOG(2) << "Cowardly refusing to analyze while loop with "
248                 << instr->ToString(print_no_metadata)
249                 << " used by non-GTE instruction "
250                 << user->ToString(print_no_metadata) << " in computation "
251                 << instr->parent()->name();
252         return false;
253       }
254     }
255   }
256 
257   const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
258   if (tuple_size == 0) {
259     VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
260                "empty.";
261     return false;
262   }
263 
264   absl::flat_hash_set<int64> used_tuple_indices;
265   for (HloComputation* comp : {while_body, while_cond}) {
266     // The HLO verifier ensures that while_input's shape matches while_init's
267     // shape, which we verified above is a tuple.
268     HloInstruction* while_input = comp->parameter_instruction(0);
269 
270     for (const HloInstruction* user : while_input->users()) {
271       // This user doesn't count if it's only used by the while body's root, and
272       // the root places the tuple element into the same index of the tuple as
273       // it came from.  That just amounts to us carrying the variable through
274       // the loop.
275       //
276       // Careful: HloInstruction::operand_index returns the first index the
277       // operand appears in, but it may appear more than once!
278       if (user->user_count() == 1 && user->users().front() == while_body_root &&
279           while_body_root->operand_index(user) == user->tuple_index() &&
280           absl::c_count(while_body_root->operands(), user) == 1) {
281         continue;
282       }
283 
284       used_tuple_indices.insert(user->tuple_index());
285       if (used_tuple_indices.size() == tuple_size) {
286         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
287                 << " uses all of its inputs; no simplification possible.";
288         return false;
289       }
290     }
291   }
292 
293   // If a tuple element is not passed unmodified from the while body's param0
294   // through to the while body's root, count that element as "used", since
295   // removing that element would be observable.
296   for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
297     if (used_tuple_indices.contains(i)) {
298       continue;
299     }
300 
301     auto* operand = while_body_root->operand(i);
302     if (operand->opcode() != HloOpcode::kGetTupleElement ||
303         operand->operand(0) != while_body->parameter_instruction(0) ||
304         operand->tuple_index() != i) {
305       VLOG(2) << "Tuple index " << i
306               << " is not passed through loop body unmodified.";
307       used_tuple_indices.insert(i);
308 
309       if (used_tuple_indices.size() == tuple_size) {
310         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
311                 << " uses all of its inputs; no simplification possible.";
312         return false;
313       }
314     }
315   }
316 
317   // If we got here, used_tuple_indices.size() < tuple_size, meaning some
318   // elements of the loop's tuple aren't used by while_body or while_cond.
319   CHECK_LT(used_tuple_indices.size(), tuple_size);
320 
321   VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
322           << " elements from tuple of "
323           << while_op->ToString(print_no_metadata);
324 
325   TF_ASSIGN_OR_RETURN(while_op,
326                       RemoveDeadTupleIndices(while_op, used_tuple_indices));
327 
328   return true;
329 }
330 
331 // This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
332 // duplicates by replacing them with tuple_index, followed by a call to
333 // RemoveDeadTupleIndices.
TryRemoveRepeatedWhileTupleIndicesHelper(HloInstruction * while_op,const int64 tuple_index,absl::flat_hash_set<int64> & duplicates)334 static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
335     HloInstruction* while_op, const int64 tuple_index,
336     absl::flat_hash_set<int64>& duplicates) {
337   HloComputation* while_cond = while_op->while_condition();
338   HloComputation* while_body = while_op->while_body();
339   HloInstruction* while_init = while_op->mutable_operand(0);
340 
341   VLOG(2) << "while_init " << while_init->ToString() << " operands "
342           << while_init->operand_count();
343   VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
344           << " operands " << while_body->root_instruction()->operand_count();
345 
346   // Change the loop body and condition such that uses of the duplicates are
347   // replaced with the original tuple element.
348   for (HloComputation* comp : {while_body, while_cond}) {
349     auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
350         comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
351         comp->parameter_instruction(0), tuple_index));
352 
353     std::vector<HloInstruction*> instrs_to_replace;
354     for (auto* instr : comp->instructions()) {
355       if (instr->opcode() == HloOpcode::kGetTupleElement &&
356           duplicates.contains(instr->tuple_index()) &&
357           instr->operand(0) == comp->parameter_instruction(0)) {
358         instrs_to_replace.push_back(instr);
359       }
360     }
361 
362     for (auto instr : instrs_to_replace) {
363       TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
364     }
365   }
366 
367   // We know which tuple indices are useful; i.e, those which aren't duplicates.
368   absl::flat_hash_set<int64> used_tuple_indices;
369   for (int index = 0; index < while_init->shape().tuple_shapes_size();
370        ++index) {
371     if (!duplicates.count(index)) {
372       used_tuple_indices.insert(index);
373     }
374   }
375 
376   // Remove the duplicate tuple elements.
377   TF_ASSIGN_OR_RETURN(while_op,
378                       RemoveDeadTupleIndices(while_op, used_tuple_indices));
379 
380   return while_op;
381 }
382 
383 // If the while loop init passes the same values to several tuple indices, and
384 // if the body keeps on passing them through, we can remove the duplicates.
TryRemoveRepeatedWhileTupleIndices(HloInstruction * while_op)385 static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
386     HloInstruction* while_op) {
387   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
388 
389   int index_to_investigate = 0;
390   // Don't try this transformation if the while loop isn't removable, since if
391   // it succeeds ultimately we're going to have to replace the old while loop
392   // with a new one.
393   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
394     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
395     return false;
396   }
397 
398   HloInstruction* while_init = while_op->mutable_operand(0);
399   HloComputation* while_cond = while_op->while_condition();
400   HloComputation* while_body = while_op->while_body();
401   HloInstruction* while_body_root = while_body->root_instruction();
402 
403   if (!while_init->shape().IsTuple()) {
404     VLOG(2) << "While op's carried value isn't tuple shaped.";
405     return false;
406   }
407 
408   bool changed = false;
409   while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
410     if (!while_init->shape().IsTuple() ||
411         while_init->opcode() != HloOpcode::kTuple) {
412       VLOG(2) << "While op's carried value isn't tuple shaped.";
413       return false;
414     }
415 
416     if (while_body_root->opcode() != HloOpcode::kTuple) {
417       VLOG(2) << "While body's root is not a tuple(...) instruction.";
418       return false;
419     }
420 
421     auto& while_shape = while_init->shape();
422     VLOG(2) << "Iterating " << index_to_investigate;
423 
424     absl::flat_hash_set<int64> duplicates;
425     auto* pivot_init_elem = while_init->operand(index_to_investigate);
426     auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
427     if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
428         pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
429       if (pivot_body_elem->tuple_index() != index_to_investigate) {
430         VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
431                 << pivot_body_elem->tuple_index() << " index_to_investigate "
432                 << index_to_investigate;
433         index_to_investigate++;
434         continue;
435       }
436     } else {
437       index_to_investigate++;
438       continue;
439     }
440 
441     // Look from index_to_investigate onwards to see if it is repeated.
442     for (int64 i = index_to_investigate + 1;
443          i < while_shape.tuple_shapes_size(); ++i) {
444       auto* init_elem = while_init->operand(i);
445       auto* body_elem = while_body_root->operand(i);
446       if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
447           body_elem->operand(0) == while_body->parameter_instruction(0)) {
448         if (body_elem->tuple_index() != i) {
449           VLOG(2) << "Mismatch between body_elem->tuple_index() "
450                   << body_elem->tuple_index() << " i " << i;
451           continue;
452         }
453       } else {
454         continue;
455       }
456 
457       if (pivot_init_elem == init_elem) {
458         VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
459                 << pivot_init_elem->ToString();
460         VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
461                 << pivot_body_elem->ToString();
462         duplicates.insert(i);
463       }
464     }
465 
466     // If duplicates are found, call the helper to remove them.
467     if (!duplicates.empty()) {
468       VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
469               << pivot_init_elem->ToString();
470       TF_ASSIGN_OR_RETURN(while_op,
471                           TryRemoveRepeatedWhileTupleIndicesHelper(
472                               while_op, index_to_investigate, duplicates));
473       changed = true;
474       VLOG(2) << "Changed while_op " << while_op->ToString()
475               << " while_op operand count " << while_op->operand_count();
476       // Update the while loop variables so we can continue looking for
477       // duplicates of a different index.
478       while_init = while_op->mutable_operand(0);
479       while_cond = while_op->while_condition();
480       while_body = while_op->while_body();
481       while_body_root = while_body->root_instruction();
482     }
483     index_to_investigate++;
484   }
485 
486   return changed;
487 }
488 
489 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
490 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)491 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
492   HloModule* module = while_op->GetModule();
493   HloComputation* computation = while_op->parent();
494   auto* while_init = while_op->mutable_operand(0);
495   auto* while_body = while_op->while_body();
496   auto* while_cond = while_op->while_condition();
497   auto* while_body_root = while_body->root_instruction();
498   if (while_init->opcode() != HloOpcode::kTuple ||
499       while_body_root->opcode() != HloOpcode::kTuple) {
500     return false;
501   }
502 
503   TF_RET_CHECK(while_cond->num_parameters() == 1);
504   TF_RET_CHECK(while_body->num_parameters() == 1);
505   TF_RET_CHECK(
506       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
507 
508   absl::flat_hash_set<int64> constant_tuple_indices;
509   const auto& while_shape = while_init->shape();
510   for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
511     auto* init_elem = while_init->operand(i);
512     auto* body_elem = while_body_root->operand(i);
513     if (init_elem->opcode() == HloOpcode::kConstant &&
514         body_elem->opcode() == HloOpcode::kConstant &&
515         init_elem->literal() == body_elem->literal()) {
516       constant_tuple_indices.insert(i);
517     }
518   }
519 
520   if (constant_tuple_indices.empty()) {
521     return false;
522   }
523 
524   // OK, we found some constant elements of the while parameter!  Eliminate
525   // them.
526   std::vector<Shape> new_while_shape_elems;
527   for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
528     if (!constant_tuple_indices.count(i)) {
529       new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
530     }
531   }
532   Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
533 
534   // `new_instrs` holds instructions created outside of a computation for
535   // cloning.  Elements added here just need to live until the end of the
536   // relevant CloneWithReplacement call.
537   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
538   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
539     new_instrs.push_back(std::move(instr));
540     return new_instrs.back().get();
541   };
542 
543   // Returns a new tuple without the elements of constant_tuple_indices.
544   auto remove_constant_elems = [&](HloInstruction* instr) {
545     CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
546 
547     std::vector<HloInstruction*> tuple_elems;
548     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
549       if (!constant_tuple_indices.count(i)) {
550         tuple_elems.push_back(
551             add_new_instr(HloInstruction::CreateGetTupleElement(
552                 while_shape.tuple_shapes(i), instr, i)));
553       }
554     }
555     return HloInstruction::CreateTuple(tuple_elems);
556   };
557 
558   auto add_constant_elems = [&](HloInstruction* instr) {
559     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
560 
561     std::vector<HloInstruction*> tuple_elems;
562     int64 j = 0;
563     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
564       if (constant_tuple_indices.count(i)) {
565         tuple_elems.push_back(while_init->mutable_operand(i));
566       } else {
567         tuple_elems.push_back(
568             add_new_instr(HloInstruction::CreateGetTupleElement(
569                 while_shape.tuple_shapes(i), instr, j)));
570         ++j;
571       }
572     }
573     return HloInstruction::CreateTuple(tuple_elems);
574   };
575 
576   // Special case: constant_tuple_indices covers the whole while parameter, so
577   // the new while shape is the empty tuple.  In this case, the value of the
578   // while loop is simply equal to the value of `init`.
579   //
580   // It's unfortunate to special-case this, but it's simpler than the
581   // alternative.  The problem is that if our while parameter has no
582   // non-constant elems, the tuple returned by `add_constant_elems` won't depend
583   // on instr (the loop body/cond parameter), and therefore
584   // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
585   // invalid HLO.
586   if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
587     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
588     return true;
589   }
590 
591   std::unique_ptr<HloComputation> new_while_cond =
592       while_cond->CloneWithReplacementPairs({
593           while_cond->parameter_instruction(0),
594           add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
595               0, new_while_shape,
596               while_cond->parameter_instruction(0)->name()))),
597       });
598 
599   std::unique_ptr<HloComputation> new_while_body =
600       while_body->CloneWithReplacementPairs(
601           {
602               while_body->parameter_instruction(0),
603               add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
604                   0, new_while_shape,
605                   while_cond->parameter_instruction(0)->name()))),
606           },
607           {
608               while_body->root_instruction(),
609               remove_constant_elems(
610                   add_new_instr(while_body->root_instruction()->Clone())),
611           });
612 
613   // Create the final while loop, and add any new instructions created to
614   // `computation`.
615   new_instrs.clear();
616   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
617       while_op,
618       add_constant_elems(
619           computation->AddInstruction(HloInstruction::CreateWhile(
620               new_while_shape,
621               module->AddEmbeddedComputation(std::move(new_while_cond)),
622               module->AddEmbeddedComputation(std::move(new_while_body)),
623               add_new_instr(remove_constant_elems(while_init)))))));
624   for (auto& instr : new_instrs) {
625     computation->AddInstruction(std::move(instr));
626   }
627   return true;
628 }
629 
630 // Tries to remove a while loop from the graph.
631 //
632 //  - Loops with trip count of 0 can be replaced by the loop's "init" value.
633 //  - Loops with trip count of 1 can be replaced by the loop's body, with the
634 //    loop itself removed.
635 //
636 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)637 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
638   // Cowardly refuse to remove loops that are not removable.  In practice, this
639   // means that we can't remove loops that have control predecessors/successors.
640   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
641     VLOG(2) << "Not attempting to remove while loop that is not removable: "
642             << while_op->ToShortString();
643     return false;
644   }
645 
646   // Refuse to remove while loops with a condition that contain side-effects,
647   // because removing a while loop is tantamount to removing its condition.
648   //
649   // TODO(jlebar): This is conservative: We could instead just run the while
650   // condition once (trip-count == 0) or twice (trip-count == 1).
651   if (while_op->while_condition()->HasSideEffect()) {
652     VLOG(2) << "Not attempting to remove while loop whose condition contains "
653                "side-effecting instructions: "
654             << while_op->ToShortString();
655     return false;
656   }
657 
658   // Remove while loops with static trip count of 0.
659   optional<int64> trip_count =
660       ComputeWhileLoopTripCount(while_op, /*max_brute_force_iters=*/1);
661   if (trip_count && *trip_count == 0) {
662     // The loop never executes, so the value of the loop is the value of its
663     // "init" operand.
664     auto computation = while_op->parent();
665 
666     // Remove while_op (i.e., call ReplaceInstruction rather than
667     // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
668     // a loop without an intervening DCE, we don't try to re-remove the loop.
669     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
670         while_op, while_op->mutable_operand(0)));
671     return true;
672   }
673 
674   // Transform while loops with static trip count of 1 into a call op, then
675   // inline the call.
676   if (trip_count && *trip_count == 1) {
677     // Do not simplify the loop away when there is a side-effectful op,
678     // otherwise the infeed op may not inherit the data dependency from
679     // the while loop.
680     //
681     // Example: while_body (param_a) {
682     //   param_a = parameter(0)
683     //   infeed2 = infeed()
684     // }
685     //
686     // infeed1 = ...
687     // while = while(infeed1), body=while_body // infeed2 has implicit
688     // dependency on infeed1.
689     //
690     // After simplification:
691     //
692     // infeed1 = ...
693     // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
694     //                    // can be scheduled after infeed2.
695     //
696     bool has_side_effects = absl::c_any_of(
697         while_op->called_computations(), [](const HloComputation* computation) {
698           return computation->HasSideEffect();
699         });
700     if (!has_side_effects) {
701       auto computation = while_op->parent();
702       auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
703           while_op->shape(), while_op->operands(), while_op->while_body()));
704       TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
705       TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
706                           CallInliner::Inline(call_op));
707       (void)inlined_instructions_map;
708       return true;
709     } else {
710       VLOG(2) << "Not attempting to simplify while loop because it contains a "
711                  "side-effecting node: "
712               << while_op->ToShortString();
713     }
714   }
715   return false;
716 }
717 
TryPropagateConstant(HloInstruction * while_op)718 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
719   auto while_init = while_op->operand(0);
720   if (while_init->opcode() != HloOpcode::kTuple) {
721     return false;
722   }
723 
724   auto while_body = while_op->while_body();
725   auto while_body_root = while_body->root_instruction();
726   if (while_body_root->opcode() != HloOpcode::kTuple) {
727     return false;
728   }
729 
730   auto while_body_param = while_body->parameter_instruction(0);
731   const HloInstruction::InstructionVector& root_operands =
732       while_body_root->operands();
733 
734   // Find the loop invariant tuple elements with scalar constant init value and
735   // build a map from the tuple element index to the constant value. Limit this
736   // to scalar constant values because propagating array constants can regress
737   // performance by forcing us to copy constants.
738   absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
739   for (int i = 0; i < root_operands.size(); i++) {
740     const HloInstruction* init_tuple_elem = nullptr;
741     if (Match(root_operands[i],
742               m::GetTupleElement(m::Op().Is(while_body_param), i)
743                   .WithShape(m::Shape().IsScalar())) &&
744         Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
745       VLOG(3) << "Found loop invariant tuple element " << i << " "
746               << init_tuple_elem->ToString();
747       index_to_constant[i] = init_tuple_elem;
748     }
749   }
750 
751   if (index_to_constant.empty()) {
752     return false;
753   }
754 
755   // Replace the use of each constant tuple element in the loop_condition and
756   // loop_body with the corresponding constant value.
757   auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
758     HloInstruction* param = computation->parameter_instruction(0);
759     bool changed = false;
760     for (auto instr : param->users()) {
761       // Since only a while-loop with a tuple result reaches here, we can safely
762       // assume that `param` is a tuple and the first operand of the
763       // GetTupleElement instruction is a use of `param`.
764       if (instr->opcode() == HloOpcode::kGetTupleElement) {
765         VLOG(3) << "tuple index " << instr->tuple_index() << " "
766                 << instr->ToString();
767         auto iter = index_to_constant.find(instr->tuple_index());
768         if (iter != index_to_constant.end()) {
769           const HloInstruction* hlo_constant = (*iter).second;
770           VLOG(3) << "Replace use of " << instr->ToString() << " with "
771                   << hlo_constant->ToString();
772           TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
773               computation->AddInstruction(hlo_constant->Clone())));
774           changed = true;
775         }
776       }
777     }
778     return changed;
779   };
780 
781   TF_ASSIGN_OR_RETURN(bool changed_cond,
782                       propagate_constant(while_op->while_condition()));
783   TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
784 
785   return changed_cond || changed_body;
786 }
787 
788 // Converts a flat list of instructions into a tuple of the desired shape.  For
789 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
790 // a tuple of value ((A, B), C).
791 //
792 // desired_shape must be a tuple.  (This precondition allows us to return a
793 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)794 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
795     absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
796     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
797   CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
798 
799   // For each child shape in `desired_shape`, slice out the correct number of
800   // `instrs` and call UnflattenTupleInstr recursively.  At each step we remove
801   // elements from `instrs` so that it only contains instructions we have not
802   // yet processed.
803   std::vector<HloInstruction*> elems;
804   for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
805     const Shape& subshape = desired_shape.tuple_shapes(i);
806     if (!subshape.IsTuple()) {
807       elems.push_back(instrs[0]);
808       instrs.remove_prefix(1);
809       continue;
810     }
811 
812     // Count the number of leaf nodes underneath desired_shape[i].
813     int64 num_leaves = 0;
814     ShapeUtil::ForEachSubshape(
815         subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
816           if (!s.IsTuple()) {
817             ++num_leaves;
818           }
819         });
820 
821     std::unique_ptr<HloInstruction> subinstr =
822         UnflattenTupleInstr(instrs.subspan(0, num_leaves),
823                             desired_shape.tuple_shapes(i), new_instrs);
824     elems.push_back(subinstr.get());
825     new_instrs->push_back(std::move(subinstr));
826     instrs.remove_prefix(num_leaves);
827   }
828   return HloInstruction::CreateTuple(elems);
829 }
830 
831 // Builds a vector whose elements are the values in the flattened tuple for
832 // `instr`.  For example, if `instr` is a tuple of form ((A, B), C), returns the
833 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)834 static std::vector<HloInstruction*> GetFlatTupleElems(
835     HloInstruction* instr,
836     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
837   const auto& shape = instr->shape();
838   if (!shape.IsTuple()) {
839     return {instr};
840   }
841   std::vector<HloInstruction*> elems;
842   for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
843     const Shape& subshape = shape.tuple_shapes(i);
844     new_instrs->push_back(
845         HloInstruction::CreateGetTupleElement(subshape, instr, i));
846     auto* gte = new_instrs->back().get();
847     auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
848     elems.insert(elems.end(), flattened_subshape.begin(),
849                  flattened_subshape.end());
850   }
851   return elems;
852 }
853 
TryFlattenNestedTuples(HloInstruction * while_op)854 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
855   HloModule* module = while_op->GetModule();
856   HloComputation* computation = while_op->parent();
857   auto* while_init = while_op->mutable_operand(0);
858   auto* while_body = while_op->while_body();
859   auto* while_cond = while_op->while_condition();
860   auto* while_body_root = while_body->root_instruction();
861   if (while_init->opcode() != HloOpcode::kTuple ||
862       while_body_root->opcode() != HloOpcode::kTuple) {
863     return false;
864   }
865 
866   TF_RET_CHECK(while_cond->num_parameters() == 1);
867   TF_RET_CHECK(while_body->num_parameters() == 1);
868   TF_RET_CHECK(
869       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
870   Shape while_shape = while_init->shape();
871   if (!ShapeUtil::IsNestedTuple(while_shape)) {
872     return false;
873   }
874 
875   std::vector<Shape> flattened_shape_elems;
876   ShapeUtil::ForEachSubshape(while_shape,
877                              [&](const Shape& s, const ShapeIndex& /*index*/) {
878                                if (!s.IsTuple()) {
879                                  flattened_shape_elems.push_back(s);
880                                }
881                              });
882   Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
883 
884   // `new_instrs` holds instructions created outside of a computation for
885   // cloning.  Elements added here just need to live until the end of the
886   // relevant CloneWithReplacement call.
887   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
888   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
889     new_instrs.push_back(std::move(instr));
890     return new_instrs.back().get();
891   };
892 
893   auto nested = [&](HloInstruction* instr) {
894     std::vector<HloInstruction*> gtes;
895     const Shape& flat_shape = instr->shape();
896     for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
897       gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
898           flat_shape.tuple_shapes(i), instr, i)));
899     }
900     auto nested_instr =
901         UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
902     CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
903         << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
904         << ShapeUtil::HumanString(while_shape);
905     return nested_instr;
906   };
907 
908   auto flattened = [&](HloInstruction* instr) {
909     return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
910   };
911 
912   // Create a new while-condition computation, where parameter 0 has flat shape
913   // but all uses of it go through the nested shape.
914   std::unique_ptr<HloComputation> new_while_cond =
915       while_cond->CloneWithReplacementPairs({
916           while_cond->parameter_instruction(0),
917           nested(add_new_instr(HloInstruction::CreateParameter(
918               0, flattened_shape,
919               while_cond->parameter_instruction(0)->name()))),
920       });
921 
922   // Create a new while-body computation, where parameter 0 has a flat shape and
923   // all uses of it go through the nested shape, and where the root has a flat
924   // shape constructed from the old nested root.
925   std::unique_ptr<HloComputation> new_while_body =
926       while_body->CloneWithReplacementPairs(
927           {
928               while_body->parameter_instruction(0),
929               nested(add_new_instr(HloInstruction::CreateParameter(
930                   0, flattened_shape,
931                   while_body->parameter_instruction(0)->name()))),
932           },
933           {
934               while_body->root_instruction(),
935               flattened(add_new_instr(while_body->root_instruction()->Clone())),
936           });
937 
938   // Create the final while loop, and add any new instructions created to
939   // `computation`.
940   new_instrs.clear();
941   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
942       while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile(
943                     flattened_shape,
944                     module->AddEmbeddedComputation(std::move(new_while_cond)),
945                     module->AddEmbeddedComputation(std::move(new_while_body)),
946                     computation->AddInstruction(flattened(while_init)))))));
947   for (auto& instr : new_instrs) {
948     computation->AddInstruction(std::move(instr));
949   }
950   return true;
951 }
952 
953 // Tries to merge loop induction variables of a given type.
954 //
955 // In this pass we're only concerned with elements of the loop's tuple that
956 // are effective-scalars of type `elem_ty`.  Some terminology:
957 //
958 //  - The trip counter is the first element of the loop's tuple that starts at
959 //    0 and does x++ on each iteration.
960 //
961 //  - An induction variable is an element of the loop's tuple that is not the
962 //    trip counter and does `x += <constant>` on each iteration of the loop.
963 //    Negative constants are OK.
964 //
965 // This pass adds a trip counter if one isn't already present, then replaces
966 // each induction variable with
967 //
968 //   <initial_value> + <trip_count> * <constant>.
969 //
970 // This reduces the number of scalar operations in the loop, which is important
971 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
972 // it's a separate kernel launch.
973 //
974 // Returns the new loop if a change was made, or null if no change was made.
975 // Note that the new loop is not a valid replacement for the old loop; it may
976 // need to be wrapped in a tuple that changes its shape.  We return the loop
977 // itself so that you can call TryMergeInductionVariables in a loop, once for
978 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)979 static StatusOr<HloInstruction*> TryMergeInductionVariables(
980     HloInstruction* while_op, PrimitiveType elem_ty) {
981   CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
982   HloModule* module = while_op->GetModule();
983   HloComputation* computation = while_op->parent();
984   auto* while_init = while_op->mutable_operand(0);
985   auto* while_body = while_op->while_body();
986   auto* while_cond = while_op->while_condition();
987   auto* while_body_root = while_body->root_instruction();
988   if (while_init->opcode() != HloOpcode::kTuple ||
989       while_body_root->opcode() != HloOpcode::kTuple) {
990     return nullptr;
991   }
992 
993   TF_RET_CHECK(while_cond->num_parameters() == 1);
994   TF_RET_CHECK(while_body->num_parameters() == 1);
995   TF_RET_CHECK(
996       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
997   Shape while_shape = while_init->shape();
998 
999   // The tuple index of the trip counter, if one is present.
1000   absl::optional<int64> trip_counter;
1001   // Maps the tuple index of each induction variable to its constant increment.
1002   absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
1003   for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
1004     HloInstruction* constant;
1005     if (!Match(while_body_root->mutable_operand(i),
1006                m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
1007                               m::ConstantScalar(&constant))
1008                    .WithShape(m::Shape().WithElementType(elem_ty)))) {
1009       continue;
1010     }
1011     if (!trip_counter && constant->literal().IsAll(1) &&
1012         while_init->operand(i)->IsConstant() &&
1013         while_init->operand(i)->literal().IsAll(0)) {
1014       VLOG(10) << "Found existing trip counter at index " << i;
1015       trip_counter = i;
1016     } else {
1017       VLOG(10) << "Found induction variable at index " << i;
1018       induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
1019     }
1020   }
1021 
1022   // There's only something to simplify if we can either:
1023   //
1024   //  - combine one or more induction vars with an existing trip counter, or
1025   //  - replace two or more induction variables with a new trip counter.
1026   //
1027   // Put another way, there's only something to simplify if the number of
1028   // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
1029   if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
1030     return nullptr;
1031   }
1032 
1033   // OK, we're going to do the transformation!  Set up some helpers.
1034 
1035   // `new_instrs` holds instructions created outside of a computation for
1036   // cloning.  Elements added here just need to live until the end of the
1037   // relevant CloneWithReplacement call.
1038   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
1039   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
1040     new_instrs.push_back(std::move(instr));
1041     return new_instrs.back().get();
1042   };
1043 
1044   auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
1045                            HloInstruction* lhs, HloInstruction* rhs) {
1046     // Reshape lhs/rhs to the output shape if necessary.  This deals with the
1047     // fact that induction variables need only be effective scalars, not true
1048     // scalars.
1049     if (!ShapeUtil::Compatible(shape, lhs->shape())) {
1050       lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
1051     }
1052     if (!ShapeUtil::Compatible(shape, rhs->shape())) {
1053       rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
1054     }
1055     return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
1056   };
1057 
1058   auto add_gte = [&](HloInstruction* src, int64 idx) {
1059     return add_new_instr(HloInstruction::CreateGetTupleElement(
1060         src->shape().tuple_shapes(idx), src, idx));
1061   };
1062 
1063   // Our new while loop will have the same shape as the old while loop, except
1064   // we'll add a trip counter to the end if it wasn't originally present.
1065   Shape new_while_shape = while_shape;
1066   bool added_trip_counter = false;
1067   if (!trip_counter) {
1068     VLOG(10) << "Adding new trip counter to end of loop's tuple.";
1069     trip_counter = new_while_shape.tuple_shapes_size();
1070     *new_while_shape.add_tuple_shapes() =
1071         ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
1072     added_trip_counter = true;
1073   }
1074 
1075   // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
1076   // shape `while_body->shape()` and where the induction variables are "reified"
1077   // (i.e. they have value <init> + <counter> * <constant>).
1078   auto convert_to_old_form = [&](HloInstruction* instr) {
1079     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
1080     std::vector<HloInstruction*> tuple_elems;
1081     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1082       const auto& elem_shape = while_shape.tuple_shapes(i);
1083       if (!induction_vars.count(i)) {
1084         tuple_elems.push_back(add_gte(instr, i));
1085         continue;
1086       }
1087       tuple_elems.push_back(add_binary_op(
1088           elem_shape, HloOpcode::kAdd, add_gte(instr, i),
1089           add_binary_op(elem_shape, HloOpcode::kMultiply,
1090                         add_gte(instr, *trip_counter),
1091                         add_new_instr(induction_vars.at(i)->Clone()))));
1092     }
1093     return HloInstruction::CreateTuple(tuple_elems);
1094   };
1095 
1096   // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
1097   // shape `new_while_shape` and where the induction variables (but not trip
1098   // counters) are replaced with their unchanging <loop_body_param> values.
1099   auto convert_to_new_form = [&](HloInstruction* old_root,
1100                                  HloParameterInstruction* loop_body_param) {
1101     CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
1102     std::vector<HloInstruction*> tuple_elems;
1103 
1104     // In the new form, induction variables come from `init`, everything else
1105     // (including the trip counter if it's not one we created ourselves) comes
1106     // from the `root` tuple unmodified.
1107     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1108       tuple_elems.push_back(
1109           add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
1110     }
1111     // If we created a trip counter ourselves, add 1 to it in the next
1112     // iteration.
1113     if (added_trip_counter) {
1114       tuple_elems.push_back(add_binary_op(
1115           new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
1116           add_gte(loop_body_param, *trip_counter),
1117           add_new_instr(
1118               HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
1119     }
1120 
1121     return HloInstruction::CreateTuple(tuple_elems);
1122   };
1123 
1124   // Creates a new init tuple, which is the same as the old init tuple except if
1125   // we added a trip counter, it's set to 0.
1126   auto get_new_while_init = [&](HloInstruction* init) {
1127     CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
1128     if (!added_trip_counter) {
1129       return init;
1130     }
1131     std::vector<HloInstruction*> tuple_elems;
1132     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1133       tuple_elems.push_back(add_gte(init, i));
1134     }
1135     tuple_elems.push_back(add_new_instr(
1136         HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
1137     return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
1138   };
1139 
1140   std::unique_ptr<HloComputation> new_while_cond =
1141       while_cond->CloneWithReplacementPairs({
1142           while_cond->parameter_instruction(0),
1143           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1144               0, new_while_shape,
1145               while_cond->parameter_instruction(0)->name()))),
1146       });
1147 
1148   // Creating the new while body proceeds in two steps.  First we convert the
1149   // users of the parameter to the old form.  Then as a second
1150   // CloneWithReplacement operation we convert the root to the new form.  We
1151   // have to do this in two steps because the new root needs to use the new
1152   // param0, and during the first clone operation, only the *old-form* param0 is
1153   // accessible.
1154   //
1155   // We have to add temp_new_while_body to the module because cloning a
1156   // computation touches the module (to get its NameUniquer).
1157   HloComputation* temp_new_while_body =
1158       module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
1159           while_body->parameter_instruction(0),
1160           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1161               0, new_while_shape,
1162               while_body->parameter_instruction(0)->name()))),
1163       }));
1164   std::unique_ptr<HloComputation> new_while_body =
1165       temp_new_while_body->CloneWithReplacementPairs({
1166           temp_new_while_body->root_instruction(),
1167           convert_to_new_form(
1168               add_new_instr(temp_new_while_body->root_instruction()->Clone()),
1169               Cast<HloParameterInstruction>(
1170                   temp_new_while_body->parameter_instruction(0))),
1171       });
1172   TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
1173 
1174   // Create the final while loop, and add any new instructions created to
1175   // `computation`.
1176   new_instrs.clear();
1177   auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
1178       new_while_shape,
1179       module->AddEmbeddedComputation(std::move(new_while_cond)),
1180       module->AddEmbeddedComputation(std::move(new_while_body)),
1181       get_new_while_init(while_init)));
1182   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
1183       while_op, convert_to_old_form(new_while)));
1184   for (auto& instr : new_instrs) {
1185     computation->AddInstruction(std::move(instr));
1186   }
1187   return new_while;
1188 }
1189 
Run(HloModule * module)1190 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
1191   XLA_VLOG_LINES(3,
1192                  "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
1193   bool changed = false;
1194 
1195   // Gather all the while ops in our module.  We do this ahead of time so we
1196   // don't have to worry about mutating the lists of computations or
1197   // instructions while we iterate.
1198   std::vector<HloInstruction*> while_ops;
1199   for (auto* comp : module->computations()) {
1200     for (auto* instr : comp->instructions()) {
1201       if (instr->opcode() == HloOpcode::kWhile) {
1202         while_ops.push_back(instr);
1203       }
1204     }
1205   }
1206 
1207   for (HloInstruction* while_op : while_ops) {
1208     // We can't remove while loops that contain send/recv nodes, because we rely
1209     // on the particular loop structure around the node matching on the send and
1210     // recv sides.  Other while simplifications require us to remove the loop
1211     // and replace it with a new one, so we can't do that either.
1212     if (ContainsInstrWithOpcode(while_op->while_body(),
1213                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1214                                  HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1215         ContainsInstrWithOpcode(while_op->while_condition(),
1216                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1217                                  HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1218       VLOG(2) << "Not attempting to simplify while loop because it contains a "
1219                  "send/recv node: "
1220               << while_op->ToShortString();
1221       continue;
1222     }
1223 
1224     TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1225     changed |= result;
1226 
1227     TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1228     changed |= result;
1229 
1230     if (result) {
1231       // Don't continue simplifying after successfully removing the while loop
1232       // -- that would result in use-after-free nastiness.
1233       continue;
1234     }
1235 
1236     // TODO(b/119281462): Cowardly refuse to perform any of the following
1237     // optimizations in the presence of kDomain instructions.  It seems that
1238     // modifying a while loop's tuple doesn't work when kDomain is present.
1239     if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1240         ContainsInstrWithOpcode(while_op->while_condition(),
1241                                 {HloOpcode::kDomain})) {
1242       continue;
1243     }
1244 
1245     // Each of the optimizations below modifies the while loop itself if it's
1246     // successful, meaning that `while_op` is no longer valid after one of these
1247     // transformations returns true.
1248 
1249     TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
1250     changed |= result;
1251     if (result) {
1252       continue;
1253     }
1254 
1255     TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1256     changed |= result;
1257     if (result) {
1258       continue;
1259     }
1260 
1261     TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1262 
1263     changed |= result;
1264     if (result) {
1265       continue;
1266     }
1267 
1268     TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1269     changed |= result;
1270     if (result) {
1271       continue;
1272     }
1273 
1274     bool merged_induction_vars = false;
1275     // Notably missing from this list are S16 and U16.  These don't currently
1276     // work because S/U16 literals are not implemented.
1277     for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1278       TF_ASSIGN_OR_RETURN(auto* new_while_op,
1279                           TryMergeInductionVariables(while_op, elem_ty));
1280       if (new_while_op) {
1281         while_op = new_while_op;
1282         changed = true;
1283         merged_induction_vars = true;
1284       }
1285     }
1286     if (merged_induction_vars) {
1287       continue;
1288     }
1289   }
1290 
1291   XLA_VLOG_LINES(3,
1292                  "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1293   return changed;
1294 }
1295 
1296 }  // namespace xla
1297