1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/inlined_vector.h"
19 #include "tensorflow/compiler/xla/service/while_util.h"
20 #include "tensorflow/compiler/xla/util.h"
21 
22 namespace xla {
23 
24 // Replaces all uses of old_instr with new_instr except the use at
25 // `while_body_root` (which must be a tuple instruction) at index `tuple_index`.
26 // This utility helps us replace an instruction in the while body with a
27 // constant while still keeping it trivially loop invariant.
ReplaceUsesWhileKeepingLoopInvariance(HloInstruction * old_instr,HloInstruction * new_instr,HloInstruction * while_body_root,int64 tuple_index)28 static Status ReplaceUsesWhileKeepingLoopInvariance(
29     HloInstruction* old_instr, HloInstruction* new_instr,
30     HloInstruction* while_body_root, int64 tuple_index) {
31   CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple);
32 
33   std::vector<HloInstruction*> users;
34   users.reserve(old_instr->user_count());
35   absl::c_copy(old_instr->users(), std::back_inserter(users));
36 
37   for (auto* user : users) {
38     for (int64 i = 0, e = user->operand_count(); i < e; i++) {
39       if (user->operand(i) == old_instr &&
40           !(user == while_body_root && i == tuple_index)) {
41         TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr));
42       }
43     }
44   }
45 
46   return Status::OK();
47 }
48 
TrySinkingConstantsIntoWhileLoop(HloInstruction * while_instr)49 StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
50     HloInstruction* while_instr) {
51   HloComputation* while_cond = while_instr->while_condition();
52   HloComputation* while_body = while_instr->while_body();
53 
54   const HloInstruction& init_value = *while_instr->operand(0);
55   if (init_value.opcode() != HloOpcode::kTuple) {
56     return false;
57   }
58 
59   bool changed = false;
60 
61   absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
62       conditional_gte_index_to_insts =
63           WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
64   std::vector<HloInstruction*> invariant_body_gtes =
65       WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
66 
67   for (HloInstruction* invariant_body_gte : invariant_body_gtes) {
68     int64 index = invariant_body_gte->tuple_index();
69     const HloInstruction& invariant_value = *init_value.operand(index);
70 
71     // Original value should be a constant.
72     if (invariant_value.opcode() != HloOpcode::kConstant) {
73       continue;
74     }
75 
76     // Sink into the while_body.
77     // Should have at least one user that's not while_body_root.
78     if (invariant_body_gte->user_count() > 1) {
79       HloInstruction* constant_instr =
80           while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk"));
81       TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance(
82           invariant_body_gte, constant_instr, while_body->root_instruction(),
83           index));
84       changed = true;
85     }
86 
87     // Check if there is a corresponding GTE in while_conditional.
88     auto it = conditional_gte_index_to_insts.find(index);
89     if (it == conditional_gte_index_to_insts.end()) {
90       continue;
91     }
92 
93     for (HloInstruction* invariant_cond_gte : it->second) {
94       // Should have at least one user.
95       if (invariant_cond_gte->user_count() > 0) {
96         HloInstruction* constant_instr = while_cond->AddInstruction(
97             invariant_value.Clone(/*suffix=*/".sunk"));
98         TF_RETURN_IF_ERROR(
99             invariant_cond_gte->ReplaceAllUsesWith(constant_instr));
100         changed = true;
101       }
102     }
103   }
104 
105   return changed;
106 }
107 
Run(HloModule * module)108 StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) {
109   VLOG(2) << "HLO module before WhileLoopConstantSinking:";
110   XLA_VLOG_LINES(2, module->ToString());
111 
112   bool changed = false;
113   std::vector<HloInstruction*> while_instrs;
114   for (auto* comp : module->MakeNonfusionComputations()) {
115     // Right now we don't particulary care about optimizing while-of-while
116     // patterns.  If/When we do, we'll want to visit the outer while (while_0)
117     // before we visit the inner while (while_1):
118     //
119     // while_1_body(state) {
120     //   val = gte(state, 0) // Loop invariant
121     //   use(val)
122     // }
123     //
124     // while_0_body(state) {
125     //   val = gte(state, 0) // Loop invariant
126     //   while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
127     //   ...
128     // }
129     //
130     // main {
131     //   while_0 = while(init=(constant, ...), body=while_0_body, ...)
132     // }
133     //
134     // This will let us sink the constant into the outer while first and then
135     // into the inner while in a single run of this pass.
136     absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
137                     [](const HloInstruction* instr) {
138                       return instr->opcode() == HloOpcode::kWhile;
139                     });
140   }
141 
142   for (HloInstruction* while_instr : while_instrs) {
143     TF_ASSIGN_OR_RETURN(bool result,
144                         TrySinkingConstantsIntoWhileLoop(while_instr));
145     changed |= result;
146   }
147 
148   if (changed) {
149     VLOG(2) << "HLO module after WhileLoopConstantSinking:";
150     XLA_VLOG_LINES(2, module->ToString());
151   } else {
152     VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
153   }
154 
155   return changed;
156 }
157 }  // namespace xla
158