1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/tuple_util.h"
25
26 namespace xla {
27
28 using absl::StrCat;
29
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)30 static StatusOr<HloComputation*> WidenWhileCondition(
31 HloComputation* narrow_condition, const Shape& wide_shape) {
32 const Shape& narrow_shape =
33 narrow_condition->parameter_instruction(0)->shape();
34
35 HloComputation* wide_while_cond = [&]() {
36 HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
37 builder.AddInstruction(
38 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
39
40 // This is needed so that the root instruction is shaped as a PRED[] -- we
41 // need to get this right to begin with since we can't mutate the type of
42 // the root instruction later. We later change the root instruction to
43 // something more appropriate.
44 builder.AddInstruction(
45 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
46 return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
47 }();
48
49 HloInstruction* truncated_parameter =
50 TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
51 narrow_shape.tuple_shapes_size());
52 HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
53 HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
54 {truncated_parameter}, narrow_condition));
55
56 wide_while_cond->set_root_instruction(call_narrow_cond);
57
58 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
59 return wide_while_cond;
60 }
61
62 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)63 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
64 const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
65
66 HloComputation* wide_while_body = [&]() {
67 HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
68 builder.AddInstruction(
69 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
70 return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
71 }();
72
73 HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
74 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
75 wide_parameter, narrow_shape.tuple_shapes_size());
76 HloInstruction* call_narrow_body =
77 wide_while_body->AddInstruction(HloInstruction::CreateCall(
78 narrow_shape, {truncated_parameter}, narrow_body));
79
80 std::vector<HloInstruction*> live_through_values;
81 for (int i = narrow_shape.tuple_shapes_size();
82 i < wide_shape.tuple_shapes_size(); i++) {
83 live_through_values.push_back(
84 wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
85 wide_shape.tuple_shapes(i), wide_parameter, i)));
86 }
87
88 wide_while_body->set_root_instruction(
89 TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
90
91 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
92 CallInliner::Inline(call_narrow_body));
93 return {{wide_while_body, std::move(inlined_instructions_map)}};
94 }
95
96 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)97 WhileUtil::MakeInstructionsLiveIn(
98 HloInstruction* while_instr,
99 absl::Span<HloInstruction* const> instructions) {
100 CHECK(while_instr->shape().IsTuple());
101
102 int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
103 Shape new_while_shape = while_instr->shape();
104 for (auto* instruction : instructions) {
105 *new_while_shape.add_tuple_shapes() = instruction->shape();
106 }
107
108 TF_ASSIGN_OR_RETURN(
109 HloComputation * new_while_condition,
110 WidenWhileCondition(while_instr->while_condition(), new_while_shape));
111
112 HloComputation* new_while_body;
113 CallInliner::InlinedInstructionMap inlined_instructions_map;
114 TF_ASSIGN_OR_RETURN(
115 std::tie(new_while_body, inlined_instructions_map),
116 WidenWhileBody(while_instr->while_body(), new_while_shape));
117
118 HloInstruction* new_while_init =
119 TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
120 HloComputation* containing_computation = while_instr->parent();
121 HloInstruction* new_while = containing_computation->AddInstruction(
122 HloInstruction::CreateWhile(new_while_shape, new_while_condition,
123 new_while_body, new_while_init));
124
125 // We want to get rid of the old while instruction even if it has side
126 // effecting operations so we do a manual HloComputation::RemoveInstruction
127 // instead of relying on HloComputation::ReplaceInstruction.
128 TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
129 new_while, while_instr->shape().tuple_shapes_size())));
130 TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
131
132 HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
133 std::vector<HloInstruction*> live_in_instructions;
134 for (int64 i = elements_in_old_while_shape;
135 i < new_while_shape.tuple_shapes_size(); i++) {
136 live_in_instructions.push_back(
137 new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
138 instructions[i - elements_in_old_while_shape]->shape(),
139 while_body_param, i)));
140 }
141
142 WhileUtil::MakeInstructionsLiveInResult result;
143
144 result.new_while_instr = new_while;
145 result.while_body_live_in_values = std::move(live_in_instructions);
146 result.while_body_instruction_map = std::move(inlined_instructions_map);
147
148 return std::move(result);
149 }
150
151 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32 trip_count)152 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
153 int32 trip_count) {
154 Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
155
156 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
157 CreateComputationWithSignature(
158 {&loop_state_shape}, scalar_pred, "while_cond"));
159
160 HloInstruction* trip_count_constant = cond_computation->AddInstruction(
161 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
162
163 HloInstruction* param = cond_computation->parameter_instruction(0);
164 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
165 MakeGetTupleElementHlo(param, 0));
166
167 TF_ASSIGN_OR_RETURN(
168 HloInstruction * compare,
169 MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
170 cond_computation->set_root_instruction(compare);
171 return std::move(cond_computation);
172 }
173
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)174 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
175 const Shape& loop_state_shape,
176 const std::function<StatusOr<WhileUtil::LoopStateTy>(
177 HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
178 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
179 CreateComputationWithSignature(
180 {&loop_state_shape}, loop_state_shape, "while_body"));
181 HloInstruction* one = body_computation->AddInstruction(
182 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
183 HloInstruction* param = body_computation->parameter_instruction(0);
184 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
185 MakeGetTupleElementHlo(param, 0));
186 TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
187 MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
188
189 std::vector<HloInstruction*> loop_body_generator_args;
190 for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
191 TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
192 MakeGetTupleElementHlo(param, i));
193 loop_body_generator_args.push_back(tuple_element);
194 }
195 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
196 loop_body_generator(indvar, loop_body_generator_args));
197 next_state.insert(next_state.begin(), next_indvar);
198 HloInstruction* next_state_tuple =
199 body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
200 body_computation->set_root_instruction(next_state_tuple);
201
202 return std::move(body_computation);
203 }
204
MakeInitTupleFromInitValues(HloComputation * computation,const WhileUtil::LoopStateTy & init_values)205 static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
206 HloComputation* computation, const WhileUtil::LoopStateTy& init_values) {
207 std::vector<HloInstruction*> init_values_with_indvar;
208 init_values_with_indvar.reserve(init_values.size() + 1);
209 HloInstruction* zero = computation->AddInstruction(
210 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
211 init_values_with_indvar.push_back(zero);
212 absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
213 return computation->AddInstruction(
214 HloInstruction::CreateTuple(init_values_with_indvar));
215 }
216
MakeLoopStateShape(const WhileUtil::LoopStateTy & init_values)217 static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
218 std::vector<Shape> loop_state_shape_components;
219 loop_state_shape_components.reserve(init_values.size() + 1);
220 loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
221 absl::c_transform(init_values,
222 std::back_inserter(loop_state_shape_components),
223 [](HloInstruction* instr) { return instr->shape(); });
224 return ShapeUtil::MakeTupleShape(loop_state_shape_components);
225 }
226
MakeCountedLoop(HloComputation * computation,int32 trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)227 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
228 HloComputation* computation, int32 trip_count,
229 const WhileUtil::LoopStateTy& init_values,
230 const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
231 const OpMetadata& metadata) {
232 CHECK_GE(trip_count, 0);
233
234 Shape loop_state_shape = MakeLoopStateShape(init_values);
235 TF_ASSIGN_OR_RETURN(
236 std::unique_ptr<HloComputation> cond,
237 MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
238 TF_ASSIGN_OR_RETURN(
239 std::unique_ptr<HloComputation> body,
240 MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
241 TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple,
242 MakeInitTupleFromInitValues(computation, init_values));
243 HloModule* module = computation->parent();
244 HloInstruction* while_instr =
245 computation->AddInstruction(HloInstruction::CreateWhile(
246 loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
247 module->AddEmbeddedComputation(std::move(body)), init_tuple));
248 while_instr->set_metadata(metadata);
249
250 std::vector<HloInstruction*> result;
251 for (int64 i = 0, e = init_values.size(); i < e; i++) {
252 TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
253 MakeGetTupleElementHlo(while_instr, i + 1));
254 result.push_back(user_state);
255 }
256 return result;
257 }
258
GetInvariantGTEsForWhileBody(const HloComputation & while_body)259 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
260 const HloComputation& while_body) {
261 std::vector<HloInstruction*> result;
262 const HloInstruction::InstructionVector root_operands =
263 while_body.root_instruction()->operands();
264 for (int i = 0; i < root_operands.size(); i++) {
265 HloInstruction* instr = root_operands[i];
266 if (instr->opcode() == HloOpcode::kGetTupleElement &&
267 instr->tuple_index() == i &&
268 instr->operand(0) == while_body.parameter_instruction(0)) {
269 result.push_back(instr);
270 }
271 }
272 return result;
273 }
274
275 /*static*/ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)276 WhileUtil::GetGTEsMapForWhileConditional(
277 const HloComputation& while_conditional) {
278 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> result;
279 for (HloInstruction* user :
280 while_conditional.parameter_instruction(0)->users()) {
281 if (user->opcode() == HloOpcode::kGetTupleElement) {
282 result[user->tuple_index()].push_back(user);
283 }
284 }
285 return result;
286 }
287
288 } // namespace xla
289