1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/call_inliner.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_query.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
29
30 namespace xla {
31
32 namespace m = match;
33 using absl::optional;
34 using hlo_query::ContainsInstrWithOpcode;
35
36 // Tries to remove elements in a while loop's tuple that aren't used within the
37 // loop.
38 //
39 // Specifically, if a loop is tuple-shaped, and there exists some element of
40 // that tuple that is not used by the loop condition and is not used by the loop
41 // body except to pass it to the next iteration of the loop, then we can remove
42 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)43 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
44 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
45
46 // Don't try this transformation if the while loop isn't removable, since if
47 // it succeeds ultimately we're going to have to replace the old while loop
48 // with a new one.
49 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
50 VLOG(2) << "Can't remove dead parameters from non-removable while op.";
51 return false;
52 }
53
54 HloModule* module = while_op->GetModule();
55 HloComputation* computation = while_op->parent();
56 HloInstruction* while_init = while_op->mutable_operand(0);
57 HloComputation* while_cond = while_op->while_condition();
58 HloComputation* while_body = while_op->while_body();
59 HloInstruction* while_body_root = while_body->root_instruction();
60
61 if (!while_init->shape().IsTuple()) {
62 VLOG(2) << "While op's carried value isn't tuple shaped.";
63 return false;
64 }
65
66 if (while_body_root->opcode() != HloOpcode::kTuple) {
67 VLOG(2) << "While body's root is not a tuple(...) instruction.";
68 return false;
69 }
70
71 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
72
73 // Bail if param0 of while_cond or while_body has users which aren't of type
74 // get-tuple-element.
75 for (const HloInstruction* instr : {while_body->parameter_instruction(0),
76 while_cond->parameter_instruction(0)}) {
77 for (const HloInstruction* user : instr->users()) {
78 if (user->opcode() != HloOpcode::kGetTupleElement) {
79 VLOG(2) << "Cowardly refusing to analyze while loop with "
80 << instr->ToString(print_no_metadata)
81 << " used by non-GTE instruction "
82 << user->ToString(print_no_metadata) << " in computation "
83 << instr->parent()->name();
84 return false;
85 }
86 }
87 }
88
89 const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
90 if (tuple_size == 0) {
91 VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
92 "empty.";
93 return false;
94 }
95
96 absl::flat_hash_set<int64> used_tuple_indices;
97 for (HloComputation* comp : {while_body, while_cond}) {
98 // The HLO verifier ensures that while_input's shape matches while_init's
99 // shape, which we verified above is a tuple.
100 HloInstruction* while_input = comp->parameter_instruction(0);
101
102 for (const HloInstruction* user : while_input->users()) {
103 // This user doesn't count if it's only used by the while body's root, and
104 // the root places the tuple element into the same index of the tuple as
105 // it came from. That just amounts to us carrying the variable through
106 // the loop.
107 //
108 // Careful: HloInstruction::operand_index returns the first index the
109 // operand appears in, but it may appear more than once!
110 if (user->user_count() == 1 && user->users().front() == while_body_root &&
111 while_body_root->operand_index(user) == user->tuple_index() &&
112 absl::c_count(while_body_root->operands(), user) == 1) {
113 continue;
114 }
115
116 used_tuple_indices.insert(user->tuple_index());
117 if (used_tuple_indices.size() == tuple_size) {
118 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
119 << " uses all of its inputs; no simplification possible.";
120 return false;
121 }
122 }
123 }
124
125 // If a tuple element is not passed unmodified from the while body's param0
126 // through to the while body's root, count that element as "used", since
127 // removing that element would be observable.
128 for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
129 if (used_tuple_indices.contains(i)) {
130 continue;
131 }
132
133 auto* operand = while_body_root->operand(i);
134 if (operand->opcode() != HloOpcode::kGetTupleElement ||
135 operand->operand(0) != while_body->parameter_instruction(0) ||
136 operand->tuple_index() != i) {
137 VLOG(2) << "Tuple index " << i
138 << " is not passed through loop body unmodified.";
139 used_tuple_indices.insert(i);
140
141 if (used_tuple_indices.size() == tuple_size) {
142 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
143 << " uses all of its inputs; no simplification possible.";
144 return false;
145 }
146 }
147 }
148
149 // If we got here, used_tuple_indices.size() < tuple_size, meaning some
150 // elements of the loop's tuple aren't used by while_body or while_cond.
151 CHECK_LT(used_tuple_indices.size(), tuple_size);
152
153 VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
154 << " elements from tuple of "
155 << while_op->ToString(print_no_metadata);
156
157 // Build up maps from the old/new to the new/old tuple indices.
158 std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
159 used_tuple_indices.end());
160 absl::c_sort(new_to_old_tuple_idx);
161
162 absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
163 for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
164 int64 old_idx = new_to_old_tuple_idx[new_idx];
165 old_to_new_tuple_idx[old_idx] = new_idx;
166 VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
167 }
168
169 // Compute the shape of the while op after we remove the dead indices.
170 std::vector<Shape> new_while_tuple_elem_shapes;
171 new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
172 for (int64 old_idx : new_to_old_tuple_idx) {
173 new_while_tuple_elem_shapes.push_back(
174 while_init->shape().tuple_shapes(old_idx));
175 }
176 Shape new_while_shape =
177 ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
178
179 // Returns a map from elements in the computation to new instructions which
180 // replace the old instructions after we remove unused elements from the while
181 // tuple.
182 auto make_while_computation_replacements = [&](const HloComputation* comp) {
183 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
184 replacements;
185
186 auto* param = comp->parameter_instruction(0);
187 replacements.emplace(param, HloInstruction::CreateParameter(
188 0, new_while_shape, param->name()));
189
190 // Materialize param's users, since we're about to add new ones below.
191 std::vector<HloInstruction*> materialized_users(param->users().begin(),
192 param->users().end());
193 for (const auto* user : materialized_users) {
194 // The while body root is handled separately.
195 if (user == while_body_root) {
196 continue;
197 }
198 CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
199 << user->ToString(print_no_metadata);
200
201 int64 old_idx = user->tuple_index();
202 auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
203 if (new_idx_iter != old_to_new_tuple_idx.end()) {
204 // This is a GTE of an index that survives. Replace it.
205 replacements.emplace(
206 user, HloInstruction::CreateGetTupleElement(user->shape(), param,
207 new_idx_iter->second));
208 } else {
209 // This is a GTE of an index that we've removed. Remove it from the
210 // cloned computation.
211 CHECK(user->user_count() == 0 ||
212 user->user_count() == 1 &&
213 user->users().front() == while_body_root)
214 << "Instruction " << user->ToString(print_no_metadata)
215 << " should be unused (except by root of while body), but has "
216 "users: {"
217 << absl::StrJoin(user->users(), ", ",
218 [&](string* out, const HloInstruction* instr) {
219 absl::StrAppend(
220 out, instr->ToString(print_no_metadata));
221 })
222 << "}";
223
224 replacements.emplace(user, nullptr);
225 }
226 }
227 return replacements;
228 };
229
230 // Create the new while condition, body, and init value.
231 std::unique_ptr<HloComputation> new_while_cond =
232 while_cond->CloneWithReplacements(
233 make_while_computation_replacements(while_cond));
234
235 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
236 while_body_replacements = make_while_computation_replacements(while_body);
237 std::vector<HloInstruction*> new_while_body_root_elems;
238 new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
239 for (int64 old_idx : new_to_old_tuple_idx) {
240 new_while_body_root_elems.push_back(
241 while_body_root->mutable_operand(old_idx));
242 }
243 while_body_replacements.emplace(
244 while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
245 std::unique_ptr<HloComputation> new_while_body =
246 while_body->CloneWithReplacements(std::move(while_body_replacements));
247
248 // Add a new while_init instruction that repackages the old while_init
249 // instruction's elements. We rely on the AlgebraicSimplifier and DCE to
250 // clean this up in the common case where while_init is a tuple op. (It's
251 // definitely tuple-shaped, but it's not necessarily a tuple op.)
252 std::vector<HloInstruction*> new_while_init_elems;
253 new_while_init_elems.reserve(new_to_old_tuple_idx.size());
254 for (int64 old_idx : new_to_old_tuple_idx) {
255 new_while_init_elems.push_back(
256 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
257 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
258 }
259 auto* new_while_init = computation->AddInstruction(
260 HloInstruction::CreateTuple(new_while_init_elems));
261
262 // Create the new while op.
263 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
264 new_while_shape,
265 module->AddEmbeddedComputation(std::move(new_while_cond)),
266 module->AddEmbeddedComputation(std::move(new_while_body)),
267 new_while_init));
268
269 // Create a tuple op that recreates the output of the old while op. That is,
270 // we transform to
271 //
272 // new_while_init while_init
273 // | |
274 // V |
275 // new_while |
276 // | |
277 // -------| |----
278 // V V
279 // new_tuple
280 // |
281 // V
282 // (orig. users of while op)
283 //
284 // The tuple simplifier will then simplify this if possible, removing
285 // new_tuple and while_init.
286 std::vector<HloInstruction*> new_tuple_elems;
287 for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
288 auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
289 if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
290 int64 gte_idx = new_tuple_idx_it->second;
291 new_tuple_elems.push_back(
292 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
293 new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
294 gte_idx)));
295 } else {
296 new_tuple_elems.push_back(
297 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
298 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
299 }
300 }
301 HloInstruction* new_tuple =
302 computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
303 TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple));
304
305 return true;
306 }
307
308 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
309 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)310 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
311 HloModule* module = while_op->GetModule();
312 HloComputation* computation = while_op->parent();
313 auto* while_init = while_op->mutable_operand(0);
314 auto* while_body = while_op->while_body();
315 auto* while_cond = while_op->while_condition();
316 auto* while_body_root = while_body->root_instruction();
317 if (while_init->opcode() != HloOpcode::kTuple ||
318 while_body_root->opcode() != HloOpcode::kTuple) {
319 return false;
320 }
321
322 TF_RET_CHECK(while_cond->num_parameters() == 1);
323 TF_RET_CHECK(while_body->num_parameters() == 1);
324 TF_RET_CHECK(
325 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
326
327 absl::flat_hash_set<int64> constant_tuple_indices;
328 const auto& while_shape = while_init->shape();
329 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
330 auto* init_elem = while_init->operand(i);
331 auto* body_elem = while_body_root->operand(i);
332 if (init_elem->opcode() == HloOpcode::kConstant &&
333 body_elem->opcode() == HloOpcode::kConstant &&
334 init_elem->literal() == body_elem->literal()) {
335 constant_tuple_indices.insert(i);
336 }
337 }
338
339 if (constant_tuple_indices.empty()) {
340 return false;
341 }
342
343 // OK, we found some constant elements of the while parameter! Eliminate
344 // them.
345 std::vector<Shape> new_while_shape_elems;
346 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
347 if (!constant_tuple_indices.count(i)) {
348 new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
349 }
350 }
351 Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
352
353 // `new_instrs` holds instructions created outside of a computation for
354 // cloning. Elements added here just need to live until the end of the
355 // relevant CloneWithReplacement call.
356 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
357 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
358 new_instrs.push_back(std::move(instr));
359 return new_instrs.back().get();
360 };
361
362 // Returns a new tuple without the elements of constant_tuple_indices.
363 auto remove_constant_elems = [&](HloInstruction* instr) {
364 CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
365
366 std::vector<HloInstruction*> tuple_elems;
367 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
368 if (!constant_tuple_indices.count(i)) {
369 tuple_elems.push_back(
370 add_new_instr(HloInstruction::CreateGetTupleElement(
371 while_shape.tuple_shapes(i), instr, i)));
372 }
373 }
374 return HloInstruction::CreateTuple(tuple_elems);
375 };
376
377 auto add_constant_elems = [&](HloInstruction* instr) {
378 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
379
380 std::vector<HloInstruction*> tuple_elems;
381 int64 j = 0;
382 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
383 if (constant_tuple_indices.count(i)) {
384 tuple_elems.push_back(while_init->mutable_operand(i));
385 } else {
386 tuple_elems.push_back(
387 add_new_instr(HloInstruction::CreateGetTupleElement(
388 while_shape.tuple_shapes(i), instr, j)));
389 ++j;
390 }
391 }
392 return HloInstruction::CreateTuple(tuple_elems);
393 };
394
395 // Special case: constant_tuple_indices covers the whole while parameter, so
396 // the new while shape is the empty tuple. In this case, the value of the
397 // while loop is simply equal to the value of `init`.
398 //
399 // It's unfortunate to special-case this, but it's simpler than the
400 // alternative. The problem is that if our while parameter has no
401 // non-constant elems, the tuple returned by `add_constant_elems` won't depend
402 // on instr (the loop body/cond parameter), and therefore
403 // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
404 // invalid HLO.
405 if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
406 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
407 return true;
408 }
409
410 std::unique_ptr<HloComputation> new_while_cond =
411 while_cond->CloneWithReplacementPairs({
412 while_cond->parameter_instruction(0),
413 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
414 0, new_while_shape,
415 while_cond->parameter_instruction(0)->name()))),
416 });
417
418 std::unique_ptr<HloComputation> new_while_body =
419 while_body->CloneWithReplacementPairs(
420 {
421 while_body->parameter_instruction(0),
422 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
423 0, new_while_shape,
424 while_cond->parameter_instruction(0)->name()))),
425 },
426 {
427 while_body->root_instruction(),
428 remove_constant_elems(
429 add_new_instr(while_body->root_instruction()->Clone())),
430 });
431
432 // Create the final while loop, and add any new instructions created to
433 // `computation`.
434 new_instrs.clear();
435 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
436 while_op,
437 add_constant_elems(
438 computation->AddInstruction(HloInstruction::CreateWhile(
439 new_while_shape,
440 module->AddEmbeddedComputation(std::move(new_while_cond)),
441 module->AddEmbeddedComputation(std::move(new_while_body)),
442 add_new_instr(remove_constant_elems(while_init)))))));
443 for (auto& instr : new_instrs) {
444 computation->AddInstruction(std::move(instr));
445 }
446 return true;
447 }
448
449 // Tries to remove a while loop from the graph.
450 //
451 // - Loops with trip count of 0 can be replaced by the loop's "init" value.
452 // - Loops with trip count of 1 can be replaced by the loop's body, with the
453 // loop itself removed.
454 //
455 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)456 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
457 // Cowardly refuse to remove loops that are not removable. In practice,
458 // this means that we can't remove loops that contain side-effecting
459 // instructions or have control predecessors/successors.
460 //
461 // This is not a fundamental limitation. The control operands can be moved
462 // onto the new HLOs after simplification, and any side-effecting ops inside
463 // the loop aren't removed, just cloned and added back to the loop. But
464 // moving an op out of the loop also removes implicit control dependencies
465 // between the op and the ops outside the loop, so we'd have to add those back
466 // for things like infeed/outfeed. It gets complicated. So for now we just
467 // avoid it.
468 if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
469 VLOG(2) << "Not attempting to remove while loop it is not removable: "
470 << while_op->ToShortString();
471 return false;
472 }
473
474 // Remove while loops with static trip count of 0.
475 optional<int64> trip_count =
476 ComputeWhileLoopTripCount(while_op,
477 /*max_value_returned=*/1);
478 if (trip_count && *trip_count == 0) {
479 // The loop never executes, so the value of the loop is the value of its
480 // "init" operand.
481 auto computation = while_op->parent();
482
483 // Remove while_op (i.e., call ReplaceInstruction rather than
484 // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
485 // a loop without an intervening DCE, we don't try to re-remove the loop.
486 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
487 while_op, while_op->mutable_operand(0)));
488 return true;
489 }
490
491 // Transform while loops with static trip count of 1 into a call op, then
492 // inline the call.
493 if (trip_count && *trip_count == 1) {
494 auto computation = while_op->parent();
495 auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
496 while_op->shape(), while_op->operands(), while_op->while_body()));
497 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
498 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
499 CallInliner::Inline(call_op));
500 (void)inlined_instructions_map;
501 return true;
502 }
503 return false;
504 }
505
TryPropagateConstant(HloInstruction * while_op)506 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
507 auto while_init = while_op->operand(0);
508 if (while_init->opcode() != HloOpcode::kTuple) {
509 return false;
510 }
511
512 auto while_body = while_op->while_body();
513 auto while_body_root = while_body->root_instruction();
514 if (while_body_root->opcode() != HloOpcode::kTuple) {
515 return false;
516 }
517
518 auto while_body_param = while_body->parameter_instruction(0);
519 const HloInstruction::InstructionVector& root_operands =
520 while_body_root->operands();
521
522 // Find the loop invariant tuple elements with scalar constant init value and
523 // build a map from the tuple element index to the constant value. Limit this
524 // to scalar constant values because propagating array constants can regress
525 // performance by forcing us to copy constants.
526 absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
527 for (int i = 0; i < root_operands.size(); i++) {
528 const HloInstruction* init_tuple_elem = nullptr;
529 if (Match(root_operands[i],
530 m::GetTupleElement(m::Op().Is(while_body_param), i)
531 .WithShape(m::Shape().IsScalar())) &&
532 Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
533 VLOG(3) << "Found loop invariant tuple element " << i << " "
534 << init_tuple_elem->ToString();
535 index_to_constant[i] = init_tuple_elem;
536 }
537 }
538
539 if (index_to_constant.empty()) {
540 return false;
541 }
542
543 // Replace the use of each constant tuple element in the loop_condition and
544 // loop_body with the corresponding constant value.
545 auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
546 HloInstruction* param = computation->parameter_instruction(0);
547 bool changed = false;
548 for (auto instr : param->users()) {
549 // Since only a while-loop with a tuple result reaches here, we can safely
550 // assume that `param` is a tuple and the first operand of the
551 // GetTupleElement instruction is a use of `param`.
552 if (instr->opcode() == HloOpcode::kGetTupleElement) {
553 VLOG(3) << "tuple index " << instr->tuple_index() << " "
554 << instr->ToString();
555 auto iter = index_to_constant.find(instr->tuple_index());
556 if (iter != index_to_constant.end()) {
557 const HloInstruction* hlo_constant = (*iter).second;
558 VLOG(3) << "Replace use of " << instr->ToString() << " with "
559 << hlo_constant->ToString();
560 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
561 computation->AddInstruction(hlo_constant->Clone())));
562 changed = true;
563 }
564 }
565 }
566 return changed;
567 };
568
569 TF_ASSIGN_OR_RETURN(bool changed_cond,
570 propagate_constant(while_op->while_condition()));
571 TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
572
573 return changed_cond || changed_body;
574 }
575
576 // Converts a flat list of instructions into a tuple of the desired shape. For
577 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
578 // a tuple of value ((A, B), C).
579 //
580 // desired_shape must be a tuple. (This precondition allows us to return a
581 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)582 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
583 absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
584 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
585 CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
586
587 // For each child shape in `desired_shape`, slice out the correct number of
588 // `instrs` and call UnflattenTupleInstr recursively. At each step we remove
589 // elements from `instrs` so that it only contains instructions we have not
590 // yet processed.
591 std::vector<HloInstruction*> elems;
592 for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
593 const Shape& subshape = desired_shape.tuple_shapes(i);
594 if (!subshape.IsTuple()) {
595 elems.push_back(instrs[0]);
596 instrs.remove_prefix(1);
597 continue;
598 }
599
600 // Count the number of leaf nodes underneath desired_shape[i].
601 int64 num_leaves = 0;
602 ShapeUtil::ForEachSubshape(
603 subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
604 if (!s.IsTuple()) {
605 ++num_leaves;
606 }
607 });
608
609 std::unique_ptr<HloInstruction> subinstr =
610 UnflattenTupleInstr(instrs.subspan(0, num_leaves),
611 desired_shape.tuple_shapes(i), new_instrs);
612 elems.push_back(subinstr.get());
613 new_instrs->push_back(std::move(subinstr));
614 instrs.remove_prefix(num_leaves);
615 }
616 return HloInstruction::CreateTuple(elems);
617 }
618
619 // Builds a vector whose elements are the values in the flattened tuple for
620 // `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the
621 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)622 static std::vector<HloInstruction*> GetFlatTupleElems(
623 HloInstruction* instr,
624 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
625 const auto& shape = instr->shape();
626 if (!shape.IsTuple()) {
627 return {instr};
628 }
629 std::vector<HloInstruction*> elems;
630 for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
631 const Shape& subshape = shape.tuple_shapes(i);
632 new_instrs->push_back(
633 HloInstruction::CreateGetTupleElement(subshape, instr, i));
634 auto* gte = new_instrs->back().get();
635 auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
636 elems.insert(elems.end(), flattened_subshape.begin(),
637 flattened_subshape.end());
638 }
639 return elems;
640 }
641
TryFlattenNestedTuples(HloInstruction * while_op)642 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
643 HloModule* module = while_op->GetModule();
644 HloComputation* computation = while_op->parent();
645 auto* while_init = while_op->mutable_operand(0);
646 auto* while_body = while_op->while_body();
647 auto* while_cond = while_op->while_condition();
648 auto* while_body_root = while_body->root_instruction();
649 if (while_init->opcode() != HloOpcode::kTuple ||
650 while_body_root->opcode() != HloOpcode::kTuple) {
651 return false;
652 }
653
654 TF_RET_CHECK(while_cond->num_parameters() == 1);
655 TF_RET_CHECK(while_body->num_parameters() == 1);
656 TF_RET_CHECK(
657 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
658 Shape while_shape = while_init->shape();
659 if (!ShapeUtil::IsNestedTuple(while_shape)) {
660 return false;
661 }
662
663 std::vector<Shape> flattened_shape_elems;
664 ShapeUtil::ForEachSubshape(while_shape,
665 [&](const Shape& s, const ShapeIndex& /*index*/) {
666 if (!s.IsTuple()) {
667 flattened_shape_elems.push_back(s);
668 }
669 });
670 Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
671
672 // `new_instrs` holds instructions created outside of a computation for
673 // cloning. Elements added here just need to live until the end of the
674 // relevant CloneWithReplacement call.
675 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
676 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
677 new_instrs.push_back(std::move(instr));
678 return new_instrs.back().get();
679 };
680
681 auto nested = [&](HloInstruction* instr) {
682 std::vector<HloInstruction*> gtes;
683 const Shape& flat_shape = instr->shape();
684 for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
685 gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
686 flat_shape.tuple_shapes(i), instr, i)));
687 }
688 auto nested_instr =
689 UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
690 CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
691 << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
692 << ShapeUtil::HumanString(while_shape);
693 return nested_instr;
694 };
695
696 auto flattened = [&](HloInstruction* instr) {
697 return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
698 };
699
700 // Create a new while-condition computation, where parameter 0 has flat shape
701 // but all uses of it go through the nested shape.
702 std::unique_ptr<HloComputation> new_while_cond =
703 while_cond->CloneWithReplacementPairs({
704 while_cond->parameter_instruction(0),
705 nested(add_new_instr(HloInstruction::CreateParameter(
706 0, flattened_shape,
707 while_cond->parameter_instruction(0)->name()))),
708 });
709
710 // Create a new while-body computation, where parameter 0 has a flat shape and
711 // all uses of it go through the nested shape, and where the root has a flat
712 // shape constructed from the old nested root.
713 std::unique_ptr<HloComputation> new_while_body =
714 while_body->CloneWithReplacementPairs(
715 {
716 while_body->parameter_instruction(0),
717 nested(add_new_instr(HloInstruction::CreateParameter(
718 0, flattened_shape,
719 while_body->parameter_instruction(0)->name()))),
720 },
721 {
722 while_body->root_instruction(),
723 flattened(add_new_instr(while_body->root_instruction()->Clone())),
724 });
725
726 // Create the final while loop, and add any new instructions created to
727 // `computation`.
728 new_instrs.clear();
729 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
730 while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile(
731 flattened_shape,
732 module->AddEmbeddedComputation(std::move(new_while_cond)),
733 module->AddEmbeddedComputation(std::move(new_while_body)),
734 computation->AddInstruction(flattened(while_init)))))));
735 for (auto& instr : new_instrs) {
736 computation->AddInstruction(std::move(instr));
737 }
738 return true;
739 }
740
741 // Tries to merge loop induction variables of a given type.
742 //
743 // In this pass we're only concerned with elements of the loop's tuple that
744 // are effective-scalars of type `elem_ty`. Some terminology:
745 //
746 // - The trip counter is the first element of the loop's tuple that starts at
747 // 0 and does x++ on each iteration.
748 //
749 // - An induction variable is an element of the loop's tuple that is not the
750 // trip counter and does `x += <constant>` on each iteration of the loop.
751 // Negative constants are OK.
752 //
753 // This pass adds a trip counter if one isn't already present, then replaces
754 // each induction variable with
755 //
756 // <initial_value> + <trip_count> * <constant>.
757 //
758 // This reduces the number of scalar operations in the loop, which is important
759 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
760 // it's a separate kernel launch.
761 //
762 // Returns the new loop if a change was made, or null if no change was made.
763 // Note that the new loop is not a valid replacement for the old loop; it may
764 // need to be wrapped in a tuple that changes its shape. We return the loop
765 // itself so that you can call TryMergeInductionVariables in a loop, once for
766 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)767 static StatusOr<HloInstruction*> TryMergeInductionVariables(
768 HloInstruction* while_op, PrimitiveType elem_ty) {
769 CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
770 HloModule* module = while_op->GetModule();
771 HloComputation* computation = while_op->parent();
772 auto* while_init = while_op->mutable_operand(0);
773 auto* while_body = while_op->while_body();
774 auto* while_cond = while_op->while_condition();
775 auto* while_body_root = while_body->root_instruction();
776 if (while_init->opcode() != HloOpcode::kTuple ||
777 while_body_root->opcode() != HloOpcode::kTuple) {
778 return nullptr;
779 }
780
781 TF_RET_CHECK(while_cond->num_parameters() == 1);
782 TF_RET_CHECK(while_body->num_parameters() == 1);
783 TF_RET_CHECK(
784 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
785 Shape while_shape = while_init->shape();
786
787 // The tuple index of the trip counter, if one is present.
788 absl::optional<int64> trip_counter;
789 // Maps the tuple index of each induction variable to its constant increment.
790 absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
791 for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
792 HloInstruction* constant;
793 if (!Match(while_body_root->mutable_operand(i),
794 m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
795 m::ConstantScalar(&constant))
796 .WithShape(m::Shape().WithElementType(elem_ty)))) {
797 continue;
798 }
799 if (!trip_counter && constant->literal().IsAll(1) &&
800 while_init->operand(i)->IsConstant() &&
801 while_init->operand(i)->literal().IsAll(0)) {
802 VLOG(10) << "Found existing trip counter at index " << i;
803 trip_counter = i;
804 } else {
805 VLOG(10) << "Found induction variable at index " << i;
806 induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
807 }
808 }
809
810 // There's only something to simplify if we can either:
811 //
812 // - combine one or more induction vars with an existing trip counter, or
813 // - replace two or more induction variables with a new trip counter.
814 //
815 // Put another way, there's only something to simplify if the number of
816 // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
817 if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
818 return nullptr;
819 }
820
821 // OK, we're going to do the transformation! Set up some helpers.
822
823 // `new_instrs` holds instructions created outside of a computation for
824 // cloning. Elements added here just need to live until the end of the
825 // relevant CloneWithReplacement call.
826 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
827 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
828 new_instrs.push_back(std::move(instr));
829 return new_instrs.back().get();
830 };
831
832 auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
833 HloInstruction* lhs, HloInstruction* rhs) {
834 // Reshape lhs/rhs to the output shape if necessary. This deals with the
835 // fact that induction variables need only be effective scalars, not true
836 // scalars.
837 if (!ShapeUtil::Compatible(shape, lhs->shape())) {
838 lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
839 }
840 if (!ShapeUtil::Compatible(shape, rhs->shape())) {
841 rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
842 }
843 return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
844 };
845
846 auto add_gte = [&](HloInstruction* src, int64 idx) {
847 return add_new_instr(HloInstruction::CreateGetTupleElement(
848 src->shape().tuple_shapes(idx), src, idx));
849 };
850
851 // Our new while loop will have the same shape as the old while loop, except
852 // we'll add a trip counter to the end if it wasn't originally present.
853 Shape new_while_shape = while_shape;
854 bool added_trip_counter = false;
855 if (!trip_counter) {
856 VLOG(10) << "Adding new trip counter to end of loop's tuple.";
857 trip_counter = new_while_shape.tuple_shapes_size();
858 *new_while_shape.add_tuple_shapes() =
859 ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
860 added_trip_counter = true;
861 }
862
863 // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
864 // shape `while_body->shape()` and where the induction variables are "reified"
865 // (i.e. they have value <init> + <counter> * <constant>).
866 auto convert_to_old_form = [&](HloInstruction* instr) {
867 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
868 std::vector<HloInstruction*> tuple_elems;
869 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
870 const auto& elem_shape = while_shape.tuple_shapes(i);
871 if (!induction_vars.count(i)) {
872 tuple_elems.push_back(add_gte(instr, i));
873 continue;
874 }
875 tuple_elems.push_back(add_binary_op(
876 elem_shape, HloOpcode::kAdd, add_gte(instr, i),
877 add_binary_op(elem_shape, HloOpcode::kMultiply,
878 add_gte(instr, *trip_counter),
879 add_new_instr(induction_vars.at(i)->Clone()))));
880 }
881 return HloInstruction::CreateTuple(tuple_elems);
882 };
883
884 // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
885 // shape `new_while_shape` and where the induction variables (but not trip
886 // counters) are replaced with their unchanging <loop_body_param> values.
887 auto convert_to_new_form = [&](HloInstruction* old_root,
888 HloParameterInstruction* loop_body_param) {
889 CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
890 std::vector<HloInstruction*> tuple_elems;
891
892 // In the new form, induction variables come from `init`, everything else
893 // (including the trip counter if it's not one we created ourselves) comes
894 // from the `root` tuple unmodified.
895 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
896 tuple_elems.push_back(
897 add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
898 }
899 // If we created a trip counter ourselves, add 1 to it in the next
900 // iteration.
901 if (added_trip_counter) {
902 tuple_elems.push_back(add_binary_op(
903 new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
904 add_gte(loop_body_param, *trip_counter),
905 add_new_instr(
906 HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
907 }
908
909 return HloInstruction::CreateTuple(tuple_elems);
910 };
911
912 // Creates a new init tuple, which is the same as the old init tuple except if
913 // we added a trip counter, it's set to 0.
914 auto get_new_while_init = [&](HloInstruction* init) {
915 CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
916 if (!added_trip_counter) {
917 return init;
918 }
919 std::vector<HloInstruction*> tuple_elems;
920 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
921 tuple_elems.push_back(add_gte(init, i));
922 }
923 tuple_elems.push_back(add_new_instr(
924 HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
925 return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
926 };
927
928 std::unique_ptr<HloComputation> new_while_cond =
929 while_cond->CloneWithReplacementPairs({
930 while_cond->parameter_instruction(0),
931 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
932 0, new_while_shape,
933 while_cond->parameter_instruction(0)->name()))),
934 });
935
936 // Creating the new while body proceeds in two steps. First we convert the
937 // users of the parameter to the old form. Then as a second
938 // CloneWithReplacement operation we convert the root to the new form. We
939 // have to do this in two steps because the new root needs to use the new
940 // param0, and during the first clone operation, only the *old-form* param0 is
941 // accessible.
942 //
943 // We have to add temp_new_while_body to the module because cloning a
944 // computation touches the module (to get its NameUniquer).
945 HloComputation* temp_new_while_body =
946 module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
947 while_body->parameter_instruction(0),
948 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
949 0, new_while_shape,
950 while_body->parameter_instruction(0)->name()))),
951 }));
952 std::unique_ptr<HloComputation> new_while_body =
953 temp_new_while_body->CloneWithReplacementPairs({
954 temp_new_while_body->root_instruction(),
955 convert_to_new_form(
956 add_new_instr(temp_new_while_body->root_instruction()->Clone()),
957 Cast<HloParameterInstruction>(
958 temp_new_while_body->parameter_instruction(0))),
959 });
960 TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
961
962 // Create the final while loop, and add any new instructions created to
963 // `computation`.
964 new_instrs.clear();
965 auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
966 new_while_shape,
967 module->AddEmbeddedComputation(std::move(new_while_cond)),
968 module->AddEmbeddedComputation(std::move(new_while_body)),
969 get_new_while_init(while_init)));
970 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
971 while_op, convert_to_old_form(new_while)));
972 for (auto& instr : new_instrs) {
973 computation->AddInstruction(std::move(instr));
974 }
975 return new_while;
976 }
977
Run(HloModule * module)978 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
979 XLA_VLOG_LINES(3,
980 "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
981 bool changed = false;
982
983 // Gather all the while ops in our module. We do this ahead of time so we
984 // don't have to worry about mutating the lists of computations or
985 // instructions while we iterate.
986 std::vector<HloInstruction*> while_ops;
987 for (auto* comp : module->computations()) {
988 for (auto* instr : comp->instructions()) {
989 if (instr->opcode() == HloOpcode::kWhile) {
990 while_ops.push_back(instr);
991 }
992 }
993 }
994
995 for (HloInstruction* while_op : while_ops) {
996 // We can't remove while loops that contain send/recv nodes, because we rely
997 // on the particular loop structure around the node matching on the send and
998 // recv sides. Other while simplifications require us to remove the loop
999 // and replace it with a new one, so we can't do that either.
1000 if (ContainsInstrWithOpcode(while_op->while_body(),
1001 {HloOpcode::kSend, HloOpcode::kSendDone,
1002 HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1003 ContainsInstrWithOpcode(while_op->while_condition(),
1004 {HloOpcode::kSend, HloOpcode::kSendDone,
1005 HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1006 VLOG(2) << "Not attempting to simplify while loop because it contains a "
1007 "send/recv node: "
1008 << while_op->ToShortString();
1009 continue;
1010 }
1011
1012 TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1013 changed |= result;
1014
1015 TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1016 changed |= result;
1017 if (result) {
1018 // Don't continue simplifying after successfully removing the while loop
1019 // -- that would result in use-after-free nastiness.
1020 continue;
1021 }
1022
1023 // TODO(b/119281462): Cowardly refuse to perform any of the following
1024 // optimizations in the presence of kDomain instructions. It seems that
1025 // modifying a while loop's tuple doesn't work when kDomain is present.
1026 if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1027 ContainsInstrWithOpcode(while_op->while_condition(),
1028 {HloOpcode::kDomain})) {
1029 continue;
1030 }
1031
1032 // Each of the optimizations below modifies the while loop itself if it's
1033 // successful, meaning that `while_op` is no longer valid after one of these
1034 // transformations returns true.
1035
1036 TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1037 changed |= result;
1038 if (result) {
1039 continue;
1040 }
1041
1042 TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1043 changed |= result;
1044 if (result) {
1045 continue;
1046 }
1047
1048 TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1049 changed |= result;
1050 if (result) {
1051 continue;
1052 }
1053
1054 bool merged_induction_vars = false;
1055 // Notably missing from this list are S16 and U16. These don't currently
1056 // work because S/U16 literals are not implemented.
1057 for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1058 TF_ASSIGN_OR_RETURN(auto* new_while_op,
1059 TryMergeInductionVariables(while_op, elem_ty));
1060 if (new_while_op) {
1061 while_op = new_while_op;
1062 changed = true;
1063 merged_induction_vars = true;
1064 }
1065 }
1066 if (merged_induction_vars) {
1067 continue;
1068 }
1069 }
1070
1071 XLA_VLOG_LINES(3,
1072 "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1073 return changed;
1074 }
1075
1076 } // namespace xla
1077