1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/tuple_util.h"
25 
26 namespace xla {
27 
28 using absl::StrCat;
29 
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)30 static StatusOr<HloComputation*> WidenWhileCondition(
31     HloComputation* narrow_condition, const Shape& wide_shape) {
32   const Shape& narrow_shape =
33       narrow_condition->parameter_instruction(0)->shape();
34 
35   HloComputation* wide_while_cond = [&]() {
36     HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
37     builder.AddInstruction(
38         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
39 
40     // This is needed so that the root instruction is shaped as a PRED[] -- we
41     // need to get this right to begin with since we can't mutate the type of
42     // the root instruction later.  We later change the root instruction to
43     // something more appropriate.
44     builder.AddInstruction(
45         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
46     return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
47   }();
48 
49   HloInstruction* truncated_parameter =
50       TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
51                                narrow_shape.tuple_shapes_size());
52   HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
53       HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
54                                  {truncated_parameter}, narrow_condition));
55 
56   wide_while_cond->set_root_instruction(call_narrow_cond);
57 
58   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
59   return wide_while_cond;
60 }
61 
62 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)63 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
64   const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
65 
66   HloComputation* wide_while_body = [&]() {
67     HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
68     builder.AddInstruction(
69         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
70     return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
71   }();
72 
73   HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
74   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
75       wide_parameter, narrow_shape.tuple_shapes_size());
76   HloInstruction* call_narrow_body =
77       wide_while_body->AddInstruction(HloInstruction::CreateCall(
78           narrow_shape, {truncated_parameter}, narrow_body));
79 
80   std::vector<HloInstruction*> live_through_values;
81   for (int i = narrow_shape.tuple_shapes_size();
82        i < wide_shape.tuple_shapes_size(); i++) {
83     live_through_values.push_back(
84         wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
85             wide_shape.tuple_shapes(i), wide_parameter, i)));
86   }
87 
88   wide_while_body->set_root_instruction(
89       TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
90 
91   TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
92                       CallInliner::Inline(call_narrow_body));
93   return {{wide_while_body, std::move(inlined_instructions_map)}};
94 }
95 
96 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)97 WhileUtil::MakeInstructionsLiveIn(
98     HloInstruction* while_instr,
99     absl::Span<HloInstruction* const> instructions) {
100   CHECK(while_instr->shape().IsTuple());
101 
102   int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
103   Shape new_while_shape = while_instr->shape();
104   for (auto* instruction : instructions) {
105     *new_while_shape.add_tuple_shapes() = instruction->shape();
106   }
107 
108   TF_ASSIGN_OR_RETURN(
109       HloComputation * new_while_condition,
110       WidenWhileCondition(while_instr->while_condition(), new_while_shape));
111 
112   HloComputation* new_while_body;
113   CallInliner::InlinedInstructionMap inlined_instructions_map;
114   TF_ASSIGN_OR_RETURN(
115       std::tie(new_while_body, inlined_instructions_map),
116       WidenWhileBody(while_instr->while_body(), new_while_shape));
117 
118   HloInstruction* new_while_init =
119       TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
120   HloComputation* containing_computation = while_instr->parent();
121   HloInstruction* new_while = containing_computation->AddInstruction(
122       HloInstruction::CreateWhile(new_while_shape, new_while_condition,
123                                   new_while_body, new_while_init));
124 
125   // We want to get rid of the old while instruction even if it has side
126   // effecting operations so we do a manual HloComputation::RemoveInstruction
127   // instead of relying on HloComputation::ReplaceInstruction.
128   TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
129       new_while, while_instr->shape().tuple_shapes_size())));
130   TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
131 
132   HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
133   std::vector<HloInstruction*> live_in_instructions;
134   for (int64 i = elements_in_old_while_shape;
135        i < new_while_shape.tuple_shapes_size(); i++) {
136     live_in_instructions.push_back(
137         new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
138             instructions[i - elements_in_old_while_shape]->shape(),
139             while_body_param, i)));
140   }
141 
142   WhileUtil::MakeInstructionsLiveInResult result;
143 
144   result.new_while_instr = new_while;
145   result.while_body_live_in_values = std::move(live_in_instructions);
146   result.while_body_instruction_map = std::move(inlined_instructions_map);
147 
148   return std::move(result);
149 }
150 
151 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32 trip_count)152 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
153                                     int32 trip_count) {
154   Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
155 
156   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
157                       CreateComputationWithSignature(
158                           {&loop_state_shape}, scalar_pred, "while_cond"));
159 
160   HloInstruction* trip_count_constant = cond_computation->AddInstruction(
161       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
162 
163   HloInstruction* param = cond_computation->parameter_instruction(0);
164   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
165                       MakeGetTupleElementHlo(param, 0));
166 
167   TF_ASSIGN_OR_RETURN(
168       HloInstruction * compare,
169       MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
170   cond_computation->set_root_instruction(compare);
171   return std::move(cond_computation);
172 }
173 
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)174 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
175     const Shape& loop_state_shape,
176     const std::function<StatusOr<WhileUtil::LoopStateTy>(
177         HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
178   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
179                       CreateComputationWithSignature(
180                           {&loop_state_shape}, loop_state_shape, "while_body"));
181   HloInstruction* one = body_computation->AddInstruction(
182       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
183   HloInstruction* param = body_computation->parameter_instruction(0);
184   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
185                       MakeGetTupleElementHlo(param, 0));
186   TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
187                       MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
188 
189   std::vector<HloInstruction*> loop_body_generator_args;
190   for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
191     TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
192                         MakeGetTupleElementHlo(param, i));
193     loop_body_generator_args.push_back(tuple_element);
194   }
195   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
196                       loop_body_generator(indvar, loop_body_generator_args));
197   next_state.insert(next_state.begin(), next_indvar);
198   HloInstruction* next_state_tuple =
199       body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
200   body_computation->set_root_instruction(next_state_tuple);
201 
202   return std::move(body_computation);
203 }
204 
MakeInitTupleFromInitValues(HloComputation * computation,const WhileUtil::LoopStateTy & init_values)205 static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
206     HloComputation* computation, const WhileUtil::LoopStateTy& init_values) {
207   std::vector<HloInstruction*> init_values_with_indvar;
208   init_values_with_indvar.reserve(init_values.size() + 1);
209   HloInstruction* zero = computation->AddInstruction(
210       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
211   init_values_with_indvar.push_back(zero);
212   absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
213   return computation->AddInstruction(
214       HloInstruction::CreateTuple(init_values_with_indvar));
215 }
216 
MakeLoopStateShape(const WhileUtil::LoopStateTy & init_values)217 static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
218   std::vector<Shape> loop_state_shape_components;
219   loop_state_shape_components.reserve(init_values.size() + 1);
220   loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
221   absl::c_transform(init_values,
222                     std::back_inserter(loop_state_shape_components),
223                     [](HloInstruction* instr) { return instr->shape(); });
224   return ShapeUtil::MakeTupleShape(loop_state_shape_components);
225 }
226 
MakeCountedLoop(HloComputation * computation,int32 trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)227 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
228     HloComputation* computation, int32 trip_count,
229     const WhileUtil::LoopStateTy& init_values,
230     const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
231     const OpMetadata& metadata) {
232   CHECK_GE(trip_count, 0);
233 
234   Shape loop_state_shape = MakeLoopStateShape(init_values);
235   TF_ASSIGN_OR_RETURN(
236       std::unique_ptr<HloComputation> cond,
237       MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
238   TF_ASSIGN_OR_RETURN(
239       std::unique_ptr<HloComputation> body,
240       MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
241   TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple,
242                       MakeInitTupleFromInitValues(computation, init_values));
243   HloModule* module = computation->parent();
244   HloInstruction* while_instr =
245       computation->AddInstruction(HloInstruction::CreateWhile(
246           loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
247           module->AddEmbeddedComputation(std::move(body)), init_tuple));
248   while_instr->set_metadata(metadata);
249 
250   std::vector<HloInstruction*> result;
251   for (int64 i = 0, e = init_values.size(); i < e; i++) {
252     TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
253                         MakeGetTupleElementHlo(while_instr, i + 1));
254     result.push_back(user_state);
255   }
256   return result;
257 }
258 
GetInvariantGTEsForWhileBody(const HloComputation & while_body)259 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
260     const HloComputation& while_body) {
261   std::vector<HloInstruction*> result;
262   const HloInstruction::InstructionVector root_operands =
263       while_body.root_instruction()->operands();
264   for (int i = 0; i < root_operands.size(); i++) {
265     HloInstruction* instr = root_operands[i];
266     if (instr->opcode() == HloOpcode::kGetTupleElement &&
267         instr->tuple_index() == i &&
268         instr->operand(0) == while_body.parameter_instruction(0)) {
269       result.push_back(instr);
270     }
271   }
272   return result;
273 }
274 
275 /*static*/ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)276 WhileUtil::GetGTEsMapForWhileConditional(
277     const HloComputation& while_conditional) {
278   absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> result;
279   for (HloInstruction* user :
280        while_conditional.parameter_instruction(0)->users()) {
281     if (user->opcode() == HloOpcode::kGetTupleElement) {
282       result[user->tuple_index()].push_back(user);
283     }
284   }
285   return result;
286 }
287 
288 }  // namespace xla
289