1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <map>
19 #include <memory>
20
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/test.h"
33
34 namespace op = xla::testing::opcode_matchers;
35
36 namespace xla {
37 namespace {
38
39 using ::testing::UnorderedElementsAre;
40 using ::testing::UnorderedElementsAreArray;
41
42 class TuplePointsToAnalysisTest : public HloTestBase {
43 protected:
44 // Builds a module with the given entry computation and runs points to
45 // analysis.
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)46 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
47 BuildModule(std::move(computation));
48 RunAnalysis();
49 }
50
BuildModule(std::unique_ptr<HloComputation> computation)51 void BuildModule(std::unique_ptr<HloComputation> computation) {
52 module_ = CreateNewUnverifiedModule();
53 module_->AddEntryComputation(std::move(computation));
54 }
55
RunAnalysis()56 void RunAnalysis() {
57 CHECK_NOTNULL(module_.get());
58 points_to_analysis_ =
59 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
60 }
61
62 // Returns the LogicalBuffer defined at the given instruction and
63 // index. CHECKs if no buffer is defined at that point.
GetBuffer(const HloInstruction * instruction,const ShapeIndex & index)64 const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
65 const ShapeIndex& index) {
66 const auto& pointed_to =
67 points_to_analysis_->GetPointsToSet(instruction).element(index);
68 CHECK_EQ(1, pointed_to.size());
69 CHECK_EQ(instruction, pointed_to[0]->instruction());
70 CHECK(index == pointed_to[0]->index());
71 return pointed_to[0];
72 }
73
74 // Checks that the given points-to set contains exactly (unordered) the given
75 // LogicalBuffers.
ExpectHasBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<const LogicalBuffer * const> buffers)76 void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
77 absl::Span<const LogicalBuffer* const> buffers) {
78 std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
79 EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
80 }
81
82 // Checks that the given points-to set contains exactly (unordered) the
83 // top-level buffers of the given instructions.
ExpectHasTopLevelBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<HloInstruction * const> instructions)84 void ExpectHasTopLevelBuffers(
85 const PointsToSet::BufferList& points_to_set,
86 absl::Span<HloInstruction* const> instructions) {
87 PointsToSet::BufferList buffers;
88 for (auto instruction : instructions) {
89 buffers.push_back(GetBuffer(instruction, /*index=*/{}));
90 }
91 ExpectHasBuffers(points_to_set, buffers);
92 }
93
94 // Overload which takes a set instead of a vector.
ExpectHasTopLevelBuffers(const PointsToSet::BufferSet & points_to_set,absl::Span<HloInstruction * const> instructions)95 void ExpectHasTopLevelBuffers(
96 const PointsToSet::BufferSet& points_to_set,
97 absl::Span<HloInstruction* const> instructions) {
98 ExpectHasTopLevelBuffers(
99 PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
100 instructions);
101 }
102
103 // Checks that the buffer defined at the given instruction and index has
104 // aliases which are exactly (unordered) the given instruction/index pairs.
ExpectHasBufferAliases(const HloInstruction * instruction,const ShapeIndex & index,absl::Span<const std::pair<HloInstruction *,ShapeIndex>> expected)105 void ExpectHasBufferAliases(
106 const HloInstruction* instruction, const ShapeIndex& index,
107 absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
108 const LogicalBuffer* buffer =
109 points_to_analysis_->GetBufferDefinedAt(instruction, index)
110 .ValueOrDie();
111 std::vector<BufferAlias> expected_aliases;
112 for (auto& pair : expected) {
113 expected_aliases.push_back(BufferAlias(pair.first, pair.second));
114 }
115 EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer),
116 UnorderedElementsAreArray(expected_aliases));
117 }
118
119 std::unique_ptr<HloModule> module_;
120 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
121 };
122
TEST_F(TuplePointsToAnalysisTest,SimpleTuple)123 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
124 auto builder = HloComputation::Builder(TestName());
125 auto constant1 = builder.AddInstruction(
126 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
127 auto constant2 = builder.AddInstruction(
128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
129 auto tuple = builder.AddInstruction(
130 HloInstruction::CreateTuple({constant1, constant2}));
131
132 BuildModuleAndRunAnalysis(builder.Build());
133 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
134 ExpectHasTopLevelBuffers(
135 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
136 EXPECT_TRUE(
137 points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
138 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
139
140 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
141 ExpectHasTopLevelBuffers(
142 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
143 EXPECT_TRUE(
144 points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
145
146 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
147 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
148 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
149 UnorderedElementsAre(tuple));
150
151 ExpectHasTopLevelBuffers(
152 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
153 {constant1, constant2, tuple});
154 ExpectHasTopLevelBuffers(
155 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
156 ExpectHasTopLevelBuffers(
157 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
158 ExpectHasTopLevelBuffers(
159 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
160
161 const PointsToSet& tuple_points_to_set =
162 points_to_analysis_->GetPointsToSet(tuple);
163 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
164 *GetBuffer(constant1, {}), {0}));
165 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
166 *GetBuffer(constant2, {}), {1}));
167 EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
168 *GetBuffer(constant2, {}), {0}));
169 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
170 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
171 }
172
TEST_F(TuplePointsToAnalysisTest,NestedTuple)173 TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
174 // Create a (nested) tuple containing an inner tuple. The points-to set of the
175 // outer tuple should contain all elements of the points-to set of the inner
176 // tuple.
177 auto builder = HloComputation::Builder(TestName());
178 auto constant1 = builder.AddInstruction(
179 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
180 auto constant2 = builder.AddInstruction(
181 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
182 auto inner_tuple = builder.AddInstruction(
183 HloInstruction::CreateTuple({constant1, constant2}));
184
185 auto constant3 = builder.AddInstruction(
186 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
187 auto tuple = builder.AddInstruction(
188 HloInstruction::CreateTuple({inner_tuple, constant3}));
189
190 BuildModuleAndRunAnalysis(builder.Build());
191 ExpectHasTopLevelBuffers(
192 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
193 ExpectHasTopLevelBuffers(
194 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
195 ExpectHasTopLevelBuffers(
196 points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
197
198 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
199 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
200 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
201 ExpectHasTopLevelBuffers(
202 points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
203 {constant1, constant2, inner_tuple});
204 ExpectHasTopLevelBuffers(
205 points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
206 {inner_tuple});
207 EXPECT_THAT(
208 points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
209 UnorderedElementsAre(inner_tuple));
210
211 EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
212 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
213 ExpectHasTopLevelBuffers(
214 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
215 {constant1, constant2, constant3, inner_tuple, tuple});
216
217 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
218 UnorderedElementsAre(tuple));
219 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
220 UnorderedElementsAre(inner_tuple));
221 EXPECT_TRUE(
222 points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
223
224 ExpectHasTopLevelBuffers(
225 points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
226 ExpectHasTopLevelBuffers(
227 points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
228 ExpectHasTopLevelBuffers(
229 points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
230 ExpectHasTopLevelBuffers(
231 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
232 }
233
TEST_F(TuplePointsToAnalysisTest,GetTupleElement)234 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
235 // Create a nested tuple, then extract the inner tuple with GetTupleElement.
236 // The points-to set of the GetTupleElement should be the same as the inner
237 // tuple.
238 auto builder = HloComputation::Builder(TestName());
239 auto constant1 = builder.AddInstruction(
240 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
241 auto constant2 = builder.AddInstruction(
242 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
243 auto inner_tuple = builder.AddInstruction(
244 HloInstruction::CreateTuple({constant1, constant2}));
245
246 auto constant3 = builder.AddInstruction(
247 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
248 auto tuple = builder.AddInstruction(
249 HloInstruction::CreateTuple({inner_tuple, constant3}));
250
251 auto get_tuple_element = builder.AddInstruction(
252 HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
253
254 BuildModuleAndRunAnalysis(builder.Build());
255
256 auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
257 EXPECT_EQ(3, points_to_set.size());
258 EXPECT_FALSE(points_to_set.IsAmbiguous());
259 EXPECT_TRUE(points_to_set.IsDistinct());
260 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
261 {constant1, constant2, inner_tuple});
262 ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
263
264 EXPECT_THAT(points_to_set.tuple_sources({}),
265 UnorderedElementsAre(inner_tuple));
266 }
267
TEST_F(TuplePointsToAnalysisTest,AddDependency)268 TEST_F(TuplePointsToAnalysisTest, AddDependency) {
269 auto builder = HloComputation::Builder(TestName());
270 auto constant = builder.AddInstruction(
271 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
272 auto token = builder.AddInstruction(HloInstruction::CreateToken());
273 auto add_dependency = builder.AddInstruction(
274 HloInstruction::CreateAddDependency(constant, token));
275 BuildModuleAndRunAnalysis(builder.Build());
276
277 auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency);
278 EXPECT_EQ(1, points_to_set.size());
279 EXPECT_FALSE(points_to_set.IsAmbiguous());
280 EXPECT_TRUE(points_to_set.IsDistinct());
281 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant});
282 }
283
TEST_F(TuplePointsToAnalysisTest,DuplicatedElement)284 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
285 // Create a tuple which contains duplicate elements.
286 auto builder = HloComputation::Builder(TestName());
287 auto constant = builder.AddInstruction(
288 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
289 auto tuple = builder.AddInstruction(
290 HloInstruction::CreateTuple({constant, constant, constant}));
291
292 BuildModuleAndRunAnalysis(builder.Build());
293
294 EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
295 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
296 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
297 ExpectHasTopLevelBuffers(
298 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
299 ExpectHasTopLevelBuffers(
300 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
301 {constant, tuple});
302 }
303
TEST_F(TuplePointsToAnalysisTest,TupleCopy)304 TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
305 // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
306 // the same.
307 auto builder = HloComputation::Builder(TestName());
308 auto constant1 = builder.AddInstruction(
309 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
310 auto constant2 = builder.AddInstruction(
311 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
312 auto tuple = builder.AddInstruction(
313 HloInstruction::CreateTuple({constant1, constant2}));
314 auto copy = builder.AddInstruction(
315 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
316
317 BuildModuleAndRunAnalysis(builder.Build());
318
319 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
320 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
321 ExpectHasTopLevelBuffers(
322 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
323 {constant1, constant2, tuple});
324 ExpectHasTopLevelBuffers(
325 points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
326 ExpectHasTopLevelBuffers(
327 points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
328 {constant1, constant2, copy});
329 }
330
TEST_F(TuplePointsToAnalysisTest,SendAndSendDone)331 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
332 // Send forwards its operand to the output tuple at {0}.
333 auto builder = HloComputation::Builder(TestName());
334 auto constant = builder.AddInstruction(
335 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
336 auto token = builder.AddInstruction(HloInstruction::CreateToken());
337 auto send = builder.AddInstruction(
338 HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
339 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
340
341 BuildModuleAndRunAnalysis(builder.Build());
342
343 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
344 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
345 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
346 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
347
348 ExpectHasTopLevelBuffers(
349 points_to_analysis_->GetPointsToSet(send).element({}), {send});
350 ExpectHasTopLevelBuffers(
351 points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
352 ExpectHasTopLevelBuffers(
353 points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
354 {send_done});
355 ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
356 }
357
TEST_F(TuplePointsToAnalysisTest,RecvAndRecvDone)358 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
359 // RecvDone forwards its operand tuple element at {0} to the output.
360 auto builder = HloComputation::Builder(TestName());
361 auto token = builder.AddInstruction(HloInstruction::CreateToken());
362 auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
363 ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
364 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
365
366 BuildModuleAndRunAnalysis(builder.Build());
367
368 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
369 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
370 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
371 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
372
373 ExpectHasTopLevelBuffers(
374 points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
375 ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
376 }
377
TEST_F(TuplePointsToAnalysisTest,TupleSelect)378 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
379 // Select from two different tuples. This should create an ambiguous points to
380 // set containing the union of both sides.
381 auto builder = HloComputation::Builder(TestName());
382 auto constant1 = builder.AddInstruction(
383 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
384 auto constant2 = builder.AddInstruction(
385 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
386 auto tuple1 = builder.AddInstruction(
387 HloInstruction::CreateTuple({constant1, constant2}));
388 auto tuple2 = builder.AddInstruction(
389 HloInstruction::CreateTuple({constant2, constant2}));
390
391 auto pred = builder.AddInstruction(
392 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
393 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
394 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
395
396 BuildModuleAndRunAnalysis(builder.Build());
397
398 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
399 EXPECT_EQ(3, points_to_set.size());
400 EXPECT_TRUE(points_to_set.IsAmbiguous());
401 EXPECT_FALSE(points_to_set.IsDistinct());
402 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
403 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
404 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
405 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
406 {constant1, constant2, select});
407 }
408
TEST_F(TuplePointsToAnalysisTest,SelectTupleParameters)409 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
410 // Create a Select which selects between two tuple parameters. Verify the
411 // points-to sets and tuple sources are properly set.
412 Shape tuple_shape = ShapeUtil::MakeTupleShape(
413 {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
414
415 auto builder = HloComputation::Builder(TestName());
416 auto param0 = builder.AddInstruction(
417 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
418 auto param1 = builder.AddInstruction(
419 HloInstruction::CreateParameter(1, tuple_shape, "param1"));
420 auto pred = builder.AddInstruction(
421 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
422 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
423 tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
424 auto copy = builder.AddInstruction(
425 HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
426
427 BuildModuleAndRunAnalysis(builder.Build());
428
429 // The points-to set of each element of a tuple parameters should be itself
430 // with the appropriate index.
431 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
432 {GetBuffer(param0, {})});
433 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
434 {GetBuffer(param0, {0})});
435 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
436 {GetBuffer(param0, {1})});
437
438 // Select's point-to set of its subelements should be the respective
439 // subelements of param0 and param1. The top-level buffer, however, does not
440 // alias as it is created by the select instruction.
441 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
442 {GetBuffer(select, {})});
443 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
444 {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
445 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
446 {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
447
448 // Copy should be identical to select other than the top-level buffer.
449 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
450 {GetBuffer(copy, {})});
451 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
452 {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
453 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
454 {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
455 }
456
TEST_F(TuplePointsToAnalysisTest,UnambiguousTupleSelect)457 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
458 // Select from two identical tuples. The result should not be ambiguous.
459 auto builder = HloComputation::Builder(TestName());
460 auto constant1 = builder.AddInstruction(
461 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
462 auto constant2 = builder.AddInstruction(
463 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
464 auto tuple1 = builder.AddInstruction(
465 HloInstruction::CreateTuple({constant1, constant2}));
466 auto tuple2 = builder.AddInstruction(
467 HloInstruction::CreateTuple({constant1, constant2}));
468
469 auto pred = builder.AddInstruction(
470 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
471 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
472 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
473
474 BuildModuleAndRunAnalysis(builder.Build());
475
476 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
477 EXPECT_EQ(3, points_to_set.size());
478 EXPECT_FALSE(points_to_set.IsAmbiguous());
479 EXPECT_TRUE(points_to_set.IsDistinct());
480 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
481 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
482 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
483 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
484 {constant1, constant2, select});
485 }
486
TEST_F(TuplePointsToAnalysisTest,NestedTupleSelect)487 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
488 // Select from nested tuples. Verify that the nested points-to sets contain
489 // the right values.
490 auto builder = HloComputation::Builder(TestName());
491 auto constant1 = builder.AddInstruction(
492 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
493 auto constant2 = builder.AddInstruction(
494 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
495 auto inner_tuple1 = builder.AddInstruction(
496 HloInstruction::CreateTuple({constant1, constant2}));
497 auto inner_tuple2 = builder.AddInstruction(
498 HloInstruction::CreateTuple({constant2, constant2}));
499
500 auto tuple1 =
501 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
502 auto tuple2 =
503 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
504
505 auto pred = builder.AddInstruction(
506 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
507 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
508 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
509
510 BuildModuleAndRunAnalysis(builder.Build());
511
512 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
513 EXPECT_EQ(5, points_to_set.size());
514 EXPECT_TRUE(points_to_set.IsAmbiguous());
515 EXPECT_FALSE(points_to_set.IsDistinct());
516
517 // Verify points-to set.
518 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
519 ExpectHasTopLevelBuffers(points_to_set.element({0}),
520 {inner_tuple1, inner_tuple2});
521 ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
522 {constant1, constant2});
523 ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
524
525 // Verify tuple sources.
526 EXPECT_THAT(points_to_set.tuple_sources({}),
527 UnorderedElementsAre(tuple1, tuple2));
528 EXPECT_THAT(points_to_set.tuple_sources({0}),
529 UnorderedElementsAre(inner_tuple1, inner_tuple2));
530 EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
531 EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
532 }
533
TEST_F(TuplePointsToAnalysisTest,TupleWithBitcast)534 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
535 // Bitcast is an alias of its operand. A tuple with a bitcast element should
536 // have the operand of the bitcast in its points-to set.
537 auto builder = HloComputation::Builder(TestName());
538 auto constant1 = builder.AddInstruction(
539 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
540 auto constant2 = builder.AddInstruction(
541 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
542 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
543 constant2->shape(), HloOpcode::kBitcast, constant2));
544 auto tuple =
545 builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
546
547 BuildModuleAndRunAnalysis(builder.Build());
548
549 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
550 ExpectHasTopLevelBuffers(
551 points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
552 EXPECT_TRUE(
553 points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
554
555 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
556 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
557 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
558 UnorderedElementsAre(tuple));
559
560 ExpectHasTopLevelBuffers(
561 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
562 {constant1, constant2, tuple});
563 ExpectHasTopLevelBuffers(
564 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
565 ExpectHasTopLevelBuffers(
566 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
567 ExpectHasTopLevelBuffers(
568 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
569 }
570
TEST_F(TuplePointsToAnalysisTest,PointsToTupleConstantElements)571 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
572 // Construct a tuple constant and kCopy it. Verify the points-to set of the
573 // copy correctly correctly points into the nested elements of the constant.
574 auto builder = HloComputation::Builder(TestName());
575 Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
576 LiteralUtil::CreateR1<float>({2.0, 42})};
577 auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
578 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
579 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
580 tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
581
582 BuildModuleAndRunAnalysis(builder.Build());
583
584 auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
585
586 ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
587 ExpectHasBuffers(points_to_set.element({0}),
588 {GetBuffer(tuple_constant, {0})});
589 ExpectHasBuffers(points_to_set.element({1}),
590 {GetBuffer(tuple_constant, {1})});
591 }
592
TEST_F(TuplePointsToAnalysisTest,BufferAliases)593 TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
594 // Create a nested tuple in which individual elements appear multiple
595 // times. Verify buffer alias sets.
596 auto builder = HloComputation::Builder(TestName());
597 auto constant1 = builder.AddInstruction(
598 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
599 auto constant2 = builder.AddInstruction(
600 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
601 auto inner_tuple = builder.AddInstruction(
602 HloInstruction::CreateTuple({constant1, constant2}));
603 auto tuple = builder.AddInstruction(
604 HloInstruction::CreateTuple({inner_tuple, constant2}));
605
606 BuildModuleAndRunAnalysis(builder.Build());
607
608 ExpectHasBufferAliases(
609 constant1, /*index=*/{},
610 {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
611 ExpectHasBufferAliases(
612 constant2, /*index=*/{},
613 {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
614 ExpectHasBufferAliases(inner_tuple, /*index=*/{},
615 {{inner_tuple, {}}, {tuple, {0}}});
616 ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
617 }
618
619 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
620 protected:
621 // Builds a computation, runs instruction fusion HloPass, runs points-to
622 // analysis, then checks for expected results (see unit test cases for
623 // example computation graphs).
Run(const bool add_additional_gte0_user)624 void Run(const bool add_additional_gte0_user) {
625 Shape input_shape = ShapeUtil::MakeShape(F32, {8});
626 Shape update_shape = ShapeUtil::MakeShape(F32, {3});
627 Shape starts_shape = ShapeUtil::MakeShape(S32, {});
628 Shape tuple_shape =
629 ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape});
630
631 auto builder = HloComputation::Builder(TestName());
632 // Create tuple-shaped parameter.
633 auto tuple_param0 = builder.AddInstruction(
634 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
635 // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1).
636 auto tuple_element1 = builder.AddInstruction(
637 HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
638 auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
639 LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
640 // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
641 auto update = builder.AddInstruction(HloInstruction::CreateBinary(
642 update_shape, HloOpcode::kAdd, tuple_element1, ones));
643 // Create 'input' = GetTupleElement(tuple_param0, 0).
644 auto input = builder.AddInstruction(
645 HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0));
646
647 if (add_additional_gte0_user) {
648 // Create 'slice' as an additional user of 'input'.
649 auto slice = builder.AddInstruction(
650 HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
651 // Modify 'update' to take 'slice' output.
652 update = builder.AddInstruction(HloInstruction::CreateBinary(
653 update_shape, HloOpcode::kAdd, update, slice));
654 }
655
656 // Create slice 'starts' = GetTupleElement(tuple_param0, 2).
657 auto starts = builder.AddInstruction(
658 HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2));
659 // Update 'input' with 'update' at dynamic 'starts' indices.
660 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
661 input_shape, input, update, {starts}));
662
663 // Build computation and add it to module as entry computation.
664 BuildModule(builder.Build());
665 // Run instruction fusion HloPass.
666 EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
667 .Run(module_.get())
668 .ValueOrDie());
669 // Get computation root instruction (should be a kFusion).
670 auto* fusion = module_->entry_computation()->root_instruction();
671 EXPECT_THAT(fusion, op::Fusion(tuple_param0));
672 // Run points-to analysis (should include fused instructions from 'fusion').
673 RunAnalysis();
674
675 // Check points-to set of fusion parameter associated with 'tuple_param0'.
676 auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0);
677 ExpectHasBuffers(
678 points_to_analysis_->GetPointsToSet(fusion_param).element({}),
679 {GetBuffer(fusion_param, {})});
680 ExpectHasBuffers(
681 points_to_analysis_->GetPointsToSet(fusion_param).element({0}),
682 {GetBuffer(fusion_param, {0})});
683 ExpectHasBuffers(
684 points_to_analysis_->GetPointsToSet(fusion_param).element({1}),
685 {GetBuffer(fusion_param, {1})});
686 ExpectHasBuffers(
687 points_to_analysis_->GetPointsToSet(fusion_param).element({2}),
688 {GetBuffer(fusion_param, {2})});
689
690 // Check that Gte at tuple_index = 0 points-to fusion_param({0})
691 auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0);
692 ExpectHasBuffers(
693 points_to_analysis_->GetPointsToSet(fused_gte0).element({}),
694 {GetBuffer(fusion_param, {0})});
695 // Check that Gte at tuple_index = 1 points-to fusion_param({1})
696 auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1);
697 ExpectHasBuffers(
698 points_to_analysis_->GetPointsToSet(fused_gte1).element({}),
699 {GetBuffer(fusion_param, {1})});
700 // Check that Gte at tuple_index = 2 points-to fusion_param({2})
701 auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2);
702 ExpectHasBuffers(
703 points_to_analysis_->GetPointsToSet(fused_gte2).element({}),
704 {GetBuffer(fusion_param, {2})});
705
706 // Check buffer aliases of 'fusion_param' at shape index {0}.
707 ExpectHasBufferAliases(fusion_param, /*index=*/{0},
708 {{fusion_param, {0}}, {fused_gte0, {}}});
709 // Check buffer aliases of 'fusion_param' at shape index {1}.
710 ExpectHasBufferAliases(fusion_param, /*index=*/{1},
711 {{fusion_param, {1}}, {fused_gte1, {}}});
712 // Check buffer aliases of 'fusion_param' at shape index {2}.
713 ExpectHasBufferAliases(fusion_param, /*index=*/{2},
714 {{fusion_param, {2}}, {fused_gte2, {}}});
715
716 // Check number of users of 'fusion_param' aliases at shape index {0}.
717 ExpectNumUsersOfAliases(fusion_param, {0},
718 add_additional_gte0_user ? 2 : 1);
719 }
720
721 // Returns fusion parameter (from 'fusion.fused_instructions') corresponding
722 // to fusion 'operand'.
GetFusionParameterForOperand(HloInstruction * fusion,HloInstruction * operand)723 HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
724 HloInstruction* operand) {
725 auto it = absl::c_find_if(
726 fusion->fused_instructions(), [&](const HloInstruction* fused) {
727 return fused->opcode() == HloOpcode::kParameter &&
728 fusion->operand(fused->parameter_number()) == operand;
729 });
730 CHECK(it != fusion->fused_instructions().end());
731 return *it;
732 }
733
734 // Returns all users of 'fusion_paran' at 'tuple_index'.
GetFusionParameterUsersAt(HloInstruction * fusion_param,int64 tuple_index)735 std::vector<HloInstruction*> GetFusionParameterUsersAt(
736 HloInstruction* fusion_param, int64 tuple_index) {
737 CHECK(fusion_param->shape().IsTuple());
738 std::vector<HloInstruction*> users_at_tuple_index;
739 for (auto user : fusion_param->users()) {
740 CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode());
741 if (user->tuple_index() == tuple_index) {
742 users_at_tuple_index.push_back(user);
743 }
744 }
745 return users_at_tuple_index;
746 }
747
748 // Returns the unique user of 'fusion_param' at 'tuple_index'.
GetUniqueFusionParameterUserAt(HloInstruction * fusion_param,int64 tuple_index)749 HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param,
750 int64 tuple_index) {
751 std::vector<HloInstruction*> users =
752 GetFusionParameterUsersAt(fusion_param, tuple_index);
753 CHECK_EQ(1, users.size());
754 return users[0];
755 }
756
757 // Checks that the count of all users of all aliases of 'instruction' at
758 // 'index' match 'expected_num_users'.
ExpectNumUsersOfAliases(const HloInstruction * instruction,const ShapeIndex & index,const int64 expected_num_users)759 void ExpectNumUsersOfAliases(const HloInstruction* instruction,
760 const ShapeIndex& index,
761 const int64 expected_num_users) {
762 const auto* buffer = GetBuffer(instruction, index);
763 int64 num_users = 0;
764 for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) {
765 for (auto user : alias.instruction()->users()) {
766 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
767 // Gte instructions only access the top-level buffer of their operand.
768 continue;
769 }
770 ++num_users;
771 }
772 }
773 EXPECT_EQ(expected_num_users, num_users);
774 }
775 };
776
777 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
778 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
779 // Tests that there is a single user of the aliases of tuple-shaped fusion
780 // parameter 0 at shape index {0}.
781 //
782 // Param0 Const
783 // \ /
784 // Fusion
785 // / \
786 // FusionParam0 FusionParam1
787 // / | \ |
788 // Gte(0) Gte(2) Gte(1) /
789 // \ | \ /
790 // \ | Add
791 // \ | /
792 // \0 |2 /1
793 // DynamicUpdateSlice // fused root.
794 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0OneUser)795 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) {
796 Run(/*add_additional_gte0_user=*/false);
797 }
798
799 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
800 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
801 // Tests that there are two users of the aliases of tuple-shaped fusion
802 // parameter 0 at shape index {0}.
803 //
804 // Param0 Const
805 // \ /
806 // Fusion
807 // / \
808 // FusionParam0 FusionParam1
809 // / | \ |
810 // Gte(2) Gte(0) Gte(1) /
811 // \ | \ /
812 // \ |\ Add
813 // \ | \ /
814 // | | Slice /
815 // | | \ /
816 // | | Add
817 // | | |
818 // |2 |0 |1
819 // DynamicUpdateSlice // fused root.
820 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0TwoUsers)821 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
822 Run(/*add_additional_gte0_user=*/true);
823 }
824
825 class PointsToAnalysisTestBase : public HloTestBase {
826 protected:
BuildModule(std::unique_ptr<HloComputation> computation)827 void BuildModule(std::unique_ptr<HloComputation> computation) {
828 module_ = CreateNewUnverifiedModule();
829 computation_ = module_->AddEntryComputation(std::move(computation));
830 }
831
RunAnalysis()832 void RunAnalysis() {
833 CHECK_NOTNULL(module_.get());
834 points_to_analysis_ =
835 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
836 }
837
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)838 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
839 BuildModule(std::move(computation));
840 RunAnalysis();
841 }
842
843 std::unique_ptr<HloModule> module_;
844 HloComputation* computation_ = nullptr;
845 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
846 };
847
848 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
849
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)850 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
851 auto builder = HloComputation::Builder(TestName());
852
853 Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
854 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
855 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
856 auto gte0 = builder.AddInstruction(
857 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
858 auto gte1 = builder.AddInstruction(
859 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
860 builder.AddInstruction(
861 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
862
863 BuildModuleAndRunAnalysis(builder.Build());
864
865 // GetTupleElement instructions only access the top-level buffer of their
866 // operand.
867 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
868 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
869 EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
870 EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
871 }
872
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)873 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
874 auto builder = HloComputation::Builder(TestName());
875
876 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
877 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
878 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
879 auto gte0 = builder.AddInstruction(
880 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
881 auto gte1 = builder.AddInstruction(
882 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
883
884 // Create a DynamicUpdateSlice instruction of tuple element 1.
885 auto starts = builder.AddInstruction(
886 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
887 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
888 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
889 auto dynamic_update_slice =
890 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
891 data_shape, gte1, update, {starts}));
892 builder.AddInstruction(
893 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
894
895 BuildModule(builder.Build());
896 auto fusion = computation_->CreateFusionInstruction(
897 {dynamic_update_slice, starts, update, gte1},
898 HloInstruction::FusionKind::kLoop);
899 RunAnalysis();
900
901 // The fusion instruction never uses tuple element 0, but does use element 1.
902 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
903 EXPECT_FALSE(
904 points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
905 }
906
907 class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
908
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)909 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
910 auto builder = HloComputation::Builder(TestName());
911
912 Shape shape = ShapeUtil::MakeShape(F32, {8});
913 auto param = builder.AddInstruction(
914 HloInstruction::CreateParameter(0, shape, "param"));
915 auto exp = builder.AddInstruction(
916 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
917 auto log = builder.AddInstruction(
918 HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
919
920 BuildModuleAndRunAnalysis(builder.Build());
921
922 EXPECT_TRUE(
923 points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
924 EXPECT_TRUE(
925 points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
926 }
927
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)928 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
929 auto builder = HloComputation::Builder(TestName());
930
931 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
932 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
933 auto param0 = builder.AddInstruction(
934 HloInstruction::CreateParameter(0, in_shape, "param0"));
935 auto param1 = builder.AddInstruction(
936 HloInstruction::CreateParameter(1, in_shape, "param1"));
937 auto result = builder.AddInstruction(HloInstruction::CreateCompare(
938 out_shape, param0, param1, ComparisonDirection::kEq));
939
940 BuildModuleAndRunAnalysis(builder.Build());
941
942 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
943 result, {}));
944 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
945 result, {}));
946 }
947
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)948 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
949 auto builder = HloComputation::Builder(TestName());
950
951 Shape shape = ShapeUtil::MakeShape(F32, {8});
952 auto param = builder.AddInstruction(
953 HloInstruction::CreateParameter(0, shape, "param"));
954 auto exp = builder.AddInstruction(
955 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
956 auto copy = builder.AddInstruction(
957 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
958
959 BuildModuleAndRunAnalysis(builder.Build());
960
961 EXPECT_TRUE(
962 points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
963 EXPECT_TRUE(
964 points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
965 }
966
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)967 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
968 auto builder = HloComputation::Builder(TestName());
969
970 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
971 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
972 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
973 auto gte0 = builder.AddInstruction(
974 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
975 auto gte1 = builder.AddInstruction(
976 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
977
978 // Create a DynamicUpdateSlice instruction of tuple element 1.
979 auto starts = builder.AddInstruction(
980 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
981 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
982 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
983 auto dynamic_update_slice =
984 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
985 data_shape, gte1, update, {starts}));
986 builder.AddInstruction(
987 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
988
989 BuildModule(builder.Build());
990 auto fusion = computation_->CreateFusionInstruction(
991 {dynamic_update_slice, starts, update, gte1},
992 HloInstruction::FusionKind::kLoop);
993 RunAnalysis();
994
995 // The fusion instruction can share with tuple element 1.
996 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0},
997 fusion, {}));
998 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1},
999 fusion, {}));
1000 }
1001
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)1002 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
1003 auto builder = HloComputation::Builder(TestName());
1004
1005 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1006 Shape update_shape = ShapeUtil::MakeShape(F32, {4});
1007 Shape starts_shape = ShapeUtil::MakeShape(S32, {});
1008 auto data = builder.AddInstruction(
1009 HloInstruction::CreateParameter(0, data_shape, "data"));
1010 auto update = builder.AddInstruction(
1011 HloInstruction::CreateParameter(1, update_shape, "update"));
1012 auto starts = builder.AddInstruction(
1013 HloInstruction::CreateParameter(2, starts_shape, "starts"));
1014 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1015 data_shape, data, update, {starts}));
1016
1017 BuildModuleAndRunAnalysis(builder.Build());
1018
1019 // The DynamicUpdateSlice instruction can share with the data operand, but not
1020 // with update or starts.
1021 EXPECT_TRUE(
1022 points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
1023 EXPECT_FALSE(
1024 points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
1025 EXPECT_FALSE(
1026 points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
1027 }
1028
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)1029 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
1030 const char* hlo_text = R"(
1031 HloModule TensorFlowScatterV1
1032
1033 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1034 lhs = s32[] parameter(0)
1035 ROOT rhs = s32[] parameter(1)
1036 }
1037
1038 ENTRY main {
1039 operand = s32[3,3] parameter(0)
1040 indices = s32[2] parameter(1)
1041 updates = s32[2,3] parameter(2)
1042 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
1043 to_apply=update_s32,
1044 update_window_dims={1},
1045 inserted_window_dims={0},
1046 scatter_dims_to_operand_dims={0},
1047 index_vector_dim=1
1048 }
1049 )";
1050 TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
1051 computation_ = module_->entry_computation();
1052 RunAnalysis();
1053
1054 HloInstruction* operand_param = computation_->parameter_instruction(0);
1055 HloInstruction* indices_param = computation_->parameter_instruction(1);
1056 HloInstruction* updates_param = computation_->parameter_instruction(2);
1057 HloInstruction* scatter = computation_->root_instruction();
1058
1059 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
1060 operand_param, {}, scatter, {}));
1061 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
1062 indices_param, {}, scatter, {}));
1063 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
1064 updates_param, {}, scatter, {}));
1065 }
1066
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)1067 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
1068 auto builder = HloComputation::Builder(TestName());
1069 module_ = CreateNewVerifiedModule();
1070
1071 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
1072 auto keys = builder.AddInstruction(
1073 HloInstruction::CreateParameter(0, keys_shape, "keys"));
1074 TF_ASSERT_OK_AND_ASSIGN(
1075 auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false,
1076 &builder, module_.get()));
1077
1078 computation_ = module_->AddEntryComputation(builder.Build());
1079 RunAnalysis();
1080
1081 EXPECT_TRUE(
1082 points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
1083 }
1084
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)1085 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
1086 auto builder = HloComputation::Builder(TestName());
1087 module_ = CreateNewVerifiedModule();
1088
1089 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
1090 Shape values_shape = ShapeUtil::MakeShape(F32, {8});
1091 auto keys = builder.AddInstruction(
1092 HloInstruction::CreateParameter(0, keys_shape, "keys"));
1093 auto values = builder.AddInstruction(
1094 HloInstruction::CreateParameter(1, values_shape, "values"));
1095 TF_ASSERT_OK_AND_ASSIGN(
1096 auto* sort,
1097 MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
1098 {keys, values}, 0, /*is_stable=*/false, &builder,
1099 module_.get()));
1100
1101 computation_ = module_->AddEntryComputation(builder.Build());
1102 RunAnalysis();
1103
1104 // The buffer for the keys can be shared with the first tuple entry.
1105 EXPECT_TRUE(
1106 points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
1107 // The buffer for the values can be shared with the second tuple entry.
1108 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
1109 sort, {1}));
1110 // Verify that the buffers are not shared with the "wrong" tuple entry.
1111 EXPECT_FALSE(
1112 points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
1113 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
1114 sort, {0}));
1115 }
1116
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)1117 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
1118 auto builder = HloComputation::Builder(TestName());
1119 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1120
1121 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1122 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1123 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1124 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1125
1126 DotDimensionNumbers dot_dnums;
1127 dot_dnums.add_lhs_contracting_dimensions(1);
1128 dot_dnums.add_rhs_contracting_dimensions(0);
1129 PrecisionConfig precision_config;
1130 precision_config.mutable_operand_precision()->Resize(
1131 /*new_size=*/2, PrecisionConfig::DEFAULT);
1132 auto dot = builder.AddInstruction(
1133 HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
1134
1135 auto one = builder.AddInstruction(
1136 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1137 auto add_operand = builder.AddInstruction(
1138 HloInstruction::CreateBroadcast(data_shape, one, {1}));
1139
1140 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1141 data_shape, HloOpcode::kAdd, dot, add_operand));
1142
1143 BuildModule(builder.Build());
1144 auto fusion = computation_->CreateFusionInstruction(
1145 {add, dot}, HloInstruction::FusionKind::kOutput);
1146 RunAnalysis();
1147
1148 // Output fused dot add should be able to share buffer with 'add_operand'.
1149 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
1150 add_operand, {}, fusion, {}));
1151 }
1152
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)1153 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
1154 auto builder = HloComputation::Builder(TestName());
1155 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1156
1157 auto one = builder.AddInstruction(
1158 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1159 auto operand = builder.AddInstruction(
1160 HloInstruction::CreateBroadcast(data_shape, one, {1}));
1161
1162 auto reverse = builder.AddInstruction(
1163 HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
1164
1165 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
1166 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1167
1168 auto add = builder.AddInstruction(
1169 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
1170
1171 BuildModule(builder.Build());
1172 auto fusion = computation_->CreateFusionInstruction(
1173 {add, two, reverse}, HloInstruction::FusionKind::kOutput);
1174 RunAnalysis();
1175
1176 // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
1177 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {},
1178 fusion, {}));
1179 }
1180
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)1181 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
1182 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1183
1184 auto make_cond = [&data_shape]() {
1185 auto builder = HloComputation::Builder(TestName() + ".Cond");
1186 auto data = builder.AddInstruction(
1187 HloInstruction::CreateParameter(0, data_shape, "data"));
1188 builder.AddInstruction(HloInstruction::CreateCompare(
1189 ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
1190 return builder.Build();
1191 };
1192
1193 auto make_body = [&data_shape]() {
1194 auto builder = HloComputation::Builder(TestName() + ".Body");
1195 auto data = builder.AddInstruction(
1196 HloInstruction::CreateParameter(0, data_shape, "data"));
1197 builder.AddInstruction(
1198 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
1199 return builder.Build();
1200 };
1201
1202 module_ = CreateNewUnverifiedModule();
1203 HloComputation* cond_computation =
1204 module_->AddEmbeddedComputation(make_cond());
1205 HloComputation* body_computation =
1206 module_->AddEmbeddedComputation(make_body());
1207
1208 auto builder = HloComputation::Builder(TestName());
1209 auto data = builder.AddInstruction(
1210 HloInstruction::CreateParameter(0, data_shape, "data"));
1211 auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
1212 data_shape, cond_computation, body_computation, data));
1213 computation_ = module_->AddEntryComputation(builder.Build());
1214
1215 RunAnalysis();
1216
1217 // The While instruction can share with the data operand.
1218 EXPECT_TRUE(
1219 points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
1220 }
1221
1222 // Tests that Call can alias operand buffer if the only use of the operand
1223 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)1224 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
1225 Shape shape = ShapeUtil::MakeShape(F32, {8});
1226 // Build sub-computation with fusion root.
1227 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1228 auto sub_param = sub_builder.AddInstruction(
1229 HloInstruction::CreateParameter(0, shape, "sub_param"));
1230 auto one = sub_builder.AddInstruction(
1231 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1232 auto ones = sub_builder.AddInstruction(
1233 HloInstruction::CreateBroadcast(shape, one, {1}));
1234 auto add = sub_builder.AddInstruction(
1235 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
1236
1237 module_ = CreateNewUnverifiedModule();
1238 auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
1239 sub_computation->CreateFusionInstruction({add, ones},
1240 HloInstruction::FusionKind::kLoop);
1241
1242 // Build entry-computation with kCall which calls 'sub_computation'.
1243 auto builder = HloComputation::Builder(TestName());
1244
1245 auto param = builder.AddInstruction(
1246 HloInstruction::CreateParameter(0, shape, "param"));
1247 auto reverse =
1248 builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
1249 auto call = builder.AddInstruction(
1250 HloInstruction::CreateCall(shape, {reverse}, sub_computation));
1251 computation_ = module_->AddEntryComputation(builder.Build());
1252
1253 RunAnalysis();
1254
1255 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {},
1256 call, {}));
1257 }
1258
TEST_F(CanShareOperandBufferWithUserTest,LoopFusionWithElementwiseOperand)1259 TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) {
1260 Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32});
1261 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16});
1262
1263 auto builder = HloComputation::Builder(TestName() + "_fusion");
1264 auto param0 = builder.AddInstruction(
1265 HloInstruction::CreateParameter(0, full_shape, "full"));
1266 auto param1 = builder.AddInstruction(
1267 HloInstruction::CreateParameter(1, broadcast_shape, "small"));
1268 auto broadcast = builder.AddInstruction(
1269 HloInstruction::CreateBroadcast(full_shape, param1, {0}));
1270 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1271 full_shape, HloOpcode::kAdd, param0, broadcast));
1272
1273 BuildModule(builder.Build());
1274 auto fusion = computation_->CreateFusionInstruction(
1275 {add, broadcast}, HloInstruction::FusionKind::kLoop);
1276 RunAnalysis();
1277
1278 EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
1279 fusion, {}));
1280 EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
1281 fusion, {}));
1282 }
1283
1284 } // namespace
1285 } // namespace xla
1286