1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/call_inliner.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_query.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
29 
30 namespace xla {
31 
32 namespace m = match;
33 using absl::optional;
34 using hlo_query::ContainsInstrWithOpcode;
35 
36 // Tries to remove elements in a while loop's tuple that aren't used within the
37 // loop.
38 //
39 // Specifically, if a loop is tuple-shaped, and there exists some element of
40 // that tuple that is not used by the loop condition and is not used by the loop
41 // body except to pass it to the next iteration of the loop, then we can remove
42 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)43 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
44   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
45 
46   // Don't try this transformation if the while loop isn't removable, since if
47   // it succeeds ultimately we're going to have to replace the old while loop
48   // with a new one.
49   if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
50     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
51     return false;
52   }
53 
54   HloModule* module = while_op->GetModule();
55   HloComputation* computation = while_op->parent();
56   HloInstruction* while_init = while_op->mutable_operand(0);
57   HloComputation* while_cond = while_op->while_condition();
58   HloComputation* while_body = while_op->while_body();
59   HloInstruction* while_body_root = while_body->root_instruction();
60 
61   if (!while_init->shape().IsTuple()) {
62     VLOG(2) << "While op's carried value isn't tuple shaped.";
63     return false;
64   }
65 
66   if (while_body_root->opcode() != HloOpcode::kTuple) {
67     VLOG(2) << "While body's root is not a tuple(...) instruction.";
68     return false;
69   }
70 
71   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
72 
73   // Bail if param0 of while_cond or while_body has users which aren't of type
74   // get-tuple-element.
75   for (const HloInstruction* instr : {while_body->parameter_instruction(0),
76                                       while_cond->parameter_instruction(0)}) {
77     for (const HloInstruction* user : instr->users()) {
78       if (user->opcode() != HloOpcode::kGetTupleElement) {
79         VLOG(2) << "Cowardly refusing to analyze while loop with "
80                 << instr->ToString(print_no_metadata)
81                 << " used by non-GTE instruction "
82                 << user->ToString(print_no_metadata) << " in computation "
83                 << instr->parent()->name();
84         return false;
85       }
86     }
87   }
88 
89   const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
90   if (tuple_size == 0) {
91     VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
92                "empty.";
93     return false;
94   }
95 
96   absl::flat_hash_set<int64> used_tuple_indices;
97   for (HloComputation* comp : {while_body, while_cond}) {
98     // The HLO verifier ensures that while_input's shape matches while_init's
99     // shape, which we verified above is a tuple.
100     HloInstruction* while_input = comp->parameter_instruction(0);
101 
102     for (const HloInstruction* user : while_input->users()) {
103       // This user doesn't count if it's only used by the while body's root, and
104       // the root places the tuple element into the same index of the tuple as
105       // it came from.  That just amounts to us carrying the variable through
106       // the loop.
107       //
108       // Careful: HloInstruction::operand_index returns the first index the
109       // operand appears in, but it may appear more than once!
110       if (user->user_count() == 1 && user->users().front() == while_body_root &&
111           while_body_root->operand_index(user) == user->tuple_index() &&
112           absl::c_count(while_body_root->operands(), user) == 1) {
113         continue;
114       }
115 
116       used_tuple_indices.insert(user->tuple_index());
117       if (used_tuple_indices.size() == tuple_size) {
118         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
119                 << " uses all of its inputs; no simplification possible.";
120         return false;
121       }
122     }
123   }
124 
125   // If a tuple element is not passed unmodified from the while body's param0
126   // through to the while body's root, count that element as "used", since
127   // removing that element would be observable.
128   for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
129     if (used_tuple_indices.contains(i)) {
130       continue;
131     }
132 
133     auto* operand = while_body_root->operand(i);
134     if (operand->opcode() != HloOpcode::kGetTupleElement ||
135         operand->operand(0) != while_body->parameter_instruction(0) ||
136         operand->tuple_index() != i) {
137       VLOG(2) << "Tuple index " << i
138               << " is not passed through loop body unmodified.";
139       used_tuple_indices.insert(i);
140 
141       if (used_tuple_indices.size() == tuple_size) {
142         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
143                 << " uses all of its inputs; no simplification possible.";
144         return false;
145       }
146     }
147   }
148 
149   // If we got here, used_tuple_indices.size() < tuple_size, meaning some
150   // elements of the loop's tuple aren't used by while_body or while_cond.
151   CHECK_LT(used_tuple_indices.size(), tuple_size);
152 
153   VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
154           << " elements from tuple of "
155           << while_op->ToString(print_no_metadata);
156 
157   // Build up maps from the old/new to the new/old tuple indices.
158   std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
159                                           used_tuple_indices.end());
160   absl::c_sort(new_to_old_tuple_idx);
161 
162   absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
163   for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
164     int64 old_idx = new_to_old_tuple_idx[new_idx];
165     old_to_new_tuple_idx[old_idx] = new_idx;
166     VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
167   }
168 
169   // Compute the shape of the while op after we remove the dead indices.
170   std::vector<Shape> new_while_tuple_elem_shapes;
171   new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
172   for (int64 old_idx : new_to_old_tuple_idx) {
173     new_while_tuple_elem_shapes.push_back(
174         while_init->shape().tuple_shapes(old_idx));
175   }
176   Shape new_while_shape =
177       ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
178 
179   // Returns a map from elements in the computation to new instructions which
180   // replace the old instructions after we remove unused elements from the while
181   // tuple.
182   auto make_while_computation_replacements = [&](const HloComputation* comp) {
183     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
184         replacements;
185 
186     auto* param = comp->parameter_instruction(0);
187     replacements.emplace(param, HloInstruction::CreateParameter(
188                                     0, new_while_shape, param->name()));
189 
190     // Materialize param's users, since we're about to add new ones below.
191     std::vector<HloInstruction*> materialized_users(param->users().begin(),
192                                                     param->users().end());
193     for (const auto* user : materialized_users) {
194       // The while body root is handled separately.
195       if (user == while_body_root) {
196         continue;
197       }
198       CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
199           << user->ToString(print_no_metadata);
200 
201       int64 old_idx = user->tuple_index();
202       auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
203       if (new_idx_iter != old_to_new_tuple_idx.end()) {
204         // This is a GTE of an index that survives.  Replace it.
205         replacements.emplace(
206             user, HloInstruction::CreateGetTupleElement(user->shape(), param,
207                                                         new_idx_iter->second));
208       } else {
209         // This is a GTE of an index that we've removed.  Remove it from the
210         // cloned computation.
211         CHECK(user->user_count() == 0 ||
212               user->user_count() == 1 &&
213                   user->users().front() == while_body_root)
214             << "Instruction " << user->ToString(print_no_metadata)
215             << " should be unused (except by root of while body), but has "
216                "users: {"
217             << absl::StrJoin(user->users(), ", ",
218                              [&](string* out, const HloInstruction* instr) {
219                                absl::StrAppend(
220                                    out, instr->ToString(print_no_metadata));
221                              })
222             << "}";
223 
224         replacements.emplace(user, nullptr);
225       }
226     }
227     return replacements;
228   };
229 
230   // Create the new while condition, body, and init value.
231   std::unique_ptr<HloComputation> new_while_cond =
232       while_cond->CloneWithReplacements(
233           make_while_computation_replacements(while_cond));
234 
235   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
236       while_body_replacements = make_while_computation_replacements(while_body);
237   std::vector<HloInstruction*> new_while_body_root_elems;
238   new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
239   for (int64 old_idx : new_to_old_tuple_idx) {
240     new_while_body_root_elems.push_back(
241         while_body_root->mutable_operand(old_idx));
242   }
243   while_body_replacements.emplace(
244       while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
245   std::unique_ptr<HloComputation> new_while_body =
246       while_body->CloneWithReplacements(std::move(while_body_replacements));
247 
248   // Add a new while_init instruction that repackages the old while_init
249   // instruction's elements.  We rely on the AlgebraicSimplifier and DCE to
250   // clean this up in the common case where while_init is a tuple op.  (It's
251   // definitely tuple-shaped, but it's not necessarily a tuple op.)
252   std::vector<HloInstruction*> new_while_init_elems;
253   new_while_init_elems.reserve(new_to_old_tuple_idx.size());
254   for (int64 old_idx : new_to_old_tuple_idx) {
255     new_while_init_elems.push_back(
256         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
257             while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
258   }
259   auto* new_while_init = computation->AddInstruction(
260       HloInstruction::CreateTuple(new_while_init_elems));
261 
262   // Create the new while op.
263   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
264       new_while_shape,
265       module->AddEmbeddedComputation(std::move(new_while_cond)),
266       module->AddEmbeddedComputation(std::move(new_while_body)),
267       new_while_init));
268 
269   // Create a tuple op that recreates the output of the old while op.  That is,
270   // we transform to
271   //
272   //  new_while_init   while_init
273   //       |              |
274   //       V              |
275   //   new_while          |
276   //       |              |
277   //       -------|   |----
278   //              V   V
279   //            new_tuple
280   //                |
281   //                V
282   //    (orig. users of while op)
283   //
284   // The tuple simplifier will then simplify this if possible, removing
285   // new_tuple and while_init.
286   std::vector<HloInstruction*> new_tuple_elems;
287   for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
288     auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
289     if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
290       int64 gte_idx = new_tuple_idx_it->second;
291       new_tuple_elems.push_back(
292           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
293               new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
294               gte_idx)));
295     } else {
296       new_tuple_elems.push_back(
297           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
298               while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
299     }
300   }
301   HloInstruction* new_tuple =
302       computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
303   TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple));
304 
305   return true;
306 }
307 
308 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
309 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)310 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
311   HloModule* module = while_op->GetModule();
312   HloComputation* computation = while_op->parent();
313   auto* while_init = while_op->mutable_operand(0);
314   auto* while_body = while_op->while_body();
315   auto* while_cond = while_op->while_condition();
316   auto* while_body_root = while_body->root_instruction();
317   if (while_init->opcode() != HloOpcode::kTuple ||
318       while_body_root->opcode() != HloOpcode::kTuple) {
319     return false;
320   }
321 
322   TF_RET_CHECK(while_cond->num_parameters() == 1);
323   TF_RET_CHECK(while_body->num_parameters() == 1);
324   TF_RET_CHECK(
325       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
326 
327   absl::flat_hash_set<int64> constant_tuple_indices;
328   const auto& while_shape = while_init->shape();
329   for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
330     auto* init_elem = while_init->operand(i);
331     auto* body_elem = while_body_root->operand(i);
332     if (init_elem->opcode() == HloOpcode::kConstant &&
333         body_elem->opcode() == HloOpcode::kConstant &&
334         init_elem->literal() == body_elem->literal()) {
335       constant_tuple_indices.insert(i);
336     }
337   }
338 
339   if (constant_tuple_indices.empty()) {
340     return false;
341   }
342 
343   // OK, we found some constant elements of the while parameter!  Eliminate
344   // them.
345   std::vector<Shape> new_while_shape_elems;
346   for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
347     if (!constant_tuple_indices.count(i)) {
348       new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
349     }
350   }
351   Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
352 
353   // `new_instrs` holds instructions created outside of a computation for
354   // cloning.  Elements added here just need to live until the end of the
355   // relevant CloneWithReplacement call.
356   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
357   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
358     new_instrs.push_back(std::move(instr));
359     return new_instrs.back().get();
360   };
361 
362   // Returns a new tuple without the elements of constant_tuple_indices.
363   auto remove_constant_elems = [&](HloInstruction* instr) {
364     CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
365 
366     std::vector<HloInstruction*> tuple_elems;
367     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
368       if (!constant_tuple_indices.count(i)) {
369         tuple_elems.push_back(
370             add_new_instr(HloInstruction::CreateGetTupleElement(
371                 while_shape.tuple_shapes(i), instr, i)));
372       }
373     }
374     return HloInstruction::CreateTuple(tuple_elems);
375   };
376 
377   auto add_constant_elems = [&](HloInstruction* instr) {
378     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
379 
380     std::vector<HloInstruction*> tuple_elems;
381     int64 j = 0;
382     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
383       if (constant_tuple_indices.count(i)) {
384         tuple_elems.push_back(while_init->mutable_operand(i));
385       } else {
386         tuple_elems.push_back(
387             add_new_instr(HloInstruction::CreateGetTupleElement(
388                 while_shape.tuple_shapes(i), instr, j)));
389         ++j;
390       }
391     }
392     return HloInstruction::CreateTuple(tuple_elems);
393   };
394 
395   // Special case: constant_tuple_indices covers the whole while parameter, so
396   // the new while shape is the empty tuple.  In this case, the value of the
397   // while loop is simply equal to the value of `init`.
398   //
399   // It's unfortunate to special-case this, but it's simpler than the
400   // alternative.  The problem is that if our while parameter has no
401   // non-constant elems, the tuple returned by `add_constant_elems` won't depend
402   // on instr (the loop body/cond parameter), and therefore
403   // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
404   // invalid HLO.
405   if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
406     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
407     return true;
408   }
409 
410   std::unique_ptr<HloComputation> new_while_cond =
411       while_cond->CloneWithReplacementPairs({
412           while_cond->parameter_instruction(0),
413           add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
414               0, new_while_shape,
415               while_cond->parameter_instruction(0)->name()))),
416       });
417 
418   std::unique_ptr<HloComputation> new_while_body =
419       while_body->CloneWithReplacementPairs(
420           {
421               while_body->parameter_instruction(0),
422               add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
423                   0, new_while_shape,
424                   while_cond->parameter_instruction(0)->name()))),
425           },
426           {
427               while_body->root_instruction(),
428               remove_constant_elems(
429                   add_new_instr(while_body->root_instruction()->Clone())),
430           });
431 
432   // Create the final while loop, and add any new instructions created to
433   // `computation`.
434   new_instrs.clear();
435   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
436       while_op,
437       add_constant_elems(
438           computation->AddInstruction(HloInstruction::CreateWhile(
439               new_while_shape,
440               module->AddEmbeddedComputation(std::move(new_while_cond)),
441               module->AddEmbeddedComputation(std::move(new_while_body)),
442               add_new_instr(remove_constant_elems(while_init)))))));
443   for (auto& instr : new_instrs) {
444     computation->AddInstruction(std::move(instr));
445   }
446   return true;
447 }
448 
449 // Tries to remove a while loop from the graph.
450 //
451 //  - Loops with trip count of 0 can be replaced by the loop's "init" value.
452 //  - Loops with trip count of 1 can be replaced by the loop's body, with the
453 //    loop itself removed.
454 //
455 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)456 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
457   // Cowardly refuse to remove loops that are not removable.  In practice,
458   // this means that we can't remove loops that contain side-effecting
459   // instructions or have control predecessors/successors.
460   //
461   // This is not a fundamental limitation.  The control operands can be moved
462   // onto the new HLOs after simplification, and any side-effecting ops inside
463   // the loop aren't removed, just cloned and added back to the loop.  But
464   // moving an op out of the loop also removes implicit control dependencies
465   // between the op and the ops outside the loop, so we'd have to add those back
466   // for things like infeed/outfeed.  It gets complicated.  So for now we just
467   // avoid it.
468   if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
469     VLOG(2) << "Not attempting to remove while loop it is not removable: "
470             << while_op->ToShortString();
471     return false;
472   }
473 
474   // Remove while loops with static trip count of 0.
475   optional<int64> trip_count =
476       ComputeWhileLoopTripCount(while_op,
477                                 /*max_value_returned=*/1);
478   if (trip_count && *trip_count == 0) {
479     // The loop never executes, so the value of the loop is the value of its
480     // "init" operand.
481     auto computation = while_op->parent();
482 
483     // Remove while_op (i.e., call ReplaceInstruction rather than
484     // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
485     // a loop without an intervening DCE, we don't try to re-remove the loop.
486     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
487         while_op, while_op->mutable_operand(0)));
488     return true;
489   }
490 
491   // Transform while loops with static trip count of 1 into a call op, then
492   // inline the call.
493   if (trip_count && *trip_count == 1) {
494     auto computation = while_op->parent();
495     auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
496         while_op->shape(), while_op->operands(), while_op->while_body()));
497     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
498     TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
499                         CallInliner::Inline(call_op));
500     (void)inlined_instructions_map;
501     return true;
502   }
503   return false;
504 }
505 
TryPropagateConstant(HloInstruction * while_op)506 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
507   auto while_init = while_op->operand(0);
508   if (while_init->opcode() != HloOpcode::kTuple) {
509     return false;
510   }
511 
512   auto while_body = while_op->while_body();
513   auto while_body_root = while_body->root_instruction();
514   if (while_body_root->opcode() != HloOpcode::kTuple) {
515     return false;
516   }
517 
518   auto while_body_param = while_body->parameter_instruction(0);
519   const HloInstruction::InstructionVector& root_operands =
520       while_body_root->operands();
521 
522   // Find the loop invariant tuple elements with scalar constant init value and
523   // build a map from the tuple element index to the constant value. Limit this
524   // to scalar constant values because propagating array constants can regress
525   // performance by forcing us to copy constants.
526   absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
527   for (int i = 0; i < root_operands.size(); i++) {
528     const HloInstruction* init_tuple_elem = nullptr;
529     if (Match(root_operands[i],
530               m::GetTupleElement(m::Op().Is(while_body_param), i)
531                   .WithShape(m::Shape().IsScalar())) &&
532         Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
533       VLOG(3) << "Found loop invariant tuple element " << i << " "
534               << init_tuple_elem->ToString();
535       index_to_constant[i] = init_tuple_elem;
536     }
537   }
538 
539   if (index_to_constant.empty()) {
540     return false;
541   }
542 
543   // Replace the use of each constant tuple element in the loop_condition and
544   // loop_body with the corresponding constant value.
545   auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
546     HloInstruction* param = computation->parameter_instruction(0);
547     bool changed = false;
548     for (auto instr : param->users()) {
549       // Since only a while-loop with a tuple result reaches here, we can safely
550       // assume that `param` is a tuple and the first operand of the
551       // GetTupleElement instruction is a use of `param`.
552       if (instr->opcode() == HloOpcode::kGetTupleElement) {
553         VLOG(3) << "tuple index " << instr->tuple_index() << " "
554                 << instr->ToString();
555         auto iter = index_to_constant.find(instr->tuple_index());
556         if (iter != index_to_constant.end()) {
557           const HloInstruction* hlo_constant = (*iter).second;
558           VLOG(3) << "Replace use of " << instr->ToString() << " with "
559                   << hlo_constant->ToString();
560           TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
561               computation->AddInstruction(hlo_constant->Clone())));
562           changed = true;
563         }
564       }
565     }
566     return changed;
567   };
568 
569   TF_ASSIGN_OR_RETURN(bool changed_cond,
570                       propagate_constant(while_op->while_condition()));
571   TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
572 
573   return changed_cond || changed_body;
574 }
575 
576 // Converts a flat list of instructions into a tuple of the desired shape.  For
577 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
578 // a tuple of value ((A, B), C).
579 //
580 // desired_shape must be a tuple.  (This precondition allows us to return a
581 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)582 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
583     absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
584     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
585   CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
586 
587   // For each child shape in `desired_shape`, slice out the correct number of
588   // `instrs` and call UnflattenTupleInstr recursively.  At each step we remove
589   // elements from `instrs` so that it only contains instructions we have not
590   // yet processed.
591   std::vector<HloInstruction*> elems;
592   for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
593     const Shape& subshape = desired_shape.tuple_shapes(i);
594     if (!subshape.IsTuple()) {
595       elems.push_back(instrs[0]);
596       instrs.remove_prefix(1);
597       continue;
598     }
599 
600     // Count the number of leaf nodes underneath desired_shape[i].
601     int64 num_leaves = 0;
602     ShapeUtil::ForEachSubshape(
603         subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
604           if (!s.IsTuple()) {
605             ++num_leaves;
606           }
607         });
608 
609     std::unique_ptr<HloInstruction> subinstr =
610         UnflattenTupleInstr(instrs.subspan(0, num_leaves),
611                             desired_shape.tuple_shapes(i), new_instrs);
612     elems.push_back(subinstr.get());
613     new_instrs->push_back(std::move(subinstr));
614     instrs.remove_prefix(num_leaves);
615   }
616   return HloInstruction::CreateTuple(elems);
617 }
618 
619 // Builds a vector whose elements are the values in the flattened tuple for
620 // `instr`.  For example, if `instr` is a tuple of form ((A, B), C), returns the
621 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)622 static std::vector<HloInstruction*> GetFlatTupleElems(
623     HloInstruction* instr,
624     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
625   const auto& shape = instr->shape();
626   if (!shape.IsTuple()) {
627     return {instr};
628   }
629   std::vector<HloInstruction*> elems;
630   for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
631     const Shape& subshape = shape.tuple_shapes(i);
632     new_instrs->push_back(
633         HloInstruction::CreateGetTupleElement(subshape, instr, i));
634     auto* gte = new_instrs->back().get();
635     auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
636     elems.insert(elems.end(), flattened_subshape.begin(),
637                  flattened_subshape.end());
638   }
639   return elems;
640 }
641 
TryFlattenNestedTuples(HloInstruction * while_op)642 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
643   HloModule* module = while_op->GetModule();
644   HloComputation* computation = while_op->parent();
645   auto* while_init = while_op->mutable_operand(0);
646   auto* while_body = while_op->while_body();
647   auto* while_cond = while_op->while_condition();
648   auto* while_body_root = while_body->root_instruction();
649   if (while_init->opcode() != HloOpcode::kTuple ||
650       while_body_root->opcode() != HloOpcode::kTuple) {
651     return false;
652   }
653 
654   TF_RET_CHECK(while_cond->num_parameters() == 1);
655   TF_RET_CHECK(while_body->num_parameters() == 1);
656   TF_RET_CHECK(
657       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
658   Shape while_shape = while_init->shape();
659   if (!ShapeUtil::IsNestedTuple(while_shape)) {
660     return false;
661   }
662 
663   std::vector<Shape> flattened_shape_elems;
664   ShapeUtil::ForEachSubshape(while_shape,
665                              [&](const Shape& s, const ShapeIndex& /*index*/) {
666                                if (!s.IsTuple()) {
667                                  flattened_shape_elems.push_back(s);
668                                }
669                              });
670   Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
671 
672   // `new_instrs` holds instructions created outside of a computation for
673   // cloning.  Elements added here just need to live until the end of the
674   // relevant CloneWithReplacement call.
675   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
676   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
677     new_instrs.push_back(std::move(instr));
678     return new_instrs.back().get();
679   };
680 
681   auto nested = [&](HloInstruction* instr) {
682     std::vector<HloInstruction*> gtes;
683     const Shape& flat_shape = instr->shape();
684     for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
685       gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
686           flat_shape.tuple_shapes(i), instr, i)));
687     }
688     auto nested_instr =
689         UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
690     CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
691         << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
692         << ShapeUtil::HumanString(while_shape);
693     return nested_instr;
694   };
695 
696   auto flattened = [&](HloInstruction* instr) {
697     return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
698   };
699 
700   // Create a new while-condition computation, where parameter 0 has flat shape
701   // but all uses of it go through the nested shape.
702   std::unique_ptr<HloComputation> new_while_cond =
703       while_cond->CloneWithReplacementPairs({
704           while_cond->parameter_instruction(0),
705           nested(add_new_instr(HloInstruction::CreateParameter(
706               0, flattened_shape,
707               while_cond->parameter_instruction(0)->name()))),
708       });
709 
710   // Create a new while-body computation, where parameter 0 has a flat shape and
711   // all uses of it go through the nested shape, and where the root has a flat
712   // shape constructed from the old nested root.
713   std::unique_ptr<HloComputation> new_while_body =
714       while_body->CloneWithReplacementPairs(
715           {
716               while_body->parameter_instruction(0),
717               nested(add_new_instr(HloInstruction::CreateParameter(
718                   0, flattened_shape,
719                   while_body->parameter_instruction(0)->name()))),
720           },
721           {
722               while_body->root_instruction(),
723               flattened(add_new_instr(while_body->root_instruction()->Clone())),
724           });
725 
726   // Create the final while loop, and add any new instructions created to
727   // `computation`.
728   new_instrs.clear();
729   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
730       while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile(
731                     flattened_shape,
732                     module->AddEmbeddedComputation(std::move(new_while_cond)),
733                     module->AddEmbeddedComputation(std::move(new_while_body)),
734                     computation->AddInstruction(flattened(while_init)))))));
735   for (auto& instr : new_instrs) {
736     computation->AddInstruction(std::move(instr));
737   }
738   return true;
739 }
740 
741 // Tries to merge loop induction variables of a given type.
742 //
743 // In this pass we're only concerned with elements of the loop's tuple that
744 // are effective-scalars of type `elem_ty`.  Some terminology:
745 //
746 //  - The trip counter is the first element of the loop's tuple that starts at
747 //    0 and does x++ on each iteration.
748 //
749 //  - An induction variable is an element of the loop's tuple that is not the
750 //    trip counter and does `x += <constant>` on each iteration of the loop.
751 //    Negative constants are OK.
752 //
753 // This pass adds a trip counter if one isn't already present, then replaces
754 // each induction variable with
755 //
756 //   <initial_value> + <trip_count> * <constant>.
757 //
758 // This reduces the number of scalar operations in the loop, which is important
759 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
760 // it's a separate kernel launch.
761 //
762 // Returns the new loop if a change was made, or null if no change was made.
763 // Note that the new loop is not a valid replacement for the old loop; it may
764 // need to be wrapped in a tuple that changes its shape.  We return the loop
765 // itself so that you can call TryMergeInductionVariables in a loop, once for
766 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)767 static StatusOr<HloInstruction*> TryMergeInductionVariables(
768     HloInstruction* while_op, PrimitiveType elem_ty) {
769   CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
770   HloModule* module = while_op->GetModule();
771   HloComputation* computation = while_op->parent();
772   auto* while_init = while_op->mutable_operand(0);
773   auto* while_body = while_op->while_body();
774   auto* while_cond = while_op->while_condition();
775   auto* while_body_root = while_body->root_instruction();
776   if (while_init->opcode() != HloOpcode::kTuple ||
777       while_body_root->opcode() != HloOpcode::kTuple) {
778     return nullptr;
779   }
780 
781   TF_RET_CHECK(while_cond->num_parameters() == 1);
782   TF_RET_CHECK(while_body->num_parameters() == 1);
783   TF_RET_CHECK(
784       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
785   Shape while_shape = while_init->shape();
786 
787   // The tuple index of the trip counter, if one is present.
788   absl::optional<int64> trip_counter;
789   // Maps the tuple index of each induction variable to its constant increment.
790   absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
791   for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
792     HloInstruction* constant;
793     if (!Match(while_body_root->mutable_operand(i),
794                m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
795                               m::ConstantScalar(&constant))
796                    .WithShape(m::Shape().WithElementType(elem_ty)))) {
797       continue;
798     }
799     if (!trip_counter && constant->literal().IsAll(1) &&
800         while_init->operand(i)->IsConstant() &&
801         while_init->operand(i)->literal().IsAll(0)) {
802       VLOG(10) << "Found existing trip counter at index " << i;
803       trip_counter = i;
804     } else {
805       VLOG(10) << "Found induction variable at index " << i;
806       induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
807     }
808   }
809 
810   // There's only something to simplify if we can either:
811   //
812   //  - combine one or more induction vars with an existing trip counter, or
813   //  - replace two or more induction variables with a new trip counter.
814   //
815   // Put another way, there's only something to simplify if the number of
816   // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
817   if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
818     return nullptr;
819   }
820 
821   // OK, we're going to do the transformation!  Set up some helpers.
822 
823   // `new_instrs` holds instructions created outside of a computation for
824   // cloning.  Elements added here just need to live until the end of the
825   // relevant CloneWithReplacement call.
826   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
827   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
828     new_instrs.push_back(std::move(instr));
829     return new_instrs.back().get();
830   };
831 
832   auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
833                            HloInstruction* lhs, HloInstruction* rhs) {
834     // Reshape lhs/rhs to the output shape if necessary.  This deals with the
835     // fact that induction variables need only be effective scalars, not true
836     // scalars.
837     if (!ShapeUtil::Compatible(shape, lhs->shape())) {
838       lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
839     }
840     if (!ShapeUtil::Compatible(shape, rhs->shape())) {
841       rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
842     }
843     return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
844   };
845 
846   auto add_gte = [&](HloInstruction* src, int64 idx) {
847     return add_new_instr(HloInstruction::CreateGetTupleElement(
848         src->shape().tuple_shapes(idx), src, idx));
849   };
850 
851   // Our new while loop will have the same shape as the old while loop, except
852   // we'll add a trip counter to the end if it wasn't originally present.
853   Shape new_while_shape = while_shape;
854   bool added_trip_counter = false;
855   if (!trip_counter) {
856     VLOG(10) << "Adding new trip counter to end of loop's tuple.";
857     trip_counter = new_while_shape.tuple_shapes_size();
858     *new_while_shape.add_tuple_shapes() =
859         ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
860     added_trip_counter = true;
861   }
862 
863   // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
864   // shape `while_body->shape()` and where the induction variables are "reified"
865   // (i.e. they have value <init> + <counter> * <constant>).
866   auto convert_to_old_form = [&](HloInstruction* instr) {
867     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
868     std::vector<HloInstruction*> tuple_elems;
869     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
870       const auto& elem_shape = while_shape.tuple_shapes(i);
871       if (!induction_vars.count(i)) {
872         tuple_elems.push_back(add_gte(instr, i));
873         continue;
874       }
875       tuple_elems.push_back(add_binary_op(
876           elem_shape, HloOpcode::kAdd, add_gte(instr, i),
877           add_binary_op(elem_shape, HloOpcode::kMultiply,
878                         add_gte(instr, *trip_counter),
879                         add_new_instr(induction_vars.at(i)->Clone()))));
880     }
881     return HloInstruction::CreateTuple(tuple_elems);
882   };
883 
884   // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
885   // shape `new_while_shape` and where the induction variables (but not trip
886   // counters) are replaced with their unchanging <loop_body_param> values.
887   auto convert_to_new_form = [&](HloInstruction* old_root,
888                                  HloParameterInstruction* loop_body_param) {
889     CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
890     std::vector<HloInstruction*> tuple_elems;
891 
892     // In the new form, induction variables come from `init`, everything else
893     // (including the trip counter if it's not one we created ourselves) comes
894     // from the `root` tuple unmodified.
895     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
896       tuple_elems.push_back(
897           add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
898     }
899     // If we created a trip counter ourselves, add 1 to it in the next
900     // iteration.
901     if (added_trip_counter) {
902       tuple_elems.push_back(add_binary_op(
903           new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
904           add_gte(loop_body_param, *trip_counter),
905           add_new_instr(
906               HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
907     }
908 
909     return HloInstruction::CreateTuple(tuple_elems);
910   };
911 
912   // Creates a new init tuple, which is the same as the old init tuple except if
913   // we added a trip counter, it's set to 0.
914   auto get_new_while_init = [&](HloInstruction* init) {
915     CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
916     if (!added_trip_counter) {
917       return init;
918     }
919     std::vector<HloInstruction*> tuple_elems;
920     for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
921       tuple_elems.push_back(add_gte(init, i));
922     }
923     tuple_elems.push_back(add_new_instr(
924         HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
925     return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
926   };
927 
928   std::unique_ptr<HloComputation> new_while_cond =
929       while_cond->CloneWithReplacementPairs({
930           while_cond->parameter_instruction(0),
931           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
932               0, new_while_shape,
933               while_cond->parameter_instruction(0)->name()))),
934       });
935 
936   // Creating the new while body proceeds in two steps.  First we convert the
937   // users of the parameter to the old form.  Then as a second
938   // CloneWithReplacement operation we convert the root to the new form.  We
939   // have to do this in two steps because the new root needs to use the new
940   // param0, and during the first clone operation, only the *old-form* param0 is
941   // accessible.
942   //
943   // We have to add temp_new_while_body to the module because cloning a
944   // computation touches the module (to get its NameUniquer).
945   HloComputation* temp_new_while_body =
946       module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
947           while_body->parameter_instruction(0),
948           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
949               0, new_while_shape,
950               while_body->parameter_instruction(0)->name()))),
951       }));
952   std::unique_ptr<HloComputation> new_while_body =
953       temp_new_while_body->CloneWithReplacementPairs({
954           temp_new_while_body->root_instruction(),
955           convert_to_new_form(
956               add_new_instr(temp_new_while_body->root_instruction()->Clone()),
957               Cast<HloParameterInstruction>(
958                   temp_new_while_body->parameter_instruction(0))),
959       });
960   TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
961 
962   // Create the final while loop, and add any new instructions created to
963   // `computation`.
964   new_instrs.clear();
965   auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
966       new_while_shape,
967       module->AddEmbeddedComputation(std::move(new_while_cond)),
968       module->AddEmbeddedComputation(std::move(new_while_body)),
969       get_new_while_init(while_init)));
970   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
971       while_op, convert_to_old_form(new_while)));
972   for (auto& instr : new_instrs) {
973     computation->AddInstruction(std::move(instr));
974   }
975   return new_while;
976 }
977 
Run(HloModule * module)978 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
979   XLA_VLOG_LINES(3,
980                  "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
981   bool changed = false;
982 
983   // Gather all the while ops in our module.  We do this ahead of time so we
984   // don't have to worry about mutating the lists of computations or
985   // instructions while we iterate.
986   std::vector<HloInstruction*> while_ops;
987   for (auto* comp : module->computations()) {
988     for (auto* instr : comp->instructions()) {
989       if (instr->opcode() == HloOpcode::kWhile) {
990         while_ops.push_back(instr);
991       }
992     }
993   }
994 
995   for (HloInstruction* while_op : while_ops) {
996     // We can't remove while loops that contain send/recv nodes, because we rely
997     // on the particular loop structure around the node matching on the send and
998     // recv sides.  Other while simplifications require us to remove the loop
999     // and replace it with a new one, so we can't do that either.
1000     if (ContainsInstrWithOpcode(while_op->while_body(),
1001                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1002                                  HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1003         ContainsInstrWithOpcode(while_op->while_condition(),
1004                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1005                                  HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1006       VLOG(2) << "Not attempting to simplify while loop because it contains a "
1007                  "send/recv node: "
1008               << while_op->ToShortString();
1009       continue;
1010     }
1011 
1012     TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1013     changed |= result;
1014 
1015     TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1016     changed |= result;
1017     if (result) {
1018       // Don't continue simplifying after successfully removing the while loop
1019       // -- that would result in use-after-free nastiness.
1020       continue;
1021     }
1022 
1023     // TODO(b/119281462): Cowardly refuse to perform any of the following
1024     // optimizations in the presence of kDomain instructions.  It seems that
1025     // modifying a while loop's tuple doesn't work when kDomain is present.
1026     if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1027         ContainsInstrWithOpcode(while_op->while_condition(),
1028                                 {HloOpcode::kDomain})) {
1029       continue;
1030     }
1031 
1032     // Each of the optimizations below modifies the while loop itself if it's
1033     // successful, meaning that `while_op` is no longer valid after one of these
1034     // transformations returns true.
1035 
1036     TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1037     changed |= result;
1038     if (result) {
1039       continue;
1040     }
1041 
1042     TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1043     changed |= result;
1044     if (result) {
1045       continue;
1046     }
1047 
1048     TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1049     changed |= result;
1050     if (result) {
1051       continue;
1052     }
1053 
1054     bool merged_induction_vars = false;
1055     // Notably missing from this list are S16 and U16.  These don't currently
1056     // work because S/U16 literals are not implemented.
1057     for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1058       TF_ASSIGN_OR_RETURN(auto* new_while_op,
1059                           TryMergeInductionVariables(while_op, elem_ty));
1060       if (new_while_op) {
1061         while_op = new_while_op;
1062         changed = true;
1063         merged_induction_vars = true;
1064       }
1065     }
1066     if (merged_induction_vars) {
1067       continue;
1068     }
1069   }
1070 
1071   XLA_VLOG_LINES(3,
1072                  "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1073   return changed;
1074 }
1075 
1076 }  // namespace xla
1077