1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
25 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 using ::testing::ElementsAre;
40 using ::testing::UnorderedElementsAre;
41
42 // Test is parameterized on a bool which is whether the dataflow analysis is
43 // performed with SSA form.
44 class HloDataflowAnalysisTest : public HloTestBase,
45 public ::testing::WithParamInterface<bool> {
46 protected:
HloDataflowAnalysisTest()47 HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
48
49 // Run dataflow analysis on the member module. For convenience returns a
50 // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false)51 const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
52 bool bitcast_defines_value = false) {
53 analysis_ =
54 HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
55 .ConsumeValueOrDie();
56 return *analysis_;
57 }
58
59 // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})60 std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
61 const ShapeIndex& index = {}) {
62 CHECK(analysis_ != nullptr);
63 std::vector<HloValue> values;
64 for (const HloValue* value :
65 analysis_->GetValueSet(instruction, index).values()) {
66 values.push_back(*value);
67 }
68 return values;
69 }
70
71 // Returns true if the top-level values for instructions 'a' and 'b' may
72 // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)73 bool InstructionsMayInterfere(const HloOrdering& ordering,
74 const HloInstruction* a,
75 const HloInstruction* b) {
76 EXPECT_FALSE(a->shape().IsTuple());
77 EXPECT_FALSE(b->shape().IsTuple());
78 return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
79 analysis_->GetValueDefinedAt(b), *analysis_);
80 }
81
CreateR0F32UnaryOpComputation(HloOpcode opcode)82 std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
83 HloOpcode opcode) {
84 HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
85 HloInstruction* param0 = builder.AddInstruction(
86 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
87 builder.AddInstruction(
88 HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
89 return builder.Build();
90 }
91
92 std::unique_ptr<HloModule> module_;
93 std::unique_ptr<HloDataflowAnalysis> analysis_;
94
95 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
96 const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
97 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
98 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
99 };
100
TEST_P(HloDataflowAnalysisTest,BinaryOperation)101 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
102 // Test the dataflow for a simple binary operation (Add).
103 auto builder = HloComputation::Builder(TestName());
104 auto constant1 = builder.AddInstruction(
105 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
106 auto constant2 = builder.AddInstruction(
107 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
108 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
109 scalar_shape_, HloOpcode::kAdd, constant1, constant2));
110 module_->AddEntryComputation(builder.Build());
111 SCOPED_TRACE(module_->ToString());
112
113 bool ssa_form = GetParam();
114 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
115
116 // Each instruction should define a single value.
117 EXPECT_EQ(analysis.values().size(), 3);
118 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
119 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
120 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
121
122 // Verify the positions of the values. These positions are all trivial because
123 // there are no instructions which forward values.
124 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
125 UnorderedElementsAre(HloPosition{constant1, {}}));
126 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
127 UnorderedElementsAre(HloPosition{constant2, {}}));
128 EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
129 UnorderedElementsAre(HloPosition{add, {}}));
130
131 // Verify the uses of the values.
132 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
133 UnorderedElementsAre(HloUse{add, 0, {}}));
134 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
135 UnorderedElementsAre(HloUse{add, 1, {}}));
136 EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty());
137
138 // Verify liveout values from the module.
139 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
140 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
141 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
142 }
143
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)144 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
145 // Verify the dataflow through a Tuple and GetTupleElement instructions.
146 auto builder = HloComputation::Builder(TestName());
147 auto param0 = builder.AddInstruction(
148 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
149 auto param1 = builder.AddInstruction(
150 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
151 auto tuple =
152 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
153 auto gte0 = builder.AddInstruction(
154 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
155 auto gte1 = builder.AddInstruction(
156 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
157 auto add = builder.AddInstruction(
158 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
159 module_->AddEntryComputation(builder.Build());
160 SCOPED_TRACE(module_->ToString());
161
162 bool ssa_form = GetParam();
163 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
164
165 // The two params, tuple, and add should each define one value.
166 EXPECT_EQ(analysis.values().size(), 4);
167
168 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
169 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
170 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
171 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
172 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
173 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
174 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
175 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
176
177 // Verify the positions of the values.
178 EXPECT_THAT(
179 analysis.GetValueDefinedAt(param0).positions(),
180 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
181 HloPosition{gte0, {}}));
182 EXPECT_THAT(
183 analysis.GetValueDefinedAt(param1).positions(),
184 UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
185 HloPosition{gte1, {}}));
186 EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
187 UnorderedElementsAre(HloPosition{tuple, {}}));
188
189 // Verify uses. Of interest is that a GetTupleElement instruction is only a
190 // use of the top-level value in the tuple operand.
191 EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
192 UnorderedElementsAre(HloUse{add, 0, {}}));
193 EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
194 UnorderedElementsAre(HloUse{add, 1, {}}));
195 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
196 UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
197 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
198 }
199
TEST_P(HloDataflowAnalysisTest,NestedTuple)200 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
201 // Verify the dataflow through a nested tuple.
202 auto builder = HloComputation::Builder(TestName());
203 auto constant1 = builder.AddInstruction(
204 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
205 auto constant2 = builder.AddInstruction(
206 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
207 auto tuple = builder.AddInstruction(
208 HloInstruction::CreateTuple({constant1, constant2}));
209 auto nested_tuple = builder.AddInstruction(
210 HloInstruction::CreateTuple({tuple, tuple, constant1}));
211 auto gte_tuple = builder.AddInstruction(
212 HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
213 auto gte_out = builder.AddInstruction(
214 HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
215 module_->AddEntryComputation(builder.Build());
216 SCOPED_TRACE(module_->ToString());
217
218 bool ssa_form = GetParam();
219 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
220
221 EXPECT_EQ(analysis.values().size(), 4);
222
223 // Verify positions and uses.
224 EXPECT_THAT(
225 analysis.GetValueDefinedAt(constant1).positions(),
226 UnorderedElementsAre(
227 HloPosition{constant1, {}}, HloPosition{tuple, {0}},
228 HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
229 HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
230 HloPosition{gte_out, {}}));
231 // Constant values should have only a single use, which is the root of the
232 // computation.
233 EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
234 UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
235 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
236
237 // The top-level tuple values are used in GTE instructions.
238 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
239 UnorderedElementsAre(HloUse{gte_out, 0, {}}));
240 EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
241 UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
242
243 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
244 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
245 EXPECT_FALSE(
246 analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
247 EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
248 .live_out_of_module());
249 }
250
TEST_P(HloDataflowAnalysisTest,SingleCall)251 TEST_P(HloDataflowAnalysisTest, SingleCall) {
252 // Test a single call of a subcomputation. The subcomputation adds its two
253 // array-shaped parameters.
254 auto subbuilder = HloComputation::Builder("Subcomputation");
255 auto subparam0 = subbuilder.AddInstruction(
256 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
257 auto subparam1 = subbuilder.AddInstruction(
258 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
259 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
260 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
261 HloComputation* called_computation =
262 module_->AddEmbeddedComputation(subbuilder.Build());
263
264 auto builder = HloComputation::Builder(TestName());
265 auto constant1 = builder.AddInstruction(
266 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
267 auto constant2 = builder.AddInstruction(
268 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
269 auto call = builder.AddInstruction(HloInstruction::CreateCall(
270 scalar_shape_, {constant1, constant2}, called_computation));
271 module_->AddEntryComputation(builder.Build());
272 SCOPED_TRACE(module_->ToString());
273
274 bool ssa_form = GetParam();
275 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
276
277 EXPECT_EQ(analysis.values().size(), 3);
278
279 // The parameters of the subcomputation and the call instruction itself should
280 // not define values. Their values flow from elsewhere.
281 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
282 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
283 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
284 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
285 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
286 EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
287
288 EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
289 analysis.GetValueDefinedAt(constant1));
290 EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
291 analysis.GetValueDefinedAt(constant2));
292 EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
293
294 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
295 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
296 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
297 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
298
299 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
300 }
301
TEST_P(HloDataflowAnalysisTest,ComputationCalledTwiceWithSameArguments)302 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
303 // Test a subcomputation which is called twice with identical values.
304 auto subbuilder = HloComputation::Builder("Subcomputation");
305 auto subparam0 = subbuilder.AddInstruction(
306 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
307 auto subparam1 = subbuilder.AddInstruction(
308 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
309 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
310 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
311 HloComputation* called_computation =
312 module_->AddEmbeddedComputation(subbuilder.Build());
313
314 auto builder = HloComputation::Builder(TestName());
315 auto constant1 = builder.AddInstruction(
316 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
317 auto constant2 = builder.AddInstruction(
318 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
319 auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
320 scalar_shape_, {constant1, constant2}, called_computation));
321 auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
322 scalar_shape_, {constant1, constant2}, called_computation));
323 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
324 scalar_shape_, HloOpcode::kSubtract, call1, call2));
325 module_->AddEntryComputation(builder.Build());
326 SCOPED_TRACE(module_->ToString());
327
328 bool ssa_form = GetParam();
329 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
330
331 EXPECT_EQ(analysis.values().size(), 4);
332
333 // Definitions should be identical to the single callsite case.
334 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
335 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
336 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
337 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
338 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
339 EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
340 EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
341 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
342
343 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
344 UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
345 HloUse{add, 0, {}}));
346 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
347 UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
348 HloUse{add, 1, {}}));
349 // The Add from the subcomputation is used as both operands of the Subtract.
350 EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
351 UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
352
353 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
354 EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
355 }
356
TEST_P(HloDataflowAnalysisTest,ComputationCalledTwiceWithDifferentArguments)357 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
358 // Test a subcomputation which is called twice with different argument values.
359 auto subbuilder = HloComputation::Builder("Subcomputation");
360 auto subparam0 = subbuilder.AddInstruction(
361 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
362 auto subparam1 = subbuilder.AddInstruction(
363 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
364 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
365 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
366 HloComputation* called_computation =
367 module_->AddEmbeddedComputation(subbuilder.Build());
368
369 auto builder = HloComputation::Builder(TestName());
370 auto constant1 = builder.AddInstruction(
371 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
372 auto constant2 = builder.AddInstruction(
373 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
374 auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
375 scalar_shape_, {constant1, constant2}, called_computation));
376 auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
377 scalar_shape_, {call1, constant2}, called_computation));
378 module_->AddEntryComputation(builder.Build());
379 SCOPED_TRACE(module_->ToString());
380
381 bool ssa_form = GetParam();
382 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
383
384 EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
385 EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
386
387 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
388
389 EXPECT_THAT(HloValuesAt(subparam0),
390 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
391 analysis.GetValueDefinedAt(add)));
392 EXPECT_THAT(HloValuesAt(subparam1),
393 UnorderedElementsAre(analysis.GetValueDefinedAt(constant2)));
394
395 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
396 }
397
TEST_P(HloDataflowAnalysisTest,NestedCalls)398 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
399 // Test a module with nested computations. HLO is:
400 //
401 // F32[] inner_computation(F32[] %param0, F32[] %param1):
402 // %add = Add(%param0, %param1)
403 //
404 // F32[] outer_computation((F32[] %param0, F32[] %param1):
405 // ;; Note that parameters are interchanged in the call.
406 // %nested_call = Call(inner_computation, {%param1, %param0})
407 //
408 // F32[] entry:
409 // %constant1 = Constant(1.0)
410 // %constant2 = Constant(2.0)
411 // %call = Call(outer_computation, {%constant1, %constant2})
412 //
413 auto inner_builder = HloComputation::Builder("InnerComputation");
414 auto inner_param0 = inner_builder.AddInstruction(
415 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
416 auto inner_param1 = inner_builder.AddInstruction(
417 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
418 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
419 scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
420 HloComputation* inner_computation =
421 module_->AddEmbeddedComputation(inner_builder.Build());
422
423 auto outer_builder = HloComputation::Builder("OuterComputation");
424 auto outer_param0 = outer_builder.AddInstruction(
425 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
426 auto outer_param1 = outer_builder.AddInstruction(
427 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
428 // Swizzle parameters.
429 auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
430 scalar_shape_, {outer_param1, outer_param0}, inner_computation));
431 HloComputation* outer_computation =
432 module_->AddEmbeddedComputation(outer_builder.Build());
433
434 auto builder = HloComputation::Builder(TestName());
435 auto constant1 = builder.AddInstruction(
436 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
437 auto constant2 = builder.AddInstruction(
438 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
439 auto call = builder.AddInstruction(HloInstruction::CreateCall(
440 scalar_shape_, {constant1, constant2}, outer_computation));
441 module_->AddEntryComputation(builder.Build());
442 SCOPED_TRACE(module_->ToString());
443
444 bool ssa_form = GetParam();
445 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
446
447 // Only three values should be defined. Most instructions just pass through
448 // their operand values.
449 EXPECT_EQ(analysis.values().size(), 3);
450
451 // Verify that the uses of the constants are properly swizzled by parameter
452 // permutation in nested_call.
453 EXPECT_THAT(
454 analysis.GetValueDefinedAt(constant1).uses(),
455 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
456 HloUse{add, 1, {}}));
457 EXPECT_THAT(
458 analysis.GetValueDefinedAt(constant2).uses(),
459 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
460 HloUse{add, 0, {}}));
461
462 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
463 }
464
TEST_P(HloDataflowAnalysisTest,SingleWhile)465 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
466 // Test a simple single while instruction. The while body includes a
467 // pass-through value. HLO:
468 //
469 // body((F32[], F32[]) %tuple_param):
470 // %add = Add(%tuple_param{0}, %tuple_param{1})
471 // return Tuple(%tuple_param{0}, %add)
472 //
473 // condition((F32[], F32[]) %tuple_param):
474 // return Constant(false)
475 //
476 // entry:
477 // %constant1 = Constant(1.0)
478 // %constant2 = Constant(2.0)
479 // %tuple = Tuple(%constant1, %constant2)
480 // return While(%tuple, body, condition)
481 //
482 const Shape tuple_shape =
483 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
484
485 // Element 0 passes transparently through the body.
486 auto body_builder = HloComputation::Builder("body");
487 auto body_param = body_builder.AddInstruction(
488 HloInstruction::CreateParameter(0, tuple_shape, "param"));
489 auto body_element_0 = body_builder.AddInstruction(
490 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
491 auto body_element_1 = body_builder.AddInstruction(
492 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
493 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
494 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
495 auto body_root = body_builder.AddInstruction(
496 HloInstruction::CreateTuple({body_element_0, add}));
497 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
498
499 // Condition computation trivially returns a constant "false".
500 auto cond_builder = HloComputation::Builder("condition");
501 auto cond_param = cond_builder.AddInstruction(
502 HloInstruction::CreateParameter(0, tuple_shape, "param"));
503 auto cond_constant = cond_builder.AddInstruction(
504 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
505 HloComputation* condition =
506 module_->AddEmbeddedComputation(cond_builder.Build());
507
508 auto builder = HloComputation::Builder(TestName());
509 auto constant1 = builder.AddInstruction(
510 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
511 auto constant2 = builder.AddInstruction(
512 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
513 auto tuple = builder.AddInstruction(
514 HloInstruction::CreateTuple({constant1, constant2}));
515 auto xla_while = builder.AddInstruction(
516 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
517 module_->AddEntryComputation(builder.Build());
518 SCOPED_TRACE(module_->ToString());
519
520 bool ssa_form = GetParam();
521 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
522
523 EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
524
525 if (ssa_form) {
526 // Element 0 of the tuple passed through the body so no phi value is
527 // defined.
528 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
529 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
530 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
531
532 // Element 1 of the tuple should be a phi value.
533 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
534 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
535 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
536 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
537 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
538 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
539
540 EXPECT_THAT(
541 analysis.GetValueDefinedAt(constant1).uses(),
542 UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
543 HloUse{xla_while, 0, {0}}));
544
545 // Constant1 passes through the body and out of the module.
546 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
547 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
548 .live_out_of_module());
549
550 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
551 } else {
552 // While instruction and subcomputation parameters should not define values
553 // in non-ssa form.
554 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
555 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
556 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
557 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
558 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
559 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
560
561 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
562 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
563 }
564 }
565
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)566 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
567 // Test sequential while instructions. The while body includes a
568 // pass-through value. HLO:
569 //
570 // body((F32[], F32[]) %tuple_param):
571 // %add = Add(%tuple_param{0}, %tuple_param{1})
572 // return Tuple(%tuple_param{0}, %add)
573 //
574 // condition((F32[], F32[]) %tuple_param):
575 // return Constant(false)
576 //
577 // entry:
578 // %constant1 = Constant(1.0)
579 // %constant2 = Constant(2.0)
580 // %tuple = Tuple(%constant1, %constant2)
581 // %while0 = While(%tuple, body, condition)
582 // %while1 = While(%while0, body, condition)
583 // return While(%while1, body, condition)
584 //
585 const Shape tuple_shape =
586 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
587
588 // Element 0 passes transparently through the body.
589 auto body_builder = HloComputation::Builder("body");
590 auto body_param = body_builder.AddInstruction(
591 HloInstruction::CreateParameter(0, tuple_shape, "param"));
592 auto body_element_0 = body_builder.AddInstruction(
593 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
594 auto body_element_1 = body_builder.AddInstruction(
595 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
596 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
597 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
598 body_builder.AddInstruction(
599 HloInstruction::CreateTuple({body_element_0, add}));
600 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
601
602 auto cond_builder = HloComputation::Builder("condition");
603 cond_builder.AddInstruction(
604 HloInstruction::CreateParameter(0, tuple_shape, "param"));
605 cond_builder.AddInstruction(
606 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
607 HloComputation* condition =
608 module_->AddEmbeddedComputation(cond_builder.Build());
609
610 auto builder = HloComputation::Builder(TestName());
611 auto constant1 = builder.AddInstruction(
612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
613 auto constant2 = builder.AddInstruction(
614 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
615 auto tuple = builder.AddInstruction(
616 HloInstruction::CreateTuple({constant1, constant2}));
617 auto xla_while0 = builder.AddInstruction(
618 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
619 auto xla_while1 = builder.AddInstruction(
620 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
621 auto xla_while2 = builder.AddInstruction(
622 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
623 module_->AddEntryComputation(builder.Build());
624 SCOPED_TRACE(module_->ToString());
625
626 bool ssa_form = GetParam();
627 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
628
629 // Element 0 is passed through all the while instructions and out of the
630 // module..
631 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
632 analysis.GetValueDefinedAt(constant1));
633 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
634 analysis.GetValueDefinedAt(constant1));
635 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
636 analysis.GetValueDefinedAt(constant1));
637 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
638 }
639
TEST_P(HloDataflowAnalysisTest,NestedWhiles)640 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
641 // Test nested while instructions. The inner body passes through element 0 of
642 // its parameter, and the outer body passes through element 1. HLO:
643 //
644 // inner_body((F32[], F32[]) %tuple_param):
645 // %add = Add(%tuple_param{0}, %tuple_param{1})
646 // return Tuple(%tuple_param{0}, %add)
647 //
648 // outer_body((F32[], F32[]) %tuple_param):
649 // %negate = Negate(%tuple_param{0})
650 // %tuple = Tuple(%negate, %tuple_param{1})
651 // return While(%tuple, inner_body, condition)
652 //
653 // entry:
654 // %constant1 = Constant(1.0)
655 // %constant2 = Constant(2.0)
656 // %tuple = Tuple(%constant1, %constant2)
657 // return While(%tuple, outer_body, condition)
658 //
659 const Shape tuple_shape =
660 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
661
662 auto cond_builder = HloComputation::Builder("condition");
663 cond_builder.AddInstruction(
664 HloInstruction::CreateParameter(0, tuple_shape, "param"));
665 cond_builder.AddInstruction(
666 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
667 HloComputation* condition =
668 module_->AddEmbeddedComputation(cond_builder.Build());
669
670 // Element 0 passes transparently through the body.
671 auto inner_builder = HloComputation::Builder("inner_body");
672 auto inner_param = inner_builder.AddInstruction(
673 HloInstruction::CreateParameter(0, tuple_shape, "param"));
674 auto inner_element_0 = inner_builder.AddInstruction(
675 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
676 auto inner_element_1 = inner_builder.AddInstruction(
677 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
678 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
679 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
680 inner_builder.AddInstruction(
681 HloInstruction::CreateTuple({inner_element_0, add}));
682 HloComputation* inner_body =
683 module_->AddEmbeddedComputation(inner_builder.Build());
684
685 // Element 1 passes transparently through the body.
686 auto outer_builder = HloComputation::Builder("outer_body");
687 auto outer_param = outer_builder.AddInstruction(
688 HloInstruction::CreateParameter(0, tuple_shape, "param"));
689 auto outer_element_0 = outer_builder.AddInstruction(
690 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
691 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
692 scalar_shape_, HloOpcode::kNegate, outer_element_0));
693 auto outer_element_1 = outer_builder.AddInstruction(
694 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
695 auto outer_tuple = outer_builder.AddInstruction(
696 HloInstruction::CreateTuple({negate, outer_element_1}));
697 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
698 tuple_shape, condition, inner_body, outer_tuple));
699 HloComputation* outer_body =
700 module_->AddEmbeddedComputation(outer_builder.Build());
701
702 auto builder = HloComputation::Builder(TestName());
703 auto constant1 = builder.AddInstruction(
704 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
705 auto constant2 = builder.AddInstruction(
706 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
707 auto tuple = builder.AddInstruction(
708 HloInstruction::CreateTuple({constant1, constant2}));
709 auto entry_while = builder.AddInstruction(
710 HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
711 module_->AddEntryComputation(builder.Build());
712 SCOPED_TRACE(module_->ToString());
713
714 bool ssa_form = GetParam();
715 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
716
717 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
718 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
719 if (ssa_form) {
720 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
721 EXPECT_TRUE(
722 analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
723
724 // Element 0 of the nested while is %negate.
725 EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
726 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
727 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
728 // Element 1 is a phi value (join of %add and %constant2).
729 EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
730 EXPECT_TRUE(
731 analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
732
733 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
734 EXPECT_TRUE(
735 analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
736
737 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
738 EXPECT_TRUE(
739 analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
740 } else {
741 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
742 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
743 analysis.GetValueDefinedAt(constant2)));
744
745 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
746 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
747 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
748 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
749 analysis.GetValueDefinedAt(constant2)));
750
751 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
752 UnorderedElementsAre(analysis.GetValueDefinedAt(negate),
753 analysis.GetValueDefinedAt(constant1)));
754 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
755 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
756 analysis.GetValueDefinedAt(constant2)));
757 }
758 }
759
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)760 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
761 // Test a while instruction with a body which permutes it's tuple parameter
762 // elements. HLO:
763 //
764 // body((F32[], F32[]) %tuple_param):
765 // return Tuple(%tuple_param{1}, %tuple_param{0})
766 //
767 // condition((F32[], F32[]) %tuple_param):
768 // return Constant(false)
769 //
770 // entry:
771 // %constant1 = Constant(1.0)
772 // %constant2 = Constant(2.0)
773 // %tuple = Tuple(%constant1, %constant2)
774 // return While(%tuple, body, condition)
775 //
776 const Shape tuple_shape =
777 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
778
779 auto body_builder = HloComputation::Builder("body");
780 auto body_param = body_builder.AddInstruction(
781 HloInstruction::CreateParameter(0, tuple_shape, "param"));
782 auto body_element_0 = body_builder.AddInstruction(
783 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
784 auto body_element_1 = body_builder.AddInstruction(
785 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
786 body_builder.AddInstruction(
787 HloInstruction::CreateTuple({body_element_1, body_element_0}));
788 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
789
790 auto cond_builder = HloComputation::Builder("condition");
791 auto cond_param = cond_builder.AddInstruction(
792 HloInstruction::CreateParameter(0, tuple_shape, "param"));
793 cond_builder.AddInstruction(
794 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
795 HloComputation* condition =
796 module_->AddEmbeddedComputation(cond_builder.Build());
797
798 auto builder = HloComputation::Builder(TestName());
799 auto constant1 = builder.AddInstruction(
800 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
801 auto constant2 = builder.AddInstruction(
802 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
803 auto tuple = builder.AddInstruction(
804 HloInstruction::CreateTuple({constant1, constant2}));
805 auto xla_while = builder.AddInstruction(
806 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
807 module_->AddEntryComputation(builder.Build());
808 SCOPED_TRACE(module_->ToString());
809
810 bool ssa_form = GetParam();
811 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
812
813 if (ssa_form) {
814 // Element 0 and 1 in the while should both be phi values.
815 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
816 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
817 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
818 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
819
820 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
821 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
822 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
823 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
824
825 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
826 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
827 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
828 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
829
830 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
831 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
832 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
833 .live_out_of_module());
834 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
835 .live_out_of_module());
836 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
837 .live_out_of_module());
838 } else {
839 // Elements 0 and 1 have both constants as reaching definitions.
840 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
841 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
842 analysis.GetValueDefinedAt(constant2)));
843 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
844 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
845 analysis.GetValueDefinedAt(constant2)));
846 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
847 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
848 }
849 }
850
TEST_P(HloDataflowAnalysisTest,ArraySelect)851 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
852 // Test a kSelect of an array value.
853 auto builder = HloComputation::Builder(TestName());
854 auto pred = builder.AddInstruction(
855 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
856 auto constant1 = builder.AddInstruction(
857 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
858 auto constant2 = builder.AddInstruction(
859 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
860 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
861 scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
862
863 module_->AddEntryComputation(builder.Build());
864 SCOPED_TRACE(module_->ToString());
865
866 bool ssa_form = GetParam();
867 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
868
869 EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
870 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
871 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
872 EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
873 }
874
TEST_P(HloDataflowAnalysisTest,TupleSelect)875 TEST_P(HloDataflowAnalysisTest, TupleSelect) {
876 // Test a kTupleSelect. Non-top-level element flow through the instruction.
877 auto builder = HloComputation::Builder(TestName());
878 auto pred = builder.AddInstruction(
879 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
880 auto constant1 = builder.AddInstruction(
881 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
882 auto constant2 = builder.AddInstruction(
883 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
884 auto constant3 = builder.AddInstruction(
885 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
886 auto constant4 = builder.AddInstruction(
887 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
888 auto tuple1 =
889 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
890 auto tuple2 =
891 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
892 auto tuple3 =
893 builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
894 auto tuple4 =
895 builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
896 const Shape tuple_shape = tuple1->shape();
897 auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
898 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
899 auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
900 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
901 auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
902 tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
903 auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
904 tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
905
906 module_->AddEntryComputation(builder.Build());
907 SCOPED_TRACE(module_->ToString());
908
909 bool ssa_form = GetParam();
910 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
911
912 // Top-level value is always defined by a kTupleSelect.
913 EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
914 EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
915 EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
916 EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
917
918 EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
919 EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
920 EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
921 EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
922
923 EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
924 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
925 EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}),
926 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
927 analysis.GetValueDefinedAt(constant2)));
928 EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}),
929 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3),
930 analysis.GetValueDefinedAt(constant4)));
931 EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}),
932 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
933 analysis.GetValueDefinedAt(constant2),
934 analysis.GetValueDefinedAt(constant3),
935 analysis.GetValueDefinedAt(constant4)));
936
937 EXPECT_THAT(
938 analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
939 UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
940 HloUse{select12, 1, {}}));
941
942 // The two constant values just pass through the Selects and are not
943 // used except at the root. They are live out however.
944 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
945 UnorderedElementsAre(HloUse{select1234, 1, {0}}));
946 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
947 UnorderedElementsAre(HloUse{select1234, 1, {0}}));
948 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
949 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
950 }
951
TEST_P(HloDataflowAnalysisTest,NestedTupleSelect)952 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
953 // Test kTupleSelect of a nested tuple.
954 auto builder = HloComputation::Builder(TestName());
955 auto pred = builder.AddInstruction(
956 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
957 auto constant1 = builder.AddInstruction(
958 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
959 auto constant2 = builder.AddInstruction(
960 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
961 auto constant3 = builder.AddInstruction(
962 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
963 auto constant4 = builder.AddInstruction(
964 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
965 auto constant5 = builder.AddInstruction(
966 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
967 auto inner_tuple1 = builder.AddInstruction(
968 HloInstruction::CreateTuple({constant2, constant3}));
969 auto tuple1 = builder.AddInstruction(
970 HloInstruction::CreateTuple({constant1, inner_tuple1}));
971 auto inner_tuple2 = builder.AddInstruction(
972 HloInstruction::CreateTuple({constant5, constant3}));
973 auto tuple2 = builder.AddInstruction(
974 HloInstruction::CreateTuple({constant4, inner_tuple2}));
975 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
976 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
977
978 module_->AddEntryComputation(builder.Build());
979 SCOPED_TRACE(module_->ToString());
980
981 bool ssa_form = GetParam();
982 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
983
984 EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
985
986 EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
987 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
988 analysis.GetValueDefinedAt(constant4)));
989 EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
990 UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
991 analysis.GetValueDefinedAt(inner_tuple2)));
992 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
993 UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
994 analysis.GetValueDefinedAt(constant5)));
995 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
996 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
997 }
998
TEST_P(HloDataflowAnalysisTest,TupleSelectToWhile)999 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
1000 // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
1001 //
1002 // body((F32[], F32[]) %tuple_param):
1003 // %add = Add(%tuple_param{0}, %tuple_param{1})
1004 // return Tuple(%tuple_param{0}, %add)
1005 //
1006 // condition((F32[], F32[]) %tuple_param):
1007 // return Constant(false)
1008 //
1009 // entry:
1010 // %constant1 = Constant(1.0)
1011 // %constant2 = Constant(2.0)
1012 // %constant3 = Constant(3.0)
1013 // %tuple1 = Tuple(%constant1)
1014 // %tuple2 = Tuple(%constant2)
1015 // %select = Select(%tuple1, %tuple2)
1016 // %gte = GetTupleElement(%select, 0)
1017 // %tuple = Tuple(%gte, %constant3)
1018 // return While(%tuple, body, condition)
1019 //
1020 auto builder = HloComputation::Builder(TestName());
1021
1022 const Shape tuple_shape =
1023 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1024
1025 // Element 0 passes transparently through the body.
1026 auto body_builder = HloComputation::Builder("body");
1027 auto body_param = body_builder.AddInstruction(
1028 HloInstruction::CreateParameter(0, tuple_shape, "param"));
1029 auto body_element_0 = body_builder.AddInstruction(
1030 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1031 auto body_element_1 = body_builder.AddInstruction(
1032 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1033 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1034 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
1035 body_builder.AddInstruction(
1036 HloInstruction::CreateTuple({body_element_0, add}));
1037 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1038
1039 auto cond_builder = HloComputation::Builder("condition");
1040 cond_builder.AddInstruction(
1041 HloInstruction::CreateParameter(0, tuple_shape, "param"));
1042 cond_builder.AddInstruction(
1043 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1044 HloComputation* condition =
1045 module_->AddEmbeddedComputation(cond_builder.Build());
1046
1047 auto pred = builder.AddInstruction(
1048 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1049 auto constant1 = builder.AddInstruction(
1050 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1051 auto constant2 = builder.AddInstruction(
1052 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1053 auto constant3 = builder.AddInstruction(
1054 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1055 auto tuple1 =
1056 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
1057 auto tuple2 =
1058 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
1059 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1060 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1061 auto gte = builder.AddInstruction(
1062 HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
1063 auto tuple =
1064 builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3}));
1065 auto xla_while = builder.AddInstruction(
1066 HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
1067
1068 module_->AddEntryComputation(builder.Build());
1069 SCOPED_TRACE(module_->ToString());
1070
1071 bool ssa_form = GetParam();
1072 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1073
1074 if (ssa_form) {
1075 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
1076 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
1077 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
1078 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
1079
1080 EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
1081
1082 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1083 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1084 EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1085 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
1086 .live_out_of_module());
1087 } else {
1088 EXPECT_THAT(HloValuesAt(gte),
1089 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1090 analysis.GetValueDefinedAt(constant2)));
1091 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
1092 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1093 analysis.GetValueDefinedAt(constant2)));
1094 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
1095 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1096 analysis.GetValueDefinedAt(constant3)));
1097 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1098 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1099 EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1100 }
1101 }
1102
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)1103 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
1104 // Test the bitcast_defines_value flag to the dataflow analysis.
1105 auto builder = HloComputation::Builder(TestName());
1106 auto constant = builder.AddInstruction(
1107 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1108 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1109 scalar_shape_, HloOpcode::kBitcast, constant));
1110
1111 module_->AddEntryComputation(builder.Build());
1112 SCOPED_TRACE(module_->ToString());
1113
1114 bool ssa_form = GetParam();
1115 {
1116 const HloDataflowAnalysis& analysis =
1117 RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
1118
1119 EXPECT_EQ(analysis.values().size(), 2);
1120
1121 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1122 EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
1123 EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1124 EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
1125 }
1126 {
1127 const HloDataflowAnalysis& analysis =
1128 RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
1129 EXPECT_EQ(analysis.values().size(), 1);
1130
1131 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1132 EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
1133 EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1134 }
1135 }
1136
TEST_P(HloDataflowAnalysisTest,TupleCopy)1137 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
1138 // Test that a tuple-shaped copy only copies (defines) the top-level value.
1139 auto builder = HloComputation::Builder(TestName());
1140 auto param0 = builder.AddInstruction(
1141 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1142 auto param1 = builder.AddInstruction(
1143 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1144 auto tuple =
1145 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1146 auto copy = builder.AddInstruction(
1147 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
1148 module_->AddEntryComputation(builder.Build());
1149 SCOPED_TRACE(module_->ToString());
1150
1151 bool ssa_form = GetParam();
1152 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1153
1154 EXPECT_EQ(analysis.values().size(), 4);
1155
1156 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1157 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1158 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1159 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1160 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1161 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
1162 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
1163 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
1164
1165 EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
1166 UnorderedElementsAre(analysis.GetValueDefinedAt(param0)));
1167 EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1168 UnorderedElementsAre(analysis.GetValueDefinedAt(param1)));
1169 EXPECT_TRUE(
1170 analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1171 }
1172
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1173 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1174 // Test that a Send forwards its operand to the output tuple at {0}.
1175 auto builder = HloComputation::Builder(TestName());
1176 auto param = builder.AddInstruction(
1177 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1178 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1179 auto send = builder.AddInstruction(
1180 HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1181 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1182 module_->AddEntryComputation(builder.Build());
1183 SCOPED_TRACE(module_->ToString());
1184
1185 bool ssa_form = GetParam();
1186 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1187
1188 EXPECT_EQ(analysis.values().size(), 6);
1189
1190 EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1191 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1192 EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1193 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1194 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1195 EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1196 EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1197 UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
1198 }
1199
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1200 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1201 // Test that a RecvDone forwards its operand tuple element at {0} to element
1202 // {0} of the output.
1203 auto builder = HloComputation::Builder(TestName());
1204 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1205 auto recv = builder.AddInstruction(
1206 HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1207 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1208 module_->AddEntryComputation(builder.Build());
1209 SCOPED_TRACE(module_->ToString());
1210
1211 bool ssa_form = GetParam();
1212 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1213
1214 EXPECT_EQ(analysis.values().size(), 7);
1215
1216 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1217 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1218 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1219 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1220 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1221 EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1222 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1223 EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1224 UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
1225 EXPECT_TRUE(
1226 analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1227 }
1228
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1229 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1230 // A simple chain of elementwise operations. No values should interfere.
1231 //
1232 // param --> negate -> exp -> log
1233 //
1234 auto builder = HloComputation::Builder(TestName());
1235 auto param = builder.AddInstruction(
1236 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1237 auto negate = builder.AddInstruction(
1238 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1239 auto exp = builder.AddInstruction(
1240 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1241 auto log = builder.AddInstruction(
1242 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1243
1244 module_->AddEntryComputation(builder.Build());
1245 SCOPED_TRACE(module_->ToString());
1246 RunAnalysis(GetParam());
1247
1248 DependencyHloOrdering ordering(module_.get());
1249
1250 // No values should interfere.
1251 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1252 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1253 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1254 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1255 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1256 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1257 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1258 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1259 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1260
1261 // Values should interfere with itself.
1262 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1263 }
1264
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1265 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1266 // Two entry params, which interfere with each other.
1267 //
1268 // param0 --> negate ---------------\
1269 // param1 --> exp --> add
1270 auto builder = HloComputation::Builder(TestName());
1271 auto param0 = builder.AddInstruction(
1272 HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1273 auto param1 = builder.AddInstruction(
1274 HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1275 auto negate = builder.AddInstruction(
1276 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1277 auto exp = builder.AddInstruction(
1278 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1279 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1280 vector_shape_, HloOpcode::kAdd, negate, exp));
1281
1282 auto entry = module_->AddEntryComputation(builder.Build());
1283 SCOPED_TRACE(module_->ToString());
1284 RunAnalysis(GetParam());
1285
1286 HloSchedule schedule(module_.get());
1287 schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1288 TF_ASSERT_OK(schedule.Verify());
1289 SequentialHloOrdering ordering(schedule);
1290
1291 // Entry parameters interfere as if they are defined simultaneously at
1292 // the very beginning.
1293 EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1294 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1295 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1296 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1297 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1298 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1299 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1300 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1301
1302 // Negate and exp still interfere.
1303 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1304 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1305
1306 // But {negate, add} and {exp, add} don't interfere.
1307 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1308 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1309 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1310 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1311 }
1312
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1313 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1314 // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1315 // while body computation. Body computation in the sequential order:
1316 //
1317 // %constant = Constant(...)
1318 // %exp = Exp(%constant)
1319 // %param = Param(0)
1320 // %add = Add(%param, %exp) ;; Root of body
1321 // %dead_constant = Constant(...)
1322 // %dead_negate = Negate(%dead_constant)
1323 //
1324 // %constant and its only use %exp are ordered before 'param'. However, the
1325 // %constant and %param values still interfere because the parameter is
1326 // considered live into the while body.
1327 //
1328 // Similarly, %dead_constant and %dead_negate are ordered after the root of
1329 // the body computation %add. However, %add is liveout of the computation so
1330 // %dead_constant and %add interfere.
1331 auto body_builder = HloComputation::Builder(TestName());
1332 auto body_param = body_builder.AddInstruction(
1333 HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1334 auto constant = body_builder.AddInstruction(
1335 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1336 auto exp = body_builder.AddInstruction(
1337 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1338 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1339 scalar_shape_, HloOpcode::kAdd, exp, body_param));
1340 auto dead_constant = body_builder.AddInstruction(
1341 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1342 auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1343 scalar_shape_, HloOpcode::kNegate, dead_constant));
1344 HloComputation* body = module_->AddEmbeddedComputation(
1345 body_builder.Build(/*root_instruction=*/add));
1346
1347 auto cond_builder = HloComputation::Builder("condition");
1348 auto cond_param = cond_builder.AddInstruction(
1349 HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1350 auto cond_constant = cond_builder.AddInstruction(
1351 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1352 HloComputation* condition =
1353 module_->AddEmbeddedComputation(cond_builder.Build());
1354
1355 auto builder = HloComputation::Builder(TestName());
1356 auto param = builder.AddInstruction(
1357 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1358 auto xla_while = builder.AddInstruction(
1359 HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1360
1361 auto entry = module_->AddEntryComputation(builder.Build());
1362 SCOPED_TRACE(module_->ToString());
1363 bool ssa_form = GetParam();
1364 RunAnalysis(ssa_form);
1365
1366 HloSchedule schedule(module_.get());
1367 schedule.set_sequence(entry, {param, xla_while});
1368 schedule.set_sequence(condition, {cond_param, cond_constant});
1369 // Construct the order such that 'constant' and its use 'exp' are before
1370 // body_param.
1371 schedule.set_sequence(
1372 body, {constant, exp, body_param, add, dead_constant, dead_negate});
1373 TF_ASSERT_OK(schedule.Verify());
1374
1375 SequentialHloOrdering ordering(schedule);
1376
1377 // 'add' is live out of the body and will interfere with an later instructions
1378 // such as 'dead_constant' and 'dead_negate'.
1379 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1380 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1381
1382 // The remaining checks test phi values defined by body and condition
1383 // parameters which only occur in the SSA form of the analysis.
1384 if (ssa_form) {
1385 // Though the ordering suggests 'constant' and 'param' should not interfere,
1386 // 'param' is live in and thus interferes with any earlier instruction of
1387 // the computation in the order (eg 'constant')'
1388 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1389 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1390 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1391
1392 // The following values end up in the same buffer:
1393 // (1) the init value: 'param'
1394 // (2) the body parameter: 'body_param'
1395 // (3) the condition parameter: 'cond_param'
1396 // (4) the root value of the while body: 'add'
1397 // (5) the while value: 'xla_while'
1398 // None should interfere.
1399 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1400 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1401 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1402 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1403
1404 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1405 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1406 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1407
1408 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1409 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1410
1411 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1412 }
1413 }
1414
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1415 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1416 // A chain of operations with two elementwise and one non-elementwise. The
1417 // elementwise op should not interfere with its operand, while the
1418 // non-elementwise op should interfere. Entry params always interfere.
1419 //
1420 // param --> exp -> negate -> reverse
1421 //
1422 auto builder = HloComputation::Builder(TestName());
1423 auto param = builder.AddInstruction(
1424 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1425 auto exp = builder.AddInstruction(
1426 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1427 auto negate = builder.AddInstruction(
1428 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1429 auto reverse = builder.AddInstruction(
1430 HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1431
1432 module_->AddEntryComputation(builder.Build());
1433 SCOPED_TRACE(module_->ToString());
1434 RunAnalysis(GetParam());
1435
1436 DependencyHloOrdering ordering(module_.get());
1437
1438 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1439 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1440 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1441
1442 // Negate is elementwise, so doesn't interfere with its operand.
1443 // Reverse is non-elementwise, so does interfere with its operand.
1444 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1445 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1446 }
1447
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1448 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1449 // Verify simultaneously live values interfere (exp and negate).
1450 //
1451 // param --> negate -> add
1452 // \---> exp -----/
1453 //
1454 auto builder = HloComputation::Builder(TestName());
1455 auto param = builder.AddInstruction(
1456 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1457 auto negate = builder.AddInstruction(
1458 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1459 auto exp = builder.AddInstruction(
1460 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1461 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1462 vector_shape_, HloOpcode::kAdd, negate, exp));
1463
1464 module_->AddEntryComputation(builder.Build());
1465 SCOPED_TRACE(module_->ToString());
1466 RunAnalysis(GetParam());
1467
1468 DependencyHloOrdering ordering(module_.get());
1469
1470 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1471 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1472 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1473
1474 // Negate and exp interfere with each other, but not with add.
1475 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1476 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1477 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1478 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1479 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1480 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1481 }
1482
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1483 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1484 // Identical to the test OverlappedValue but using a sequential ordering of
1485 // HLO instructions.
1486 //
1487 // param --> negate -> add
1488 // \---> exp -----/
1489 //
1490 // Sequential order:
1491 // param, negate, exp, add
1492 //
1493 // Liveness is identical to the DependencyHloOrdering.
1494 auto builder = HloComputation::Builder(TestName());
1495 auto param = builder.AddInstruction(
1496 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1497 auto negate = builder.AddInstruction(
1498 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1499 auto exp = builder.AddInstruction(
1500 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1501 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1502 vector_shape_, HloOpcode::kAdd, negate, exp));
1503
1504 auto entry = module_->AddEntryComputation(builder.Build());
1505 SCOPED_TRACE(module_->ToString());
1506 RunAnalysis(GetParam());
1507
1508 HloSchedule schedule(module_.get());
1509 schedule.set_sequence(entry, {param, negate, exp, add});
1510 TF_ASSERT_OK(schedule.Verify());
1511 SequentialHloOrdering ordering(schedule);
1512
1513 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1514 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1515 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1516
1517 // Negate and exp interfere with each other, but not with add.
1518 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1519 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1520 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1521 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1522 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1523 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1524 }
1525
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1526 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1527 // Test MayInterfere() for embedded computation, specifically the interference
1528 // of values in different computations.
1529 //
1530 // embedded_computation:
1531 // %embedded_param = Param(0)
1532 // %embedded_log = Log(%embedded_param)
1533 //
1534 // entry computation:
1535 // %param = Param(0)
1536 // %negate = Negate(%param)
1537 // %exp = Negate(%exp)
1538 // %call = Call(embedded_computation, {%exp})
1539 // %add = Add(%negate, %call)
1540 //
1541 // Note %negate is live across the call and should interfere with all values
1542 // in the embedded computation.
1543 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1544 auto embedded_param = embedded_builder.AddInstruction(
1545 HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1546 auto embedded_log =
1547 embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1548 vector_shape_, HloOpcode::kLog, embedded_param));
1549 auto embedded_computation =
1550 module_->AddEmbeddedComputation(embedded_builder.Build());
1551
1552 auto builder = HloComputation::Builder(TestName());
1553 auto param = builder.AddInstruction(
1554 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1555 auto negate = builder.AddInstruction(
1556 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1557 auto exp = builder.AddInstruction(
1558 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1559 auto call = builder.AddInstruction(
1560 HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1561 builder.AddInstruction(HloInstruction::CreateBinary(
1562 vector_shape_, HloOpcode::kAdd, negate, call));
1563 module_->AddEntryComputation(builder.Build());
1564 SCOPED_TRACE(module_->ToString());
1565 RunAnalysis(GetParam());
1566
1567 DependencyHloOrdering ordering(module_.get());
1568
1569 // Exp only use is the call so it should not interfere with values inside the
1570 // embedded computation.
1571 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1572
1573 // Negate is live across the call and should interfere with values in the
1574 // embedded computation
1575 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1576 }
1577
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1578 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1579 // Test conditional with identity computations in both true and false cases.
1580 //
1581 // true_computation(F32[] %true_param):
1582 // return %true_param
1583 //
1584 // false_computation(F32[] %false_param):
1585 // return %false_param
1586 //
1587 // entry:
1588 // %pred = Constant(true)
1589 // %constant1 = Constant(56.0)
1590 // %constant2 = Constant(12.0)
1591 // return Conditional(%pred, %constant1, true_computation,
1592 // %constant2, false_computation)
1593
1594 auto true_builder = HloComputation::Builder(TestName() + "_true");
1595 auto true_param = true_builder.AddInstruction(
1596 HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1597 HloComputation* true_computation =
1598 module_->AddEmbeddedComputation(true_builder.Build());
1599
1600 auto false_builder = HloComputation::Builder(TestName() + "_false");
1601 auto false_param = false_builder.AddInstruction(
1602 HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1603 HloComputation* false_computation =
1604 module_->AddEmbeddedComputation(false_builder.Build());
1605
1606 auto builder = HloComputation::Builder(TestName());
1607 auto pred = builder.AddInstruction(
1608 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1609 auto constant1 = builder.AddInstruction(
1610 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1611 auto constant2 = builder.AddInstruction(
1612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1613 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1614 scalar_shape_, pred, constant1, true_computation, constant2,
1615 false_computation));
1616 module_->AddEntryComputation(builder.Build());
1617 SCOPED_TRACE(module_->ToString());
1618
1619 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1620
1621 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1622 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1623 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1624
1625 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1626 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1627
1628 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1629 analysis.GetValueDefinedAt(constant1));
1630 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1631 analysis.GetValueDefinedAt(constant2));
1632
1633 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1634 ElementsAre(HloUse{conditional, 0, {}}));
1635 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1636 ElementsAre(HloUse{conditional, 1, {}}));
1637 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1638 ElementsAre(HloUse{conditional, 2, {}}));
1639
1640 bool ssa_form = GetParam();
1641 if (ssa_form) {
1642 EXPECT_EQ(analysis.values().size(), 4);
1643 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1644 } else {
1645 EXPECT_EQ(analysis.values().size(), 3);
1646 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1647 EXPECT_THAT(HloValuesAt(conditional),
1648 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1649 analysis.GetValueDefinedAt(constant2)));
1650 }
1651 }
1652
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1653 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1654 // Test conditional with true and false computations taking a tuple operand.
1655 //
1656 // true_computation((F32[], F32[]) %true_param):
1657 // %true_x = GetTupleElement(%true_param, 0)
1658 // %true_y = GetTupleElement(%true_param, 1)
1659 // return Add(%true_x, %true_y)
1660 //
1661 // false_computation((F32[], F32[]) %false_param):
1662 // %false_x = GetTupleElement(%false_param, 0)
1663 // %false_y = GetTupleElement(%false_param, 1)
1664 // return Subtract(%false_x, %false_y)
1665 //
1666 // entry:
1667 // %pred = Constant(true)
1668 // %constant1 = Constant(56.0)
1669 // %constant2 = Constant(12.0)
1670 // %tuple_operand = Tuple(%constant1, %constant2)
1671 // return Conditional(%pred, %tuple_operand, true_computation,
1672 // %tuple_operand, false_computation)
1673
1674 auto true_builder = HloComputation::Builder(TestName() + "_true");
1675 auto true_param = true_builder.AddInstruction(
1676 HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1677 auto true_x = true_builder.AddInstruction(
1678 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1679 auto true_y = true_builder.AddInstruction(
1680 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1681 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1682 scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1683 HloComputation* true_computation =
1684 module_->AddEmbeddedComputation(true_builder.Build());
1685
1686 auto false_builder = HloComputation::Builder(TestName() + "_false");
1687 auto false_param = false_builder.AddInstruction(
1688 HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1689 auto false_x = false_builder.AddInstruction(
1690 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1691 auto false_y = false_builder.AddInstruction(
1692 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1693 auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1694 scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1695 HloComputation* false_computation =
1696 module_->AddEmbeddedComputation(false_builder.Build());
1697
1698 auto builder = HloComputation::Builder(TestName());
1699 auto pred = builder.AddInstruction(
1700 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1701 auto constant1 = builder.AddInstruction(
1702 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1703 auto constant2 = builder.AddInstruction(
1704 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1705 auto tuple_operand = builder.AddInstruction(
1706 HloInstruction::CreateTuple({constant1, constant2}));
1707 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1708 scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1709 false_computation));
1710 module_->AddEntryComputation(builder.Build());
1711 SCOPED_TRACE(module_->ToString());
1712
1713 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1714
1715 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1716 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1717 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1718 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1719 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1720 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1721
1722 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1723 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1724 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1725 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1726 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1727 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1728
1729 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1730 analysis.GetValueDefinedAt(tuple_operand));
1731 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1732 analysis.GetValueDefinedAt(tuple_operand));
1733 EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1734 analysis.GetValueDefinedAt(constant1));
1735 EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1736 analysis.GetValueDefinedAt(constant2));
1737 EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1738 analysis.GetValueDefinedAt(constant1));
1739 EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1740 analysis.GetValueDefinedAt(constant2));
1741
1742 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1743 ElementsAre(HloUse{conditional, 0, {}}));
1744 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1745 UnorderedElementsAre(HloUse{conditional, 1, {0}},
1746 HloUse{conditional, 2, {0}},
1747 HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1748 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1749 UnorderedElementsAre(HloUse{conditional, 1, {1}},
1750 HloUse{conditional, 2, {1}},
1751 HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1752 EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
1753 UnorderedElementsAre(
1754 HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1755 HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1756 HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1757
1758 bool ssa_form = GetParam();
1759 if (ssa_form) {
1760 EXPECT_EQ(analysis.values().size(), 7);
1761 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1762 } else {
1763 EXPECT_EQ(analysis.values().size(), 6);
1764 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1765 EXPECT_THAT(HloValuesAt(conditional),
1766 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1767 analysis.GetValueDefinedAt(sub)));
1768 }
1769 }
1770
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1771 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1772 // computation1(F32[] %param1):
1773 // %ceil = Ceil(%param1)
1774 // return %ceil
1775 //
1776 // computation2(F32[] %param2):
1777 // %floor = Floor(%param2)
1778 // return %floor
1779 //
1780 // computation3(F32[] %param3):
1781 // %negate = Negate(%param3)
1782 // return %negate
1783 //
1784 // inner_conditional((PRED, F32[], F32[]) %param_cond):
1785 // %pred_cond = GetTupleElement(%param_cond, 0)
1786 // %true_operand_cond = GetTupleElement(%param_cond, 1)
1787 // %false_opearnd_cond = GetTupleElement(%param_cond, 2)
1788 // return Conditional(%pred_cond, %true_operand_cond, computation1,
1789 // %false_operand_cond, computation2)
1790 //
1791 // entry:
1792 // %pred1 = Constant(true)
1793 // %pred2 = Constant(false)
1794 // %constant1 = Constant(1.1);
1795 // %constant2 = Constant(2.2);
1796 // %constant3 = Constant(3.3);
1797 // return Conditional(%pred1, (%pred2, %constant1, %constant2),
1798 // inner_conditional, %constant3, computation3)
1799
1800 auto computation1 = module_->AddEmbeddedComputation(
1801 CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1802 auto computation2 = module_->AddEmbeddedComputation(
1803 CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1804 auto computation3 = module_->AddEmbeddedComputation(
1805 CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1806
1807 // Build inner_conditional computation.
1808 const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1809 const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1810 {scalar_bool_shape, scalar_shape_, scalar_shape_});
1811 auto inner_builder =
1812 HloComputation::Builder(TestName() + "_inner_conditional");
1813 auto param_cond = inner_builder.AddInstruction(
1814 HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1815 auto pred_cond = inner_builder.AddInstruction(
1816 HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1817 auto true_operand_cond = inner_builder.AddInstruction(
1818 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1819 auto false_operand_cond = inner_builder.AddInstruction(
1820 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1821 auto inner_conditional =
1822 inner_builder.AddInstruction(HloInstruction::CreateConditional(
1823 scalar_shape_, pred_cond, true_operand_cond, computation1,
1824 false_operand_cond, computation2));
1825 auto inner_conditional_computation =
1826 module_->AddEmbeddedComputation(inner_builder.Build());
1827
1828 // Build entry computation.
1829 auto builder = HloComputation::Builder(TestName());
1830 auto pred1 = builder.AddInstruction(
1831 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1832 auto pred2 = builder.AddInstruction(
1833 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1834 auto constant1 = builder.AddInstruction(
1835 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1836 auto constant2 = builder.AddInstruction(
1837 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1838 auto constant3 = builder.AddInstruction(
1839 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1840 auto tuple_operand = builder.AddInstruction(
1841 HloInstruction::CreateTuple({pred2, constant1, constant2}));
1842 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1843 scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1844 constant3, computation3));
1845 module_->AddEntryComputation(builder.Build());
1846 SCOPED_TRACE(module_->ToString());
1847
1848 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1849
1850 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1851 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1852 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1853 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1854 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1855 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1856 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1857 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1858 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1859
1860 auto computation1_param = computation1->parameter_instruction(0);
1861 auto computation2_param = computation2->parameter_instruction(0);
1862 auto computation3_param = computation3->parameter_instruction(0);
1863 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1864 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1865 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1866 EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1867 analysis.GetValueDefinedAt(constant1));
1868 EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1869 analysis.GetValueDefinedAt(constant2));
1870 EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1871 analysis.GetValueDefinedAt(constant3));
1872
1873 EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
1874 EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
1875 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
1876 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
1877 EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
1878 analysis.GetValueDefinedAt(tuple_operand));
1879 EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
1880 analysis.GetValueDefinedAt(pred2));
1881 EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
1882 analysis.GetValueDefinedAt(constant1));
1883 EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
1884 analysis.GetValueDefinedAt(constant2));
1885
1886 bool ssa_form = GetParam();
1887 if (ssa_form) {
1888 EXPECT_EQ(analysis.values().size(), 11);
1889 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
1890 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1891 } else {
1892 EXPECT_EQ(analysis.values().size(), 9);
1893 EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
1894 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1895 EXPECT_THAT(
1896 HloValuesAt(inner_conditional),
1897 UnorderedElementsAre(
1898 analysis.GetValueDefinedAt(computation1->root_instruction()),
1899 analysis.GetValueDefinedAt(computation2->root_instruction())));
1900 EXPECT_THAT(
1901 HloValuesAt(conditional),
1902 UnorderedElementsAre(
1903 analysis.GetValueDefinedAt(computation1->root_instruction()),
1904 analysis.GetValueDefinedAt(computation2->root_instruction()),
1905 analysis.GetValueDefinedAt(computation3->root_instruction())));
1906 }
1907 }
1908
TEST_P(HloDataflowAnalysisTest,AddDependency)1909 TEST_P(HloDataflowAnalysisTest, AddDependency) {
1910 string module_string = R"(
1911 HloModule AddDependency
1912 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1913 %p = f32[3] parameter(0)
1914 %token0 = token[] after-all()
1915 ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
1916 }
1917 )";
1918 TF_ASSERT_OK_AND_ASSIGN(
1919 std::unique_ptr<HloModule> module,
1920 ParseHloString(module_string, GetModuleConfigForTest()));
1921
1922 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
1923 HloDataflowAnalysis::Run(*module));
1924 const HloInstruction* root = module->entry_computation()->root_instruction();
1925 EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
1926
1927 // The after-all and parameter should define a value. Add-dependency should
1928 // not.
1929 EXPECT_EQ(analysis->values().size(), 2);
1930 EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
1931 }
1932
1933 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
1934 HloDataflowAnalysisTest,
1935 ::testing::Values(false, true));
1936
1937 class HloDataflowAnalysisTestBase : public HloTestBase {
1938 protected:
BuildModule(std::unique_ptr<HloComputation> computation)1939 void BuildModule(std::unique_ptr<HloComputation> computation) {
1940 module_ = CreateNewUnverifiedModule();
1941 computation_ = module_->AddEntryComputation(std::move(computation));
1942 }
1943
RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction & fusion_can_share_buffer=nullptr)1944 void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
1945 fusion_can_share_buffer = nullptr) {
1946 CHECK_NOTNULL(module_.get());
1947 dataflow_analysis_ =
1948 HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
1949 /*bitcast_defines_value=*/false,
1950 fusion_can_share_buffer)
1951 .ConsumeValueOrDie();
1952 }
1953
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)1954 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
1955 BuildModule(std::move(computation));
1956 RunAnalysis();
1957 }
1958
1959 std::unique_ptr<HloModule> module_;
1960 HloComputation* computation_ = nullptr;
1961 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
1962 };
1963
1964 class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {};
1965
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)1966 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
1967 auto builder = HloComputation::Builder(TestName());
1968
1969 Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
1970 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
1971 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
1972 auto gte0 = builder.AddInstruction(
1973 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
1974 auto gte1 = builder.AddInstruction(
1975 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
1976 builder.AddInstruction(
1977 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
1978
1979 BuildModuleAndRunAnalysis(builder.Build());
1980
1981 // GetTupleElement instructions only access the top-level buffer of their
1982 // operand.
1983 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
1984 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
1985 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
1986 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
1987 }
1988
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)1989 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
1990 auto builder = HloComputation::Builder(TestName());
1991
1992 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1993 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
1994 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
1995 auto gte0 = builder.AddInstruction(
1996 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
1997 auto gte1 = builder.AddInstruction(
1998 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
1999
2000 // Create a DynamicUpdateSlice instruction of tuple element 1.
2001 auto starts = builder.AddInstruction(
2002 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2003 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2004 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2005 auto dynamic_update_slice =
2006 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2007 data_shape, gte1, update,
2008 std::initializer_list<HloInstruction*>({starts})));
2009 builder.AddInstruction(
2010 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2011
2012 BuildModule(builder.Build());
2013 auto fusion = computation_->CreateFusionInstruction(
2014 {dynamic_update_slice, starts, update, gte1},
2015 HloInstruction::FusionKind::kLoop);
2016 RunAnalysis();
2017
2018 // The fusion instruction never uses tuple element 0, but does use element 1.
2019 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2020 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2021 }
2022
2023 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2024 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2025 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2026 auto builder = HloComputation::Builder(TestName());
2027
2028 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2029 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2030 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2031 auto t0 = builder.AddInstruction(
2032 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2033 auto t1 = builder.AddInstruction(
2034 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2035 // Swap the tuple elements.
2036 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2037
2038 auto gte0 = builder.AddInstruction(
2039 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2040 auto gte1 = builder.AddInstruction(
2041 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2042
2043 // Create a DynamicUpdateSlice instruction of tuple element 1.
2044 auto starts = builder.AddInstruction(
2045 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2046 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2047 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2048 auto dynamic_update_slice =
2049 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2050 data_shape, gte1, update,
2051 std::initializer_list<HloInstruction*>({starts})));
2052 builder.AddInstruction(
2053 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2054
2055 BuildModule(builder.Build());
2056 auto fusion = computation_->CreateFusionInstruction(
2057 {dynamic_update_slice, starts, update, gte1},
2058 HloInstruction::FusionKind::kLoop);
2059 RunAnalysis();
2060
2061 // The fusion instruction never uses tuple element 0, but does use element 1.
2062 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2063 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2064 // The same holds for the parameter tuple, except that the tuple elements are
2065 // swapped in 'tuple'.
2066 EXPECT_TRUE(
2067 dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2068 EXPECT_FALSE(
2069 dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2070 }
2071
2072 class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
2073
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2074 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2075 auto builder = HloComputation::Builder(TestName());
2076
2077 Shape shape = ShapeUtil::MakeShape(F32, {8});
2078 auto param = builder.AddInstruction(
2079 HloInstruction::CreateParameter(0, shape, "param"));
2080 auto exp = builder.AddInstruction(
2081 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2082 auto log = builder.AddInstruction(
2083 HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2084
2085 BuildModuleAndRunAnalysis(builder.Build());
2086
2087 EXPECT_TRUE(
2088 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2089 EXPECT_TRUE(
2090 dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
2091 }
2092
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2093 TEST_F(CanShareOperandBufferWithUserTest,
2094 NonElementwiseLoopFusionCantAliasOperandBuffer) {
2095 auto builder = HloComputation::Builder(TestName());
2096 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2097
2098 auto param0 = builder.AddInstruction(
2099 HloInstruction::CreateParameter(0, data_shape, "param0"));
2100
2101 auto neg = builder.AddInstruction(
2102 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2103
2104 auto reverse = builder.AddInstruction(
2105 HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2106
2107 BuildModule(builder.Build());
2108 auto fusion = computation_->CreateFusionInstruction(
2109 {reverse, neg}, HloInstruction::FusionKind::kLoop);
2110 RunAnalysis();
2111
2112 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2113 fusion, {}));
2114 }
2115
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2116 TEST_F(CanShareOperandBufferWithUserTest,
2117 MultiOutputFusionCanAliasOperandBuffer) {
2118 auto builder = HloComputation::Builder(TestName());
2119 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2120
2121 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2122 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2123 auto param0 = builder.AddInstruction(
2124 HloInstruction::CreateParameter(0, in_shape, "param0"));
2125 auto param1 = builder.AddInstruction(
2126 HloInstruction::CreateParameter(1, in_shape, "param1"));
2127
2128 auto copy0 = builder.AddInstruction(
2129 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2130 auto copy1 = builder.AddInstruction(
2131 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2132
2133 auto tuple =
2134 builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2135
2136 BuildModule(builder.Build());
2137 auto fusion = computation_->CreateFusionInstruction(
2138 {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2139 RunAnalysis();
2140
2141 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2142 fusion, {0}));
2143 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2144 fusion, {1}));
2145 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2146 fusion, {0}));
2147 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2148 fusion, {1}));
2149 }
2150
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2151 TEST_F(CanShareOperandBufferWithUserTest,
2152 ElementwiseLoopFusionCantAliasOperandBuffer) {
2153 auto builder = HloComputation::Builder(TestName());
2154 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2155
2156 auto one = builder.AddInstruction(
2157 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2158 auto operand = builder.AddInstruction(
2159 HloInstruction::CreateBroadcast(data_shape, one, {1}));
2160
2161 auto neg = builder.AddInstruction(
2162 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2163
2164 auto exp = builder.AddInstruction(
2165 HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2166
2167 BuildModule(builder.Build());
2168 auto fusion = computation_->CreateFusionInstruction(
2169 {exp, neg}, HloInstruction::FusionKind::kLoop);
2170 RunAnalysis();
2171
2172 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2173 fusion, {}));
2174 }
2175
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2176 TEST_F(CanShareOperandBufferWithUserTest,
2177 CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2178 auto builder = HloComputation::Builder(TestName());
2179 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2180 Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2181
2182 auto param = builder.AddInstruction(
2183 HloInstruction::CreateParameter(0, data_shape, "param0"));
2184 auto zero = builder.AddInstruction(
2185 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
2186 auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2187 slice_shape, param, {zero, zero}, {1, 2, 2}));
2188
2189 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2190 data_shape, param, ds, {zero, zero}));
2191
2192 BuildModule(builder.Build());
2193 auto fusion = computation_->CreateFusionInstruction(
2194 {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2195 RunAnalysis();
2196
2197 EXPECT_TRUE(
2198 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2199 }
2200
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithDifferentIndices)2201 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) {
2202 const char* kModule = R"(
2203 HloModule test
2204
2205 fused_computation {
2206 p0 = f32[10,20,30] parameter(0)
2207 p1 = s32[] parameter(1)
2208 p2 = s32[] parameter(2)
2209 p3 = s32[] parameter(3)
2210 slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2211 ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2)
2212 }
2213
2214 ENTRY test {
2215 p0 = f32[10,20,30] parameter(0)
2216 p1 = s32[] parameter(1)
2217 p2 = s32[] parameter(2)
2218 p3 = s32[] parameter(3)
2219 ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2220 }
2221 )";
2222 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2223 auto* fusion = module_->entry_computation()->root_instruction();
2224 auto* param = module_->entry_computation()->parameter_instruction(0);
2225
2226 RunAnalysis();
2227 EXPECT_FALSE(
2228 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2229 }
2230
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2231 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2232 const char* kModule = R"(
2233 HloModule test
2234
2235 fused_computation {
2236 p0 = f32[10,20,30] parameter(0)
2237 p1 = s32[] parameter(1)
2238 p2 = s32[] parameter(2)
2239 p3 = s32[] parameter(3)
2240 slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2241 ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2242 }
2243
2244 ENTRY test {
2245 p0 = f32[10,20,30] parameter(0)
2246 p1 = s32[] parameter(1)
2247 p2 = s32[] parameter(2)
2248 p3 = s32[] parameter(3)
2249 ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2250 }
2251 )";
2252 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2253 auto* fusion = module_->entry_computation()->root_instruction();
2254 auto* param = module_->entry_computation()->parameter_instruction(0);
2255
2256 RunAnalysis();
2257 EXPECT_TRUE(
2258 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2259 }
2260
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2261 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2262 auto builder = HloComputation::Builder(TestName());
2263
2264 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2265 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2266 auto param0 = builder.AddInstruction(
2267 HloInstruction::CreateParameter(0, in_shape, "param0"));
2268 auto param1 = builder.AddInstruction(
2269 HloInstruction::CreateParameter(1, in_shape, "param1"));
2270 auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2271 out_shape, param0, param1, ComparisonDirection::kEq));
2272
2273 BuildModuleAndRunAnalysis(builder.Build());
2274
2275 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2276 result, {}));
2277 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2278 result, {}));
2279 }
2280
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2281 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2282 auto builder = HloComputation::Builder(TestName());
2283
2284 Shape shape = ShapeUtil::MakeShape(F32, {8});
2285 auto param = builder.AddInstruction(
2286 HloInstruction::CreateParameter(0, shape, "param"));
2287 auto exp = builder.AddInstruction(
2288 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2289 auto copy = builder.AddInstruction(
2290 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2291
2292 BuildModuleAndRunAnalysis(builder.Build());
2293
2294 EXPECT_TRUE(
2295 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2296 EXPECT_TRUE(
2297 dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2298 }
2299
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2300 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2301 auto builder = HloComputation::Builder(TestName());
2302
2303 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2304 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2305 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2306 auto gte0 = builder.AddInstruction(
2307 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2308 auto gte1 = builder.AddInstruction(
2309 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2310
2311 // Create a DynamicUpdateSlice instruction of tuple element 1.
2312 auto starts = builder.AddInstruction(
2313 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2314 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2315 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2316 auto dynamic_update_slice =
2317 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2318 data_shape, gte1, update,
2319 std::initializer_list<HloInstruction*>({starts})));
2320 builder.AddInstruction(
2321 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2322
2323 BuildModule(builder.Build());
2324 auto fusion = computation_->CreateFusionInstruction(
2325 {dynamic_update_slice, starts, update, gte1},
2326 HloInstruction::FusionKind::kLoop);
2327 RunAnalysis();
2328
2329 // The fusion instruction can share with tuple element 1.
2330 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0},
2331 fusion, {}));
2332 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1},
2333 fusion, {}));
2334 }
2335
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2336 TEST_F(CanShareOperandBufferWithUserTest,
2337 FusedDynamicUpdateSliceWithConvertCanShare) {
2338 auto builder = HloComputation::Builder(TestName());
2339
2340 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2341 Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2342 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2343 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2344 auto gte0 = builder.AddInstruction(
2345 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2346 auto gte1 = builder.AddInstruction(
2347 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2348
2349 auto convert1 = builder.AddInstruction(
2350 HloInstruction::CreateConvert(data_shape_bf16, gte1));
2351
2352 // Create a DynamicUpdateSlice instruction of tuple element 1.
2353 auto starts = builder.AddInstruction(
2354 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2355 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2356 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2357 auto dynamic_update_slice =
2358 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2359 data_shape_bf16, convert1, update,
2360 std::initializer_list<HloInstruction*>({starts})));
2361
2362 auto convert2 = builder.AddInstruction(
2363 HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2364 builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2365
2366 BuildModule(builder.Build());
2367 auto fusion = computation_->CreateFusionInstruction(
2368 {convert2, dynamic_update_slice, starts, update, convert1},
2369 HloInstruction::FusionKind::kLoop);
2370 RunAnalysis();
2371
2372 EXPECT_TRUE(
2373 dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2374 }
2375
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2376 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2377 auto builder = HloComputation::Builder(TestName());
2378
2379 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2380 Shape update_shape = ShapeUtil::MakeShape(F32, {4});
2381 Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
2382 auto data = builder.AddInstruction(
2383 HloInstruction::CreateParameter(0, data_shape, "data"));
2384 auto update = builder.AddInstruction(
2385 HloInstruction::CreateParameter(1, update_shape, "update"));
2386 auto start0 = builder.AddInstruction(
2387 HloInstruction::CreateParameter(2, starts_shape, "start0"));
2388 auto start1 = builder.AddInstruction(
2389 HloInstruction::CreateParameter(3, starts_shape, "start1"));
2390
2391 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2392 data_shape, data, update, {start0, start1}));
2393
2394 BuildModuleAndRunAnalysis(builder.Build());
2395
2396 // The DynamicUpdateSlice instruction can share with the data operand, but not
2397 // with update or starts.
2398 EXPECT_TRUE(
2399 dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
2400 EXPECT_FALSE(
2401 dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
2402 EXPECT_FALSE(
2403 dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {}));
2404 EXPECT_FALSE(
2405 dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {}));
2406 }
2407
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2408 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2409 const char* hlo_text = R"(
2410 HloModule TensorFlowScatterV1
2411
2412 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2413 lhs = s32[] parameter(0)
2414 ROOT rhs = s32[] parameter(1)
2415 }
2416
2417 ENTRY main {
2418 operand = s32[3,3] parameter(0)
2419 indices = s32[2] parameter(1)
2420 updates = s32[2,3] parameter(2)
2421 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2422 to_apply=update_s32,
2423 update_window_dims={1},
2424 inserted_window_dims={0},
2425 scatter_dims_to_operand_dims={0},
2426 index_vector_dim=1
2427 }
2428 )";
2429 TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
2430 computation_ = module_->entry_computation();
2431 RunAnalysis();
2432
2433 HloInstruction* operand_param = computation_->parameter_instruction(0);
2434 HloInstruction* indices_param = computation_->parameter_instruction(1);
2435 HloInstruction* updates_param = computation_->parameter_instruction(2);
2436 HloInstruction* scatter = computation_->root_instruction();
2437
2438 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2439 operand_param, {}, scatter, {}));
2440 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2441 indices_param, {}, scatter, {}));
2442 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2443 updates_param, {}, scatter, {}));
2444 }
2445
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2446 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2447 auto builder = HloComputation::Builder(TestName());
2448 module_ = CreateNewVerifiedModule();
2449
2450 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2451 auto keys = builder.AddInstruction(
2452 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2453 TF_ASSERT_OK_AND_ASSIGN(
2454 auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2455 &builder, module_.get()));
2456
2457 computation_ = module_->AddEntryComputation(builder.Build());
2458 RunAnalysis();
2459
2460 EXPECT_TRUE(
2461 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2462 }
2463
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2464 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2465 auto builder = HloComputation::Builder(TestName());
2466 module_ = CreateNewVerifiedModule();
2467
2468 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2469 Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2470 auto keys = builder.AddInstruction(
2471 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2472 auto values = builder.AddInstruction(
2473 HloInstruction::CreateParameter(1, values_shape, "values"));
2474 TF_ASSERT_OK_AND_ASSIGN(
2475 auto* sort,
2476 MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2477 {keys, values}, 0, /*is_stable=*/false, &builder,
2478 module_.get()));
2479
2480 computation_ = module_->AddEntryComputation(builder.Build());
2481 RunAnalysis();
2482
2483 // The buffer for the keys can be shared with the first tuple entry.
2484 EXPECT_TRUE(
2485 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2486 // The buffer for the values can be shared with the second tuple entry.
2487 EXPECT_TRUE(
2488 dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2489 // Verify that the buffers are not shared with the "wrong" tuple entry.
2490 EXPECT_FALSE(
2491 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2492 EXPECT_FALSE(
2493 dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2494 }
2495
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2496 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2497 auto builder = HloComputation::Builder(TestName());
2498 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2499
2500 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2501 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2502 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2503 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2504
2505 DotDimensionNumbers dot_dnums;
2506 dot_dnums.add_lhs_contracting_dimensions(1);
2507 dot_dnums.add_rhs_contracting_dimensions(0);
2508 PrecisionConfig precision_config;
2509 precision_config.mutable_operand_precision()->Resize(
2510 2, PrecisionConfig::DEFAULT);
2511 auto dot = builder.AddInstruction(
2512 HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2513
2514 auto one = builder.AddInstruction(
2515 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2516 auto add_operand = builder.AddInstruction(
2517 HloInstruction::CreateBroadcast(data_shape, one, {1}));
2518
2519 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2520 data_shape, HloOpcode::kAdd, dot, add_operand));
2521
2522 BuildModule(builder.Build());
2523 auto fusion = computation_->CreateFusionInstruction(
2524 {add, dot}, HloInstruction::FusionKind::kOutput);
2525 RunAnalysis();
2526
2527 // Output fused dot add should be able to share buffer with 'add_operand'.
2528 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {},
2529 fusion, {}));
2530 }
2531
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2532 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2533 auto builder = HloComputation::Builder(TestName());
2534 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2535
2536 auto one = builder.AddInstruction(
2537 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2538 auto operand = builder.AddInstruction(
2539 HloInstruction::CreateBroadcast(data_shape, one, {1}));
2540
2541 auto reverse = builder.AddInstruction(
2542 HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2543
2544 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2545 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2546
2547 auto add = builder.AddInstruction(
2548 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2549
2550 BuildModule(builder.Build());
2551 auto fusion = computation_->CreateFusionInstruction(
2552 {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2553 RunAnalysis();
2554
2555 // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2556 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2557 fusion, {}));
2558 }
2559
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2560 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2561 auto builder = HloComputation::Builder(TestName());
2562 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2563
2564 auto one = builder.AddInstruction(
2565 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2566 auto operand = builder.AddInstruction(
2567 HloInstruction::CreateBroadcast(data_shape, one, {1}));
2568 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2569 data_shape, HloOpcode::kMultiply, operand, operand));
2570 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2571 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2572 auto add = builder.AddInstruction(
2573 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2574
2575 BuildModule(builder.Build());
2576 auto fusion = computation_->CreateFusionInstruction(
2577 {add, two, mul}, HloInstruction::FusionKind::kInput);
2578 RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
2579 const HloInstruction*) {
2580 return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop;
2581 });
2582
2583 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2584 fusion, {}));
2585 }
2586
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2587 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2588 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2589
2590 auto make_cond = [&data_shape]() {
2591 auto builder = HloComputation::Builder(TestName() + ".Cond");
2592 auto data = builder.AddInstruction(
2593 HloInstruction::CreateParameter(0, data_shape, "data"));
2594 builder.AddInstruction(HloInstruction::CreateCompare(
2595 ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
2596 return builder.Build();
2597 };
2598
2599 auto make_body = [&data_shape]() {
2600 auto builder = HloComputation::Builder(TestName() + ".Body");
2601 auto data = builder.AddInstruction(
2602 HloInstruction::CreateParameter(0, data_shape, "data"));
2603 builder.AddInstruction(
2604 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2605 return builder.Build();
2606 };
2607
2608 module_ = CreateNewUnverifiedModule();
2609 HloComputation* cond_computation =
2610 module_->AddEmbeddedComputation(make_cond());
2611 HloComputation* body_computation =
2612 module_->AddEmbeddedComputation(make_body());
2613
2614 auto builder = HloComputation::Builder(TestName());
2615 auto data = builder.AddInstruction(
2616 HloInstruction::CreateParameter(0, data_shape, "data"));
2617 auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2618 data_shape, cond_computation, body_computation, data));
2619 computation_ = module_->AddEntryComputation(builder.Build());
2620
2621 RunAnalysis();
2622
2623 // The While instruction can share with the data operand.
2624 EXPECT_TRUE(
2625 dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
2626 }
2627
2628 // Tests that Call can alias operand buffer if the only use of the operand
2629 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2630 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2631 Shape shape = ShapeUtil::MakeShape(F32, {8});
2632 // Build sub-computation with fusion root.
2633 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2634 auto sub_param = sub_builder.AddInstruction(
2635 HloInstruction::CreateParameter(0, shape, "sub_param"));
2636 auto one = sub_builder.AddInstruction(
2637 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2638 auto ones = sub_builder.AddInstruction(
2639 HloInstruction::CreateBroadcast(shape, one, {1}));
2640 auto add = sub_builder.AddInstruction(
2641 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2642
2643 module_ = CreateNewUnverifiedModule();
2644 auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
2645 sub_computation->CreateFusionInstruction({add, ones},
2646 HloInstruction::FusionKind::kLoop);
2647
2648 // Build entry-computation with kCall which calls 'sub_computation'.
2649 auto builder = HloComputation::Builder(TestName());
2650
2651 auto param = builder.AddInstruction(
2652 HloInstruction::CreateParameter(0, shape, "param"));
2653 auto reverse =
2654 builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2655 auto call = builder.AddInstruction(
2656 HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2657 computation_ = module_->AddEntryComputation(builder.Build());
2658
2659 RunAnalysis();
2660
2661 EXPECT_TRUE(
2662 dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2663 }
2664
2665 } // namespace
2666 } // namespace xla
2667